From 47fde592b7a54deed8d77d48884306d08f5911ee Mon Sep 17 00:00:00 2001 From: Konstantin Slavnov Date: Sat, 3 Jul 2021 00:00:48 +0300 Subject: [PATCH] Update timm library to 0.4.12 --- requirements.txt | 4 ++-- segmentation_models_pytorch/encoders/timm_efficientnet.py | 8 ++++---- segmentation_models_pytorch/encoders/timm_sknet.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index a88a7a87..49a43b77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torchvision>=0.3.0 +torchvision>=0.5.0 pretrainedmodels==0.7.4 efficientnet-pytorch==0.6.3 -timm==0.3.2 +timm==0.4.12 diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index ad4cbb2c..b7bd7785 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -1,3 +1,5 @@ +from functools import partial + import torch import torch.nn as nn @@ -41,9 +43,8 @@ def get_efficientnet_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, drop_r block_args=decode_arch_def(arch_def, depth_multiplier), num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32, - channel_multiplier=channel_multiplier, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), act_layer=Swish, - norm_kwargs={}, # TODO: check drop_rate=drop_rate, drop_path_rate=0.2, ) @@ -81,9 +82,8 @@ def gen_efficientnet_lite_kwargs(channel_multiplier=1.0, depth_multiplier=1.0, d num_features=1280, stem_size=32, fix_stem=True, - channel_multiplier=channel_multiplier, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), act_layer=nn.ReLU6, - norm_kwargs={}, drop_rate=drop_rate, drop_path_rate=0.2, ) diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py index bfb7572d..6118ae19 100644 --- a/segmentation_models_pytorch/encoders/timm_sknet.py +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -74,7 +74,7 @@ def load_state_dict(self, state_dict, **kwargs): 'block': SelectiveKernelBasic, 'layers': [2, 2, 2, 2], 'zero_init_last_bn': False, - 'block_args': {'sk_kwargs': {'min_attn_channels': 16, 'attn_reduction': 8, 'split_input': True}} + 'block_args': {'sk_kwargs': {'rd_ratio': 1/8, 'split_input': True}} } }, 'timm-skresnet34': { @@ -85,7 +85,7 @@ def load_state_dict(self, state_dict, **kwargs): 'block': SelectiveKernelBasic, 'layers': [3, 4, 6, 3], 'zero_init_last_bn': False, - 'block_args': {'sk_kwargs': {'min_attn_channels': 16, 'attn_reduction': 8, 'split_input': True}} + 'block_args': {'sk_kwargs': {'rd_ratio': 1/8, 'split_input': True}} } }, 'timm-skresnext50_32x4d': {