Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
10ce5d4
Update utils.py
2wins May 13, 2018
151958e
Update utils.py
2wins May 13, 2018
445690e
Merge branch 'master' into patch-5
2wins May 13, 2018
e504c85
Merge branch 'master' into patch-5
DEKHTIARJonathan May 13, 2018
5d6ad68
Merge branch 'master' into patch-5
DEKHTIARJonathan May 13, 2018
ce3ef61
Create test_utils_predict.py
2wins May 14, 2018
d274464
Update utils.py
2wins May 14, 2018
f494d8c
Update test_utils_predict.py
2wins May 14, 2018
619c08f
Update CHANGELOG.md
2wins May 14, 2018
2546d64
Merge remote-tracking branch 'origin/patch-1' into patch-5
2wins May 14, 2018
cb48f7d
Merge branch 'patch-5' of https://github.com/2wins/tensorlayer into p…
2wins May 14, 2018
e259cfe
Update test_utils_predict.py
2wins May 14, 2018
5ccc7a9
Update CHANGELOG.md
2wins May 14, 2018
a904ae7
Update CHANGELOG.md
2wins May 14, 2018
a0be0c9
Update CHANGELOG.md
2wins May 14, 2018
67cce87
Merge branch 'master' into patch-5
zsdonghao May 14, 2018
346f23b
Update test_utils_predict.py
2wins May 14, 2018
a9a0091
Reflect the latest update
2wins May 14, 2018
147e893
Update test_utils_predict.py
2wins May 14, 2018
f9bd69f
Update test_utils_predict.py
2wins May 14, 2018
be46a1b
Update test_utils_predict.py
2wins May 14, 2018
c31bc43
Update test_utils_predict.py
2wins May 14, 2018
532c3b8
Update test_utils_predict.py (fix Bad Coding Style)
2wins May 14, 2018
0cb228f
Update test_utils_predict.py
2wins May 15, 2018
cee22df
Update CHANGELOG.md
2wins May 15, 2018
d3c4346
Update CHANGELOG.md
2wins May 15, 2018
cb1ebb7
Merge branch 'master' into patch-5
2wins May 15, 2018
99fa99b
Update CHANGELOG.md
2wins May 15, 2018
fd4a01c
Update CHANGELOG.md
2wins May 15, 2018
89e8535
Merge branch 'master' into patch-5
DEKHTIARJonathan May 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ To release a new version, please update the changelog as followed:
- Tutorials:
- `tutorial_tfslim` has been introduced to show how to use `SlimNetsLayer` (by @2wins in #560).
- Test:
- `test_utils_predict.py` added to reproduce and fix issue #288 (by @2wins in #566)
- `Layer_DeformableConvolution_Test` added to reproduce issue #572 with deformable convolution (by @DEKHTIARJonathan in #573)
- `Array_Op_Alphas_Test` and `Array_Op_Alphas_Like_Test` added to test `tensorlayer/array_ops.py` file (by @DEKHTIARJonathan in #580)
- CI Tool:
Expand All @@ -93,7 +94,8 @@ To release a new version, please update the changelog as followed:
### Fixed
- Issue #498 - Deprecation Warning Fix in `tl.layers.RNNLayer` with `inspect` (by @DEKHTIARJonathan in #574)
- Issue #498 - Deprecation Warning Fix in `tl.files` with truth value of an empty array is ambiguous (by @DEKHTIARJonathan in #575)
- Issue #572 with deformable convolution fixed (by @DEKHTIARJonathan in #573)
- Issue #565 related to `tl.utils.predict` fixed - `np.hstack` problem in which the results for multiple batches are stacked along `axis=1` (by @2wins in #566)
- Issue #572 with `tl.layers.DeformableConv2d` fixed (by @DEKHTIARJonathan in #573)
- Typo of the document of ElementwiseLambdaLayer (by @zsdonghao in #588)

### Security
Expand Down
4 changes: 2 additions & 2 deletions tensorlayer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def predict(sess, network, X, x, y_op, batch_size=None):
if result is None:
result = result_a
else:
result = np.hstack((result, result_a)) # TODO: https://github.com/tensorlayer/tensorlayer/issues/288
result = np.concatenate((result, result_a))
if result is None:
if len(X) % batch_size != 0:
dp_dict = dict_to_one(network.all_drop)
Expand All @@ -338,7 +338,7 @@ def predict(sess, network, X, x, y_op, batch_size=None):
}
feed_dict.update(dp_dict)
result_a = sess.run(y_op, feed_dict=feed_dict)
result = np.hstack((result, result_a)) # TODO: https://github.com/tensorlayer/tensorlayer/issues/288
result = np.concatenate((result, result_a))
return result


Expand Down
51 changes: 51 additions & 0 deletions tests/test_utils_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import unittest

try:
from tests.unittests_helper import CustomTestCase
except ImportError:
from unittests_helper import CustomTestCase

import tensorflow as tf
import tensorlayer as tl
import numpy as np


class Util_Predict_Test(CustomTestCase):

@classmethod
def setUpClass(cls):
cls.x1 = tf.placeholder(tf.float32, [None, 5, 5, 3])
cls.x2 = tf.placeholder(tf.float32, [8, 5, 5, 3])
cls.X1 = np.ones([127, 5, 5, 3])
cls.X2 = np.ones([7, 5, 5, 3])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it normal that they do not use the same batch_size ? It is usual a thin that we don't want to do...

cls.x2 = tf.placeholder(tf.float32, [8, 5, 5, 3])
cls.X2 = np.ones([7, 5, 5, 3])

Copy link
Member

@DEKHTIARJonathan DEKHTIARJonathan May 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes we set the shape of tf.placeholder explicitly using a specific batch size.
However, it becomes a problem when the number of samples is not a factor of the batch size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls.x2 = tf.placeholder(tf.float32, [8, 5, 5, 3])
cls.X2 = np.ones([7, 5, 5, 3])
# Actually, the same with the above case
cls.x2 = tf.placeholder(tf.float32, [8, 5, 5, 3])
cls.X2 = np.ones([135, 5, 5, 3])

cls.batch_size = 8

@classmethod
def tearDownClass(cls):
tf.reset_default_graph()

def test_case1(self):
with self.assertNotRaises(Exception):
with tf.Session() as sess:
n = tl.layers.InputLayer(self.x1)
y = n.outputs
y_op = tf.nn.softmax(y)
tl.utils.predict(sess, n, self.X1, self.x1, y_op, batch_size=self.batch_size)
sess.close()

def test_case2(self):
with self.assertRaises(Exception):
with tf.Session() as sess:
n = tl.layers.InputLayer(self.x2)
y = n.outputs
y_op = tf.nn.softmax(y)
tl.utils.predict(sess, n, self.X2, self.x2, y_op, batch_size=self.batch_size)
sess.close()


if __name__ == '__main__':

# tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.set_verbosity(tf.logging.DEBUG)

unittest.main()