├── .gitattributes
├── .idea
├── .gitignore
├── PV_Classify.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
└── modules.xml
├── Matrix.py
├── PVF-10.zip
├── Readme.txt
├── class_indices.json
├── fig
├── __pycache__
│ ├── chart.cpython-38.pyc
│ └── con_mat_fig.cpython-38.pyc
├── chart.py
├── con mat.py
├── con_mat_10.png
├── con_mat_compare_samp_pad.png
├── con_mat_compare_samp_pad531.png
├── con_mat_fig.py
├── figure_temp.py
├── read_jpg.py
├── train_loss_fig.py
└── vitslog.csv
├── model
├── 111.py
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── _builder.cpython-38.pyc
│ ├── _efficientnet_blocks.cpython-38.pyc
│ ├── _efficientnet_builder.cpython-38.pyc
│ ├── _factory.cpython-38.pyc
│ ├── _features.cpython-38.pyc
│ ├── _features_fx.cpython-38.pyc
│ ├── _helpers.cpython-38.pyc
│ ├── _hub.cpython-38.pyc
│ ├── _manipulate.cpython-38.pyc
│ ├── _pretrained.cpython-38.pyc
│ ├── _prune.cpython-38.pyc
│ ├── _registry.cpython-38.pyc
│ ├── coat.cpython-38.pyc
│ ├── efficientnet.cpython-38.pyc
│ ├── effnetv2.cpython-38.pyc
│ ├── resnet.cpython-38.pyc
│ ├── swin_transformer_v2.cpython-38.pyc
│ └── vision_transformer.cpython-38.pyc
├── _builder.py
├── _efficientnet_blocks.py
├── _efficientnet_builder.py
├── _factory.py
├── _features.py
├── _features_fx.py
├── _helpers.py
├── _hub.py
├── _manipulate.py
├── _pretrained.py
├── _prune.py
├── _registry.py
├── coat.py
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── auto_augment.cpython-38.pyc
│ │ ├── config.cpython-38.pyc
│ │ ├── constants.cpython-38.pyc
│ │ ├── dataset.cpython-38.pyc
│ │ ├── dataset_factory.cpython-38.pyc
│ │ ├── dataset_info.cpython-38.pyc
│ │ ├── distributed_sampler.cpython-38.pyc
│ │ ├── imagenet_info.cpython-38.pyc
│ │ ├── loader.cpython-38.pyc
│ │ ├── mixup.cpython-38.pyc
│ │ ├── random_erasing.cpython-38.pyc
│ │ ├── real_labels.cpython-38.pyc
│ │ ├── transforms.cpython-38.pyc
│ │ └── transforms_factory.cpython-38.pyc
│ ├── _info
│ │ ├── imagenet12k_synsets.txt
│ │ ├── imagenet21k_goog_synsets.txt
│ │ ├── imagenet21k_goog_to_12k_indices.txt
│ │ ├── imagenet21k_goog_to_22k_indices.txt
│ │ ├── imagenet21k_miil_synsets.txt
│ │ ├── imagenet21k_miil_w21_synsets.txt
│ │ ├── imagenet22k_ms_synsets.txt
│ │ ├── imagenet22k_ms_to_12k_indices.txt
│ │ ├── imagenet22k_ms_to_22k_indices.txt
│ │ ├── imagenet22k_synsets.txt
│ │ ├── imagenet22k_to_12k_indices.txt
│ │ ├── imagenet_a_indices.txt
│ │ ├── imagenet_a_synsets.txt
│ │ ├── imagenet_r_indices.txt
│ │ ├── imagenet_r_synsets.txt
│ │ ├── imagenet_real_labels.json
│ │ ├── imagenet_synset_to_definition.txt
│ │ ├── imagenet_synset_to_lemma.txt
│ │ └── imagenet_synsets.txt
│ ├── auto_augment.py
│ ├── config.py
│ ├── constants.py
│ ├── dataset.py
│ ├── dataset_factory.py
│ ├── dataset_info.py
│ ├── distributed_sampler.py
│ ├── imagenet_info.py
│ ├── loader.py
│ ├── mixup.py
│ ├── random_erasing.py
│ ├── readers
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── class_map.cpython-38.pyc
│ │ │ ├── img_extensions.cpython-38.pyc
│ │ │ ├── reader.cpython-38.pyc
│ │ │ ├── reader_factory.cpython-38.pyc
│ │ │ ├── reader_image_folder.cpython-38.pyc
│ │ │ └── reader_image_in_tar.cpython-38.pyc
│ │ ├── class_map.py
│ │ ├── img_extensions.py
│ │ ├── reader.py
│ │ ├── reader_factory.py
│ │ ├── reader_hfds.py
│ │ ├── reader_image_folder.py
│ │ ├── reader_image_in_tar.py
│ │ ├── reader_image_tar.py
│ │ ├── reader_tfds.py
│ │ ├── reader_wds.py
│ │ └── shared_count.py
│ ├── real_labels.py
│ ├── tf_preprocessing.py
│ ├── transforms.py
│ └── transforms_factory.py
├── efficientnet.py
├── effnetv2.py
├── layers
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── activations.cpython-38.pyc
│ │ ├── activations_jit.cpython-38.pyc
│ │ ├── activations_me.cpython-38.pyc
│ │ ├── adaptive_avgmax_pool.cpython-38.pyc
│ │ ├── attention_pool.cpython-38.pyc
│ │ ├── attention_pool2d.cpython-38.pyc
│ │ ├── blur_pool.cpython-38.pyc
│ │ ├── bottleneck_attn.cpython-38.pyc
│ │ ├── cbam.cpython-38.pyc
│ │ ├── classifier.cpython-38.pyc
│ │ ├── cond_conv2d.cpython-38.pyc
│ │ ├── config.cpython-38.pyc
│ │ ├── conv2d_same.cpython-38.pyc
│ │ ├── conv_bn_act.cpython-38.pyc
│ │ ├── create_act.cpython-38.pyc
│ │ ├── create_attn.cpython-38.pyc
│ │ ├── create_conv2d.cpython-38.pyc
│ │ ├── create_norm.cpython-38.pyc
│ │ ├── create_norm_act.cpython-38.pyc
│ │ ├── drop.cpython-38.pyc
│ │ ├── eca.cpython-38.pyc
│ │ ├── evo_norm.cpython-38.pyc
│ │ ├── fast_norm.cpython-38.pyc
│ │ ├── filter_response_norm.cpython-38.pyc
│ │ ├── format.cpython-38.pyc
│ │ ├── gather_excite.cpython-38.pyc
│ │ ├── global_context.cpython-38.pyc
│ │ ├── grn.cpython-38.pyc
│ │ ├── halo_attn.cpython-38.pyc
│ │ ├── helpers.cpython-38.pyc
│ │ ├── inplace_abn.cpython-38.pyc
│ │ ├── interpolate.cpython-38.pyc
│ │ ├── lambda_layer.cpython-38.pyc
│ │ ├── linear.cpython-38.pyc
│ │ ├── mixed_conv2d.cpython-38.pyc
│ │ ├── mlp.cpython-38.pyc
│ │ ├── non_local_attn.cpython-38.pyc
│ │ ├── norm.cpython-38.pyc
│ │ ├── norm_act.cpython-38.pyc
│ │ ├── padding.cpython-38.pyc
│ │ ├── patch_dropout.cpython-38.pyc
│ │ ├── patch_embed.cpython-38.pyc
│ │ ├── pool2d_same.cpython-38.pyc
│ │ ├── pos_embed.cpython-38.pyc
│ │ ├── pos_embed_rel.cpython-38.pyc
│ │ ├── pos_embed_sincos.cpython-38.pyc
│ │ ├── selective_kernel.cpython-38.pyc
│ │ ├── separable_conv.cpython-38.pyc
│ │ ├── space_to_depth.cpython-38.pyc
│ │ ├── split_attn.cpython-38.pyc
│ │ ├── split_batchnorm.cpython-38.pyc
│ │ ├── squeeze_excite.cpython-38.pyc
│ │ ├── std_conv.cpython-38.pyc
│ │ ├── test_time_pool.cpython-38.pyc
│ │ ├── trace_utils.cpython-38.pyc
│ │ ├── typing.cpython-38.pyc
│ │ └── weight_init.cpython-38.pyc
│ ├── activations.py
│ ├── activations_jit.py
│ ├── activations_me.py
│ ├── adaptive_avgmax_pool.py
│ ├── attention_pool.py
│ ├── attention_pool2d.py
│ ├── blur_pool.py
│ ├── bottleneck_attn.py
│ ├── cbam.py
│ ├── classifier.py
│ ├── cond_conv2d.py
│ ├── config.py
│ ├── conv2d_same.py
│ ├── conv_bn_act.py
│ ├── create_act.py
│ ├── create_attn.py
│ ├── create_conv2d.py
│ ├── create_norm.py
│ ├── create_norm_act.py
│ ├── drop.py
│ ├── eca.py
│ ├── evo_norm.py
│ ├── fast_norm.py
│ ├── filter_response_norm.py
│ ├── format.py
│ ├── gather_excite.py
│ ├── global_context.py
│ ├── grn.py
│ ├── halo_attn.py
│ ├── helpers.py
│ ├── inplace_abn.py
│ ├── interpolate.py
│ ├── lambda_layer.py
│ ├── linear.py
│ ├── median_pool.py
│ ├── mixed_conv2d.py
│ ├── ml_decoder.py
│ ├── mlp.py
│ ├── non_local_attn.py
│ ├── norm.py
│ ├── norm_act.py
│ ├── padding.py
│ ├── patch_dropout.py
│ ├── patch_embed.py
│ ├── pool2d_same.py
│ ├── pos_embed.py
│ ├── pos_embed_rel.py
│ ├── pos_embed_sincos.py
│ ├── selective_kernel.py
│ ├── separable_conv.py
│ ├── space_to_depth.py
│ ├── split_attn.py
│ ├── split_batchnorm.py
│ ├── squeeze_excite.py
│ ├── std_conv.py
│ ├── test_time_pool.py
│ ├── trace_utils.py
│ ├── typing.py
│ └── weight_init.py
├── resnet.py
├── swin_transformer_v2.py
└── vision_transformer.py
├── my_dataset.py
├── parameter.py
├── predict.py
├── train_val.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | PVF-10.zip filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/PV_Classify.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
15 |
16 |
17 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Matrix.py:
--------------------------------------------------------------------------------
1 | """
2 | @FileName:Matrix.py\n
3 | @Description:混淆矩阵\n
4 | @Author:WBobby\n
5 | @Department:CUG\n
6 | @Time:2023/7/12 10:05\n
7 | """
8 | import csv
9 |
10 | import numpy as np
11 | import pandas as pd
12 | from sklearn.metrics import confusion_matrix
13 |
14 |
15 | def read_csv(infile):
16 | df = pd.read_csv(infile)
17 | data = df.values.tolist()
18 | for row in data:
19 | image = row[1]
20 | row.append(image[:2])
21 | return data
22 |
23 | def matrix(data,outfile):
24 | y_true = [row[2] for row in data]
25 | # print(y_true)
26 | y_pred = [row[4] for row in data]
27 | # print(y_pred)
28 | confusion_mat = confusion_matrix(y_true, y_pred)
29 | # print(confusion_mat)
30 | return confusion_mat
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 | if __name__ == '__main__':
41 | infile = r'C:\Users\Wbobby\Desktop\BB.csv'
42 | outfile = ''
43 | data = read_csv(infile)
44 | matrix(data, outfile)
--------------------------------------------------------------------------------
/PVF-10.zip:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:5bbedd95fd99b98198fcea84017c9243af8d4e8aeef307a3bd77200919195189
3 | size 165094105
4 |
--------------------------------------------------------------------------------
/Readme.txt:
--------------------------------------------------------------------------------
1 | # PVF-Dataset (PVF-10)
2 | 1. This is code and dataset from the paper of "Photovoltaic Fault Dataset (PVF-10): A High-resolution UAV Thermal Infrared Image Dataset for Fine-grained Photovoltaic Fault Classification "
3 | 2. The code and dataset (PVF-10) are on the master branch.
4 | 3. Since the weight file of the model is too large, we put it on "pan.baidu.com" for interested researchers at the link: https://pan.baidu.com/s/1SfvW7jvkqhF5tu7J-EAzjw; Code:PVFs
5 | train_val.py is the training script.
6 | pre is the test script.
7 | The model folder holds the classification models used in the paper.
8 | The weights folder holds the weights files for the five models in the paper.
9 | 4. Due to the bandwidth limitation of Github large file system GLF, PVF-10 may not be able to be downloaded properly, please move to Baidu.com. Link: https://pan.baidu.com/s/1LJVe2lkqvYnwTI8cuiESDg; Extract Code: PVFD; or Google drive link: https://drive.google.com/file/d/1SQq0hETXi8I3Kdq9tDAEVyZgIsRCbOah/view?usp=sharing
10 |
11 | If you use PVF-10 for related research, please cite our paper: PVF-10: A high-resolution unmanned aerial vehicle thermal infrared image dataset for fine-grained photovoltaic fault classification, with bibtex as: @article{WANG2024124187, title = {PVF-10: A high-resolution unmanned aerial vehicle thermal infrared image dataset for fine-grained photovoltaic fault classification}, journal = {Applied Energy}, volume = {376}, pages = {124187}, year = {2024}, issn = {0306-2619}, doi = {https://doi.org/10.1016/j.apenergy.2024.124187}, url = {https://www.sciencedirect.com/science/article/pii/S0306261924015708}, author = {Bo Wang and Qi Chen and Mengmeng Wang and Yuntian Chen and Zhengjia Zhang and Xiuguo Liu and Wei Gao and Yanzhen Zhang and Haoran Zhang}, keywords = {Photovoltaic fault, Thermal infrared data, Classification, Deep learning, Unmanned aerial vehicle}, abstract = {Accurate identification of faulty photovoltaic (PV) modules is crucial for the effective operation and maintenance of PV systems. Deep learning (DL) algorithms exhibit promising potential for classifying PV fault (PVF) from thermal infrared (TIR) images captured by unmanned aerial vehicle (UAV), contingent upon the availability of extensive and high-quality labeled data. However, existing TIR PVF datasets are limited by low image resolution and incomplete coverage of fault types. This study proposes a high-resolution TIR PVF dataset with 10 classes, named PVF-10, comprising 5579 cropped images of PV panels collected from 8 PV power plants. These classes are further categorized into two groups according to the repairability of PVF, with 5 repairable and 5 irreparable classes each. Additionally, the circuit mechanisms underlying the TIR image features of typical PVF types are analyzed, supported by high-resolution images, thereby providing comprehensive information for PV operators. Finally, five state-of-the-art DL algorithms are trained and validated based on the PVF-10 dataset using three levels of resampling strategy. The results show that the overall accuracy (OA) of these algorithms exceeds 83%, with the highest OA reaching 93.32%. Moreover, the preprocessing procedure involving resampling and padding strategies are beneficial for improving PVF classification accuracy using PVF-10 datasets. The developed PVF-10 dataset is expected to stimulate further research and innovation in PVF classification.} }
12 |
13 | Note: Data provided by the School of Geography and Information Engineering, China University of Geosciences.
14 |
--------------------------------------------------------------------------------
/class_indices.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "01bottom dirt",
3 | "1": "02break",
4 | "2": "03Debris cover",
5 | "3": "04junction box heat",
6 | "4": "05hot cell",
7 | "5": "06shadow",
8 | "6": "07short circuit panel",
9 | "7": "08string short circuit",
10 | "8": "09substring open circuit",
11 | "9": "10healthy panel"
12 | }
--------------------------------------------------------------------------------
/fig/__pycache__/chart.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/__pycache__/chart.cpython-38.pyc
--------------------------------------------------------------------------------
/fig/__pycache__/con_mat_fig.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/__pycache__/con_mat_fig.cpython-38.pyc
--------------------------------------------------------------------------------
/fig/chart.py:
--------------------------------------------------------------------------------
1 | """
2 | @FileName:chart.py\n
3 | @Description:\n
4 | @Author:WBobby\n
5 | @Department:CUG\n
6 | @Time:2023/12/18 23:51\n
7 | """
8 | import matplotlib
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 |
12 | fig_name = 'Sample Number of PPF '
13 | figsize = (8, 8)
14 |
15 | matplotlib.rcParams['font.family'] = 'serif'
16 | matplotlib.rcParams['font.serif'] = ['Times New Roman'] + matplotlib.rcParams['font.serif']
17 | # (a) bypass diode heating, (b) substring disconnection, (c) debris covering, (d) panel breaking, (e) dusty covering,
18 | # (f) General hot spot, (g) health panel.
19 | fault_dic = {'a': 595, 'b': 427, 'c': 71, 'd': 410, 'e': 303, 'f': 377, 'g': 131,
20 | 'h': 800, 'i': 946, 'j': 1519}
21 |
22 | fault_names = list(fault_dic.keys())[::]
23 | fault_counts = list(fault_dic.values())[::]
24 |
25 |
26 | # colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink']
27 | colors = plt.cm.Wistia(np.linspace(1, 0.1, len(fault_names)))
28 | print(colors)
29 |
30 | fig, ax = plt.subplots(figsize=(6, 4))
31 | bar_width = 0.6
32 | bars = ax.bar(fault_names, fault_counts, color=colors, width=bar_width)
33 | # bars = ax.bar(fault_counts, fault_names, color=colors, height=bar_width)
34 |
35 | # plt.barh(fault_names, fault_counts, color=colors)
36 | plt.ylabel('Fault Count', fontsize=12, fontweight='bold')
37 | plt.xlabel('Fault Types', fontsize=12, fontweight='bold')
38 | # plt.title(fig_name, fontsize=25)
39 |
40 | for bar, x, count in zip(bars, fault_names, fault_counts):
41 | ax.text(x, bar.get_y() + bar.get_height(), str(count), ha='center', va='top', fontsize=14)
42 |
43 | ax.grid(axis='y', linestyle='--', alpha=0.7)
44 | ax.set_title(fig_name, fontsize=15, fontweight='bold')
45 | plt.yticks(fontsize=11, fontweight='bold')
46 | plt.xticks(fontsize=12, fontweight='bold')
47 | plt.tight_layout() # 调整布局
48 | plt.savefig(r'C:\Users\Wbobby\Documents\TeX_files\PVFC\tu/' + 'Sample3.png', dpi=300)
49 | plt.show()
50 |
--------------------------------------------------------------------------------
/fig/con_mat_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/con_mat_10.png
--------------------------------------------------------------------------------
/fig/con_mat_compare_samp_pad.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/con_mat_compare_samp_pad.png
--------------------------------------------------------------------------------
/fig/con_mat_compare_samp_pad531.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/con_mat_compare_samp_pad531.png
--------------------------------------------------------------------------------
/fig/figure_temp.py:
--------------------------------------------------------------------------------
1 | """
2 | @FileName:figure_temp.py\n
3 | @Description:python画图模板\n
4 | @Author:WBobby\n
5 | @Department:CUG\n
6 | @Time:2024/3/20 20:29\n
7 | """
8 |
9 |
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 |
13 | def plot_figure(feature1, feature2, filename):
14 | # 创建整块画布
15 | '''
16 | plt.figure() 是 matplotlib.pyplot 模块中的一个函数,用于创建一个新的图表画布(Figure对象)。这个函数有几个参数,可以用来配置图表的外观和行为。以下是一些常用的参数:
17 | figsize:一个元组,包含两个数值,分别代表图表的宽度和高度(以英寸为单位)。例如,figsize=(8, 6) 将创建一个宽8英寸和高6英寸的图表。
18 | dpi:图表的分辨率,表示每英寸多少个像素。默认值通常是 100。
19 | facecolor:图表画布的背景颜色。可以是一个颜色名、十六进制颜色码、RGB 或 RGBA 元组,或者预定义的颜色字符串。
20 | edgecolor:图表画布的边缘颜色。
21 | frameon:布尔值,表示是否在图表画布周围显示边框。默认值为 True。
22 | fig:可选的参数,用于指定要修改的现有Figure对象。如果没有提供,将创建一个新的Figure对象。
23 | constrained_layout:布尔值,表示是否使用受约束的布局。如果设置为 True,matplotlib 将尝试优化子图之间的空间分配。
24 | clear:布尔值,表示是否在创建新图表之前清除当前图表。默认值为 False。
25 | :param feature1:
26 | :param feature2:
27 | :return:
28 | '''
29 | fig = plt.figure(figsize=(8, 6))
30 | # 创建子图、子图实在画布上排布的,并且会覆盖
31 | '''
32 | 在 matplotlib 中,Axes 类是用于在 Figure 对象上绘制图表和图形的核心类。每个 Axes 对象代表图表中的一个坐标轴系统,可以用来绘制线条、散点图、柱状图、饼图等多种类型的图形。
33 | Axes 类提供了一系列的方法和属性来控制图形的绘制和样式,以及进行数据处理和分析。以下是一些 Axes 类的主要属性和方法:
34 | 属性:
35 | xaxis 和 yaxis:分别代表 x 轴和 y 轴的对象,可以用来设置轴的属性,如标签、刻度、范围等。
36 | lines:一个 LineCollection 对象,包含了当前 Axes 对象上的所有线条。
37 | collections:一个 Collection 对象,包含了当前 Axes 对象上的所有集合图形(如多边形、散点图等)。
38 | patches:一个 PatchCollection 对象,包含了当前 Axes 对象上的所有填充区域。
39 | images:一个 ImageCollection 对象,包含了当前 Axes 对象上的所有图像。
40 | legend_:一个 Legend 对象,包含了当前 Axes 对象的图例。
41 | 方法:
42 | plot:绘制线条,接受 x 和 y 坐标,以及其他可选参数来控制线条的样式。
43 | scatter:绘制散点图,接受 x 和 y 坐标,以及其他可选参数来控制标记的样式。
44 | bar 和 barh:绘制柱状图,接受 x 和 y 坐标,以及其他可选参数来控制柱状图的样式。
45 | boxplot:绘制箱线图,接受数据集,以及其他可选参数来控制箱线图的样式。
46 | fill_between:填充两个数据集之间的区域。
47 | errorbar:绘制带有误差线的数据点。
48 | legend:添加图例。
49 | set_xlim 和 set_ylim:设置 x 轴和 y 轴的显示范围。
50 | set_xlabel 和 set_ylabel:设置 x 轴和 y 轴的标签。
51 | set_title:设置图表的标题。
52 | grid:显示或隐藏网格。
53 | set_facecolor 和 set_edgecolor:设置坐标轴的背景色和边缘色。
54 | set_alpha:设置坐标轴的透明度。
55 | set_axis_bgcolor 和 set_axis_color:设置坐标轴背景色和坐标轴颜色。
56 | '''
57 | ax1 = fig.add_subplot()
58 | x1 = feature1[0]
59 | y1 = feature1[1]
60 | # 画第一个图
61 | '''
62 | axes.plot 是 matplotlib.axes.Axes 类的一个方法,用于在坐标轴上绘制线条。它的参数可以用来控制线条的外观和行为。以下是一些常用的参数:
63 | x 和 y:这两个参数是必须的,它们分别代表线条的x轴和y轴坐标。x 可以是单个数值或数值数组,y 必须与 x 具有相同的形状。
64 | color:线条的颜色,可以是颜色名(如 ‘red’、‘green’ 等),十六进制颜色码(如 ‘#FF00FF’),RGB 或 RGBA 元组,或者预定义的颜色字符串(如 ‘C0’、‘C1’ 等,其中 ‘C0’、‘C1’ 等表示颜色循环中的颜色)。
65 | linewidth:线条的宽度,默认为 1.0。
66 | linestyle:线条的样式,可以是实线(‘-’)、虚线(‘–’)、点线(‘:’)、点(‘.’)等。
67 | marker:数据点的标记样式,如 ‘o’(圆形)、‘s’(方形)、‘^’(三角形上)、‘<’(三角形下)等。
68 | markersize:数据点的大小。
69 | markeredgewidth:数据点边缘线的宽度。
70 | markeredgecolor:数据点边缘线的颜色。
71 | label:用于图例的标签。
72 | alpha:线条和标记的透明度,范围从 0(完全透明)到 1(完全不透明)。
73 | ax:要绘制的坐标轴对象。
74 | data:如果 x 和 y 参数未提供,则可以提供一个包含 x 和 y 数据的字典或元组。
75 | '''
76 | # ax1.boxplot(x1, y1, 'r')
77 | x2 = feature2[0]
78 | print(x2)
79 | y2 = feature2[1]
80 | ax1.boxplot(x2, notch=False, patch_artist=True, labels=['A'])
81 | # 坐标轴微调
82 | '''
83 | 在 matplotlib 中,Axes 对象有一个名为 axis 的方法,它用于设置和调整坐标轴的各种属性。这个方法通常接受一个参数,表示要调整的坐标轴(x 或 y),以及一些可选的参数来控制轴的显示和行为。
84 | 以下是 axis 方法的一些常见用法:
85 | axis('off'):关闭坐标轴的显示。
86 | axis('on'):开启坐标轴的显示。
87 | axis('equal'):确保两个坐标轴的刻度比例相同,这对于绘制等比例的地图或其他图形非常重要。
88 | axis('scaled'):这是 axis('equal') 的一个别名。
89 | axis('tight'):自动调整坐标轴的范围以适应数据,通常在绘制多个子图时使用。
90 | axis('auto'):自动调整坐标轴的范围,通常是默认行为。
91 | axis([xmin, xmax, ymin, ymax]):手动设置坐标轴的范围。
92 | axis_bgcolor:设置坐标轴的背景色。
93 | axis_color:设置坐标轴的颜色。
94 | '''
95 |
96 | # 显示图表
97 | plt.show()
98 | # 将图表保存为文件。
99 | plt.savefig(filename, format='png', dpi=300)
100 |
101 |
102 | if __name__ == '__main__':
103 | a = np.arange(1, 10)
104 | b = 2 * a
105 | c = 3 * a
106 | # 将a和b合并为一个二维数组
107 | feature1 = np.vstack((a, b))
108 | # feature2 = np.vstack((a, c))
109 | save_name = 'figure.png'
110 | data1 = [2, 2, 3, 4, 4, 4, 5, 5, 6, 6, 7, 8, 9, 9, 10]
111 | data2 = [1, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 9, 9, 10, 11, 10, 9]
112 | feature2 = (data1, data2)
113 | plot_figure(feature1, feature2, save_name)
114 |
--------------------------------------------------------------------------------
/fig/read_jpg.py:
--------------------------------------------------------------------------------
1 | """
2 | @FileName:read_jpg.py\n
3 | @Description:\n
4 | @Author:WBobby\n
5 | @Department:CUG\n
6 | @Time:2024/3/26 23:07\n
7 | """
8 | import cv2
9 |
10 | if __name__ == '__main__':
11 | jpg = r'C:\Users\Wbobby\Desktop\12\DJI_20231225162517_0002_T.JPG'
12 | img = cv2.imread(jpg)
13 | cv2.imshow('image', img)
14 | cv2.waitKey(0)
15 |
--------------------------------------------------------------------------------
/fig/train_loss_fig.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import pandas as pd
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import os
6 |
7 |
8 | def csv2dict(csv_file):
9 | data = pd.read_csv(csv_file)
10 | dict = {}
11 | for i in range(len(data)):
12 | if i % 5 == 0:
13 | epoch = data.iloc[i, 1] + 1
14 | val_acc = data.iloc[i, 3]
15 | loss = data.iloc[i, 5]
16 | dict[epoch] = [val_acc, loss]
17 | return dict
18 |
19 |
20 | def loss_fig(csv_path):
21 | matplotlib.rcParams['font.family'] = 'serif'
22 | matplotlib.rcParams['font.serif'] = ['Times New Roman'] + matplotlib.rcParams['font.serif']
23 | figsize = (16, 9)
24 | plt.figure(figsize=figsize)
25 | csv_files = os.listdir(csv_path)
26 | csv_names = [os.path.join(csv_path, csv_file) for csv_file in csv_files]
27 | model_names = ['Coat-ls', 'Effv2-s', 'Res-50', 'Swinv2-t', 'ViT-s']
28 | model_losscolors = plt.cm.coolwarm([0, 0.1, 0.7, 0.8, 0.9])
29 | model_valcolors = plt.cm.coolwarm([0.02, 0.12, 0.68, 0.78, 0.88])
30 | line_styles = ['-', '--', '-.', ':', '--']
31 | # linewidths = [2, 1.8, 1.6, 1.4, 1.4]
32 | linewidths = [2, 2, 2, 2, 2]
33 | # alphas = [1, 0.8, 0.7, 0.6, 0.6]
34 | alphas = [1, 1, 1, 1, 1]
35 | model_losscolors = ['lightcoral', 'orange', 'lightgreen', 'darkturquoise', 'plum']
36 | model_valcolors = ['indianred', 'darkorange', 'mediumseagreen', 'c', 'orchid']
37 | markers = ['o', 'v', '^', 's', '*']
38 | ax1 = plt.subplot()
39 | ax2 = plt.twinx()
40 | zz = zip(model_names, csv_names, model_losscolors, model_valcolors, line_styles, linewidths, alphas, markers)
41 |
42 | lines = []
43 | labels = []
44 | lines_labels = []
45 | for model_name, csv_name, model_color, model_valcolor, line_style, linewidth, alpha, markers in zz:
46 | csv_file = os.path.join(csv_path, csv_name)
47 | dict = csv2dict(csv_file)
48 | epochs = list(dict.keys())
49 | val_acc = [item[0] for item in dict.values()]
50 | loss = [item[1] for item in dict.values()]
51 | loss_line, = ax1.plot(epochs, loss, color=model_color, linestyle=line_style, linewidth=linewidth, alpha=alpha,
52 | marker=markers,
53 | markersize=5, label='Loss {}'.format(model_name))
54 | val_line, = ax2.plot(epochs, val_acc, color=model_valcolor, marker=markers, markersize=5, linestyle=line_style,
55 | alpha=alpha,
56 | linewidth=linewidth, label='Val Acc {}'.format(model_name))
57 | lines.append(val_line)
58 | lines.append(loss_line)
59 | labels.append(val_line.get_label())
60 | labels.append(loss_line.get_label())
61 |
62 | # 将线条和标签的元组列表解包,并传入图例中
63 |
64 |
65 | # 创建统一的图例
66 | plt.legend(lines, labels, loc='center right', fontsize=17)
67 | plt.title('Model Loss and Validation Accuracy', fontsize=40, fontweight='bold')
68 | ax1.set_xlabel('Epoch', fontsize=25, fontweight='bold')
69 | ax1.tick_params(axis='x', labelsize=20)
70 | ax1.set_ylabel('Train Loss', fontsize=25, fontweight='bold')
71 | ax1.yaxis.set_tick_params(labelsize=20)
72 | ax2.set_ylabel('Validation Accuracy', fontsize=25, fontweight='bold')
73 | ax2.tick_params(axis='y', labelsize=20)
74 | # plt.legend(loc='center right', fontsize=20)
75 | plt.savefig(r'C:\Users\Wbobby\Documents\TeX_files\PVFC\tu/' + 'loss_val1.png', dpi=300)
76 | plt.show()
77 |
78 |
79 | if __name__ == '__main__':
80 | csv_parth = r'C:\Users\Wbobby\Desktop\csv'
81 | out = ''
82 | loss_fig(csv_parth)
83 |
--------------------------------------------------------------------------------
/model/111.py:
--------------------------------------------------------------------------------
1 | """
2 | @FileName:111.py\n
3 | @Description:\n
4 | @Author:WBobby\n
5 | @Department:CUG\n
6 | @Time:2024/6/17 22:40\n
7 | """
8 |
9 |
10 | def my_decorator(func):
11 | def wrapper():
12 | print("装饰器添加的功能")
13 | func()
14 | func()
15 | return wrapper
16 |
17 | @my_decorator
18 | def say_hello():
19 | print("Hello!")
20 |
21 |
22 | say_hello()
23 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .coat import *
3 | from .efficientnet import *
4 | from .swin_transformer_v2 import *
5 | from .vision_transformer import *
6 | from .effnetv2 import *
7 |
8 |
9 | from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
10 | set_pretrained_download_progress, set_pretrained_check_hash
11 | from ._factory import create_model, parse_model_name, safe_model_name
12 | from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
13 | from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
14 | register_notrace_module, is_notrace_module, get_notrace_modules, \
15 | register_notrace_function, is_notrace_function, get_notrace_functions
16 | from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
17 | from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
18 | from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
19 | group_modules, group_parameters, checkpoint_seq, adapt_input_conv
20 | from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
21 | from ._prune import adapt_model_from_string
22 | from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
23 | register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
24 | is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
25 |
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_builder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_builder.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_efficientnet_blocks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_efficientnet_blocks.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_efficientnet_builder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_efficientnet_builder.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_features.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_features.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_features_fx.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_features_fx.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_helpers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_helpers.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_hub.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_hub.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_manipulate.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_manipulate.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_pretrained.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_pretrained.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_prune.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_prune.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/_registry.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_registry.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/coat.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/coat.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/efficientnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/efficientnet.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/effnetv2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/effnetv2.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/swin_transformer_v2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/swin_transformer_v2.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/vision_transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/vision_transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/model/_pretrained.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import deque, defaultdict
3 | from dataclasses import dataclass, field, replace, asdict
4 | from typing import Any, Deque, Dict, Tuple, Optional, Union
5 |
6 |
7 | __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
8 |
9 |
10 | @dataclass
11 | class PretrainedCfg:
12 | """
13 | """
14 | # weight source locations
15 | url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
16 | file: Optional[str] = None # local / shared filesystem path
17 | state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
18 | hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
19 | hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
20 |
21 | source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
22 | architecture: Optional[str] = None # architecture variant can be set when not implicit
23 | tag: Optional[str] = None # pretrained tag of source
24 | custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
25 |
26 | # input / data config
27 | input_size: Tuple[int, int, int] = (3, 224, 224)
28 | test_input_size: Optional[Tuple[int, int, int]] = None
29 | min_input_size: Optional[Tuple[int, int, int]] = None
30 | fixed_input_size: bool = False
31 | interpolation: str = 'bicubic'
32 | crop_pct: float = 0.875
33 | test_crop_pct: Optional[float] = None
34 | crop_mode: str = 'center'
35 | mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
36 | std: Tuple[float, ...] = (0.229, 0.224, 0.225)
37 |
38 | # head / classifier config and meta-data
39 | num_classes: int = 1000
40 | label_offset: Optional[int] = None
41 | label_names: Optional[Tuple[str]] = None
42 | label_descriptions: Optional[Dict[str, str]] = None
43 |
44 | # model attributes that vary with above or required for pretrained adaptation
45 | pool_size: Optional[Tuple[int, ...]] = None
46 | test_pool_size: Optional[Tuple[int, ...]] = None
47 | first_conv: Optional[str] = None
48 | classifier: Optional[str] = None
49 |
50 | license: Optional[str] = None
51 | description: Optional[str] = None
52 | origin_url: Optional[str] = None
53 | paper_name: Optional[str] = None
54 | paper_ids: Optional[Union[str, Tuple[str]]] = None
55 | notes: Optional[Tuple[str]] = None
56 |
57 | @property
58 | def has_weights(self):
59 | return self.url or self.file or self.hf_hub_id
60 |
61 | def to_dict(self, remove_source=False, remove_null=True):
62 | return filter_pretrained_cfg(
63 | asdict(self),
64 | remove_source=remove_source,
65 | remove_null=remove_null
66 | )
67 |
68 |
69 | def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
70 | filtered_cfg = {}
71 | keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
72 | for k, v in cfg.items():
73 | if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
74 | continue
75 | if remove_null and v is None and k not in keep_null:
76 | continue
77 | filtered_cfg[k] = v
78 | return filtered_cfg
79 |
80 |
81 | @dataclass
82 | class DefaultCfg:
83 | tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
84 | cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
85 | is_pretrained: bool = False # at least one of the configs has a pretrained source set
86 |
87 | @property
88 | def default(self):
89 | return self.cfgs[self.tags[0]]
90 |
91 | @property
92 | def default_with_tag(self):
93 | tag = self.tags[0]
94 | return tag, self.cfgs[tag]
95 |
--------------------------------------------------------------------------------
/model/_prune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pkgutil
3 | from copy import deepcopy
4 |
5 | from torch import nn as nn
6 |
7 | from timm.layers import Conv2dSame, BatchNormAct2d, Linear
8 |
9 | __all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
10 |
11 |
12 | def extract_layer(model, layer):
13 | layer = layer.split('.')
14 | module = model
15 | if hasattr(model, 'module') and layer[0] != 'module':
16 | module = model.module
17 | if not hasattr(model, 'module') and layer[0] == 'module':
18 | layer = layer[1:]
19 | for l in layer:
20 | if hasattr(module, l):
21 | if not l.isdigit():
22 | module = getattr(module, l)
23 | else:
24 | module = module[int(l)]
25 | else:
26 | return module
27 | return module
28 |
29 |
30 | def set_layer(model, layer, val):
31 | layer = layer.split('.')
32 | module = model
33 | if hasattr(model, 'module') and layer[0] != 'module':
34 | module = model.module
35 | lst_index = 0
36 | module2 = module
37 | for l in layer:
38 | if hasattr(module2, l):
39 | if not l.isdigit():
40 | module2 = getattr(module2, l)
41 | else:
42 | module2 = module2[int(l)]
43 | lst_index += 1
44 | lst_index -= 1
45 | for l in layer[:lst_index]:
46 | if not l.isdigit():
47 | module = getattr(module, l)
48 | else:
49 | module = module[int(l)]
50 | l = layer[lst_index]
51 | setattr(module, l, val)
52 |
53 |
54 | def adapt_model_from_string(parent_module, model_string):
55 | separator = '***'
56 | state_dict = {}
57 | lst_shape = model_string.split(separator)
58 | for k in lst_shape:
59 | k = k.split(':')
60 | key = k[0]
61 | shape = k[1][1:-1].split(',')
62 | if shape[0] != '':
63 | state_dict[key] = [int(i) for i in shape]
64 |
65 | new_module = deepcopy(parent_module)
66 | for n, m in parent_module.named_modules():
67 | old_module = extract_layer(parent_module, n)
68 | if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
69 | if isinstance(old_module, Conv2dSame):
70 | conv = Conv2dSame
71 | else:
72 | conv = nn.Conv2d
73 | s = state_dict[n + '.weight']
74 | in_channels = s[1]
75 | out_channels = s[0]
76 | g = 1
77 | if old_module.groups > 1:
78 | in_channels = out_channels
79 | g = in_channels
80 | new_conv = conv(
81 | in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
82 | bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
83 | groups=g, stride=old_module.stride)
84 | set_layer(new_module, n, new_conv)
85 | elif isinstance(old_module, BatchNormAct2d):
86 | new_bn = BatchNormAct2d(
87 | state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
88 | affine=old_module.affine, track_running_stats=True)
89 | new_bn.drop = old_module.drop
90 | new_bn.act = old_module.act
91 | set_layer(new_module, n, new_bn)
92 | elif isinstance(old_module, nn.BatchNorm2d):
93 | new_bn = nn.BatchNorm2d(
94 | num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
95 | affine=old_module.affine, track_running_stats=True)
96 | set_layer(new_module, n, new_bn)
97 | elif isinstance(old_module, nn.Linear):
98 | # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
99 | num_features = state_dict[n + '.weight'][1]
100 | new_fc = Linear(
101 | in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
102 | set_layer(new_module, n, new_fc)
103 | if hasattr(new_module, 'num_features'):
104 | new_module.num_features = num_features
105 | new_module.eval()
106 | parent_module.eval()
107 |
108 | return new_module
109 |
110 |
111 | def adapt_model_from_file(parent_module, model_variant):
112 | adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
113 | return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())
114 |
--------------------------------------------------------------------------------
/model/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
2 | rand_augment_transform, auto_augment_transform
3 | from .config import resolve_data_config, resolve_model_data_config
4 | from .constants import *
5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
6 | from .dataset_factory import create_dataset
7 | from .dataset_info import DatasetInfo, CustomDatasetInfo
8 | from .imagenet_info import ImageNetInfo, infer_imagenet_subset
9 | from .loader import create_loader
10 | from .mixup import Mixup, FastCollateMixup
11 | from .readers import create_reader
12 | from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
13 | from .real_labels import RealLabelsImagenet
14 | from .transforms import *
15 | from .transforms_factory import create_transform
16 | from .constants import *
17 |
18 |
--------------------------------------------------------------------------------
/model/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/auto_augment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/auto_augment.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/constants.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/constants.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/dataset_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/dataset_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/dataset_info.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/dataset_info.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/distributed_sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/distributed_sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/imagenet_info.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/imagenet_info.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/loader.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/mixup.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/mixup.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/random_erasing.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/random_erasing.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/real_labels.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/real_labels.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/__pycache__/transforms_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/transforms_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/_info/imagenet_a_indices.txt:
--------------------------------------------------------------------------------
1 | 6
2 | 11
3 | 13
4 | 15
5 | 17
6 | 22
7 | 23
8 | 27
9 | 30
10 | 37
11 | 39
12 | 42
13 | 47
14 | 50
15 | 57
16 | 70
17 | 71
18 | 76
19 | 79
20 | 89
21 | 90
22 | 94
23 | 96
24 | 97
25 | 99
26 | 105
27 | 107
28 | 108
29 | 110
30 | 113
31 | 124
32 | 125
33 | 130
34 | 132
35 | 143
36 | 144
37 | 150
38 | 151
39 | 207
40 | 234
41 | 235
42 | 254
43 | 277
44 | 283
45 | 287
46 | 291
47 | 295
48 | 298
49 | 301
50 | 306
51 | 307
52 | 308
53 | 309
54 | 310
55 | 311
56 | 313
57 | 314
58 | 315
59 | 317
60 | 319
61 | 323
62 | 324
63 | 326
64 | 327
65 | 330
66 | 334
67 | 335
68 | 336
69 | 347
70 | 361
71 | 363
72 | 372
73 | 378
74 | 386
75 | 397
76 | 400
77 | 401
78 | 402
79 | 404
80 | 407
81 | 411
82 | 416
83 | 417
84 | 420
85 | 425
86 | 428
87 | 430
88 | 437
89 | 438
90 | 445
91 | 456
92 | 457
93 | 461
94 | 462
95 | 470
96 | 472
97 | 483
98 | 486
99 | 488
100 | 492
101 | 496
102 | 514
103 | 516
104 | 528
105 | 530
106 | 539
107 | 542
108 | 543
109 | 549
110 | 552
111 | 557
112 | 561
113 | 562
114 | 569
115 | 572
116 | 573
117 | 575
118 | 579
119 | 589
120 | 606
121 | 607
122 | 609
123 | 614
124 | 626
125 | 627
126 | 640
127 | 641
128 | 642
129 | 643
130 | 658
131 | 668
132 | 677
133 | 682
134 | 684
135 | 687
136 | 701
137 | 704
138 | 719
139 | 736
140 | 746
141 | 749
142 | 752
143 | 758
144 | 763
145 | 765
146 | 768
147 | 773
148 | 774
149 | 776
150 | 779
151 | 780
152 | 786
153 | 792
154 | 797
155 | 802
156 | 803
157 | 804
158 | 813
159 | 815
160 | 820
161 | 823
162 | 831
163 | 833
164 | 835
165 | 839
166 | 845
167 | 847
168 | 850
169 | 859
170 | 862
171 | 870
172 | 879
173 | 880
174 | 888
175 | 890
176 | 897
177 | 900
178 | 907
179 | 913
180 | 924
181 | 932
182 | 933
183 | 934
184 | 937
185 | 943
186 | 945
187 | 947
188 | 951
189 | 954
190 | 956
191 | 957
192 | 959
193 | 971
194 | 972
195 | 980
196 | 981
197 | 984
198 | 986
199 | 987
200 | 988
201 |
--------------------------------------------------------------------------------
/model/data/_info/imagenet_a_synsets.txt:
--------------------------------------------------------------------------------
1 | n01498041
2 | n01531178
3 | n01534433
4 | n01558993
5 | n01580077
6 | n01614925
7 | n01616318
8 | n01631663
9 | n01641577
10 | n01669191
11 | n01677366
12 | n01687978
13 | n01694178
14 | n01698640
15 | n01735189
16 | n01770081
17 | n01770393
18 | n01774750
19 | n01784675
20 | n01819313
21 | n01820546
22 | n01833805
23 | n01843383
24 | n01847000
25 | n01855672
26 | n01882714
27 | n01910747
28 | n01914609
29 | n01924916
30 | n01944390
31 | n01985128
32 | n01986214
33 | n02007558
34 | n02009912
35 | n02037110
36 | n02051845
37 | n02077923
38 | n02085620
39 | n02099601
40 | n02106550
41 | n02106662
42 | n02110958
43 | n02119022
44 | n02123394
45 | n02127052
46 | n02129165
47 | n02133161
48 | n02137549
49 | n02165456
50 | n02174001
51 | n02177972
52 | n02190166
53 | n02206856
54 | n02219486
55 | n02226429
56 | n02231487
57 | n02233338
58 | n02236044
59 | n02259212
60 | n02268443
61 | n02279972
62 | n02280649
63 | n02281787
64 | n02317335
65 | n02325366
66 | n02346627
67 | n02356798
68 | n02361337
69 | n02410509
70 | n02445715
71 | n02454379
72 | n02486410
73 | n02492035
74 | n02504458
75 | n02655020
76 | n02669723
77 | n02672831
78 | n02676566
79 | n02690373
80 | n02701002
81 | n02730930
82 | n02777292
83 | n02782093
84 | n02787622
85 | n02793495
86 | n02797295
87 | n02802426
88 | n02814860
89 | n02815834
90 | n02837789
91 | n02879718
92 | n02883205
93 | n02895154
94 | n02906734
95 | n02948072
96 | n02951358
97 | n02980441
98 | n02992211
99 | n02999410
100 | n03014705
101 | n03026506
102 | n03124043
103 | n03125729
104 | n03187595
105 | n03196217
106 | n03223299
107 | n03250847
108 | n03255030
109 | n03291819
110 | n03325584
111 | n03355925
112 | n03384352
113 | n03388043
114 | n03417042
115 | n03443371
116 | n03444034
117 | n03445924
118 | n03452741
119 | n03483316
120 | n03584829
121 | n03590841
122 | n03594945
123 | n03617480
124 | n03666591
125 | n03670208
126 | n03717622
127 | n03720891
128 | n03721384
129 | n03724870
130 | n03775071
131 | n03788195
132 | n03804744
133 | n03837869
134 | n03840681
135 | n03854065
136 | n03888257
137 | n03891332
138 | n03935335
139 | n03982430
140 | n04019541
141 | n04033901
142 | n04039381
143 | n04067472
144 | n04086273
145 | n04099969
146 | n04118538
147 | n04131690
148 | n04133789
149 | n04141076
150 | n04146614
151 | n04147183
152 | n04179913
153 | n04208210
154 | n04235860
155 | n04252077
156 | n04252225
157 | n04254120
158 | n04270147
159 | n04275548
160 | n04310018
161 | n04317175
162 | n04344873
163 | n04347754
164 | n04355338
165 | n04366367
166 | n04376876
167 | n04389033
168 | n04399382
169 | n04442312
170 | n04456115
171 | n04482393
172 | n04507155
173 | n04509417
174 | n04532670
175 | n04540053
176 | n04554684
177 | n04562935
178 | n04591713
179 | n04606251
180 | n07583066
181 | n07695742
182 | n07697313
183 | n07697537
184 | n07714990
185 | n07718472
186 | n07720875
187 | n07734744
188 | n07749582
189 | n07753592
190 | n07760859
191 | n07768694
192 | n07831146
193 | n09229709
194 | n09246464
195 | n09472597
196 | n09835506
197 | n11879895
198 | n12057211
199 | n12144580
200 | n12267677
201 |
--------------------------------------------------------------------------------
/model/data/_info/imagenet_r_indices.txt:
--------------------------------------------------------------------------------
1 | 1
2 | 2
3 | 4
4 | 6
5 | 8
6 | 9
7 | 11
8 | 13
9 | 22
10 | 23
11 | 26
12 | 29
13 | 31
14 | 39
15 | 47
16 | 63
17 | 71
18 | 76
19 | 79
20 | 84
21 | 90
22 | 94
23 | 96
24 | 97
25 | 99
26 | 100
27 | 105
28 | 107
29 | 113
30 | 122
31 | 125
32 | 130
33 | 132
34 | 144
35 | 145
36 | 147
37 | 148
38 | 150
39 | 151
40 | 155
41 | 160
42 | 161
43 | 162
44 | 163
45 | 171
46 | 172
47 | 178
48 | 187
49 | 195
50 | 199
51 | 203
52 | 207
53 | 208
54 | 219
55 | 231
56 | 232
57 | 234
58 | 235
59 | 242
60 | 245
61 | 247
62 | 250
63 | 251
64 | 254
65 | 259
66 | 260
67 | 263
68 | 265
69 | 267
70 | 269
71 | 276
72 | 277
73 | 281
74 | 288
75 | 289
76 | 291
77 | 292
78 | 293
79 | 296
80 | 299
81 | 301
82 | 308
83 | 309
84 | 310
85 | 311
86 | 314
87 | 315
88 | 319
89 | 323
90 | 327
91 | 330
92 | 334
93 | 335
94 | 337
95 | 338
96 | 340
97 | 341
98 | 344
99 | 347
100 | 353
101 | 355
102 | 361
103 | 362
104 | 365
105 | 366
106 | 367
107 | 368
108 | 372
109 | 388
110 | 390
111 | 393
112 | 397
113 | 401
114 | 407
115 | 413
116 | 414
117 | 425
118 | 428
119 | 430
120 | 435
121 | 437
122 | 441
123 | 447
124 | 448
125 | 457
126 | 462
127 | 463
128 | 469
129 | 470
130 | 471
131 | 472
132 | 476
133 | 483
134 | 487
135 | 515
136 | 546
137 | 555
138 | 558
139 | 570
140 | 579
141 | 583
142 | 587
143 | 593
144 | 594
145 | 596
146 | 609
147 | 613
148 | 617
149 | 621
150 | 629
151 | 637
152 | 657
153 | 658
154 | 701
155 | 717
156 | 724
157 | 763
158 | 768
159 | 774
160 | 776
161 | 779
162 | 780
163 | 787
164 | 805
165 | 812
166 | 815
167 | 820
168 | 824
169 | 833
170 | 847
171 | 852
172 | 866
173 | 875
174 | 883
175 | 889
176 | 895
177 | 907
178 | 928
179 | 931
180 | 932
181 | 933
182 | 934
183 | 936
184 | 937
185 | 943
186 | 945
187 | 947
188 | 948
189 | 949
190 | 951
191 | 953
192 | 954
193 | 957
194 | 963
195 | 965
196 | 967
197 | 980
198 | 981
199 | 983
200 | 988
201 |
--------------------------------------------------------------------------------
/model/data/_info/imagenet_r_synsets.txt:
--------------------------------------------------------------------------------
1 | n01443537
2 | n01484850
3 | n01494475
4 | n01498041
5 | n01514859
6 | n01518878
7 | n01531178
8 | n01534433
9 | n01614925
10 | n01616318
11 | n01630670
12 | n01632777
13 | n01644373
14 | n01677366
15 | n01694178
16 | n01748264
17 | n01770393
18 | n01774750
19 | n01784675
20 | n01806143
21 | n01820546
22 | n01833805
23 | n01843383
24 | n01847000
25 | n01855672
26 | n01860187
27 | n01882714
28 | n01910747
29 | n01944390
30 | n01983481
31 | n01986214
32 | n02007558
33 | n02009912
34 | n02051845
35 | n02056570
36 | n02066245
37 | n02071294
38 | n02077923
39 | n02085620
40 | n02086240
41 | n02088094
42 | n02088238
43 | n02088364
44 | n02088466
45 | n02091032
46 | n02091134
47 | n02092339
48 | n02094433
49 | n02096585
50 | n02097298
51 | n02098286
52 | n02099601
53 | n02099712
54 | n02102318
55 | n02106030
56 | n02106166
57 | n02106550
58 | n02106662
59 | n02108089
60 | n02108915
61 | n02109525
62 | n02110185
63 | n02110341
64 | n02110958
65 | n02112018
66 | n02112137
67 | n02113023
68 | n02113624
69 | n02113799
70 | n02114367
71 | n02117135
72 | n02119022
73 | n02123045
74 | n02128385
75 | n02128757
76 | n02129165
77 | n02129604
78 | n02130308
79 | n02134084
80 | n02138441
81 | n02165456
82 | n02190166
83 | n02206856
84 | n02219486
85 | n02226429
86 | n02233338
87 | n02236044
88 | n02268443
89 | n02279972
90 | n02317335
91 | n02325366
92 | n02346627
93 | n02356798
94 | n02363005
95 | n02364673
96 | n02391049
97 | n02395406
98 | n02398521
99 | n02410509
100 | n02423022
101 | n02437616
102 | n02445715
103 | n02447366
104 | n02480495
105 | n02480855
106 | n02481823
107 | n02483362
108 | n02486410
109 | n02510455
110 | n02526121
111 | n02607072
112 | n02655020
113 | n02672831
114 | n02701002
115 | n02749479
116 | n02769748
117 | n02793495
118 | n02797295
119 | n02802426
120 | n02808440
121 | n02814860
122 | n02823750
123 | n02841315
124 | n02843684
125 | n02883205
126 | n02906734
127 | n02909870
128 | n02939185
129 | n02948072
130 | n02950826
131 | n02951358
132 | n02966193
133 | n02980441
134 | n02992529
135 | n03124170
136 | n03272010
137 | n03345487
138 | n03372029
139 | n03424325
140 | n03452741
141 | n03467068
142 | n03481172
143 | n03494278
144 | n03495258
145 | n03498962
146 | n03594945
147 | n03602883
148 | n03630383
149 | n03649909
150 | n03676483
151 | n03710193
152 | n03773504
153 | n03775071
154 | n03888257
155 | n03930630
156 | n03947888
157 | n04086273
158 | n04118538
159 | n04133789
160 | n04141076
161 | n04146614
162 | n04147183
163 | n04192698
164 | n04254680
165 | n04266014
166 | n04275548
167 | n04310018
168 | n04325704
169 | n04347754
170 | n04389033
171 | n04409515
172 | n04465501
173 | n04487394
174 | n04522168
175 | n04536866
176 | n04552348
177 | n04591713
178 | n07614500
179 | n07693725
180 | n07695742
181 | n07697313
182 | n07697537
183 | n07714571
184 | n07714990
185 | n07718472
186 | n07720875
187 | n07734744
188 | n07742313
189 | n07745940
190 | n07749582
191 | n07753275
192 | n07753592
193 | n07768694
194 | n07873807
195 | n07880968
196 | n07920052
197 | n09472597
198 | n09835506
199 | n10565667
200 | n12267677
201 |
--------------------------------------------------------------------------------
/model/data/config.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .constants import *
3 |
4 |
5 | _logger = logging.getLogger(__name__)
6 |
7 |
8 | def resolve_data_config(
9 | args=None,
10 | pretrained_cfg=None,
11 | model=None,
12 | use_test_size=False,
13 | verbose=False
14 | ):
15 | assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
16 | args = args or {}
17 | pretrained_cfg = pretrained_cfg or {}
18 | if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
19 | pretrained_cfg = model.pretrained_cfg
20 | data_config = {}
21 |
22 | # Resolve input/image size
23 | in_chans = 3
24 | if args.get('in_chans', None) is not None:
25 | in_chans = args['in_chans']
26 | elif args.get('chans', None) is not None:
27 | in_chans = args['chans']
28 |
29 | input_size = (in_chans, 224, 224)
30 | if args.get('input_size', None) is not None:
31 | assert isinstance(args['input_size'], (tuple, list))
32 | assert len(args['input_size']) == 3
33 | input_size = tuple(args['input_size'])
34 | in_chans = input_size[0] # input_size overrides in_chans
35 | elif args.get('img_size', None) is not None:
36 | assert isinstance(args['img_size'], int)
37 | input_size = (in_chans, args['img_size'], args['img_size'])
38 | else:
39 | if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
40 | input_size = pretrained_cfg['test_input_size']
41 | elif pretrained_cfg.get('input_size', None) is not None:
42 | input_size = pretrained_cfg['input_size']
43 | data_config['input_size'] = input_size
44 |
45 | # resolve interpolation method
46 | data_config['interpolation'] = 'bicubic'
47 | if args.get('interpolation', None):
48 | data_config['interpolation'] = args['interpolation']
49 | elif pretrained_cfg.get('interpolation', None):
50 | data_config['interpolation'] = pretrained_cfg['interpolation']
51 |
52 | # resolve dataset + model mean for normalization
53 | data_config['mean'] = IMAGENET_DEFAULT_MEAN
54 | if args.get('mean', None) is not None:
55 | mean = tuple(args['mean'])
56 | if len(mean) == 1:
57 | mean = tuple(list(mean) * in_chans)
58 | else:
59 | assert len(mean) == in_chans
60 | data_config['mean'] = mean
61 | elif pretrained_cfg.get('mean', None):
62 | data_config['mean'] = pretrained_cfg['mean']
63 |
64 | # resolve dataset + model std deviation for normalization
65 | data_config['std'] = IMAGENET_DEFAULT_STD
66 | if args.get('std', None) is not None:
67 | std = tuple(args['std'])
68 | if len(std) == 1:
69 | std = tuple(list(std) * in_chans)
70 | else:
71 | assert len(std) == in_chans
72 | data_config['std'] = std
73 | elif pretrained_cfg.get('std', None):
74 | data_config['std'] = pretrained_cfg['std']
75 |
76 | # resolve default inference crop
77 | crop_pct = DEFAULT_CROP_PCT
78 | if args.get('crop_pct', None):
79 | crop_pct = args['crop_pct']
80 | else:
81 | if use_test_size and pretrained_cfg.get('test_crop_pct', None):
82 | crop_pct = pretrained_cfg['test_crop_pct']
83 | elif pretrained_cfg.get('crop_pct', None):
84 | crop_pct = pretrained_cfg['crop_pct']
85 | data_config['crop_pct'] = crop_pct
86 |
87 | # resolve default crop percentage
88 | crop_mode = DEFAULT_CROP_MODE
89 | if args.get('crop_mode', None):
90 | crop_mode = args['crop_mode']
91 | elif pretrained_cfg.get('crop_mode', None):
92 | crop_mode = pretrained_cfg['crop_mode']
93 | data_config['crop_mode'] = crop_mode
94 |
95 | if verbose:
96 | _logger.info('Data processing configuration for current model + dataset:')
97 | for n, v in data_config.items():
98 | _logger.info('\t%s: %s' % (n, str(v)))
99 |
100 | return data_config
101 |
102 |
103 | def resolve_model_data_config(
104 | model,
105 | args=None,
106 | pretrained_cfg=None,
107 | use_test_size=False,
108 | verbose=False,
109 | ):
110 | """ Resolve Model Data Config
111 | This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
112 |
113 | Args:
114 | model (nn.Module): the model instance
115 | args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
116 | pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
117 | use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
118 | verbose (bool): enable extra logging of resolved values
119 |
120 | Returns:
121 | dictionary of config
122 | """
123 | return resolve_data_config(
124 | args=args,
125 | pretrained_cfg=pretrained_cfg,
126 | model=model,
127 | use_test_size=use_test_size,
128 | verbose=verbose,
129 | )
130 |
--------------------------------------------------------------------------------
/model/data/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_CROP_PCT = 0.875
2 | DEFAULT_CROP_MODE = 'center'
3 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
4 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
5 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
6 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
7 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
8 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
9 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
10 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
11 |
--------------------------------------------------------------------------------
/model/data/dataset_info.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, List, Optional, Union
3 |
4 |
5 | class DatasetInfo(ABC):
6 |
7 | def __init__(self):
8 | pass
9 |
10 | @abstractmethod
11 | def num_classes(self):
12 | pass
13 |
14 | @abstractmethod
15 | def label_names(self):
16 | pass
17 |
18 | @abstractmethod
19 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
20 | pass
21 |
22 | @abstractmethod
23 | def index_to_label_name(self, index) -> str:
24 | pass
25 |
26 | @abstractmethod
27 | def index_to_description(self, index: int, detailed: bool = False) -> str:
28 | pass
29 |
30 | @abstractmethod
31 | def label_name_to_description(self, label: str, detailed: bool = False) -> str:
32 | pass
33 |
34 |
35 | class CustomDatasetInfo(DatasetInfo):
36 | """ DatasetInfo that wraps passed values for custom datasets."""
37 |
38 | def __init__(
39 | self,
40 | label_names: Union[List[str], Dict[int, str]],
41 | label_descriptions: Optional[Dict[str, str]] = None
42 | ):
43 | super().__init__()
44 | assert len(label_names) > 0
45 | self._label_names = label_names # label index => label name mapping
46 | self._label_descriptions = label_descriptions # label name => label description mapping
47 | if self._label_descriptions is not None:
48 | # validate descriptions (label names required)
49 | assert isinstance(self._label_descriptions, dict)
50 | for n in self._label_names:
51 | assert n in self._label_descriptions
52 |
53 | def num_classes(self):
54 | return len(self._label_names)
55 |
56 | def label_names(self):
57 | return self._label_names
58 |
59 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
60 | return self._label_descriptions
61 |
62 | def label_name_to_description(self, label: str, detailed: bool = False) -> str:
63 | if self._label_descriptions:
64 | return self._label_descriptions[label]
65 | return label # return label name itself if a descriptions is not present
66 |
67 | def index_to_label_name(self, index) -> str:
68 | assert 0 <= index < len(self._label_names)
69 | return self._label_names[index]
70 |
71 | def index_to_description(self, index: int, detailed: bool = False) -> str:
72 | label = self.index_to_label_name(index)
73 | return self.label_name_to_description(label, detailed=detailed)
74 |
--------------------------------------------------------------------------------
/model/data/imagenet_info.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import pkgutil
4 | import re
5 | from typing import Dict, List, Optional, Union
6 |
7 | from .dataset_info import DatasetInfo
8 |
9 |
10 | # NOTE no ambiguity wrt to mapping from # classes to ImageNet subset so far, but likely to change
11 | _NUM_CLASSES_TO_SUBSET = {
12 | 1000: 'imagenet-1k',
13 | 11221: 'imagenet-21k-miil', # miil subset of fall11
14 | 11821: 'imagenet-12k', # timm specific 12k subset of fall11
15 | 21841: 'imagenet-22k', # as in fall11.tar
16 | 21842: 'imagenet-22k-ms', # a Microsoft (for FocalNet) remapping of 22k w/ moves ImageNet-1k classes to first 1000
17 | 21843: 'imagenet-21k-goog', # Google's ImageNet full has two classes not in fall11
18 | }
19 |
20 | _SUBSETS = {
21 | 'imagenet1k': 'imagenet_synsets.txt',
22 | 'imagenet12k': 'imagenet12k_synsets.txt',
23 | 'imagenet22k': 'imagenet22k_synsets.txt',
24 | 'imagenet21k': 'imagenet21k_goog_synsets.txt',
25 | 'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
26 | 'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
27 | 'imagenet22kms': 'imagenet22k_ms_synsets.txt',
28 | }
29 | _LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
30 | _DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
31 |
32 |
33 | def infer_imagenet_subset(model_or_cfg) -> Optional[str]:
34 | if isinstance(model_or_cfg, dict):
35 | num_classes = model_or_cfg.get('num_classes', None)
36 | else:
37 | num_classes = getattr(model_or_cfg, 'num_classes', None)
38 | if not num_classes:
39 | pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {})
40 | # FIXME at some point pretrained_cfg should include dataset-tag,
41 | # which will be more robust than a guess based on num_classes
42 | num_classes = pretrained_cfg.get('num_classes', None)
43 | if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET:
44 | return None
45 | return _NUM_CLASSES_TO_SUBSET[num_classes]
46 |
47 |
48 | class ImageNetInfo(DatasetInfo):
49 |
50 | def __init__(self, subset: str = 'imagenet-1k'):
51 | super().__init__()
52 | subset = re.sub(r'[-_\s]', '', subset.lower())
53 | assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.'
54 |
55 | # WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers
56 | synset_file = _SUBSETS[subset]
57 | synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file))
58 | self._synsets = synset_data.decode('utf-8').splitlines()
59 |
60 | # WordNet lemmas (canonical dictionary form of word) and definitions are used to build
61 | # the class descriptions. If detailed=True both are used, otherwise just the lemmas.
62 | lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE))
63 | reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t')
64 | self._lemmas = dict(reader)
65 | definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE))
66 | reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t')
67 | self._definitions = dict(reader)
68 |
69 | def num_classes(self):
70 | return len(self._synsets)
71 |
72 | def label_names(self):
73 | return self._synsets
74 |
75 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
76 | if as_dict:
77 | return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets}
78 | else:
79 | return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets]
80 |
81 | def index_to_label_name(self, index) -> str:
82 | assert 0 <= index < len(self._synsets), \
83 | f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.'
84 | return self._synsets[index]
85 |
86 | def index_to_description(self, index: int, detailed: bool = False) -> str:
87 | label = self.index_to_label_name(index)
88 | return self.label_name_to_description(label, detailed=detailed)
89 |
90 | def label_name_to_description(self, label: str, detailed: bool = False) -> str:
91 | if detailed:
92 | description = f'{self._lemmas[label]}: {self._definitions[label]}'
93 | else:
94 | description = f'{self._lemmas[label]}'
95 | return description
96 |
--------------------------------------------------------------------------------
/model/data/random_erasing.py:
--------------------------------------------------------------------------------
1 | """ Random Erasing (Cutout)
2 |
3 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
4 | Copyright Zhun Zhong & Liang Zheng
5 |
6 | Hacked together by / Copyright 2019, Ross Wightman
7 | """
8 | import random
9 | import math
10 |
11 | import torch
12 |
13 |
14 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
15 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
16 | # paths, flip the order so normal is run on CPU if this becomes a problem
17 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
18 | if per_pixel:
19 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
20 | elif rand_color:
21 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
22 | else:
23 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
24 |
25 |
26 | class RandomErasing:
27 | """ Randomly selects a rectangle region in an image and erases its pixels.
28 | 'Random Erasing Data Augmentation' by Zhong et al.
29 | See https://arxiv.org/pdf/1708.04896.pdf
30 |
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1/3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode='const',
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device='cuda',
58 | ):
59 | self.probability = probability
60 | self.min_area = min_area
61 | self.max_area = max_area
62 | max_aspect = max_aspect or 1 / min_aspect
63 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
64 | self.min_count = min_count
65 | self.max_count = max_count or min_count
66 | self.num_splits = num_splits
67 | self.mode = mode.lower()
68 | self.rand_color = False
69 | self.per_pixel = False
70 | if self.mode == 'rand':
71 | self.rand_color = True # per block random normal
72 | elif self.mode == 'pixel':
73 | self.per_pixel = True # per pixel random normal
74 | else:
75 | assert not self.mode or self.mode == 'const'
76 | self.device = device
77 |
78 | def _erase(self, img, chan, img_h, img_w, dtype):
79 | if random.random() > self.probability:
80 | return
81 | area = img_h * img_w
82 | count = self.min_count if self.min_count == self.max_count else \
83 | random.randint(self.min_count, self.max_count)
84 | for _ in range(count):
85 | for attempt in range(10):
86 | target_area = random.uniform(self.min_area, self.max_area) * area / count
87 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
88 | h = int(round(math.sqrt(target_area * aspect_ratio)))
89 | w = int(round(math.sqrt(target_area / aspect_ratio)))
90 | if w < img_w and h < img_h:
91 | top = random.randint(0, img_h - h)
92 | left = random.randint(0, img_w - w)
93 | img[:, top:top + h, left:left + w] = _get_pixels(
94 | self.per_pixel,
95 | self.rand_color,
96 | (chan, h, w),
97 | dtype=dtype,
98 | device=self.device,
99 | )
100 | break
101 |
102 | def __call__(self, input):
103 | if len(input.size()) == 3:
104 | self._erase(input, *input.size(), input.dtype)
105 | else:
106 | batch_size, chan, img_h, img_w = input.size()
107 | # skip first slice of batch if num_splits is set (for clean portion of samples)
108 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
109 | for i in range(batch_start, batch_size):
110 | self._erase(input[i], chan, img_h, img_w, input.dtype)
111 | return input
112 |
113 | def __repr__(self):
114 | # NOTE simplified state for repr
115 | fs = self.__class__.__name__ + f'(p={self.probability}, mode={self.mode}'
116 | fs += f', count=({self.min_count}, {self.max_count}))'
117 | return fs
118 |
--------------------------------------------------------------------------------
/model/data/readers/__init__.py:
--------------------------------------------------------------------------------
1 | from .reader_factory import create_reader
2 | from .img_extensions import *
3 |
--------------------------------------------------------------------------------
/model/data/readers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/readers/__pycache__/class_map.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/class_map.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/readers/__pycache__/img_extensions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/img_extensions.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/readers/__pycache__/reader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/reader.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/readers/__pycache__/reader_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/reader_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/readers/__pycache__/reader_image_folder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/reader_image_folder.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/readers/__pycache__/reader_image_in_tar.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/reader_image_in_tar.cpython-38.pyc
--------------------------------------------------------------------------------
/model/data/readers/class_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 |
5 | def load_class_map(map_or_filename, root=''):
6 | if isinstance(map_or_filename, dict):
7 | assert dict, 'class_map dict must be non-empty'
8 | return map_or_filename
9 | class_map_path = map_or_filename
10 | if not os.path.exists(class_map_path):
11 | class_map_path = os.path.join(root, class_map_path)
12 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
13 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
14 | if class_map_ext == '.txt':
15 | with open(class_map_path) as f:
16 | class_to_idx = {v.strip(): k for k, v in enumerate(f)}
17 | elif class_map_ext == '.pkl':
18 | with open(class_map_path, 'rb') as f:
19 | class_to_idx = pickle.load(f)
20 | else:
21 | assert False, f'Unsupported class map file extension ({class_map_ext}).'
22 | return class_to_idx
23 |
24 |
--------------------------------------------------------------------------------
/model/data/readers/img_extensions.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
4 |
5 |
6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
8 |
9 |
10 | def _set_extensions(extensions):
11 | global IMG_EXTENSIONS
12 | global _IMG_EXTENSIONS_SET
13 | dedupe = set() # NOTE de-duping tuple while keeping original order
14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
15 | _IMG_EXTENSIONS_SET = set(extensions)
16 |
17 |
18 | def _valid_extension(x: str):
19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
20 |
21 |
22 | def is_img_extension(ext):
23 | return ext in _IMG_EXTENSIONS_SET
24 |
25 |
26 | def get_img_extensions(as_set=False):
27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
28 |
29 |
30 | def set_img_extensions(extensions):
31 | assert len(extensions)
32 | for x in extensions:
33 | assert _valid_extension(x)
34 | _set_extensions(extensions)
35 |
36 |
37 | def add_img_extensions(ext):
38 | if not isinstance(ext, (list, tuple, set)):
39 | ext = (ext,)
40 | for x in ext:
41 | assert _valid_extension(x)
42 | extensions = IMG_EXTENSIONS + tuple(ext)
43 | _set_extensions(extensions)
44 |
45 |
46 | def del_img_extensions(ext):
47 | if not isinstance(ext, (list, tuple, set)):
48 | ext = (ext,)
49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
50 | _set_extensions(extensions)
51 |
--------------------------------------------------------------------------------
/model/data/readers/reader.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 |
4 | class Reader:
5 | def __init__(self):
6 | pass
7 |
8 | @abstractmethod
9 | def _filename(self, index, basename=False, absolute=False):
10 | pass
11 |
12 | def filename(self, index, basename=False, absolute=False):
13 | return self._filename(index, basename=basename, absolute=absolute)
14 |
15 | def filenames(self, basename=False, absolute=False):
16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
17 |
18 |
--------------------------------------------------------------------------------
/model/data/readers/reader_factory.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .reader_image_folder import ReaderImageFolder
4 | from .reader_image_in_tar import ReaderImageInTar
5 |
6 |
7 | def create_reader(name, root, split='train', **kwargs):
8 | name = name.lower()
9 | name = name.split('/', 1)
10 | prefix = ''
11 | if len(name) > 1:
12 | prefix = name[0]
13 | name = name[-1]
14 |
15 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
16 | # explicitly select other options shortly
17 | if prefix == 'hfds':
18 | from .reader_hfds import ReaderHfds # defer tensorflow import
19 | reader = ReaderHfds(root, name, split=split, **kwargs)
20 | elif prefix == 'tfds':
21 | from .reader_tfds import ReaderTfds # defer tensorflow import
22 | reader = ReaderTfds(root, name, split=split, **kwargs)
23 | elif prefix == 'wds':
24 | from .reader_wds import ReaderWds
25 | kwargs.pop('download', False)
26 | reader = ReaderWds(root, name, split=split, **kwargs)
27 | else:
28 | assert os.path.exists(root)
29 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
30 | # FIXME support split here or in reader?
31 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
32 | reader = ReaderImageInTar(root, **kwargs)
33 | else:
34 | reader = ReaderImageFolder(root, **kwargs)
35 | return reader
36 |
--------------------------------------------------------------------------------
/model/data/readers/reader_hfds.py:
--------------------------------------------------------------------------------
1 | """ Dataset reader that wraps Hugging Face datasets
2 |
3 | Hacked together by / Copyright 2022 Ross Wightman
4 | """
5 | import io
6 | import math
7 | import torch
8 | import torch.distributed as dist
9 | from PIL import Image
10 |
11 | try:
12 | import datasets
13 | except ImportError as e:
14 | print("Please install Hugging Face datasets package `pip install datasets`.")
15 | exit(1)
16 | from .class_map import load_class_map
17 | from .reader import Reader
18 |
19 |
20 | def get_class_labels(info, label_key='label'):
21 | if 'label' not in info.features:
22 | return {}
23 | class_label = info.features[label_key]
24 | class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
25 | return class_to_idx
26 |
27 |
28 | class ReaderHfds(Reader):
29 |
30 | def __init__(
31 | self,
32 | root,
33 | name,
34 | split='train',
35 | class_map=None,
36 | label_key='label',
37 | download=False,
38 | ):
39 | """
40 | """
41 | super().__init__()
42 | self.root = root
43 | self.split = split
44 | self.dataset = datasets.load_dataset(
45 | name, # 'name' maps to path arg in hf datasets
46 | split=split,
47 | cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
48 | )
49 | # leave decode for caller, plus we want easy access to original path names...
50 | self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
51 |
52 | self.label_key = label_key
53 | self.remap_class = False
54 | if class_map:
55 | self.class_to_idx = load_class_map(class_map)
56 | self.remap_class = True
57 | else:
58 | self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
59 | self.split_info = self.dataset.info.splits[split]
60 | self.num_samples = self.split_info.num_examples
61 |
62 | def __getitem__(self, index):
63 | item = self.dataset[index]
64 | image = item['image']
65 | if 'bytes' in image and image['bytes']:
66 | image = io.BytesIO(image['bytes'])
67 | else:
68 | assert 'path' in image and image['path']
69 | image = open(image['path'], 'rb')
70 | label = item[self.label_key]
71 | if self.remap_class:
72 | label = self.class_to_idx[label]
73 | return image, label
74 |
75 | def __len__(self):
76 | return len(self.dataset)
77 |
78 | def _filename(self, index, basename=False, absolute=False):
79 | item = self.dataset[index]
80 | return item['image']['path']
81 |
--------------------------------------------------------------------------------
/model/data/readers/reader_image_folder.py:
--------------------------------------------------------------------------------
1 | """ A dataset reader that extracts images from folders
2 |
3 | Folders are scanned recursively to find image files. Labels are based
4 | on the folder hierarchy, just leaf folders by default.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import os
9 | from typing import Dict, List, Optional, Set, Tuple, Union
10 |
11 | from timm.utils.misc import natural_key
12 |
13 | from .class_map import load_class_map
14 | from .img_extensions import get_img_extensions
15 | from .reader import Reader
16 |
17 |
18 | def find_images_and_targets(
19 | folder: str,
20 | types: Optional[Union[List, Tuple, Set]] = None,
21 | class_to_idx: Optional[Dict] = None,
22 | leaf_name_only: bool = True,
23 | sort: bool = True
24 | ):
25 | """ Walk folder recursively to discover images and map them to classes by folder names.
26 |
27 | Args:
28 | folder: root of folder to recrusively search
29 | types: types (file extensions) to search for in path
30 | class_to_idx: specify mapping for class (folder name) to class index if set
31 | leaf_name_only: use only leaf-name of folder walk for class names
32 | sort: re-sort found images by name (for consistent ordering)
33 |
34 | Returns:
35 | A list of image and target tuples, class_to_idx mapping
36 | """
37 | types = get_img_extensions(as_set=True) if not types else set(types)
38 | labels = []
39 | filenames = []
40 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
41 | rel_path = os.path.relpath(root, folder) if (root != folder) else ''
42 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
43 | for f in files:
44 | base, ext = os.path.splitext(f)
45 | if ext.lower() in types:
46 | filenames.append(os.path.join(root, f))
47 | labels.append(label)
48 | if class_to_idx is None:
49 | # building class index
50 | unique_labels = set(labels)
51 | sorted_labels = list(sorted(unique_labels, key=natural_key))
52 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
53 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
54 | if sort:
55 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
56 | return images_and_targets, class_to_idx
57 |
58 |
59 | class ReaderImageFolder(Reader):
60 |
61 | def __init__(
62 | self,
63 | root,
64 | class_map=''):
65 | super().__init__()
66 |
67 | self.root = root
68 | class_to_idx = None
69 | if class_map:
70 | class_to_idx = load_class_map(class_map, root)
71 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
72 | if len(self.samples) == 0:
73 | raise RuntimeError(
74 | f'Found 0 images in subfolders of {root}. '
75 | f'Supported image extensions are {", ".join(get_img_extensions())}')
76 |
77 | def __getitem__(self, index):
78 | path, target = self.samples[index]
79 | return open(path, 'rb'), target
80 |
81 | def __len__(self):
82 | return len(self.samples)
83 |
84 | def _filename(self, index, basename=False, absolute=False):
85 | filename = self.samples[index][0]
86 | if basename:
87 | filename = os.path.basename(filename)
88 | elif not absolute:
89 | filename = os.path.relpath(filename, self.root)
90 | return filename
91 |
--------------------------------------------------------------------------------
/model/data/readers/reader_image_tar.py:
--------------------------------------------------------------------------------
1 | """ A dataset reader that reads single tarfile based datasets
2 |
3 | This reader can read datasets consisting if a single tarfile containing images.
4 | I am planning to deprecated it in favour of ParerImageInTar.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import os
9 | import tarfile
10 |
11 | from timm.utils.misc import natural_key
12 |
13 | from .class_map import load_class_map
14 | from .img_extensions import get_img_extensions
15 | from .reader import Reader
16 |
17 |
18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
19 | extensions = get_img_extensions(as_set=True)
20 | files = []
21 | labels = []
22 | for ti in tarfile.getmembers():
23 | if not ti.isfile():
24 | continue
25 | dirname, basename = os.path.split(ti.path)
26 | label = os.path.basename(dirname)
27 | ext = os.path.splitext(basename)[1]
28 | if ext.lower() in extensions:
29 | files.append(ti)
30 | labels.append(label)
31 | if class_to_idx is None:
32 | unique_labels = set(labels)
33 | sorted_labels = list(sorted(unique_labels, key=natural_key))
34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
36 | if sort:
37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
38 | return tarinfo_and_targets, class_to_idx
39 |
40 |
41 | class ReaderImageTar(Reader):
42 | """ Single tarfile dataset where classes are mapped to folders within tar
43 | NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
44 | operate on folders of tars or tars in tars.
45 | """
46 | def __init__(self, root, class_map=''):
47 | super().__init__()
48 |
49 | class_to_idx = None
50 | if class_map:
51 | class_to_idx = load_class_map(class_map, root)
52 | assert os.path.isfile(root)
53 | self.root = root
54 |
55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
57 | self.imgs = self.samples
58 | self.tarfile = None # lazy init in __getitem__
59 |
60 | def __getitem__(self, index):
61 | if self.tarfile is None:
62 | self.tarfile = tarfile.open(self.root)
63 | tarinfo, target = self.samples[index]
64 | fileobj = self.tarfile.extractfile(tarinfo)
65 | return fileobj, target
66 |
67 | def __len__(self):
68 | return len(self.samples)
69 |
70 | def _filename(self, index, basename=False, absolute=False):
71 | filename = self.samples[index][0].name
72 | if basename:
73 | filename = os.path.basename(filename)
74 | return filename
75 |
--------------------------------------------------------------------------------
/model/data/readers/shared_count.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Value
2 |
3 |
4 | class SharedCount:
5 | def __init__(self, epoch: int = 0):
6 | self.shared_epoch = Value('i', epoch)
7 |
8 | @property
9 | def value(self):
10 | return self.shared_epoch.value
11 |
12 | @value.setter
13 | def value(self, epoch):
14 | self.shared_epoch.value = epoch
15 |
--------------------------------------------------------------------------------
/model/data/real_labels.py:
--------------------------------------------------------------------------------
1 | """ Real labels evaluator for ImageNet
2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import os
8 | import json
9 | import numpy as np
10 | import pkgutil
11 |
12 |
13 | class RealLabelsImagenet:
14 |
15 | def __init__(self, filenames, real_json=None, topk=(1, 5)):
16 | if real_json is not None:
17 | with open(real_json) as real_labels:
18 | real_labels = json.load(real_labels)
19 | else:
20 | real_labels = json.loads(
21 | pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
22 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
23 | self.real_labels = real_labels
24 | self.filenames = filenames
25 | assert len(self.filenames) == len(self.real_labels)
26 | self.topk = topk
27 | self.is_correct = {k: [] for k in topk}
28 | self.sample_idx = 0
29 |
30 | def add_result(self, output):
31 | maxk = max(self.topk)
32 | _, pred_batch = output.topk(maxk, 1, True, True)
33 | pred_batch = pred_batch.cpu().numpy()
34 | for pred in pred_batch:
35 | filename = self.filenames[self.sample_idx]
36 | filename = os.path.basename(filename)
37 | if self.real_labels[filename]:
38 | for k in self.topk:
39 | self.is_correct[k].append(
40 | any([p in self.real_labels[filename] for p in pred[:k]]))
41 | self.sample_idx += 1
42 |
43 | def get_accuracy(self, k=None):
44 | if k is None:
45 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
46 | else:
47 | return float(np.mean(self.is_correct[k])) * 100
48 |
--------------------------------------------------------------------------------
/model/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .activations import *
2 | from .adaptive_avgmax_pool import \
3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4 | from .attention_pool import AttentionPoolLatent
5 | from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
6 | from .blur_pool import BlurPool2d
7 | from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
8 | from .cond_conv2d import CondConv2d, get_condconv_initializer
9 | from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
10 | set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
11 | from .conv2d_same import Conv2dSame, conv2d_same
12 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
13 | from .create_act import create_act_layer, get_act_layer, get_act_fn
14 | from .create_attn import get_attn, create_attn
15 | from .create_conv2d import create_conv2d
16 | from .create_norm import get_norm_layer, create_norm_layer
17 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
18 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
19 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
20 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
21 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
22 | from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
23 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
24 | from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
25 | from .gather_excite import GatherExcite
26 | from .global_context import GlobalContext
27 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
28 | from .inplace_abn import InplaceAbn
29 | from .linear import Linear
30 | from .mixed_conv2d import MixedConv2d
31 | from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
32 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn
33 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
34 | from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
35 | SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
36 | from .padding import get_padding, get_same_padding, pad_same
37 | from .patch_dropout import PatchDropout
38 | from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
39 | from .pool2d_same import AvgPool2dSame, create_pool2d
40 | from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
41 | from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
42 | resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit
43 | from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
44 | build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
45 | FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
46 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
47 | from .selective_kernel import SelectiveKernel
48 | from .separable_conv import SeparableConv2d, SeparableConvNormAct
49 | from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
50 | from .split_attn import SplitAttn
51 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
52 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
53 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool
54 | from .trace_utils import _assert, _float_to_int
55 | from .typing import LayerType, PadType
56 | from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
57 |
--------------------------------------------------------------------------------
/model/layers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/activations.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/activations.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/activations_jit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/activations_jit.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/activations_me.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/activations_me.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/attention_pool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/attention_pool.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/attention_pool2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/attention_pool2d.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/blur_pool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/blur_pool.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/bottleneck_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/bottleneck_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/cbam.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/cbam.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/classifier.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/classifier.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/cond_conv2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/cond_conv2d.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/conv2d_same.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/conv2d_same.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/conv_bn_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/conv_bn_act.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/create_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_act.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/create_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/create_conv2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_conv2d.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/create_norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_norm.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/create_norm_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_norm_act.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/drop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/drop.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/eca.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/eca.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/evo_norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/evo_norm.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/fast_norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/fast_norm.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/filter_response_norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/filter_response_norm.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/format.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/format.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/gather_excite.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/gather_excite.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/global_context.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/global_context.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/grn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/grn.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/halo_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/halo_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/helpers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/helpers.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/inplace_abn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/inplace_abn.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/interpolate.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/interpolate.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/lambda_layer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/lambda_layer.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/linear.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/linear.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/mixed_conv2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/mixed_conv2d.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/mlp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/mlp.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/non_local_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/non_local_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/norm.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/norm_act.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/norm_act.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/padding.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/padding.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/patch_dropout.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/patch_dropout.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/patch_embed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/patch_embed.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/pool2d_same.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/pool2d_same.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/pos_embed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/pos_embed.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/pos_embed_rel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/pos_embed_rel.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/pos_embed_sincos.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/pos_embed_sincos.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/selective_kernel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/selective_kernel.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/separable_conv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/separable_conv.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/space_to_depth.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/space_to_depth.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/split_attn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/split_attn.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/split_batchnorm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/split_batchnorm.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/squeeze_excite.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/squeeze_excite.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/std_conv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/std_conv.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/test_time_pool.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/test_time_pool.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/trace_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/trace_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/typing.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/typing.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/__pycache__/weight_init.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/weight_init.cpython-38.pyc
--------------------------------------------------------------------------------
/model/layers/activations_jit.py:
--------------------------------------------------------------------------------
1 | """ Activations
2 |
3 | A collection of jit-scripted activations fn and modules with a common interface so that they can
4 | easily be swapped. All have an `inplace` arg even if not used.
5 |
6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
8 | versions if they contain in-place ops.
9 |
10 | Hacked together by / Copyright 2020 Ross Wightman
11 | """
12 |
13 | import torch
14 | from torch import nn as nn
15 | from torch.nn import functional as F
16 |
17 |
18 | @torch.jit.script
19 | def swish_jit(x, inplace: bool = False):
20 | """Swish - Described in: https://arxiv.org/abs/1710.05941
21 | """
22 | return x.mul(x.sigmoid())
23 |
24 |
25 | @torch.jit.script
26 | def mish_jit(x, _inplace: bool = False):
27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
28 | """
29 | return x.mul(F.softplus(x).tanh())
30 |
31 |
32 | class SwishJit(nn.Module):
33 | def __init__(self, inplace: bool = False):
34 | super(SwishJit, self).__init__()
35 |
36 | def forward(self, x):
37 | return swish_jit(x)
38 |
39 |
40 | class MishJit(nn.Module):
41 | def __init__(self, inplace: bool = False):
42 | super(MishJit, self).__init__()
43 |
44 | def forward(self, x):
45 | return mish_jit(x)
46 |
47 |
48 | @torch.jit.script
49 | def hard_sigmoid_jit(x, inplace: bool = False):
50 | # return F.relu6(x + 3.) / 6.
51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
52 |
53 |
54 | class HardSigmoidJit(nn.Module):
55 | def __init__(self, inplace: bool = False):
56 | super(HardSigmoidJit, self).__init__()
57 |
58 | def forward(self, x):
59 | return hard_sigmoid_jit(x)
60 |
61 |
62 | @torch.jit.script
63 | def hard_swish_jit(x, inplace: bool = False):
64 | # return x * (F.relu6(x + 3.) / 6)
65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
66 |
67 |
68 | class HardSwishJit(nn.Module):
69 | def __init__(self, inplace: bool = False):
70 | super(HardSwishJit, self).__init__()
71 |
72 | def forward(self, x):
73 | return hard_swish_jit(x)
74 |
75 |
76 | @torch.jit.script
77 | def hard_mish_jit(x, inplace: bool = False):
78 | """ Hard Mish
79 | Experimental, based on notes by Mish author Diganta Misra at
80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
81 | """
82 | return 0.5 * x * (x + 2).clamp(min=0, max=2)
83 |
84 |
85 | class HardMishJit(nn.Module):
86 | def __init__(self, inplace: bool = False):
87 | super(HardMishJit, self).__init__()
88 |
89 | def forward(self, x):
90 | return hard_mish_jit(x)
91 |
--------------------------------------------------------------------------------
/model/layers/attention_pool.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .config import use_fused_attn
8 | from .mlp import Mlp
9 | from .weight_init import trunc_normal_tf_
10 |
11 |
12 | class AttentionPoolLatent(nn.Module):
13 | """ Attention pooling w/ latent query
14 | """
15 | fused_attn: torch.jit.Final[bool]
16 |
17 | def __init__(
18 | self,
19 | in_features: int,
20 | out_features: int = None,
21 | embed_dim: int = None,
22 | num_heads: int = 8,
23 | mlp_ratio: float = 4.0,
24 | qkv_bias: bool = True,
25 | qk_norm: bool = False,
26 | latent_len: int = 1,
27 | latent_dim: int = None,
28 | pos_embed: str = '',
29 | pool_type: str = 'token',
30 | norm_layer: Optional[nn.Module] = None,
31 | drop: float = 0.0,
32 | ):
33 | super().__init__()
34 | embed_dim = embed_dim or in_features
35 | out_features = out_features or in_features
36 | assert embed_dim % num_heads == 0
37 | self.num_heads = num_heads
38 | self.head_dim = embed_dim // num_heads
39 | self.scale = self.head_dim ** -0.5
40 | self.pool = pool_type
41 | self.fused_attn = use_fused_attn()
42 |
43 | if pos_embed == 'abs':
44 | spatial_len = self.feat_size
45 | self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
46 | else:
47 | self.pos_embed = None
48 |
49 | self.latent_dim = latent_dim or embed_dim
50 | self.latent_len = latent_len
51 | self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
52 |
53 | self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
54 | self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
55 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
56 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
57 | self.proj = nn.Linear(embed_dim, embed_dim)
58 | self.proj_drop = nn.Dropout(drop)
59 |
60 | self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
61 | self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
62 |
63 | self.init_weights()
64 |
65 | def init_weights(self):
66 | if self.pos_embed is not None:
67 | trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
68 | trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
69 |
70 | def forward(self, x):
71 | B, N, C = x.shape
72 |
73 | if self.pos_embed is not None:
74 | # FIXME interpolate
75 | x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
76 |
77 | q_latent = self.latent.expand(B, -1, -1)
78 | q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
79 |
80 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
81 | k, v = kv.unbind(0)
82 |
83 | q, k = self.q_norm(q), self.k_norm(k)
84 |
85 | if self.fused_attn:
86 | x = F.scaled_dot_product_attention(q, k, v)
87 | else:
88 | q = q * self.scale
89 | attn = q @ k.transpose(-2, -1)
90 | attn = attn.softmax(dim=-1)
91 | x = attn @ v
92 | x = x.transpose(1, 2).reshape(B, self.latent_len, C)
93 | x = self.proj(x)
94 | x = self.proj_drop(x)
95 |
96 | x = x + self.mlp(self.norm(x))
97 |
98 | # optional pool if latent seq_len > 1 and pooled output is desired
99 | if self.pool == 'token':
100 | x = x[:, 0]
101 | elif self.pool == 'avg':
102 | x = x.mean(1)
103 | return x
--------------------------------------------------------------------------------
/model/layers/blur_pool.py:
--------------------------------------------------------------------------------
1 | """
2 | BlurPool layer inspired by
3 | - Kornia's Max_BlurPool2d
4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
5 |
6 | Hacked together by Chris Ha and Ross Wightman
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import numpy as np
13 | from .padding import get_padding
14 |
15 |
16 | class BlurPool2d(nn.Module):
17 | r"""Creates a module that computes blurs and downsample a given feature map.
18 | See :cite:`zhang2019shiftinvar` for more details.
19 | Corresponds to the Downsample class, which does blurring and subsampling
20 |
21 | Args:
22 | channels = Number of input channels
23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
24 | stride (int): downsampling filter stride
25 |
26 | Returns:
27 | torch.Tensor: the transformed tensor.
28 | """
29 | def __init__(self, channels, filt_size=3, stride=2) -> None:
30 | super(BlurPool2d, self).__init__()
31 | assert filt_size > 1
32 | self.channels = channels
33 | self.filt_size = filt_size
34 | self.stride = stride
35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
38 | self.register_buffer('filt', blur_filter, persistent=False)
39 |
40 | def forward(self, x: torch.Tensor) -> torch.Tensor:
41 | x = F.pad(x, self.padding, 'reflect')
42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
43 |
--------------------------------------------------------------------------------
/model/layers/cbam.py:
--------------------------------------------------------------------------------
1 | """ CBAM (sort-of) Attention
2 |
3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
4 |
5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
6 | some tasks, especially fine-grained it seems. I may end up removing this impl.
7 |
8 | Hacked together by / Copyright 2020 Ross Wightman
9 | """
10 | import torch
11 | from torch import nn as nn
12 | import torch.nn.functional as F
13 |
14 | from .conv_bn_act import ConvNormAct
15 | from .create_act import create_act_layer, get_act_layer
16 | from .helpers import make_divisible
17 |
18 |
19 | class ChannelAttn(nn.Module):
20 | """ Original CBAM channel attention module, currently avg + max pool variant only.
21 | """
22 | def __init__(
23 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
24 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
25 | super(ChannelAttn, self).__init__()
26 | if not rd_channels:
27 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
28 | self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
29 | self.act = act_layer(inplace=True)
30 | self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
31 | self.gate = create_act_layer(gate_layer)
32 |
33 | def forward(self, x):
34 | x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
35 | x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
36 | return x * self.gate(x_avg + x_max)
37 |
38 |
39 | class LightChannelAttn(ChannelAttn):
40 | """An experimental 'lightweight' that sums avg + max pool first
41 | """
42 | def __init__(
43 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
44 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
45 | super(LightChannelAttn, self).__init__(
46 | channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
47 |
48 | def forward(self, x):
49 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
50 | x_attn = self.fc2(self.act(self.fc1(x_pool)))
51 | return x * F.sigmoid(x_attn)
52 |
53 |
54 | class SpatialAttn(nn.Module):
55 | """ Original CBAM spatial attention module
56 | """
57 | def __init__(self, kernel_size=7, gate_layer='sigmoid'):
58 | super(SpatialAttn, self).__init__()
59 | self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False)
60 | self.gate = create_act_layer(gate_layer)
61 |
62 | def forward(self, x):
63 | x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
64 | x_attn = self.conv(x_attn)
65 | return x * self.gate(x_attn)
66 |
67 |
68 | class LightSpatialAttn(nn.Module):
69 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results.
70 | """
71 | def __init__(self, kernel_size=7, gate_layer='sigmoid'):
72 | super(LightSpatialAttn, self).__init__()
73 | self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False)
74 | self.gate = create_act_layer(gate_layer)
75 |
76 | def forward(self, x):
77 | x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
78 | x_attn = self.conv(x_attn)
79 | return x * self.gate(x_attn)
80 |
81 |
82 | class CbamModule(nn.Module):
83 | def __init__(
84 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
85 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
86 | super(CbamModule, self).__init__()
87 | self.channel = ChannelAttn(
88 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
89 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
90 | self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
91 |
92 | def forward(self, x):
93 | x = self.channel(x)
94 | x = self.spatial(x)
95 | return x
96 |
97 |
98 | class LightCbamModule(nn.Module):
99 | def __init__(
100 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
101 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
102 | super(LightCbamModule, self).__init__()
103 | self.channel = LightChannelAttn(
104 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
105 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
106 | self.spatial = LightSpatialAttn(spatial_kernel_size)
107 |
108 | def forward(self, x):
109 | x = self.channel(x)
110 | x = self.spatial(x)
111 | return x
112 |
113 |
--------------------------------------------------------------------------------
/model/layers/config.py:
--------------------------------------------------------------------------------
1 | """ Model / Layer Config singleton state
2 | """
3 | import os
4 | import warnings
5 | from typing import Any, Optional
6 |
7 | import torch
8 |
9 | __all__ = [
10 | 'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
11 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
12 | ]
13 |
14 | # Set to True if prefer to have layers with no jit optimization (includes activations)
15 | _NO_JIT = False
16 |
17 | # Set to True if prefer to have activation layers with no jit optimization
18 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
19 | # the jit flags so far are activations. This will change as more layers are updated and/or added.
20 | _NO_ACTIVATION_JIT = False
21 |
22 | # Set to True if exporting a model with Same padding via ONNX
23 | _EXPORTABLE = False
24 |
25 | # Set to True if wanting to use torch.jit.script on a model
26 | _SCRIPTABLE = False
27 |
28 |
29 | # use torch.scaled_dot_product_attention where possible
30 | _HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
31 | if 'TIMM_FUSED_ATTN' in os.environ:
32 | _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])
33 | else:
34 | _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
35 |
36 |
37 | def is_no_jit():
38 | return _NO_JIT
39 |
40 |
41 | class set_no_jit:
42 | def __init__(self, mode: bool) -> None:
43 | global _NO_JIT
44 | self.prev = _NO_JIT
45 | _NO_JIT = mode
46 |
47 | def __enter__(self) -> None:
48 | pass
49 |
50 | def __exit__(self, *args: Any) -> bool:
51 | global _NO_JIT
52 | _NO_JIT = self.prev
53 | return False
54 |
55 |
56 | def is_exportable():
57 | return _EXPORTABLE
58 |
59 |
60 | class set_exportable:
61 | def __init__(self, mode: bool) -> None:
62 | global _EXPORTABLE
63 | self.prev = _EXPORTABLE
64 | _EXPORTABLE = mode
65 |
66 | def __enter__(self) -> None:
67 | pass
68 |
69 | def __exit__(self, *args: Any) -> bool:
70 | global _EXPORTABLE
71 | _EXPORTABLE = self.prev
72 | return False
73 |
74 |
75 | def is_scriptable():
76 | return _SCRIPTABLE
77 |
78 |
79 | class set_scriptable:
80 | def __init__(self, mode: bool) -> None:
81 | global _SCRIPTABLE
82 | self.prev = _SCRIPTABLE
83 | _SCRIPTABLE = mode
84 |
85 | def __enter__(self) -> None:
86 | pass
87 |
88 | def __exit__(self, *args: Any) -> bool:
89 | global _SCRIPTABLE
90 | _SCRIPTABLE = self.prev
91 | return False
92 |
93 |
94 | class set_layer_config:
95 | """ Layer config context manager that allows setting all layer config flags at once.
96 | If a flag arg is None, it will not change the current value.
97 | """
98 | def __init__(
99 | self,
100 | scriptable: Optional[bool] = None,
101 | exportable: Optional[bool] = None,
102 | no_jit: Optional[bool] = None,
103 | no_activation_jit: Optional[bool] = None):
104 | global _SCRIPTABLE
105 | global _EXPORTABLE
106 | global _NO_JIT
107 | global _NO_ACTIVATION_JIT
108 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
109 | if scriptable is not None:
110 | _SCRIPTABLE = scriptable
111 | if exportable is not None:
112 | _EXPORTABLE = exportable
113 | if no_jit is not None:
114 | _NO_JIT = no_jit
115 | if no_activation_jit is not None:
116 | _NO_ACTIVATION_JIT = no_activation_jit
117 |
118 | def __enter__(self) -> None:
119 | pass
120 |
121 | def __exit__(self, *args: Any) -> bool:
122 | global _SCRIPTABLE
123 | global _EXPORTABLE
124 | global _NO_JIT
125 | global _NO_ACTIVATION_JIT
126 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
127 | return False
128 |
129 |
130 | def use_fused_attn(experimental: bool = False) -> bool:
131 | # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
132 | if not _HAS_FUSED_ATTN or _EXPORTABLE:
133 | return False
134 | if experimental:
135 | return _USE_FUSED_ATTN > 1
136 | return _USE_FUSED_ATTN > 0
137 |
138 |
139 | def set_fused_attn(enable: bool = True, experimental: bool = False):
140 | global _USE_FUSED_ATTN
141 | if not _HAS_FUSED_ATTN:
142 | warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.')
143 | return
144 | if experimental and enable:
145 | _USE_FUSED_ATTN = 2
146 | elif enable:
147 | _USE_FUSED_ATTN = 1
148 | else:
149 | _USE_FUSED_ATTN = 0
150 |
--------------------------------------------------------------------------------
/model/layers/conv2d_same.py:
--------------------------------------------------------------------------------
1 | """ Conv2d w/ Same Padding
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import Tuple, Optional
9 |
10 | from .config import is_exportable, is_scriptable
11 | from .padding import pad_same, pad_same_arg, get_padding_value
12 |
13 |
14 | _USE_EXPORT_CONV = False
15 |
16 |
17 | def conv2d_same(
18 | x,
19 | weight: torch.Tensor,
20 | bias: Optional[torch.Tensor] = None,
21 | stride: Tuple[int, int] = (1, 1),
22 | padding: Tuple[int, int] = (0, 0),
23 | dilation: Tuple[int, int] = (1, 1),
24 | groups: int = 1,
25 | ):
26 | x = pad_same(x, weight.shape[-2:], stride, dilation)
27 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
28 |
29 |
30 | class Conv2dSame(nn.Conv2d):
31 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
32 | """
33 |
34 | def __init__(
35 | self,
36 | in_channels,
37 | out_channels,
38 | kernel_size,
39 | stride=1,
40 | padding=0,
41 | dilation=1,
42 | groups=1,
43 | bias=True,
44 | ):
45 | super(Conv2dSame, self).__init__(
46 | in_channels, out_channels, kernel_size,
47 | stride, 0, dilation, groups, bias,
48 | )
49 |
50 | def forward(self, x):
51 | return conv2d_same(
52 | x, self.weight, self.bias,
53 | self.stride, self.padding, self.dilation, self.groups,
54 | )
55 |
56 |
57 | class Conv2dSameExport(nn.Conv2d):
58 | """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
59 |
60 | NOTE: This does not currently work with torch.jit.script
61 | """
62 |
63 | # pylint: disable=unused-argument
64 | def __init__(
65 | self,
66 | in_channels,
67 | out_channels,
68 | kernel_size,
69 | stride=1,
70 | padding=0,
71 | dilation=1,
72 | groups=1,
73 | bias=True,
74 | ):
75 | super(Conv2dSameExport, self).__init__(
76 | in_channels, out_channels, kernel_size,
77 | stride, 0, dilation, groups, bias,
78 | )
79 | self.pad = None
80 | self.pad_input_size = (0, 0)
81 |
82 | def forward(self, x):
83 | input_size = x.size()[-2:]
84 | if self.pad is None:
85 | pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
86 | self.pad = nn.ZeroPad2d(pad_arg)
87 | self.pad_input_size = input_size
88 |
89 | x = self.pad(x)
90 | return F.conv2d(
91 | x, self.weight, self.bias,
92 | self.stride, self.padding, self.dilation, self.groups,
93 | )
94 |
95 |
96 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
97 | padding = kwargs.pop('padding', '')
98 | kwargs.setdefault('bias', False)
99 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
100 | if is_dynamic:
101 | if _USE_EXPORT_CONV and is_exportable():
102 | # older PyTorch ver needed this to export same padding reasonably
103 | assert not is_scriptable() # Conv2DSameExport does not work with jit
104 | return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
105 | else:
106 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
107 | else:
108 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
109 |
110 |
111 |
--------------------------------------------------------------------------------
/model/layers/conv_bn_act.py:
--------------------------------------------------------------------------------
1 | """ Conv2d + BN + Act
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import functools
6 | from torch import nn as nn
7 |
8 | from .create_conv2d import create_conv2d
9 | from .create_norm_act import get_norm_act_layer
10 |
11 |
12 | class ConvNormAct(nn.Module):
13 | def __init__(
14 | self,
15 | in_channels,
16 | out_channels,
17 | kernel_size=1,
18 | stride=1,
19 | padding='',
20 | dilation=1,
21 | groups=1,
22 | bias=False,
23 | apply_act=True,
24 | norm_layer=nn.BatchNorm2d,
25 | norm_kwargs=None,
26 | act_layer=nn.ReLU,
27 | act_kwargs=None,
28 | drop_layer=None,
29 | ):
30 | super(ConvNormAct, self).__init__()
31 | norm_kwargs = norm_kwargs or {}
32 | act_kwargs = act_kwargs or {}
33 |
34 | self.conv = create_conv2d(
35 | in_channels, out_channels, kernel_size, stride=stride,
36 | padding=padding, dilation=dilation, groups=groups, bias=bias)
37 |
38 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
39 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
40 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
41 | if drop_layer:
42 | norm_kwargs['drop_layer'] = drop_layer
43 | self.bn = norm_act_layer(
44 | out_channels,
45 | apply_act=apply_act,
46 | act_kwargs=act_kwargs,
47 | **norm_kwargs,
48 | )
49 |
50 | @property
51 | def in_channels(self):
52 | return self.conv.in_channels
53 |
54 | @property
55 | def out_channels(self):
56 | return self.conv.out_channels
57 |
58 | def forward(self, x):
59 | x = self.conv(x)
60 | x = self.bn(x)
61 | return x
62 |
63 |
64 | ConvBnAct = ConvNormAct
65 |
66 |
67 | def create_aa(aa_layer, channels, stride=2, enable=True):
68 | if not aa_layer or not enable:
69 | return nn.Identity()
70 | if isinstance(aa_layer, functools.partial):
71 | if issubclass(aa_layer.func, nn.AvgPool2d):
72 | return aa_layer()
73 | else:
74 | return aa_layer(channels)
75 | elif issubclass(aa_layer, nn.AvgPool2d):
76 | return aa_layer(stride)
77 | else:
78 | return aa_layer(channels=channels, stride=stride)
79 |
80 |
81 | class ConvNormActAa(nn.Module):
82 | def __init__(
83 | self,
84 | in_channels,
85 | out_channels,
86 | kernel_size=1,
87 | stride=1,
88 | padding='',
89 | dilation=1,
90 | groups=1,
91 | bias=False,
92 | apply_act=True,
93 | norm_layer=nn.BatchNorm2d,
94 | norm_kwargs=None,
95 | act_layer=nn.ReLU,
96 | act_kwargs=None,
97 | aa_layer=None,
98 | drop_layer=None,
99 | ):
100 | super(ConvNormActAa, self).__init__()
101 | use_aa = aa_layer is not None and stride == 2
102 | norm_kwargs = norm_kwargs or {}
103 | act_kwargs = act_kwargs or {}
104 |
105 | self.conv = create_conv2d(
106 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
107 | padding=padding, dilation=dilation, groups=groups, bias=bias)
108 |
109 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
110 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
111 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
112 | if drop_layer:
113 | norm_kwargs['drop_layer'] = drop_layer
114 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
115 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
116 |
117 | @property
118 | def in_channels(self):
119 | return self.conv.in_channels
120 |
121 | @property
122 | def out_channels(self):
123 | return self.conv.out_channels
124 |
125 | def forward(self, x):
126 | x = self.conv(x)
127 | x = self.bn(x)
128 | x = self.aa(x)
129 | return x
130 |
--------------------------------------------------------------------------------
/model/layers/create_attn.py:
--------------------------------------------------------------------------------
1 | """ Attention Factory
2 |
3 | Hacked together by / Copyright 2021 Ross Wightman
4 | """
5 | import torch
6 | from functools import partial
7 |
8 | from .bottleneck_attn import BottleneckAttn
9 | from .cbam import CbamModule, LightCbamModule
10 | from .eca import EcaModule, CecaModule
11 | from .gather_excite import GatherExcite
12 | from .global_context import GlobalContext
13 | from .halo_attn import HaloAttn
14 | from .lambda_layer import LambdaLayer
15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn
16 | from .selective_kernel import SelectiveKernel
17 | from .split_attn import SplitAttn
18 | from .squeeze_excite import SEModule, EffectiveSEModule
19 |
20 |
21 | def get_attn(attn_type):
22 | if isinstance(attn_type, torch.nn.Module):
23 | return attn_type
24 | module_cls = None
25 | if attn_type:
26 | if isinstance(attn_type, str):
27 | attn_type = attn_type.lower()
28 | # Lightweight attention modules (channel and/or coarse spatial).
29 | # Typically added to existing network architecture blocks in addition to existing convolutions.
30 | if attn_type == 'se':
31 | module_cls = SEModule
32 | elif attn_type == 'ese':
33 | module_cls = EffectiveSEModule
34 | elif attn_type == 'eca':
35 | module_cls = EcaModule
36 | elif attn_type == 'ecam':
37 | module_cls = partial(EcaModule, use_mlp=True)
38 | elif attn_type == 'ceca':
39 | module_cls = CecaModule
40 | elif attn_type == 'ge':
41 | module_cls = GatherExcite
42 | elif attn_type == 'gc':
43 | module_cls = GlobalContext
44 | elif attn_type == 'gca':
45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
46 | elif attn_type == 'cbam':
47 | module_cls = CbamModule
48 | elif attn_type == 'lcbam':
49 | module_cls = LightCbamModule
50 |
51 | # Attention / attention-like modules w/ significant params
52 | # Typically replace some of the existing workhorse convs in a network architecture.
53 | # All of these accept a stride argument and can spatially downsample the input.
54 | elif attn_type == 'sk':
55 | module_cls = SelectiveKernel
56 | elif attn_type == 'splat':
57 | module_cls = SplitAttn
58 |
59 | # Self-attention / attention-like modules w/ significant compute and/or params
60 | # Typically replace some of the existing workhorse convs in a network architecture.
61 | # All of these accept a stride argument and can spatially downsample the input.
62 | elif attn_type == 'lambda':
63 | return LambdaLayer
64 | elif attn_type == 'bottleneck':
65 | return BottleneckAttn
66 | elif attn_type == 'halo':
67 | return HaloAttn
68 | elif attn_type == 'nl':
69 | module_cls = NonLocalAttn
70 | elif attn_type == 'bat':
71 | module_cls = BatNonLocalAttn
72 |
73 | # Woops!
74 | else:
75 | assert False, "Invalid attn module (%s)" % attn_type
76 | elif isinstance(attn_type, bool):
77 | if attn_type:
78 | module_cls = SEModule
79 | else:
80 | module_cls = attn_type
81 | return module_cls
82 |
83 |
84 | def create_attn(attn_type, channels, **kwargs):
85 | module_cls = get_attn(attn_type)
86 | if module_cls is not None:
87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
88 | return module_cls(channels, **kwargs)
89 | return None
90 |
--------------------------------------------------------------------------------
/model/layers/create_conv2d.py:
--------------------------------------------------------------------------------
1 | """ Create Conv2d Factory Method
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | from .mixed_conv2d import MixedConv2d
7 | from .cond_conv2d import CondConv2d
8 | from .conv2d_same import create_conv2d_pad
9 |
10 |
11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
12 | """ Select a 2d convolution implementation based on arguments
13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
14 |
15 | Used extensively by EfficientNet, MobileNetv3 and related networks.
16 | """
17 | if isinstance(kernel_size, list):
18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
19 | if 'groups' in kwargs:
20 | groups = kwargs.pop('groups')
21 | if groups == in_channels:
22 | kwargs['depthwise'] = True
23 | else:
24 | assert groups == 1
25 | # We're going to use only lists for defining the MixedConv2d kernel groups,
26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
28 | else:
29 | depthwise = kwargs.pop('depthwise', False)
30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
31 | groups = in_channels if depthwise else kwargs.pop('groups', 1)
32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
34 | else:
35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
36 | return m
37 |
--------------------------------------------------------------------------------
/model/layers/create_norm.py:
--------------------------------------------------------------------------------
1 | """ Norm Layer Factory
2 |
3 | Create norm modules by string (to mirror create_act and creat_norm-act fns)
4 |
5 | Copyright 2022 Ross Wightman
6 | """
7 | import functools
8 | import types
9 | from typing import Type
10 |
11 | import torch.nn as nn
12 |
13 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
14 | from torchvision.ops.misc import FrozenBatchNorm2d
15 |
16 | _NORM_MAP = dict(
17 | batchnorm=nn.BatchNorm2d,
18 | batchnorm2d=nn.BatchNorm2d,
19 | batchnorm1d=nn.BatchNorm1d,
20 | groupnorm=GroupNorm,
21 | groupnorm1=GroupNorm1,
22 | layernorm=LayerNorm,
23 | layernorm2d=LayerNorm2d,
24 | rmsnorm=RmsNorm,
25 | frozenbatchnorm2d=FrozenBatchNorm2d,
26 | )
27 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()}
28 |
29 |
30 | def create_norm_layer(layer_name, num_features, **kwargs):
31 | layer = get_norm_layer(layer_name)
32 | layer_instance = layer(num_features, **kwargs)
33 | return layer_instance
34 |
35 |
36 | def get_norm_layer(norm_layer):
37 | if norm_layer is None:
38 | return None
39 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
40 | norm_kwargs = {}
41 |
42 | # unbind partial fn, so args can be rebound later
43 | if isinstance(norm_layer, functools.partial):
44 | norm_kwargs.update(norm_layer.keywords)
45 | norm_layer = norm_layer.func
46 |
47 | if isinstance(norm_layer, str):
48 | if not norm_layer:
49 | return None
50 | layer_name = norm_layer.replace('_', '')
51 | norm_layer = _NORM_MAP[layer_name]
52 | else:
53 | norm_layer = norm_layer
54 |
55 | if norm_kwargs:
56 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
57 | return norm_layer
58 |
--------------------------------------------------------------------------------
/model/layers/create_norm_act.py:
--------------------------------------------------------------------------------
1 | """ NormAct (Normalizaiton + Activation Layer) Factory
2 |
3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with
5 | combined modules like IABN or EvoNorms.
6 |
7 | Hacked together by / Copyright 2020 Ross Wightman
8 | """
9 | import types
10 | import functools
11 |
12 | from .evo_norm import *
13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
15 | from .inplace_abn import InplaceAbn
16 |
17 | _NORM_ACT_MAP = dict(
18 | batchnorm=BatchNormAct2d,
19 | batchnorm2d=BatchNormAct2d,
20 | groupnorm=GroupNormAct,
21 | groupnorm1=functools.partial(GroupNormAct, num_groups=1),
22 | layernorm=LayerNormAct,
23 | layernorm2d=LayerNormAct2d,
24 | evonormb0=EvoNorm2dB0,
25 | evonormb1=EvoNorm2dB1,
26 | evonormb2=EvoNorm2dB2,
27 | evonorms0=EvoNorm2dS0,
28 | evonorms0a=EvoNorm2dS0a,
29 | evonorms1=EvoNorm2dS1,
30 | evonorms1a=EvoNorm2dS1a,
31 | evonorms2=EvoNorm2dS2,
32 | evonorms2a=EvoNorm2dS2a,
33 | frn=FilterResponseNormAct2d,
34 | frntlu=FilterResponseNormTlu2d,
35 | inplaceabn=InplaceAbn,
36 | iabn=InplaceAbn,
37 | )
38 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
39 | # has act_layer arg to define act type
40 | _NORM_ACT_REQUIRES_ARG = {
41 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
42 |
43 |
44 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
45 | layer = get_norm_act_layer(layer_name, act_layer=act_layer)
46 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
47 | if jit:
48 | layer_instance = torch.jit.script(layer_instance)
49 | return layer_instance
50 |
51 |
52 | def get_norm_act_layer(norm_layer, act_layer=None):
53 | if norm_layer is None:
54 | return None
55 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
56 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
57 | norm_act_kwargs = {}
58 |
59 | # unbind partial fn, so args can be rebound later
60 | if isinstance(norm_layer, functools.partial):
61 | norm_act_kwargs.update(norm_layer.keywords)
62 | norm_layer = norm_layer.func
63 |
64 | if isinstance(norm_layer, str):
65 | if not norm_layer:
66 | return None
67 | layer_name = norm_layer.replace('_', '').lower().split('-')[0]
68 | norm_act_layer = _NORM_ACT_MAP[layer_name]
69 | elif norm_layer in _NORM_ACT_TYPES:
70 | norm_act_layer = norm_layer
71 | elif isinstance(norm_layer, types.FunctionType):
72 | # if function type, must be a lambda/fn that creates a norm_act layer
73 | norm_act_layer = norm_layer
74 | else:
75 | type_name = norm_layer.__name__.lower()
76 | if type_name.startswith('batchnorm'):
77 | norm_act_layer = BatchNormAct2d
78 | elif type_name.startswith('groupnorm'):
79 | norm_act_layer = GroupNormAct
80 | elif type_name.startswith('groupnorm1'):
81 | norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
82 | elif type_name.startswith('layernorm2d'):
83 | norm_act_layer = LayerNormAct2d
84 | elif type_name.startswith('layernorm'):
85 | norm_act_layer = LayerNormAct
86 | else:
87 | assert False, f"No equivalent norm_act layer for {type_name}"
88 |
89 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
90 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
91 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
92 | norm_act_kwargs.setdefault('act_layer', act_layer)
93 | if norm_act_kwargs:
94 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
95 | return norm_act_layer
96 |
--------------------------------------------------------------------------------
/model/layers/fast_norm.py:
--------------------------------------------------------------------------------
1 | """ 'Fast' Normalization Functions
2 |
3 | For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
4 |
5 | Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
6 |
7 | Hacked together by / Copyright 2022 Ross Wightman
8 | """
9 | from typing import List, Optional
10 |
11 | import torch
12 | from torch.nn import functional as F
13 |
14 | try:
15 | from apex.normalization.fused_layer_norm import fused_layer_norm_affine
16 | has_apex = True
17 | except ImportError:
18 | has_apex = False
19 |
20 | try:
21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm
22 | has_apex_rmsnorm = True
23 | except ImportError:
24 | has_apex_rmsnorm = False
25 |
26 |
27 | # fast (ie lower precision LN) can be disabled with this flag if issues crop up
28 | _USE_FAST_NORM = False # defaulting to False for now
29 |
30 |
31 | def is_fast_norm():
32 | return _USE_FAST_NORM
33 |
34 |
35 | def set_fast_norm(enable=True):
36 | global _USE_FAST_NORM
37 | _USE_FAST_NORM = enable
38 |
39 |
40 | def fast_group_norm(
41 | x: torch.Tensor,
42 | num_groups: int,
43 | weight: Optional[torch.Tensor] = None,
44 | bias: Optional[torch.Tensor] = None,
45 | eps: float = 1e-5
46 | ) -> torch.Tensor:
47 | if torch.jit.is_scripting():
48 | # currently cannot use is_autocast_enabled within torchscript
49 | return F.group_norm(x, num_groups, weight, bias, eps)
50 |
51 | if torch.is_autocast_enabled():
52 | # normally native AMP casts GN inputs to float32
53 | # here we use the low precision autocast dtype
54 | # FIXME what to do re CPU autocast?
55 | dt = torch.get_autocast_gpu_dtype()
56 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
57 |
58 | with torch.cuda.amp.autocast(enabled=False):
59 | return F.group_norm(x, num_groups, weight, bias, eps)
60 |
61 |
62 | def fast_layer_norm(
63 | x: torch.Tensor,
64 | normalized_shape: List[int],
65 | weight: Optional[torch.Tensor] = None,
66 | bias: Optional[torch.Tensor] = None,
67 | eps: float = 1e-5
68 | ) -> torch.Tensor:
69 | if torch.jit.is_scripting():
70 | # currently cannot use is_autocast_enabled within torchscript
71 | return F.layer_norm(x, normalized_shape, weight, bias, eps)
72 |
73 | if has_apex:
74 | return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
75 |
76 | if torch.is_autocast_enabled():
77 | # normally native AMP casts LN inputs to float32
78 | # apex LN does not, this is behaving like Apex
79 | dt = torch.get_autocast_gpu_dtype()
80 | # FIXME what to do re CPU autocast?
81 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
82 |
83 | with torch.cuda.amp.autocast(enabled=False):
84 | return F.layer_norm(x, normalized_shape, weight, bias, eps)
85 |
86 |
87 | def rms_norm(
88 | x: torch.Tensor,
89 | normalized_shape: List[int],
90 | weight: Optional[torch.Tensor] = None,
91 | eps: float = 1e-5,
92 | ):
93 | norm_ndim = len(normalized_shape)
94 | if torch.jit.is_scripting():
95 | # ndim = len(x.shape)
96 | # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
97 | # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
98 | assert norm_ndim == 1
99 | v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
100 | else:
101 | dims = tuple(range(-1, -norm_ndim - 1, -1))
102 | v = torch.var(x, dim=dims, keepdim=True)
103 | x = x * torch.rsqrt(v + eps)
104 | if weight is not None:
105 | x = x * weight
106 | return x
107 |
108 |
109 | def fast_rms_norm(
110 | x: torch.Tensor,
111 | normalized_shape: List[int],
112 | weight: Optional[torch.Tensor] = None,
113 | eps: float = 1e-5,
114 | ) -> torch.Tensor:
115 | if torch.jit.is_scripting():
116 | # this must be by itself, cannot merge with has_apex_rmsnorm
117 | return rms_norm(x, normalized_shape, weight, eps)
118 |
119 | if has_apex_rmsnorm:
120 | if weight is None:
121 | return fused_rms_norm(x, normalized_shape, eps)
122 | else:
123 | return fused_rms_norm_affine(x, weight, normalized_shape, eps)
124 |
125 | # fallback
126 | return rms_norm(x, normalized_shape, weight, eps)
127 |
--------------------------------------------------------------------------------
/model/layers/filter_response_norm.py:
--------------------------------------------------------------------------------
1 | """ Filter Response Norm in PyTorch
2 |
3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
4 |
5 | Hacked together by / Copyright 2021 Ross Wightman
6 | """
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .create_act import create_act_layer
11 | from .trace_utils import _assert
12 |
13 |
14 | def inv_instance_rms(x, eps: float = 1e-5):
15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
16 | return rms.expand(x.shape)
17 |
18 |
19 | class FilterResponseNormTlu2d(nn.Module):
20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
21 | super(FilterResponseNormTlu2d, self).__init__()
22 | self.apply_act = apply_act # apply activation (non-linearity)
23 | self.rms = rms
24 | self.eps = eps
25 | self.weight = nn.Parameter(torch.ones(num_features))
26 | self.bias = nn.Parameter(torch.zeros(num_features))
27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | nn.init.ones_(self.weight)
32 | nn.init.zeros_(self.bias)
33 | if self.tau is not None:
34 | nn.init.zeros_(self.tau)
35 |
36 | def forward(self, x):
37 | _assert(x.dim() == 4, 'expected 4D input')
38 | x_dtype = x.dtype
39 | v_shape = (1, -1, 1, 1)
40 | x = x * inv_instance_rms(x, self.eps)
41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
43 |
44 |
45 | class FilterResponseNormAct2d(nn.Module):
46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
47 | super(FilterResponseNormAct2d, self).__init__()
48 | if act_layer is not None and apply_act:
49 | self.act = create_act_layer(act_layer, inplace=inplace)
50 | else:
51 | self.act = nn.Identity()
52 | self.rms = rms
53 | self.eps = eps
54 | self.weight = nn.Parameter(torch.ones(num_features))
55 | self.bias = nn.Parameter(torch.zeros(num_features))
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.ones_(self.weight)
60 | nn.init.zeros_(self.bias)
61 |
62 | def forward(self, x):
63 | _assert(x.dim() == 4, 'expected 4D input')
64 | x_dtype = x.dtype
65 | v_shape = (1, -1, 1, 1)
66 | x = x * inv_instance_rms(x, self.eps)
67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
68 | return self.act(x)
69 |
--------------------------------------------------------------------------------
/model/layers/format.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Union
3 |
4 | import torch
5 |
6 |
7 | class Format(str, Enum):
8 | NCHW = 'NCHW'
9 | NHWC = 'NHWC'
10 | NCL = 'NCL'
11 | NLC = 'NLC'
12 |
13 |
14 | FormatT = Union[str, Format]
15 |
16 |
17 | def get_spatial_dim(fmt: FormatT):
18 | fmt = Format(fmt)
19 | if fmt is Format.NLC:
20 | dim = (1,)
21 | elif fmt is Format.NCL:
22 | dim = (2,)
23 | elif fmt is Format.NHWC:
24 | dim = (1, 2)
25 | else:
26 | dim = (2, 3)
27 | return dim
28 |
29 |
30 | def get_channel_dim(fmt: FormatT):
31 | fmt = Format(fmt)
32 | if fmt is Format.NHWC:
33 | dim = 3
34 | elif fmt is Format.NLC:
35 | dim = 2
36 | else:
37 | dim = 1
38 | return dim
39 |
40 |
41 | def nchw_to(x: torch.Tensor, fmt: Format):
42 | if fmt == Format.NHWC:
43 | x = x.permute(0, 2, 3, 1)
44 | elif fmt == Format.NLC:
45 | x = x.flatten(2).transpose(1, 2)
46 | elif fmt == Format.NCL:
47 | x = x.flatten(2)
48 | return x
49 |
50 |
51 | def nhwc_to(x: torch.Tensor, fmt: Format):
52 | if fmt == Format.NCHW:
53 | x = x.permute(0, 3, 1, 2)
54 | elif fmt == Format.NLC:
55 | x = x.flatten(1, 2)
56 | elif fmt == Format.NCL:
57 | x = x.flatten(1, 2).transpose(1, 2)
58 | return x
59 |
--------------------------------------------------------------------------------
/model/layers/gather_excite.py:
--------------------------------------------------------------------------------
1 | """ Gather-Excite Attention Block
2 |
3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
4 |
5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
6 |
7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
8 | impl that covers all of the cases.
9 |
10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
11 |
12 | Hacked together by / Copyright 2021 Ross Wightman
13 | """
14 | import math
15 |
16 | from torch import nn as nn
17 | import torch.nn.functional as F
18 |
19 | from .create_act import create_act_layer, get_act_layer
20 | from .create_conv2d import create_conv2d
21 | from .helpers import make_divisible
22 | from .mlp import ConvMlp
23 |
24 |
25 | class GatherExcite(nn.Module):
26 | """ Gather-Excite Attention Module
27 | """
28 | def __init__(
29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
32 | super(GatherExcite, self).__init__()
33 | self.add_maxpool = add_maxpool
34 | act_layer = get_act_layer(act_layer)
35 | self.extent = extent
36 | if extra_params:
37 | self.gather = nn.Sequential()
38 | if extent == 0:
39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
40 | self.gather.add_module(
41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
42 | if norm_layer:
43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
44 | else:
45 | assert extent % 2 == 0
46 | num_conv = int(math.log2(extent))
47 | for i in range(num_conv):
48 | self.gather.add_module(
49 | f'conv{i + 1}',
50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
51 | if norm_layer:
52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
53 | if i != num_conv - 1:
54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
55 | else:
56 | self.gather = None
57 | if self.extent == 0:
58 | self.gk = 0
59 | self.gs = 0
60 | else:
61 | assert extent % 2 == 0
62 | self.gk = self.extent * 2 - 1
63 | self.gs = self.extent
64 |
65 | if not rd_channels:
66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
68 | self.gate = create_act_layer(gate_layer)
69 |
70 | def forward(self, x):
71 | size = x.shape[-2:]
72 | if self.gather is not None:
73 | x_ge = self.gather(x)
74 | else:
75 | if self.extent == 0:
76 | # global extent
77 | x_ge = x.mean(dim=(2, 3), keepdims=True)
78 | if self.add_maxpool:
79 | # experimental codepath, may remove or change
80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
81 | else:
82 | x_ge = F.avg_pool2d(
83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
84 | if self.add_maxpool:
85 | # experimental codepath, may remove or change
86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
87 | x_ge = self.mlp(x_ge)
88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
89 | x_ge = F.interpolate(x_ge, size=size)
90 | return x * self.gate(x_ge)
91 |
--------------------------------------------------------------------------------
/model/layers/global_context.py:
--------------------------------------------------------------------------------
1 | """ Global Context Attention Block
2 |
3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
4 | - https://arxiv.org/abs/1904.11492
5 |
6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet
7 |
8 | Hacked together by / Copyright 2021 Ross Wightman
9 | """
10 | from torch import nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .create_act import create_act_layer, get_act_layer
14 | from .helpers import make_divisible
15 | from .mlp import ConvMlp
16 | from .norm import LayerNorm2d
17 |
18 |
19 | class GlobalContext(nn.Module):
20 |
21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
23 | super(GlobalContext, self).__init__()
24 | act_layer = get_act_layer(act_layer)
25 |
26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
27 |
28 | if rd_channels is None:
29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
30 | if fuse_add:
31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
32 | else:
33 | self.mlp_add = None
34 | if fuse_scale:
35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
36 | else:
37 | self.mlp_scale = None
38 |
39 | self.gate = create_act_layer(gate_layer)
40 | self.init_last_zero = init_last_zero
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | if self.conv_attn is not None:
45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
46 | if self.mlp_add is not None:
47 | nn.init.zeros_(self.mlp_add.fc2.weight)
48 |
49 | def forward(self, x):
50 | B, C, H, W = x.shape
51 |
52 | if self.conv_attn is not None:
53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
56 | context = context.view(B, C, 1, 1)
57 | else:
58 | context = x.mean(dim=(2, 3), keepdim=True)
59 |
60 | if self.mlp_scale is not None:
61 | mlp_x = self.mlp_scale(context)
62 | x = x * self.gate(mlp_x)
63 | if self.mlp_add is not None:
64 | mlp_x = self.mlp_add(context)
65 | x = x + mlp_x
66 |
67 | return x
68 |
--------------------------------------------------------------------------------
/model/layers/grn.py:
--------------------------------------------------------------------------------
1 | """ Global Response Normalization Module
2 |
3 | Based on the GRN layer presented in
4 | `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
5 |
6 | This implementation
7 | * works for both NCHW and NHWC tensor layouts
8 | * uses affine param names matching existing torch norm layers
9 | * slightly improves eager mode performance via fused addcmul
10 |
11 | Hacked together by / Copyright 2023 Ross Wightman
12 | """
13 |
14 | import torch
15 | from torch import nn as nn
16 |
17 |
18 | class GlobalResponseNorm(nn.Module):
19 | """ Global Response Normalization layer
20 | """
21 | def __init__(self, dim, eps=1e-6, channels_last=True):
22 | super().__init__()
23 | self.eps = eps
24 | if channels_last:
25 | self.spatial_dim = (1, 2)
26 | self.channel_dim = -1
27 | self.wb_shape = (1, 1, 1, -1)
28 | else:
29 | self.spatial_dim = (2, 3)
30 | self.channel_dim = 1
31 | self.wb_shape = (1, -1, 1, 1)
32 |
33 | self.weight = nn.Parameter(torch.zeros(dim))
34 | self.bias = nn.Parameter(torch.zeros(dim))
35 |
36 | def forward(self, x):
37 | x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
38 | x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
39 | return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)
40 |
--------------------------------------------------------------------------------
/model/layers/helpers.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from itertools import repeat
6 | import collections.abc
7 |
8 |
9 | # From PyTorch internals
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
13 | return tuple(x)
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 |
18 | to_1tuple = _ntuple(1)
19 | to_2tuple = _ntuple(2)
20 | to_3tuple = _ntuple(3)
21 | to_4tuple = _ntuple(4)
22 | to_ntuple = _ntuple
23 |
24 |
25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
26 | min_value = min_value or divisor
27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28 | # Make sure that round down does not go down by more than 10%.
29 | if new_v < round_limit * v:
30 | new_v += divisor
31 | return new_v
32 |
33 |
34 | def extend_tuple(x, n):
35 | # pads a tuple to specified n by padding with last value
36 | if not isinstance(x, (tuple, list)):
37 | x = (x,)
38 | else:
39 | x = tuple(x)
40 | pad_n = n - len(x)
41 | if pad_n <= 0:
42 | return x[:n]
43 | return x + (x[-1],) * pad_n
44 |
--------------------------------------------------------------------------------
/model/layers/inplace_abn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | try:
5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync
6 | has_iabn = True
7 | except ImportError:
8 | has_iabn = False
9 |
10 | def inplace_abn(x, weight, bias, running_mean, running_var,
11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
12 | raise ImportError(
13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
14 |
15 | def inplace_abn_sync(**kwargs):
16 | inplace_abn(**kwargs)
17 |
18 |
19 | class InplaceAbn(nn.Module):
20 | """Activated Batch Normalization
21 |
22 | This gathers a BatchNorm and an activation function in a single module
23 |
24 | Parameters
25 | ----------
26 | num_features : int
27 | Number of feature channels in the input and output.
28 | eps : float
29 | Small constant to prevent numerical issues.
30 | momentum : float
31 | Momentum factor applied to compute running statistics.
32 | affine : bool
33 | If `True` apply learned scale and shift transformation after normalization.
34 | act_layer : str or nn.Module type
35 | Name or type of the activation functions, one of: `leaky_relu`, `elu`
36 | act_param : float
37 | Negative slope for the `leaky_relu` activation.
38 | """
39 |
40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None):
42 | super(InplaceAbn, self).__init__()
43 | self.num_features = num_features
44 | self.affine = affine
45 | self.eps = eps
46 | self.momentum = momentum
47 | if apply_act:
48 | if isinstance(act_layer, str):
49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '')
50 | self.act_name = act_layer if act_layer else 'identity'
51 | else:
52 | # convert act layer passed as type to string
53 | if act_layer == nn.ELU:
54 | self.act_name = 'elu'
55 | elif act_layer == nn.LeakyReLU:
56 | self.act_name = 'leaky_relu'
57 | elif act_layer is None or act_layer == nn.Identity:
58 | self.act_name = 'identity'
59 | else:
60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN'
61 | else:
62 | self.act_name = 'identity'
63 | self.act_param = act_param
64 | if self.affine:
65 | self.weight = nn.Parameter(torch.ones(num_features))
66 | self.bias = nn.Parameter(torch.zeros(num_features))
67 | else:
68 | self.register_parameter('weight', None)
69 | self.register_parameter('bias', None)
70 | self.register_buffer('running_mean', torch.zeros(num_features))
71 | self.register_buffer('running_var', torch.ones(num_features))
72 | self.reset_parameters()
73 |
74 | def reset_parameters(self):
75 | nn.init.constant_(self.running_mean, 0)
76 | nn.init.constant_(self.running_var, 1)
77 | if self.affine:
78 | nn.init.constant_(self.weight, 1)
79 | nn.init.constant_(self.bias, 0)
80 |
81 | def forward(self, x):
82 | output = inplace_abn(
83 | x, self.weight, self.bias, self.running_mean, self.running_var,
84 | self.training, self.momentum, self.eps, self.act_name, self.act_param)
85 | if isinstance(output, tuple):
86 | output = output[0]
87 | return output
88 |
--------------------------------------------------------------------------------
/model/layers/interpolate.py:
--------------------------------------------------------------------------------
1 | """ Interpolation helpers for timm layers
2 |
3 | RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations
4 | Copyright Shane Barratt, Apache 2.0 license
5 | """
6 | import torch
7 | from itertools import product
8 |
9 |
10 | class RegularGridInterpolator:
11 | """ Interpolate data defined on a rectilinear grid with even or uneven spacing.
12 | Produces similar results to scipy RegularGridInterpolator or interp2d
13 | in 'linear' mode.
14 |
15 | Taken from https://github.com/sbarratt/torch_interpolations
16 | """
17 |
18 | def __init__(self, points, values):
19 | self.points = points
20 | self.values = values
21 |
22 | assert isinstance(self.points, tuple) or isinstance(self.points, list)
23 | assert isinstance(self.values, torch.Tensor)
24 |
25 | self.ms = list(self.values.shape)
26 | self.n = len(self.points)
27 |
28 | assert len(self.ms) == self.n
29 |
30 | for i, p in enumerate(self.points):
31 | assert isinstance(p, torch.Tensor)
32 | assert p.shape[0] == self.values.shape[i]
33 |
34 | def __call__(self, points_to_interp):
35 | assert self.points is not None
36 | assert self.values is not None
37 |
38 | assert len(points_to_interp) == len(self.points)
39 | K = points_to_interp[0].shape[0]
40 | for x in points_to_interp:
41 | assert x.shape[0] == K
42 |
43 | idxs = []
44 | dists = []
45 | overalls = []
46 | for p, x in zip(self.points, points_to_interp):
47 | idx_right = torch.bucketize(x, p)
48 | idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
49 | idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
50 | dist_left = x - p[idx_left]
51 | dist_right = p[idx_right] - x
52 | dist_left[dist_left < 0] = 0.
53 | dist_right[dist_right < 0] = 0.
54 | both_zero = (dist_left == 0) & (dist_right == 0)
55 | dist_left[both_zero] = dist_right[both_zero] = 1.
56 |
57 | idxs.append((idx_left, idx_right))
58 | dists.append((dist_left, dist_right))
59 | overalls.append(dist_left + dist_right)
60 |
61 | numerator = 0.
62 | for indexer in product([0, 1], repeat=self.n):
63 | as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
64 | bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
65 | numerator += self.values[as_s] * \
66 | torch.prod(torch.stack(bs_s), dim=0)
67 | denominator = torch.prod(torch.stack(overalls), dim=0)
68 | return numerator / denominator
69 |
--------------------------------------------------------------------------------
/model/layers/linear.py:
--------------------------------------------------------------------------------
1 | """ Linear layer (alternate definition)
2 | """
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn as nn
6 |
7 |
8 | class Linear(nn.Linear):
9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10 |
11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13 | """
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | if torch.jit.is_scripting():
16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
18 | else:
19 | return F.linear(input, self.weight, self.bias)
20 |
--------------------------------------------------------------------------------
/model/layers/median_pool.py:
--------------------------------------------------------------------------------
1 | """ Median Pool
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .helpers import to_2tuple, to_4tuple
7 |
8 |
9 | class MedianPool2d(nn.Module):
10 | """ Median pool (usable as median filter when stride=1) module.
11 |
12 | Args:
13 | kernel_size: size of pooling kernel, int or 2-tuple
14 | stride: pool stride, int or 2-tuple
15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
16 | same: override padding and enforce same padding, boolean
17 | """
18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
19 | super(MedianPool2d, self).__init__()
20 | self.k = to_2tuple(kernel_size)
21 | self.stride = to_2tuple(stride)
22 | self.padding = to_4tuple(padding) # convert to l, r, t, b
23 | self.same = same
24 |
25 | def _padding(self, x):
26 | if self.same:
27 | ih, iw = x.size()[2:]
28 | if ih % self.stride[0] == 0:
29 | ph = max(self.k[0] - self.stride[0], 0)
30 | else:
31 | ph = max(self.k[0] - (ih % self.stride[0]), 0)
32 | if iw % self.stride[1] == 0:
33 | pw = max(self.k[1] - self.stride[1], 0)
34 | else:
35 | pw = max(self.k[1] - (iw % self.stride[1]), 0)
36 | pl = pw // 2
37 | pr = pw - pl
38 | pt = ph // 2
39 | pb = ph - pt
40 | padding = (pl, pr, pt, pb)
41 | else:
42 | padding = self.padding
43 | return padding
44 |
45 | def forward(self, x):
46 | x = F.pad(x, self._padding(x), mode='reflect')
47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
49 | return x
50 |
--------------------------------------------------------------------------------
/model/layers/mixed_conv2d.py:
--------------------------------------------------------------------------------
1 | """ PyTorch Mixed Convolution
2 |
3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 |
8 | import torch
9 | from torch import nn as nn
10 |
11 | from .conv2d_same import create_conv2d_pad
12 |
13 |
14 | def _split_channels(num_chan, num_groups):
15 | split = [num_chan // num_groups for _ in range(num_groups)]
16 | split[0] += num_chan - sum(split)
17 | return split
18 |
19 |
20 | class MixedConv2d(nn.ModuleDict):
21 | """ Mixed Grouped Convolution
22 |
23 | Based on MDConv and GroupedConv in MixNet impl:
24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
25 | """
26 | def __init__(self, in_channels, out_channels, kernel_size=3,
27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs):
28 | super(MixedConv2d, self).__init__()
29 |
30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
31 | num_groups = len(kernel_size)
32 | in_splits = _split_channels(in_channels, num_groups)
33 | out_splits = _split_channels(out_channels, num_groups)
34 | self.in_channels = sum(in_splits)
35 | self.out_channels = sum(out_splits)
36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
37 | conv_groups = in_ch if depthwise else 1
38 | # use add_module to keep key space clean
39 | self.add_module(
40 | str(idx),
41 | create_conv2d_pad(
42 | in_ch, out_ch, k, stride=stride,
43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
44 | )
45 | self.splits = in_splits
46 |
47 | def forward(self, x):
48 | x_split = torch.split(x, self.splits, 1)
49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
50 | x = torch.cat(x_out, 1)
51 | return x
52 |
--------------------------------------------------------------------------------
/model/layers/padding.py:
--------------------------------------------------------------------------------
1 | """ Padding Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import math
6 | from typing import List, Tuple
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | # Calculate symmetric padding for a convolution
13 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
14 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
15 | return padding
16 |
17 |
18 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
19 | def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
20 | if isinstance(x, torch.Tensor):
21 | return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0)
22 | else:
23 | return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
24 |
25 |
26 | # Can SAME padding for given args be done statically?
27 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
28 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
29 |
30 |
31 | def pad_same_arg(
32 | input_size: List[int],
33 | kernel_size: List[int],
34 | stride: List[int],
35 | dilation: List[int] = (1, 1),
36 | ) -> List[int]:
37 | ih, iw = input_size
38 | kh, kw = kernel_size
39 | pad_h = get_same_padding(ih, kh, stride[0], dilation[0])
40 | pad_w = get_same_padding(iw, kw, stride[1], dilation[1])
41 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
42 |
43 |
44 | # Dynamically pad input x with 'SAME' padding for conv with specified args
45 | def pad_same(
46 | x,
47 | kernel_size: List[int],
48 | stride: List[int],
49 | dilation: List[int] = (1, 1),
50 | value: float = 0,
51 | ):
52 | ih, iw = x.size()[-2:]
53 | pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
54 | pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
55 | x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value)
56 | return x
57 |
58 |
59 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
60 | dynamic = False
61 | if isinstance(padding, str):
62 | # for any string padding, the padding will be calculated for you, one of three ways
63 | padding = padding.lower()
64 | if padding == 'same':
65 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
66 | if is_static_pad(kernel_size, **kwargs):
67 | # static case, no extra overhead
68 | padding = get_padding(kernel_size, **kwargs)
69 | else:
70 | # dynamic 'SAME' padding, has runtime/GPU memory overhead
71 | padding = 0
72 | dynamic = True
73 | elif padding == 'valid':
74 | # 'VALID' padding, same as padding=0
75 | padding = 0
76 | else:
77 | # Default to PyTorch style 'same'-ish symmetric padding
78 | padding = get_padding(kernel_size, **kwargs)
79 | return padding, dynamic
80 |
--------------------------------------------------------------------------------
/model/layers/patch_dropout.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class PatchDropout(nn.Module):
8 | """
9 | https://arxiv.org/abs/2212.00794
10 | """
11 | return_indices: torch.jit.Final[bool]
12 |
13 | def __init__(
14 | self,
15 | prob: float = 0.5,
16 | num_prefix_tokens: int = 1,
17 | ordered: bool = False,
18 | return_indices: bool = False,
19 | ):
20 | super().__init__()
21 | assert 0 <= prob < 1.
22 | self.prob = prob
23 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
24 | self.ordered = ordered
25 | self.return_indices = return_indices
26 |
27 | def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
28 | if not self.training or self.prob == 0.:
29 | if self.return_indices:
30 | return x, None
31 | return x
32 |
33 | if self.num_prefix_tokens:
34 | prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
35 | else:
36 | prefix_tokens = None
37 |
38 | B = x.shape[0]
39 | L = x.shape[1]
40 | num_keep = max(1, int(L * (1. - self.prob)))
41 | keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
42 | if self.ordered:
43 | # NOTE does not need to maintain patch order in typical transformer use,
44 | # but possibly useful for debug / visualization
45 | keep_indices = keep_indices.sort(dim=-1)[0]
46 | x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
47 |
48 | if prefix_tokens is not None:
49 | x = torch.cat((prefix_tokens, x), dim=1)
50 |
51 | if self.return_indices:
52 | return x, keep_indices
53 | return x
54 |
--------------------------------------------------------------------------------
/model/layers/pool2d_same.py:
--------------------------------------------------------------------------------
1 | """ AvgPool2d w/ Same Padding
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import List, Tuple, Optional
9 |
10 | from .helpers import to_2tuple
11 | from .padding import pad_same, get_padding_value
12 |
13 |
14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
15 | ceil_mode: bool = False, count_include_pad: bool = True):
16 | # FIXME how to deal with count_include_pad vs not for external padding?
17 | x = pad_same(x, kernel_size, stride)
18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
19 |
20 |
21 | class AvgPool2dSame(nn.AvgPool2d):
22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling
23 | """
24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
25 | kernel_size = to_2tuple(kernel_size)
26 | stride = to_2tuple(stride)
27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
28 |
29 | def forward(self, x):
30 | x = pad_same(x, self.kernel_size, self.stride)
31 | return F.avg_pool2d(
32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
33 |
34 |
35 | def max_pool2d_same(
36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
37 | dilation: List[int] = (1, 1), ceil_mode: bool = False):
38 | x = pad_same(x, kernel_size, stride, value=-float('inf'))
39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
40 |
41 |
42 | class MaxPool2dSame(nn.MaxPool2d):
43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling
44 | """
45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
46 | kernel_size = to_2tuple(kernel_size)
47 | stride = to_2tuple(stride)
48 | dilation = to_2tuple(dilation)
49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
50 |
51 | def forward(self, x):
52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
54 |
55 |
56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
57 | stride = stride or kernel_size
58 | padding = kwargs.pop('padding', '')
59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
60 | if is_dynamic:
61 | if pool_type == 'avg':
62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
63 | elif pool_type == 'max':
64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
65 | else:
66 | assert False, f'Unsupported pool type {pool_type}'
67 | else:
68 | if pool_type == 'avg':
69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
70 | elif pool_type == 'max':
71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
72 | else:
73 | assert False, f'Unsupported pool type {pool_type}'
74 |
--------------------------------------------------------------------------------
/model/layers/pos_embed.py:
--------------------------------------------------------------------------------
1 | """ Position Embedding Utilities
2 |
3 | Hacked together by / Copyright 2022 Ross Wightman
4 | """
5 | import logging
6 | import math
7 | from typing import List, Tuple, Optional, Union
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from .helpers import to_2tuple
13 |
14 | _logger = logging.getLogger(__name__)
15 |
16 |
17 | def resample_abs_pos_embed(
18 | posemb,
19 | new_size: List[int],
20 | old_size: Optional[List[int]] = None,
21 | num_prefix_tokens: int = 1,
22 | interpolation: str = 'bicubic',
23 | antialias: bool = True,
24 | verbose: bool = False,
25 | ):
26 | # sort out sizes, assume square if old size not provided
27 | num_pos_tokens = posemb.shape[1]
28 | num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
29 | if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
30 | return posemb
31 |
32 | if old_size is None:
33 | hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
34 | old_size = hw, hw
35 |
36 | if num_prefix_tokens:
37 | posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
38 | else:
39 | posemb_prefix, posemb = None, posemb
40 |
41 | # do the interpolation
42 | embed_dim = posemb.shape[-1]
43 | orig_dtype = posemb.dtype
44 | posemb = posemb.float() # interpolate needs float32
45 | posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
46 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
47 | posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
48 | posemb = posemb.to(orig_dtype)
49 |
50 | # add back extra (class, etc) prefix tokens
51 | if posemb_prefix is not None:
52 | posemb = torch.cat([posemb_prefix, posemb], dim=1)
53 |
54 | if not torch.jit.is_scripting() and verbose:
55 | _logger.info(f'Resized position embedding: {old_size} to {new_size}.')
56 |
57 | return posemb
58 |
59 |
60 | def resample_abs_pos_embed_nhwc(
61 | posemb,
62 | new_size: List[int],
63 | interpolation: str = 'bicubic',
64 | antialias: bool = True,
65 | verbose: bool = False,
66 | ):
67 | if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
68 | return posemb
69 |
70 | orig_dtype = posemb.dtype
71 | posemb = posemb.float()
72 | # do the interpolation
73 | posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
74 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
75 | posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
76 |
77 | if not torch.jit.is_scripting() and verbose:
78 | _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
79 |
80 | return posemb
81 |
--------------------------------------------------------------------------------
/model/layers/separable_conv.py:
--------------------------------------------------------------------------------
1 | """ Depthwise Separable Conv Modules
2 |
3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | from torch import nn as nn
9 |
10 | from .create_conv2d import create_conv2d
11 | from .create_norm_act import get_norm_act_layer
12 |
13 |
14 | class SeparableConvNormAct(nn.Module):
15 | """ Separable Conv w/ trailing Norm and Activation
16 | """
17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
19 | apply_act=True, drop_layer=None):
20 | super(SeparableConvNormAct, self).__init__()
21 |
22 | self.conv_dw = create_conv2d(
23 | in_channels, int(in_channels * channel_multiplier), kernel_size,
24 | stride=stride, dilation=dilation, padding=padding, depthwise=True)
25 |
26 | self.conv_pw = create_conv2d(
27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
28 |
29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
32 |
33 | @property
34 | def in_channels(self):
35 | return self.conv_dw.in_channels
36 |
37 | @property
38 | def out_channels(self):
39 | return self.conv_pw.out_channels
40 |
41 | def forward(self, x):
42 | x = self.conv_dw(x)
43 | x = self.conv_pw(x)
44 | x = self.bn(x)
45 | return x
46 |
47 |
48 | SeparableConvBnAct = SeparableConvNormAct
49 |
50 |
51 | class SeparableConv2d(nn.Module):
52 | """ Separable Conv
53 | """
54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
55 | channel_multiplier=1.0, pw_kernel_size=1):
56 | super(SeparableConv2d, self).__init__()
57 |
58 | self.conv_dw = create_conv2d(
59 | in_channels, int(in_channels * channel_multiplier), kernel_size,
60 | stride=stride, dilation=dilation, padding=padding, depthwise=True)
61 |
62 | self.conv_pw = create_conv2d(
63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
64 |
65 | @property
66 | def in_channels(self):
67 | return self.conv_dw.in_channels
68 |
69 | @property
70 | def out_channels(self):
71 | return self.conv_pw.out_channels
72 |
73 | def forward(self, x):
74 | x = self.conv_dw(x)
75 | x = self.conv_pw(x)
76 | return x
77 |
--------------------------------------------------------------------------------
/model/layers/space_to_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SpaceToDepth(nn.Module):
6 | bs: torch.jit.Final[int]
7 |
8 | def __init__(self, block_size=4):
9 | super().__init__()
10 | assert block_size == 4
11 | self.bs = block_size
12 |
13 | def forward(self, x):
14 | N, C, H, W = x.size()
15 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
16 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
17 | x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
18 | return x
19 |
20 |
21 | @torch.jit.script
22 | class SpaceToDepthJit:
23 | def __call__(self, x: torch.Tensor):
24 | # assuming hard-coded that block_size==4 for acceleration
25 | N, C, H, W = x.size()
26 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
27 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
28 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
29 | return x
30 |
31 |
32 | class SpaceToDepthModule(nn.Module):
33 | def __init__(self, no_jit=False):
34 | super().__init__()
35 | if not no_jit:
36 | self.op = SpaceToDepthJit()
37 | else:
38 | self.op = SpaceToDepth()
39 |
40 | def forward(self, x):
41 | return self.op(x)
42 |
43 |
44 | class DepthToSpace(nn.Module):
45 |
46 | def __init__(self, block_size):
47 | super().__init__()
48 | self.bs = block_size
49 |
50 | def forward(self, x):
51 | N, C, H, W = x.size()
52 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
53 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
54 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
55 | return x
56 |
--------------------------------------------------------------------------------
/model/layers/split_attn.py:
--------------------------------------------------------------------------------
1 | """ Split Attention Conv2d (for ResNeSt Models)
2 |
3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4 |
5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
6 |
7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
8 | """
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 |
13 | from .helpers import make_divisible
14 |
15 |
16 | class RadixSoftmax(nn.Module):
17 | def __init__(self, radix, cardinality):
18 | super(RadixSoftmax, self).__init__()
19 | self.radix = radix
20 | self.cardinality = cardinality
21 |
22 | def forward(self, x):
23 | batch = x.size(0)
24 | if self.radix > 1:
25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
26 | x = F.softmax(x, dim=1)
27 | x = x.reshape(batch, -1)
28 | else:
29 | x = torch.sigmoid(x)
30 | return x
31 |
32 |
33 | class SplitAttn(nn.Module):
34 | """Split-Attention (aka Splat)
35 | """
36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
39 | super(SplitAttn, self).__init__()
40 | out_channels = out_channels or in_channels
41 | self.radix = radix
42 | mid_chs = out_channels * radix
43 | if rd_channels is None:
44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
45 | else:
46 | attn_chs = rd_channels * radix
47 |
48 | padding = kernel_size // 2 if padding is None else padding
49 | self.conv = nn.Conv2d(
50 | in_channels, mid_chs, kernel_size, stride, padding, dilation,
51 | groups=groups * radix, bias=bias, **kwargs)
52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
54 | self.act0 = act_layer(inplace=True)
55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
57 | self.act1 = act_layer(inplace=True)
58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
59 | self.rsoftmax = RadixSoftmax(radix, groups)
60 |
61 | def forward(self, x):
62 | x = self.conv(x)
63 | x = self.bn0(x)
64 | x = self.drop(x)
65 | x = self.act0(x)
66 |
67 | B, RC, H, W = x.shape
68 | if self.radix > 1:
69 | x = x.reshape((B, self.radix, RC // self.radix, H, W))
70 | x_gap = x.sum(dim=1)
71 | else:
72 | x_gap = x
73 | x_gap = x_gap.mean((2, 3), keepdim=True)
74 | x_gap = self.fc1(x_gap)
75 | x_gap = self.bn1(x_gap)
76 | x_gap = self.act1(x_gap)
77 | x_attn = self.fc2(x_gap)
78 |
79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
80 | if self.radix > 1:
81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
82 | else:
83 | out = x * x_attn
84 | return out.contiguous()
85 |
--------------------------------------------------------------------------------
/model/layers/split_batchnorm.py:
--------------------------------------------------------------------------------
1 | """ Split BatchNorm
2 |
3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias
5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
6 | namespace.
7 |
8 | This allows easily removing the auxiliary BN layers after training to efficiently
9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
10 | 'Disentangled Learning via An Auxiliary BN'
11 |
12 | Hacked together by / Copyright 2020 Ross Wightman
13 | """
14 | import torch
15 | import torch.nn as nn
16 |
17 |
18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d):
19 |
20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
21 | track_running_stats=True, num_splits=2):
22 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
24 | self.num_splits = num_splits
25 | self.aux_bn = nn.ModuleList([
26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
27 |
28 | def forward(self, input: torch.Tensor):
29 | if self.training: # aux BN only relevant while training
30 | split_size = input.shape[0] // self.num_splits
31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
32 | split_input = input.split(split_size)
33 | x = [super().forward(split_input[0])]
34 | for i, a in enumerate(self.aux_bn):
35 | x.append(a(split_input[i + 1]))
36 | return torch.cat(x, dim=0)
37 | else:
38 | return super().forward(input)
39 |
40 |
41 | def convert_splitbn_model(module, num_splits=2):
42 | """
43 | Recursively traverse module and its children to replace all instances of
44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
45 | Args:
46 | module (torch.nn.Module): input module
47 | num_splits: number of separate batchnorm layers to split input across
48 | Example::
49 | >>> # model is an instance of torch.nn.Module
50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
51 | """
52 | mod = module
53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
54 | return module
55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
56 | mod = SplitBatchNorm2d(
57 | module.num_features, module.eps, module.momentum, module.affine,
58 | module.track_running_stats, num_splits=num_splits)
59 | mod.running_mean = module.running_mean
60 | mod.running_var = module.running_var
61 | mod.num_batches_tracked = module.num_batches_tracked
62 | if module.affine:
63 | mod.weight.data = module.weight.data.clone().detach()
64 | mod.bias.data = module.bias.data.clone().detach()
65 | for aux in mod.aux_bn:
66 | aux.running_mean = module.running_mean.clone()
67 | aux.running_var = module.running_var.clone()
68 | aux.num_batches_tracked = module.num_batches_tracked.clone()
69 | if module.affine:
70 | aux.weight.data = module.weight.data.clone().detach()
71 | aux.bias.data = module.bias.data.clone().detach()
72 | for name, child in module.named_children():
73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
74 | del module
75 | return mod
76 |
--------------------------------------------------------------------------------
/model/layers/squeeze_excite.py:
--------------------------------------------------------------------------------
1 | """ Squeeze-and-Excitation Channel Attention
2 |
3 | An SE implementation originally based on PyTorch SE-Net impl.
4 | Has since evolved with additional functionality / configuration.
5 |
6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
7 |
8 | Also included is Effective Squeeze-Excitation (ESE).
9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
10 |
11 | Hacked together by / Copyright 2021 Ross Wightman
12 | """
13 | from torch import nn as nn
14 |
15 | from .create_act import create_act_layer
16 | from .helpers import make_divisible
17 |
18 |
19 | class SEModule(nn.Module):
20 | """ SE Module as defined in original SE-Nets with a few additions
21 | Additions include:
22 | * divisor can be specified to keep channels % div == 0 (default: 8)
23 | * reduction channels can be specified directly by arg (if rd_channels is set)
24 | * reduction channels can be specified by float rd_ratio (default: 1/16)
25 | * global max pooling can be added to the squeeze aggregation
26 | * customizable activation, normalization, and gate layer
27 | """
28 | def __init__(
29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
30 | bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
31 | super(SEModule, self).__init__()
32 | self.add_maxpool = add_maxpool
33 | if not rd_channels:
34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias)
36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
37 | self.act = create_act_layer(act_layer, inplace=True)
38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias)
39 | self.gate = create_act_layer(gate_layer)
40 |
41 | def forward(self, x):
42 | x_se = x.mean((2, 3), keepdim=True)
43 | if self.add_maxpool:
44 | # experimental codepath, may remove or change
45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
46 | x_se = self.fc1(x_se)
47 | x_se = self.act(self.bn(x_se))
48 | x_se = self.fc2(x_se)
49 | return x * self.gate(x_se)
50 |
51 |
52 | SqueezeExcite = SEModule # alias
53 |
54 |
55 | class EffectiveSEModule(nn.Module):
56 | """ 'Effective Squeeze-Excitation
57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
58 | """
59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
60 | super(EffectiveSEModule, self).__init__()
61 | self.add_maxpool = add_maxpool
62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
63 | self.gate = create_act_layer(gate_layer)
64 |
65 | def forward(self, x):
66 | x_se = x.mean((2, 3), keepdim=True)
67 | if self.add_maxpool:
68 | # experimental codepath, may remove or change
69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
70 | x_se = self.fc(x_se)
71 | return x * self.gate(x_se)
72 |
73 |
74 | EffectiveSqueezeExcite = EffectiveSEModule # alias
75 |
76 |
77 | class SqueezeExciteCl(nn.Module):
78 | """ SE Module as defined in original SE-Nets with a few additions
79 | Additions include:
80 | * divisor can be specified to keep channels % div == 0 (default: 8)
81 | * reduction channels can be specified directly by arg (if rd_channels is set)
82 | * reduction channels can be specified by float rd_ratio (default: 1/16)
83 | * global max pooling can be added to the squeeze aggregation
84 | * customizable activation, normalization, and gate layer
85 | """
86 | def __init__(
87 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8,
88 | bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'):
89 | super().__init__()
90 | if not rd_channels:
91 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
92 | self.fc1 = nn.Linear(channels, rd_channels, bias=bias)
93 | self.act = create_act_layer(act_layer, inplace=True)
94 | self.fc2 = nn.Linear(rd_channels, channels, bias=bias)
95 | self.gate = create_act_layer(gate_layer)
96 |
97 | def forward(self, x):
98 | x_se = x.mean((1, 2), keepdims=True) # FIXME avg dim [1:n-1], don't assume 2D NHWC
99 | x_se = self.fc1(x_se)
100 | x_se = self.act(x_se)
101 | x_se = self.fc2(x_se)
102 | return x * self.gate(x_se)
--------------------------------------------------------------------------------
/model/layers/test_time_pool.py:
--------------------------------------------------------------------------------
1 | """ Test Time Pooling (Average-Max Pool)
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | import logging
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
11 |
12 |
13 | _logger = logging.getLogger(__name__)
14 |
15 |
16 | class TestTimePoolHead(nn.Module):
17 | def __init__(self, base, original_pool=7):
18 | super(TestTimePoolHead, self).__init__()
19 | self.base = base
20 | self.original_pool = original_pool
21 | base_fc = self.base.get_classifier()
22 | if isinstance(base_fc, nn.Conv2d):
23 | self.fc = base_fc
24 | else:
25 | self.fc = nn.Conv2d(
26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
29 | self.base.reset_classifier(0) # delete original fc layer
30 |
31 | def forward(self, x):
32 | x = self.base.forward_features(x)
33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
34 | x = self.fc(x)
35 | x = adaptive_avgmax_pool2d(x, 1)
36 | return x.view(x.size(0), -1)
37 |
38 |
39 | def apply_test_time_pool(model, config, use_test_size=False):
40 | test_time_pool = False
41 | if not hasattr(model, 'default_cfg') or not model.default_cfg:
42 | return model, False
43 | if use_test_size and 'test_input_size' in model.default_cfg:
44 | df_input_size = model.default_cfg['test_input_size']
45 | else:
46 | df_input_size = model.default_cfg['input_size']
47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
49 | (str(config['input_size'][-2:]), str(df_input_size[-2:])))
50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
51 | test_time_pool = True
52 | return model, test_time_pool
53 |
--------------------------------------------------------------------------------
/model/layers/trace_utils.py:
--------------------------------------------------------------------------------
1 | try:
2 | from torch import _assert
3 | except ImportError:
4 | def _assert(condition: bool, message: str):
5 | assert condition, message
6 |
7 |
8 | def _float_to_int(x: float) -> int:
9 | """
10 | Symbolic tracing helper to substitute for inbuilt `int`.
11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy`
12 | """
13 | return int(x)
14 |
--------------------------------------------------------------------------------
/model/layers/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple, Type, Union
2 |
3 | import torch
4 |
5 |
6 | LayerType = Union[str, Callable, Type[torch.nn.Module]]
7 | PadType = Union[str, int, Tuple[int, int]]
8 |
--------------------------------------------------------------------------------
/model/layers/weight_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 |
5 | from torch.nn.init import _calculate_fan_in_and_fan_out
6 |
7 |
8 | def _trunc_normal_(tensor, mean, std, a, b):
9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11 | def norm_cdf(x):
12 | # Computes standard normal cumulative distribution function
13 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
14 |
15 | if (mean < a - 2 * std) or (mean > b + 2 * std):
16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17 | "The distribution of values may be incorrect.",
18 | stacklevel=2)
19 |
20 | # Values are generated by using a truncated uniform distribution and
21 | # then using the inverse CDF for the normal distribution.
22 | # Get upper and lower cdf values
23 | l = norm_cdf((a - mean) / std)
24 | u = norm_cdf((b - mean) / std)
25 |
26 | # Uniformly fill tensor with values from [l, u], then translate to
27 | # [2l-1, 2u-1].
28 | tensor.uniform_(2 * l - 1, 2 * u - 1)
29 |
30 | # Use inverse cdf transform for normal distribution to get truncated
31 | # standard normal
32 | tensor.erfinv_()
33 |
34 | # Transform to proper mean, std
35 | tensor.mul_(std * math.sqrt(2.))
36 | tensor.add_(mean)
37 |
38 | # Clamp to ensure it's in the proper range
39 | tensor.clamp_(min=a, max=b)
40 | return tensor
41 |
42 |
43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
44 | # type: (Tensor, float, float, float, float) -> Tensor
45 | r"""Fills the input Tensor with values drawn from a truncated
46 | normal distribution. The values are effectively drawn from the
47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
48 | with values outside :math:`[a, b]` redrawn until they are within
49 | the bounds. The method used for generating the random values works
50 | best when :math:`a \leq \text{mean} \leq b`.
51 |
52 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
53 | applied while sampling the normal with mean/std applied, therefore a, b args
54 | should be adjusted to match the range of mean, std args.
55 |
56 | Args:
57 | tensor: an n-dimensional `torch.Tensor`
58 | mean: the mean of the normal distribution
59 | std: the standard deviation of the normal distribution
60 | a: the minimum cutoff value
61 | b: the maximum cutoff value
62 | Examples:
63 | >>> w = torch.empty(3, 5)
64 | >>> nn.init.trunc_normal_(w)
65 | """
66 | with torch.no_grad():
67 | return _trunc_normal_(tensor, mean, std, a, b)
68 |
69 |
70 | def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
71 | # type: (Tensor, float, float, float, float) -> Tensor
72 | r"""Fills the input Tensor with values drawn from a truncated
73 | normal distribution. The values are effectively drawn from the
74 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
75 | with values outside :math:`[a, b]` redrawn until they are within
76 | the bounds. The method used for generating the random values works
77 | best when :math:`a \leq \text{mean} \leq b`.
78 |
79 | NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
80 | bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
81 | and the result is subsquently scaled and shifted by the mean and std args.
82 |
83 | Args:
84 | tensor: an n-dimensional `torch.Tensor`
85 | mean: the mean of the normal distribution
86 | std: the standard deviation of the normal distribution
87 | a: the minimum cutoff value
88 | b: the maximum cutoff value
89 | Examples:
90 | >>> w = torch.empty(3, 5)
91 | >>> nn.init.trunc_normal_(w)
92 | """
93 | with torch.no_grad():
94 | _trunc_normal_(tensor, 0, 1.0, a, b)
95 | tensor.mul_(std).add_(mean)
96 | return tensor
97 |
98 |
99 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
100 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
101 | if mode == 'fan_in':
102 | denom = fan_in
103 | elif mode == 'fan_out':
104 | denom = fan_out
105 | elif mode == 'fan_avg':
106 | denom = (fan_in + fan_out) / 2
107 |
108 | variance = scale / denom
109 |
110 | if distribution == "truncated_normal":
111 | # constant is stddev of standard normal truncated to (-2, 2)
112 | trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
113 | elif distribution == "normal":
114 | with torch.no_grad():
115 | tensor.normal_(std=math.sqrt(variance))
116 | elif distribution == "uniform":
117 | bound = math.sqrt(3 * variance)
118 | with torch.no_grad():
119 | tensor.uniform_(-bound, bound)
120 | else:
121 | raise ValueError(f"invalid distribution {distribution}")
122 |
123 |
124 | def lecun_normal_(tensor):
125 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
126 |
--------------------------------------------------------------------------------
/my_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | @FileName:my_dataset.py\n
3 | @Description:\n
4 | @Author:WBobby\n
5 | @Department:CUG\n
6 | @Time:2023/4/29 16:04\n
7 | """
8 | from PIL import Image
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 |
13 | class MyDataSet(Dataset):
14 | """自定义数据集"""
15 |
16 | def __init__(self, images_path: list, images_class: list, transform=None):
17 | self.images_path = images_path
18 | self.images_class = images_class
19 | self.transform = transform
20 |
21 | def __len__(self):
22 | return len(self.images_path)
23 |
24 | def __getitem__(self, item):
25 | img = Image.open(self.images_path[item])
26 | # RGB为彩色图片,L为灰度图片
27 | if img.mode != 'RGB':
28 | raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
29 | label = self.images_class[item]
30 |
31 | if self.transform is not None:
32 | img = self.transform(img)
33 |
34 | return img, label
35 |
36 |
37 |
38 |
39 | @staticmethod
40 | def collate_fn(batch):
41 | # 官方实现的default_collate可以参考
42 | # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
43 | images, labels = tuple(zip(*batch))
44 |
45 | images = torch.stack(images, dim=0)
46 | labels = torch.as_tensor(labels)
47 | return images, labels
48 |
--------------------------------------------------------------------------------
/parameter.py:
--------------------------------------------------------------------------------
1 | """
2 | @FileName:parameter.py\n
3 | @Description:\n
4 | @Author:WBobby\n
5 | @Department:CUG\n
6 | @Time:2023/6/13 16:10\n
7 | """
8 | import torch
9 | from coca_pytorch import CoCa
10 | from thop import profile
11 | from torch import nn
12 | from torchsummary import summary
13 | from torchvision.models import vgg16, resnet50
14 | from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
15 | from vit_pytorch.extractor import Extractor
16 | import timm
17 | from BBnet.model import PVFCNet
18 | from EfficientNET.model import efficientnet_b0
19 | from pytorch.timm.models.mobilevit import mobilevit_s
20 | from timm.models.effnetv2 import effnetv2_s
21 |
22 | x = torch.rand(size=(1, 3, 112, 112))
23 | model = timm.create_model('resnet50', pretrained=False, num_classes=10)
24 | # 为网络重写分类层
25 | # model = PVFCNet(11, 5, 10)
26 | # model = effnetv2_s()
27 | # output = model(x)
28 | # print(output.shape)
29 | print(summary(model, (3, 112, 112), device="cpu"))
30 | # print('model_name:coat_small')
31 | # print('flops:{}, params:{}'.format(model, params))
32 |
--------------------------------------------------------------------------------