├── .gitattributes ├── .idea ├── .gitignore ├── PV_Classify.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── Matrix.py ├── PVF-10.zip ├── Readme.txt ├── class_indices.json ├── fig ├── __pycache__ │ ├── chart.cpython-38.pyc │ └── con_mat_fig.cpython-38.pyc ├── chart.py ├── con mat.py ├── con_mat_10.png ├── con_mat_compare_samp_pad.png ├── con_mat_compare_samp_pad531.png ├── con_mat_fig.py ├── figure_temp.py ├── read_jpg.py ├── train_loss_fig.py └── vitslog.csv ├── model ├── 111.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── _builder.cpython-38.pyc │ ├── _efficientnet_blocks.cpython-38.pyc │ ├── _efficientnet_builder.cpython-38.pyc │ ├── _factory.cpython-38.pyc │ ├── _features.cpython-38.pyc │ ├── _features_fx.cpython-38.pyc │ ├── _helpers.cpython-38.pyc │ ├── _hub.cpython-38.pyc │ ├── _manipulate.cpython-38.pyc │ ├── _pretrained.cpython-38.pyc │ ├── _prune.cpython-38.pyc │ ├── _registry.cpython-38.pyc │ ├── coat.cpython-38.pyc │ ├── efficientnet.cpython-38.pyc │ ├── effnetv2.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ ├── swin_transformer_v2.cpython-38.pyc │ └── vision_transformer.cpython-38.pyc ├── _builder.py ├── _efficientnet_blocks.py ├── _efficientnet_builder.py ├── _factory.py ├── _features.py ├── _features_fx.py ├── _helpers.py ├── _hub.py ├── _manipulate.py ├── _pretrained.py ├── _prune.py ├── _registry.py ├── coat.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── auto_augment.cpython-38.pyc │ │ ├── config.cpython-38.pyc │ │ ├── constants.cpython-38.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── dataset_factory.cpython-38.pyc │ │ ├── dataset_info.cpython-38.pyc │ │ ├── distributed_sampler.cpython-38.pyc │ │ ├── imagenet_info.cpython-38.pyc │ │ ├── loader.cpython-38.pyc │ │ ├── mixup.cpython-38.pyc │ │ ├── random_erasing.cpython-38.pyc │ │ ├── real_labels.cpython-38.pyc │ │ ├── transforms.cpython-38.pyc │ │ └── transforms_factory.cpython-38.pyc │ ├── _info │ │ ├── imagenet12k_synsets.txt │ │ ├── imagenet21k_goog_synsets.txt │ │ ├── imagenet21k_goog_to_12k_indices.txt │ │ ├── imagenet21k_goog_to_22k_indices.txt │ │ ├── imagenet21k_miil_synsets.txt │ │ ├── imagenet21k_miil_w21_synsets.txt │ │ ├── imagenet22k_ms_synsets.txt │ │ ├── imagenet22k_ms_to_12k_indices.txt │ │ ├── imagenet22k_ms_to_22k_indices.txt │ │ ├── imagenet22k_synsets.txt │ │ ├── imagenet22k_to_12k_indices.txt │ │ ├── imagenet_a_indices.txt │ │ ├── imagenet_a_synsets.txt │ │ ├── imagenet_r_indices.txt │ │ ├── imagenet_r_synsets.txt │ │ ├── imagenet_real_labels.json │ │ ├── imagenet_synset_to_definition.txt │ │ ├── imagenet_synset_to_lemma.txt │ │ └── imagenet_synsets.txt │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── dataset_factory.py │ ├── dataset_info.py │ ├── distributed_sampler.py │ ├── imagenet_info.py │ ├── loader.py │ ├── mixup.py │ ├── random_erasing.py │ ├── readers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── class_map.cpython-38.pyc │ │ │ ├── img_extensions.cpython-38.pyc │ │ │ ├── reader.cpython-38.pyc │ │ │ ├── reader_factory.cpython-38.pyc │ │ │ ├── reader_image_folder.cpython-38.pyc │ │ │ └── reader_image_in_tar.cpython-38.pyc │ │ ├── class_map.py │ │ ├── img_extensions.py │ │ ├── reader.py │ │ ├── reader_factory.py │ │ ├── reader_hfds.py │ │ ├── reader_image_folder.py │ │ ├── reader_image_in_tar.py │ │ ├── reader_image_tar.py │ │ ├── reader_tfds.py │ │ ├── reader_wds.py │ │ └── shared_count.py │ ├── real_labels.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── efficientnet.py ├── effnetv2.py ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── activations.cpython-38.pyc │ │ ├── activations_jit.cpython-38.pyc │ │ ├── activations_me.cpython-38.pyc │ │ ├── adaptive_avgmax_pool.cpython-38.pyc │ │ ├── attention_pool.cpython-38.pyc │ │ ├── attention_pool2d.cpython-38.pyc │ │ ├── blur_pool.cpython-38.pyc │ │ ├── bottleneck_attn.cpython-38.pyc │ │ ├── cbam.cpython-38.pyc │ │ ├── classifier.cpython-38.pyc │ │ ├── cond_conv2d.cpython-38.pyc │ │ ├── config.cpython-38.pyc │ │ ├── conv2d_same.cpython-38.pyc │ │ ├── conv_bn_act.cpython-38.pyc │ │ ├── create_act.cpython-38.pyc │ │ ├── create_attn.cpython-38.pyc │ │ ├── create_conv2d.cpython-38.pyc │ │ ├── create_norm.cpython-38.pyc │ │ ├── create_norm_act.cpython-38.pyc │ │ ├── drop.cpython-38.pyc │ │ ├── eca.cpython-38.pyc │ │ ├── evo_norm.cpython-38.pyc │ │ ├── fast_norm.cpython-38.pyc │ │ ├── filter_response_norm.cpython-38.pyc │ │ ├── format.cpython-38.pyc │ │ ├── gather_excite.cpython-38.pyc │ │ ├── global_context.cpython-38.pyc │ │ ├── grn.cpython-38.pyc │ │ ├── halo_attn.cpython-38.pyc │ │ ├── helpers.cpython-38.pyc │ │ ├── inplace_abn.cpython-38.pyc │ │ ├── interpolate.cpython-38.pyc │ │ ├── lambda_layer.cpython-38.pyc │ │ ├── linear.cpython-38.pyc │ │ ├── mixed_conv2d.cpython-38.pyc │ │ ├── mlp.cpython-38.pyc │ │ ├── non_local_attn.cpython-38.pyc │ │ ├── norm.cpython-38.pyc │ │ ├── norm_act.cpython-38.pyc │ │ ├── padding.cpython-38.pyc │ │ ├── patch_dropout.cpython-38.pyc │ │ ├── patch_embed.cpython-38.pyc │ │ ├── pool2d_same.cpython-38.pyc │ │ ├── pos_embed.cpython-38.pyc │ │ ├── pos_embed_rel.cpython-38.pyc │ │ ├── pos_embed_sincos.cpython-38.pyc │ │ ├── selective_kernel.cpython-38.pyc │ │ ├── separable_conv.cpython-38.pyc │ │ ├── space_to_depth.cpython-38.pyc │ │ ├── split_attn.cpython-38.pyc │ │ ├── split_batchnorm.cpython-38.pyc │ │ ├── squeeze_excite.cpython-38.pyc │ │ ├── std_conv.cpython-38.pyc │ │ ├── test_time_pool.cpython-38.pyc │ │ ├── trace_utils.cpython-38.pyc │ │ ├── typing.cpython-38.pyc │ │ └── weight_init.cpython-38.pyc │ ├── activations.py │ ├── activations_jit.py │ ├── activations_me.py │ ├── adaptive_avgmax_pool.py │ ├── attention_pool.py │ ├── attention_pool2d.py │ ├── blur_pool.py │ ├── bottleneck_attn.py │ ├── cbam.py │ ├── classifier.py │ ├── cond_conv2d.py │ ├── config.py │ ├── conv2d_same.py │ ├── conv_bn_act.py │ ├── create_act.py │ ├── create_attn.py │ ├── create_conv2d.py │ ├── create_norm.py │ ├── create_norm_act.py │ ├── drop.py │ ├── eca.py │ ├── evo_norm.py │ ├── fast_norm.py │ ├── filter_response_norm.py │ ├── format.py │ ├── gather_excite.py │ ├── global_context.py │ ├── grn.py │ ├── halo_attn.py │ ├── helpers.py │ ├── inplace_abn.py │ ├── interpolate.py │ ├── lambda_layer.py │ ├── linear.py │ ├── median_pool.py │ ├── mixed_conv2d.py │ ├── ml_decoder.py │ ├── mlp.py │ ├── non_local_attn.py │ ├── norm.py │ ├── norm_act.py │ ├── padding.py │ ├── patch_dropout.py │ ├── patch_embed.py │ ├── pool2d_same.py │ ├── pos_embed.py │ ├── pos_embed_rel.py │ ├── pos_embed_sincos.py │ ├── selective_kernel.py │ ├── separable_conv.py │ ├── space_to_depth.py │ ├── split_attn.py │ ├── split_batchnorm.py │ ├── squeeze_excite.py │ ├── std_conv.py │ ├── test_time_pool.py │ ├── trace_utils.py │ ├── typing.py │ └── weight_init.py ├── resnet.py ├── swin_transformer_v2.py └── vision_transformer.py ├── my_dataset.py ├── parameter.py ├── predict.py ├── train_val.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | PVF-10.zip filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/PV_Classify.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 24 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Matrix.py: -------------------------------------------------------------------------------- 1 | """ 2 | @FileName:Matrix.py\n 3 | @Description:混淆矩阵\n 4 | @Author:WBobby\n 5 | @Department:CUG\n 6 | @Time:2023/7/12 10:05\n 7 | """ 8 | import csv 9 | 10 | import numpy as np 11 | import pandas as pd 12 | from sklearn.metrics import confusion_matrix 13 | 14 | 15 | def read_csv(infile): 16 | df = pd.read_csv(infile) 17 | data = df.values.tolist() 18 | for row in data: 19 | image = row[1] 20 | row.append(image[:2]) 21 | return data 22 | 23 | def matrix(data,outfile): 24 | y_true = [row[2] for row in data] 25 | # print(y_true) 26 | y_pred = [row[4] for row in data] 27 | # print(y_pred) 28 | confusion_mat = confusion_matrix(y_true, y_pred) 29 | # print(confusion_mat) 30 | return confusion_mat 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | if __name__ == '__main__': 41 | infile = r'C:\Users\Wbobby\Desktop\BB.csv' 42 | outfile = '' 43 | data = read_csv(infile) 44 | matrix(data, outfile) -------------------------------------------------------------------------------- /PVF-10.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5bbedd95fd99b98198fcea84017c9243af8d4e8aeef307a3bd77200919195189 3 | size 165094105 4 | -------------------------------------------------------------------------------- /Readme.txt: -------------------------------------------------------------------------------- 1 | # PVF-Dataset (PVF-10) 2 | 1. This is code and dataset from the paper of "Photovoltaic Fault Dataset (PVF-10): A High-resolution UAV Thermal Infrared Image Dataset for Fine-grained Photovoltaic Fault Classification " 3 | 2. The code and dataset (PVF-10) are on the master branch. 4 | 3. Since the weight file of the model is too large, we put it on "pan.baidu.com" for interested researchers at the link: https://pan.baidu.com/s/1SfvW7jvkqhF5tu7J-EAzjw; Code:PVFs 5 | train_val.py is the training script. 6 | pre is the test script. 7 | The model folder holds the classification models used in the paper. 8 | The weights folder holds the weights files for the five models in the paper. 9 | 4. Due to the bandwidth limitation of Github large file system GLF, PVF-10 may not be able to be downloaded properly, please move to Baidu.com. Link: https://pan.baidu.com/s/1LJVe2lkqvYnwTI8cuiESDg; Extract Code: PVFD; or Google drive link: https://drive.google.com/file/d/1SQq0hETXi8I3Kdq9tDAEVyZgIsRCbOah/view?usp=sharing 10 | 11 | If you use PVF-10 for related research, please cite our paper: PVF-10: A high-resolution unmanned aerial vehicle thermal infrared image dataset for fine-grained photovoltaic fault classification, with bibtex as: @article{WANG2024124187, title = {PVF-10: A high-resolution unmanned aerial vehicle thermal infrared image dataset for fine-grained photovoltaic fault classification}, journal = {Applied Energy}, volume = {376}, pages = {124187}, year = {2024}, issn = {0306-2619}, doi = {https://doi.org/10.1016/j.apenergy.2024.124187}, url = {https://www.sciencedirect.com/science/article/pii/S0306261924015708}, author = {Bo Wang and Qi Chen and Mengmeng Wang and Yuntian Chen and Zhengjia Zhang and Xiuguo Liu and Wei Gao and Yanzhen Zhang and Haoran Zhang}, keywords = {Photovoltaic fault, Thermal infrared data, Classification, Deep learning, Unmanned aerial vehicle}, abstract = {Accurate identification of faulty photovoltaic (PV) modules is crucial for the effective operation and maintenance of PV systems. Deep learning (DL) algorithms exhibit promising potential for classifying PV fault (PVF) from thermal infrared (TIR) images captured by unmanned aerial vehicle (UAV), contingent upon the availability of extensive and high-quality labeled data. However, existing TIR PVF datasets are limited by low image resolution and incomplete coverage of fault types. This study proposes a high-resolution TIR PVF dataset with 10 classes, named PVF-10, comprising 5579 cropped images of PV panels collected from 8 PV power plants. These classes are further categorized into two groups according to the repairability of PVF, with 5 repairable and 5 irreparable classes each. Additionally, the circuit mechanisms underlying the TIR image features of typical PVF types are analyzed, supported by high-resolution images, thereby providing comprehensive information for PV operators. Finally, five state-of-the-art DL algorithms are trained and validated based on the PVF-10 dataset using three levels of resampling strategy. The results show that the overall accuracy (OA) of these algorithms exceeds 83%, with the highest OA reaching 93.32%. Moreover, the preprocessing procedure involving resampling and padding strategies are beneficial for improving PVF classification accuracy using PVF-10 datasets. The developed PVF-10 dataset is expected to stimulate further research and innovation in PVF classification.} } 12 | 13 | Note: Data provided by the School of Geography and Information Engineering, China University of Geosciences. 14 | -------------------------------------------------------------------------------- /class_indices.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "01bottom dirt", 3 | "1": "02break", 4 | "2": "03Debris cover", 5 | "3": "04junction box heat", 6 | "4": "05hot cell", 7 | "5": "06shadow", 8 | "6": "07short circuit panel", 9 | "7": "08string short circuit", 10 | "8": "09substring open circuit", 11 | "9": "10healthy panel" 12 | } -------------------------------------------------------------------------------- /fig/__pycache__/chart.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/__pycache__/chart.cpython-38.pyc -------------------------------------------------------------------------------- /fig/__pycache__/con_mat_fig.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/__pycache__/con_mat_fig.cpython-38.pyc -------------------------------------------------------------------------------- /fig/chart.py: -------------------------------------------------------------------------------- 1 | """ 2 | @FileName:chart.py\n 3 | @Description:\n 4 | @Author:WBobby\n 5 | @Department:CUG\n 6 | @Time:2023/12/18 23:51\n 7 | """ 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | fig_name = 'Sample Number of PPF ' 13 | figsize = (8, 8) 14 | 15 | matplotlib.rcParams['font.family'] = 'serif' 16 | matplotlib.rcParams['font.serif'] = ['Times New Roman'] + matplotlib.rcParams['font.serif'] 17 | # (a) bypass diode heating, (b) substring disconnection, (c) debris covering, (d) panel breaking, (e) dusty covering, 18 | # (f) General hot spot, (g) health panel. 19 | fault_dic = {'a': 595, 'b': 427, 'c': 71, 'd': 410, 'e': 303, 'f': 377, 'g': 131, 20 | 'h': 800, 'i': 946, 'j': 1519} 21 | 22 | fault_names = list(fault_dic.keys())[::] 23 | fault_counts = list(fault_dic.values())[::] 24 | 25 | 26 | # colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink'] 27 | colors = plt.cm.Wistia(np.linspace(1, 0.1, len(fault_names))) 28 | print(colors) 29 | 30 | fig, ax = plt.subplots(figsize=(6, 4)) 31 | bar_width = 0.6 32 | bars = ax.bar(fault_names, fault_counts, color=colors, width=bar_width) 33 | # bars = ax.bar(fault_counts, fault_names, color=colors, height=bar_width) 34 | 35 | # plt.barh(fault_names, fault_counts, color=colors) 36 | plt.ylabel('Fault Count', fontsize=12, fontweight='bold') 37 | plt.xlabel('Fault Types', fontsize=12, fontweight='bold') 38 | # plt.title(fig_name, fontsize=25) 39 | 40 | for bar, x, count in zip(bars, fault_names, fault_counts): 41 | ax.text(x, bar.get_y() + bar.get_height(), str(count), ha='center', va='top', fontsize=14) 42 | 43 | ax.grid(axis='y', linestyle='--', alpha=0.7) 44 | ax.set_title(fig_name, fontsize=15, fontweight='bold') 45 | plt.yticks(fontsize=11, fontweight='bold') 46 | plt.xticks(fontsize=12, fontweight='bold') 47 | plt.tight_layout() # 调整布局 48 | plt.savefig(r'C:\Users\Wbobby\Documents\TeX_files\PVFC\tu/' + 'Sample3.png', dpi=300) 49 | plt.show() 50 | -------------------------------------------------------------------------------- /fig/con_mat_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/con_mat_10.png -------------------------------------------------------------------------------- /fig/con_mat_compare_samp_pad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/con_mat_compare_samp_pad.png -------------------------------------------------------------------------------- /fig/con_mat_compare_samp_pad531.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/fig/con_mat_compare_samp_pad531.png -------------------------------------------------------------------------------- /fig/figure_temp.py: -------------------------------------------------------------------------------- 1 | """ 2 | @FileName:figure_temp.py\n 3 | @Description:python画图模板\n 4 | @Author:WBobby\n 5 | @Department:CUG\n 6 | @Time:2024/3/20 20:29\n 7 | """ 8 | 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | 13 | def plot_figure(feature1, feature2, filename): 14 | # 创建整块画布 15 | ''' 16 | plt.figure() 是 matplotlib.pyplot 模块中的一个函数,用于创建一个新的图表画布(Figure对象)。这个函数有几个参数,可以用来配置图表的外观和行为。以下是一些常用的参数: 17 | figsize:一个元组,包含两个数值,分别代表图表的宽度和高度(以英寸为单位)。例如,figsize=(8, 6) 将创建一个宽8英寸和高6英寸的图表。 18 | dpi:图表的分辨率,表示每英寸多少个像素。默认值通常是 100。 19 | facecolor:图表画布的背景颜色。可以是一个颜色名、十六进制颜色码、RGB 或 RGBA 元组,或者预定义的颜色字符串。 20 | edgecolor:图表画布的边缘颜色。 21 | frameon:布尔值,表示是否在图表画布周围显示边框。默认值为 True。 22 | fig:可选的参数,用于指定要修改的现有Figure对象。如果没有提供,将创建一个新的Figure对象。 23 | constrained_layout:布尔值,表示是否使用受约束的布局。如果设置为 True,matplotlib 将尝试优化子图之间的空间分配。 24 | clear:布尔值,表示是否在创建新图表之前清除当前图表。默认值为 False。 25 | :param feature1: 26 | :param feature2: 27 | :return: 28 | ''' 29 | fig = plt.figure(figsize=(8, 6)) 30 | # 创建子图、子图实在画布上排布的,并且会覆盖 31 | ''' 32 | 在 matplotlib 中,Axes 类是用于在 Figure 对象上绘制图表和图形的核心类。每个 Axes 对象代表图表中的一个坐标轴系统,可以用来绘制线条、散点图、柱状图、饼图等多种类型的图形。 33 | Axes 类提供了一系列的方法和属性来控制图形的绘制和样式,以及进行数据处理和分析。以下是一些 Axes 类的主要属性和方法: 34 | 属性: 35 | xaxis 和 yaxis:分别代表 x 轴和 y 轴的对象,可以用来设置轴的属性,如标签、刻度、范围等。 36 | lines:一个 LineCollection 对象,包含了当前 Axes 对象上的所有线条。 37 | collections:一个 Collection 对象,包含了当前 Axes 对象上的所有集合图形(如多边形、散点图等)。 38 | patches:一个 PatchCollection 对象,包含了当前 Axes 对象上的所有填充区域。 39 | images:一个 ImageCollection 对象,包含了当前 Axes 对象上的所有图像。 40 | legend_:一个 Legend 对象,包含了当前 Axes 对象的图例。 41 | 方法: 42 | plot:绘制线条,接受 x 和 y 坐标,以及其他可选参数来控制线条的样式。 43 | scatter:绘制散点图,接受 x 和 y 坐标,以及其他可选参数来控制标记的样式。 44 | bar 和 barh:绘制柱状图,接受 x 和 y 坐标,以及其他可选参数来控制柱状图的样式。 45 | boxplot:绘制箱线图,接受数据集,以及其他可选参数来控制箱线图的样式。 46 | fill_between:填充两个数据集之间的区域。 47 | errorbar:绘制带有误差线的数据点。 48 | legend:添加图例。 49 | set_xlim 和 set_ylim:设置 x 轴和 y 轴的显示范围。 50 | set_xlabel 和 set_ylabel:设置 x 轴和 y 轴的标签。 51 | set_title:设置图表的标题。 52 | grid:显示或隐藏网格。 53 | set_facecolor 和 set_edgecolor:设置坐标轴的背景色和边缘色。 54 | set_alpha:设置坐标轴的透明度。 55 | set_axis_bgcolor 和 set_axis_color:设置坐标轴背景色和坐标轴颜色。 56 | ''' 57 | ax1 = fig.add_subplot() 58 | x1 = feature1[0] 59 | y1 = feature1[1] 60 | # 画第一个图 61 | ''' 62 | axes.plot 是 matplotlib.axes.Axes 类的一个方法,用于在坐标轴上绘制线条。它的参数可以用来控制线条的外观和行为。以下是一些常用的参数: 63 | x 和 y:这两个参数是必须的,它们分别代表线条的x轴和y轴坐标。x 可以是单个数值或数值数组,y 必须与 x 具有相同的形状。 64 | color:线条的颜色,可以是颜色名(如 ‘red’、‘green’ 等),十六进制颜色码(如 ‘#FF00FF’),RGB 或 RGBA 元组,或者预定义的颜色字符串(如 ‘C0’、‘C1’ 等,其中 ‘C0’、‘C1’ 等表示颜色循环中的颜色)。 65 | linewidth:线条的宽度,默认为 1.0。 66 | linestyle:线条的样式,可以是实线(‘-’)、虚线(‘–’)、点线(‘:’)、点(‘.’)等。 67 | marker:数据点的标记样式,如 ‘o’(圆形)、‘s’(方形)、‘^’(三角形上)、‘<’(三角形下)等。 68 | markersize:数据点的大小。 69 | markeredgewidth:数据点边缘线的宽度。 70 | markeredgecolor:数据点边缘线的颜色。 71 | label:用于图例的标签。 72 | alpha:线条和标记的透明度,范围从 0(完全透明)到 1(完全不透明)。 73 | ax:要绘制的坐标轴对象。 74 | data:如果 x 和 y 参数未提供,则可以提供一个包含 x 和 y 数据的字典或元组。 75 | ''' 76 | # ax1.boxplot(x1, y1, 'r') 77 | x2 = feature2[0] 78 | print(x2) 79 | y2 = feature2[1] 80 | ax1.boxplot(x2, notch=False, patch_artist=True, labels=['A']) 81 | # 坐标轴微调 82 | ''' 83 | 在 matplotlib 中,Axes 对象有一个名为 axis 的方法,它用于设置和调整坐标轴的各种属性。这个方法通常接受一个参数,表示要调整的坐标轴(x 或 y),以及一些可选的参数来控制轴的显示和行为。 84 | 以下是 axis 方法的一些常见用法: 85 | axis('off'):关闭坐标轴的显示。 86 | axis('on'):开启坐标轴的显示。 87 | axis('equal'):确保两个坐标轴的刻度比例相同,这对于绘制等比例的地图或其他图形非常重要。 88 | axis('scaled'):这是 axis('equal') 的一个别名。 89 | axis('tight'):自动调整坐标轴的范围以适应数据,通常在绘制多个子图时使用。 90 | axis('auto'):自动调整坐标轴的范围,通常是默认行为。 91 | axis([xmin, xmax, ymin, ymax]):手动设置坐标轴的范围。 92 | axis_bgcolor:设置坐标轴的背景色。 93 | axis_color:设置坐标轴的颜色。 94 | ''' 95 | 96 | # 显示图表 97 | plt.show() 98 | # 将图表保存为文件。 99 | plt.savefig(filename, format='png', dpi=300) 100 | 101 | 102 | if __name__ == '__main__': 103 | a = np.arange(1, 10) 104 | b = 2 * a 105 | c = 3 * a 106 | # 将a和b合并为一个二维数组 107 | feature1 = np.vstack((a, b)) 108 | # feature2 = np.vstack((a, c)) 109 | save_name = 'figure.png' 110 | data1 = [2, 2, 3, 4, 4, 4, 5, 5, 6, 6, 7, 8, 9, 9, 10] 111 | data2 = [1, 1, 2, 3, 3, 4, 4, 5, 6, 7, 8, 9, 9, 10, 11, 10, 9] 112 | feature2 = (data1, data2) 113 | plot_figure(feature1, feature2, save_name) 114 | -------------------------------------------------------------------------------- /fig/read_jpg.py: -------------------------------------------------------------------------------- 1 | """ 2 | @FileName:read_jpg.py\n 3 | @Description:\n 4 | @Author:WBobby\n 5 | @Department:CUG\n 6 | @Time:2024/3/26 23:07\n 7 | """ 8 | import cv2 9 | 10 | if __name__ == '__main__': 11 | jpg = r'C:\Users\Wbobby\Desktop\12\DJI_20231225162517_0002_T.JPG' 12 | img = cv2.imread(jpg) 13 | cv2.imshow('image', img) 14 | cv2.waitKey(0) 15 | -------------------------------------------------------------------------------- /fig/train_loss_fig.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | 8 | def csv2dict(csv_file): 9 | data = pd.read_csv(csv_file) 10 | dict = {} 11 | for i in range(len(data)): 12 | if i % 5 == 0: 13 | epoch = data.iloc[i, 1] + 1 14 | val_acc = data.iloc[i, 3] 15 | loss = data.iloc[i, 5] 16 | dict[epoch] = [val_acc, loss] 17 | return dict 18 | 19 | 20 | def loss_fig(csv_path): 21 | matplotlib.rcParams['font.family'] = 'serif' 22 | matplotlib.rcParams['font.serif'] = ['Times New Roman'] + matplotlib.rcParams['font.serif'] 23 | figsize = (16, 9) 24 | plt.figure(figsize=figsize) 25 | csv_files = os.listdir(csv_path) 26 | csv_names = [os.path.join(csv_path, csv_file) for csv_file in csv_files] 27 | model_names = ['Coat-ls', 'Effv2-s', 'Res-50', 'Swinv2-t', 'ViT-s'] 28 | model_losscolors = plt.cm.coolwarm([0, 0.1, 0.7, 0.8, 0.9]) 29 | model_valcolors = plt.cm.coolwarm([0.02, 0.12, 0.68, 0.78, 0.88]) 30 | line_styles = ['-', '--', '-.', ':', '--'] 31 | # linewidths = [2, 1.8, 1.6, 1.4, 1.4] 32 | linewidths = [2, 2, 2, 2, 2] 33 | # alphas = [1, 0.8, 0.7, 0.6, 0.6] 34 | alphas = [1, 1, 1, 1, 1] 35 | model_losscolors = ['lightcoral', 'orange', 'lightgreen', 'darkturquoise', 'plum'] 36 | model_valcolors = ['indianred', 'darkorange', 'mediumseagreen', 'c', 'orchid'] 37 | markers = ['o', 'v', '^', 's', '*'] 38 | ax1 = plt.subplot() 39 | ax2 = plt.twinx() 40 | zz = zip(model_names, csv_names, model_losscolors, model_valcolors, line_styles, linewidths, alphas, markers) 41 | 42 | lines = [] 43 | labels = [] 44 | lines_labels = [] 45 | for model_name, csv_name, model_color, model_valcolor, line_style, linewidth, alpha, markers in zz: 46 | csv_file = os.path.join(csv_path, csv_name) 47 | dict = csv2dict(csv_file) 48 | epochs = list(dict.keys()) 49 | val_acc = [item[0] for item in dict.values()] 50 | loss = [item[1] for item in dict.values()] 51 | loss_line, = ax1.plot(epochs, loss, color=model_color, linestyle=line_style, linewidth=linewidth, alpha=alpha, 52 | marker=markers, 53 | markersize=5, label='Loss {}'.format(model_name)) 54 | val_line, = ax2.plot(epochs, val_acc, color=model_valcolor, marker=markers, markersize=5, linestyle=line_style, 55 | alpha=alpha, 56 | linewidth=linewidth, label='Val Acc {}'.format(model_name)) 57 | lines.append(val_line) 58 | lines.append(loss_line) 59 | labels.append(val_line.get_label()) 60 | labels.append(loss_line.get_label()) 61 | 62 | # 将线条和标签的元组列表解包,并传入图例中 63 | 64 | 65 | # 创建统一的图例 66 | plt.legend(lines, labels, loc='center right', fontsize=17) 67 | plt.title('Model Loss and Validation Accuracy', fontsize=40, fontweight='bold') 68 | ax1.set_xlabel('Epoch', fontsize=25, fontweight='bold') 69 | ax1.tick_params(axis='x', labelsize=20) 70 | ax1.set_ylabel('Train Loss', fontsize=25, fontweight='bold') 71 | ax1.yaxis.set_tick_params(labelsize=20) 72 | ax2.set_ylabel('Validation Accuracy', fontsize=25, fontweight='bold') 73 | ax2.tick_params(axis='y', labelsize=20) 74 | # plt.legend(loc='center right', fontsize=20) 75 | plt.savefig(r'C:\Users\Wbobby\Documents\TeX_files\PVFC\tu/' + 'loss_val1.png', dpi=300) 76 | plt.show() 77 | 78 | 79 | if __name__ == '__main__': 80 | csv_parth = r'C:\Users\Wbobby\Desktop\csv' 81 | out = '' 82 | loss_fig(csv_parth) 83 | -------------------------------------------------------------------------------- /model/111.py: -------------------------------------------------------------------------------- 1 | """ 2 | @FileName:111.py\n 3 | @Description:\n 4 | @Author:WBobby\n 5 | @Department:CUG\n 6 | @Time:2024/6/17 22:40\n 7 | """ 8 | 9 | 10 | def my_decorator(func): 11 | def wrapper(): 12 | print("装饰器添加的功能") 13 | func() 14 | func() 15 | return wrapper 16 | 17 | @my_decorator 18 | def say_hello(): 19 | print("Hello!") 20 | 21 | 22 | say_hello() 23 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .coat import * 3 | from .efficientnet import * 4 | from .swin_transformer_v2 import * 5 | from .vision_transformer import * 6 | from .effnetv2 import * 7 | 8 | 9 | from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \ 10 | set_pretrained_download_progress, set_pretrained_check_hash 11 | from ._factory import create_model, parse_model_name, safe_model_name 12 | from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet 13 | from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ 14 | register_notrace_module, is_notrace_module, get_notrace_modules, \ 15 | register_notrace_function, is_notrace_function, get_notrace_functions 16 | from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint 17 | from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub 18 | from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ 19 | group_modules, group_parameters, checkpoint_seq, adapt_input_conv 20 | from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg 21 | from ._prune import adapt_model_from_string 22 | from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \ 23 | register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \ 24 | is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value 25 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_builder.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_efficientnet_blocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_efficientnet_blocks.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_efficientnet_builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_efficientnet_builder.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_factory.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_features.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_features.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_features_fx.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_features_fx.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_helpers.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_hub.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_hub.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_manipulate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_manipulate.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_pretrained.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_pretrained.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_prune.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_prune.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/_registry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/_registry.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/coat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/coat.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/efficientnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/efficientnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/effnetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/effnetv2.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/swin_transformer_v2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/swin_transformer_v2.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/vision_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/__pycache__/vision_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /model/_pretrained.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import deque, defaultdict 3 | from dataclasses import dataclass, field, replace, asdict 4 | from typing import Any, Deque, Dict, Tuple, Optional, Union 5 | 6 | 7 | __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg'] 8 | 9 | 10 | @dataclass 11 | class PretrainedCfg: 12 | """ 13 | """ 14 | # weight source locations 15 | url: Optional[Union[str, Tuple[str, str]]] = None # remote URL 16 | file: Optional[str] = None # local / shared filesystem path 17 | state_dict: Optional[Dict[str, Any]] = None # in-memory state dict 18 | hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model') 19 | hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default) 20 | 21 | source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub) 22 | architecture: Optional[str] = None # architecture variant can be set when not implicit 23 | tag: Optional[str] = None # pretrained tag of source 24 | custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) 25 | 26 | # input / data config 27 | input_size: Tuple[int, int, int] = (3, 224, 224) 28 | test_input_size: Optional[Tuple[int, int, int]] = None 29 | min_input_size: Optional[Tuple[int, int, int]] = None 30 | fixed_input_size: bool = False 31 | interpolation: str = 'bicubic' 32 | crop_pct: float = 0.875 33 | test_crop_pct: Optional[float] = None 34 | crop_mode: str = 'center' 35 | mean: Tuple[float, ...] = (0.485, 0.456, 0.406) 36 | std: Tuple[float, ...] = (0.229, 0.224, 0.225) 37 | 38 | # head / classifier config and meta-data 39 | num_classes: int = 1000 40 | label_offset: Optional[int] = None 41 | label_names: Optional[Tuple[str]] = None 42 | label_descriptions: Optional[Dict[str, str]] = None 43 | 44 | # model attributes that vary with above or required for pretrained adaptation 45 | pool_size: Optional[Tuple[int, ...]] = None 46 | test_pool_size: Optional[Tuple[int, ...]] = None 47 | first_conv: Optional[str] = None 48 | classifier: Optional[str] = None 49 | 50 | license: Optional[str] = None 51 | description: Optional[str] = None 52 | origin_url: Optional[str] = None 53 | paper_name: Optional[str] = None 54 | paper_ids: Optional[Union[str, Tuple[str]]] = None 55 | notes: Optional[Tuple[str]] = None 56 | 57 | @property 58 | def has_weights(self): 59 | return self.url or self.file or self.hf_hub_id 60 | 61 | def to_dict(self, remove_source=False, remove_null=True): 62 | return filter_pretrained_cfg( 63 | asdict(self), 64 | remove_source=remove_source, 65 | remove_null=remove_null 66 | ) 67 | 68 | 69 | def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): 70 | filtered_cfg = {} 71 | keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none 72 | for k, v in cfg.items(): 73 | if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: 74 | continue 75 | if remove_null and v is None and k not in keep_null: 76 | continue 77 | filtered_cfg[k] = v 78 | return filtered_cfg 79 | 80 | 81 | @dataclass 82 | class DefaultCfg: 83 | tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default) 84 | cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag 85 | is_pretrained: bool = False # at least one of the configs has a pretrained source set 86 | 87 | @property 88 | def default(self): 89 | return self.cfgs[self.tags[0]] 90 | 91 | @property 92 | def default_with_tag(self): 93 | tag = self.tags[0] 94 | return tag, self.cfgs[tag] 95 | -------------------------------------------------------------------------------- /model/_prune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pkgutil 3 | from copy import deepcopy 4 | 5 | from torch import nn as nn 6 | 7 | from timm.layers import Conv2dSame, BatchNormAct2d, Linear 8 | 9 | __all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file'] 10 | 11 | 12 | def extract_layer(model, layer): 13 | layer = layer.split('.') 14 | module = model 15 | if hasattr(model, 'module') and layer[0] != 'module': 16 | module = model.module 17 | if not hasattr(model, 'module') and layer[0] == 'module': 18 | layer = layer[1:] 19 | for l in layer: 20 | if hasattr(module, l): 21 | if not l.isdigit(): 22 | module = getattr(module, l) 23 | else: 24 | module = module[int(l)] 25 | else: 26 | return module 27 | return module 28 | 29 | 30 | def set_layer(model, layer, val): 31 | layer = layer.split('.') 32 | module = model 33 | if hasattr(model, 'module') and layer[0] != 'module': 34 | module = model.module 35 | lst_index = 0 36 | module2 = module 37 | for l in layer: 38 | if hasattr(module2, l): 39 | if not l.isdigit(): 40 | module2 = getattr(module2, l) 41 | else: 42 | module2 = module2[int(l)] 43 | lst_index += 1 44 | lst_index -= 1 45 | for l in layer[:lst_index]: 46 | if not l.isdigit(): 47 | module = getattr(module, l) 48 | else: 49 | module = module[int(l)] 50 | l = layer[lst_index] 51 | setattr(module, l, val) 52 | 53 | 54 | def adapt_model_from_string(parent_module, model_string): 55 | separator = '***' 56 | state_dict = {} 57 | lst_shape = model_string.split(separator) 58 | for k in lst_shape: 59 | k = k.split(':') 60 | key = k[0] 61 | shape = k[1][1:-1].split(',') 62 | if shape[0] != '': 63 | state_dict[key] = [int(i) for i in shape] 64 | 65 | new_module = deepcopy(parent_module) 66 | for n, m in parent_module.named_modules(): 67 | old_module = extract_layer(parent_module, n) 68 | if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): 69 | if isinstance(old_module, Conv2dSame): 70 | conv = Conv2dSame 71 | else: 72 | conv = nn.Conv2d 73 | s = state_dict[n + '.weight'] 74 | in_channels = s[1] 75 | out_channels = s[0] 76 | g = 1 77 | if old_module.groups > 1: 78 | in_channels = out_channels 79 | g = in_channels 80 | new_conv = conv( 81 | in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, 82 | bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, 83 | groups=g, stride=old_module.stride) 84 | set_layer(new_module, n, new_conv) 85 | elif isinstance(old_module, BatchNormAct2d): 86 | new_bn = BatchNormAct2d( 87 | state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, 88 | affine=old_module.affine, track_running_stats=True) 89 | new_bn.drop = old_module.drop 90 | new_bn.act = old_module.act 91 | set_layer(new_module, n, new_bn) 92 | elif isinstance(old_module, nn.BatchNorm2d): 93 | new_bn = nn.BatchNorm2d( 94 | num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, 95 | affine=old_module.affine, track_running_stats=True) 96 | set_layer(new_module, n, new_bn) 97 | elif isinstance(old_module, nn.Linear): 98 | # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? 99 | num_features = state_dict[n + '.weight'][1] 100 | new_fc = Linear( 101 | in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) 102 | set_layer(new_module, n, new_fc) 103 | if hasattr(new_module, 'num_features'): 104 | new_module.num_features = num_features 105 | new_module.eval() 106 | parent_module.eval() 107 | 108 | return new_module 109 | 110 | 111 | def adapt_model_from_file(parent_module, model_variant): 112 | adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt')) 113 | return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip()) 114 | -------------------------------------------------------------------------------- /model/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config, resolve_model_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .dataset_info import DatasetInfo, CustomDatasetInfo 8 | from .imagenet_info import ImageNetInfo, infer_imagenet_subset 9 | from .loader import create_loader 10 | from .mixup import Mixup, FastCollateMixup 11 | from .readers import create_reader 12 | from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions 13 | from .real_labels import RealLabelsImagenet 14 | from .transforms import * 15 | from .transforms_factory import create_transform 16 | from .constants import * 17 | 18 | -------------------------------------------------------------------------------- /model/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/auto_augment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/auto_augment.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/constants.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/dataset_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/dataset_factory.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/dataset_info.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/dataset_info.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/distributed_sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/distributed_sampler.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/imagenet_info.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/imagenet_info.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/loader.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/mixup.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/mixup.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/random_erasing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/random_erasing.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/real_labels.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/real_labels.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/__pycache__/transforms_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/__pycache__/transforms_factory.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/_info/imagenet_a_indices.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 11 3 | 13 4 | 15 5 | 17 6 | 22 7 | 23 8 | 27 9 | 30 10 | 37 11 | 39 12 | 42 13 | 47 14 | 50 15 | 57 16 | 70 17 | 71 18 | 76 19 | 79 20 | 89 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 105 27 | 107 28 | 108 29 | 110 30 | 113 31 | 124 32 | 125 33 | 130 34 | 132 35 | 143 36 | 144 37 | 150 38 | 151 39 | 207 40 | 234 41 | 235 42 | 254 43 | 277 44 | 283 45 | 287 46 | 291 47 | 295 48 | 298 49 | 301 50 | 306 51 | 307 52 | 308 53 | 309 54 | 310 55 | 311 56 | 313 57 | 314 58 | 315 59 | 317 60 | 319 61 | 323 62 | 324 63 | 326 64 | 327 65 | 330 66 | 334 67 | 335 68 | 336 69 | 347 70 | 361 71 | 363 72 | 372 73 | 378 74 | 386 75 | 397 76 | 400 77 | 401 78 | 402 79 | 404 80 | 407 81 | 411 82 | 416 83 | 417 84 | 420 85 | 425 86 | 428 87 | 430 88 | 437 89 | 438 90 | 445 91 | 456 92 | 457 93 | 461 94 | 462 95 | 470 96 | 472 97 | 483 98 | 486 99 | 488 100 | 492 101 | 496 102 | 514 103 | 516 104 | 528 105 | 530 106 | 539 107 | 542 108 | 543 109 | 549 110 | 552 111 | 557 112 | 561 113 | 562 114 | 569 115 | 572 116 | 573 117 | 575 118 | 579 119 | 589 120 | 606 121 | 607 122 | 609 123 | 614 124 | 626 125 | 627 126 | 640 127 | 641 128 | 642 129 | 643 130 | 658 131 | 668 132 | 677 133 | 682 134 | 684 135 | 687 136 | 701 137 | 704 138 | 719 139 | 736 140 | 746 141 | 749 142 | 752 143 | 758 144 | 763 145 | 765 146 | 768 147 | 773 148 | 774 149 | 776 150 | 779 151 | 780 152 | 786 153 | 792 154 | 797 155 | 802 156 | 803 157 | 804 158 | 813 159 | 815 160 | 820 161 | 823 162 | 831 163 | 833 164 | 835 165 | 839 166 | 845 167 | 847 168 | 850 169 | 859 170 | 862 171 | 870 172 | 879 173 | 880 174 | 888 175 | 890 176 | 897 177 | 900 178 | 907 179 | 913 180 | 924 181 | 932 182 | 933 183 | 934 184 | 937 185 | 943 186 | 945 187 | 947 188 | 951 189 | 954 190 | 956 191 | 957 192 | 959 193 | 971 194 | 972 195 | 980 196 | 981 197 | 984 198 | 986 199 | 987 200 | 988 201 | -------------------------------------------------------------------------------- /model/data/_info/imagenet_a_synsets.txt: -------------------------------------------------------------------------------- 1 | n01498041 2 | n01531178 3 | n01534433 4 | n01558993 5 | n01580077 6 | n01614925 7 | n01616318 8 | n01631663 9 | n01641577 10 | n01669191 11 | n01677366 12 | n01687978 13 | n01694178 14 | n01698640 15 | n01735189 16 | n01770081 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01819313 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01882714 27 | n01910747 28 | n01914609 29 | n01924916 30 | n01944390 31 | n01985128 32 | n01986214 33 | n02007558 34 | n02009912 35 | n02037110 36 | n02051845 37 | n02077923 38 | n02085620 39 | n02099601 40 | n02106550 41 | n02106662 42 | n02110958 43 | n02119022 44 | n02123394 45 | n02127052 46 | n02129165 47 | n02133161 48 | n02137549 49 | n02165456 50 | n02174001 51 | n02177972 52 | n02190166 53 | n02206856 54 | n02219486 55 | n02226429 56 | n02231487 57 | n02233338 58 | n02236044 59 | n02259212 60 | n02268443 61 | n02279972 62 | n02280649 63 | n02281787 64 | n02317335 65 | n02325366 66 | n02346627 67 | n02356798 68 | n02361337 69 | n02410509 70 | n02445715 71 | n02454379 72 | n02486410 73 | n02492035 74 | n02504458 75 | n02655020 76 | n02669723 77 | n02672831 78 | n02676566 79 | n02690373 80 | n02701002 81 | n02730930 82 | n02777292 83 | n02782093 84 | n02787622 85 | n02793495 86 | n02797295 87 | n02802426 88 | n02814860 89 | n02815834 90 | n02837789 91 | n02879718 92 | n02883205 93 | n02895154 94 | n02906734 95 | n02948072 96 | n02951358 97 | n02980441 98 | n02992211 99 | n02999410 100 | n03014705 101 | n03026506 102 | n03124043 103 | n03125729 104 | n03187595 105 | n03196217 106 | n03223299 107 | n03250847 108 | n03255030 109 | n03291819 110 | n03325584 111 | n03355925 112 | n03384352 113 | n03388043 114 | n03417042 115 | n03443371 116 | n03444034 117 | n03445924 118 | n03452741 119 | n03483316 120 | n03584829 121 | n03590841 122 | n03594945 123 | n03617480 124 | n03666591 125 | n03670208 126 | n03717622 127 | n03720891 128 | n03721384 129 | n03724870 130 | n03775071 131 | n03788195 132 | n03804744 133 | n03837869 134 | n03840681 135 | n03854065 136 | n03888257 137 | n03891332 138 | n03935335 139 | n03982430 140 | n04019541 141 | n04033901 142 | n04039381 143 | n04067472 144 | n04086273 145 | n04099969 146 | n04118538 147 | n04131690 148 | n04133789 149 | n04141076 150 | n04146614 151 | n04147183 152 | n04179913 153 | n04208210 154 | n04235860 155 | n04252077 156 | n04252225 157 | n04254120 158 | n04270147 159 | n04275548 160 | n04310018 161 | n04317175 162 | n04344873 163 | n04347754 164 | n04355338 165 | n04366367 166 | n04376876 167 | n04389033 168 | n04399382 169 | n04442312 170 | n04456115 171 | n04482393 172 | n04507155 173 | n04509417 174 | n04532670 175 | n04540053 176 | n04554684 177 | n04562935 178 | n04591713 179 | n04606251 180 | n07583066 181 | n07695742 182 | n07697313 183 | n07697537 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07749582 189 | n07753592 190 | n07760859 191 | n07768694 192 | n07831146 193 | n09229709 194 | n09246464 195 | n09472597 196 | n09835506 197 | n11879895 198 | n12057211 199 | n12144580 200 | n12267677 201 | -------------------------------------------------------------------------------- /model/data/_info/imagenet_r_indices.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 4 4 | 6 5 | 8 6 | 9 7 | 11 8 | 13 9 | 22 10 | 23 11 | 26 12 | 29 13 | 31 14 | 39 15 | 47 16 | 63 17 | 71 18 | 76 19 | 79 20 | 84 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 100 27 | 105 28 | 107 29 | 113 30 | 122 31 | 125 32 | 130 33 | 132 34 | 144 35 | 145 36 | 147 37 | 148 38 | 150 39 | 151 40 | 155 41 | 160 42 | 161 43 | 162 44 | 163 45 | 171 46 | 172 47 | 178 48 | 187 49 | 195 50 | 199 51 | 203 52 | 207 53 | 208 54 | 219 55 | 231 56 | 232 57 | 234 58 | 235 59 | 242 60 | 245 61 | 247 62 | 250 63 | 251 64 | 254 65 | 259 66 | 260 67 | 263 68 | 265 69 | 267 70 | 269 71 | 276 72 | 277 73 | 281 74 | 288 75 | 289 76 | 291 77 | 292 78 | 293 79 | 296 80 | 299 81 | 301 82 | 308 83 | 309 84 | 310 85 | 311 86 | 314 87 | 315 88 | 319 89 | 323 90 | 327 91 | 330 92 | 334 93 | 335 94 | 337 95 | 338 96 | 340 97 | 341 98 | 344 99 | 347 100 | 353 101 | 355 102 | 361 103 | 362 104 | 365 105 | 366 106 | 367 107 | 368 108 | 372 109 | 388 110 | 390 111 | 393 112 | 397 113 | 401 114 | 407 115 | 413 116 | 414 117 | 425 118 | 428 119 | 430 120 | 435 121 | 437 122 | 441 123 | 447 124 | 448 125 | 457 126 | 462 127 | 463 128 | 469 129 | 470 130 | 471 131 | 472 132 | 476 133 | 483 134 | 487 135 | 515 136 | 546 137 | 555 138 | 558 139 | 570 140 | 579 141 | 583 142 | 587 143 | 593 144 | 594 145 | 596 146 | 609 147 | 613 148 | 617 149 | 621 150 | 629 151 | 637 152 | 657 153 | 658 154 | 701 155 | 717 156 | 724 157 | 763 158 | 768 159 | 774 160 | 776 161 | 779 162 | 780 163 | 787 164 | 805 165 | 812 166 | 815 167 | 820 168 | 824 169 | 833 170 | 847 171 | 852 172 | 866 173 | 875 174 | 883 175 | 889 176 | 895 177 | 907 178 | 928 179 | 931 180 | 932 181 | 933 182 | 934 183 | 936 184 | 937 185 | 943 186 | 945 187 | 947 188 | 948 189 | 949 190 | 951 191 | 953 192 | 954 193 | 957 194 | 963 195 | 965 196 | 967 197 | 980 198 | 981 199 | 983 200 | 988 201 | -------------------------------------------------------------------------------- /model/data/_info/imagenet_r_synsets.txt: -------------------------------------------------------------------------------- 1 | n01443537 2 | n01484850 3 | n01494475 4 | n01498041 5 | n01514859 6 | n01518878 7 | n01531178 8 | n01534433 9 | n01614925 10 | n01616318 11 | n01630670 12 | n01632777 13 | n01644373 14 | n01677366 15 | n01694178 16 | n01748264 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01806143 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01860187 27 | n01882714 28 | n01910747 29 | n01944390 30 | n01983481 31 | n01986214 32 | n02007558 33 | n02009912 34 | n02051845 35 | n02056570 36 | n02066245 37 | n02071294 38 | n02077923 39 | n02085620 40 | n02086240 41 | n02088094 42 | n02088238 43 | n02088364 44 | n02088466 45 | n02091032 46 | n02091134 47 | n02092339 48 | n02094433 49 | n02096585 50 | n02097298 51 | n02098286 52 | n02099601 53 | n02099712 54 | n02102318 55 | n02106030 56 | n02106166 57 | n02106550 58 | n02106662 59 | n02108089 60 | n02108915 61 | n02109525 62 | n02110185 63 | n02110341 64 | n02110958 65 | n02112018 66 | n02112137 67 | n02113023 68 | n02113624 69 | n02113799 70 | n02114367 71 | n02117135 72 | n02119022 73 | n02123045 74 | n02128385 75 | n02128757 76 | n02129165 77 | n02129604 78 | n02130308 79 | n02134084 80 | n02138441 81 | n02165456 82 | n02190166 83 | n02206856 84 | n02219486 85 | n02226429 86 | n02233338 87 | n02236044 88 | n02268443 89 | n02279972 90 | n02317335 91 | n02325366 92 | n02346627 93 | n02356798 94 | n02363005 95 | n02364673 96 | n02391049 97 | n02395406 98 | n02398521 99 | n02410509 100 | n02423022 101 | n02437616 102 | n02445715 103 | n02447366 104 | n02480495 105 | n02480855 106 | n02481823 107 | n02483362 108 | n02486410 109 | n02510455 110 | n02526121 111 | n02607072 112 | n02655020 113 | n02672831 114 | n02701002 115 | n02749479 116 | n02769748 117 | n02793495 118 | n02797295 119 | n02802426 120 | n02808440 121 | n02814860 122 | n02823750 123 | n02841315 124 | n02843684 125 | n02883205 126 | n02906734 127 | n02909870 128 | n02939185 129 | n02948072 130 | n02950826 131 | n02951358 132 | n02966193 133 | n02980441 134 | n02992529 135 | n03124170 136 | n03272010 137 | n03345487 138 | n03372029 139 | n03424325 140 | n03452741 141 | n03467068 142 | n03481172 143 | n03494278 144 | n03495258 145 | n03498962 146 | n03594945 147 | n03602883 148 | n03630383 149 | n03649909 150 | n03676483 151 | n03710193 152 | n03773504 153 | n03775071 154 | n03888257 155 | n03930630 156 | n03947888 157 | n04086273 158 | n04118538 159 | n04133789 160 | n04141076 161 | n04146614 162 | n04147183 163 | n04192698 164 | n04254680 165 | n04266014 166 | n04275548 167 | n04310018 168 | n04325704 169 | n04347754 170 | n04389033 171 | n04409515 172 | n04465501 173 | n04487394 174 | n04522168 175 | n04536866 176 | n04552348 177 | n04591713 178 | n07614500 179 | n07693725 180 | n07695742 181 | n07697313 182 | n07697537 183 | n07714571 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07742313 189 | n07745940 190 | n07749582 191 | n07753275 192 | n07753592 193 | n07768694 194 | n07873807 195 | n07880968 196 | n07920052 197 | n09472597 198 | n09835506 199 | n10565667 200 | n12267677 201 | -------------------------------------------------------------------------------- /model/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def resolve_data_config( 9 | args=None, 10 | pretrained_cfg=None, 11 | model=None, 12 | use_test_size=False, 13 | verbose=False 14 | ): 15 | assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config." 16 | args = args or {} 17 | pretrained_cfg = pretrained_cfg or {} 18 | if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'): 19 | pretrained_cfg = model.pretrained_cfg 20 | data_config = {} 21 | 22 | # Resolve input/image size 23 | in_chans = 3 24 | if args.get('in_chans', None) is not None: 25 | in_chans = args['in_chans'] 26 | elif args.get('chans', None) is not None: 27 | in_chans = args['chans'] 28 | 29 | input_size = (in_chans, 224, 224) 30 | if args.get('input_size', None) is not None: 31 | assert isinstance(args['input_size'], (tuple, list)) 32 | assert len(args['input_size']) == 3 33 | input_size = tuple(args['input_size']) 34 | in_chans = input_size[0] # input_size overrides in_chans 35 | elif args.get('img_size', None) is not None: 36 | assert isinstance(args['img_size'], int) 37 | input_size = (in_chans, args['img_size'], args['img_size']) 38 | else: 39 | if use_test_size and pretrained_cfg.get('test_input_size', None) is not None: 40 | input_size = pretrained_cfg['test_input_size'] 41 | elif pretrained_cfg.get('input_size', None) is not None: 42 | input_size = pretrained_cfg['input_size'] 43 | data_config['input_size'] = input_size 44 | 45 | # resolve interpolation method 46 | data_config['interpolation'] = 'bicubic' 47 | if args.get('interpolation', None): 48 | data_config['interpolation'] = args['interpolation'] 49 | elif pretrained_cfg.get('interpolation', None): 50 | data_config['interpolation'] = pretrained_cfg['interpolation'] 51 | 52 | # resolve dataset + model mean for normalization 53 | data_config['mean'] = IMAGENET_DEFAULT_MEAN 54 | if args.get('mean', None) is not None: 55 | mean = tuple(args['mean']) 56 | if len(mean) == 1: 57 | mean = tuple(list(mean) * in_chans) 58 | else: 59 | assert len(mean) == in_chans 60 | data_config['mean'] = mean 61 | elif pretrained_cfg.get('mean', None): 62 | data_config['mean'] = pretrained_cfg['mean'] 63 | 64 | # resolve dataset + model std deviation for normalization 65 | data_config['std'] = IMAGENET_DEFAULT_STD 66 | if args.get('std', None) is not None: 67 | std = tuple(args['std']) 68 | if len(std) == 1: 69 | std = tuple(list(std) * in_chans) 70 | else: 71 | assert len(std) == in_chans 72 | data_config['std'] = std 73 | elif pretrained_cfg.get('std', None): 74 | data_config['std'] = pretrained_cfg['std'] 75 | 76 | # resolve default inference crop 77 | crop_pct = DEFAULT_CROP_PCT 78 | if args.get('crop_pct', None): 79 | crop_pct = args['crop_pct'] 80 | else: 81 | if use_test_size and pretrained_cfg.get('test_crop_pct', None): 82 | crop_pct = pretrained_cfg['test_crop_pct'] 83 | elif pretrained_cfg.get('crop_pct', None): 84 | crop_pct = pretrained_cfg['crop_pct'] 85 | data_config['crop_pct'] = crop_pct 86 | 87 | # resolve default crop percentage 88 | crop_mode = DEFAULT_CROP_MODE 89 | if args.get('crop_mode', None): 90 | crop_mode = args['crop_mode'] 91 | elif pretrained_cfg.get('crop_mode', None): 92 | crop_mode = pretrained_cfg['crop_mode'] 93 | data_config['crop_mode'] = crop_mode 94 | 95 | if verbose: 96 | _logger.info('Data processing configuration for current model + dataset:') 97 | for n, v in data_config.items(): 98 | _logger.info('\t%s: %s' % (n, str(v))) 99 | 100 | return data_config 101 | 102 | 103 | def resolve_model_data_config( 104 | model, 105 | args=None, 106 | pretrained_cfg=None, 107 | use_test_size=False, 108 | verbose=False, 109 | ): 110 | """ Resolve Model Data Config 111 | This is equivalent to resolve_data_config() but with arguments re-ordered to put model first. 112 | 113 | Args: 114 | model (nn.Module): the model instance 115 | args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg) 116 | pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model) 117 | use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution 118 | verbose (bool): enable extra logging of resolved values 119 | 120 | Returns: 121 | dictionary of config 122 | """ 123 | return resolve_data_config( 124 | args=args, 125 | pretrained_cfg=pretrained_cfg, 126 | model=model, 127 | use_test_size=use_test_size, 128 | verbose=verbose, 129 | ) 130 | -------------------------------------------------------------------------------- /model/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | DEFAULT_CROP_MODE = 'center' 3 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 5 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 7 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 8 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 9 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 10 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 11 | -------------------------------------------------------------------------------- /model/data/dataset_info.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Optional, Union 3 | 4 | 5 | class DatasetInfo(ABC): 6 | 7 | def __init__(self): 8 | pass 9 | 10 | @abstractmethod 11 | def num_classes(self): 12 | pass 13 | 14 | @abstractmethod 15 | def label_names(self): 16 | pass 17 | 18 | @abstractmethod 19 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 20 | pass 21 | 22 | @abstractmethod 23 | def index_to_label_name(self, index) -> str: 24 | pass 25 | 26 | @abstractmethod 27 | def index_to_description(self, index: int, detailed: bool = False) -> str: 28 | pass 29 | 30 | @abstractmethod 31 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 32 | pass 33 | 34 | 35 | class CustomDatasetInfo(DatasetInfo): 36 | """ DatasetInfo that wraps passed values for custom datasets.""" 37 | 38 | def __init__( 39 | self, 40 | label_names: Union[List[str], Dict[int, str]], 41 | label_descriptions: Optional[Dict[str, str]] = None 42 | ): 43 | super().__init__() 44 | assert len(label_names) > 0 45 | self._label_names = label_names # label index => label name mapping 46 | self._label_descriptions = label_descriptions # label name => label description mapping 47 | if self._label_descriptions is not None: 48 | # validate descriptions (label names required) 49 | assert isinstance(self._label_descriptions, dict) 50 | for n in self._label_names: 51 | assert n in self._label_descriptions 52 | 53 | def num_classes(self): 54 | return len(self._label_names) 55 | 56 | def label_names(self): 57 | return self._label_names 58 | 59 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 60 | return self._label_descriptions 61 | 62 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 63 | if self._label_descriptions: 64 | return self._label_descriptions[label] 65 | return label # return label name itself if a descriptions is not present 66 | 67 | def index_to_label_name(self, index) -> str: 68 | assert 0 <= index < len(self._label_names) 69 | return self._label_names[index] 70 | 71 | def index_to_description(self, index: int, detailed: bool = False) -> str: 72 | label = self.index_to_label_name(index) 73 | return self.label_name_to_description(label, detailed=detailed) 74 | -------------------------------------------------------------------------------- /model/data/imagenet_info.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pkgutil 4 | import re 5 | from typing import Dict, List, Optional, Union 6 | 7 | from .dataset_info import DatasetInfo 8 | 9 | 10 | # NOTE no ambiguity wrt to mapping from # classes to ImageNet subset so far, but likely to change 11 | _NUM_CLASSES_TO_SUBSET = { 12 | 1000: 'imagenet-1k', 13 | 11221: 'imagenet-21k-miil', # miil subset of fall11 14 | 11821: 'imagenet-12k', # timm specific 12k subset of fall11 15 | 21841: 'imagenet-22k', # as in fall11.tar 16 | 21842: 'imagenet-22k-ms', # a Microsoft (for FocalNet) remapping of 22k w/ moves ImageNet-1k classes to first 1000 17 | 21843: 'imagenet-21k-goog', # Google's ImageNet full has two classes not in fall11 18 | } 19 | 20 | _SUBSETS = { 21 | 'imagenet1k': 'imagenet_synsets.txt', 22 | 'imagenet12k': 'imagenet12k_synsets.txt', 23 | 'imagenet22k': 'imagenet22k_synsets.txt', 24 | 'imagenet21k': 'imagenet21k_goog_synsets.txt', 25 | 'imagenet21kgoog': 'imagenet21k_goog_synsets.txt', 26 | 'imagenet21kmiil': 'imagenet21k_miil_synsets.txt', 27 | 'imagenet22kms': 'imagenet22k_ms_synsets.txt', 28 | } 29 | _LEMMA_FILE = 'imagenet_synset_to_lemma.txt' 30 | _DEFINITION_FILE = 'imagenet_synset_to_definition.txt' 31 | 32 | 33 | def infer_imagenet_subset(model_or_cfg) -> Optional[str]: 34 | if isinstance(model_or_cfg, dict): 35 | num_classes = model_or_cfg.get('num_classes', None) 36 | else: 37 | num_classes = getattr(model_or_cfg, 'num_classes', None) 38 | if not num_classes: 39 | pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {}) 40 | # FIXME at some point pretrained_cfg should include dataset-tag, 41 | # which will be more robust than a guess based on num_classes 42 | num_classes = pretrained_cfg.get('num_classes', None) 43 | if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET: 44 | return None 45 | return _NUM_CLASSES_TO_SUBSET[num_classes] 46 | 47 | 48 | class ImageNetInfo(DatasetInfo): 49 | 50 | def __init__(self, subset: str = 'imagenet-1k'): 51 | super().__init__() 52 | subset = re.sub(r'[-_\s]', '', subset.lower()) 53 | assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.' 54 | 55 | # WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers 56 | synset_file = _SUBSETS[subset] 57 | synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file)) 58 | self._synsets = synset_data.decode('utf-8').splitlines() 59 | 60 | # WordNet lemmas (canonical dictionary form of word) and definitions are used to build 61 | # the class descriptions. If detailed=True both are used, otherwise just the lemmas. 62 | lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE)) 63 | reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t') 64 | self._lemmas = dict(reader) 65 | definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE)) 66 | reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t') 67 | self._definitions = dict(reader) 68 | 69 | def num_classes(self): 70 | return len(self._synsets) 71 | 72 | def label_names(self): 73 | return self._synsets 74 | 75 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 76 | if as_dict: 77 | return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets} 78 | else: 79 | return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets] 80 | 81 | def index_to_label_name(self, index) -> str: 82 | assert 0 <= index < len(self._synsets), \ 83 | f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.' 84 | return self._synsets[index] 85 | 86 | def index_to_description(self, index: int, detailed: bool = False) -> str: 87 | label = self.index_to_label_name(index) 88 | return self.label_name_to_description(label, detailed=detailed) 89 | 90 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 91 | if detailed: 92 | description = f'{self._lemmas[label]}: {self._definitions[label]}' 93 | else: 94 | description = f'{self._lemmas[label]}' 95 | return description 96 | -------------------------------------------------------------------------------- /model/data/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ Random Erasing (Cutout) 2 | 3 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 4 | Copyright Zhun Zhong & Liang Zheng 5 | 6 | Hacked together by / Copyright 2019, Ross Wightman 7 | """ 8 | import random 9 | import math 10 | 11 | import torch 12 | 13 | 14 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): 15 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 16 | # paths, flip the order so normal is run on CPU if this becomes a problem 17 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 18 | if per_pixel: 19 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 20 | elif rand_color: 21 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 22 | else: 23 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 24 | 25 | 26 | class RandomErasing: 27 | """ Randomly selects a rectangle region in an image and erases its pixels. 28 | 'Random Erasing Data Augmentation' by Zhong et al. 29 | See https://arxiv.org/pdf/1708.04896.pdf 30 | 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1/3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode='const', 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device='cuda', 58 | ): 59 | self.probability = probability 60 | self.min_area = min_area 61 | self.max_area = max_area 62 | max_aspect = max_aspect or 1 / min_aspect 63 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 64 | self.min_count = min_count 65 | self.max_count = max_count or min_count 66 | self.num_splits = num_splits 67 | self.mode = mode.lower() 68 | self.rand_color = False 69 | self.per_pixel = False 70 | if self.mode == 'rand': 71 | self.rand_color = True # per block random normal 72 | elif self.mode == 'pixel': 73 | self.per_pixel = True # per pixel random normal 74 | else: 75 | assert not self.mode or self.mode == 'const' 76 | self.device = device 77 | 78 | def _erase(self, img, chan, img_h, img_w, dtype): 79 | if random.random() > self.probability: 80 | return 81 | area = img_h * img_w 82 | count = self.min_count if self.min_count == self.max_count else \ 83 | random.randint(self.min_count, self.max_count) 84 | for _ in range(count): 85 | for attempt in range(10): 86 | target_area = random.uniform(self.min_area, self.max_area) * area / count 87 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 88 | h = int(round(math.sqrt(target_area * aspect_ratio))) 89 | w = int(round(math.sqrt(target_area / aspect_ratio))) 90 | if w < img_w and h < img_h: 91 | top = random.randint(0, img_h - h) 92 | left = random.randint(0, img_w - w) 93 | img[:, top:top + h, left:left + w] = _get_pixels( 94 | self.per_pixel, 95 | self.rand_color, 96 | (chan, h, w), 97 | dtype=dtype, 98 | device=self.device, 99 | ) 100 | break 101 | 102 | def __call__(self, input): 103 | if len(input.size()) == 3: 104 | self._erase(input, *input.size(), input.dtype) 105 | else: 106 | batch_size, chan, img_h, img_w = input.size() 107 | # skip first slice of batch if num_splits is set (for clean portion of samples) 108 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 109 | for i in range(batch_start, batch_size): 110 | self._erase(input[i], chan, img_h, img_w, input.dtype) 111 | return input 112 | 113 | def __repr__(self): 114 | # NOTE simplified state for repr 115 | fs = self.__class__.__name__ + f'(p={self.probability}, mode={self.mode}' 116 | fs += f', count=({self.min_count}, {self.max_count}))' 117 | return fs 118 | -------------------------------------------------------------------------------- /model/data/readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader_factory import create_reader 2 | from .img_extensions import * 3 | -------------------------------------------------------------------------------- /model/data/readers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/readers/__pycache__/class_map.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/class_map.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/readers/__pycache__/img_extensions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/img_extensions.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/readers/__pycache__/reader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/reader.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/readers/__pycache__/reader_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/reader_factory.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/readers/__pycache__/reader_image_folder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/reader_image_folder.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/readers/__pycache__/reader_image_in_tar.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/data/readers/__pycache__/reader_image_in_tar.cpython-38.pyc -------------------------------------------------------------------------------- /model/data/readers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | 5 | def load_class_map(map_or_filename, root=''): 6 | if isinstance(map_or_filename, dict): 7 | assert dict, 'class_map dict must be non-empty' 8 | return map_or_filename 9 | class_map_path = map_or_filename 10 | if not os.path.exists(class_map_path): 11 | class_map_path = os.path.join(root, class_map_path) 12 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename 13 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower() 14 | if class_map_ext == '.txt': 15 | with open(class_map_path) as f: 16 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 17 | elif class_map_ext == '.pkl': 18 | with open(class_map_path, 'rb') as f: 19 | class_to_idx = pickle.load(f) 20 | else: 21 | assert False, f'Unsupported class map file extension ({class_map_ext}).' 22 | return class_to_idx 23 | 24 | -------------------------------------------------------------------------------- /model/data/readers/img_extensions.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] 4 | 5 | 6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use 7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync 8 | 9 | 10 | def _set_extensions(extensions): 11 | global IMG_EXTENSIONS 12 | global _IMG_EXTENSIONS_SET 13 | dedupe = set() # NOTE de-duping tuple while keeping original order 14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) 15 | _IMG_EXTENSIONS_SET = set(extensions) 16 | 17 | 18 | def _valid_extension(x: str): 19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') 20 | 21 | 22 | def is_img_extension(ext): 23 | return ext in _IMG_EXTENSIONS_SET 24 | 25 | 26 | def get_img_extensions(as_set=False): 27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) 28 | 29 | 30 | def set_img_extensions(extensions): 31 | assert len(extensions) 32 | for x in extensions: 33 | assert _valid_extension(x) 34 | _set_extensions(extensions) 35 | 36 | 37 | def add_img_extensions(ext): 38 | if not isinstance(ext, (list, tuple, set)): 39 | ext = (ext,) 40 | for x in ext: 41 | assert _valid_extension(x) 42 | extensions = IMG_EXTENSIONS + tuple(ext) 43 | _set_extensions(extensions) 44 | 45 | 46 | def del_img_extensions(ext): 47 | if not isinstance(ext, (list, tuple, set)): 48 | ext = (ext,) 49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) 50 | _set_extensions(extensions) 51 | -------------------------------------------------------------------------------- /model/data/readers/reader.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Reader: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /model/data/readers/reader_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .reader_image_folder import ReaderImageFolder 4 | from .reader_image_in_tar import ReaderImageInTar 5 | 6 | 7 | def create_reader(name, root, split='train', **kwargs): 8 | name = name.lower() 9 | name = name.split('/', 1) 10 | prefix = '' 11 | if len(name) > 1: 12 | prefix = name[0] 13 | name = name[-1] 14 | 15 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 16 | # explicitly select other options shortly 17 | if prefix == 'hfds': 18 | from .reader_hfds import ReaderHfds # defer tensorflow import 19 | reader = ReaderHfds(root, name, split=split, **kwargs) 20 | elif prefix == 'tfds': 21 | from .reader_tfds import ReaderTfds # defer tensorflow import 22 | reader = ReaderTfds(root, name, split=split, **kwargs) 23 | elif prefix == 'wds': 24 | from .reader_wds import ReaderWds 25 | kwargs.pop('download', False) 26 | reader = ReaderWds(root, name, split=split, **kwargs) 27 | else: 28 | assert os.path.exists(root) 29 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 30 | # FIXME support split here or in reader? 31 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 32 | reader = ReaderImageInTar(root, **kwargs) 33 | else: 34 | reader = ReaderImageFolder(root, **kwargs) 35 | return reader 36 | -------------------------------------------------------------------------------- /model/data/readers/reader_hfds.py: -------------------------------------------------------------------------------- 1 | """ Dataset reader that wraps Hugging Face datasets 2 | 3 | Hacked together by / Copyright 2022 Ross Wightman 4 | """ 5 | import io 6 | import math 7 | import torch 8 | import torch.distributed as dist 9 | from PIL import Image 10 | 11 | try: 12 | import datasets 13 | except ImportError as e: 14 | print("Please install Hugging Face datasets package `pip install datasets`.") 15 | exit(1) 16 | from .class_map import load_class_map 17 | from .reader import Reader 18 | 19 | 20 | def get_class_labels(info, label_key='label'): 21 | if 'label' not in info.features: 22 | return {} 23 | class_label = info.features[label_key] 24 | class_to_idx = {n: class_label.str2int(n) for n in class_label.names} 25 | return class_to_idx 26 | 27 | 28 | class ReaderHfds(Reader): 29 | 30 | def __init__( 31 | self, 32 | root, 33 | name, 34 | split='train', 35 | class_map=None, 36 | label_key='label', 37 | download=False, 38 | ): 39 | """ 40 | """ 41 | super().__init__() 42 | self.root = root 43 | self.split = split 44 | self.dataset = datasets.load_dataset( 45 | name, # 'name' maps to path arg in hf datasets 46 | split=split, 47 | cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path 48 | ) 49 | # leave decode for caller, plus we want easy access to original path names... 50 | self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False)) 51 | 52 | self.label_key = label_key 53 | self.remap_class = False 54 | if class_map: 55 | self.class_to_idx = load_class_map(class_map) 56 | self.remap_class = True 57 | else: 58 | self.class_to_idx = get_class_labels(self.dataset.info, self.label_key) 59 | self.split_info = self.dataset.info.splits[split] 60 | self.num_samples = self.split_info.num_examples 61 | 62 | def __getitem__(self, index): 63 | item = self.dataset[index] 64 | image = item['image'] 65 | if 'bytes' in image and image['bytes']: 66 | image = io.BytesIO(image['bytes']) 67 | else: 68 | assert 'path' in image and image['path'] 69 | image = open(image['path'], 'rb') 70 | label = item[self.label_key] 71 | if self.remap_class: 72 | label = self.class_to_idx[label] 73 | return image, label 74 | 75 | def __len__(self): 76 | return len(self.dataset) 77 | 78 | def _filename(self, index, basename=False, absolute=False): 79 | item = self.dataset[index] 80 | return item['image']['path'] 81 | -------------------------------------------------------------------------------- /model/data/readers/reader_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset reader that extracts images from folders 2 | 3 | Folders are scanned recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | from typing import Dict, List, Optional, Set, Tuple, Union 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .reader import Reader 16 | 17 | 18 | def find_images_and_targets( 19 | folder: str, 20 | types: Optional[Union[List, Tuple, Set]] = None, 21 | class_to_idx: Optional[Dict] = None, 22 | leaf_name_only: bool = True, 23 | sort: bool = True 24 | ): 25 | """ Walk folder recursively to discover images and map them to classes by folder names. 26 | 27 | Args: 28 | folder: root of folder to recrusively search 29 | types: types (file extensions) to search for in path 30 | class_to_idx: specify mapping for class (folder name) to class index if set 31 | leaf_name_only: use only leaf-name of folder walk for class names 32 | sort: re-sort found images by name (for consistent ordering) 33 | 34 | Returns: 35 | A list of image and target tuples, class_to_idx mapping 36 | """ 37 | types = get_img_extensions(as_set=True) if not types else set(types) 38 | labels = [] 39 | filenames = [] 40 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 41 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 42 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 43 | for f in files: 44 | base, ext = os.path.splitext(f) 45 | if ext.lower() in types: 46 | filenames.append(os.path.join(root, f)) 47 | labels.append(label) 48 | if class_to_idx is None: 49 | # building class index 50 | unique_labels = set(labels) 51 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 52 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 53 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 54 | if sort: 55 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 56 | return images_and_targets, class_to_idx 57 | 58 | 59 | class ReaderImageFolder(Reader): 60 | 61 | def __init__( 62 | self, 63 | root, 64 | class_map=''): 65 | super().__init__() 66 | 67 | self.root = root 68 | class_to_idx = None 69 | if class_map: 70 | class_to_idx = load_class_map(class_map, root) 71 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 72 | if len(self.samples) == 0: 73 | raise RuntimeError( 74 | f'Found 0 images in subfolders of {root}. ' 75 | f'Supported image extensions are {", ".join(get_img_extensions())}') 76 | 77 | def __getitem__(self, index): 78 | path, target = self.samples[index] 79 | return open(path, 'rb'), target 80 | 81 | def __len__(self): 82 | return len(self.samples) 83 | 84 | def _filename(self, index, basename=False, absolute=False): 85 | filename = self.samples[index][0] 86 | if basename: 87 | filename = os.path.basename(filename) 88 | elif not absolute: 89 | filename = os.path.relpath(filename, self.root) 90 | return filename 91 | -------------------------------------------------------------------------------- /model/data/readers/reader_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset reader that reads single tarfile based datasets 2 | 3 | This reader can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .reader import Reader 16 | 17 | 18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 19 | extensions = get_img_extensions(as_set=True) 20 | files = [] 21 | labels = [] 22 | for ti in tarfile.getmembers(): 23 | if not ti.isfile(): 24 | continue 25 | dirname, basename = os.path.split(ti.path) 26 | label = os.path.basename(dirname) 27 | ext = os.path.splitext(basename)[1] 28 | if ext.lower() in extensions: 29 | files.append(ti) 30 | labels.append(label) 31 | if class_to_idx is None: 32 | unique_labels = set(labels) 33 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 36 | if sort: 37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 38 | return tarinfo_and_targets, class_to_idx 39 | 40 | 41 | class ReaderImageTar(Reader): 42 | """ Single tarfile dataset where classes are mapped to folders within tar 43 | NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can 44 | operate on folders of tars or tars in tars. 45 | """ 46 | def __init__(self, root, class_map=''): 47 | super().__init__() 48 | 49 | class_to_idx = None 50 | if class_map: 51 | class_to_idx = load_class_map(class_map, root) 52 | assert os.path.isfile(root) 53 | self.root = root 54 | 55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 57 | self.imgs = self.samples 58 | self.tarfile = None # lazy init in __getitem__ 59 | 60 | def __getitem__(self, index): 61 | if self.tarfile is None: 62 | self.tarfile = tarfile.open(self.root) 63 | tarinfo, target = self.samples[index] 64 | fileobj = self.tarfile.extractfile(tarinfo) 65 | return fileobj, target 66 | 67 | def __len__(self): 68 | return len(self.samples) 69 | 70 | def _filename(self, index, basename=False, absolute=False): 71 | filename = self.samples[index][0].name 72 | if basename: 73 | filename = os.path.basename(filename) 74 | return filename 75 | -------------------------------------------------------------------------------- /model/data/readers/shared_count.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Value 2 | 3 | 4 | class SharedCount: 5 | def __init__(self, epoch: int = 0): 6 | self.shared_epoch = Value('i', epoch) 7 | 8 | @property 9 | def value(self): 10 | return self.shared_epoch.value 11 | 12 | @value.setter 13 | def value(self, epoch): 14 | self.shared_epoch.value = epoch 15 | -------------------------------------------------------------------------------- /model/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | import pkgutil 11 | 12 | 13 | class RealLabelsImagenet: 14 | 15 | def __init__(self, filenames, real_json=None, topk=(1, 5)): 16 | if real_json is not None: 17 | with open(real_json) as real_labels: 18 | real_labels = json.load(real_labels) 19 | else: 20 | real_labels = json.loads( 21 | pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8')) 22 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 23 | self.real_labels = real_labels 24 | self.filenames = filenames 25 | assert len(self.filenames) == len(self.real_labels) 26 | self.topk = topk 27 | self.is_correct = {k: [] for k in topk} 28 | self.sample_idx = 0 29 | 30 | def add_result(self, output): 31 | maxk = max(self.topk) 32 | _, pred_batch = output.topk(maxk, 1, True, True) 33 | pred_batch = pred_batch.cpu().numpy() 34 | for pred in pred_batch: 35 | filename = self.filenames[self.sample_idx] 36 | filename = os.path.basename(filename) 37 | if self.real_labels[filename]: 38 | for k in self.topk: 39 | self.is_correct[k].append( 40 | any([p in self.real_labels[filename] for p in pred[:k]])) 41 | self.sample_idx += 1 42 | 43 | def get_accuracy(self, k=None): 44 | if k is None: 45 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 46 | else: 47 | return float(np.mean(self.is_correct[k])) * 100 48 | -------------------------------------------------------------------------------- /model/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .attention_pool import AttentionPoolLatent 5 | from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding 6 | from .blur_pool import BlurPool2d 7 | from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead 8 | from .cond_conv2d import CondConv2d, get_condconv_initializer 9 | from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ 10 | set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn 11 | from .conv2d_same import Conv2dSame, conv2d_same 12 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 13 | from .create_act import create_act_layer, get_act_layer, get_act_fn 14 | from .create_attn import get_attn, create_attn 15 | from .create_conv2d import create_conv2d 16 | from .create_norm import get_norm_layer, create_norm_layer 17 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 18 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 19 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 20 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 21 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 22 | from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm 23 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 24 | from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to 25 | from .gather_excite import GatherExcite 26 | from .global_context import GlobalContext 27 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple 28 | from .inplace_abn import InplaceAbn 29 | from .linear import Linear 30 | from .mixed_conv2d import MixedConv2d 31 | from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp 32 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 33 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm 34 | from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ 35 | SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d 36 | from .padding import get_padding, get_same_padding, pad_same 37 | from .patch_dropout import PatchDropout 38 | from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed 39 | from .pool2d_same import AvgPool2dSame, create_pool2d 40 | from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc 41 | from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ 42 | resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit 43 | from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ 44 | build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ 45 | FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat 46 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 47 | from .selective_kernel import SelectiveKernel 48 | from .separable_conv import SeparableConv2d, SeparableConvNormAct 49 | from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace 50 | from .split_attn import SplitAttn 51 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 52 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 53 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 54 | from .trace_utils import _assert, _float_to_int 55 | from .typing import LayerType, PadType 56 | from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 57 | -------------------------------------------------------------------------------- /model/layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/activations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/activations.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/activations_jit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/activations_jit.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/activations_me.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/activations_me.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/attention_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/attention_pool.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/attention_pool2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/attention_pool2d.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/blur_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/blur_pool.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/bottleneck_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/bottleneck_attn.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/cbam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/cbam.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/classifier.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/classifier.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/cond_conv2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/cond_conv2d.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/conv2d_same.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/conv2d_same.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/conv_bn_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/conv_bn_act.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/create_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_act.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/create_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_attn.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/create_conv2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_conv2d.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/create_norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_norm.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/create_norm_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/create_norm_act.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/drop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/drop.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/eca.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/eca.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/evo_norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/evo_norm.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/fast_norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/fast_norm.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/filter_response_norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/filter_response_norm.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/format.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/format.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/gather_excite.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/gather_excite.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/global_context.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/global_context.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/grn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/grn.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/halo_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/halo_attn.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/helpers.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/inplace_abn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/inplace_abn.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/interpolate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/interpolate.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/lambda_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/lambda_layer.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/linear.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/linear.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/mixed_conv2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/mixed_conv2d.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/mlp.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/non_local_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/non_local_attn.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/norm.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/norm_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/norm_act.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/padding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/padding.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/patch_dropout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/patch_dropout.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/patch_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/patch_embed.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/pool2d_same.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/pool2d_same.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/pos_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/pos_embed.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/pos_embed_rel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/pos_embed_rel.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/pos_embed_sincos.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/pos_embed_sincos.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/selective_kernel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/selective_kernel.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/separable_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/separable_conv.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/space_to_depth.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/space_to_depth.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/split_attn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/split_attn.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/split_batchnorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/split_batchnorm.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/squeeze_excite.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/squeeze_excite.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/std_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/std_conv.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/test_time_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/test_time_pool.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/trace_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/trace_utils.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/typing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/typing.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/__pycache__/weight_init.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangbobby1026/PVF-Dataset/5702a07713fa598de7c15baa21b796bf9c8eb5ec/model/layers/__pycache__/weight_init.cpython-38.pyc -------------------------------------------------------------------------------- /model/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /model/layers/attention_pool.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .config import use_fused_attn 8 | from .mlp import Mlp 9 | from .weight_init import trunc_normal_tf_ 10 | 11 | 12 | class AttentionPoolLatent(nn.Module): 13 | """ Attention pooling w/ latent query 14 | """ 15 | fused_attn: torch.jit.Final[bool] 16 | 17 | def __init__( 18 | self, 19 | in_features: int, 20 | out_features: int = None, 21 | embed_dim: int = None, 22 | num_heads: int = 8, 23 | mlp_ratio: float = 4.0, 24 | qkv_bias: bool = True, 25 | qk_norm: bool = False, 26 | latent_len: int = 1, 27 | latent_dim: int = None, 28 | pos_embed: str = '', 29 | pool_type: str = 'token', 30 | norm_layer: Optional[nn.Module] = None, 31 | drop: float = 0.0, 32 | ): 33 | super().__init__() 34 | embed_dim = embed_dim or in_features 35 | out_features = out_features or in_features 36 | assert embed_dim % num_heads == 0 37 | self.num_heads = num_heads 38 | self.head_dim = embed_dim // num_heads 39 | self.scale = self.head_dim ** -0.5 40 | self.pool = pool_type 41 | self.fused_attn = use_fused_attn() 42 | 43 | if pos_embed == 'abs': 44 | spatial_len = self.feat_size 45 | self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) 46 | else: 47 | self.pos_embed = None 48 | 49 | self.latent_dim = latent_dim or embed_dim 50 | self.latent_len = latent_len 51 | self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) 52 | 53 | self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) 54 | self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) 55 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 56 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 57 | self.proj = nn.Linear(embed_dim, embed_dim) 58 | self.proj_drop = nn.Dropout(drop) 59 | 60 | self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() 61 | self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) 62 | 63 | self.init_weights() 64 | 65 | def init_weights(self): 66 | if self.pos_embed is not None: 67 | trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) 68 | trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) 69 | 70 | def forward(self, x): 71 | B, N, C = x.shape 72 | 73 | if self.pos_embed is not None: 74 | # FIXME interpolate 75 | x = x + self.pos_embed.unsqueeze(0).to(x.dtype) 76 | 77 | q_latent = self.latent.expand(B, -1, -1) 78 | q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) 79 | 80 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 81 | k, v = kv.unbind(0) 82 | 83 | q, k = self.q_norm(q), self.k_norm(k) 84 | 85 | if self.fused_attn: 86 | x = F.scaled_dot_product_attention(q, k, v) 87 | else: 88 | q = q * self.scale 89 | attn = q @ k.transpose(-2, -1) 90 | attn = attn.softmax(dim=-1) 91 | x = attn @ v 92 | x = x.transpose(1, 2).reshape(B, self.latent_len, C) 93 | x = self.proj(x) 94 | x = self.proj_drop(x) 95 | 96 | x = x + self.mlp(self.norm(x)) 97 | 98 | # optional pool if latent seq_len > 1 and pooled output is desired 99 | if self.pool == 'token': 100 | x = x[:, 0] 101 | elif self.pool == 'avg': 102 | x = x.mean(1) 103 | return x -------------------------------------------------------------------------------- /model/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) 43 | -------------------------------------------------------------------------------- /model/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman 9 | """ 10 | import torch 11 | from torch import nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .conv_bn_act import ConvNormAct 15 | from .create_act import create_act_layer, get_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class ChannelAttn(nn.Module): 20 | """ Original CBAM channel attention module, currently avg + max pool variant only. 21 | """ 22 | def __init__( 23 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 24 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 25 | super(ChannelAttn, self).__init__() 26 | if not rd_channels: 27 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 28 | self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) 29 | self.act = act_layer(inplace=True) 30 | self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) 31 | self.gate = create_act_layer(gate_layer) 32 | 33 | def forward(self, x): 34 | x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) 35 | x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) 36 | return x * self.gate(x_avg + x_max) 37 | 38 | 39 | class LightChannelAttn(ChannelAttn): 40 | """An experimental 'lightweight' that sums avg + max pool first 41 | """ 42 | def __init__( 43 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 44 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 45 | super(LightChannelAttn, self).__init__( 46 | channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) 47 | 48 | def forward(self, x): 49 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) 50 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 51 | return x * F.sigmoid(x_attn) 52 | 53 | 54 | class SpatialAttn(nn.Module): 55 | """ Original CBAM spatial attention module 56 | """ 57 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 58 | super(SpatialAttn, self).__init__() 59 | self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) 60 | self.gate = create_act_layer(gate_layer) 61 | 62 | def forward(self, x): 63 | x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) 64 | x_attn = self.conv(x_attn) 65 | return x * self.gate(x_attn) 66 | 67 | 68 | class LightSpatialAttn(nn.Module): 69 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 70 | """ 71 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 72 | super(LightSpatialAttn, self).__init__() 73 | self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) 74 | self.gate = create_act_layer(gate_layer) 75 | 76 | def forward(self, x): 77 | x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) 78 | x_attn = self.conv(x_attn) 79 | return x * self.gate(x_attn) 80 | 81 | 82 | class CbamModule(nn.Module): 83 | def __init__( 84 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 85 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 86 | super(CbamModule, self).__init__() 87 | self.channel = ChannelAttn( 88 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 89 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 90 | self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) 91 | 92 | def forward(self, x): 93 | x = self.channel(x) 94 | x = self.spatial(x) 95 | return x 96 | 97 | 98 | class LightCbamModule(nn.Module): 99 | def __init__( 100 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 101 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 102 | super(LightCbamModule, self).__init__() 103 | self.channel = LightChannelAttn( 104 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 105 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 106 | self.spatial = LightSpatialAttn(spatial_kernel_size) 107 | 108 | def forward(self, x): 109 | x = self.channel(x) 110 | x = self.spatial(x) 111 | return x 112 | 113 | -------------------------------------------------------------------------------- /model/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | import os 4 | import warnings 5 | from typing import Any, Optional 6 | 7 | import torch 8 | 9 | __all__ = [ 10 | 'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn', 11 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn' 12 | ] 13 | 14 | # Set to True if prefer to have layers with no jit optimization (includes activations) 15 | _NO_JIT = False 16 | 17 | # Set to True if prefer to have activation layers with no jit optimization 18 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 19 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 20 | _NO_ACTIVATION_JIT = False 21 | 22 | # Set to True if exporting a model with Same padding via ONNX 23 | _EXPORTABLE = False 24 | 25 | # Set to True if wanting to use torch.jit.script on a model 26 | _SCRIPTABLE = False 27 | 28 | 29 | # use torch.scaled_dot_product_attention where possible 30 | _HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 31 | if 'TIMM_FUSED_ATTN' in os.environ: 32 | _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN']) 33 | else: 34 | _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) 35 | 36 | 37 | def is_no_jit(): 38 | return _NO_JIT 39 | 40 | 41 | class set_no_jit: 42 | def __init__(self, mode: bool) -> None: 43 | global _NO_JIT 44 | self.prev = _NO_JIT 45 | _NO_JIT = mode 46 | 47 | def __enter__(self) -> None: 48 | pass 49 | 50 | def __exit__(self, *args: Any) -> bool: 51 | global _NO_JIT 52 | _NO_JIT = self.prev 53 | return False 54 | 55 | 56 | def is_exportable(): 57 | return _EXPORTABLE 58 | 59 | 60 | class set_exportable: 61 | def __init__(self, mode: bool) -> None: 62 | global _EXPORTABLE 63 | self.prev = _EXPORTABLE 64 | _EXPORTABLE = mode 65 | 66 | def __enter__(self) -> None: 67 | pass 68 | 69 | def __exit__(self, *args: Any) -> bool: 70 | global _EXPORTABLE 71 | _EXPORTABLE = self.prev 72 | return False 73 | 74 | 75 | def is_scriptable(): 76 | return _SCRIPTABLE 77 | 78 | 79 | class set_scriptable: 80 | def __init__(self, mode: bool) -> None: 81 | global _SCRIPTABLE 82 | self.prev = _SCRIPTABLE 83 | _SCRIPTABLE = mode 84 | 85 | def __enter__(self) -> None: 86 | pass 87 | 88 | def __exit__(self, *args: Any) -> bool: 89 | global _SCRIPTABLE 90 | _SCRIPTABLE = self.prev 91 | return False 92 | 93 | 94 | class set_layer_config: 95 | """ Layer config context manager that allows setting all layer config flags at once. 96 | If a flag arg is None, it will not change the current value. 97 | """ 98 | def __init__( 99 | self, 100 | scriptable: Optional[bool] = None, 101 | exportable: Optional[bool] = None, 102 | no_jit: Optional[bool] = None, 103 | no_activation_jit: Optional[bool] = None): 104 | global _SCRIPTABLE 105 | global _EXPORTABLE 106 | global _NO_JIT 107 | global _NO_ACTIVATION_JIT 108 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 109 | if scriptable is not None: 110 | _SCRIPTABLE = scriptable 111 | if exportable is not None: 112 | _EXPORTABLE = exportable 113 | if no_jit is not None: 114 | _NO_JIT = no_jit 115 | if no_activation_jit is not None: 116 | _NO_ACTIVATION_JIT = no_activation_jit 117 | 118 | def __enter__(self) -> None: 119 | pass 120 | 121 | def __exit__(self, *args: Any) -> bool: 122 | global _SCRIPTABLE 123 | global _EXPORTABLE 124 | global _NO_JIT 125 | global _NO_ACTIVATION_JIT 126 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 127 | return False 128 | 129 | 130 | def use_fused_attn(experimental: bool = False) -> bool: 131 | # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0 132 | if not _HAS_FUSED_ATTN or _EXPORTABLE: 133 | return False 134 | if experimental: 135 | return _USE_FUSED_ATTN > 1 136 | return _USE_FUSED_ATTN > 0 137 | 138 | 139 | def set_fused_attn(enable: bool = True, experimental: bool = False): 140 | global _USE_FUSED_ATTN 141 | if not _HAS_FUSED_ATTN: 142 | warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.') 143 | return 144 | if experimental and enable: 145 | _USE_FUSED_ATTN = 2 146 | elif enable: 147 | _USE_FUSED_ATTN = 1 148 | else: 149 | _USE_FUSED_ATTN = 0 150 | -------------------------------------------------------------------------------- /model/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .config import is_exportable, is_scriptable 11 | from .padding import pad_same, pad_same_arg, get_padding_value 12 | 13 | 14 | _USE_EXPORT_CONV = False 15 | 16 | 17 | def conv2d_same( 18 | x, 19 | weight: torch.Tensor, 20 | bias: Optional[torch.Tensor] = None, 21 | stride: Tuple[int, int] = (1, 1), 22 | padding: Tuple[int, int] = (0, 0), 23 | dilation: Tuple[int, int] = (1, 1), 24 | groups: int = 1, 25 | ): 26 | x = pad_same(x, weight.shape[-2:], stride, dilation) 27 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 28 | 29 | 30 | class Conv2dSame(nn.Conv2d): 31 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 32 | """ 33 | 34 | def __init__( 35 | self, 36 | in_channels, 37 | out_channels, 38 | kernel_size, 39 | stride=1, 40 | padding=0, 41 | dilation=1, 42 | groups=1, 43 | bias=True, 44 | ): 45 | super(Conv2dSame, self).__init__( 46 | in_channels, out_channels, kernel_size, 47 | stride, 0, dilation, groups, bias, 48 | ) 49 | 50 | def forward(self, x): 51 | return conv2d_same( 52 | x, self.weight, self.bias, 53 | self.stride, self.padding, self.dilation, self.groups, 54 | ) 55 | 56 | 57 | class Conv2dSameExport(nn.Conv2d): 58 | """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions 59 | 60 | NOTE: This does not currently work with torch.jit.script 61 | """ 62 | 63 | # pylint: disable=unused-argument 64 | def __init__( 65 | self, 66 | in_channels, 67 | out_channels, 68 | kernel_size, 69 | stride=1, 70 | padding=0, 71 | dilation=1, 72 | groups=1, 73 | bias=True, 74 | ): 75 | super(Conv2dSameExport, self).__init__( 76 | in_channels, out_channels, kernel_size, 77 | stride, 0, dilation, groups, bias, 78 | ) 79 | self.pad = None 80 | self.pad_input_size = (0, 0) 81 | 82 | def forward(self, x): 83 | input_size = x.size()[-2:] 84 | if self.pad is None: 85 | pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) 86 | self.pad = nn.ZeroPad2d(pad_arg) 87 | self.pad_input_size = input_size 88 | 89 | x = self.pad(x) 90 | return F.conv2d( 91 | x, self.weight, self.bias, 92 | self.stride, self.padding, self.dilation, self.groups, 93 | ) 94 | 95 | 96 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 97 | padding = kwargs.pop('padding', '') 98 | kwargs.setdefault('bias', False) 99 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 100 | if is_dynamic: 101 | if _USE_EXPORT_CONV and is_exportable(): 102 | # older PyTorch ver needed this to export same padding reasonably 103 | assert not is_scriptable() # Conv2DSameExport does not work with jit 104 | return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) 105 | else: 106 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 107 | else: 108 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 109 | 110 | 111 | -------------------------------------------------------------------------------- /model/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import functools 6 | from torch import nn as nn 7 | 8 | from .create_conv2d import create_conv2d 9 | from .create_norm_act import get_norm_act_layer 10 | 11 | 12 | class ConvNormAct(nn.Module): 13 | def __init__( 14 | self, 15 | in_channels, 16 | out_channels, 17 | kernel_size=1, 18 | stride=1, 19 | padding='', 20 | dilation=1, 21 | groups=1, 22 | bias=False, 23 | apply_act=True, 24 | norm_layer=nn.BatchNorm2d, 25 | norm_kwargs=None, 26 | act_layer=nn.ReLU, 27 | act_kwargs=None, 28 | drop_layer=None, 29 | ): 30 | super(ConvNormAct, self).__init__() 31 | norm_kwargs = norm_kwargs or {} 32 | act_kwargs = act_kwargs or {} 33 | 34 | self.conv = create_conv2d( 35 | in_channels, out_channels, kernel_size, stride=stride, 36 | padding=padding, dilation=dilation, groups=groups, bias=bias) 37 | 38 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 39 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 40 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 41 | if drop_layer: 42 | norm_kwargs['drop_layer'] = drop_layer 43 | self.bn = norm_act_layer( 44 | out_channels, 45 | apply_act=apply_act, 46 | act_kwargs=act_kwargs, 47 | **norm_kwargs, 48 | ) 49 | 50 | @property 51 | def in_channels(self): 52 | return self.conv.in_channels 53 | 54 | @property 55 | def out_channels(self): 56 | return self.conv.out_channels 57 | 58 | def forward(self, x): 59 | x = self.conv(x) 60 | x = self.bn(x) 61 | return x 62 | 63 | 64 | ConvBnAct = ConvNormAct 65 | 66 | 67 | def create_aa(aa_layer, channels, stride=2, enable=True): 68 | if not aa_layer or not enable: 69 | return nn.Identity() 70 | if isinstance(aa_layer, functools.partial): 71 | if issubclass(aa_layer.func, nn.AvgPool2d): 72 | return aa_layer() 73 | else: 74 | return aa_layer(channels) 75 | elif issubclass(aa_layer, nn.AvgPool2d): 76 | return aa_layer(stride) 77 | else: 78 | return aa_layer(channels=channels, stride=stride) 79 | 80 | 81 | class ConvNormActAa(nn.Module): 82 | def __init__( 83 | self, 84 | in_channels, 85 | out_channels, 86 | kernel_size=1, 87 | stride=1, 88 | padding='', 89 | dilation=1, 90 | groups=1, 91 | bias=False, 92 | apply_act=True, 93 | norm_layer=nn.BatchNorm2d, 94 | norm_kwargs=None, 95 | act_layer=nn.ReLU, 96 | act_kwargs=None, 97 | aa_layer=None, 98 | drop_layer=None, 99 | ): 100 | super(ConvNormActAa, self).__init__() 101 | use_aa = aa_layer is not None and stride == 2 102 | norm_kwargs = norm_kwargs or {} 103 | act_kwargs = act_kwargs or {} 104 | 105 | self.conv = create_conv2d( 106 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 107 | padding=padding, dilation=dilation, groups=groups, bias=bias) 108 | 109 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 110 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 111 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 112 | if drop_layer: 113 | norm_kwargs['drop_layer'] = drop_layer 114 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs) 115 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) 116 | 117 | @property 118 | def in_channels(self): 119 | return self.conv.in_channels 120 | 121 | @property 122 | def out_channels(self): 123 | return self.conv.out_channels 124 | 125 | def forward(self, x): 126 | x = self.conv(x) 127 | x = self.bn(x) 128 | x = self.aa(x) 129 | return x 130 | -------------------------------------------------------------------------------- /model/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Attention Factory 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | import torch 6 | from functools import partial 7 | 8 | from .bottleneck_attn import BottleneckAttn 9 | from .cbam import CbamModule, LightCbamModule 10 | from .eca import EcaModule, CecaModule 11 | from .gather_excite import GatherExcite 12 | from .global_context import GlobalContext 13 | from .halo_attn import HaloAttn 14 | from .lambda_layer import LambdaLayer 15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 16 | from .selective_kernel import SelectiveKernel 17 | from .split_attn import SplitAttn 18 | from .squeeze_excite import SEModule, EffectiveSEModule 19 | 20 | 21 | def get_attn(attn_type): 22 | if isinstance(attn_type, torch.nn.Module): 23 | return attn_type 24 | module_cls = None 25 | if attn_type: 26 | if isinstance(attn_type, str): 27 | attn_type = attn_type.lower() 28 | # Lightweight attention modules (channel and/or coarse spatial). 29 | # Typically added to existing network architecture blocks in addition to existing convolutions. 30 | if attn_type == 'se': 31 | module_cls = SEModule 32 | elif attn_type == 'ese': 33 | module_cls = EffectiveSEModule 34 | elif attn_type == 'eca': 35 | module_cls = EcaModule 36 | elif attn_type == 'ecam': 37 | module_cls = partial(EcaModule, use_mlp=True) 38 | elif attn_type == 'ceca': 39 | module_cls = CecaModule 40 | elif attn_type == 'ge': 41 | module_cls = GatherExcite 42 | elif attn_type == 'gc': 43 | module_cls = GlobalContext 44 | elif attn_type == 'gca': 45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) 46 | elif attn_type == 'cbam': 47 | module_cls = CbamModule 48 | elif attn_type == 'lcbam': 49 | module_cls = LightCbamModule 50 | 51 | # Attention / attention-like modules w/ significant params 52 | # Typically replace some of the existing workhorse convs in a network architecture. 53 | # All of these accept a stride argument and can spatially downsample the input. 54 | elif attn_type == 'sk': 55 | module_cls = SelectiveKernel 56 | elif attn_type == 'splat': 57 | module_cls = SplitAttn 58 | 59 | # Self-attention / attention-like modules w/ significant compute and/or params 60 | # Typically replace some of the existing workhorse convs in a network architecture. 61 | # All of these accept a stride argument and can spatially downsample the input. 62 | elif attn_type == 'lambda': 63 | return LambdaLayer 64 | elif attn_type == 'bottleneck': 65 | return BottleneckAttn 66 | elif attn_type == 'halo': 67 | return HaloAttn 68 | elif attn_type == 'nl': 69 | module_cls = NonLocalAttn 70 | elif attn_type == 'bat': 71 | module_cls = BatNonLocalAttn 72 | 73 | # Woops! 74 | else: 75 | assert False, "Invalid attn module (%s)" % attn_type 76 | elif isinstance(attn_type, bool): 77 | if attn_type: 78 | module_cls = SEModule 79 | else: 80 | module_cls = attn_type 81 | return module_cls 82 | 83 | 84 | def create_attn(attn_type, channels, **kwargs): 85 | module_cls = get_attn(attn_type) 86 | if module_cls is not None: 87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 88 | return module_cls(channels, **kwargs) 89 | return None 90 | -------------------------------------------------------------------------------- /model/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /model/layers/create_norm.py: -------------------------------------------------------------------------------- 1 | """ Norm Layer Factory 2 | 3 | Create norm modules by string (to mirror create_act and creat_norm-act fns) 4 | 5 | Copyright 2022 Ross Wightman 6 | """ 7 | import functools 8 | import types 9 | from typing import Type 10 | 11 | import torch.nn as nn 12 | 13 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm 14 | from torchvision.ops.misc import FrozenBatchNorm2d 15 | 16 | _NORM_MAP = dict( 17 | batchnorm=nn.BatchNorm2d, 18 | batchnorm2d=nn.BatchNorm2d, 19 | batchnorm1d=nn.BatchNorm1d, 20 | groupnorm=GroupNorm, 21 | groupnorm1=GroupNorm1, 22 | layernorm=LayerNorm, 23 | layernorm2d=LayerNorm2d, 24 | rmsnorm=RmsNorm, 25 | frozenbatchnorm2d=FrozenBatchNorm2d, 26 | ) 27 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()} 28 | 29 | 30 | def create_norm_layer(layer_name, num_features, **kwargs): 31 | layer = get_norm_layer(layer_name) 32 | layer_instance = layer(num_features, **kwargs) 33 | return layer_instance 34 | 35 | 36 | def get_norm_layer(norm_layer): 37 | if norm_layer is None: 38 | return None 39 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 40 | norm_kwargs = {} 41 | 42 | # unbind partial fn, so args can be rebound later 43 | if isinstance(norm_layer, functools.partial): 44 | norm_kwargs.update(norm_layer.keywords) 45 | norm_layer = norm_layer.func 46 | 47 | if isinstance(norm_layer, str): 48 | if not norm_layer: 49 | return None 50 | layer_name = norm_layer.replace('_', '') 51 | norm_layer = _NORM_MAP[layer_name] 52 | else: 53 | norm_layer = norm_layer 54 | 55 | if norm_kwargs: 56 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args 57 | return norm_layer 58 | -------------------------------------------------------------------------------- /model/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | from .evo_norm import * 13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d 14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d 15 | from .inplace_abn import InplaceAbn 16 | 17 | _NORM_ACT_MAP = dict( 18 | batchnorm=BatchNormAct2d, 19 | batchnorm2d=BatchNormAct2d, 20 | groupnorm=GroupNormAct, 21 | groupnorm1=functools.partial(GroupNormAct, num_groups=1), 22 | layernorm=LayerNormAct, 23 | layernorm2d=LayerNormAct2d, 24 | evonormb0=EvoNorm2dB0, 25 | evonormb1=EvoNorm2dB1, 26 | evonormb2=EvoNorm2dB2, 27 | evonorms0=EvoNorm2dS0, 28 | evonorms0a=EvoNorm2dS0a, 29 | evonorms1=EvoNorm2dS1, 30 | evonorms1a=EvoNorm2dS1a, 31 | evonorms2=EvoNorm2dS2, 32 | evonorms2a=EvoNorm2dS2a, 33 | frn=FilterResponseNormAct2d, 34 | frntlu=FilterResponseNormTlu2d, 35 | inplaceabn=InplaceAbn, 36 | iabn=InplaceAbn, 37 | ) 38 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} 39 | # has act_layer arg to define act type 40 | _NORM_ACT_REQUIRES_ARG = { 41 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} 42 | 43 | 44 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): 45 | layer = get_norm_act_layer(layer_name, act_layer=act_layer) 46 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 47 | if jit: 48 | layer_instance = torch.jit.script(layer_instance) 49 | return layer_instance 50 | 51 | 52 | def get_norm_act_layer(norm_layer, act_layer=None): 53 | if norm_layer is None: 54 | return None 55 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 56 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 57 | norm_act_kwargs = {} 58 | 59 | # unbind partial fn, so args can be rebound later 60 | if isinstance(norm_layer, functools.partial): 61 | norm_act_kwargs.update(norm_layer.keywords) 62 | norm_layer = norm_layer.func 63 | 64 | if isinstance(norm_layer, str): 65 | if not norm_layer: 66 | return None 67 | layer_name = norm_layer.replace('_', '').lower().split('-')[0] 68 | norm_act_layer = _NORM_ACT_MAP[layer_name] 69 | elif norm_layer in _NORM_ACT_TYPES: 70 | norm_act_layer = norm_layer 71 | elif isinstance(norm_layer, types.FunctionType): 72 | # if function type, must be a lambda/fn that creates a norm_act layer 73 | norm_act_layer = norm_layer 74 | else: 75 | type_name = norm_layer.__name__.lower() 76 | if type_name.startswith('batchnorm'): 77 | norm_act_layer = BatchNormAct2d 78 | elif type_name.startswith('groupnorm'): 79 | norm_act_layer = GroupNormAct 80 | elif type_name.startswith('groupnorm1'): 81 | norm_act_layer = functools.partial(GroupNormAct, num_groups=1) 82 | elif type_name.startswith('layernorm2d'): 83 | norm_act_layer = LayerNormAct2d 84 | elif type_name.startswith('layernorm'): 85 | norm_act_layer = LayerNormAct 86 | else: 87 | assert False, f"No equivalent norm_act layer for {type_name}" 88 | 89 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 90 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 91 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 92 | norm_act_kwargs.setdefault('act_layer', act_layer) 93 | if norm_act_kwargs: 94 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 95 | return norm_act_layer 96 | -------------------------------------------------------------------------------- /model/layers/fast_norm.py: -------------------------------------------------------------------------------- 1 | """ 'Fast' Normalization Functions 2 | 3 | For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. 4 | 5 | Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) 6 | 7 | Hacked together by / Copyright 2022 Ross Wightman 8 | """ 9 | from typing import List, Optional 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | try: 15 | from apex.normalization.fused_layer_norm import fused_layer_norm_affine 16 | has_apex = True 17 | except ImportError: 18 | has_apex = False 19 | 20 | try: 21 | from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm 22 | has_apex_rmsnorm = True 23 | except ImportError: 24 | has_apex_rmsnorm = False 25 | 26 | 27 | # fast (ie lower precision LN) can be disabled with this flag if issues crop up 28 | _USE_FAST_NORM = False # defaulting to False for now 29 | 30 | 31 | def is_fast_norm(): 32 | return _USE_FAST_NORM 33 | 34 | 35 | def set_fast_norm(enable=True): 36 | global _USE_FAST_NORM 37 | _USE_FAST_NORM = enable 38 | 39 | 40 | def fast_group_norm( 41 | x: torch.Tensor, 42 | num_groups: int, 43 | weight: Optional[torch.Tensor] = None, 44 | bias: Optional[torch.Tensor] = None, 45 | eps: float = 1e-5 46 | ) -> torch.Tensor: 47 | if torch.jit.is_scripting(): 48 | # currently cannot use is_autocast_enabled within torchscript 49 | return F.group_norm(x, num_groups, weight, bias, eps) 50 | 51 | if torch.is_autocast_enabled(): 52 | # normally native AMP casts GN inputs to float32 53 | # here we use the low precision autocast dtype 54 | # FIXME what to do re CPU autocast? 55 | dt = torch.get_autocast_gpu_dtype() 56 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None 57 | 58 | with torch.cuda.amp.autocast(enabled=False): 59 | return F.group_norm(x, num_groups, weight, bias, eps) 60 | 61 | 62 | def fast_layer_norm( 63 | x: torch.Tensor, 64 | normalized_shape: List[int], 65 | weight: Optional[torch.Tensor] = None, 66 | bias: Optional[torch.Tensor] = None, 67 | eps: float = 1e-5 68 | ) -> torch.Tensor: 69 | if torch.jit.is_scripting(): 70 | # currently cannot use is_autocast_enabled within torchscript 71 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 72 | 73 | if has_apex: 74 | return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) 75 | 76 | if torch.is_autocast_enabled(): 77 | # normally native AMP casts LN inputs to float32 78 | # apex LN does not, this is behaving like Apex 79 | dt = torch.get_autocast_gpu_dtype() 80 | # FIXME what to do re CPU autocast? 81 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None 82 | 83 | with torch.cuda.amp.autocast(enabled=False): 84 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 85 | 86 | 87 | def rms_norm( 88 | x: torch.Tensor, 89 | normalized_shape: List[int], 90 | weight: Optional[torch.Tensor] = None, 91 | eps: float = 1e-5, 92 | ): 93 | norm_ndim = len(normalized_shape) 94 | if torch.jit.is_scripting(): 95 | # ndim = len(x.shape) 96 | # dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x 97 | # NOTE -ve dims cause torchscript to crash in some cases, out of options to work around 98 | assert norm_ndim == 1 99 | v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True 100 | else: 101 | dims = tuple(range(-1, -norm_ndim - 1, -1)) 102 | v = torch.var(x, dim=dims, keepdim=True) 103 | x = x * torch.rsqrt(v + eps) 104 | if weight is not None: 105 | x = x * weight 106 | return x 107 | 108 | 109 | def fast_rms_norm( 110 | x: torch.Tensor, 111 | normalized_shape: List[int], 112 | weight: Optional[torch.Tensor] = None, 113 | eps: float = 1e-5, 114 | ) -> torch.Tensor: 115 | if torch.jit.is_scripting(): 116 | # this must be by itself, cannot merge with has_apex_rmsnorm 117 | return rms_norm(x, normalized_shape, weight, eps) 118 | 119 | if has_apex_rmsnorm: 120 | if weight is None: 121 | return fused_rms_norm(x, normalized_shape, eps) 122 | else: 123 | return fused_rms_norm_affine(x, weight, normalized_shape, eps) 124 | 125 | # fallback 126 | return rms_norm(x, normalized_shape, weight, eps) 127 | -------------------------------------------------------------------------------- /model/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /model/layers/format.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Union 3 | 4 | import torch 5 | 6 | 7 | class Format(str, Enum): 8 | NCHW = 'NCHW' 9 | NHWC = 'NHWC' 10 | NCL = 'NCL' 11 | NLC = 'NLC' 12 | 13 | 14 | FormatT = Union[str, Format] 15 | 16 | 17 | def get_spatial_dim(fmt: FormatT): 18 | fmt = Format(fmt) 19 | if fmt is Format.NLC: 20 | dim = (1,) 21 | elif fmt is Format.NCL: 22 | dim = (2,) 23 | elif fmt is Format.NHWC: 24 | dim = (1, 2) 25 | else: 26 | dim = (2, 3) 27 | return dim 28 | 29 | 30 | def get_channel_dim(fmt: FormatT): 31 | fmt = Format(fmt) 32 | if fmt is Format.NHWC: 33 | dim = 3 34 | elif fmt is Format.NLC: 35 | dim = 2 36 | else: 37 | dim = 1 38 | return dim 39 | 40 | 41 | def nchw_to(x: torch.Tensor, fmt: Format): 42 | if fmt == Format.NHWC: 43 | x = x.permute(0, 2, 3, 1) 44 | elif fmt == Format.NLC: 45 | x = x.flatten(2).transpose(1, 2) 46 | elif fmt == Format.NCL: 47 | x = x.flatten(2) 48 | return x 49 | 50 | 51 | def nhwc_to(x: torch.Tensor, fmt: Format): 52 | if fmt == Format.NCHW: 53 | x = x.permute(0, 3, 1, 2) 54 | elif fmt == Format.NLC: 55 | x = x.flatten(1, 2) 56 | elif fmt == Format.NCL: 57 | x = x.flatten(1, 2).transpose(1, 2) 58 | return x 59 | -------------------------------------------------------------------------------- /model/layers/gather_excite.py: -------------------------------------------------------------------------------- 1 | """ Gather-Excite Attention Block 2 | 3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 4 | 5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet 6 | 7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another 8 | impl that covers all of the cases. 9 | 10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation 11 | 12 | Hacked together by / Copyright 2021 Ross Wightman 13 | """ 14 | import math 15 | 16 | from torch import nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .create_act import create_act_layer, get_act_layer 20 | from .create_conv2d import create_conv2d 21 | from .helpers import make_divisible 22 | from .mlp import ConvMlp 23 | 24 | 25 | class GatherExcite(nn.Module): 26 | """ Gather-Excite Attention Module 27 | """ 28 | def __init__( 29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, 30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, 31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): 32 | super(GatherExcite, self).__init__() 33 | self.add_maxpool = add_maxpool 34 | act_layer = get_act_layer(act_layer) 35 | self.extent = extent 36 | if extra_params: 37 | self.gather = nn.Sequential() 38 | if extent == 0: 39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' 40 | self.gather.add_module( 41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) 42 | if norm_layer: 43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) 44 | else: 45 | assert extent % 2 == 0 46 | num_conv = int(math.log2(extent)) 47 | for i in range(num_conv): 48 | self.gather.add_module( 49 | f'conv{i + 1}', 50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) 51 | if norm_layer: 52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) 53 | if i != num_conv - 1: 54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) 55 | else: 56 | self.gather = None 57 | if self.extent == 0: 58 | self.gk = 0 59 | self.gs = 0 60 | else: 61 | assert extent % 2 == 0 62 | self.gk = self.extent * 2 - 1 63 | self.gs = self.extent 64 | 65 | if not rd_channels: 66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() 68 | self.gate = create_act_layer(gate_layer) 69 | 70 | def forward(self, x): 71 | size = x.shape[-2:] 72 | if self.gather is not None: 73 | x_ge = self.gather(x) 74 | else: 75 | if self.extent == 0: 76 | # global extent 77 | x_ge = x.mean(dim=(2, 3), keepdims=True) 78 | if self.add_maxpool: 79 | # experimental codepath, may remove or change 80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) 81 | else: 82 | x_ge = F.avg_pool2d( 83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) 84 | if self.add_maxpool: 85 | # experimental codepath, may remove or change 86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) 87 | x_ge = self.mlp(x_ge) 88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: 89 | x_ge = F.interpolate(x_ge, size=size) 90 | return x * self.gate(x_ge) 91 | -------------------------------------------------------------------------------- /model/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /model/layers/grn.py: -------------------------------------------------------------------------------- 1 | """ Global Response Normalization Module 2 | 3 | Based on the GRN layer presented in 4 | `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808 5 | 6 | This implementation 7 | * works for both NCHW and NHWC tensor layouts 8 | * uses affine param names matching existing torch norm layers 9 | * slightly improves eager mode performance via fused addcmul 10 | 11 | Hacked together by / Copyright 2023 Ross Wightman 12 | """ 13 | 14 | import torch 15 | from torch import nn as nn 16 | 17 | 18 | class GlobalResponseNorm(nn.Module): 19 | """ Global Response Normalization layer 20 | """ 21 | def __init__(self, dim, eps=1e-6, channels_last=True): 22 | super().__init__() 23 | self.eps = eps 24 | if channels_last: 25 | self.spatial_dim = (1, 2) 26 | self.channel_dim = -1 27 | self.wb_shape = (1, 1, 1, -1) 28 | else: 29 | self.spatial_dim = (2, 3) 30 | self.channel_dim = 1 31 | self.wb_shape = (1, -1, 1, 1) 32 | 33 | self.weight = nn.Parameter(torch.zeros(dim)) 34 | self.bias = nn.Parameter(torch.zeros(dim)) 35 | 36 | def forward(self, x): 37 | x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True) 38 | x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps) 39 | return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n) 40 | -------------------------------------------------------------------------------- /model/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return tuple(x) 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pads a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /model/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer is None or act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /model/layers/interpolate.py: -------------------------------------------------------------------------------- 1 | """ Interpolation helpers for timm layers 2 | 3 | RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations 4 | Copyright Shane Barratt, Apache 2.0 license 5 | """ 6 | import torch 7 | from itertools import product 8 | 9 | 10 | class RegularGridInterpolator: 11 | """ Interpolate data defined on a rectilinear grid with even or uneven spacing. 12 | Produces similar results to scipy RegularGridInterpolator or interp2d 13 | in 'linear' mode. 14 | 15 | Taken from https://github.com/sbarratt/torch_interpolations 16 | """ 17 | 18 | def __init__(self, points, values): 19 | self.points = points 20 | self.values = values 21 | 22 | assert isinstance(self.points, tuple) or isinstance(self.points, list) 23 | assert isinstance(self.values, torch.Tensor) 24 | 25 | self.ms = list(self.values.shape) 26 | self.n = len(self.points) 27 | 28 | assert len(self.ms) == self.n 29 | 30 | for i, p in enumerate(self.points): 31 | assert isinstance(p, torch.Tensor) 32 | assert p.shape[0] == self.values.shape[i] 33 | 34 | def __call__(self, points_to_interp): 35 | assert self.points is not None 36 | assert self.values is not None 37 | 38 | assert len(points_to_interp) == len(self.points) 39 | K = points_to_interp[0].shape[0] 40 | for x in points_to_interp: 41 | assert x.shape[0] == K 42 | 43 | idxs = [] 44 | dists = [] 45 | overalls = [] 46 | for p, x in zip(self.points, points_to_interp): 47 | idx_right = torch.bucketize(x, p) 48 | idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 49 | idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) 50 | dist_left = x - p[idx_left] 51 | dist_right = p[idx_right] - x 52 | dist_left[dist_left < 0] = 0. 53 | dist_right[dist_right < 0] = 0. 54 | both_zero = (dist_left == 0) & (dist_right == 0) 55 | dist_left[both_zero] = dist_right[both_zero] = 1. 56 | 57 | idxs.append((idx_left, idx_right)) 58 | dists.append((dist_left, dist_right)) 59 | overalls.append(dist_left + dist_right) 60 | 61 | numerator = 0. 62 | for indexer in product([0, 1], repeat=self.n): 63 | as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] 64 | bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] 65 | numerator += self.values[as_s] * \ 66 | torch.prod(torch.stack(bs_s), dim=0) 67 | denominator = torch.prod(torch.stack(overalls), dim=0) 68 | return numerator / denominator 69 | -------------------------------------------------------------------------------- /model/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /model/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /model/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /model/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | # Calculate symmetric padding for a convolution 13 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 14 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 15 | return padding 16 | 17 | 18 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 19 | def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int): 20 | if isinstance(x, torch.Tensor): 21 | return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0) 22 | else: 23 | return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0) 24 | 25 | 26 | # Can SAME padding for given args be done statically? 27 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 28 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 29 | 30 | 31 | def pad_same_arg( 32 | input_size: List[int], 33 | kernel_size: List[int], 34 | stride: List[int], 35 | dilation: List[int] = (1, 1), 36 | ) -> List[int]: 37 | ih, iw = input_size 38 | kh, kw = kernel_size 39 | pad_h = get_same_padding(ih, kh, stride[0], dilation[0]) 40 | pad_w = get_same_padding(iw, kw, stride[1], dilation[1]) 41 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] 42 | 43 | 44 | # Dynamically pad input x with 'SAME' padding for conv with specified args 45 | def pad_same( 46 | x, 47 | kernel_size: List[int], 48 | stride: List[int], 49 | dilation: List[int] = (1, 1), 50 | value: float = 0, 51 | ): 52 | ih, iw = x.size()[-2:] 53 | pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0]) 54 | pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1]) 55 | x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value) 56 | return x 57 | 58 | 59 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 60 | dynamic = False 61 | if isinstance(padding, str): 62 | # for any string padding, the padding will be calculated for you, one of three ways 63 | padding = padding.lower() 64 | if padding == 'same': 65 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 66 | if is_static_pad(kernel_size, **kwargs): 67 | # static case, no extra overhead 68 | padding = get_padding(kernel_size, **kwargs) 69 | else: 70 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 71 | padding = 0 72 | dynamic = True 73 | elif padding == 'valid': 74 | # 'VALID' padding, same as padding=0 75 | padding = 0 76 | else: 77 | # Default to PyTorch style 'same'-ish symmetric padding 78 | padding = get_padding(kernel_size, **kwargs) 79 | return padding, dynamic 80 | -------------------------------------------------------------------------------- /model/layers/patch_dropout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PatchDropout(nn.Module): 8 | """ 9 | https://arxiv.org/abs/2212.00794 10 | """ 11 | return_indices: torch.jit.Final[bool] 12 | 13 | def __init__( 14 | self, 15 | prob: float = 0.5, 16 | num_prefix_tokens: int = 1, 17 | ordered: bool = False, 18 | return_indices: bool = False, 19 | ): 20 | super().__init__() 21 | assert 0 <= prob < 1. 22 | self.prob = prob 23 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) 24 | self.ordered = ordered 25 | self.return_indices = return_indices 26 | 27 | def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: 28 | if not self.training or self.prob == 0.: 29 | if self.return_indices: 30 | return x, None 31 | return x 32 | 33 | if self.num_prefix_tokens: 34 | prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] 35 | else: 36 | prefix_tokens = None 37 | 38 | B = x.shape[0] 39 | L = x.shape[1] 40 | num_keep = max(1, int(L * (1. - self.prob))) 41 | keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep] 42 | if self.ordered: 43 | # NOTE does not need to maintain patch order in typical transformer use, 44 | # but possibly useful for debug / visualization 45 | keep_indices = keep_indices.sort(dim=-1)[0] 46 | x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) 47 | 48 | if prefix_tokens is not None: 49 | x = torch.cat((prefix_tokens, x), dim=1) 50 | 51 | if self.return_indices: 52 | return x, keep_indices 53 | return x 54 | -------------------------------------------------------------------------------- /model/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /model/layers/pos_embed.py: -------------------------------------------------------------------------------- 1 | """ Position Embedding Utilities 2 | 3 | Hacked together by / Copyright 2022 Ross Wightman 4 | """ 5 | import logging 6 | import math 7 | from typing import List, Tuple, Optional, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from .helpers import to_2tuple 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | def resample_abs_pos_embed( 18 | posemb, 19 | new_size: List[int], 20 | old_size: Optional[List[int]] = None, 21 | num_prefix_tokens: int = 1, 22 | interpolation: str = 'bicubic', 23 | antialias: bool = True, 24 | verbose: bool = False, 25 | ): 26 | # sort out sizes, assume square if old size not provided 27 | num_pos_tokens = posemb.shape[1] 28 | num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens 29 | if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: 30 | return posemb 31 | 32 | if old_size is None: 33 | hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) 34 | old_size = hw, hw 35 | 36 | if num_prefix_tokens: 37 | posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] 38 | else: 39 | posemb_prefix, posemb = None, posemb 40 | 41 | # do the interpolation 42 | embed_dim = posemb.shape[-1] 43 | orig_dtype = posemb.dtype 44 | posemb = posemb.float() # interpolate needs float32 45 | posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) 46 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) 47 | posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) 48 | posemb = posemb.to(orig_dtype) 49 | 50 | # add back extra (class, etc) prefix tokens 51 | if posemb_prefix is not None: 52 | posemb = torch.cat([posemb_prefix, posemb], dim=1) 53 | 54 | if not torch.jit.is_scripting() and verbose: 55 | _logger.info(f'Resized position embedding: {old_size} to {new_size}.') 56 | 57 | return posemb 58 | 59 | 60 | def resample_abs_pos_embed_nhwc( 61 | posemb, 62 | new_size: List[int], 63 | interpolation: str = 'bicubic', 64 | antialias: bool = True, 65 | verbose: bool = False, 66 | ): 67 | if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]: 68 | return posemb 69 | 70 | orig_dtype = posemb.dtype 71 | posemb = posemb.float() 72 | # do the interpolation 73 | posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2) 74 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) 75 | posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype) 76 | 77 | if not torch.jit.is_scripting() and verbose: 78 | _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.') 79 | 80 | return posemb 81 | -------------------------------------------------------------------------------- /model/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /model/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | bs: torch.jit.Final[int] 7 | 8 | def __init__(self, block_size=4): 9 | super().__init__() 10 | assert block_size == 4 11 | self.bs = block_size 12 | 13 | def forward(self, x): 14 | N, C, H, W = x.size() 15 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 16 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 17 | x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 18 | return x 19 | 20 | 21 | @torch.jit.script 22 | class SpaceToDepthJit: 23 | def __call__(self, x: torch.Tensor): 24 | # assuming hard-coded that block_size==4 for acceleration 25 | N, C, H, W = x.size() 26 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 27 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 28 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 29 | return x 30 | 31 | 32 | class SpaceToDepthModule(nn.Module): 33 | def __init__(self, no_jit=False): 34 | super().__init__() 35 | if not no_jit: 36 | self.op = SpaceToDepthJit() 37 | else: 38 | self.op = SpaceToDepth() 39 | 40 | def forward(self, x): 41 | return self.op(x) 42 | 43 | 44 | class DepthToSpace(nn.Module): 45 | 46 | def __init__(self, block_size): 47 | super().__init__() 48 | self.bs = block_size 49 | 50 | def forward(self, x): 51 | N, C, H, W = x.size() 52 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 53 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 54 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 55 | return x 56 | -------------------------------------------------------------------------------- /model/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /model/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /model/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | """ Squeeze-and-Excitation Channel Attention 2 | 3 | An SE implementation originally based on PyTorch SE-Net impl. 4 | Has since evolved with additional functionality / configuration. 5 | 6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 7 | 8 | Also included is Effective Squeeze-Excitation (ESE). 9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 10 | 11 | Hacked together by / Copyright 2021 Ross Wightman 12 | """ 13 | from torch import nn as nn 14 | 15 | from .create_act import create_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class SEModule(nn.Module): 20 | """ SE Module as defined in original SE-Nets with a few additions 21 | Additions include: 22 | * divisor can be specified to keep channels % div == 0 (default: 8) 23 | * reduction channels can be specified directly by arg (if rd_channels is set) 24 | * reduction channels can be specified by float rd_ratio (default: 1/16) 25 | * global max pooling can be added to the squeeze aggregation 26 | * customizable activation, normalization, and gate layer 27 | """ 28 | def __init__( 29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, 30 | bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): 31 | super(SEModule, self).__init__() 32 | self.add_maxpool = add_maxpool 33 | if not rd_channels: 34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) 36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() 37 | self.act = create_act_layer(act_layer, inplace=True) 38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) 39 | self.gate = create_act_layer(gate_layer) 40 | 41 | def forward(self, x): 42 | x_se = x.mean((2, 3), keepdim=True) 43 | if self.add_maxpool: 44 | # experimental codepath, may remove or change 45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 46 | x_se = self.fc1(x_se) 47 | x_se = self.act(self.bn(x_se)) 48 | x_se = self.fc2(x_se) 49 | return x * self.gate(x_se) 50 | 51 | 52 | SqueezeExcite = SEModule # alias 53 | 54 | 55 | class EffectiveSEModule(nn.Module): 56 | """ 'Effective Squeeze-Excitation 57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 58 | """ 59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): 60 | super(EffectiveSEModule, self).__init__() 61 | self.add_maxpool = add_maxpool 62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 63 | self.gate = create_act_layer(gate_layer) 64 | 65 | def forward(self, x): 66 | x_se = x.mean((2, 3), keepdim=True) 67 | if self.add_maxpool: 68 | # experimental codepath, may remove or change 69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 70 | x_se = self.fc(x_se) 71 | return x * self.gate(x_se) 72 | 73 | 74 | EffectiveSqueezeExcite = EffectiveSEModule # alias 75 | 76 | 77 | class SqueezeExciteCl(nn.Module): 78 | """ SE Module as defined in original SE-Nets with a few additions 79 | Additions include: 80 | * divisor can be specified to keep channels % div == 0 (default: 8) 81 | * reduction channels can be specified directly by arg (if rd_channels is set) 82 | * reduction channels can be specified by float rd_ratio (default: 1/16) 83 | * global max pooling can be added to the squeeze aggregation 84 | * customizable activation, normalization, and gate layer 85 | """ 86 | def __init__( 87 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, 88 | bias=True, act_layer=nn.ReLU, gate_layer='sigmoid'): 89 | super().__init__() 90 | if not rd_channels: 91 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 92 | self.fc1 = nn.Linear(channels, rd_channels, bias=bias) 93 | self.act = create_act_layer(act_layer, inplace=True) 94 | self.fc2 = nn.Linear(rd_channels, channels, bias=bias) 95 | self.gate = create_act_layer(gate_layer) 96 | 97 | def forward(self, x): 98 | x_se = x.mean((1, 2), keepdims=True) # FIXME avg dim [1:n-1], don't assume 2D NHWC 99 | x_se = self.fc1(x_se) 100 | x_se = self.act(x_se) 101 | x_se = self.fc2(x_se) 102 | return x * self.gate(x_se) -------------------------------------------------------------------------------- /model/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /model/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /model/layers/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, Type, Union 2 | 3 | import torch 4 | 5 | 6 | LayerType = Union[str, Callable, Type[torch.nn.Module]] 7 | PadType = Union[str, int, Tuple[int, int]] 8 | -------------------------------------------------------------------------------- /model/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | from torch.nn.init import _calculate_fan_in_and_fan_out 6 | 7 | 8 | def _trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | # Values are generated by using a truncated uniform distribution and 21 | # then using the inverse CDF for the normal distribution. 22 | # Get upper and lower cdf values 23 | l = norm_cdf((a - mean) / std) 24 | u = norm_cdf((b - mean) / std) 25 | 26 | # Uniformly fill tensor with values from [l, u], then translate to 27 | # [2l-1, 2u-1]. 28 | tensor.uniform_(2 * l - 1, 2 * u - 1) 29 | 30 | # Use inverse cdf transform for normal distribution to get truncated 31 | # standard normal 32 | tensor.erfinv_() 33 | 34 | # Transform to proper mean, std 35 | tensor.mul_(std * math.sqrt(2.)) 36 | tensor.add_(mean) 37 | 38 | # Clamp to ensure it's in the proper range 39 | tensor.clamp_(min=a, max=b) 40 | return tensor 41 | 42 | 43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 44 | # type: (Tensor, float, float, float, float) -> Tensor 45 | r"""Fills the input Tensor with values drawn from a truncated 46 | normal distribution. The values are effectively drawn from the 47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 48 | with values outside :math:`[a, b]` redrawn until they are within 49 | the bounds. The method used for generating the random values works 50 | best when :math:`a \leq \text{mean} \leq b`. 51 | 52 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 53 | applied while sampling the normal with mean/std applied, therefore a, b args 54 | should be adjusted to match the range of mean, std args. 55 | 56 | Args: 57 | tensor: an n-dimensional `torch.Tensor` 58 | mean: the mean of the normal distribution 59 | std: the standard deviation of the normal distribution 60 | a: the minimum cutoff value 61 | b: the maximum cutoff value 62 | Examples: 63 | >>> w = torch.empty(3, 5) 64 | >>> nn.init.trunc_normal_(w) 65 | """ 66 | with torch.no_grad(): 67 | return _trunc_normal_(tensor, mean, std, a, b) 68 | 69 | 70 | def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): 71 | # type: (Tensor, float, float, float, float) -> Tensor 72 | r"""Fills the input Tensor with values drawn from a truncated 73 | normal distribution. The values are effectively drawn from the 74 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 75 | with values outside :math:`[a, b]` redrawn until they are within 76 | the bounds. The method used for generating the random values works 77 | best when :math:`a \leq \text{mean} \leq b`. 78 | 79 | NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the 80 | bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 81 | and the result is subsquently scaled and shifted by the mean and std args. 82 | 83 | Args: 84 | tensor: an n-dimensional `torch.Tensor` 85 | mean: the mean of the normal distribution 86 | std: the standard deviation of the normal distribution 87 | a: the minimum cutoff value 88 | b: the maximum cutoff value 89 | Examples: 90 | >>> w = torch.empty(3, 5) 91 | >>> nn.init.trunc_normal_(w) 92 | """ 93 | with torch.no_grad(): 94 | _trunc_normal_(tensor, 0, 1.0, a, b) 95 | tensor.mul_(std).add_(mean) 96 | return tensor 97 | 98 | 99 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 100 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 101 | if mode == 'fan_in': 102 | denom = fan_in 103 | elif mode == 'fan_out': 104 | denom = fan_out 105 | elif mode == 'fan_avg': 106 | denom = (fan_in + fan_out) / 2 107 | 108 | variance = scale / denom 109 | 110 | if distribution == "truncated_normal": 111 | # constant is stddev of standard normal truncated to (-2, 2) 112 | trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978) 113 | elif distribution == "normal": 114 | with torch.no_grad(): 115 | tensor.normal_(std=math.sqrt(variance)) 116 | elif distribution == "uniform": 117 | bound = math.sqrt(3 * variance) 118 | with torch.no_grad(): 119 | tensor.uniform_(-bound, bound) 120 | else: 121 | raise ValueError(f"invalid distribution {distribution}") 122 | 123 | 124 | def lecun_normal_(tensor): 125 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 126 | -------------------------------------------------------------------------------- /my_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | @FileName:my_dataset.py\n 3 | @Description:\n 4 | @Author:WBobby\n 5 | @Department:CUG\n 6 | @Time:2023/4/29 16:04\n 7 | """ 8 | from PIL import Image 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class MyDataSet(Dataset): 14 | """自定义数据集""" 15 | 16 | def __init__(self, images_path: list, images_class: list, transform=None): 17 | self.images_path = images_path 18 | self.images_class = images_class 19 | self.transform = transform 20 | 21 | def __len__(self): 22 | return len(self.images_path) 23 | 24 | def __getitem__(self, item): 25 | img = Image.open(self.images_path[item]) 26 | # RGB为彩色图片,L为灰度图片 27 | if img.mode != 'RGB': 28 | raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) 29 | label = self.images_class[item] 30 | 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | 34 | return img, label 35 | 36 | 37 | 38 | 39 | @staticmethod 40 | def collate_fn(batch): 41 | # 官方实现的default_collate可以参考 42 | # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py 43 | images, labels = tuple(zip(*batch)) 44 | 45 | images = torch.stack(images, dim=0) 46 | labels = torch.as_tensor(labels) 47 | return images, labels 48 | -------------------------------------------------------------------------------- /parameter.py: -------------------------------------------------------------------------------- 1 | """ 2 | @FileName:parameter.py\n 3 | @Description:\n 4 | @Author:WBobby\n 5 | @Department:CUG\n 6 | @Time:2023/6/13 16:10\n 7 | """ 8 | import torch 9 | from coca_pytorch import CoCa 10 | from thop import profile 11 | from torch import nn 12 | from torchsummary import summary 13 | from torchvision.models import vgg16, resnet50 14 | from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT 15 | from vit_pytorch.extractor import Extractor 16 | import timm 17 | from BBnet.model import PVFCNet 18 | from EfficientNET.model import efficientnet_b0 19 | from pytorch.timm.models.mobilevit import mobilevit_s 20 | from timm.models.effnetv2 import effnetv2_s 21 | 22 | x = torch.rand(size=(1, 3, 112, 112)) 23 | model = timm.create_model('resnet50', pretrained=False, num_classes=10) 24 | # 为网络重写分类层 25 | # model = PVFCNet(11, 5, 10) 26 | # model = effnetv2_s() 27 | # output = model(x) 28 | # print(output.shape) 29 | print(summary(model, (3, 112, 112), device="cpu")) 30 | # print('model_name:coat_small') 31 | # print('flops:{}, params:{}'.format(model, params)) 32 | --------------------------------------------------------------------------------