Skip to content

Commit cb70069

Browse files
committed
Add test for inconsistent zero batch output
1 parent d6dd369 commit cb70069

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

test/tests_onnx.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,27 @@ def test_onnx_modelrun_mnist(env):
152152
env.assertEqual(values2, values)
153153

154154

155+
def test_onnx_modelrun_batchdim_mismatch(env):
156+
con = env.getConnection()
157+
158+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
159+
model_filename = os.path.join(test_data_path, 'batchdim_mismatch.onnx')
160+
161+
with open(model_filename, 'rb') as f:
162+
model_pb = f.read()
163+
164+
ret = con.execute_command('AI.MODELSET', 'm', 'ONNX', DEVICE, 'BLOB', model_pb)
165+
env.assertEqual(ret, b'OK')
166+
167+
ensureSlaveSynced(con, env)
168+
169+
con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 'VALUES', 1, 1)
170+
con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 'VALUES', 1, 1)
171+
172+
con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c', 'd')
173+
174+
175+
155176
def test_onnx_modelrun_mnist_autobatch(env):
156177
if not TEST_ONNX:
157178
return

test/tests_pytorch.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,26 @@ def test_pytorch_modelrun(env):
177177
env.assertEqual(values2, values)
178178

179179

180+
def test_pytorch_modelrun_batchdim_mismatch(env):
181+
con = env.getConnection()
182+
183+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
184+
model_filename = os.path.join(test_data_path, 'batchdim_mismatch.pt')
185+
186+
with open(model_filename, 'rb') as f:
187+
model_pb = f.read()
188+
189+
ret = con.execute_command('AI.MODELSET', 'm', 'TORCH', DEVICE, 'BLOB', model_pb)
190+
env.assertEqual(ret, b'OK')
191+
192+
ensureSlaveSynced(con, env)
193+
194+
con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 'VALUES', 1, 1)
195+
con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 'VALUES', 1, 1)
196+
197+
con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c', 'd')
198+
199+
180200
def test_pytorch_modelrun_autobatch(env):
181201
if not TEST_PT:
182202
return

0 commit comments

Comments
 (0)