Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ To release a new version, please update the changelog as followed:
- Support string dtype in InputLayer (#PR 1017)
- Support Dynamic RNN in RNN (#PR 1023)
- Add ResNet50 static model (#PR 1030)
_ Add performance test code in static model (#PR 1041)

### Changed

Expand All @@ -115,6 +116,7 @@ To release a new version, please update the changelog as followed:
- Copy original model's `trainable_weights` and `nontrainable_weights` when initializing `LayerList` (#PR 1029)
- Remove redundant parts in `model.all_layers` (#PR 1029)
- Replace `tf.image.resize_image_with_crop_or_pad` with `tf.image.resize_with_crop_or_pad` (#PR 1032)
- Fix a bug in `ResNet50` static model (#PR 1041)

### Removed

Expand All @@ -124,7 +126,7 @@ To release a new version, please update the changelog as followed:

- @zsdonghao
- @ChrisWu1997: #1010 #1015 #1025 #1030
- @warshallrho: #1017 #1021 #1026 #1029 #1032
- @warshallrho: #1017 #1021 #1026 #1029 #1032 #1041
- @ArnoldLIULJ: #1023
- @JingqingZ: #1023

Expand Down
14 changes: 7 additions & 7 deletions tensorlayer/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,21 @@ def ResNet50(pretrained=False, end_with='fc1000', n_classes=1000, name=None):
n = BatchNorm(name='bn_conv1', act='relu')(n)
n = MaxPool2d((3, 3), strides=(2, 2), name='max_pool1')(n)

for i, name in enumerate(block_names):
if len(name) == 2:
stage = int(name[0])
block = name[1]
for i, block_name in enumerate(block_names):
if len(block_name) == 2:
stage = int(block_name[0])
block = block_name[1]
if block == 'a':
strides = (1, 1) if stage == 2 else (2, 2)
n = conv_block(n, 3, block_filters[stage - 2], stage=stage, block=block, strides=strides)
else:
n = identity_block(n, 3, block_filters[stage - 2], stage=stage, block=block)
elif name == 'avg_pool':
elif block_name == 'avg_pool':
n = GlobalMeanPool2d(name='avg_pool')(n)
elif name == 'fc1000':
elif block_name == 'fc1000':
n = Dense(n_classes, name='fc1000')(n)

if name == end_with:
if block_name == end_with:
break

network = Model(inputs=ni, outputs=n, name=name)
Expand Down
79 changes: 79 additions & 0 deletions tests/performance_test/vgg/tl2-static-autograph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import time
import os
import psutil
import tensorflow as tf
import tensorlayer as tl
from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

tl.logging.set_verbosity(tl.logging.DEBUG)

# get the whole model
vgg = tl.models.vgg16(mode='static')

# system monitor
info = psutil.virtual_memory()
monitor_interval = MONITOR_INTERVAL
avg_mem_usage = 0
max_mem_usage = 0
count = 0
total_time = 0

# training setting
num_iter = NUM_ITERS
batch_size = BATCH_SIZE
train_weights = vgg.trainable_weights
optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE)
loss_object = tl.cost.cross_entropy

# data generator
gen = random_input_generator(num_iter, batch_size)


# training function
@tf.function
def train_step(x_batch, y_batch):
# forward + backward
with tf.GradientTape() as tape:
## compute outputs
_logits = vgg(x_batch)
## compute loss and update model
_loss = loss_object(_logits, y_batch)

grad = tape.gradient(_loss, train_weights)
optimizer.apply_gradients(zip(grad, train_weights))


# begin training
vgg.train()

for idx, data in enumerate(gen):
start_time = time.time()

train_step(data[0], data[1])

end_time = time.time()
consume_time = end_time - start_time
total_time += consume_time

if idx % monitor_interval == 0:
cur_usage = psutil.Process(os.getpid()).memory_info().rss
max_mem_usage = max(cur_usage, max_mem_usage)
avg_mem_usage += cur_usage
count += 1
tl.logging.info(
"[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s".format(
idx, cur_usage / (1024 * 1024), consume_time
)
)

print('consumed time:', total_time)

avg_mem_usage = avg_mem_usage / count / (1024 * 1024)
max_mem_usage = max_mem_usage / (1024 * 1024)
print('average memory usage: {:.2f}MB'.format(avg_mem_usage))
print('maximum memory usage: {:.2f}MB'.format(max_mem_usage))
79 changes: 79 additions & 0 deletions tests/performance_test/vgg/tl2-static-eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import time
import os
import psutil
import tensorflow as tf
import tensorlayer as tl
from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

tl.logging.set_verbosity(tl.logging.DEBUG)

# get the whole model
vgg = tl.models.vgg16(mode='static')

# system monitor
info = psutil.virtual_memory()
monitor_interval = MONITOR_INTERVAL
avg_mem_usage = 0
max_mem_usage = 0
count = 0
total_time = 0

# training setting
num_iter = NUM_ITERS
batch_size = BATCH_SIZE
train_weights = vgg.trainable_weights
optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE)
loss_object = tl.cost.cross_entropy

# data generator
gen = random_input_generator(num_iter, batch_size)


# training function
def train_step(x_batch, y_batch):
# forward + backward
with tf.GradientTape() as tape:
## compute outputs
_logits = vgg(x_batch)
## compute loss and update model
_loss = loss_object(_logits, y_batch)

grad = tape.gradient(_loss, train_weights)
optimizer.apply_gradients(zip(grad, train_weights))
return _loss


# begin training
vgg.train()

for idx, data in enumerate(gen):
start_time = time.time()

loss = train_step(data[0], data[1])

end_time = time.time()
consume_time = end_time - start_time
total_time += consume_time

if idx % monitor_interval == 0:
cur_usage = psutil.Process(os.getpid()).memory_info().rss
max_mem_usage = max(cur_usage, max_mem_usage)
avg_mem_usage += cur_usage
count += 1
tl.logging.info(
"[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s, loss {:.4f}".format(
idx, cur_usage / (1024 * 1024), consume_time, loss
)
)

print('consumed time:', total_time)

avg_mem_usage = avg_mem_usage / count / (1024 * 1024)
max_mem_usage = max_mem_usage / (1024 * 1024)
print('average memory usage: {:.2f}MB'.format(avg_mem_usage))
print('maximum memory usage: {:.2f}MB'.format(max_mem_usage))