├── LICENSE ├── readme.md ├── resnest ├── __init__.py ├── resnest.py ├── resnet.py └── splat.py └── usage.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hyungjin Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ResNeSt: Split-Attention Networks 2 | 3 | This is an implementation of "ResNeSt : Split-Attention Networks" on Keras and Tensorflow. 4 | 5 | The implementation is based on paper[1] and [official implementation](https://github.com/zhanghang1989/ResNeSt). 6 | 7 | ## Model 8 | 9 | - Model 10 | * ResNeSt-50 11 | * ResNeSt-101 12 | * ResNeSt-200 13 | * ResNeSt-269 14 | - Pre-trained weight 15 | * Imagenet(The pre-trained weights are converted from [official weight](https://github.com/zhanghang1989/ResNeSt).) 16 | 17 | ## Requirements 18 | 19 | - Python 3 20 | - tensorflow 2 21 | - torch 1.1▲ (Use the pre-trained weights) 22 | 23 | ## Reference 24 | 25 | * ResNeSt : Split-Attention Networks, 26 | Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Mueller, R. Manmatha, Mu Li, Alexander Smola, 27 | https://arxiv.org/abs/2004.08955 28 | 29 | ## Contributor 30 | 31 | * Hyungjin Kim(flslzk@gmail.com) 32 | -------------------------------------------------------------------------------- /resnest/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnest import * 2 | from .splat import rSoftMax -------------------------------------------------------------------------------- /resnest/resnest.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .resnet import ResNet 4 | 5 | #_url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 6 | _url_format = 'https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/{}-{}.pth' 7 | 8 | _model_sha256 = {name: checksum for checksum, name in [ 9 | ('528c19ca', 'resnest50'), 10 | ('22405ba7', 'resnest101'), 11 | ('75117900', 'resnest200'), 12 | ('0cc87c48', 'resnest269'), 13 | ]} 14 | 15 | def short_hash(name): 16 | if name not in _model_sha256: 17 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 18 | return _model_sha256[name][:8] 19 | 20 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 21 | name in _model_sha256.keys() 22 | } 23 | 24 | def load_weight(keras_model, torch_url, group_size = 2): 25 | """ 26 | https://s3.us-west-1.wasabisys.com/resnest/torch/resnest50-528c19ca.pth > https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest50-528c19ca.pth 27 | https://s3.us-west-1.wasabisys.com/resnest/torch/resnest101-22405ba7.pth > https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth 28 | https://s3.us-west-1.wasabisys.com/resnest/torch/resnest200-75117900.pth > https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest200-75117900.pth 29 | https://s3.us-west-1.wasabisys.com/resnest/torch/resnest269-0cc87c48.pth > https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest269-0cc87c48.pth 30 | """ 31 | try: 32 | import torch 33 | torch_weight = torch.hub.load_state_dict_from_url(torch_url, map_location = "cpu", progress = True, check_hash = True) 34 | except: 35 | print("If you want to use 'ResNeSt Weight', please install 'torch 1.1▲'") 36 | return keras_model 37 | 38 | weight = {} 39 | for k, v in dict(torch_weight).items(): 40 | if k.split(".")[-1] in ["weight", "bias", "running_mean", "running_var"]: 41 | if ("downsample" in k or "conv" in k) and "weight" in k and v.ndim == 4: 42 | v = v.permute(2, 3, 1, 0) 43 | elif "fc.weight" in k: 44 | v = v.t() 45 | weight[k] = v.cpu().data.numpy() 46 | 47 | g = 0 48 | downsample = [] 49 | keras_weight = [] 50 | for i, (torch_name, torch_weight) in enumerate(weight.items()): 51 | if i + g < len(keras_model.weights): 52 | keras_name = keras_model.weights[i + g].name 53 | if "downsample" in torch_name: 54 | downsample.append(torch_weight) 55 | continue 56 | elif "group" in keras_name: 57 | g += (group_size - 1) 58 | torch_weight = tf.split(torch_weight, group_size, axis = -1) 59 | else: 60 | torch_weight = [torch_weight] 61 | keras_weight += torch_weight 62 | 63 | for w in keras_model.weights: 64 | if "downsample" in w.name: 65 | new_w = downsample.pop(0) 66 | else: 67 | new_w = keras_weight.pop(0) 68 | tf.keras.backend.set_value(w, new_w) 69 | return keras_model 70 | 71 | def resnest50(input_tensor = None, input_shape = None, classes = 1000, include_top = True, weights = "imagenet"): 72 | if input_tensor is None: 73 | img_input = tf.keras.layers.Input(shape = input_shape) 74 | else: 75 | if not tf.keras.backend.is_keras_tensor(input_tensor): 76 | img_input = tf.keras.layers.Input(tensor = input_tensor, shape = input_shape) 77 | else: 78 | img_input = input_tensor 79 | 80 | out = ResNet(img_input, [3, 4, 6, 3], classes, include_top, radix = 2, group_size = 1, block_width = 64, stem_width = 32, deep_stem = True, avg_down = True, avd = True, avd_first = False) 81 | model = tf.keras.Model(img_input, out) 82 | 83 | if weights == "imagenet": 84 | load_weight(model, resnest_model_urls["resnest50"], group_size = 2 * 1) 85 | elif weights is not None: 86 | model.load_weights(weights) 87 | return model 88 | 89 | 90 | def resnest101(input_tensor = None, input_shape = None, classes = 1000, include_top = True, weights = "imagenet"): 91 | if input_tensor is None: 92 | img_input = tf.keras.layers.Input(shape = input_shape) 93 | else: 94 | if not tf.keras.backend.is_keras_tensor(input_tensor): 95 | img_input = tf.keras.layers.Input(tensor = input_tensor, shape = input_shape) 96 | else: 97 | img_input = input_tensor 98 | 99 | out = ResNet(img_input, [3, 4, 23, 3], classes, include_top, radix = 2, group_size = 1, block_width = 64, stem_width = 64, deep_stem = True, avg_down = True, avd = True, avd_first = False) 100 | model = tf.keras.Model(img_input, out) 101 | 102 | if weights == "imagenet": 103 | load_weight(model, resnest_model_urls["resnest101"], group_size = 2 * 1) 104 | elif weights is not None: 105 | model.load_weights(weights) 106 | return model 107 | 108 | def resnest200(input_tensor = None, input_shape = None, classes = 1000, include_top = True, weights = "imagenet"): 109 | if input_tensor is None: 110 | img_input = tf.keras.layers.Input(shape = input_shape) 111 | else: 112 | if not tf.keras.backend.is_keras_tensor(input_tensor): 113 | img_input = tf.keras.layers.Input(tensor = input_tensor, shape = input_shape) 114 | else: 115 | img_input = input_tensor 116 | 117 | out = ResNet(img_input, [3, 24, 36, 3], classes, include_top, radix = 2, group_size = 1, block_width = 64, stem_width = 64, deep_stem = True, avg_down = True, avd = True, avd_first = False) 118 | model = tf.keras.Model(img_input, out) 119 | 120 | if weights == "imagenet": 121 | load_weight(model, resnest_model_urls["resnest200"], group_size = 2 * 1) 122 | elif weights is not None: 123 | model.load_weights(weights) 124 | return model 125 | 126 | def resnest269(input_tensor = None, input_shape = None, classes = 1000, include_top = True, weights = "imagenet"): 127 | if input_tensor is None: 128 | img_input = tf.keras.layers.Input(shape = input_shape) 129 | else: 130 | if not tf.keras.backend.is_keras_tensor(input_tensor): 131 | img_input = tf.keras.layers.Input(tensor = input_tensor, shape = input_shape) 132 | else: 133 | img_input = input_tensor 134 | 135 | out = ResNet(img_input, [3, 30, 48, 8], classes, include_top, radix = 2, group_size = 1, block_width = 64, stem_width = 64, deep_stem = True, avg_down = True, avd = True, avd_first = False) 136 | model = tf.keras.Model(img_input, out) 137 | 138 | if weights == "imagenet": 139 | load_weight(model, resnest_model_urls["resnest269"], group_size = 2 * 1) 140 | elif weights is not None: 141 | model.load_weights(weights) 142 | return model 143 | -------------------------------------------------------------------------------- /resnest/resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .splat import * 3 | 4 | def resnest_block(x, n_filter, stride_size = 1, dilation = 1, group_size = 1, radix = 1, block_width = 64, avd = False, avd_first = False, downsample = None, dropout_rate = 0., expansion = 4, is_first = False, stage = 1, index = 1): 5 | avd = avd and (1 < stride_size or is_first) 6 | group_width = int(n_filter * (block_width / 64)) * group_size 7 | 8 | out = tf.keras.layers.Conv2D(group_width, 1, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stage{0}_block{1}_conv1".format(stage, index))(x) 9 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stage{0}_block{1}_bn1".format(stage, index))(out) 10 | if 0 < dropout_rate: 11 | out = tf.keras.layers.Dropout(dropout_rate, name = "stage{0}_block{1}_dropout1".format(stage, index))(out) 12 | out = tf.keras.layers.Activation("relu", name = "stage{0}_block{1}_act1".format(stage, index))(out) 13 | 14 | if avd: 15 | avd_layer = tf.keras.layers.AveragePooling2D(3, strides = stride_size, padding = "same", name = "stage{0}_block{1}_avd".format(stage, index)) 16 | stride_size = 1 17 | if avd_first: 18 | out = avd_layer(out) 19 | 20 | if 0 < radix: 21 | out = split_attention_block(out, group_width, 3, stride_size, dilation, group_size, radix, dropout_rate, expansion, prefix = "stage{0}_block{1}".format(stage, index)) 22 | else: 23 | out = tf.keras.layers.Conv2D(group_width, 3, strides = stride_size, dilation_rate = dilation, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stage{0}_block{1}_conv2".format(stage, index))(out) 24 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stage{0}_block{1}_bn2".format(stage, index))(out) 25 | if 0 < dropout_rate: 26 | out = tf.keras.layers.Dropout(dropout_rate, name = "stage{0}_block{1}_dropout2".format(stage, index))(out) 27 | out = tf.keras.layers.Activation("relu", name = "stage{0}_block{1}_act2".format(stage, index))(out) 28 | 29 | if avd and not avd_first: 30 | out = avd_layer(out) 31 | 32 | out = tf.keras.layers.Conv2D(n_filter * expansion, 1, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stage{0}_block{1}_conv3".format(stage, index))(out) 33 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stage{0}_block{1}_bn3".format(stage, index))(out) 34 | if 0 < dropout_rate: 35 | out = tf.keras.layers.Dropout(dropout_rate, name = "stage{0}_block{1}_dropout3".format(stage, index))(out) 36 | residual = x 37 | if downsample is not None: 38 | residual = downsample 39 | out = tf.keras.layers.Add(name = "stage{0}_block{1}_shorcut".format(stage, index))([out, residual]) 40 | out = tf.keras.layers.Activation(tf.keras.activations.relu, name = "stage{0}_block{1}_shorcut_act".format(stage, index))(out) 41 | return out 42 | 43 | def resnest_module(x, n_filter, n_block, stride_size = 1, dilation = 1, group_size = 1, radix = 1, block_width = 64, avg_down = True, avd = False, avd_first = False, dropout_rate = 0., expansion = 4, is_first = True, stage = 1): 44 | downsample = None 45 | if stride_size != 1 or tf.keras.backend.int_shape(x)[-1] != (n_filter * expansion): 46 | if avg_down: 47 | if dilation == 1: 48 | downsample = tf.keras.layers.AveragePooling2D(stride_size, strides = stride_size, padding = "same", name = "stage{0}_downsample_avgpool".format(stage))(x) 49 | else: 50 | downsample = tf.keras.layers.AveragePooling2D(1, strides = 1, padding = "same", name = "stage{0}_downsample_avgpool".format(stage))(x) 51 | downsample = tf.keras.layers.Conv2D(n_filter * expansion, 1, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stage{0}_downsample_conv1".format(stage))(downsample) 52 | downsample = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stage{0}_downsample_bn1".format(stage))(downsample) 53 | else: 54 | downsample = tf.keras.layers.Conv2D(n_filter * expansion, 1, strides = stride_size, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stage{0}_downsample_conv1".format(stage))(x) 55 | downsample = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stage{0}_downsample_bn1".format(stage))(downsample) 56 | 57 | if dilation == 1 or dilation == 2 or dilation == 4: 58 | out = resnest_block(x, n_filter, stride_size, 2 ** (dilation // 4), group_size, radix, block_width, avd, avd_first, downsample, dropout_rate, expansion, is_first, stage = stage) 59 | else: 60 | raise ValueError("unknown dilation size '{0}'".format(dilation)) 61 | 62 | for index in range(1, n_block): 63 | out = resnest_block(out, n_filter, 1, dilation, group_size, radix, block_width, avd, avd_first, dropout_rate = dropout_rate, expansion = expansion, stage = stage, index = index + 1) 64 | return out 65 | 66 | def ResNet(x, stack, n_class = 1000, include_top = True, dilation = 1, group_size =1, radix = 1, block_width = 64, stem_width = 64, deep_stem = False, dilated = False, avg_down = False, avd = False, avd_first = False, dropout_rate = 0., expansion = 4): 67 | #Stem 68 | if deep_stem: 69 | out = tf.keras.layers.Conv2D(stem_width, 3, strides = 2, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stem_conv1")(x) 70 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stem_bn1")(out) 71 | out = tf.keras.layers.Activation("relu", name = "stem_act1")(out) 72 | out = tf.keras.layers.Conv2D(stem_width, 3, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stem_conv2")(out) 73 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stem_bn2")(out) 74 | out = tf.keras.layers.Activation("relu", name = "stem_act2")(out) 75 | out = tf.keras.layers.Conv2D(stem_width * 2, 3, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stem_conv3")(out) 76 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stem_bn3")(out) 77 | out = tf.keras.layers.Activation("relu", name = "stem_act3")(out) 78 | else: 79 | out = tf.keras.layers.Conv2D(64, 7, strides = 2, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "stem_conv1")(x) 80 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "stem_bn1")(out) 81 | out = tf.keras.layers.Activation("relu", name = "stem_act1")(out) 82 | out = tf.keras.layers.MaxPool2D(3, strides = 2, padding = "same", name = "stem_pooling")(out) 83 | 84 | #Stage 1 85 | out = resnest_module(out, 64, stack[0], 1, 1, group_size, radix, block_width, avg_down, avd, avd_first, expansion = expansion, is_first = False, stage = 1) 86 | #Stage 2 87 | out = resnest_module(out, 128, stack[1], 2, 1, group_size, radix, block_width, avg_down, avd, avd_first, expansion = expansion, stage = 2) 88 | 89 | if dilated or dilation == 4: 90 | dilation = [2, 4] 91 | stride_size = [1, 1] 92 | elif dilation == 2: 93 | dilation = [1, 2] 94 | stride_size = [2, 1] 95 | else: 96 | dilation = [1, 1] 97 | stride_size = [2, 2] 98 | 99 | #Stage 3 100 | out = resnest_module(out, 256, stack[2], stride_size[0], dilation[0], group_size, radix, block_width, avg_down, avd, avd_first, dropout_rate, expansion, stage = 3) 101 | #Stage 4 102 | out = resnest_module(out, 512, stack[3], stride_size[1],dilation[1], group_size, radix, block_width, avg_down, avd, avd_first, dropout_rate, expansion, stage = 4) 103 | 104 | if include_top: 105 | out = tf.keras.layers.GlobalAveragePooling2D(name = "feature_avg_pool")(out) 106 | out = tf.keras.layers.Dense(n_class, activation = tf.keras.activations.softmax, name = "logits")(out) 107 | return out 108 | -------------------------------------------------------------------------------- /resnest/splat.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def group_conv(x, filters = None, kernel_size = 3, **kwargs): 4 | if not isinstance(kernel_size, list): 5 | kernel_size = [kernel_size] 6 | n_group = len(kernel_size) 7 | if n_group == 1: 8 | out = [x] 9 | else: 10 | size = tf.keras.backend.int_shape(x)[-1] 11 | split_size = [size // n_group if index != 0 else size // n_group + size % n_group for index in range(n_group)] 12 | out = tf.split(x, split_size, axis = -1) 13 | 14 | name = None 15 | if "name" in kwargs: 16 | name = kwargs["name"] 17 | result = [] 18 | for index in range(n_group): 19 | kwargs["filters"] = filters // n_group 20 | if index == 0: 21 | kwargs["filters"] += filters % n_group 22 | kwargs["kernel_size"] = kernel_size[index] 23 | if name is not None and 1 < n_group: 24 | kwargs["name"] = "{0}_group{1}".format(name, index + 1) 25 | result.append(tf.keras.layers.Conv2D(**kwargs)(out[index])) 26 | if n_group == 1: 27 | out = result[0] 28 | else: 29 | out = tf.keras.layers.Concatenate(axis = -1, name = name)(result) 30 | return out 31 | 32 | def split_attention_block(x, n_filter, kernel_size = 3, stride_size = 1, dilation = 1, group_size = 1, radix = 1, dropout_rate = 0., expansion = 4, prefix = ""): 33 | if len(prefix) != 0: 34 | prefix += "_" 35 | out = group_conv(x, n_filter * radix, [kernel_size] * (group_size * radix), strides = stride_size, dilation_rate = dilation, padding = "same", use_bias = False, kernel_initializer = "he_normal", name = "{0}split_attention_conv1".format(prefix)) 36 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "{0}split_attention_bn1".format(prefix))(out) 37 | if 0 < dropout_rate: 38 | out = tf.keras.layers.Dropout(dropout_rate, name = "{0}split_attention_dropout1".format(prefix))(out) 39 | out = tf.keras.layers.Activation(tf.keras.activations.relu, name = "{0}split_attention_act1".format(prefix))(out) 40 | 41 | inter_channel = max(tf.keras.backend.int_shape(x)[-1] * radix // expansion, 32) 42 | if 1 < radix: 43 | split = tf.split(out, radix, axis = -1) 44 | out = tf.keras.layers.Add(name = "{0}split_attention_add".format(prefix))(split) 45 | out = tf.keras.layers.GlobalAveragePooling2D(name = "{0}split_attention_gap".format(prefix))(out) 46 | out = tf.keras.layers.Reshape([1, 1, n_filter], name = "{0}split_attention_expand_dims".format(prefix))(out) 47 | 48 | out = group_conv(out, inter_channel, [1] * group_size, padding = "same", use_bias = True, kernel_initializer = "he_normal", name = "{0}split_attention_conv2".format(prefix)) 49 | out = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.9, epsilon = 1e-5, name = "{0}split_attention_bn2".format(prefix))(out) 50 | out = tf.keras.layers.Activation("relu", name = "{0}split_attention_act2".format(prefix))(out) 51 | out = group_conv(out, n_filter * radix, [1] * group_size, padding = "same", use_bias = True, kernel_initializer = "he_normal", name = "{0}split_attention_conv3".format(prefix)) 52 | 53 | #attention = rsoftmax(out, n_filter, radix, group_size) 54 | attention = rSoftMax(n_filter, radix, group_size, name = "{0}split_attention_softmax".format(prefix))(out) 55 | if 1 < radix: 56 | attention = tf.split(attention, radix, axis = -1) 57 | out = tf.keras.layers.Add(name = "{0}split_attention_out".format(prefix))([o * a for o, a in zip(split, attention)]) 58 | else: 59 | out = tf.keras.layers.Multiply(name = "{0}split_attention_out".format(prefix))([attention, out]) 60 | return out 61 | 62 | def rsoftmax(x, n_filter, radix, group_size): 63 | if 1 < radix: 64 | out = tf.keras.layers.Reshape([group_size, radix, n_filter // group_size])(x) 65 | out = tf.keras.layers.Permute([2, 1, 3])(out) 66 | out = tf.keras.layers.Lambda(lambda x: tf.nn.softmax(x, axis = 1))(out) 67 | out = tf.keras.layers.Reshape([1, 1, radix * n_filter])(out) 68 | else: 69 | out = tf.keras.layers.Activation(tf.keras.activations.sigmoid)(x) 70 | return out 71 | 72 | class rSoftMax(tf.keras.layers.Layer): 73 | def __init__(self, filters, radix, group_size, **kwargs): 74 | super(rSoftMax, self).__init__(**kwargs) 75 | 76 | self.filters = filters 77 | self.radix = radix 78 | self.group_size = group_size 79 | 80 | if 1 < radix: 81 | self.seq1 = tf.keras.layers.Reshape([group_size, radix, filters // group_size]) 82 | self.seq2 = tf.keras.layers.Permute([2, 1, 3]) 83 | self.seq3 = tf.keras.layers.Lambda(lambda x: tf.nn.softmax(x, axis = 1)) 84 | self.seq4 = tf.keras.layers.Reshape([1, 1, radix * filters]) 85 | self.seq = [self.seq1, self.seq2, self.seq3, self.seq4] 86 | else: 87 | self.seq1 = tf.keras.layers.Activation(tf.keras.activations.sigmoid) 88 | self.seq = [self.seq1] 89 | 90 | def call(self, inputs): 91 | out = inputs 92 | for l in self.seq: 93 | out = l(out) 94 | return out 95 | 96 | def get_config(self): 97 | config = super(rSoftMax, self).get_config() 98 | config["filters"] = self.filters 99 | config["radix"] = self.radix 100 | config["group_size"] = self.group_size 101 | return config 102 | -------------------------------------------------------------------------------- /usage.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/", 9 | "height": 51 10 | }, 11 | "id": "TAtrzndb7Vun", 12 | "outputId": "074362dc-5c41-4895-95b3-e8a11975bf50" 13 | }, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n", 20 | "170500096/170498071 [==============================] - 4s 0us/step\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import tensorflow as tf\n", 26 | "\n", 27 | "def pipe(data, batch_size = 128, shuffle = False):\n", 28 | " dataset = tf.data.Dataset.from_tensor_slices(data)\n", 29 | " if shuffle:\n", 30 | " dataset = dataset.shuffle(buffer_size = batch_size * 10)\n", 31 | " dataset = dataset.batch(batch_size)\n", 32 | " #dataset = dataset.prefetch((batch_size * 2) + 1)\n", 33 | " dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n", 34 | " return dataset\n", 35 | "\n", 36 | "(tr_x, tr_y), (te_x, te_y) = tf.keras.datasets.cifar10.load_data()\n", 37 | "\n", 38 | "tr_x = tr_x * 1/255\n", 39 | "te_x = te_x * 1/255\n", 40 | "\n", 41 | "batch_size = 128\n", 42 | "\n", 43 | "tr_data = pipe((tr_x, tr_y), batch_size = batch_size, shuffle = True)\n", 44 | "te_data = pipe((te_x, te_y), batch_size = batch_size, shuffle = False)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": { 51 | "colab": { 52 | "base_uri": "https://localhost:8080/", 53 | "height": 83, 54 | "referenced_widgets": [ 55 | "f984345bb7124684a517b4231e5766ec", 56 | "d3656d36be6d4477a5bfd941fed8a886", 57 | "85e0d96918444df3878cc0a03f74beb2", 58 | "3914cdf644284fd98224d0716f63f206", 59 | "43b6d8c059e343899db1070a74e745f5", 60 | "8c066c9507e2410c9ede383e75e2b8ea", 61 | "cd152a0c3ce045ceb3c0391304486abd", 62 | "9be86dc9ed694d84b691b5507631a083" 63 | ] 64 | }, 65 | "id": "ySaeM_5H7Q8A", 66 | "outputId": "1b558a06-9ca4-4c40-fdce-76171393d4d0" 67 | }, 68 | "outputs": [ 69 | { 70 | "name": "stderr", 71 | "output_type": "stream", 72 | "text": [ 73 | "Downloading: \"https://s3.us-west-1.wasabisys.com/resnest/torch/resnest101-22405ba7.pth\" to /root/.cache/torch/hub/checkpoints/resnest101-22405ba7.pth\n" 74 | ] 75 | }, 76 | { 77 | "data": { 78 | "application/vnd.jupyter.widget-view+json": { 79 | "model_id": "f984345bb7124684a517b4231e5766ec", 80 | "version_major": 2, 81 | "version_minor": 0 82 | }, 83 | "text/plain": [ 84 | "HBox(children=(FloatProgress(value=0.0, max=193782911.0), HTML(value='')))" 85 | ] 86 | }, 87 | "metadata": { 88 | "tags": [] 89 | }, 90 | "output_type": "display_data" 91 | }, 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "import resnest\n", 102 | "\n", 103 | "model = resnest.resnest101(input_shape = (32, 32, 3), include_top = False, weights = \"imagenet\")\n", 104 | "\n", 105 | "flatten = tf.keras.layers.GlobalAveragePooling2D()(model.output)\n", 106 | "drop_out = tf.keras.layers.Dropout(0.5)(flatten)\n", 107 | "dense = tf.keras.layers.Dense(2048, activation = \"relu\")(drop_out)\n", 108 | "prediction = tf.keras.layers.Dense(10, activation = \"softmax\", name = \"prediction\")(dense)\n", 109 | "model = tf.keras.Model(model.input, prediction)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 3, 115 | "metadata": { 116 | "id": "cka3vwta8pmU" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "loss = tf.keras.losses.sparse_categorical_crossentropy\n", 121 | "opt = tf.keras.optimizers.Adam(1e-4)\n", 122 | "metric = [tf.keras.metrics.sparse_categorical_accuracy]\n", 123 | "model.compile(loss = loss, optimizer = opt, metrics = metric)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 4, 129 | "metadata": { 130 | "colab": { 131 | "base_uri": "https://localhost:8080/", 132 | "height": 1000 133 | }, 134 | "id": "S2T8gk6z9iBH", 135 | "outputId": "a28287c1-21ea-4170-d840-9cdce6a3783c" 136 | }, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "Epoch 1/50\n", 143 | "391/391 [==============================] - 56s 144ms/step - loss: 2.2179 - sparse_categorical_accuracy: 0.2226 - val_loss: 2.2187 - val_sparse_categorical_accuracy: 0.1797\n", 144 | "Epoch 2/50\n", 145 | "391/391 [==============================] - 53s 135ms/step - loss: 1.5344 - sparse_categorical_accuracy: 0.4758 - val_loss: 1.1841 - val_sparse_categorical_accuracy: 0.5939\n", 146 | "Epoch 3/50\n", 147 | "391/391 [==============================] - 52s 134ms/step - loss: 0.9328 - sparse_categorical_accuracy: 0.6882 - val_loss: 1.1895 - val_sparse_categorical_accuracy: 0.7486\n", 148 | "Epoch 4/50\n", 149 | "391/391 [==============================] - 52s 133ms/step - loss: 0.6582 - sparse_categorical_accuracy: 0.7799 - val_loss: 1.6012 - val_sparse_categorical_accuracy: 0.7847\n", 150 | "Epoch 5/50\n", 151 | "391/391 [==============================] - 52s 133ms/step - loss: 0.4902 - sparse_categorical_accuracy: 0.8364 - val_loss: 2.7330 - val_sparse_categorical_accuracy: 0.8052\n", 152 | "Epoch 6/50\n", 153 | "391/391 [==============================] - 52s 133ms/step - loss: 0.3808 - sparse_categorical_accuracy: 0.8742 - val_loss: 0.9634 - val_sparse_categorical_accuracy: 0.8035\n", 154 | "Epoch 7/50\n", 155 | "391/391 [==============================] - 52s 133ms/step - loss: 0.3214 - sparse_categorical_accuracy: 0.8928 - val_loss: 2.9835 - val_sparse_categorical_accuracy: 0.7432\n", 156 | "Epoch 8/50\n", 157 | "391/391 [==============================] - 52s 133ms/step - loss: 0.3916 - sparse_categorical_accuracy: 0.8708 - val_loss: 0.8159 - val_sparse_categorical_accuracy: 0.8155\n", 158 | "Epoch 9/50\n", 159 | "391/391 [==============================] - 52s 134ms/step - loss: 0.2251 - sparse_categorical_accuracy: 0.9250 - val_loss: 0.6268 - val_sparse_categorical_accuracy: 0.8299\n", 160 | "Epoch 10/50\n", 161 | "391/391 [==============================] - 52s 133ms/step - loss: 0.1781 - sparse_categorical_accuracy: 0.9411 - val_loss: 0.6696 - val_sparse_categorical_accuracy: 0.8247\n", 162 | "Epoch 11/50\n", 163 | "391/391 [==============================] - 52s 133ms/step - loss: 0.1272 - sparse_categorical_accuracy: 0.9585 - val_loss: 0.7185 - val_sparse_categorical_accuracy: 0.8337\n", 164 | "Epoch 12/50\n", 165 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0909 - sparse_categorical_accuracy: 0.9712 - val_loss: 0.7217 - val_sparse_categorical_accuracy: 0.8386\n", 166 | "Epoch 13/50\n", 167 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0674 - sparse_categorical_accuracy: 0.9787 - val_loss: 0.7826 - val_sparse_categorical_accuracy: 0.8295\n", 168 | "Epoch 14/50\n", 169 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0584 - sparse_categorical_accuracy: 0.9813 - val_loss: 0.9486 - val_sparse_categorical_accuracy: 0.8217\n", 170 | "Epoch 15/50\n", 171 | "391/391 [==============================] - 53s 134ms/step - loss: 0.0929 - sparse_categorical_accuracy: 0.9708 - val_loss: 19.7243 - val_sparse_categorical_accuracy: 0.7750\n", 172 | "Epoch 16/50\n", 173 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0867 - sparse_categorical_accuracy: 0.9736 - val_loss: 0.8127 - val_sparse_categorical_accuracy: 0.8403\n", 174 | "Epoch 17/50\n", 175 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0738 - sparse_categorical_accuracy: 0.9777 - val_loss: 0.7319 - val_sparse_categorical_accuracy: 0.8357\n", 176 | "Epoch 18/50\n", 177 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0698 - sparse_categorical_accuracy: 0.9786 - val_loss: 0.7287 - val_sparse_categorical_accuracy: 0.8392\n", 178 | "Epoch 19/50\n", 179 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0721 - sparse_categorical_accuracy: 0.9777 - val_loss: 3.6693 - val_sparse_categorical_accuracy: 0.8262\n", 180 | "Epoch 20/50\n", 181 | "391/391 [==============================] - 52s 133ms/step - loss: 0.1582 - sparse_categorical_accuracy: 0.9514 - val_loss: 0.6422 - val_sparse_categorical_accuracy: 0.8316\n", 182 | "Epoch 21/50\n", 183 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0749 - sparse_categorical_accuracy: 0.9756 - val_loss: 0.6634 - val_sparse_categorical_accuracy: 0.8520\n", 184 | "Epoch 22/50\n", 185 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0395 - sparse_categorical_accuracy: 0.9870 - val_loss: 0.7186 - val_sparse_categorical_accuracy: 0.8540\n", 186 | "Epoch 23/50\n", 187 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0308 - sparse_categorical_accuracy: 0.9903 - val_loss: 0.6969 - val_sparse_categorical_accuracy: 0.8564\n", 188 | "Epoch 24/50\n", 189 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0293 - sparse_categorical_accuracy: 0.9907 - val_loss: 0.7245 - val_sparse_categorical_accuracy: 0.8536\n", 190 | "Epoch 25/50\n", 191 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.7723 - val_sparse_categorical_accuracy: 0.8448\n", 192 | "Epoch 26/50\n", 193 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9857 - val_loss: 0.8425 - val_sparse_categorical_accuracy: 0.8447\n", 194 | "Epoch 27/50\n", 195 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0388 - sparse_categorical_accuracy: 0.9875 - val_loss: 0.7098 - val_sparse_categorical_accuracy: 0.8571\n", 196 | "Epoch 28/50\n", 197 | "391/391 [==============================] - 52s 133ms/step - loss: 0.0364 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.6613 - val_sparse_categorical_accuracy: 0.8568\n", 198 | "Epoch 29/50\n", 199 | "391/391 [==============================] - 52s 134ms/step - loss: 0.0628 - sparse_categorical_accuracy: 0.9801 - val_loss: 0.6780 - val_sparse_categorical_accuracy: 0.8460\n", 200 | "Epoch 30/50\n", 201 | "391/391 [==============================] - 53s 135ms/step - loss: 0.0500 - sparse_categorical_accuracy: 0.9857 - val_loss: 0.6705 - val_sparse_categorical_accuracy: 0.8485\n", 202 | "Epoch 31/50\n", 203 | "391/391 [==============================] - 53s 135ms/step - loss: 0.0423 - sparse_categorical_accuracy: 0.9873 - val_loss: 0.6592 - val_sparse_categorical_accuracy: 0.8581\n", 204 | "Epoch 32/50\n", 205 | "391/391 [==============================] - 53s 136ms/step - loss: 0.0304 - sparse_categorical_accuracy: 0.9907 - val_loss: 0.6420 - val_sparse_categorical_accuracy: 0.8572\n", 206 | "Epoch 33/50\n", 207 | "391/391 [==============================] - 53s 136ms/step - loss: 0.0283 - sparse_categorical_accuracy: 0.9912 - val_loss: 0.6432 - val_sparse_categorical_accuracy: 0.8596\n", 208 | "Epoch 34/50\n", 209 | "391/391 [==============================] - 53s 135ms/step - loss: 0.0306 - sparse_categorical_accuracy: 0.9906 - val_loss: 0.6413 - val_sparse_categorical_accuracy: 0.8530\n", 210 | "Epoch 35/50\n", 211 | "391/391 [==============================] - 53s 137ms/step - loss: 0.0310 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.6644 - val_sparse_categorical_accuracy: 0.8527\n", 212 | "Epoch 36/50\n", 213 | "391/391 [==============================] - 53s 136ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9901 - val_loss: 0.6512 - val_sparse_categorical_accuracy: 0.8576\n", 214 | "Epoch 37/50\n", 215 | "391/391 [==============================] - 53s 137ms/step - loss: 0.0290 - sparse_categorical_accuracy: 0.9908 - val_loss: 0.6451 - val_sparse_categorical_accuracy: 0.8604\n", 216 | "Epoch 38/50\n", 217 | "391/391 [==============================] - 53s 135ms/step - loss: 0.0284 - sparse_categorical_accuracy: 0.9917 - val_loss: 0.6429 - val_sparse_categorical_accuracy: 0.8595\n", 218 | "Epoch 39/50\n", 219 | "391/391 [==============================] - 53s 135ms/step - loss: 0.0324 - sparse_categorical_accuracy: 0.9897 - val_loss: 0.6404 - val_sparse_categorical_accuracy: 0.8556\n", 220 | "Epoch 40/50\n", 221 | "391/391 [==============================] - 53s 136ms/step - loss: 0.0343 - sparse_categorical_accuracy: 0.9903 - val_loss: 0.9350 - val_sparse_categorical_accuracy: 0.8611\n", 222 | "Epoch 41/50\n", 223 | "391/391 [==============================] - 53s 136ms/step - loss: 0.0420 - sparse_categorical_accuracy: 0.9874 - val_loss: 0.6213 - val_sparse_categorical_accuracy: 0.8590\n", 224 | "Epoch 42/50\n", 225 | "391/391 [==============================] - 53s 136ms/step - loss: 0.0284 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.6547 - val_sparse_categorical_accuracy: 0.8583\n", 226 | "Epoch 43/50\n", 227 | "391/391 [==============================] - 53s 135ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.6317 - val_sparse_categorical_accuracy: 0.8686\n", 228 | "Epoch 44/50\n", 229 | "391/391 [==============================] - 53s 135ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.6605 - val_sparse_categorical_accuracy: 0.8662\n", 230 | "Epoch 45/50\n", 231 | "391/391 [==============================] - 52s 134ms/step - loss: 0.0233 - sparse_categorical_accuracy: 0.9925 - val_loss: 0.7771 - val_sparse_categorical_accuracy: 0.8482\n", 232 | "Epoch 46/50\n", 233 | "391/391 [==============================] - 52s 134ms/step - loss: 0.0260 - sparse_categorical_accuracy: 0.9922 - val_loss: 0.6526 - val_sparse_categorical_accuracy: 0.8621\n", 234 | "Epoch 47/50\n", 235 | "391/391 [==============================] - 52s 134ms/step - loss: 0.0245 - sparse_categorical_accuracy: 0.9926 - val_loss: 0.7785 - val_sparse_categorical_accuracy: 0.8477\n", 236 | "Epoch 48/50\n", 237 | "391/391 [==============================] - 52s 134ms/step - loss: 0.0294 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.6483 - val_sparse_categorical_accuracy: 0.8550\n", 238 | "Epoch 49/50\n", 239 | "391/391 [==============================] - 52s 134ms/step - loss: 0.0250 - sparse_categorical_accuracy: 0.9926 - val_loss: 0.7692 - val_sparse_categorical_accuracy: 0.8661\n", 240 | "Epoch 50/50\n", 241 | "391/391 [==============================] - 52s 134ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9931 - val_loss: 0.6120 - val_sparse_categorical_accuracy: 0.8650\n" 242 | ] 243 | }, 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "" 248 | ] 249 | }, 250 | "execution_count": 4, 251 | "metadata": { 252 | "tags": [] 253 | }, 254 | "output_type": "execute_result" 255 | } 256 | ], 257 | "source": [ 258 | "model.fit(tr_data, validation_data = te_data, epochs = 50)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 5, 264 | "metadata": { 265 | "id": "tXn4W1lqhbbf" 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "with open(\"model.json\", mode = \"w\") as file:\n", 270 | " file.write(model.to_json())\n", 271 | "model.save_weights(\"model.h5\")" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 6, 277 | "metadata": { 278 | "id": "9Rx7ssmeh167" 279 | }, 280 | "outputs": [], 281 | "source": [ 282 | "with open(\"model.json\", mode = \"r\") as file:\n", 283 | " model = tf.keras.models.model_from_json(file.read(), {\"rSoftMax\":resnest.rSoftMax})\n", 284 | "model.load_weights(\"model.h5\")" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 7, 290 | "metadata": { 291 | "colab": { 292 | "base_uri": "https://localhost:8080/", 293 | "height": 51 294 | }, 295 | "id": "JTBoj3vrFLca", 296 | "outputId": "1eb30d0c-5961-4019-d946-7c747f4b7461" 297 | }, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "79/79 [==============================] - 3s 37ms/step - loss: 0.6120 - sparse_categorical_accuracy: 0.8650\n" 304 | ] 305 | }, 306 | { 307 | "data": { 308 | "text/plain": [ 309 | "[0.6120349168777466, 0.8650000095367432]" 310 | ] 311 | }, 312 | "execution_count": 7, 313 | "metadata": { 314 | "tags": [] 315 | }, 316 | "output_type": "execute_result" 317 | } 318 | ], 319 | "source": [ 320 | "loss = tf.keras.losses.sparse_categorical_crossentropy\n", 321 | "metric = [tf.keras.metrics.sparse_categorical_accuracy]\n", 322 | "model.compile(loss = loss, metrics = metric)\n", 323 | "model.evaluate(te_data)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": { 330 | "id": "MsTu5aaRHY9d" 331 | }, 332 | "outputs": [], 333 | "source": [] 334 | } 335 | ], 336 | "metadata": { 337 | "accelerator": "GPU", 338 | "colab": { 339 | "collapsed_sections": [], 340 | "name": "usage.ipynb", 341 | "provenance": [] 342 | }, 343 | "kernelspec": { 344 | "display_name": "Python 3", 345 | "language": "python", 346 | "name": "python3" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.7.6" 359 | }, 360 | "widgets": { 361 | "application/vnd.jupyter.widget-state+json": { 362 | "3914cdf644284fd98224d0716f63f206": { 363 | "model_module": "@jupyter-widgets/controls", 364 | "model_name": "HTMLModel", 365 | "state": { 366 | "_dom_classes": [], 367 | "_model_module": "@jupyter-widgets/controls", 368 | "_model_module_version": "1.5.0", 369 | "_model_name": "HTMLModel", 370 | "_view_count": null, 371 | "_view_module": "@jupyter-widgets/controls", 372 | "_view_module_version": "1.5.0", 373 | "_view_name": "HTMLView", 374 | "description": "", 375 | "description_tooltip": null, 376 | "layout": "IPY_MODEL_9be86dc9ed694d84b691b5507631a083", 377 | "placeholder": "​", 378 | "style": "IPY_MODEL_cd152a0c3ce045ceb3c0391304486abd", 379 | "value": " 185M/185M [00:02<00:00, 73.8MB/s]" 380 | } 381 | }, 382 | "43b6d8c059e343899db1070a74e745f5": { 383 | "model_module": "@jupyter-widgets/controls", 384 | "model_name": "ProgressStyleModel", 385 | "state": { 386 | "_model_module": "@jupyter-widgets/controls", 387 | "_model_module_version": "1.5.0", 388 | "_model_name": "ProgressStyleModel", 389 | "_view_count": null, 390 | "_view_module": "@jupyter-widgets/base", 391 | "_view_module_version": "1.2.0", 392 | "_view_name": "StyleView", 393 | "bar_color": null, 394 | "description_width": "initial" 395 | } 396 | }, 397 | "85e0d96918444df3878cc0a03f74beb2": { 398 | "model_module": "@jupyter-widgets/controls", 399 | "model_name": "FloatProgressModel", 400 | "state": { 401 | "_dom_classes": [], 402 | "_model_module": "@jupyter-widgets/controls", 403 | "_model_module_version": "1.5.0", 404 | "_model_name": "FloatProgressModel", 405 | "_view_count": null, 406 | "_view_module": "@jupyter-widgets/controls", 407 | "_view_module_version": "1.5.0", 408 | "_view_name": "ProgressView", 409 | "bar_style": "success", 410 | "description": "100%", 411 | "description_tooltip": null, 412 | "layout": "IPY_MODEL_8c066c9507e2410c9ede383e75e2b8ea", 413 | "max": 193782911, 414 | "min": 0, 415 | "orientation": "horizontal", 416 | "style": "IPY_MODEL_43b6d8c059e343899db1070a74e745f5", 417 | "value": 193782911 418 | } 419 | }, 420 | "8c066c9507e2410c9ede383e75e2b8ea": { 421 | "model_module": "@jupyter-widgets/base", 422 | "model_name": "LayoutModel", 423 | "state": { 424 | "_model_module": "@jupyter-widgets/base", 425 | "_model_module_version": "1.2.0", 426 | "_model_name": "LayoutModel", 427 | "_view_count": null, 428 | "_view_module": "@jupyter-widgets/base", 429 | "_view_module_version": "1.2.0", 430 | "_view_name": "LayoutView", 431 | "align_content": null, 432 | "align_items": null, 433 | "align_self": null, 434 | "border": null, 435 | "bottom": null, 436 | "display": null, 437 | "flex": null, 438 | "flex_flow": null, 439 | "grid_area": null, 440 | "grid_auto_columns": null, 441 | "grid_auto_flow": null, 442 | "grid_auto_rows": null, 443 | "grid_column": null, 444 | "grid_gap": null, 445 | "grid_row": null, 446 | "grid_template_areas": null, 447 | "grid_template_columns": null, 448 | "grid_template_rows": null, 449 | "height": null, 450 | "justify_content": null, 451 | "justify_items": null, 452 | "left": null, 453 | "margin": null, 454 | "max_height": null, 455 | "max_width": null, 456 | "min_height": null, 457 | "min_width": null, 458 | "object_fit": null, 459 | "object_position": null, 460 | "order": null, 461 | "overflow": null, 462 | "overflow_x": null, 463 | "overflow_y": null, 464 | "padding": null, 465 | "right": null, 466 | "top": null, 467 | "visibility": null, 468 | "width": null 469 | } 470 | }, 471 | "9be86dc9ed694d84b691b5507631a083": { 472 | "model_module": "@jupyter-widgets/base", 473 | "model_name": "LayoutModel", 474 | "state": { 475 | "_model_module": "@jupyter-widgets/base", 476 | "_model_module_version": "1.2.0", 477 | "_model_name": "LayoutModel", 478 | "_view_count": null, 479 | "_view_module": "@jupyter-widgets/base", 480 | "_view_module_version": "1.2.0", 481 | "_view_name": "LayoutView", 482 | "align_content": null, 483 | "align_items": null, 484 | "align_self": null, 485 | "border": null, 486 | "bottom": null, 487 | "display": null, 488 | "flex": null, 489 | "flex_flow": null, 490 | "grid_area": null, 491 | "grid_auto_columns": null, 492 | "grid_auto_flow": null, 493 | "grid_auto_rows": null, 494 | "grid_column": null, 495 | "grid_gap": null, 496 | "grid_row": null, 497 | "grid_template_areas": null, 498 | "grid_template_columns": null, 499 | "grid_template_rows": null, 500 | "height": null, 501 | "justify_content": null, 502 | "justify_items": null, 503 | "left": null, 504 | "margin": null, 505 | "max_height": null, 506 | "max_width": null, 507 | "min_height": null, 508 | "min_width": null, 509 | "object_fit": null, 510 | "object_position": null, 511 | "order": null, 512 | "overflow": null, 513 | "overflow_x": null, 514 | "overflow_y": null, 515 | "padding": null, 516 | "right": null, 517 | "top": null, 518 | "visibility": null, 519 | "width": null 520 | } 521 | }, 522 | "cd152a0c3ce045ceb3c0391304486abd": { 523 | "model_module": "@jupyter-widgets/controls", 524 | "model_name": "DescriptionStyleModel", 525 | "state": { 526 | "_model_module": "@jupyter-widgets/controls", 527 | "_model_module_version": "1.5.0", 528 | "_model_name": "DescriptionStyleModel", 529 | "_view_count": null, 530 | "_view_module": "@jupyter-widgets/base", 531 | "_view_module_version": "1.2.0", 532 | "_view_name": "StyleView", 533 | "description_width": "" 534 | } 535 | }, 536 | "d3656d36be6d4477a5bfd941fed8a886": { 537 | "model_module": "@jupyter-widgets/base", 538 | "model_name": "LayoutModel", 539 | "state": { 540 | "_model_module": "@jupyter-widgets/base", 541 | "_model_module_version": "1.2.0", 542 | "_model_name": "LayoutModel", 543 | "_view_count": null, 544 | "_view_module": "@jupyter-widgets/base", 545 | "_view_module_version": "1.2.0", 546 | "_view_name": "LayoutView", 547 | "align_content": null, 548 | "align_items": null, 549 | "align_self": null, 550 | "border": null, 551 | "bottom": null, 552 | "display": null, 553 | "flex": null, 554 | "flex_flow": null, 555 | "grid_area": null, 556 | "grid_auto_columns": null, 557 | "grid_auto_flow": null, 558 | "grid_auto_rows": null, 559 | "grid_column": null, 560 | "grid_gap": null, 561 | "grid_row": null, 562 | "grid_template_areas": null, 563 | "grid_template_columns": null, 564 | "grid_template_rows": null, 565 | "height": null, 566 | "justify_content": null, 567 | "justify_items": null, 568 | "left": null, 569 | "margin": null, 570 | "max_height": null, 571 | "max_width": null, 572 | "min_height": null, 573 | "min_width": null, 574 | "object_fit": null, 575 | "object_position": null, 576 | "order": null, 577 | "overflow": null, 578 | "overflow_x": null, 579 | "overflow_y": null, 580 | "padding": null, 581 | "right": null, 582 | "top": null, 583 | "visibility": null, 584 | "width": null 585 | } 586 | }, 587 | "f984345bb7124684a517b4231e5766ec": { 588 | "model_module": "@jupyter-widgets/controls", 589 | "model_name": "HBoxModel", 590 | "state": { 591 | "_dom_classes": [], 592 | "_model_module": "@jupyter-widgets/controls", 593 | "_model_module_version": "1.5.0", 594 | "_model_name": "HBoxModel", 595 | "_view_count": null, 596 | "_view_module": "@jupyter-widgets/controls", 597 | "_view_module_version": "1.5.0", 598 | "_view_name": "HBoxView", 599 | "box_style": "", 600 | "children": [ 601 | "IPY_MODEL_85e0d96918444df3878cc0a03f74beb2", 602 | "IPY_MODEL_3914cdf644284fd98224d0716f63f206" 603 | ], 604 | "layout": "IPY_MODEL_d3656d36be6d4477a5bfd941fed8a886" 605 | } 606 | } 607 | } 608 | } 609 | }, 610 | "nbformat": 4, 611 | "nbformat_minor": 4 612 | } 613 | --------------------------------------------------------------------------------