diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b22341f0..e393a9018 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,7 @@ To release a new version, please update the changelog as followed: - Support string dtype in InputLayer (#PR 1017) - Support Dynamic RNN in RNN (#PR 1023) - Add ResNet50 static model (#PR 1030) +_ Add performance test code in static model (#PR 1041) ### Changed @@ -115,6 +116,7 @@ To release a new version, please update the changelog as followed: - Copy original model's `trainable_weights` and `nontrainable_weights` when initializing `LayerList` (#PR 1029) - Remove redundant parts in `model.all_layers` (#PR 1029) - Replace `tf.image.resize_image_with_crop_or_pad` with `tf.image.resize_with_crop_or_pad` (#PR 1032) +- Fix a bug in `ResNet50` static model (#PR 1041) ### Removed @@ -124,7 +126,7 @@ To release a new version, please update the changelog as followed: - @zsdonghao - @ChrisWu1997: #1010 #1015 #1025 #1030 -- @warshallrho: #1017 #1021 #1026 #1029 #1032 +- @warshallrho: #1017 #1021 #1026 #1029 #1032 #1041 - @ArnoldLIULJ: #1023 - @JingqingZ: #1023 diff --git a/tensorlayer/models/resnet.py b/tensorlayer/models/resnet.py index 9938fd1cd..87bdc5641 100644 --- a/tensorlayer/models/resnet.py +++ b/tensorlayer/models/resnet.py @@ -150,21 +150,21 @@ def ResNet50(pretrained=False, end_with='fc1000', n_classes=1000, name=None): n = BatchNorm(name='bn_conv1', act='relu')(n) n = MaxPool2d((3, 3), strides=(2, 2), name='max_pool1')(n) - for i, name in enumerate(block_names): - if len(name) == 2: - stage = int(name[0]) - block = name[1] + for i, block_name in enumerate(block_names): + if len(block_name) == 2: + stage = int(block_name[0]) + block = block_name[1] if block == 'a': strides = (1, 1) if stage == 2 else (2, 2) n = conv_block(n, 3, block_filters[stage - 2], stage=stage, block=block, strides=strides) else: n = identity_block(n, 3, block_filters[stage - 2], stage=stage, block=block) - elif name == 'avg_pool': + elif block_name == 'avg_pool': n = GlobalMeanPool2d(name='avg_pool')(n) - elif name == 'fc1000': + elif block_name == 'fc1000': n = Dense(n_classes, name='fc1000')(n) - if name == end_with: + if block_name == end_with: break network = Model(inputs=ni, outputs=n, name=name) diff --git a/tests/performance_test/vgg/tl2-static-autograph.py b/tests/performance_test/vgg/tl2-static-autograph.py new file mode 100644 index 000000000..0af20adb8 --- /dev/null +++ b/tests/performance_test/vgg/tl2-static-autograph.py @@ -0,0 +1,79 @@ +import time +import os +import psutil +import tensorflow as tf +import tensorlayer as tl +from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE + +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + +tl.logging.set_verbosity(tl.logging.DEBUG) + +# get the whole model +vgg = tl.models.vgg16(mode='static') + +# system monitor +info = psutil.virtual_memory() +monitor_interval = MONITOR_INTERVAL +avg_mem_usage = 0 +max_mem_usage = 0 +count = 0 +total_time = 0 + +# training setting +num_iter = NUM_ITERS +batch_size = BATCH_SIZE +train_weights = vgg.trainable_weights +optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE) +loss_object = tl.cost.cross_entropy + +# data generator +gen = random_input_generator(num_iter, batch_size) + + +# training function +@tf.function +def train_step(x_batch, y_batch): + # forward + backward + with tf.GradientTape() as tape: + ## compute outputs + _logits = vgg(x_batch) + ## compute loss and update model + _loss = loss_object(_logits, y_batch) + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + + +# begin training +vgg.train() + +for idx, data in enumerate(gen): + start_time = time.time() + + train_step(data[0], data[1]) + + end_time = time.time() + consume_time = end_time - start_time + total_time += consume_time + + if idx % monitor_interval == 0: + cur_usage = psutil.Process(os.getpid()).memory_info().rss + max_mem_usage = max(cur_usage, max_mem_usage) + avg_mem_usage += cur_usage + count += 1 + tl.logging.info( + "[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s".format( + idx, cur_usage / (1024 * 1024), consume_time + ) + ) + +print('consumed time:', total_time) + +avg_mem_usage = avg_mem_usage / count / (1024 * 1024) +max_mem_usage = max_mem_usage / (1024 * 1024) +print('average memory usage: {:.2f}MB'.format(avg_mem_usage)) +print('maximum memory usage: {:.2f}MB'.format(max_mem_usage)) diff --git a/tests/performance_test/vgg/tl2-static-eager.py b/tests/performance_test/vgg/tl2-static-eager.py new file mode 100644 index 000000000..b6d5287ba --- /dev/null +++ b/tests/performance_test/vgg/tl2-static-eager.py @@ -0,0 +1,79 @@ +import time +import os +import psutil +import tensorflow as tf +import tensorlayer as tl +from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE + +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + +tl.logging.set_verbosity(tl.logging.DEBUG) + +# get the whole model +vgg = tl.models.vgg16(mode='static') + +# system monitor +info = psutil.virtual_memory() +monitor_interval = MONITOR_INTERVAL +avg_mem_usage = 0 +max_mem_usage = 0 +count = 0 +total_time = 0 + +# training setting +num_iter = NUM_ITERS +batch_size = BATCH_SIZE +train_weights = vgg.trainable_weights +optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE) +loss_object = tl.cost.cross_entropy + +# data generator +gen = random_input_generator(num_iter, batch_size) + + +# training function +def train_step(x_batch, y_batch): + # forward + backward + with tf.GradientTape() as tape: + ## compute outputs + _logits = vgg(x_batch) + ## compute loss and update model + _loss = loss_object(_logits, y_batch) + + grad = tape.gradient(_loss, train_weights) + optimizer.apply_gradients(zip(grad, train_weights)) + return _loss + + +# begin training +vgg.train() + +for idx, data in enumerate(gen): + start_time = time.time() + + loss = train_step(data[0], data[1]) + + end_time = time.time() + consume_time = end_time - start_time + total_time += consume_time + + if idx % monitor_interval == 0: + cur_usage = psutil.Process(os.getpid()).memory_info().rss + max_mem_usage = max(cur_usage, max_mem_usage) + avg_mem_usage += cur_usage + count += 1 + tl.logging.info( + "[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s, loss {:.4f}".format( + idx, cur_usage / (1024 * 1024), consume_time, loss + ) + ) + +print('consumed time:', total_time) + +avg_mem_usage = avg_mem_usage / count / (1024 * 1024) +max_mem_usage = max_mem_usage / (1024 * 1024) +print('average memory usage: {:.2f}MB'.format(avg_mem_usage)) +print('maximum memory usage: {:.2f}MB'.format(max_mem_usage))