├── models ├── __init__.py ├── transformer.py └── model.py ├── utils ├── __init__.py ├── save_load.py ├── general.py ├── dataset.py ├── provider.py └── pointnet_util.py ├── haq_lib ├── __init__.py ├── lib │ ├── __init__.py │ ├── env │ │ ├── __init__.py │ │ └── linear_quantize_env.py │ ├── rl │ │ ├── __init__.py │ │ ├── ddpg.py │ │ └── memory.py │ ├── utils │ │ ├── __init__.py │ │ ├── make_data.py │ │ ├── utils.py │ │ └── data_utils.py │ └── simulator │ │ └── lookup_tables │ │ └── qmobilenetv2_imagenet100_batch16_latency_table.npy ├── rl_quantize.py ├── pretrain.py └── finetune.py ├── run ├── search.sh ├── finutune.sh └── pretrain.sh ├── save ├── actor.pkl └── critic.pkl ├── config ├── cls.yaml ├── finetune.yaml ├── finetune-dist.yaml ├── partseg.yaml └── agent.yaml ├── LICENSE ├── README.md ├── model.py ├── search.py ├── pretrain.py ├── finetune.py └── finetune-distill.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /haq_lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /haq_lib/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /haq_lib/lib/env/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /haq_lib/lib/rl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /haq_lib/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /run/search.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python search.py -------------------------------------------------------------------------------- /run/finutune.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python finutune.py -------------------------------------------------------------------------------- /run/pretrain.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python pretrained.py -------------------------------------------------------------------------------- /save/actor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/Point-Transformers-with-Quantization/HEAD/save/actor.pkl -------------------------------------------------------------------------------- /save/critic.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/Point-Transformers-with-Quantization/HEAD/save/critic.pkl -------------------------------------------------------------------------------- /haq_lib/lib/simulator/lookup_tables/qmobilenetv2_imagenet100_batch16_latency_table.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/Point-Transformers-with-Quantization/HEAD/haq_lib/lib/simulator/lookup_tables/qmobilenetv2_imagenet100_batch16_latency_table.npy -------------------------------------------------------------------------------- /config/cls.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | epoch: 200 3 | learning_rate: 1e-3 4 | gpu: 1 5 | num_point: 1024 6 | optimizer: Adam 7 | weight_decay: 1e-4 8 | normal: True 9 | 10 | defaults: 11 | - model: Menghao 12 | 13 | hydra: 14 | run: 15 | dir: log/cls/${model.name} 16 | 17 | sweep: 18 | dir: log/cls 19 | subdir: ${model.name} -------------------------------------------------------------------------------- /config/finetune.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | epoch: 200 3 | learning_rate: 1e-3 4 | gpu: 0 5 | num_point: 1024 6 | optimizer: Adam 7 | weight_decay: 1e-4 8 | normal: True 9 | lr_decay: 0.5 10 | step_size: 20 11 | num_category: 16 12 | num_part: 50 13 | num_workers: 8 14 | train_episode: 600 15 | output: ./save_finetune 16 | free_high_bit: True 17 | 18 | model: 19 | nneighbor: 16 20 | nblocks: 4 21 | transformer_dim: 512 22 | name: PointTransformer 23 | 24 | work_dir: log/partseg_finetune/PointTransformer 25 | -------------------------------------------------------------------------------- /config/finetune-dist.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 8 2 | epoch: 200 3 | learning_rate: 1e-3 4 | gpu: 0 5 | num_point: 1024 6 | optimizer: Adam 7 | weight_decay: 1e-4 8 | normal: True 9 | lr_decay: 0.5 10 | step_size: 20 11 | num_category: 16 12 | num_part: 50 13 | num_workers: 8 14 | train_episode: 600 15 | output: ./save_finetune 16 | free_high_bit: True 17 | dist_loss_weight: 0.1 18 | 19 | model: 20 | nneighbor: 16 21 | nblocks: 4 22 | transformer_dim: 512 23 | name: PointTransformer 24 | 25 | work_dir: log/partseg_finetune/PointTransformer 26 | -------------------------------------------------------------------------------- /config/partseg.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 16 2 | epoch: 200 3 | learning_rate: 1e-3 4 | gpu: 0 5 | num_point: 1024 6 | optimizer: Adam 7 | weight_decay: 1e-4 8 | normal: True 9 | lr_decay: 0.5 10 | step_size: 20 11 | num_category: 16 12 | num_part: 50 13 | num_workers: 8 14 | train_episode: 600 15 | output: ./save 16 | warmup: 5 17 | fp16: True 18 | 19 | model: 20 | nneighbor: 16 21 | nblocks: 4 22 | transformer_dim: 512 23 | name: PointTransformer 24 | 25 | work_dir: log/partseg/PointTransformer 26 | init_delta: 0.5 27 | delta_decay: 0.99 28 | n_update: 1 29 | -------------------------------------------------------------------------------- /config/agent.yaml: -------------------------------------------------------------------------------- 1 | suffix: indoor3d 2 | output: ./save 3 | seed: 10 4 | linear_quantization: False 5 | preserve_ratio: 0.1 6 | float_bit: 32 7 | is_pruned: False 8 | rmsize: 128 9 | finetune_flag: True 10 | finetune_lr: 0.001 11 | finetune_gamma: 0.8 12 | finetune_epoch: 1 13 | min_bit: 2 14 | max_bit: 4 15 | work_dir: save 16 | # DDPG 17 | hidden1: 300 18 | hidden2: 300 19 | lr_c: 1e-3 20 | lr_a: 1e-4 21 | warmup: 5 22 | discount: 0.9 23 | bsize: 64 24 | rmsize: 128 25 | window_length: 1 26 | tau: 0.01 27 | init_delta: 0.5 28 | delta_decay: 0.99 29 | n_update: 1 30 | init_w: 0.0003 31 | epsilon: 50000 32 | -------------------------------------------------------------------------------- /utils/save_load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | 4 | def load_pretrained(model, weights): 5 | 6 | state_dict = torch.load(weights) 7 | if 'state_dict' in state_dict: 8 | state_dict = state_dict['state_dict'] 9 | src_state_dict = model.state_dict() 10 | fsd = collections.OrderedDict() 11 | for key, value in state_dict.items(): 12 | if key in state_dict: 13 | if value.shape == src_state_dict[key].shape: 14 | fsd[key] = value 15 | print('-[INFO] sucessfully loaded', key) 16 | else: 17 | print('-[WARN] shape mis-match', key) 18 | else: 19 | print('-[WARN] unexcepted', key) 20 | model.load_state_dict(fsd, strict=False) 21 | -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 4 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 5 | 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 6 | 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 7 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 8 | for cat in seg_classes.keys(): 9 | for label in seg_classes[cat]: 10 | seg_label_to_cat[label] = cat 11 | 12 | 13 | def to_categorical(y, num_classes): 14 | """ 1-hot encodes a tensor """ 15 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(), ] 16 | if (y.is_cuda): 17 | return new_y.cuda() 18 | return new_y 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 MIT HAN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /haq_lib/lib/utils/make_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from multiprocessing import Pool 4 | from tqdm import tqdm 5 | 6 | root = os.getcwd() 7 | data_name = 'imagenet100' 8 | src_dir = os.path.join(root, 'data/imagenet') 9 | dst_dir = os.path.join(root, 'data/' + data_name) 10 | txt_path = os.path.join(root, 'lib/utils/' + data_name + '.txt') 11 | 12 | # os.makedirs('/dev/shm/dataset', exist_ok=True) 13 | # os.makedirs('/dev/shm/dataset/imagenet', exist_ok=True) 14 | 15 | n_thread = 32 16 | 17 | 18 | def copy_func(pair): 19 | src, dst = pair 20 | # os.system('rsync -r {} {}'.format(src, dst)) 21 | os.system('ln -s {} {}'.format(src, dst)) 22 | 23 | 24 | for split in ['train', 'val']: 25 | src_split_dir = os.path.join(src_dir, split) 26 | dst_split_dir = os.path.join(dst_dir, split) 27 | os.makedirs(dst_split_dir, exist_ok=True) 28 | cls_list = [] 29 | f = open(txt_path, 'r') 30 | for x in f: 31 | cls_list.append(x[:9]) 32 | # pair_list = [(os.path.join(src_split_dir, c), os.path.join(dst_split_dir, c)) for c in cls_list] 33 | pair_list = [(os.path.join(src_split_dir, c), dst_split_dir) for c in cls_list] 34 | 35 | p = Pool(n_thread) 36 | 37 | for _ in tqdm(p.imap_unordered(copy_func, pair_list), total=len(pair_list)): 38 | pass 39 | # p.map(worker, vid_list) 40 | p.close() 41 | p.join() -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | from utils.pointnet_util import index_points, square_distance 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class TransformerBlock(nn.Module): 8 | def __init__(self, d_points, d_model, k): 9 | super().__init__() 10 | self.fc1 = nn.Linear(d_points, d_model) 11 | self.fc2 = nn.Linear(d_model, d_points) 12 | self.fc_delta = nn.Sequential( 13 | nn.Linear(3, d_model), 14 | nn.ReLU(), 15 | nn.Linear(d_model, d_model) 16 | ) 17 | self.fc_gamma = nn.Sequential( 18 | nn.Linear(d_model, d_model), 19 | nn.ReLU(), 20 | nn.Linear(d_model, d_model) 21 | ) 22 | self.w_qs = nn.Linear(d_model, d_model, bias=False) 23 | self.w_ks = nn.Linear(d_model, d_model, bias=False) 24 | self.w_vs = nn.Linear(d_model, d_model, bias=False) 25 | self.k = k 26 | 27 | # xyz: b x n x 3, features: b x n x f 28 | def forward(self, xyz, features): 29 | dists = square_distance(xyz, xyz) 30 | knn_idx = dists.argsort()[:, :, :self.k] # b x n x k 31 | knn_xyz = index_points(xyz, knn_idx) 32 | 33 | pre = features 34 | x = self.fc1(features) 35 | q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx) 36 | 37 | pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz) # b x n x k x f 38 | 39 | attn = self.fc_gamma(q[:, :, None] - k + pos_enc) 40 | attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2) # b x n x k x f 41 | 42 | res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) 43 | res = self.fc2(res) + pre 44 | return res, attn 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of Point Transformers 2 | 3 | 4 | 5 | 基于Point Transformers复现点云分割任务,并使用HAQ算法进行自动量化(2bit和4bit)压缩,几乎不影响精度 6 | 7 | ## 准备数据: 8 | 使用链接下载 **ShapeNet** 数据集:[下载地址](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) 9 | 10 | 11 | 下载完成后解压到 `data/shapenetcore_partanno_segmentation_benchmark_v0_normal` 12 | 13 | ## 预训练: 14 | 15 | ```bash 16 | bash run/pretrain.sh 17 | ``` 18 | 19 | ## 强化学习搜索: 20 | 21 | ```bash 22 | bash run/search.sh 23 | ``` 24 | 25 | ## 量化后微调: 26 | 27 | ```bash 28 | bash run/finutune.sh 29 | ``` 30 | ## 解决问题记录: 31 | 32 | - 梯度更新不一致问题:原因是每次根据loss更新参数时梯度没有清零,使用的是累计梯度,添加'self.optimizer.zero_grad()'即可 33 | - Acc等指标计算错误问题:在计算mIOU时开始是使用一个batch的数据求mIOU再最后取平均,这样一个batch某些类数据量可能为0导致计算有偏差,改成最后一起求mIOU即可 34 | - 模型量化后Acc不变的问题:这个问题最难解决,最后发现是transform里面linear往往参数较少,使用kmeans聚类算法(指定聚类中心数目)导致某些聚类中心没有数据,对应的mask产生0值;在使用这些mask更新参数时则会导致模型参数更新为nan,输出nan,使得参数不再更新,模型输出每次都完全相同 35 | 36 | ## 论文复现结果: 37 | 38 | ### ShapeNet: 39 | 40 | | Class | mIoU | 41 | | ------------------- | -------- | 42 | | Airplane | 0.7901 | 43 | | Bag | 0.7901 | 44 | | Cap | 0.8042 | 45 | | Car | 0.8287 | 46 | | Chair | 0.8985 | 47 | | Earphone | 0.7293 | 48 | | Guitar | 0.8979 | 49 | | Knife | 0.8654 | 50 | | Lamp | 0.8211 | 51 | | Laptop | 0.9524 | 52 | | Motorbike | 0.5616 | 53 | | Mug | 0.9288 | 54 | | Pistol | 0.7693 | 55 | | Rocket | 0.5708 | 56 | | Skateboard | 0.7270 | 57 | | Table | 0.8190 | 58 | | Total | 0.7940 | 59 | 60 | | ShapeNet | Accuracy | cat.mIOU | ins.mIOU | 61 | | ------------------------ | -------------- | ------------ | ------------ | 62 | | Point Transformer (papers) | None | 0.837 | 0.866 | 63 | | Point Transformer (ours) | 0.93535 | 0.79958 | 0.83802 | 64 | 65 | ### S3DIS: 66 | 67 | | S3DIS | Accuracy | 68 | | ------------------------ | -------------- | 69 | | Point Transformer (papers) | 0.908 | 70 | | Point Transformer (ours) | 0.846 | 71 | 72 | trained only for 4 epoches due to hardware limitation. 73 | 74 | 75 | 76 | ## 针对量化后准确率降低的改进方案: 77 | 78 | - 对于模型前几层和最后的分类(分割)层对准确率影响较大,因此使用较高位数量化或者不进行量化 79 | - 使用知识蒸馏的方法,在保证模型大小不变的情况下提升准确率 80 | 81 | ## 量化实验结果(ShapeNet): 82 | 83 | 84 | | Models | Accuracy | cat.mIOU | ins.mIOU | 85 | | ------------------------ | -------------- | ------------ | ------------ | 86 | | Point Transformer (paper) | None | 0.837 | 0.866 | 87 | | Point Transformer (our-no quant) | 0.93535 | 0.79958 | 0.83802 | 88 | | Point Transformer (our-no quant, mix) | 0.93653 | 0.791004 | 0.838491 | 89 | | Point Transformer (our-0.1×preserve) | 0.932 | 0.781 | 0.826 | 90 | | Point Transformer (our-0.1×preserve, mix) | 0.9341 | 0.7894 | 0.8337 | 91 | | Point Transformer (our-0.1×preserve, mix, finetune) | 0.936523 | 0.796603 | 0.837771 | 92 | | Point Transformer (our-0.1×preserve, mix, finetune, distill) | 0.940213 | 0.799304 | 0.839487 | 93 | 94 | # 2.2 Pytorch Implementation of Point Transformers 95 | 96 | 基于PCT: Point Cloud Transformer复现点云分割任务,还未进行量化工作。 97 | ## 运行: 98 | ```bash 99 | PCT.ipynb 100 | ``` 101 | ## 论文复现结果: 102 | ### modelnet40(PCT): 103 | 104 | | modelnet40 | Accuracy | 105 | | ------------------------ | -------------- | 106 | | PCT (papers) | 0.932 | 107 | | PCT (ours) | 0.896576 | 108 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.pointnet_util import PointNetFeaturePropagation, PointNetSetAbstraction 3 | from .transformer import TransformerBlock 4 | 5 | 6 | class TransitionDown(nn.Module): 7 | def __init__(self, k, nneighbor, channels): 8 | super().__init__() 9 | self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True) 10 | 11 | def forward(self, xyz, points): 12 | return self.sa(xyz, points) 13 | 14 | 15 | class TransitionUp(nn.Module): 16 | def __init__(self, dim1, dim2, dim_out): 17 | class SwapAxes(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | 21 | def forward(self, x): 22 | return x.transpose(1, 2) 23 | 24 | super().__init__() 25 | self.fc1 = nn.Sequential( 26 | nn.Linear(dim1, dim_out), 27 | SwapAxes(), 28 | nn.BatchNorm1d(dim_out), # TODO 29 | SwapAxes(), 30 | nn.ReLU(), 31 | ) 32 | self.fc2 = nn.Sequential( 33 | nn.Linear(dim2, dim_out), 34 | SwapAxes(), 35 | nn.BatchNorm1d(dim_out), # TODO 36 | SwapAxes(), 37 | nn.ReLU(), 38 | ) 39 | self.fp = PointNetFeaturePropagation(-1, []) 40 | 41 | def forward(self, xyz1, points1, xyz2, points2): 42 | feats1 = self.fc1(points1) 43 | feats2 = self.fc2(points2) 44 | feats1 = self.fp(xyz2.transpose(1, 2), xyz1.transpose(1, 2), None, feats1.transpose(1, 2)).transpose(1, 2) 45 | return feats1 + feats2 46 | 47 | 48 | class Backbone(nn.Module): 49 | def __init__(self, cfg): 50 | super().__init__() 51 | npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim 52 | self.fc1 = nn.Sequential( 53 | nn.Linear(d_points, 32), 54 | nn.ReLU(), 55 | nn.Linear(32, 32) 56 | ) 57 | self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor) 58 | self.transition_downs = nn.ModuleList() 59 | self.transformers = nn.ModuleList() 60 | for i in range(nblocks): 61 | channel = 32 * 2 ** (i + 1) 62 | self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel])) 63 | self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) 64 | self.nblocks = nblocks 65 | 66 | def forward(self, x): 67 | xyz = x[..., :3] 68 | points = self.transformer1(xyz, self.fc1(x))[0] 69 | 70 | xyz_and_feats = [(xyz, points)] 71 | for i in range(self.nblocks): 72 | xyz, points = self.transition_downs[i](xyz, points) 73 | points = self.transformers[i](xyz, points)[0] 74 | xyz_and_feats.append((xyz, points)) 75 | return points, xyz_and_feats 76 | 77 | 78 | 79 | class PointTransformerSeg(nn.Module): 80 | def __init__(self, cfg): 81 | super().__init__() 82 | self.backbone = Backbone(cfg) 83 | npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim 84 | self.fc2 = nn.Sequential( 85 | nn.Linear(32 * 2 ** nblocks, 512), 86 | nn.ReLU(), 87 | nn.Linear(512, 512), 88 | nn.ReLU(), 89 | nn.Linear(512, 32 * 2 ** nblocks) 90 | ) 91 | self.transformer2 = TransformerBlock(32 * 2 ** nblocks, cfg.model.transformer_dim, nneighbor) 92 | self.nblocks = nblocks 93 | self.transition_ups = nn.ModuleList() 94 | self.transformers = nn.ModuleList() 95 | for i in reversed(range(nblocks)): 96 | channel = 32 * 2 ** i 97 | self.transition_ups.append(TransitionUp(channel * 2, channel, channel)) 98 | self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) 99 | 100 | self.fc3 = nn.Sequential( 101 | nn.Linear(32, 64), 102 | nn.ReLU(), 103 | nn.Linear(64, 64), 104 | nn.ReLU(), 105 | nn.Linear(64, n_c) 106 | ) 107 | 108 | def forward(self, x): 109 | points, xyz_and_feats = self.backbone(x) 110 | xyz = xyz_and_feats[-1][0] 111 | points = self.transformer2(xyz, self.fc2(points))[0] 112 | 113 | for i in range(self.nblocks): 114 | points = self.transition_ups[i](xyz, points, xyz_and_feats[- i - 2][0], xyz_and_feats[- i - 2][1]) 115 | xyz = xyz_and_feats[- i - 2][0] 116 | points = self.transformers[i](xyz, points)[0] 117 | 118 | return self.fc3(points) 119 | 120 | 121 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.pointnet_util import PointNetFeaturePropagation, PointNetSetAbstraction 3 | from .transformer import TransformerBlock 4 | 5 | 6 | class TransitionDown(nn.Module): 7 | def __init__(self, k, nneighbor, channels): 8 | super().__init__() 9 | self.sa = PointNetSetAbstraction(k, 0, nneighbor, channels[0], channels[1:], group_all=False, knn=True) 10 | 11 | def forward(self, xyz, points): 12 | return self.sa(xyz, points) 13 | 14 | 15 | class TransitionUp(nn.Module): 16 | def __init__(self, dim1, dim2, dim_out): 17 | class SwapAxes(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | 21 | def forward(self, x): 22 | return x.transpose(1, 2) 23 | 24 | super().__init__() 25 | self.fc1 = nn.Sequential( 26 | nn.Linear(dim1, dim_out), 27 | SwapAxes(), 28 | nn.BatchNorm1d(dim_out), # TODO 29 | SwapAxes(), 30 | nn.ReLU(), 31 | ) 32 | self.fc2 = nn.Sequential( 33 | nn.Linear(dim2, dim_out), 34 | SwapAxes(), 35 | nn.BatchNorm1d(dim_out), # TODO 36 | SwapAxes(), 37 | nn.ReLU(), 38 | ) 39 | self.fp = PointNetFeaturePropagation(-1, []) 40 | 41 | def forward(self, xyz1, points1, xyz2, points2): 42 | feats1 = self.fc1(points1) 43 | feats2 = self.fc2(points2) 44 | feats1 = self.fp(xyz2.transpose(1, 2), xyz1.transpose(1, 2), None, feats1.transpose(1, 2)).transpose(1, 2) 45 | return feats1 + feats2 46 | 47 | 48 | class Backbone(nn.Module): 49 | def __init__(self, cfg): 50 | super().__init__() 51 | npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim 52 | self.fc1 = nn.Sequential( 53 | nn.Linear(d_points, 32), 54 | nn.ReLU(), 55 | nn.Linear(32, 32) 56 | ) 57 | self.transformer1 = TransformerBlock(32, cfg.model.transformer_dim, nneighbor) 58 | self.transition_downs = nn.ModuleList() 59 | self.transformers = nn.ModuleList() 60 | for i in range(nblocks): 61 | channel = 32 * 2 ** (i + 1) 62 | self.transition_downs.append(TransitionDown(npoints // 4 ** (i + 1), nneighbor, [channel // 2 + 3, channel, channel])) 63 | self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) 64 | self.nblocks = nblocks 65 | 66 | def forward(self, x): 67 | xyz = x[..., :3] 68 | points = self.transformer1(xyz, self.fc1(x))[0] 69 | 70 | xyz_and_feats = [(xyz, points)] 71 | for i in range(self.nblocks): 72 | xyz, points = self.transition_downs[i](xyz, points) 73 | points = self.transformers[i](xyz, points)[0] 74 | xyz_and_feats.append((xyz, points)) 75 | return points, xyz_and_feats 76 | 77 | 78 | 79 | class PointTransformerSeg(nn.Module): 80 | def __init__(self, cfg): 81 | super().__init__() 82 | self.backbone = Backbone(cfg) 83 | npoints, nblocks, nneighbor, n_c, d_points = cfg.num_point, cfg.model.nblocks, cfg.model.nneighbor, cfg.num_class, cfg.input_dim 84 | self.fc2 = nn.Sequential( 85 | nn.Linear(32 * 2 ** nblocks, 512), 86 | nn.ReLU(), 87 | nn.Linear(512, 512), 88 | nn.ReLU(), 89 | nn.Linear(512, 32 * 2 ** nblocks) 90 | ) 91 | self.transformer2 = TransformerBlock(32 * 2 ** nblocks, cfg.model.transformer_dim, nneighbor) 92 | self.nblocks = nblocks 93 | self.transition_ups = nn.ModuleList() 94 | self.transformers = nn.ModuleList() 95 | for i in reversed(range(nblocks)): 96 | channel = 32 * 2 ** i 97 | self.transition_ups.append(TransitionUp(channel * 2, channel, channel)) 98 | self.transformers.append(TransformerBlock(channel, cfg.model.transformer_dim, nneighbor)) 99 | 100 | self.fc3 = nn.Sequential( 101 | nn.Linear(32, 64), 102 | nn.ReLU(), 103 | nn.Linear(64, 64), 104 | nn.ReLU(), 105 | nn.Linear(64, n_c) 106 | ) 107 | 108 | def forward(self, x): 109 | points, xyz_and_feats = self.backbone(x) 110 | xyz = xyz_and_feats[-1][0] 111 | points = self.transformer2(xyz, self.fc2(points))[0] 112 | 113 | for i in range(self.nblocks): 114 | points = self.transition_ups[i](xyz, points, xyz_and_feats[- i - 2][0], xyz_and_feats[- i - 2][1]) 115 | xyz = xyz_and_feats[- i - 2][0] 116 | points = self.transformers[i](xyz, points)[0] 117 | 118 | return self.fc3(points) 119 | 120 | 121 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torch.utils.data import Dataset 4 | import torch 5 | from utils.pointnet_util import farthest_point_sample, pc_normalize 6 | import json 7 | 8 | 9 | class ModelNetDataLoader(Dataset): 10 | def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000): 11 | self.root = root 12 | self.npoints = npoint 13 | self.uniform = uniform 14 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') 15 | 16 | self.cat = [line.rstrip() for line in open(self.catfile)] 17 | self.classes = dict(zip(self.cat, range(len(self.cat)))) 18 | self.normal_channel = normal_channel 19 | 20 | shape_ids = {} 21 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] 22 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] 23 | 24 | assert (split == 'train' or split == 'test') 25 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] 26 | # list of (shape_name, shape_txt_file_path) tuple 27 | self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i 28 | in range(len(shape_ids[split]))] 29 | print('The size of %s data is %d'%(split,len(self.datapath))) 30 | 31 | self.cache_size = cache_size # how many data points to cache in memory 32 | self.cache = {} # from index to (point_set, cls) tuple 33 | 34 | def __len__(self): 35 | return len(self.datapath) 36 | 37 | def _get_item(self, index): 38 | if index in self.cache: 39 | point_set, cls = self.cache[index] 40 | else: 41 | fn = self.datapath[index] 42 | cls = self.classes[self.datapath[index][0]] 43 | cls = np.array([cls]).astype(np.int32) 44 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) 45 | if self.uniform: 46 | point_set = farthest_point_sample(point_set, self.npoints) 47 | else: 48 | point_set = point_set[0:self.npoints,:] 49 | 50 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 51 | 52 | if not self.normal_channel: 53 | point_set = point_set[:, 0:3] 54 | 55 | if len(self.cache) < self.cache_size: 56 | self.cache[index] = (point_set, cls) 57 | 58 | return point_set, cls 59 | 60 | def __getitem__(self, index): 61 | return self._get_item(index) 62 | 63 | 64 | class PartNormalDataset(Dataset): 65 | def __init__(self, root='./data/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False): 66 | self.npoints = npoints 67 | self.root = root 68 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 69 | self.cat = {} 70 | self.normal_channel = normal_channel 71 | 72 | 73 | with open(self.catfile, 'r') as f: 74 | for line in f: 75 | ls = line.strip().split() 76 | self.cat[ls[0]] = ls[1] 77 | self.cat = {k: v for k, v in self.cat.items()} 78 | self.classes_original = dict(zip(self.cat, range(len(self.cat)))) 79 | 80 | if not class_choice is None: 81 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice} 82 | # print(self.cat) 83 | 84 | self.meta = {} 85 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: 86 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 87 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: 88 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 89 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: 90 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 91 | for item in self.cat: 92 | # print('category', item) 93 | self.meta[item] = [] 94 | dir_point = os.path.join(self.root, self.cat[item]) 95 | fns = sorted(os.listdir(dir_point)) 96 | # print(fns[0][0:-4]) 97 | if split == 'trainval': 98 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] 99 | elif split == 'train': 100 | fns = [fn for fn in fns if fn[0:-4] in train_ids] 101 | elif split == 'val': 102 | fns = [fn for fn in fns if fn[0:-4] in val_ids] 103 | elif split == 'test': 104 | fns = [fn for fn in fns if fn[0:-4] in test_ids] 105 | else: 106 | print('Unknown split: %s. Exiting..' % (split)) 107 | exit(-1) 108 | 109 | # print(os.path.basename(fns)) 110 | for fn in fns: 111 | token = (os.path.splitext(os.path.basename(fn))[0]) 112 | self.meta[item].append(os.path.join(dir_point, token + '.txt')) 113 | 114 | self.datapath = [] 115 | for item in self.cat: 116 | for fn in self.meta[item]: 117 | self.datapath.append((item, fn)) 118 | 119 | self.classes = {} 120 | for i in self.cat.keys(): 121 | self.classes[i] = self.classes_original[i] 122 | 123 | # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels 124 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 125 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 126 | 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 127 | 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 128 | 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 129 | 130 | # for cat in sorted(self.seg_classes.keys()): 131 | # print(cat, self.seg_classes[cat]) 132 | 133 | self.cache = {} # from index to (point_set, cls, seg) tuple 134 | self.cache_size = 20000 135 | 136 | 137 | def __getitem__(self, index): 138 | if index in self.cache: 139 | point_set, cls, seg = self.cache[index] 140 | else: 141 | fn = self.datapath[index] 142 | cat = self.datapath[index][0] 143 | cls = self.classes[cat] 144 | cls = np.array([cls]).astype(np.int32) 145 | data = np.loadtxt(fn[1]).astype(np.float32) 146 | if not self.normal_channel: 147 | point_set = data[:, 0:3] 148 | else: 149 | point_set = data[:, 0:6] 150 | seg = data[:, -1].astype(np.int32) 151 | if len(self.cache) < self.cache_size: 152 | self.cache[index] = (point_set, cls, seg) 153 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 154 | 155 | choice = np.random.choice(len(seg), self.npoints, replace=True) 156 | # resample 157 | point_set = point_set[choice, :] 158 | seg = seg[choice] 159 | 160 | return point_set, cls, seg 161 | 162 | def __len__(self): 163 | return len(self.datapath) 164 | 165 | 166 | if __name__ == '__main__': 167 | data = ModelNetDataLoader('modelnet40_normal_resampled/', split='train', uniform=False, normal_channel=True) 168 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True) 169 | for point,label in DataLoader: 170 | print(point.shape) 171 | print(label.shape) -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import torch.backends.cudnn as cudnn 4 | from copy import deepcopy 5 | from haq_lib.lib.rl.ddpg import DDPG 6 | from haq_lib.lib.env.linear_quantize_env import LinearQuantizeEnv 7 | from haq_lib.lib.env.quantize_env import QuantizeEnv 8 | import hydra 9 | from utils.dataset import PartNormalDataset 10 | import numpy as np 11 | import shutil 12 | import importlib 13 | import torch 14 | import os 15 | import yaml 16 | import warnings 17 | 18 | warnings.filterwarnings('ignore') 19 | 20 | # training models 21 | # rl search 22 | 23 | 24 | def create_attr_dict(yaml_config): 25 | from ast import literal_eval 26 | for key, value in yaml_config.items(): 27 | if type(value) is dict: 28 | yaml_config[key] = value = AttrDict(value) 29 | if isinstance(value, str): 30 | try: 31 | value = literal_eval(value) 32 | except BaseException: 33 | pass 34 | if isinstance(value, AttrDict): 35 | create_attr_dict(yaml_config[key]) 36 | else: 37 | yaml_config[key] = value 38 | 39 | 40 | class AttrDict(dict): 41 | def __getattr__(self, key): 42 | return self[key] 43 | 44 | def __setattr__(self, key, value): 45 | if key in self.__dict__: 46 | self.__dict__[key] = value 47 | else: 48 | self[key] = value 49 | 50 | 51 | def inplace_relu(m): 52 | classname = m.__class__.__name__ 53 | if classname.find('ReLU') != -1: 54 | m.inplace = True 55 | 56 | 57 | def to_categorical(y, num_classes): 58 | """ 1-hot encodes a tensor """ 59 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(), ] 60 | if (y.is_cuda): 61 | return new_y.cuda() 62 | return new_y 63 | 64 | 65 | def init_agent(model, pretrained, train_loader, 66 | val_loader, num_category, 67 | num_class, num_point, logger): 68 | with open('config/agent.yaml', 'r') as f: 69 | args = AttrDict(yaml.safe_load(f.read())) 70 | create_attr_dict(args) 71 | base_folder_name = 'output' 72 | if args.suffix is not None: 73 | base_folder_name = base_folder_name + '_' + args.suffix 74 | args.output = os.path.join(args.output, base_folder_name) 75 | print('==> Output path: {}...'.format(args.output)) 76 | 77 | assert torch.cuda.is_available(), 'CUDA is needed for CNN' 78 | 79 | if args.seed > 0: 80 | np.random.seed(args.seed) 81 | torch.manual_seed(args.seed) 82 | torch.cuda.manual_seed_all(args.seed) 83 | 84 | print(' Total params: %.2fM' % (sum(p.numel() 85 | for p in model.parameters())/1000000.0)) 86 | cudnn.benchmark = True 87 | 88 | if args.linear_quantization: 89 | env = LinearQuantizeEnv(model, pretrained, train_loader, val_loader, 90 | compress_ratio=args.preserve_ratio, args=args, 91 | float_bit=args.float_bit, is_model_pruned=args.is_pruned, 92 | num_category=num_category, num_part=num_class, 93 | num_point=num_point, logger=logger) 94 | else: 95 | env = QuantizeEnv(model, pretrained, train_loader, val_loader, 96 | compress_ratio=args.preserve_ratio, args=args, 97 | float_bit=args.float_bit, is_model_pruned=args.is_pruned, 98 | num_category=num_category, num_part=num_class, 99 | num_point=num_point, logger=logger) 100 | 101 | nb_states = env.layer_embedding.shape[1] 102 | nb_actions = 1 # actions for weight and activation quantization 103 | args.rmsize = args.rmsize * len(env.quantizable_idx) # for each layer 104 | print('** Actual replay buffer size: {}'.format(args.rmsize)) 105 | agent = DDPG(nb_states, nb_actions, args) 106 | return agent, env 107 | 108 | 109 | def main(): 110 | with open('config/partseg.yaml', 'r') as f: 111 | args = AttrDict(yaml.safe_load(f.read())) 112 | create_attr_dict(args) 113 | print(args.model) 114 | # HYPER PARAMETER 115 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 116 | work_dir = args.work_dir 117 | if not os.path.exists(work_dir): 118 | os.makedirs(work_dir) 119 | 120 | log_file = os.path.join(work_dir, 'search.log') 121 | logger = logging.getLogger('search') 122 | logger.setLevel(logging.INFO) 123 | file_handler = logging.FileHandler(log_file, 'w') 124 | file_handler.setFormatter( 125 | logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) 126 | file_handler.setLevel(logging.INFO) 127 | logger.addHandler(file_handler) 128 | console = logging.StreamHandler() 129 | logger.addHandler(console) 130 | 131 | root = hydra.utils.to_absolute_path( 132 | 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/') 133 | 134 | TRAIN_DATASET = PartNormalDataset( 135 | root=root, npoints=args.num_point, split='trainval', normal_channel=args.normal) 136 | train_loader = torch.utils.data.DataLoader( 137 | TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) 138 | TEST_DATASET = PartNormalDataset( 139 | root=root, npoints=args.num_point, split='test', normal_channel=args.normal) 140 | val_loader = torch.utils.data.DataLoader( 141 | TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 142 | 143 | # MODEL LOADING 144 | args.input_dim = (6 if args.normal else 3) + 16 145 | args.num_class = 50 146 | shutil.copy(hydra.utils.to_absolute_path('models/model.py'), '.') 147 | 148 | model = getattr(importlib.import_module('models.model'), 149 | 'PointTransformerSeg')(args).cuda() 150 | 151 | pretrained = torch.load("best_model.pth") 152 | if 'model_state_dict' in pretrained: 153 | pretrained = pretrained['model_state_dict'] 154 | # pretrained = model.state_dict() 155 | # for m in model.modules(): 156 | # print(type(m)) 157 | # pass 158 | agent, env = init_agent(model, pretrained, train_loader, 159 | val_loader, args.num_category, 160 | args.num_class, args.num_point, logger) 161 | 162 | best_reward = -math.inf 163 | best_policy = [] 164 | agent.is_training = True 165 | step = episode = episode_steps = 0 166 | episode_reward = 0. 167 | observation = None 168 | T = [] # trajectory 169 | if not os.path.exists(args.output): 170 | os.mkdir(args.output) 171 | while episode < args.train_episode: # counting based on episode 172 | # reset if it is the start of episode 173 | if observation is None: 174 | observation = deepcopy(env.reset()) 175 | agent.reset(observation) 176 | 177 | # agent pick action ... 178 | if episode <= args.warmup: 179 | action = agent.random_action() 180 | else: 181 | action = agent.select_action(observation, episode=episode) 182 | 183 | # env response with next_observation, reward, terminate_info 184 | observation2, reward, done, info = env.step(action) 185 | observation2 = deepcopy(observation2) 186 | 187 | T.append([reward, deepcopy(observation), 188 | deepcopy(observation2), action, done]) 189 | 190 | # [optional] save intermideate model 191 | if episode % int(args.train_episode / 10) == 0: 192 | agent.save_model(args.output) 193 | 194 | # update 195 | step += 1 196 | episode_steps += 1 197 | episode_reward += reward 198 | observation = deepcopy(observation2) 199 | 200 | if done: # end of episode 201 | 202 | logger.info('#{}: episode_reward:{:.4f} acc: {:.4f}, weight: {:.4f} %'.format(episode, episode_reward, 203 | info['accuracy'], 204 | info['w_ratio'] * 100)) 205 | 206 | final_reward = T[-1][0] 207 | # agent observe and update policy 208 | for r_t, s_t, s_t1, a_t, done in T: 209 | agent.observe(final_reward, s_t, s_t1, a_t, done) 210 | if episode > args.warmup: 211 | for _ in range(args.n_update): 212 | agent.update_policy() 213 | 214 | agent.memory.append( 215 | observation, 216 | agent.select_action(observation, episode=episode), 217 | 0., False 218 | ) 219 | 220 | # reset 221 | observation = None 222 | episode_steps = 0 223 | episode_reward = 0. 224 | episode += 1 225 | T = [] 226 | 227 | if final_reward > best_reward: 228 | best_reward = final_reward 229 | best_policy = env.strategy 230 | 231 | logger.info('best reward: {}\n'.format(best_reward)) 232 | logger.info('best policy: {}\n'.format(best_policy)) 233 | 234 | return best_policy, best_reward 235 | 236 | 237 | if __name__ == '__main__': 238 | 239 | main() 240 | -------------------------------------------------------------------------------- /haq_lib/lib/rl/ddpg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.sys.path.insert(0, os.path.abspath("../..")) 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim import Adam 9 | 10 | from haq_lib.lib.rl.memory import SequentialMemory 11 | 12 | from haq_lib.lib.utils.utils import to_numpy, to_tensor, sample_from_truncated_normal_distribution 13 | 14 | criterion = nn.MSELoss() 15 | USE_CUDA = torch.cuda.is_available() 16 | 17 | 18 | class Actor(nn.Module): 19 | def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300, init_w=3e-3): 20 | super(Actor, self).__init__() 21 | self.fc1 = nn.Linear(nb_states, hidden1) 22 | self.fc2 = nn.Linear(hidden1, hidden2) 23 | self.fc3 = nn.Linear(hidden2, nb_actions) 24 | self.relu = nn.ReLU() 25 | self.sigmoid = nn.Sigmoid() 26 | 27 | def forward(self, x): 28 | out = self.fc1(x) 29 | out = self.relu(out) 30 | out = self.fc2(out) 31 | out = self.relu(out) 32 | out = self.fc3(out) 33 | out = self.sigmoid(out) 34 | return out 35 | 36 | 37 | class Critic(nn.Module): 38 | def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300, init_w=3e-3): 39 | super(Critic, self).__init__() 40 | self.fc11 = nn.Linear(nb_states, hidden1) 41 | self.fc12 = nn.Linear(nb_actions, hidden1) 42 | self.fc2 = nn.Linear(hidden1, hidden2) 43 | self.fc3 = nn.Linear(hidden2, 1) 44 | self.relu = nn.ReLU() 45 | 46 | def forward(self, xs): 47 | x, a = xs 48 | out = self.fc11(x) + self.fc12(a) 49 | out = self.relu(out) 50 | out = self.fc2(out) 51 | out = self.relu(out) 52 | out = self.fc3(out) 53 | return out 54 | 55 | 56 | class DDPG(object): 57 | def __init__(self, nb_states, nb_actions, args): 58 | 59 | if args.seed > 0: 60 | self.seed(args.seed) 61 | 62 | self.nb_states = nb_states 63 | self.nb_actions = nb_actions 64 | 65 | # Create Actor and Critic Network 66 | net_cfg = { 67 | 'hidden1': args.hidden1, 68 | 'hidden2': args.hidden2, 69 | 'init_w': args.init_w 70 | } 71 | self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg) 72 | self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg) 73 | self.actor_optim = Adam(self.actor.parameters(), lr=args.lr_a) 74 | 75 | self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg) 76 | self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg) 77 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr_c) 78 | 79 | self.hard_update(self.actor_target, self.actor) # Make sure target is with the same weight 80 | self.hard_update(self.critic_target, self.critic) 81 | 82 | # Create replay buffer 83 | self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length) 84 | # self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=args.ou_theta, mu=args.ou_mu, 85 | # sigma=args.ou_sigma) 86 | 87 | # Hyper-parameters 88 | self.batch_size = args.bsize 89 | self.tau = args.tau 90 | self.discount = args.discount 91 | self.depsilon = 1.0 / args.epsilon 92 | self.lbound = 0. # args.lbound 93 | self.rbound = 1. # args.rbound 94 | 95 | # noise 96 | self.init_delta = args.init_delta 97 | self.delta_decay = args.delta_decay 98 | self.warmup = args.warmup 99 | self.delta = args.init_delta 100 | # loss 101 | self.value_loss = 0.0 102 | self.policy_loss = 0.0 103 | 104 | # 105 | self.epsilon = 1.0 106 | # self.s_t = None # Most recent state 107 | # self.a_t = None # Most recent action 108 | self.is_training = True 109 | 110 | # 111 | if USE_CUDA: self.cuda() 112 | 113 | # moving average baseline 114 | self.moving_average = None 115 | self.moving_alpha = 0.5 # based on batch, so small 116 | 117 | def update_policy(self): 118 | # Sample batch 119 | state_batch, action_batch, reward_batch, \ 120 | next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size) 121 | 122 | # normalize the reward 123 | batch_mean_reward = np.mean(reward_batch) 124 | if self.moving_average is None: 125 | self.moving_average = batch_mean_reward 126 | else: 127 | self.moving_average += self.moving_alpha * (batch_mean_reward - self.moving_average) 128 | reward_batch -= self.moving_average 129 | # if reward_batch.std() > 0: 130 | # reward_batch /= reward_batch.std() 131 | 132 | # Prepare for the target q batch 133 | with torch.no_grad(): 134 | next_q_values = self.critic_target([ 135 | to_tensor(next_state_batch), 136 | self.actor_target(to_tensor(next_state_batch)), 137 | ]) 138 | 139 | target_q_batch = to_tensor(reward_batch) + \ 140 | self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values 141 | 142 | # Critic update 143 | self.critic.zero_grad() 144 | 145 | q_batch = self.critic([to_tensor(state_batch), to_tensor(action_batch)]) 146 | 147 | value_loss = criterion(q_batch, target_q_batch) 148 | value_loss.backward() 149 | self.critic_optim.step() 150 | 151 | # Actor update 152 | self.actor.zero_grad() 153 | 154 | policy_loss = -self.critic([ 155 | to_tensor(state_batch), 156 | self.actor(to_tensor(state_batch)) 157 | ]) 158 | 159 | policy_loss = policy_loss.mean() 160 | policy_loss.backward() 161 | self.actor_optim.step() 162 | 163 | # Target update 164 | self.soft_update(self.actor_target, self.actor) 165 | self.soft_update(self.critic_target, self.critic) 166 | 167 | # update for log 168 | self.value_loss = value_loss 169 | self.policy_loss = policy_loss 170 | 171 | def eval(self): 172 | self.actor.eval() 173 | self.actor_target.eval() 174 | self.critic.eval() 175 | self.critic_target.eval() 176 | 177 | def cuda(self): 178 | self.actor.cuda() 179 | self.actor_target.cuda() 180 | self.critic.cuda() 181 | self.critic_target.cuda() 182 | 183 | def observe(self, r_t, s_t, s_t1, a_t, done): 184 | if self.is_training: 185 | self.memory.append(s_t, a_t, r_t, done) # save to memory 186 | # self.s_t = s_t1 187 | 188 | def random_action(self): 189 | action = np.random.uniform(self.lbound, self.rbound, self.nb_actions) 190 | # self.a_t = action 191 | return action 192 | 193 | def select_action(self, s_t, episode, decay_epsilon=True): 194 | # assert episode >= self.warmup, 'Episode: {} warmup: {}'.format(episode, self.warmup) 195 | action = to_numpy(self.actor(to_tensor(np.array(s_t).reshape(1, -1)))).squeeze(0) 196 | delta = self.init_delta * (self.delta_decay ** (episode - self.warmup)) 197 | # action += self.is_training * max(self.epsilon, 0) * self.random_process.sample() 198 | #from IPython import embed; embed() # TODO eable decay_epsilon=True 199 | action = sample_from_truncated_normal_distribution(lower=self.lbound, upper=self.rbound, mu=action, sigma=delta) 200 | action = np.clip(action, self.lbound, self.rbound) 201 | # update for log 202 | self.delta = delta 203 | # self.a_t = action 204 | return action 205 | 206 | def reset(self, obs): 207 | pass 208 | # self.s_t = obs 209 | # self.random_process.reset_states() 210 | 211 | def load_weights(self, output): 212 | if output is None: return 213 | 214 | self.actor.load_state_dict( 215 | torch.load('{}/actor.pkl'.format(output)) 216 | ) 217 | 218 | self.critic.load_state_dict( 219 | torch.load('{}/critic.pkl'.format(output)) 220 | ) 221 | 222 | def save_model(self, output): 223 | torch.save( 224 | self.actor.state_dict(), 225 | '{}/actor.pkl'.format(output) 226 | ) 227 | torch.save( 228 | self.critic.state_dict(), 229 | '{}/critic.pkl'.format(output) 230 | ) 231 | 232 | def seed(self, s): 233 | torch.manual_seed(s) 234 | if USE_CUDA: 235 | torch.cuda.manual_seed(s) 236 | 237 | def soft_update(self, target, source): 238 | for target_param, param in zip(target.parameters(), source.parameters()): 239 | target_param.data.copy_( 240 | target_param.data * (1.0 - self.tau) + param.data * self.tau 241 | ) 242 | 243 | def hard_update(self, target, source): 244 | for target_param, param in zip(target.parameters(), source.parameters()): 245 | target_param.data.copy_(param.data) 246 | 247 | def get_delta(self): 248 | return self.delta 249 | 250 | def get_value_loss(self): 251 | return self.value_loss 252 | 253 | def get_policy_loss(self): 254 | return self.policy_loss 255 | -------------------------------------------------------------------------------- /haq_lib/lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Code for "[HAQ: Hardware-Aware Automated Quantization with Mixed Precision" 2 | # Kuan Wang*, Zhijian Liu*, Yujun Lin*, Ji Lin, Song Han 3 | # {kuanwang, zhijian, yujunlin, jilin, songhan}@mit.edu 4 | 5 | import torch 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from haq_lib.lib.utils.quantize_utils import QConv2d, QLinear 9 | 10 | 11 | class AverageMeter(object): 12 | def __init__(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | if self.count > 0: 29 | self.avg = self.sum / self.count 30 | 31 | def accumulate(self, val, n=1): 32 | self.sum += val 33 | self.count += n 34 | if self.count > 0: 35 | self.avg = self.sum / self.count 36 | 37 | 38 | class Logger(object): 39 | def __init__(self, fpath, title=None, resume=False): 40 | self.file = None 41 | self.resume = resume 42 | self.title = '' if title == None else title 43 | if fpath is not None: 44 | if resume: 45 | self.file = open(fpath, 'r') 46 | name = self.file.readline() 47 | self.names = name.rstrip().split('\t') 48 | self.numbers = {} 49 | for _, name in enumerate(self.names): 50 | self.numbers[name] = [] 51 | 52 | for numbers in self.file: 53 | numbers = numbers.rstrip().split('\t') 54 | for i in range(0, len(numbers)): 55 | self.numbers[self.names[i]].append(numbers[i]) 56 | self.file.close() 57 | self.file = open(fpath, 'a') 58 | else: 59 | self.file = open(fpath, 'w') 60 | 61 | def set_names(self, names): 62 | if self.resume: 63 | pass 64 | # initialize numbers as empty list 65 | self.numbers = {} 66 | self.names = names 67 | for _, name in enumerate(self.names): 68 | self.file.write(name) 69 | self.file.write('\t') 70 | self.numbers[name] = [] 71 | self.file.write('\n') 72 | self.file.flush() 73 | 74 | def append(self, numbers): 75 | assert len(self.names) == len(numbers), 'Numbers do not match names' 76 | for index, num in enumerate(numbers): 77 | self.file.write("{0:.6f}".format(num)) 78 | self.file.write('\t') 79 | self.numbers[self.names[index]].append(num) 80 | self.file.write('\n') 81 | self.file.flush() 82 | 83 | def plot(self, names=None): 84 | names = self.names if names == None else names 85 | numbers = self.numbers 86 | for _, name in enumerate(names): 87 | x = np.arange(len(numbers[name])) 88 | plt.plot(x, np.asarray(numbers[name])) 89 | plt.legend([self.title + '(' + name + ')' for name in names]) 90 | plt.grid(True) 91 | 92 | def close(self): 93 | if self.file is not None: 94 | self.file.close() 95 | 96 | 97 | def accuracy(output, target, topk=(1,)): 98 | maxk = max(topk) 99 | batch_size = target.size(0) 100 | 101 | _, pred = output.topk(maxk, 1, True, True) 102 | pred = pred.t() 103 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 104 | 105 | res = [] 106 | for k in topk: 107 | correct_k = correct[:k].view(-1).float().sum(0) 108 | res.append(correct_k.mul_(100.0 / batch_size)) 109 | return res 110 | 111 | 112 | USE_CUDA = torch.cuda.is_available() 113 | FLOAT = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor 114 | from torch.autograd import Variable 115 | 116 | 117 | def to_numpy(var): 118 | # return var.cpu().data.numpy() 119 | return var.cpu().data.numpy() if USE_CUDA else var.data.numpy() 120 | 121 | 122 | def to_tensor(ndarray, volatile=False, requires_grad=False, dtype=FLOAT): 123 | return Variable( 124 | torch.from_numpy(ndarray), volatile=volatile, requires_grad=requires_grad 125 | ).type(dtype) 126 | 127 | 128 | def sample_from_truncated_normal_distribution(lower, upper, mu, sigma, size=1): 129 | from scipy import stats 130 | return stats.truncnorm.rvs((lower-mu)/sigma, (upper-mu)/sigma, loc=mu, scale=sigma, size=size) 131 | 132 | 133 | # logging 134 | def prRed(prt): print("\033[91m {}\033[00m" .format(prt)) 135 | def prGreen(prt): print("\033[92m {}\033[00m" .format(prt)) 136 | def prYellow(prt): print("\033[93m {}\033[00m" .format(prt)) 137 | def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt)) 138 | def prPurple(prt): print("\033[95m {}\033[00m" .format(prt)) 139 | def prCyan(prt): print("\033[96m {}\033[00m" .format(prt)) 140 | def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt)) 141 | def prBlack(prt): print("\033[98m {}\033[00m" .format(prt)) 142 | 143 | 144 | def get_num_gen(gen): 145 | return sum(1 for x in gen) 146 | 147 | 148 | def is_leaf(model): 149 | return get_num_gen(model.children()) == 0 150 | 151 | 152 | def get_layer_info(layer): 153 | layer_str = str(layer) 154 | type_name = layer_str[:layer_str.find('(')].strip() 155 | return type_name 156 | 157 | 158 | def get_layer_param(model): 159 | import operator 160 | import functools 161 | 162 | return sum([functools.reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 163 | 164 | 165 | def measure_layer(layer, x): 166 | global count_ops, count_params 167 | delta_ops = 0 168 | delta_params = 0 169 | multi_add = 1 170 | type_name = get_layer_info(layer) 171 | 172 | # ops_conv 173 | if type_name in ['Conv2d', 'QConv2d']: 174 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / 175 | layer.stride[0] + 1) 176 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / 177 | layer.stride[1] + 1) 178 | layer.in_h = x.size()[2] 179 | layer.in_w = x.size()[3] 180 | layer.out_h = out_h 181 | layer.out_w = out_w 182 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 183 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 184 | delta_params = get_layer_param(layer) 185 | layer.flops = delta_ops 186 | layer.params = delta_params 187 | 188 | # ops_nonlinearity 189 | elif type_name in ['ReLU']: 190 | delta_ops = x.numel() / x.size(0) 191 | delta_params = get_layer_param(layer) 192 | 193 | # ops_pooling 194 | elif type_name in ['AvgPool2d']: 195 | in_w = x.size()[2] 196 | kernel_ops = layer.kernel_size * layer.kernel_size 197 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 198 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 199 | delta_ops = x.size()[1] * out_w * out_h * kernel_ops 200 | delta_params = get_layer_param(layer) 201 | 202 | elif type_name in ['AdaptiveAvgPool2d']: 203 | delta_ops = x.size()[1] * x.size()[2] * x.size()[3] 204 | delta_params = get_layer_param(layer) 205 | 206 | # ops_linear 207 | elif type_name in ['Linear', 'QLinear']: 208 | weight_ops = layer.weight.numel() * multi_add 209 | if layer.bias is not None: 210 | bias_ops = layer.bias.numel() 211 | else: 212 | bias_ops = 0 213 | layer.in_h = x.size()[1] 214 | layer.in_w = 1 215 | delta_ops = weight_ops + bias_ops 216 | delta_params = get_layer_param(layer) 217 | layer.flops = delta_ops 218 | layer.params = delta_params 219 | 220 | # ops_nothing 221 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']: 222 | delta_params = get_layer_param(layer) 223 | 224 | # unknown layer type 225 | else: 226 | delta_params = get_layer_param(layer) 227 | 228 | count_ops += delta_ops 229 | count_params += delta_params 230 | 231 | return delta_ops, delta_params 232 | 233 | 234 | def measure_model(model, H, W): 235 | global count_ops, count_params 236 | count_ops = 0 237 | count_params = 0 238 | data = torch.zeros(1, H, W).cuda() 239 | 240 | def should_measure(x): 241 | return is_leaf(x) 242 | 243 | def modify_forward(model): 244 | for child in model.children(): 245 | if should_measure(child): 246 | def new_forward(m): 247 | def lambda_forward(x): 248 | measure_layer(m, x) 249 | return m.old_forward(x) 250 | return lambda_forward 251 | child.old_forward = child.forward 252 | child.forward = new_forward(child) 253 | else: 254 | modify_forward(child) 255 | 256 | def restore_forward(model): 257 | for child in model.children(): 258 | # leaf node 259 | if is_leaf(child) and hasattr(child, 'old_forward'): 260 | child.forward = child.old_forward 261 | child.old_forward = None 262 | else: 263 | restore_forward(child) 264 | 265 | modify_forward(model) 266 | model.forward(data) 267 | restore_forward(model) 268 | 269 | return count_ops, count_params 270 | 271 | -------------------------------------------------------------------------------- /haq_lib/lib/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Code for "[HAQ: Hardware-Aware Automated Quantization with Mixed Precision" 2 | # Kuan Wang*, Zhijian Liu*, Yujun Lin*, Ji Lin, Song Han 3 | # {kuanwang, zhijian, yujunlin, jilin, songhan}@mit.edu 4 | 5 | import os 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn.parallel 10 | import torch.optim 11 | import torch.utils.data 12 | import torchvision.transforms as transforms 13 | import torchvision.datasets as datasets 14 | from torch.utils.data.sampler import SubsetRandomSampler 15 | 16 | 17 | def get_dataset(dataset_name, batch_size, n_worker, data_root='data/imagenet', for_inception=False): 18 | print('==> Preparing data..') 19 | if dataset_name == 'imagenet': 20 | traindir = os.path.join(data_root, 'train') 21 | valdir = os.path.join(data_root, 'val') 22 | assert os.path.exists(traindir), traindir + ' not found' 23 | assert os.path.exists(valdir), valdir + ' not found' 24 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 25 | std=[0.229, 0.224, 0.225]) 26 | 27 | input_size = 299 if for_inception else 224 28 | 29 | train_loader = torch.utils.data.DataLoader( 30 | datasets.ImageFolder( 31 | traindir, transforms.Compose([ 32 | transforms.RandomResizedCrop(input_size), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | normalize, 36 | ])), 37 | batch_size=batch_size, shuffle=True, 38 | num_workers=n_worker, pin_memory=True) 39 | 40 | val_loader = torch.utils.data.DataLoader( 41 | datasets.ImageFolder(valdir, transforms.Compose([ 42 | transforms.Resize(int(input_size / 0.875)), 43 | transforms.CenterCrop(input_size), 44 | transforms.ToTensor(), 45 | normalize, 46 | ])), 47 | batch_size=batch_size, shuffle=False, 48 | num_workers=n_worker, pin_memory=True) 49 | 50 | n_class = 1000 51 | elif dataset_name == 'imagenet100': 52 | traindir = os.path.join(data_root, 'train') 53 | valdir = os.path.join(data_root, 'val') 54 | assert os.path.exists(traindir), traindir + ' not found' 55 | assert os.path.exists(valdir), valdir + ' not found' 56 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 57 | std=[0.229, 0.224, 0.225]) 58 | 59 | input_size = 299 if for_inception else 224 60 | 61 | train_loader = torch.utils.data.DataLoader( 62 | datasets.ImageFolder( 63 | traindir, transforms.Compose([ 64 | transforms.RandomResizedCrop(input_size), 65 | transforms.RandomHorizontalFlip(), 66 | transforms.ToTensor(), 67 | normalize, 68 | ])), 69 | batch_size=batch_size, shuffle=True, 70 | num_workers=n_worker, pin_memory=True) 71 | 72 | val_loader = torch.utils.data.DataLoader( 73 | datasets.ImageFolder(valdir, transforms.Compose([ 74 | transforms.Resize(int(input_size / 0.875)), 75 | transforms.CenterCrop(input_size), 76 | transforms.ToTensor(), 77 | normalize, 78 | ])), 79 | batch_size=batch_size, shuffle=False, 80 | num_workers=n_worker, pin_memory=True) 81 | 82 | n_class = 100 83 | elif dataset_name == 'imagenet10': 84 | traindir = os.path.join(data_root, 'train') 85 | valdir = os.path.join(data_root, 'val') 86 | assert os.path.exists(traindir), traindir + ' not found' 87 | assert os.path.exists(valdir), valdir + ' not found' 88 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 89 | std=[0.229, 0.224, 0.225]) 90 | 91 | input_size = 299 if for_inception else 224 92 | 93 | train_loader = torch.utils.data.DataLoader( 94 | datasets.ImageFolder( 95 | traindir, transforms.Compose([ 96 | transforms.RandomResizedCrop(input_size), 97 | transforms.RandomHorizontalFlip(), 98 | transforms.ToTensor(), 99 | normalize, 100 | ])), 101 | batch_size=batch_size, shuffle=True, 102 | num_workers=n_worker, pin_memory=True) 103 | 104 | val_loader = torch.utils.data.DataLoader( 105 | datasets.ImageFolder(valdir, transforms.Compose([ 106 | transforms.Resize(int(input_size / 0.875)), 107 | transforms.CenterCrop(input_size), 108 | transforms.ToTensor(), 109 | normalize, 110 | ])), 111 | batch_size=batch_size, shuffle=False, 112 | num_workers=n_worker, pin_memory=True) 113 | 114 | n_class = 10 115 | else: 116 | # Add customized data here 117 | raise NotImplementedError 118 | return train_loader, val_loader, n_class 119 | 120 | 121 | def get_split_train_dataset(dataset_name, batch_size, n_worker, val_size, train_size=None, random_seed=1, 122 | data_root='data/imagenet', for_inception=False, shuffle=True): 123 | if shuffle: 124 | index_sampler = SubsetRandomSampler 125 | else: 126 | # use the same order 127 | class SubsetSequentialSampler(SubsetRandomSampler): 128 | def __iter__(self): 129 | return (self.indices[i] for i in torch.arange(len(self.indices)).int()) 130 | index_sampler = SubsetSequentialSampler 131 | 132 | print('==> Preparing data..') 133 | if dataset_name == 'imagenet': 134 | 135 | traindir = os.path.join(data_root, 'train') 136 | valdir = os.path.join(data_root, 'val') 137 | assert os.path.exists(traindir), traindir + ' not found' 138 | assert os.path.exists(valdir), valdir + ' not found' 139 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 140 | std=[0.229, 0.224, 0.225]) 141 | 142 | input_size = 299 if for_inception else 224 143 | train_transform = transforms.Compose([ 144 | transforms.RandomResizedCrop(input_size), 145 | transforms.RandomHorizontalFlip(), 146 | transforms.ToTensor(), 147 | normalize, 148 | ]) 149 | test_transform = transforms.Compose([ 150 | transforms.Resize(int(input_size/0.875)), 151 | transforms.CenterCrop(input_size), 152 | transforms.ToTensor(), 153 | normalize, 154 | ]) 155 | 156 | trainset = datasets.ImageFolder(traindir, train_transform) 157 | valset = datasets.ImageFolder(traindir, test_transform) 158 | 159 | n_train = len(trainset) 160 | indices = list(range(n_train)) 161 | # shuffle the indices 162 | np.random.seed(random_seed) 163 | np.random.shuffle(indices) 164 | assert val_size < n_train, 'val size should less than n_train' 165 | train_idx, val_idx = indices[val_size:], indices[:val_size] 166 | if train_size: 167 | train_idx = train_idx[:train_size] 168 | print('Data: train: {}, val: {}'.format(len(train_idx), len(val_idx))) 169 | 170 | train_sampler = index_sampler(train_idx) 171 | val_sampler = index_sampler(val_idx) 172 | 173 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, 174 | num_workers=n_worker, pin_memory=True) 175 | val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, sampler=val_sampler, 176 | num_workers=n_worker, pin_memory=True) 177 | n_class = 1000 178 | elif dataset_name == 'imagenet100': 179 | 180 | traindir = os.path.join(data_root, 'train') 181 | valdir = os.path.join(data_root, 'val') 182 | assert os.path.exists(traindir), traindir + ' not found' 183 | assert os.path.exists(valdir), valdir + ' not found' 184 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 185 | std=[0.229, 0.224, 0.225]) 186 | 187 | input_size = 299 if for_inception else 224 188 | train_transform = transforms.Compose([ 189 | transforms.RandomResizedCrop(input_size), 190 | transforms.RandomHorizontalFlip(), 191 | transforms.ToTensor(), 192 | normalize, 193 | ]) 194 | test_transform = transforms.Compose([ 195 | transforms.Resize(int(input_size/0.875)), 196 | transforms.CenterCrop(input_size), 197 | transforms.ToTensor(), 198 | normalize, 199 | ]) 200 | 201 | trainset = datasets.ImageFolder(traindir, train_transform) 202 | valset = datasets.ImageFolder(traindir, test_transform) 203 | 204 | n_train = len(trainset) 205 | indices = list(range(n_train)) 206 | # shuffle the indices 207 | np.random.seed(random_seed) 208 | np.random.shuffle(indices) 209 | assert val_size < n_train, 'val size should less than n_train' 210 | train_idx, val_idx = indices[val_size:], indices[:val_size] 211 | if train_size: 212 | train_idx = train_idx[:train_size] 213 | print('Data: train: {}, val: {}'.format(len(train_idx), len(val_idx))) 214 | 215 | train_sampler = index_sampler(train_idx) 216 | val_sampler = index_sampler(val_idx) 217 | 218 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, 219 | num_workers=n_worker, pin_memory=True) 220 | val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, sampler=val_sampler, 221 | num_workers=n_worker, pin_memory=True) 222 | n_class = 100 223 | else: 224 | raise NotImplementedError 225 | 226 | return train_loader, val_loader, n_class 227 | -------------------------------------------------------------------------------- /utils/provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize_data(batch_data): 4 | """ Normalize the batch data, use coordinates of the block centered at origin, 5 | Input: 6 | BxNxC array 7 | Output: 8 | BxNxC array 9 | """ 10 | B, N, C = batch_data.shape 11 | normal_data = np.zeros((B, N, C)) 12 | for b in range(B): 13 | pc = batch_data[b] 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 17 | pc = pc / m 18 | normal_data[b] = pc 19 | return normal_data 20 | 21 | 22 | def shuffle_data(data, labels): 23 | """ Shuffle data and labels. 24 | Input: 25 | data: B,N,... numpy array 26 | label: B,... numpy array 27 | Return: 28 | shuffled data, label and shuffle indices 29 | """ 30 | idx = np.arange(len(labels)) 31 | np.random.shuffle(idx) 32 | return data[idx, ...], labels[idx], idx 33 | 34 | def shuffle_points(batch_data): 35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 36 | Use the same shuffling idx for the entire batch. 37 | Input: 38 | BxNxC array 39 | Output: 40 | BxNxC array 41 | """ 42 | idx = np.arange(batch_data.shape[1]) 43 | np.random.shuffle(idx) 44 | return batch_data[:,idx,:] 45 | 46 | def rotate_point_cloud(batch_data): 47 | """ Randomly rotate the point clouds to augument the dataset 48 | rotation is per shape based along up direction 49 | Input: 50 | BxNx3 array, original batch of point clouds 51 | Return: 52 | BxNx3 array, rotated batch of point clouds 53 | """ 54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 55 | for k in range(batch_data.shape[0]): 56 | rotation_angle = np.random.uniform() * 2 * np.pi 57 | cosval = np.cos(rotation_angle) 58 | sinval = np.sin(rotation_angle) 59 | rotation_matrix = np.array([[cosval, 0, sinval], 60 | [0, 1, 0], 61 | [-sinval, 0, cosval]]) 62 | shape_pc = batch_data[k, ...] 63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 64 | return rotated_data 65 | 66 | def rotate_point_cloud_z(batch_data): 67 | """ Randomly rotate the point clouds to augument the dataset 68 | rotation is per shape based along up direction 69 | Input: 70 | BxNx3 array, original batch of point clouds 71 | Return: 72 | BxNx3 array, rotated batch of point clouds 73 | """ 74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 75 | for k in range(batch_data.shape[0]): 76 | rotation_angle = np.random.uniform() * 2 * np.pi 77 | cosval = np.cos(rotation_angle) 78 | sinval = np.sin(rotation_angle) 79 | rotation_matrix = np.array([[cosval, sinval, 0], 80 | [-sinval, cosval, 0], 81 | [0, 0, 1]]) 82 | shape_pc = batch_data[k, ...] 83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 84 | return rotated_data 85 | 86 | def rotate_point_cloud_with_normal(batch_xyz_normal): 87 | ''' Randomly rotate XYZ, normal point cloud. 88 | Input: 89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 90 | Output: 91 | B,N,6, rotated XYZ, normal point cloud 92 | ''' 93 | for k in range(batch_xyz_normal.shape[0]): 94 | rotation_angle = np.random.uniform() * 2 * np.pi 95 | cosval = np.cos(rotation_angle) 96 | sinval = np.sin(rotation_angle) 97 | rotation_matrix = np.array([[cosval, 0, sinval], 98 | [0, 1, 0], 99 | [-sinval, 0, cosval]]) 100 | shape_pc = batch_xyz_normal[k,:,0:3] 101 | shape_normal = batch_xyz_normal[k,:,3:6] 102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 104 | return batch_xyz_normal 105 | 106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 107 | """ Randomly perturb the point clouds by small rotations 108 | Input: 109 | BxNx6 array, original batch of point clouds and point normals 110 | Return: 111 | BxNx3 array, rotated batch of point clouds 112 | """ 113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 114 | for k in range(batch_data.shape[0]): 115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 116 | Rx = np.array([[1,0,0], 117 | [0,np.cos(angles[0]),-np.sin(angles[0])], 118 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 120 | [0,1,0], 121 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 123 | [np.sin(angles[2]),np.cos(angles[2]),0], 124 | [0,0,1]]) 125 | R = np.dot(Rz, np.dot(Ry,Rx)) 126 | shape_pc = batch_data[k,:,0:3] 127 | shape_normal = batch_data[k,:,3:6] 128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 130 | return rotated_data 131 | 132 | 133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 134 | """ Rotate the point cloud along up direction with certain angle. 135 | Input: 136 | BxNx3 array, original batch of point clouds 137 | Return: 138 | BxNx3 array, rotated batch of point clouds 139 | """ 140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 141 | for k in range(batch_data.shape[0]): 142 | #rotation_angle = np.random.uniform() * 2 * np.pi 143 | cosval = np.cos(rotation_angle) 144 | sinval = np.sin(rotation_angle) 145 | rotation_matrix = np.array([[cosval, 0, sinval], 146 | [0, 1, 0], 147 | [-sinval, 0, cosval]]) 148 | shape_pc = batch_data[k,:,0:3] 149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 150 | return rotated_data 151 | 152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 153 | """ Rotate the point cloud along up direction with certain angle. 154 | Input: 155 | BxNx6 array, original batch of point clouds with normal 156 | scalar, angle of rotation 157 | Return: 158 | BxNx6 array, rotated batch of point clouds iwth normal 159 | """ 160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 161 | for k in range(batch_data.shape[0]): 162 | #rotation_angle = np.random.uniform() * 2 * np.pi 163 | cosval = np.cos(rotation_angle) 164 | sinval = np.sin(rotation_angle) 165 | rotation_matrix = np.array([[cosval, 0, sinval], 166 | [0, 1, 0], 167 | [-sinval, 0, cosval]]) 168 | shape_pc = batch_data[k,:,0:3] 169 | shape_normal = batch_data[k,:,3:6] 170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) 172 | return rotated_data 173 | 174 | 175 | 176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 177 | """ Randomly perturb the point clouds by small rotations 178 | Input: 179 | BxNx3 array, original batch of point clouds 180 | Return: 181 | BxNx3 array, rotated batch of point clouds 182 | """ 183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 184 | for k in range(batch_data.shape[0]): 185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 186 | Rx = np.array([[1,0,0], 187 | [0,np.cos(angles[0]),-np.sin(angles[0])], 188 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 190 | [0,1,0], 191 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 193 | [np.sin(angles[2]),np.cos(angles[2]),0], 194 | [0,0,1]]) 195 | R = np.dot(Rz, np.dot(Ry,Rx)) 196 | shape_pc = batch_data[k, ...] 197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 198 | return rotated_data 199 | 200 | 201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 202 | """ Randomly jitter points. jittering is per point. 203 | Input: 204 | BxNx3 array, original batch of point clouds 205 | Return: 206 | BxNx3 array, jittered batch of point clouds 207 | """ 208 | B, N, C = batch_data.shape 209 | assert(clip > 0) 210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 211 | jittered_data += batch_data 212 | return jittered_data 213 | 214 | def shift_point_cloud(batch_data, shift_range=0.1): 215 | """ Randomly shift point cloud. Shift is per point cloud. 216 | Input: 217 | BxNx3 array, original batch of point clouds 218 | Return: 219 | BxNx3 array, shifted batch of point clouds 220 | """ 221 | B, N, C = batch_data.shape 222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 223 | for batch_index in range(B): 224 | batch_data[batch_index,:,:] += shifts[batch_index,:] 225 | return batch_data 226 | 227 | 228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 229 | """ Randomly scale the point cloud. Scale is per point cloud. 230 | Input: 231 | BxNx3 array, original batch of point clouds 232 | Return: 233 | BxNx3 array, scaled batch of point clouds 234 | """ 235 | B, N, C = batch_data.shape 236 | scales = np.random.uniform(scale_low, scale_high, B) 237 | for batch_index in range(B): 238 | batch_data[batch_index,:,:] *= scales[batch_index] 239 | return batch_data 240 | 241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 242 | ''' batch_pc: BxNx3 ''' 243 | for b in range(batch_pc.shape[0]): 244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 246 | if len(drop_idx)>0: 247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point 248 | return batch_pc 249 | 250 | 251 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import importlib 5 | import shutil 6 | from utils import provider 7 | import numpy as np 8 | 9 | from tqdm import tqdm 10 | from utils.dataset import PartNormalDataset 11 | import hydra 12 | import yaml 13 | from utils.general import to_categorical 14 | from utils.general import seg_classes, seg_label_to_cat 15 | from search import AttrDict, create_attr_dict 16 | 17 | 18 | def inplace_relu(m): 19 | classname = m.__class__.__name__ 20 | if classname.find('ReLU') != -1: 21 | m.inplace = True 22 | 23 | 24 | def main(): 25 | 26 | with open('config/partseg.yaml', 'r') as f: 27 | args = AttrDict(yaml.safe_load(f.read())) 28 | create_attr_dict(args) 29 | 30 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 31 | 32 | work_dir = args.work_dir 33 | if not os.path.exists(work_dir): 34 | os.makedirs(work_dir) 35 | 36 | log_file = os.path.join(work_dir, 'pretrain.log') 37 | logger = logging.getLogger('search') 38 | logger.setLevel(logging.INFO) 39 | file_handler = logging.FileHandler(log_file, 'w') 40 | file_handler.setFormatter( 41 | logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) 42 | file_handler.setLevel(logging.INFO) 43 | logger.addHandler(file_handler) 44 | console = logging.StreamHandler() 45 | logger.addHandler(console) 46 | 47 | root = hydra.utils.to_absolute_path( 48 | 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/') 49 | 50 | TRAIN_DATASET = PartNormalDataset( 51 | root=root, npoints=args.num_point, split='trainval', normal_channel=args.normal) 52 | trainDataLoader = torch.utils.data.DataLoader( 53 | TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True) 54 | TEST_DATASET = PartNormalDataset( 55 | root=root, npoints=args.num_point, split='test', normal_channel=args.normal) 56 | testDataLoader = torch.utils.data.DataLoader( 57 | TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10) 58 | 59 | '''MODEL LOADING''' 60 | args.input_dim = (6 if args.normal else 3) + 16 61 | args.num_class = 50 62 | num_category = 16 63 | num_part = args.num_class 64 | shutil.copy(hydra.utils.to_absolute_path('models/model.py'), '.') 65 | 66 | model = getattr(importlib.import_module('models.model'), 67 | 'PointTransformerSeg')(args).cuda() 68 | criterion = torch.nn.CrossEntropyLoss() 69 | 70 | try: 71 | checkpoint = torch.load('best_model.pth') 72 | start_epoch = checkpoint['epoch'] 73 | model.load_state_dict(checkpoint['model_state_dict']) 74 | logger.info('Use pretrain model') 75 | except: 76 | logger.info('No existing model, starting training from scratch...') 77 | start_epoch = 0 78 | 79 | if args.optimizer == 'Adam': 80 | optimizer = torch.optim.Adam( 81 | model.parameters(), 82 | lr=args.learning_rate, 83 | betas=(0.9, 0.999), 84 | eps=1e-08, 85 | weight_decay=args.weight_decay 86 | ) 87 | else: 88 | optimizer = torch.optim.SGD( 89 | model.parameters(), lr=args.learning_rate, momentum=0.9) 90 | 91 | def bn_momentum_adjust(m, momentum): 92 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 93 | m.momentum = momentum 94 | 95 | LEARNING_RATE_CLIP = 1e-5 96 | MOMENTUM_ORIGINAL = 0.1 97 | MOMENTUM_DECCAY = 0.5 98 | MOMENTUM_DECCAY_STEP = args.step_size 99 | 100 | best_acc = 0 101 | global_epoch = 0 102 | best_class_avg_iou = 0 103 | best_inctance_avg_iou = 0 104 | fp16 = args['fp16'] 105 | if fp16: 106 | scaler = torch.cuda.amp.GradScaler() 107 | 108 | for epoch in range(start_epoch, args.epoch): 109 | mean_correct = [] 110 | 111 | logger.info('Epoch %d (%d/%s):' % 112 | (global_epoch + 1, epoch + 1, args.epoch)) 113 | '''Adjust learning rate and BN momentum''' 114 | lr = max(args.learning_rate * (args.lr_decay ** 115 | (epoch // args.step_size)), LEARNING_RATE_CLIP) 116 | logger.info('Learning rate:%f' % lr) 117 | for param_group in optimizer.param_groups: 118 | param_group['lr'] = lr 119 | momentum = MOMENTUM_ORIGINAL * \ 120 | (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP)) 121 | if momentum < 0.01: 122 | momentum = 0.01 123 | print('BN momentum updated to: %f' % momentum) 124 | model = model.apply(lambda x: bn_momentum_adjust(x, momentum)) 125 | model = model.train() 126 | 127 | '''learning one epoch''' 128 | for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9): 129 | points = points.data.numpy() 130 | points[:, :, 0:3] = provider.random_scale_point_cloud( 131 | points[:, :, 0:3]) 132 | points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) 133 | points = torch.Tensor(points) 134 | target = target.view(-1, 1)[:, 0] 135 | 136 | points, label, target = points.float().cuda( 137 | ), label.long().cuda(), target.long().cuda() 138 | optimizer.zero_grad() 139 | 140 | if fp16: 141 | with torch.cuda.amp.autocast(): 142 | seg_pred = model(torch.cat([points, to_categorical( 143 | label, num_category).repeat(1, points.shape[1], 1)], -1)) 144 | seg_pred = seg_pred.contiguous().view(-1, num_part) 145 | loss = criterion(seg_pred, target) 146 | scaler.scale(loss).backward() 147 | scaler.step(optimizer) 148 | scaler.update() 149 | else: 150 | seg_pred = model(torch.cat([points, to_categorical( 151 | label, num_category).repeat(1, points.shape[1], 1)], -1)) 152 | seg_pred = seg_pred.contiguous().view(-1, num_part) 153 | loss = criterion(seg_pred, target) 154 | loss.backward() 155 | optimizer.step() 156 | 157 | pred_choice = seg_pred.data.max(1)[1] 158 | 159 | correct = pred_choice.eq(target.data).cpu().sum() 160 | mean_correct.append( 161 | correct.item() / (args.batch_size * args.num_point)) 162 | 163 | train_instance_acc = np.mean(mean_correct) 164 | logger.info('Train accuracy is: %.5f' % train_instance_acc) 165 | 166 | with torch.no_grad(): 167 | test_metrics = {} 168 | total_correct = 0 169 | total_seen = 0 170 | total_seen_class = [0 for _ in range(num_part)] 171 | total_correct_class = [0 for _ in range(num_part)] 172 | shape_ious = {cat: [] for cat in seg_classes.keys()} 173 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 174 | 175 | for cat in seg_classes.keys(): 176 | for label in seg_classes[cat]: 177 | seg_label_to_cat[label] = cat 178 | 179 | model = model.eval() 180 | 181 | for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): 182 | cur_batch_size, NUM_POINT, _ = points.size() 183 | points, label, target = points.float().cuda( 184 | ), label.long().cuda(), target.long().cuda() 185 | seg_pred = model(torch.cat([points, to_categorical( 186 | label, num_category).repeat(1, points.shape[1], 1)], -1)) 187 | cur_pred_val = seg_pred.cpu().data.numpy() 188 | cur_pred_val_logits = cur_pred_val 189 | cur_pred_val = np.zeros( 190 | (cur_batch_size, NUM_POINT)).astype(np.int32) 191 | target = target.cpu().data.numpy() 192 | 193 | for i in range(cur_batch_size): 194 | cat = seg_label_to_cat[target[i, 0]] 195 | logits = cur_pred_val_logits[i, :, :] 196 | cur_pred_val[i, :] = np.argmax( 197 | logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] 198 | 199 | correct = np.sum(cur_pred_val == target) 200 | total_correct += correct 201 | total_seen += (cur_batch_size * NUM_POINT) 202 | 203 | for l in range(num_part): 204 | total_seen_class[l] += np.sum(target == l) 205 | total_correct_class[l] += ( 206 | np.sum((cur_pred_val == l) & (target == l))) 207 | 208 | for i in range(cur_batch_size): 209 | segp = cur_pred_val[i, :] 210 | segl = target[i, :] 211 | cat = seg_label_to_cat[segl[0]] 212 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 213 | for l in seg_classes[cat]: 214 | if (np.sum(segl == l) == 0) and ( 215 | np.sum(segp == l) == 0): # part is not present, no prediction as well 216 | part_ious[l - seg_classes[cat][0]] = 1.0 217 | else: 218 | part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( 219 | np.sum((segl == l) | (segp == l))) 220 | shape_ious[cat].append(np.mean(part_ious)) 221 | 222 | all_shape_ious = [] 223 | for cat in shape_ious.keys(): 224 | for iou in shape_ious[cat]: 225 | all_shape_ious.append(iou) 226 | shape_ious[cat] = np.mean(shape_ious[cat]) 227 | mean_shape_ious = np.mean(list(shape_ious.values())) 228 | test_metrics['accuracy'] = total_correct / float(total_seen) 229 | test_metrics['class_avg_accuracy'] = np.mean( 230 | np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) 231 | for cat in sorted(shape_ious.keys()): 232 | logger.info('eval mIoU of %s %f' % 233 | (cat + ' ' * (14 - len(cat)), shape_ious[cat])) 234 | test_metrics['class_avg_iou'] = mean_shape_ious 235 | test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) 236 | 237 | logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( 238 | epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou'])) 239 | if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): 240 | logger.info('Save model...') 241 | savepath = os.path.join(work_dir, 'best_model.pth') 242 | logger.info('Saving at %s' % savepath) 243 | state = { 244 | 'epoch': epoch, 245 | 'train_acc': train_instance_acc, 246 | 'test_acc': test_metrics['accuracy'], 247 | 'class_avg_iou': test_metrics['class_avg_iou'], 248 | 'inctance_avg_iou': test_metrics['inctance_avg_iou'], 249 | 'model_state_dict': model.state_dict(), 250 | 'optimizer_state_dict': optimizer.state_dict(), 251 | } 252 | torch.save(state, savepath) 253 | logger.info('Saving model....') 254 | 255 | if test_metrics['accuracy'] > best_acc: 256 | best_acc = test_metrics['accuracy'] 257 | if test_metrics['class_avg_iou'] > best_class_avg_iou: 258 | best_class_avg_iou = test_metrics['class_avg_iou'] 259 | if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: 260 | best_inctance_avg_iou = test_metrics['inctance_avg_iou'] 261 | logger.info('Best accuracy is: %.5f' % best_acc) 262 | logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou) 263 | logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou) 264 | global_epoch += 1 265 | 266 | shutil.copy(savepath, './best_model.pth') 267 | 268 | 269 | if __name__ == '__main__': 270 | main() 271 | -------------------------------------------------------------------------------- /haq_lib/lib/rl/memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import deque, namedtuple 3 | import warnings 4 | import random 5 | 6 | import numpy as np 7 | 8 | # [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py 9 | 10 | # This is to be understood as a transition: Given `state0`, performing `action` 11 | # yields `reward` and results in `state1`, which might be `terminal`. 12 | Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1') 13 | 14 | 15 | def sample_batch_indexes(low, high, size): 16 | if high - low >= size: 17 | # We have enough data. Draw without replacement, that is each index is unique in the 18 | # batch. We cannot use `np.random.choice` here because it is horribly inefficient as 19 | # the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion. 20 | # `random.sample` does the same thing (drawing without replacement) and is way faster. 21 | try: 22 | r = xrange(low, high) 23 | except NameError: 24 | r = range(low, high) 25 | batch_idxs = random.sample(r, size) 26 | else: 27 | # Not enough data. Help ourselves with sampling from the range, but the same index 28 | # can occur multiple times. This is not good and should be avoided by picking a 29 | # large enough warm-up phase. 30 | warnings.warn( 31 | 'Not enough entries to sample without replacement. Consider increasing your warm-up phase to avoid oversampling!') 32 | batch_idxs = np.random.random_integers(low, high - 1, size=size) 33 | assert len(batch_idxs) == size 34 | return batch_idxs 35 | 36 | 37 | class RingBuffer(object): 38 | def __init__(self, maxlen): 39 | self.maxlen = maxlen 40 | self.start = 0 41 | self.length = 0 42 | self.data = [None for _ in range(maxlen)] 43 | 44 | def __len__(self): 45 | return self.length 46 | 47 | def __getitem__(self, idx): 48 | if idx < 0 or idx >= self.length: 49 | raise KeyError() 50 | return self.data[(self.start + idx) % self.maxlen] 51 | 52 | def append(self, v): 53 | if self.length < self.maxlen: 54 | # We have space, simply increase the length. 55 | self.length += 1 56 | elif self.length == self.maxlen: 57 | # No space, "remove" the first item. 58 | self.start = (self.start + 1) % self.maxlen 59 | else: 60 | # This should never happen. 61 | raise RuntimeError() 62 | self.data[(self.start + self.length - 1) % self.maxlen] = v 63 | 64 | 65 | def zeroed_observation(observation): 66 | if hasattr(observation, 'shape'): 67 | return np.zeros(observation.shape) 68 | elif hasattr(observation, '__iter__'): 69 | out = [] 70 | for x in observation: 71 | out.append(zeroed_observation(x)) 72 | return out 73 | else: 74 | return 0. 75 | 76 | 77 | class Memory(object): 78 | def __init__(self, window_length, ignore_episode_boundaries=False): 79 | self.window_length = window_length 80 | self.ignore_episode_boundaries = ignore_episode_boundaries 81 | 82 | self.recent_observations = deque(maxlen=window_length) 83 | self.recent_terminals = deque(maxlen=window_length) 84 | 85 | def sample(self, batch_size, batch_idxs=None): 86 | raise NotImplementedError() 87 | 88 | def append(self, observation, action, reward, terminal, training=True): 89 | self.recent_observations.append(observation) 90 | self.recent_terminals.append(terminal) 91 | 92 | def get_recent_state(self, current_observation): 93 | # This code is slightly complicated by the fact that subsequent observations might be 94 | # from different episodes. We ensure that an experience never spans multiple episodes. 95 | # This is probably not that important in practice but it seems cleaner. 96 | state = [current_observation] 97 | idx = len(self.recent_observations) - 1 98 | for offset in range(0, self.window_length - 1): 99 | current_idx = idx - offset 100 | current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False 101 | if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): 102 | # The previously handled observation was terminal, don't add the current one. 103 | # Otherwise we would leak into a different episode. 104 | break 105 | state.insert(0, self.recent_observations[current_idx]) 106 | while len(state) < self.window_length: 107 | state.insert(0, zeroed_observation(state[0])) 108 | return state 109 | 110 | def get_config(self): 111 | config = { 112 | 'window_length': self.window_length, 113 | 'ignore_episode_boundaries': self.ignore_episode_boundaries, 114 | } 115 | return config 116 | 117 | 118 | class SequentialMemory(Memory): 119 | def __init__(self, limit, **kwargs): 120 | super(SequentialMemory, self).__init__(**kwargs) 121 | 122 | self.limit = limit 123 | 124 | # Do not use deque to implement the memory. This data structure may seem convenient but 125 | # it is way too slow on random access. Instead, we use our own ring buffer implementation. 126 | self.actions = RingBuffer(limit) 127 | self.rewards = RingBuffer(limit) 128 | self.terminals = RingBuffer(limit) 129 | self.observations = RingBuffer(limit) 130 | 131 | def sample(self, batch_size, batch_idxs=None): 132 | if batch_idxs is None: 133 | # Draw random indexes such that we have at least a single entry before each 134 | # index. 135 | batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size) 136 | batch_idxs = np.array(batch_idxs) + 1 137 | assert np.min(batch_idxs) >= 1 138 | assert np.max(batch_idxs) < self.nb_entries 139 | assert len(batch_idxs) == batch_size 140 | 141 | # Create experiences 142 | experiences = [] 143 | for idx in batch_idxs: 144 | terminal0 = self.terminals[idx - 2] if idx >= 2 else False 145 | while terminal0: 146 | # Skip this transition because the environment was reset here. Select a new, random 147 | # transition and use this instead. This may cause the batch to contain the same 148 | # transition twice. 149 | idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] 150 | terminal0 = self.terminals[idx - 2] if idx >= 2 else False 151 | assert 1 <= idx < self.nb_entries 152 | 153 | # This code is slightly complicated by the fact that subsequent observations might be 154 | # from different episodes. We ensure that an experience never spans multiple episodes. 155 | # This is probably not that important in practice but it seems cleaner. 156 | state0 = [self.observations[idx - 1]] 157 | for offset in range(0, self.window_length - 1): 158 | current_idx = idx - 2 - offset 159 | current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False 160 | if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): 161 | # The previously handled observation was terminal, don't add the current one. 162 | # Otherwise we would leak into a different episode. 163 | break 164 | state0.insert(0, self.observations[current_idx]) 165 | while len(state0) < self.window_length: 166 | state0.insert(0, zeroed_observation(state0[0])) 167 | action = self.actions[idx - 1] 168 | reward = self.rewards[idx - 1] 169 | terminal1 = self.terminals[idx - 1] 170 | 171 | # Okay, now we need to create the follow-up state. This is state0 shifted on timestep 172 | # to the right. Again, we need to be careful to not include an observation from the next 173 | # episode if the last state is terminal. 174 | state1 = [np.copy(x) for x in state0[1:]] 175 | state1.append(self.observations[idx]) 176 | 177 | assert len(state0) == self.window_length 178 | assert len(state1) == len(state0) 179 | experiences.append(Experience(state0=state0, action=action, reward=reward, 180 | state1=state1, terminal1=terminal1)) 181 | assert len(experiences) == batch_size 182 | return experiences 183 | 184 | def sample_and_split(self, batch_size, batch_idxs=None): 185 | experiences = self.sample(batch_size, batch_idxs) 186 | 187 | state0_batch = [] 188 | reward_batch = [] 189 | action_batch = [] 190 | terminal1_batch = [] 191 | state1_batch = [] 192 | for e in experiences: 193 | state0_batch.append(e.state0) 194 | state1_batch.append(e.state1) 195 | reward_batch.append(e.reward) 196 | action_batch.append(e.action) 197 | terminal1_batch.append(0. if e.terminal1 else 1.) 198 | 199 | # Prepare and validate parameters. 200 | state0_batch = np.array(state0_batch, 'double').reshape(batch_size, -1) 201 | state1_batch = np.array(state1_batch, 'double').reshape(batch_size, -1) 202 | terminal1_batch = np.array(terminal1_batch, 'double').reshape(batch_size, -1) 203 | reward_batch = np.array(reward_batch, 'double').reshape(batch_size, -1) 204 | action_batch = np.array(action_batch, 'double').reshape(batch_size, -1) 205 | 206 | return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch 207 | 208 | def append(self, observation, action, reward, terminal, training=True): 209 | super(SequentialMemory, self).append(observation, action, reward, terminal, training=training) 210 | 211 | # This needs to be understood as follows: in `observation`, take `action`, obtain `reward` 212 | # and weather the next state is `terminal` or not. 213 | if training: 214 | self.observations.append(observation) 215 | self.actions.append(action) 216 | self.rewards.append(reward) 217 | self.terminals.append(terminal) 218 | 219 | @property 220 | def nb_entries(self): 221 | return len(self.observations) 222 | 223 | def get_config(self): 224 | config = super(SequentialMemory, self).get_config() 225 | config['limit'] = self.limit 226 | return config 227 | 228 | 229 | class EpisodeParameterMemory(Memory): 230 | def __init__(self, limit, **kwargs): 231 | super(EpisodeParameterMemory, self).__init__(**kwargs) 232 | self.limit = limit 233 | 234 | self.params = RingBuffer(limit) 235 | self.intermediate_rewards = [] 236 | self.total_rewards = RingBuffer(limit) 237 | 238 | def sample(self, batch_size, batch_idxs=None): 239 | if batch_idxs is None: 240 | batch_idxs = sample_batch_indexes(0, self.nb_entries, size=batch_size) 241 | assert len(batch_idxs) == batch_size 242 | 243 | batch_params = [] 244 | batch_total_rewards = [] 245 | for idx in batch_idxs: 246 | batch_params.append(self.params[idx]) 247 | batch_total_rewards.append(self.total_rewards[idx]) 248 | return batch_params, batch_total_rewards 249 | 250 | def append(self, observation, action, reward, terminal, training=True): 251 | super(EpisodeParameterMemory, self).append(observation, action, reward, terminal, training=training) 252 | if training: 253 | self.intermediate_rewards.append(reward) 254 | 255 | def finalize_episode(self, params): 256 | total_reward = sum(self.intermediate_rewards) 257 | self.total_rewards.append(total_reward) 258 | self.params.append(params) 259 | self.intermediate_rewards = [] 260 | 261 | @property 262 | def nb_entries(self): 263 | return len(self.total_rewards) 264 | 265 | def get_config(self): 266 | config = super(SequentialMemory, self).get_config() 267 | config['limit'] = self.limit 268 | return config 269 | -------------------------------------------------------------------------------- /utils/pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | 8 | # reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You 9 | 10 | 11 | def timeit(tag, t): 12 | print("{}: {}s".format(tag, time() - t)) 13 | return time() 14 | 15 | def pc_normalize(pc): 16 | centroid = np.mean(pc, axis=0) 17 | pc = pc - centroid 18 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 19 | pc = pc / m 20 | return pc 21 | 22 | def square_distance(src, dst): 23 | """ 24 | Calculate Euclid distance between each two points. 25 | src^T * dst = xn * xm + yn * ym + zn * zm; 26 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 27 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 28 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 29 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 30 | Input: 31 | src: source points, [B, N, C] 32 | dst: target points, [B, M, C] 33 | Output: 34 | dist: per-point square distance, [B, N, M] 35 | """ 36 | return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1) 37 | 38 | 39 | def index_points(points, idx): 40 | """ 41 | Input: 42 | points: input points data, [B, N, C] 43 | idx: sample index data, [B, S, [K]] 44 | Return: 45 | new_points:, indexed points data, [B, S, [K], C] 46 | """ 47 | raw_size = idx.size() 48 | idx = idx.reshape(raw_size[0], -1) 49 | res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1))) 50 | return res.reshape(*raw_size, -1) 51 | 52 | 53 | def farthest_point_sample(xyz, npoint): 54 | """ 55 | Input: 56 | xyz: pointcloud data, [B, N, 3] 57 | npoint: number of samples 58 | Return: 59 | centroids: sampled pointcloud index, [B, npoint] 60 | """ 61 | device = xyz.device 62 | B, N, C = xyz.shape 63 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 64 | distance = torch.ones(B, N).to(device) * 1e10 65 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 66 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 67 | for i in range(npoint): 68 | centroids[:, i] = farthest 69 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 70 | dist = torch.sum((xyz - centroid) ** 2, -1) 71 | distance = torch.min(distance, dist) 72 | farthest = torch.max(distance, -1)[1] 73 | return centroids 74 | 75 | 76 | def query_ball_point(radius, nsample, xyz, new_xyz): 77 | """ 78 | Input: 79 | radius: local region radius 80 | nsample: max sample number in local region 81 | xyz: all points, [B, N, 3] 82 | new_xyz: query points, [B, S, 3] 83 | Return: 84 | group_idx: grouped points index, [B, S, nsample] 85 | """ 86 | device = xyz.device 87 | B, N, C = xyz.shape 88 | _, S, _ = new_xyz.shape 89 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 90 | sqrdists = square_distance(new_xyz, xyz) 91 | group_idx[sqrdists > radius ** 2] = N 92 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 93 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 94 | mask = group_idx == N 95 | group_idx[mask] = group_first[mask] 96 | return group_idx 97 | 98 | 99 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False): 100 | """ 101 | Input: 102 | npoint: 103 | radius: 104 | nsample: 105 | xyz: input points position data, [B, N, 3] 106 | points: input points data, [B, N, D] 107 | Return: 108 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 109 | new_points: sampled points data, [B, npoint, nsample, 3+D] 110 | """ 111 | B, N, C = xyz.shape 112 | S = npoint 113 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] 114 | torch.cuda.empty_cache() 115 | new_xyz = index_points(xyz, fps_idx) 116 | torch.cuda.empty_cache() 117 | if knn: 118 | dists = square_distance(new_xyz, xyz) # B x npoint x N 119 | idx = dists.argsort()[:, :, :nsample] # B x npoint x K 120 | else: 121 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 122 | torch.cuda.empty_cache() 123 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 124 | torch.cuda.empty_cache() 125 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 126 | torch.cuda.empty_cache() 127 | 128 | if points is not None: 129 | grouped_points = index_points(points, idx) 130 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 131 | else: 132 | new_points = grouped_xyz_norm 133 | if returnfps: 134 | return new_xyz, new_points, grouped_xyz, fps_idx 135 | else: 136 | return new_xyz, new_points 137 | 138 | 139 | def sample_and_group_all(xyz, points): 140 | """ 141 | Input: 142 | xyz: input points position data, [B, N, 3] 143 | points: input points data, [B, N, D] 144 | Return: 145 | new_xyz: sampled points position data, [B, 1, 3] 146 | new_points: sampled points data, [B, 1, N, 3+D] 147 | """ 148 | device = xyz.device 149 | B, N, C = xyz.shape 150 | new_xyz = torch.zeros(B, 1, C).to(device) 151 | grouped_xyz = xyz.view(B, 1, N, C) 152 | if points is not None: 153 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 154 | else: 155 | new_points = grouped_xyz 156 | return new_xyz, new_points 157 | 158 | 159 | class PointNetSetAbstraction(nn.Module): 160 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False): 161 | super(PointNetSetAbstraction, self).__init__() 162 | self.npoint = npoint 163 | self.radius = radius 164 | self.nsample = nsample 165 | self.knn = knn 166 | self.mlp_convs = nn.ModuleList() 167 | self.mlp_bns = nn.ModuleList() 168 | last_channel = in_channel 169 | for out_channel in mlp: 170 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 171 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 172 | last_channel = out_channel 173 | self.group_all = group_all 174 | 175 | def forward(self, xyz, points): 176 | """ 177 | Input: 178 | xyz: input points position data, [B, N, C] 179 | points: input points data, [B, N, C] 180 | Return: 181 | new_xyz: sampled points position data, [B, S, C] 182 | new_points_concat: sample points feature data, [B, S, D'] 183 | """ 184 | if self.group_all: 185 | new_xyz, new_points = sample_and_group_all(xyz, points) 186 | else: 187 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn) 188 | # new_xyz: sampled points position data, [B, npoint, C] 189 | # new_points: sampled points data, [B, npoint, nsample, C+D] 190 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 191 | for i, conv in enumerate(self.mlp_convs): 192 | bn = self.mlp_bns[i] 193 | new_points = F.relu(bn(conv(new_points))) 194 | 195 | new_points = torch.max(new_points, 2)[0].transpose(1, 2) 196 | return new_xyz, new_points 197 | 198 | 199 | class PointNetSetAbstractionMsg(nn.Module): 200 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False): 201 | super(PointNetSetAbstractionMsg, self).__init__() 202 | self.npoint = npoint 203 | self.radius_list = radius_list 204 | self.nsample_list = nsample_list 205 | self.knn = knn 206 | self.conv_blocks = nn.ModuleList() 207 | self.bn_blocks = nn.ModuleList() 208 | for i in range(len(mlp_list)): 209 | convs = nn.ModuleList() 210 | bns = nn.ModuleList() 211 | last_channel = in_channel + 3 212 | for out_channel in mlp_list[i]: 213 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 214 | bns.append(nn.BatchNorm2d(out_channel)) 215 | last_channel = out_channel 216 | self.conv_blocks.append(convs) 217 | self.bn_blocks.append(bns) 218 | 219 | def forward(self, xyz, points, seed_idx=None): 220 | """ 221 | Input: 222 | xyz: input points position data, [B, C, N] 223 | points: input points data, [B, D, N] 224 | Return: 225 | new_xyz: sampled points position data, [B, C, S] 226 | new_points_concat: sample points feature data, [B, D', S] 227 | """ 228 | 229 | B, N, C = xyz.shape 230 | S = self.npoint 231 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx) 232 | new_points_list = [] 233 | for i, radius in enumerate(self.radius_list): 234 | K = self.nsample_list[i] 235 | if self.knn: 236 | dists = square_distance(new_xyz, xyz) # B x npoint x N 237 | group_idx = dists.argsort()[:, :, :K] # B x npoint x K 238 | else: 239 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 240 | grouped_xyz = index_points(xyz, group_idx) 241 | grouped_xyz -= new_xyz.view(B, S, 1, C) 242 | if points is not None: 243 | grouped_points = index_points(points, group_idx) 244 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 245 | else: 246 | grouped_points = grouped_xyz 247 | 248 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 249 | for j in range(len(self.conv_blocks[i])): 250 | conv = self.conv_blocks[i][j] 251 | bn = self.bn_blocks[i][j] 252 | grouped_points = F.relu(bn(conv(grouped_points))) 253 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 254 | new_points_list.append(new_points) 255 | 256 | new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2) 257 | return new_xyz, new_points_concat 258 | 259 | 260 | # NoteL this function swaps N and C 261 | class PointNetFeaturePropagation(nn.Module): 262 | def __init__(self, in_channel, mlp): 263 | super(PointNetFeaturePropagation, self).__init__() 264 | self.mlp_convs = nn.ModuleList() 265 | self.mlp_bns = nn.ModuleList() 266 | last_channel = in_channel 267 | for out_channel in mlp: 268 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 269 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 270 | last_channel = out_channel 271 | 272 | def forward(self, xyz1, xyz2, points1, points2): 273 | """ 274 | Input: 275 | xyz1: input points position data, [B, C, N] 276 | xyz2: sampled input points position data, [B, C, S] 277 | points1: input points data, [B, D, N] 278 | points2: input points data, [B, D, S] 279 | Return: 280 | new_points: upsampled points data, [B, D', N] 281 | """ 282 | xyz1 = xyz1.permute(0, 2, 1) 283 | xyz2 = xyz2.permute(0, 2, 1) 284 | 285 | points2 = points2.permute(0, 2, 1) 286 | B, N, C = xyz1.shape 287 | _, S, _ = xyz2.shape 288 | 289 | if S == 1: 290 | interpolated_points = points2.repeat(1, N, 1) 291 | else: 292 | dists = square_distance(xyz1, xyz2) 293 | dists, idx = dists.sort(dim=-1) 294 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 295 | 296 | dist_recip = 1.0 / (dists + 1e-8) 297 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 298 | weight = dist_recip / norm 299 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 300 | 301 | if points1 is not None: 302 | points1 = points1.permute(0, 2, 1) 303 | new_points = torch.cat([points1, interpolated_points], dim=-1) 304 | else: 305 | new_points = interpolated_points 306 | 307 | new_points = new_points.permute(0, 2, 1) 308 | for i, conv in enumerate(self.mlp_convs): 309 | bn = self.mlp_bns[i] 310 | new_points = F.relu(bn(conv(new_points))) 311 | return new_points -------------------------------------------------------------------------------- /haq_lib/rl_quantize.py: -------------------------------------------------------------------------------- 1 | # Code for "[HAQ: Hardware-Aware Automated Quantization with Mixed Precision" 2 | # Kuan Wang*, Zhijian Liu*, Yujun Lin*, Ji Lin, Song Han 3 | # {kuanwang, zhijian, yujunlin, jilin, songhan}@mit.edu 4 | 5 | import os 6 | import math 7 | import argparse 8 | import numpy as np 9 | from copy import deepcopy 10 | 11 | from haq_lib.lib.env.quantize_env import QuantizeEnv 12 | from haq_lib.lib.env.linear_quantize_env import LinearQuantizeEnv 13 | from haq_lib.lib.rl.ddpg import DDPG 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torchvision.models as models 18 | import haq_lib.models as customized_models 19 | 20 | # Models 21 | default_model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(models.__dict__[name])) 24 | 25 | customized_models_names = sorted(name for name in customized_models.__dict__ 26 | if name.islower() and not name.startswith("__") 27 | and callable(customized_models.__dict__[name])) 28 | 29 | for name in customized_models.__dict__: 30 | if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): 31 | models.__dict__[name] = customized_models.__dict__[name] 32 | 33 | model_names = default_model_names + customized_models_names 34 | print('support models: ', model_names) 35 | 36 | 37 | def train(num_episode, agent, env, output, linear_quantization=False, debug=False): 38 | # best record 39 | best_reward = -math.inf 40 | best_policy = [] 41 | 42 | agent.is_training = True 43 | step = episode = episode_steps = 0 44 | episode_reward = 0. 45 | observation = None 46 | T = [] # trajectory 47 | while episode < num_episode: # counting based on episode 48 | # reset if it is the start of episode 49 | if observation is None: 50 | observation = deepcopy(env.reset()) 51 | agent.reset(observation) 52 | 53 | # agent pick action ... 54 | if episode <= args.warmup: 55 | action = agent.random_action() 56 | else: 57 | action = agent.select_action(observation, episode=episode) 58 | 59 | # env response with next_observation, reward, terminate_info 60 | observation2, reward, done, info = env.step(action) 61 | observation2 = deepcopy(observation2) 62 | 63 | T.append([reward, deepcopy(observation), deepcopy(observation2), action, done]) 64 | 65 | # [optional] save intermideate model 66 | if episode % int(num_episode / 10) == 0: 67 | agent.save_model(output) 68 | 69 | # update 70 | step += 1 71 | episode_steps += 1 72 | episode_reward += reward 73 | observation = deepcopy(observation2) 74 | 75 | if done: # end of episode 76 | 77 | if linear_quantization: 78 | if debug: 79 | print('#{}: episode_reward:{:.4f} acc: {:.4f}, cost: {:.4f}'.format(episode, episode_reward, 80 | info['accuracy'], 81 | info['cost'] * 1. / 8e6)) 82 | text_writer.write( 83 | '#{}: episode_reward:{:.4f} acc: {:.4f}, cost: {:.4f}\n'.format(episode, episode_reward, 84 | info['accuracy'], 85 | info['cost'] * 1. / 8e6)) 86 | else: 87 | if debug: 88 | print('#{}: episode_reward:{:.4f} acc: {:.4f}, weight: {:.4f} %'.format(episode, episode_reward, 89 | info['accuracy'], 90 | info['w_ratio'] * 100)) 91 | text_writer.write( 92 | '#{}: episode_reward:{:.4f} acc: {:.4f}, weight: {:.4f} %\n'.format(episode, episode_reward, 93 | info['accuracy'], 94 | info['w_ratio'] * 100)) 95 | 96 | final_reward = T[-1][0] 97 | # agent observe and update policy 98 | for i, (r_t, s_t, s_t1, a_t, done) in enumerate(T): 99 | agent.observe(final_reward, s_t, s_t1, a_t, done) 100 | if episode > args.warmup: 101 | for i in range(args.n_update): 102 | agent.update_policy() 103 | 104 | agent.memory.append( 105 | observation, 106 | agent.select_action(observation, episode=episode), 107 | 0., False 108 | ) 109 | 110 | # reset 111 | observation = None 112 | episode_steps = 0 113 | episode_reward = 0. 114 | episode += 1 115 | T = [] 116 | 117 | if final_reward > best_reward: 118 | best_reward = final_reward 119 | best_policy = env.strategy 120 | 121 | text_writer.write('best reward: {}\n'.format(best_reward)) 122 | text_writer.write('best policy: {}\n'.format(best_policy)) 123 | text_writer.close() 124 | return best_policy, best_reward 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser(description='PyTorch Reinforcement Learning') 129 | 130 | parser.add_argument('--suffix', default=None, type=str, help='suffix to help you remember what experiment you ran') 131 | # env 132 | parser.add_argument('--dataset', default='imagenet', type=str, help='dataset to use') 133 | parser.add_argument('--dataset_root', default='data/imagenet', type=str, help='path to dataset') 134 | parser.add_argument('--preserve_ratio', default=0.1, type=float, help='preserve ratio of the model size') 135 | parser.add_argument('--min_bit', default=1, type=float, help='minimum bit to use') 136 | parser.add_argument('--max_bit', default=8, type=float, help='maximum bit to use') 137 | parser.add_argument('--float_bit', default=32, type=int, help='the bit of full precision float') 138 | parser.add_argument('--linear_quantization', dest='linear_quantization', action='store_true') 139 | parser.add_argument('--is_pruned', dest='is_pruned', action='store_true') 140 | # ddpg 141 | parser.add_argument('--hidden1', default=300, type=int, help='hidden num of first fully connect layer') 142 | parser.add_argument('--hidden2', default=300, type=int, help='hidden num of second fully connect layer') 143 | parser.add_argument('--lr_c', default=1e-3, type=float, help='learning rate for actor') 144 | parser.add_argument('--lr_a', default=1e-4, type=float, help='learning rate for actor') 145 | parser.add_argument('--warmup', default=20, type=int, 146 | help='time without training but only filling the replay memory') 147 | parser.add_argument('--discount', default=1., type=float, help='') 148 | parser.add_argument('--bsize', default=64, type=int, help='minibatch size') 149 | parser.add_argument('--rmsize', default=128, type=int, help='memory size for each layer') 150 | parser.add_argument('--window_length', default=1, type=int, help='') 151 | parser.add_argument('--tau', default=0.01, type=float, help='moving average for target network') 152 | # noise (truncated normal distribution) 153 | parser.add_argument('--init_delta', default=0.5, type=float, 154 | help='initial variance of truncated normal distribution') 155 | parser.add_argument('--delta_decay', default=0.99, type=float, 156 | help='delta decay during exploration') 157 | parser.add_argument('--n_update', default=1, type=int, help='number of rl to update each time') 158 | # training 159 | parser.add_argument('--max_episode_length', default=1e9, type=int, help='') 160 | parser.add_argument('--output', default='../../save', type=str, help='') 161 | parser.add_argument('--debug', dest='debug', action='store_true') 162 | parser.add_argument('--init_w', default=0.003, type=float, help='') 163 | parser.add_argument('--train_episode', default=600, type=int, help='train iters each timestep') 164 | parser.add_argument('--epsilon', default=50000, type=int, help='linear decay of exploration policy') 165 | parser.add_argument('--seed', default=234, type=int, help='') 166 | parser.add_argument('--n_worker', default=32, type=int, help='number of data loader worker') 167 | parser.add_argument('--data_bsize', default=256, type=int, help='number of data batch size') 168 | parser.add_argument('--finetune_epoch', default=1, type=int, help='') 169 | parser.add_argument('--finetune_gamma', default=0.8, type=float, help='finetune gamma') 170 | parser.add_argument('--finetune_lr', default=0.001, type=float, help='finetune gamma') 171 | parser.add_argument('--finetune_flag', default=True, type=bool, help='whether to finetune') 172 | parser.add_argument('--use_top5', default=False, type=bool, help='whether to use top5 acc in reward') 173 | parser.add_argument('--train_size', default=20000, type=int, help='number of train data size') 174 | parser.add_argument('--val_size', default=10000, type=int, help='number of val data size') 175 | parser.add_argument('--resume', default='default', type=str, help='Resuming model path for testing') 176 | # Architecture 177 | parser.add_argument('--arch', '-a', metavar='ARCH', default='mobilenet_v2', choices=model_names, 178 | help='model architecture:' + ' | '.join(model_names) + ' (default: mobilenet_v2)') 179 | # device options 180 | parser.add_argument('--gpu_id', default='1', type=str, 181 | help='id(s) for CUDA_VISIBLE_DEVICES') 182 | 183 | args = parser.parse_args() 184 | base_folder_name = '{}_{}'.format(args.arch, args.dataset) 185 | if args.suffix is not None: 186 | base_folder_name = base_folder_name + '_' + args.suffix 187 | args.output = os.path.join(args.output, base_folder_name) 188 | text_writer = open(os.path.join(args.output, 'log.txt'), 'w') 189 | print('==> Output path: {}...'.format(args.output)) 190 | 191 | # Use CUDA 192 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 193 | assert torch.cuda.is_available(), 'CUDA is needed for CNN' 194 | 195 | if args.seed > 0: 196 | np.random.seed(args.seed) 197 | torch.manual_seed(args.seed) 198 | torch.cuda.manual_seed_all(args.seed) 199 | 200 | if args.dataset == 'imagenet': 201 | num_classes = 1000 202 | elif args.dataset == 'imagenet100': 203 | num_classes = 100 204 | else: 205 | raise NotImplementedError 206 | model = models.__dict__[args.arch](pretrained=True, num_classes=num_classes) 207 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 208 | model.features = torch.nn.DataParallel(model.features) 209 | model.cuda() 210 | else: 211 | model = torch.nn.DataParallel(model).cuda() 212 | pretrained_model = deepcopy(model.state_dict()) 213 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 214 | cudnn.benchmark = True 215 | 216 | if args.linear_quantization: 217 | env = LinearQuantizeEnv(model, pretrained_model, args.dataset, args.dataset_root, 218 | compress_ratio=args.preserve_ratio, n_data_worker=args.n_worker, 219 | batch_size=args.data_bsize, args=args, float_bit=args.float_bit, 220 | is_model_pruned=args.is_pruned) 221 | else: 222 | env = QuantizeEnv(model, pretrained_model, args.dataset, args.dataset_root, 223 | compress_ratio=args.preserve_ratio, n_data_worker=args.n_worker, 224 | batch_size=args.data_bsize, args=args, float_bit=args.float_bit, 225 | is_model_pruned=args.is_pruned) 226 | 227 | nb_states = env.layer_embedding.shape[1] 228 | nb_actions = 1 # actions for weight and activation quantization 229 | args.rmsize = args.rmsize * len(env.quantizable_idx) # for each layer 230 | print('** Actual replay buffer size: {}'.format(args.rmsize)) 231 | agent = DDPG(nb_states, nb_actions, args) 232 | 233 | best_policy, best_reward = train(args.train_episode, agent, env, args.output, linear_quantization=args.linear_quantization, debug=args.debug) 234 | print('best_reward: ', best_reward) 235 | print('best_policy: ', best_policy) 236 | 237 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from search import AttrDict, create_attr_dict 2 | import yaml 3 | import os 4 | import torch 5 | import logging 6 | import importlib 7 | import shutil 8 | from utils import provider 9 | import numpy as np 10 | import torch.nn as nn 11 | from tqdm import tqdm 12 | from utils.dataset import PartNormalDataset 13 | import hydra 14 | 15 | from utils.general import to_categorical 16 | from utils.general import seg_classes, seg_label_to_cat 17 | from haq_lib.lib.utils.quantize_utils import quantize_model, kmeans_update_model 18 | from haq_lib.lib.utils.quantize_utils import QConv2d, QLinear, calibrate 19 | 20 | 21 | def inplace_relu(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('ReLU') != -1: 24 | m.inplace = True 25 | 26 | 27 | def main(): 28 | 29 | with open('config/finetune.yaml', 'r') as f: 30 | args = AttrDict(yaml.safe_load(f.read())) 31 | create_attr_dict(args) 32 | 33 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 34 | 35 | work_dir = args.work_dir 36 | if not os.path.exists(work_dir): 37 | os.makedirs(work_dir) 38 | 39 | log_file = os.path.join(work_dir, 'finetune.log') 40 | logger = logging.getLogger('finetune') 41 | logger.setLevel(logging.INFO) 42 | file_handler = logging.FileHandler(log_file, 'w') 43 | file_handler.setFormatter( 44 | logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) 45 | file_handler.setLevel(logging.INFO) 46 | logger.addHandler(file_handler) 47 | console = logging.StreamHandler() 48 | logger.addHandler(console) 49 | 50 | 51 | root = hydra.utils.to_absolute_path( 52 | 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/') 53 | 54 | TRAIN_DATASET = PartNormalDataset( 55 | root=root, npoints=args.num_point, split='trainval', normal_channel=args.normal) 56 | trainDataLoader = torch.utils.data.DataLoader( 57 | TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True) 58 | TEST_DATASET = PartNormalDataset( 59 | root=root, npoints=args.num_point, split='test', normal_channel=args.normal) 60 | testDataLoader = torch.utils.data.DataLoader( 61 | TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10) 62 | 63 | '''MODEL LOADING''' 64 | args.input_dim = (6 if args.normal else 3) + 16 65 | args.num_class = 50 66 | num_category = 16 67 | num_part = args.num_class 68 | 69 | model = getattr(importlib.import_module('models.model'), 70 | 'PointTransformerSeg')(args).cuda() 71 | criterion = torch.nn.CrossEntropyLoss() 72 | 73 | assert os.path.exists('best_model.pth'), 'best_model.pth must be provided.' 74 | checkpoint = torch.load('best_model.pth') 75 | start_epoch = checkpoint['epoch'] 76 | model.load_state_dict(checkpoint['model_state_dict']) 77 | logger.info('Use pretrain model') 78 | 79 | if args.optimizer == 'Adam': 80 | optimizer = torch.optim.Adam( 81 | model.parameters(), 82 | lr=args.learning_rate, 83 | betas=(0.9, 0.999), 84 | eps=1e-08, 85 | weight_decay=args.weight_decay 86 | ) 87 | else: 88 | optimizer = torch.optim.SGD( 89 | model.parameters(), lr=args.learning_rate, momentum=0.9) 90 | 91 | def bn_momentum_adjust(m, momentum): 92 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 93 | m.momentum = momentum 94 | 95 | LEARNING_RATE_CLIP = 1e-5 96 | MOMENTUM_ORIGINAL = 0.1 97 | MOMENTUM_DECCAY = 0.5 98 | MOMENTUM_DECCAY_STEP = args.step_size 99 | 100 | best_acc = 0 101 | global_epoch = 0 102 | best_class_avg_iou = 0 103 | best_inctance_avg_iou = 0 104 | 105 | quantizable_idx = [] 106 | for i, m in enumerate(model.modules()): 107 | if i < 6 or i > 238: 108 | continue 109 | if type(m) in [nn.Conv2d, nn.Linear]: 110 | quantizable_idx.append(i) 111 | print(quantizable_idx) 112 | 113 | strategy = [4, 4, 4, 4, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 114 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 115 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 4, 2, 4, 4, 116 | 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 117 | 2, 4, 4, 4, 4, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 118 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 119 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] 120 | print('strategy for : ', strategy) 121 | assert len(quantizable_idx) == len(strategy), \ 122 | 'You should provide the same number of bit setting as layer list for weight quantization!' 123 | centroid_label_dict = quantize_model(model, quantizable_idx, strategy, mode='cpu', quantize_bias=False, 124 | centroids_init='k-means++', max_iter=50) 125 | 126 | for epoch in range(start_epoch, args.epoch): 127 | if args.free_high_bit and epoch == args.epoch - 1: 128 | # quantize the high bit layers only at last epoch to save time 129 | centroid_label_dict = quantize_model(model, quantizable_idx, strategy, mode='cpu', quantize_bias=False, 130 | centroids_init='k-means++', max_iter=50, free_high_bit=False) 131 | 132 | mean_correct = [] 133 | 134 | logger.info('Epoch %d (%d/%s):' % 135 | (global_epoch + 1, epoch + 1, args.epoch)) 136 | '''Adjust learning rate and BN momentum''' 137 | lr = max(args.learning_rate * (args.lr_decay ** 138 | (epoch // args.step_size)), LEARNING_RATE_CLIP) 139 | logger.info('Learning rate:%f' % lr) 140 | for param_group in optimizer.param_groups: 141 | param_group['lr'] = lr 142 | momentum = MOMENTUM_ORIGINAL * \ 143 | (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP)) 144 | if momentum < 0.01: 145 | momentum = 0.01 146 | print('BN momentum updated to: %f' % momentum) 147 | model = model.apply(lambda x: bn_momentum_adjust(x, momentum)) 148 | model = model.train() 149 | 150 | '''learning one epoch''' 151 | for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9): 152 | points = points.data.numpy() 153 | points[:, :, 0:3] = provider.random_scale_point_cloud( 154 | points[:, :, 0:3]) 155 | points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) 156 | points = torch.Tensor(points) 157 | 158 | points, label, target = points.float().cuda( 159 | ), label.long().cuda(), target.long().cuda() 160 | optimizer.zero_grad() 161 | 162 | seg_pred = model(torch.cat([points, to_categorical( 163 | label, num_category).repeat(1, points.shape[1], 1)], -1)) 164 | seg_pred = seg_pred.contiguous().view(-1, num_part) 165 | target = target.view(-1, 1)[:, 0] 166 | pred_choice = seg_pred.data.max(1)[1] 167 | 168 | correct = pred_choice.eq(target.data).cpu().sum() 169 | mean_correct.append( 170 | correct.item() / (args.batch_size * args.num_point)) 171 | loss = criterion(seg_pred, target) 172 | loss.backward() 173 | optimizer.step() 174 | kmeans_update_model( 175 | model, quantizable_idx, centroid_label_dict, free_high_bit=args.free_high_bit) 176 | 177 | train_instance_acc = np.mean(mean_correct) 178 | logger.info('Train accuracy is: %.5f' % train_instance_acc) 179 | 180 | with torch.no_grad(): 181 | test_metrics = {} 182 | total_correct = 0 183 | total_seen = 0 184 | total_seen_class = [0 for _ in range(num_part)] 185 | total_correct_class = [0 for _ in range(num_part)] 186 | shape_ious = {cat: [] for cat in seg_classes.keys()} 187 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 188 | 189 | for cat in seg_classes.keys(): 190 | for label in seg_classes[cat]: 191 | seg_label_to_cat[label] = cat 192 | 193 | model = model.eval() 194 | 195 | for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): 196 | cur_batch_size, NUM_POINT, _ = points.size() 197 | points, label, target = points.float().cuda( 198 | ), label.long().cuda(), target.long().cuda() 199 | seg_pred = model(torch.cat([points, to_categorical( 200 | label, num_category).repeat(1, points.shape[1], 1)], -1)) 201 | cur_pred_val = seg_pred.cpu().data.numpy() 202 | cur_pred_val_logits = cur_pred_val 203 | cur_pred_val = np.zeros( 204 | (cur_batch_size, NUM_POINT)).astype(np.int32) 205 | target = target.cpu().data.numpy() 206 | 207 | for i in range(cur_batch_size): 208 | cat = seg_label_to_cat[target[i, 0]] 209 | logits = cur_pred_val_logits[i, :, :] 210 | cur_pred_val[i, :] = np.argmax( 211 | logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] 212 | 213 | correct = np.sum(cur_pred_val == target) 214 | total_correct += correct 215 | total_seen += (cur_batch_size * NUM_POINT) 216 | 217 | for l in range(num_part): 218 | total_seen_class[l] += np.sum(target == l) 219 | total_correct_class[l] += ( 220 | np.sum((cur_pred_val == l) & (target == l))) 221 | 222 | for i in range(cur_batch_size): 223 | segp = cur_pred_val[i, :] 224 | segl = target[i, :] 225 | cat = seg_label_to_cat[segl[0]] 226 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 227 | for l in seg_classes[cat]: 228 | if (np.sum(segl == l) == 0) and ( 229 | np.sum(segp == l) == 0): # part is not present, no prediction as well 230 | part_ious[l - seg_classes[cat][0]] = 1.0 231 | else: 232 | part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( 233 | np.sum((segl == l) | (segp == l))) 234 | shape_ious[cat].append(np.mean(part_ious)) 235 | 236 | all_shape_ious = [] 237 | for cat in shape_ious.keys(): 238 | for iou in shape_ious[cat]: 239 | all_shape_ious.append(iou) 240 | shape_ious[cat] = np.mean(shape_ious[cat]) 241 | mean_shape_ious = np.mean(list(shape_ious.values())) 242 | test_metrics['accuracy'] = total_correct / float(total_seen) 243 | test_metrics['class_avg_accuracy'] = np.mean( 244 | np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) 245 | for cat in sorted(shape_ious.keys()): 246 | logger.info('eval mIoU of %s %f' % 247 | (cat + ' ' * (14 - len(cat)), shape_ious[cat])) 248 | test_metrics['class_avg_iou'] = mean_shape_ious 249 | test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) 250 | 251 | logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( 252 | epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou'])) 253 | if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): 254 | logger.info('Save model...') 255 | savepath = os.path.join(work_dir, 'best_model.pth') 256 | logger.info('Saving at %s' % savepath) 257 | state = { 258 | 'epoch': epoch, 259 | 'train_acc': train_instance_acc, 260 | 'test_acc': test_metrics['accuracy'], 261 | 'class_avg_iou': test_metrics['class_avg_iou'], 262 | 'inctance_avg_iou': test_metrics['inctance_avg_iou'], 263 | 'model_state_dict': model.state_dict(), 264 | 'optimizer_state_dict': optimizer.state_dict(), 265 | } 266 | torch.save(state, savepath) 267 | logger.info('Saving model....') 268 | 269 | if test_metrics['accuracy'] > best_acc: 270 | best_acc = test_metrics['accuracy'] 271 | if test_metrics['class_avg_iou'] > best_class_avg_iou: 272 | best_class_avg_iou = test_metrics['class_avg_iou'] 273 | if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: 274 | best_inctance_avg_iou = test_metrics['inctance_avg_iou'] 275 | logger.info('Best accuracy is: %.5f' % best_acc) 276 | logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou) 277 | logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou) 278 | global_epoch += 1 279 | 280 | shutil.copy(savepath, './fintuned_model.pth') 281 | 282 | 283 | if __name__ == '__main__': 284 | main() 285 | -------------------------------------------------------------------------------- /finetune-distill.py: -------------------------------------------------------------------------------- 1 | from haq_lib.lib.utils.utils import AverageMeter 2 | from search import AttrDict, create_attr_dict 3 | import yaml 4 | import os 5 | import torch 6 | import logging 7 | import importlib 8 | import shutil 9 | from utils import provider 10 | import numpy as np 11 | import torch.nn as nn 12 | from tqdm import tqdm 13 | from utils.dataset import PartNormalDataset 14 | import hydra 15 | from progress.bar import Bar 16 | from utils.general import to_categorical 17 | from utils.general import seg_classes, seg_label_to_cat 18 | from haq_lib.lib.utils.quantize_utils import quantize_model, kmeans_update_model 19 | from haq_lib.lib.utils.quantize_utils import QConv2d, QLinear, calibrate 20 | 21 | 22 | def inplace_relu(m): 23 | classname = m.__class__.__name__ 24 | if classname.find('ReLU') != -1: 25 | m.inplace = True 26 | 27 | 28 | def main(): 29 | 30 | with open('config/finetune-dist.yaml', 'r') as f: 31 | args = AttrDict(yaml.safe_load(f.read())) 32 | create_attr_dict(args) 33 | 34 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 35 | 36 | work_dir = args.work_dir 37 | if not os.path.exists(work_dir): 38 | os.makedirs(work_dir) 39 | 40 | log_file = os.path.join(work_dir, 'finetune-dist.log') 41 | logger = logging.getLogger('finetune') 42 | logger.setLevel(logging.INFO) 43 | file_handler = logging.FileHandler(log_file, 'w') 44 | file_handler.setFormatter( 45 | logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) 46 | file_handler.setLevel(logging.INFO) 47 | logger.addHandler(file_handler) 48 | console = logging.StreamHandler() 49 | logger.addHandler(console) 50 | 51 | 52 | root = hydra.utils.to_absolute_path( 53 | 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/') 54 | 55 | TRAIN_DATASET = PartNormalDataset( 56 | root=root, npoints=args.num_point, split='trainval', normal_channel=args.normal) 57 | trainDataLoader = torch.utils.data.DataLoader( 58 | TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True) 59 | TEST_DATASET = PartNormalDataset( 60 | root=root, npoints=args.num_point, split='test', normal_channel=args.normal) 61 | testDataLoader = torch.utils.data.DataLoader( 62 | TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10) 63 | 64 | '''MODEL LOADING''' 65 | args.input_dim = (6 if args.normal else 3) + 16 66 | args.num_class = 50 67 | num_category = 16 68 | num_part = args.num_class 69 | 70 | model = getattr(importlib.import_module('models.model'), 71 | 'PointTransformerSeg')(args).cuda() 72 | 73 | teacher_model = getattr(importlib.import_module('models.model'), 74 | 'PointTransformerSeg')(args).cuda() 75 | 76 | criterion = torch.nn.CrossEntropyLoss() 77 | 78 | assert os.path.exists('best_model.pth'), 'best_model.pth must be provided.' 79 | checkpoint = torch.load('best_model.pth') 80 | start_epoch = checkpoint['epoch'] 81 | model.load_state_dict(checkpoint['model_state_dict']) 82 | teacher_model.load_state_dict(checkpoint['model_state_dict']) 83 | logger.info('Use pretrain model') 84 | 85 | if args.optimizer == 'Adam': 86 | optimizer = torch.optim.Adam( 87 | model.parameters(), 88 | lr=args.learning_rate, 89 | betas=(0.9, 0.999), 90 | eps=1e-08, 91 | weight_decay=args.weight_decay 92 | ) 93 | else: 94 | optimizer = torch.optim.SGD( 95 | model.parameters(), lr=args.learning_rate, momentum=0.9) 96 | 97 | def bn_momentum_adjust(m, momentum): 98 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 99 | m.momentum = momentum 100 | 101 | LEARNING_RATE_CLIP = 1e-5 102 | MOMENTUM_ORIGINAL = 0.1 103 | MOMENTUM_DECCAY = 0.5 104 | MOMENTUM_DECCAY_STEP = args.step_size 105 | 106 | best_acc = 0 107 | global_epoch = 0 108 | best_class_avg_iou = 0 109 | best_inctance_avg_iou = 0 110 | 111 | quantizable_idx = [] 112 | for i, m in enumerate(model.modules()): 113 | if i < 6 or i > 238: 114 | continue 115 | if type(m) in [nn.Conv2d, nn.Linear]: 116 | quantizable_idx.append(i) 117 | print(quantizable_idx) 118 | 119 | strategy = [4, 4, 4, 4, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 120 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 121 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 4, 2, 4, 4, 122 | 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 2, 123 | 2, 4, 4, 4, 4, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 124 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 125 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] 126 | print('strategy for : ', strategy) 127 | assert len(quantizable_idx) == len(strategy), \ 128 | 'You should provide the same number of bit setting as layer list for weight quantization!' 129 | centroid_label_dict = quantize_model(model, quantizable_idx, strategy, mode='cpu', quantize_bias=False, 130 | centroids_init='k-means++', max_iter=50) 131 | 132 | for epoch in range(start_epoch, args.epoch): 133 | if args.free_high_bit and epoch == args.epoch - 1: 134 | # quantize the high bit layers only at last epoch to save time 135 | centroid_label_dict = quantize_model(model, quantizable_idx, strategy, mode='cpu', quantize_bias=False, 136 | centroids_init='k-means++', max_iter=50, free_high_bit=False) 137 | 138 | mean_correct = [] 139 | 140 | logger.info('Epoch %d (%d/%s):' % 141 | (global_epoch + 1, epoch + 1, args.epoch)) 142 | '''Adjust learning rate and BN momentum''' 143 | lr = max(args.learning_rate * (args.lr_decay ** 144 | (epoch // args.step_size)), LEARNING_RATE_CLIP) 145 | logger.info('Learning rate:%f' % lr) 146 | for param_group in optimizer.param_groups: 147 | param_group['lr'] = lr 148 | momentum = MOMENTUM_ORIGINAL * \ 149 | (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP)) 150 | if momentum < 0.01: 151 | momentum = 0.01 152 | print('BN momentum updated to: %f' % momentum) 153 | model = model.apply(lambda x: bn_momentum_adjust(x, momentum)) 154 | model = model.train() 155 | teacher_model.eval() 156 | 157 | '''learning one epoch''' 158 | total_loss = AverageMeter() 159 | distill_loss = AverageMeter() 160 | bar = Bar('distill-finetune:', max=len(trainDataLoader)) 161 | for i, (points, label, target) in enumerate(trainDataLoader): 162 | points = points.data.numpy() 163 | points[:, :, 0:3] = provider.random_scale_point_cloud( 164 | points[:, :, 0:3]) 165 | points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) 166 | points = torch.Tensor(points) 167 | 168 | points, label, target = points.float().cuda( 169 | ), label.long().cuda(), target.long().cuda() 170 | optimizer.zero_grad() 171 | inputs = torch.cat([points, to_categorical( 172 | label, num_category).repeat(1, points.shape[1], 1)], -1) 173 | 174 | seg_pred = model(inputs) 175 | teacher_seg_pred = teacher_model(inputs) 176 | dloss = nn.functional.mse_loss(seg_pred, teacher_seg_pred) 177 | 178 | seg_pred = seg_pred.contiguous().view(-1, num_part) 179 | target = target.view(-1, 1)[:, 0] 180 | pred_choice = seg_pred.data.max(1)[1] 181 | 182 | correct = pred_choice.eq(target.data).cpu().sum() 183 | mean_correct.append( 184 | correct.item() / (args.batch_size * args.num_point)) 185 | tloss = criterion(seg_pred, target) 186 | loss = tloss + dloss * args['dist_loss_weight'] 187 | loss.backward() 188 | optimizer.step() 189 | kmeans_update_model( 190 | model, quantizable_idx, centroid_label_dict, free_high_bit=args.free_high_bit) 191 | total_loss.update(tloss.detach().cpu().numpy()) 192 | distill_loss.update(dloss.detach().cpu().numpy()) 193 | bar.suffix = \ 194 | '({batch}/{size}) TLoss: {total_loss:.3f} | DLoss: {distill_loss:.3f} | ETA: {eta:} | ' \ 195 | .format( 196 | batch=i + 1, 197 | size=len(trainDataLoader), 198 | total_loss=total_loss.avg, 199 | distill_loss=distill_loss.avg, 200 | eta=bar.eta_td 201 | ) 202 | bar.next() 203 | 204 | train_instance_acc = np.mean(mean_correct) 205 | logger.info('Train accuracy is: %.5f TLoss: %.5f DLoss: %.5f' % (train_instance_acc, total_loss.avg, distill_loss.avg)) 206 | 207 | with torch.no_grad(): 208 | test_metrics = {} 209 | total_correct = 0 210 | total_seen = 0 211 | total_seen_class = [0 for _ in range(num_part)] 212 | total_correct_class = [0 for _ in range(num_part)] 213 | shape_ious = {cat: [] for cat in seg_classes.keys()} 214 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 215 | 216 | for cat in seg_classes.keys(): 217 | for label in seg_classes[cat]: 218 | seg_label_to_cat[label] = cat 219 | 220 | model = model.eval() 221 | 222 | for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): 223 | cur_batch_size, NUM_POINT, _ = points.size() 224 | points, label, target = points.float().cuda( 225 | ), label.long().cuda(), target.long().cuda() 226 | seg_pred = model(torch.cat([points, to_categorical( 227 | label, num_category).repeat(1, points.shape[1], 1)], -1)) 228 | cur_pred_val = seg_pred.cpu().data.numpy() 229 | cur_pred_val_logits = cur_pred_val 230 | cur_pred_val = np.zeros( 231 | (cur_batch_size, NUM_POINT)).astype(np.int32) 232 | target = target.cpu().data.numpy() 233 | 234 | for i in range(cur_batch_size): 235 | cat = seg_label_to_cat[target[i, 0]] 236 | logits = cur_pred_val_logits[i, :, :] 237 | cur_pred_val[i, :] = np.argmax( 238 | logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] 239 | 240 | correct = np.sum(cur_pred_val == target) 241 | total_correct += correct 242 | total_seen += (cur_batch_size * NUM_POINT) 243 | 244 | for l in range(num_part): 245 | total_seen_class[l] += np.sum(target == l) 246 | total_correct_class[l] += ( 247 | np.sum((cur_pred_val == l) & (target == l))) 248 | 249 | for i in range(cur_batch_size): 250 | segp = cur_pred_val[i, :] 251 | segl = target[i, :] 252 | cat = seg_label_to_cat[segl[0]] 253 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 254 | for l in seg_classes[cat]: 255 | if (np.sum(segl == l) == 0) and ( 256 | np.sum(segp == l) == 0): # part is not present, no prediction as well 257 | part_ious[l - seg_classes[cat][0]] = 1.0 258 | else: 259 | part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( 260 | np.sum((segl == l) | (segp == l))) 261 | shape_ious[cat].append(np.mean(part_ious)) 262 | 263 | all_shape_ious = [] 264 | for cat in shape_ious.keys(): 265 | for iou in shape_ious[cat]: 266 | all_shape_ious.append(iou) 267 | shape_ious[cat] = np.mean(shape_ious[cat]) 268 | mean_shape_ious = np.mean(list(shape_ious.values())) 269 | test_metrics['accuracy'] = total_correct / float(total_seen) 270 | test_metrics['class_avg_accuracy'] = np.mean( 271 | np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) 272 | for cat in sorted(shape_ious.keys()): 273 | logger.info('eval mIoU of %s %f' % 274 | (cat + ' ' * (14 - len(cat)), shape_ious[cat])) 275 | test_metrics['class_avg_iou'] = mean_shape_ious 276 | test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) 277 | 278 | logger.info('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( 279 | epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou'])) 280 | if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): 281 | logger.info('Save model...') 282 | savepath = os.path.join(work_dir, 'best_model.pth') 283 | logger.info('Saving at %s' % savepath) 284 | state = { 285 | 'epoch': epoch, 286 | 'train_acc': train_instance_acc, 287 | 'test_acc': test_metrics['accuracy'], 288 | 'class_avg_iou': test_metrics['class_avg_iou'], 289 | 'inctance_avg_iou': test_metrics['inctance_avg_iou'], 290 | 'model_state_dict': model.state_dict(), 291 | 'optimizer_state_dict': optimizer.state_dict(), 292 | } 293 | torch.save(state, savepath) 294 | logger.info('Saving model....') 295 | 296 | if test_metrics['accuracy'] > best_acc: 297 | best_acc = test_metrics['accuracy'] 298 | if test_metrics['class_avg_iou'] > best_class_avg_iou: 299 | best_class_avg_iou = test_metrics['class_avg_iou'] 300 | if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: 301 | best_inctance_avg_iou = test_metrics['inctance_avg_iou'] 302 | logger.info('Best accuracy is: %.5f' % best_acc) 303 | logger.info('Best class avg mIOU is: %.5f' % best_class_avg_iou) 304 | logger.info('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou) 305 | global_epoch += 1 306 | 307 | shutil.copy(savepath, './fintuned_model.pth') 308 | 309 | 310 | if __name__ == '__main__': 311 | main() 312 | -------------------------------------------------------------------------------- /haq_lib/pretrain.py: -------------------------------------------------------------------------------- 1 | # Code for "[HAQ: Hardware-Aware Automated Quantization with Mixed Precision" 2 | # Kuan Wang*, Zhijian Liu*, Yujun Lin*, Ji Lin, Song Han 3 | # {kuanwang, zhijian, yujunlin, jilin, songhan}@mit.edu 4 | 5 | import os 6 | import time 7 | import math 8 | import random 9 | import shutil 10 | import argparse 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.optim as optim 17 | import torchvision.models as models 18 | import haq_lib.models as customized_models 19 | 20 | 21 | from lib.utils.utils import Logger, AverageMeter, accuracy 22 | from lib.utils.data_utils import get_dataset 23 | from progress.bar import Bar 24 | 25 | 26 | # Models 27 | default_model_names = sorted(name for name in models.__dict__ 28 | if name.islower() and not name.startswith("__") 29 | and callable(models.__dict__[name])) 30 | 31 | customized_models_names = sorted(name for name in customized_models.__dict__ 32 | if name.islower() and not name.startswith("__") 33 | and callable(customized_models.__dict__[name])) 34 | 35 | for name in customized_models.__dict__: 36 | if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): 37 | models.__dict__[name] = customized_models.__dict__[name] 38 | 39 | model_names = default_model_names + customized_models_names 40 | 41 | # Parse arguments 42 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 43 | 44 | # Datasets 45 | parser.add_argument('-d', '--data', default='data/imagenet', type=str) 46 | parser.add_argument('--data_name', default='imagenet', type=str) 47 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 48 | help='number of data loading workers (default: 4)') 49 | # Optimization options 50 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 51 | help='number of total epochs to run') 52 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 53 | help='manual epoch number (useful on restarts)') 54 | parser.add_argument('--warmup_epoch', default=0, type=int, metavar='N', 55 | help='manual warmup epoch number (useful on restarts)') 56 | parser.add_argument('--train_batch', default=256, type=int, metavar='N', 57 | help='train batchsize (default: 256)') 58 | parser.add_argument('--test_batch', default=512, type=int, metavar='N', 59 | help='test batchsize (default: 512)') 60 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 61 | metavar='LR', help='initial learning rate') 62 | parser.add_argument('--lr_type', default='cos', type=str, 63 | help='lr scheduler (exp/cos/step3/fixed)') 64 | parser.add_argument('--schedule', type=int, nargs='+', default=[31, 61, 91], 65 | help='Decrease learning rate at these epochs.') 66 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 67 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 68 | help='momentum') 69 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 70 | metavar='W', help='weight decay (default: 1e-4)') 71 | # Checkpoints 72 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 73 | help='path to save checkpoint (default: checkpoint)') 74 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 75 | help='path to latest checkpoint (default: none)') 76 | parser.add_argument('--pretrained', action='store_true', 77 | help='use pretrained model') 78 | # Quantization 79 | parser.add_argument('--half', action='store_true', 80 | help='half') 81 | parser.add_argument('--half_type', default='O1', type=str, 82 | help='half type: O0/O1/O2/O3') 83 | # Architecture 84 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, 85 | help='model architecture:' + ' | '.join(model_names) + ' (default: resnet50)') 86 | # Miscs 87 | parser.add_argument('--manualSeed', type=int, help='manual seed') 88 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 89 | help='evaluate model on validation set') 90 | # Device options 91 | parser.add_argument('--gpu_id', default='1', type=str, 92 | help='id(s) for CUDA_VISIBLE_DEVICES') 93 | 94 | args = parser.parse_args() 95 | state = {k: v for k, v in args._get_kwargs()} 96 | lr_current = state['lr'] 97 | 98 | # Use CUDA 99 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 100 | use_cuda = torch.cuda.is_available() 101 | 102 | # Random seed 103 | if args.manualSeed is None: 104 | args.manualSeed = random.randint(1, 10000) 105 | random.seed(args.manualSeed) 106 | torch.manual_seed(args.manualSeed) 107 | if use_cuda: 108 | torch.cuda.manual_seed_all(args.manualSeed) 109 | 110 | 111 | best_acc = 0 # best test accuracy 112 | 113 | 114 | def load_my_state_dict(model, state_dict): 115 | model_state = model.state_dict() 116 | for name, param in state_dict.items(): 117 | if name not in model_state: 118 | continue 119 | param_data = param.data 120 | if model_state[name].shape == param_data.shape: 121 | # print("load%s"%name) 122 | model_state[name].copy_(param_data) 123 | 124 | 125 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda): 126 | # switch to train mode 127 | model.train() 128 | 129 | batch_time = AverageMeter() 130 | data_time = AverageMeter() 131 | losses = AverageMeter() 132 | top1 = AverageMeter() 133 | top5 = AverageMeter() 134 | end = time.time() 135 | 136 | bar = Bar('Processing', max=len(train_loader)) 137 | for batch_idx, (inputs, targets) in enumerate(train_loader): 138 | # measure data loading time 139 | data_time.update(time.time() - end) 140 | 141 | if use_cuda: 142 | inputs, targets = inputs.cuda(), targets.cuda() 143 | inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) 144 | 145 | # compute output 146 | outputs = model(inputs) 147 | loss = criterion(outputs, targets) 148 | 149 | # measure accuracy and record loss 150 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 151 | losses.update(loss.item(), inputs.size(0)) 152 | top1.update(prec1.item(), inputs.size(0)) 153 | top5.update(prec5.item(), inputs.size(0)) 154 | 155 | # compute gradient 156 | optimizer.zero_grad() 157 | if args.half: 158 | with apex.amp.scale_loss(loss, optimizer) as scaled_loss: 159 | scaled_loss.backward() 160 | # with amp_handle.scale_loss(loss, optimizer) as scaled_loss: 161 | # scaled_loss.backward() 162 | else: 163 | loss.backward() 164 | # do SGD step 165 | optimizer.step() 166 | 167 | # measure elapsed time 168 | batch_time.update(time.time() - end) 169 | end = time.time() 170 | 171 | # plot progress 172 | if batch_idx % 1 == 0: 173 | bar.suffix = \ 174 | '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \ 175 | 'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 176 | batch=batch_idx + 1, 177 | size=len(train_loader), 178 | data=data_time.val, 179 | bt=batch_time.val, 180 | total=bar.elapsed_td, 181 | eta=bar.eta_td, 182 | loss=losses.avg, 183 | top1=top1.avg, 184 | top5=top5.avg, 185 | ) 186 | bar.next() 187 | bar.finish() 188 | return losses.avg, top1.avg 189 | 190 | 191 | def test(val_loader, model, criterion, epoch, use_cuda): 192 | global best_acc 193 | 194 | batch_time = AverageMeter() 195 | data_time = AverageMeter() 196 | losses = AverageMeter() 197 | top1 = AverageMeter() 198 | top5 = AverageMeter() 199 | 200 | with torch.no_grad(): 201 | # switch to evaluate mode 202 | model.eval() 203 | 204 | end = time.time() 205 | bar = Bar('Processing', max=len(val_loader)) 206 | for batch_idx, (inputs, targets) in enumerate(val_loader): 207 | # measure data loading time 208 | data_time.update(time.time() - end) 209 | 210 | if use_cuda: 211 | inputs, targets = inputs.cuda(), targets.cuda() 212 | inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) 213 | 214 | # compute output 215 | outputs = model(inputs) 216 | loss = criterion(outputs, targets) 217 | 218 | # measure accuracy and record loss 219 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 220 | losses.update(loss.item(), inputs.size(0)) 221 | top1.update(prec1.item(), inputs.size(0)) 222 | top5.update(prec5.item(), inputs.size(0)) 223 | 224 | # measure elapsed time 225 | batch_time.update(time.time() - end) 226 | end = time.time() 227 | 228 | # plot progress 229 | if batch_idx % 1 == 0: 230 | bar.suffix = \ 231 | '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \ 232 | 'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 233 | batch=batch_idx + 1, 234 | size=len(val_loader), 235 | data=data_time.avg, 236 | bt=batch_time.avg, 237 | total=bar.elapsed_td, 238 | eta=bar.eta_td, 239 | loss=losses.avg, 240 | top1=top1.avg, 241 | top5=top5.avg, 242 | ) 243 | bar.next() 244 | bar.finish() 245 | return losses.avg, top1.avg 246 | 247 | 248 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 249 | filepath = os.path.join(checkpoint, filename) 250 | torch.save(state, filepath) 251 | if is_best: 252 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 253 | 254 | 255 | def adjust_learning_rate(optimizer, epoch): 256 | global lr_current 257 | global best_acc 258 | if epoch < args.warmup_epoch: 259 | lr_current = state['lr']*args.gamma 260 | elif args.lr_type == 'cos': 261 | # cos 262 | lr_current = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs)) 263 | elif args.lr_type == 'exp': 264 | step = 1 265 | decay = args.gamma 266 | lr_current = args.lr * (decay ** (epoch // step)) 267 | elif epoch in args.schedule: 268 | lr_current *= args.gamma 269 | for param_group in optimizer.param_groups: 270 | param_group['lr'] = lr_current 271 | 272 | 273 | if __name__ == '__main__': 274 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 275 | 276 | if not os.path.isdir(args.checkpoint): 277 | os.makedirs(args.checkpoint) 278 | 279 | train_loader, val_loader, n_class = get_dataset(dataset_name=args.data_name, batch_size=args.train_batch, 280 | n_worker=args.workers, data_root=args.data) 281 | 282 | model = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=n_class) 283 | print("=> creating model '{}'".format(args.arch), ' pretrained is ', args.pretrained) 284 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 285 | cudnn.benchmark = True 286 | 287 | # define loss function (criterion) and optimizer 288 | criterion = nn.CrossEntropyLoss().cuda() 289 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 290 | 291 | # use HalfTensor 292 | if args.half: 293 | try: 294 | import apex 295 | except ImportError: 296 | raise ImportError("Please install apex from https://github.com/NVIDIA/apex") 297 | model.cuda() 298 | model, optimizer = apex.amp.initialize(model, optimizer, opt_level=args.half_type) 299 | 300 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 301 | model.features = torch.nn.DataParallel(model.features) 302 | model = model.cuda() 303 | else: 304 | model = torch.nn.DataParallel(model).cuda() 305 | 306 | # Resume 307 | title = 'ImageNet-' + args.arch 308 | if args.resume: 309 | # Load checkpoint. 310 | print('==> Resuming from checkpoint..') 311 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 312 | args.checkpoint = os.path.dirname(args.resume) 313 | checkpoint = torch.load(args.resume) 314 | best_acc = checkpoint['best_acc'] 315 | print(best_acc) 316 | start_epoch = checkpoint['epoch'] 317 | model.load_state_dict(checkpoint['state_dict'], strict=False) 318 | optimizer.load_state_dict(checkpoint['optimizer']) 319 | if os.path.isfile(os.path.join(args.checkpoint, 'log.txt')): 320 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) 321 | else: 322 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 323 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 324 | else: 325 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 326 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 327 | 328 | print('save the checkpoint to ', args.checkpoint) 329 | 330 | if args.evaluate: 331 | print('\nEvaluation only') 332 | test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda) 333 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) 334 | exit() 335 | 336 | # Train and val 337 | for epoch in range(start_epoch, args.epochs): 338 | adjust_learning_rate(optimizer, epoch) 339 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr_current)) 340 | 341 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda) 342 | test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda) 343 | 344 | # append logger file 345 | logger.append([lr_current, train_loss, test_loss, train_acc, test_acc]) 346 | 347 | # save model 348 | is_best = test_acc > best_acc 349 | best_acc = max(test_acc, best_acc) 350 | save_checkpoint({ 351 | 'epoch': epoch + 1, 352 | 'state_dict': model.state_dict(), 353 | 'acc': test_acc, 354 | 'best_acc': best_acc, 355 | 'optimizer' : optimizer.state_dict(), 356 | }, is_best, checkpoint=args.checkpoint) 357 | 358 | # ============ TensorBoard logging ============# 359 | # (1) Log the scalar values 360 | info = { 361 | 'train_loss': train_loss, 362 | 'train_accuracy': train_acc, 363 | 'test_loss': test_loss, 364 | 'test_accuracy': test_acc, 365 | 'learning_rate': lr_current 366 | } 367 | 368 | logger.close() 369 | 370 | print('Best acc:') 371 | print(best_acc) 372 | 373 | -------------------------------------------------------------------------------- /haq_lib/finetune.py: -------------------------------------------------------------------------------- 1 | # Code for "[HAQ: Hardware-Aware Automated Quantization with Mixed Precision" 2 | # Kuan Wang*, Zhijian Liu*, Yujun Lin*, Ji Lin, Song Han 3 | # {kuanwang, zhijian, yujunlin, jilin, songhan}@mit.edu 4 | 5 | import os 6 | import time 7 | import math 8 | import random 9 | import shutil 10 | import argparse 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.optim as optim 17 | import torchvision.models as models 18 | import haq_lib.models as customized_models 19 | 20 | from lib.utils.utils import Logger, AverageMeter, accuracy 21 | from lib.utils.data_utils import get_dataset 22 | from progress.bar import Bar 23 | from lib.utils.quantize_utils import quantize_model, kmeans_update_model, QConv2d, QLinear, calibrate 24 | 25 | 26 | # Models 27 | default_model_names = sorted(name for name in models.__dict__ 28 | if name.islower() and not name.startswith("__") 29 | and callable(models.__dict__[name])) 30 | 31 | customized_models_names = sorted(name for name in customized_models.__dict__ 32 | if name.islower() and not name.startswith("__") 33 | and callable(customized_models.__dict__[name])) 34 | 35 | for name in customized_models.__dict__: 36 | if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): 37 | models.__dict__[name] = customized_models.__dict__[name] 38 | 39 | model_names = default_model_names + customized_models_names 40 | 41 | # Parse arguments 42 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 43 | 44 | # Datasets 45 | parser.add_argument('-d', '--data', default='data/imagenet', type=str) 46 | parser.add_argument('--data_name', default='imagenet', type=str) 47 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 48 | help='number of data loading workers (default: 4)') 49 | # Optimization options 50 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 51 | help='number of total epochs to run') 52 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 53 | help='manual epoch number (useful on restarts)') 54 | parser.add_argument('--warmup_epoch', default=0, type=int, metavar='N', 55 | help='manual warmup epoch number (useful on restarts)') 56 | parser.add_argument('--train_batch', default=256, type=int, metavar='N', 57 | help='train batchsize (default: 256)') 58 | parser.add_argument('--test_batch', default=512, type=int, metavar='N', 59 | help='test batchsize (default: 512)') 60 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 61 | metavar='LR', help='initial learning rate') 62 | parser.add_argument('--lr_type', default='cos', type=str, 63 | help='lr scheduler (exp/cos/step3/fixed)') 64 | parser.add_argument('--schedule', type=int, nargs='+', default=[31, 61, 91], 65 | help='Decrease learning rate at these epochs.') 66 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 67 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 68 | help='momentum') 69 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 70 | metavar='W', help='weight decay (default: 1e-5)') 71 | # Checkpoints 72 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 73 | help='path to save checkpoint (default: checkpoint)') 74 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 75 | help='path to latest checkpoint (default: none)') 76 | parser.add_argument('--pretrained', action='store_true', 77 | help='use pretrained model') 78 | # Quantization 79 | parser.add_argument('--linear_quantization', dest='linear_quantization', action='store_true', 80 | help='quantize both weight and activation)') 81 | parser.add_argument('--free_high_bit', default=True, type=bool, 82 | help='free the high bit (>6)') 83 | parser.add_argument('--half', action='store_true', 84 | help='half') 85 | parser.add_argument('--half_type', default='O1', type=str, 86 | help='half type: O0/O1/O2/O3') 87 | # Architecture 88 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, 89 | help='model architecture:' + ' | '.join(model_names) + ' (default: resnet50)') 90 | # Miscs 91 | parser.add_argument('--manualSeed', type=int, help='manual seed') 92 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 93 | help='evaluate model on validation set') 94 | # Device options 95 | parser.add_argument('--gpu_id', default='1', type=str, 96 | help='id(s) for CUDA_VISIBLE_DEVICES') 97 | 98 | args = parser.parse_args() 99 | state = {k: v for k, v in args._get_kwargs()} 100 | lr_current = state['lr'] 101 | 102 | # Use CUDA 103 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 104 | use_cuda = torch.cuda.is_available() 105 | 106 | # Random seed 107 | if args.manualSeed is None: 108 | args.manualSeed = random.randint(1, 10000) 109 | random.seed(args.manualSeed) 110 | torch.manual_seed(args.manualSeed) 111 | if use_cuda: 112 | torch.cuda.manual_seed_all(args.manualSeed) 113 | 114 | 115 | best_acc = 0 # best test accuracy 116 | 117 | 118 | def load_my_state_dict(model, state_dict): 119 | model_state = model.state_dict() 120 | for name, param in state_dict.items(): 121 | if name not in model_state: 122 | continue 123 | param_data = param.data 124 | if model_state[name].shape == param_data.shape: 125 | # print("load%s"%name) 126 | model_state[name].copy_(param_data) 127 | 128 | 129 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda): 130 | # switch to train mode 131 | model.train() 132 | 133 | batch_time = AverageMeter() 134 | data_time = AverageMeter() 135 | losses = AverageMeter() 136 | top1 = AverageMeter() 137 | top5 = AverageMeter() 138 | end = time.time() 139 | 140 | bar = Bar('Processing', max=len(train_loader)) 141 | for batch_idx, (inputs, targets) in enumerate(train_loader): 142 | # measure data loading time 143 | data_time.update(time.time() - end) 144 | 145 | if use_cuda: 146 | inputs, targets = inputs.cuda(), targets.cuda() 147 | inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) 148 | 149 | # compute output 150 | outputs = model(inputs) 151 | loss = criterion(outputs, targets) 152 | 153 | # measure accuracy and record loss 154 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 155 | losses.update(loss.item(), inputs.size(0)) 156 | top1.update(prec1.item(), inputs.size(0)) 157 | top5.update(prec5.item(), inputs.size(0)) 158 | 159 | # compute gradient 160 | optimizer.zero_grad() 161 | if args.half: 162 | with apex.amp.scale_loss(loss, optimizer) as scaled_loss: 163 | scaled_loss.backward() 164 | # with amp_handle.scale_loss(loss, optimizer) as scaled_loss: 165 | # scaled_loss.backward() 166 | else: 167 | loss.backward() 168 | # do SGD step 169 | optimizer.step() 170 | 171 | if not args.linear_quantization: 172 | kmeans_update_model(model, quantizable_idx, centroid_label_dict, free_high_bit=args.free_high_bit) 173 | 174 | # measure elapsed time 175 | batch_time.update(time.time() - end) 176 | end = time.time() 177 | 178 | # plot progress 179 | if batch_idx % 1 == 0: 180 | bar.suffix = \ 181 | '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \ 182 | 'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 183 | batch=batch_idx + 1, 184 | size=len(train_loader), 185 | data=data_time.val, 186 | bt=batch_time.val, 187 | total=bar.elapsed_td, 188 | eta=bar.eta_td, 189 | loss=losses.avg, 190 | top1=top1.avg, 191 | top5=top5.avg, 192 | ) 193 | bar.next() 194 | bar.finish() 195 | return losses.avg, top1.avg 196 | 197 | 198 | def test(val_loader, model, criterion, epoch, use_cuda): 199 | global best_acc 200 | 201 | batch_time = AverageMeter() 202 | data_time = AverageMeter() 203 | losses = AverageMeter() 204 | top1 = AverageMeter() 205 | top5 = AverageMeter() 206 | 207 | with torch.no_grad(): 208 | # switch to evaluate mode 209 | model.eval() 210 | 211 | end = time.time() 212 | bar = Bar('Processing', max=len(val_loader)) 213 | for batch_idx, (inputs, targets) in enumerate(val_loader): 214 | # measure data loading time 215 | data_time.update(time.time() - end) 216 | 217 | if use_cuda: 218 | inputs, targets = inputs.cuda(), targets.cuda() 219 | inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) 220 | 221 | # compute output 222 | outputs = model(inputs) 223 | loss = criterion(outputs, targets) 224 | 225 | # measure accuracy and record loss 226 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 227 | losses.update(loss.item(), inputs.size(0)) 228 | top1.update(prec1.item(), inputs.size(0)) 229 | top5.update(prec5.item(), inputs.size(0)) 230 | 231 | # measure elapsed time 232 | batch_time.update(time.time() - end) 233 | end = time.time() 234 | 235 | # plot progress 236 | if batch_idx % 1 == 0: 237 | bar.suffix = \ 238 | '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \ 239 | 'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 240 | batch=batch_idx + 1, 241 | size=len(val_loader), 242 | data=data_time.avg, 243 | bt=batch_time.avg, 244 | total=bar.elapsed_td, 245 | eta=bar.eta_td, 246 | loss=losses.avg, 247 | top1=top1.avg, 248 | top5=top5.avg, 249 | ) 250 | bar.next() 251 | bar.finish() 252 | return losses.avg, top1.avg 253 | 254 | 255 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 256 | filepath = os.path.join(checkpoint, filename) 257 | torch.save(state, filepath) 258 | if is_best: 259 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 260 | 261 | 262 | def adjust_learning_rate(optimizer, epoch): 263 | global lr_current 264 | global best_acc 265 | if epoch < args.warmup_epoch: 266 | lr_current = state['lr']*args.gamma 267 | elif args.lr_type == 'cos': 268 | # cos 269 | lr_current = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs)) 270 | elif args.lr_type == 'exp': 271 | step = 1 272 | decay = args.gamma 273 | lr_current = args.lr * (decay ** (epoch // step)) 274 | elif epoch in args.schedule: 275 | lr_current *= args.gamma 276 | for param_group in optimizer.param_groups: 277 | param_group['lr'] = lr_current 278 | 279 | 280 | if __name__ == '__main__': 281 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 282 | 283 | if not os.path.isdir(args.checkpoint): 284 | os.makedirs(args.checkpoint) 285 | 286 | train_loader, val_loader, n_class = get_dataset(dataset_name=args.data_name, batch_size=args.train_batch, 287 | n_worker=args.workers, data_root=args.data) 288 | 289 | model = models.__dict__[args.arch](pretrained=args.pretrained) 290 | print("=> creating model '{}'".format(args.arch), ' pretrained is ', args.pretrained) 291 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 292 | cudnn.benchmark = True 293 | 294 | # define loss function (criterion) and optimizer 295 | criterion = nn.CrossEntropyLoss().cuda() 296 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 297 | 298 | # use HalfTensor 299 | if args.half: 300 | try: 301 | import apex 302 | except ImportError: 303 | raise ImportError("Please install apex from https://github.com/NVIDIA/apex") 304 | model.cuda() 305 | model, optimizer = apex.amp.initialize(model, optimizer, opt_level=args.half_type) 306 | 307 | if args.linear_quantization: 308 | quantizable_idx = [] 309 | for i, m in enumerate(model.modules()): 310 | if type(m) in [QConv2d, QLinear]: 311 | quantizable_idx.append(i) 312 | # print(model) 313 | print(quantizable_idx) 314 | 315 | if 'mobilenetv2' in args.arch: 316 | strategy = [[8, -1], [7, 7], [5, 6], [4, 6], [5, 6], [5, 7], [5, 6], [7, 4], [4, 6], [4, 6], [7, 7], [5, 6], [4, 6], [7, 3], [5, 7], [4, 7], [7, 3], [5, 7], [4, 7], [7, 7], [4, 7], [4, 7], [6, 4], [6, 7], [4, 7], [7, 4], [6, 7], [5, 7], [7, 4], [6, 7], [5, 7], [7, 4], [6, 7], [6, 7], [6, 4], [5, 7], [6, 7], [6, 4], [5, 7], [6, 7], [7, 7], [4, 7], [7, 7], [7, 7], [4, 7], [7, 7], [7, 7], [4, 7], [7, 7], [7, 7], [4, 7], [7, 7], [8, 8]] 317 | else: 318 | raise NotImplementedError 319 | 320 | print(strategy) 321 | quantize_layer_bit_dict = {n: b for n, b in zip(quantizable_idx, strategy)} 322 | for i, layer in enumerate(model.modules()): 323 | if i not in quantizable_idx: 324 | continue 325 | else: 326 | layer.w_bit = quantize_layer_bit_dict[i][0] 327 | layer.a_bit = quantize_layer_bit_dict[i][1] 328 | model = model.cuda() 329 | model = calibrate(model, train_loader) 330 | else: 331 | quantizable_idx = [] 332 | for i, m in enumerate(model.modules()): 333 | if type(m) in [nn.Conv2d, nn.Linear]: 334 | quantizable_idx.append(i) 335 | print(quantizable_idx) 336 | 337 | if args.arch.startswith('resnet50'): 338 | # resnet50 ratio 10% 339 | strategy = [6, 6, 6, 6, 5, 5, 6, 5, 5, 6, 5, 5, 6, 5, 5, 5, 5, 5, 4, 5, 4, 4, 5, 4, 4, 4, 3, 4, 340 | 4, 4, 3, 4, 4, 3, 4, 4, 3, 4, 4, 3, 4, 4, 3, 4, 3, 3, 2, 3, 2, 3, 3, 2, 3, 4] 341 | else: 342 | # you can put your own strategy here 343 | raise NotImplementedError 344 | print('strategy for ' + args.arch + ': ', strategy) 345 | 346 | assert len(quantizable_idx) == len(strategy), \ 347 | 'You should provide the same number of bit setting as layer list for weight quantization!' 348 | centroid_label_dict = quantize_model(model, quantizable_idx, strategy, mode='cpu', quantize_bias=False, 349 | centroids_init='k-means++', max_iter=50) 350 | 351 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 352 | model.features = torch.nn.DataParallel(model.features) 353 | model = model.cuda() 354 | else: 355 | model = torch.nn.DataParallel(model).cuda() 356 | 357 | # Resume 358 | title = 'ImageNet-' + args.arch 359 | if args.resume: 360 | # Load checkpoint. 361 | print('==> Resuming from checkpoint..') 362 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 363 | args.checkpoint = os.path.dirname(args.resume) 364 | checkpoint = torch.load(args.resume) 365 | best_acc = checkpoint['best_acc'] 366 | print(best_acc) 367 | start_epoch = checkpoint['epoch'] 368 | model.load_state_dict(checkpoint['state_dict'], strict=False) 369 | optimizer.load_state_dict(checkpoint['optimizer']) 370 | if os.path.isfile(os.path.join(args.checkpoint, 'log.txt')): 371 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) 372 | else: 373 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 374 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 375 | else: 376 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 377 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 378 | 379 | if args.evaluate: 380 | print('\nEvaluation only') 381 | test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda) 382 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) 383 | exit() 384 | 385 | # Train and val 386 | for epoch in range(start_epoch, args.epochs): 387 | adjust_learning_rate(optimizer, epoch) 388 | # if args.free_high_bit and args.epochs - epoch < args.epochs // 10: 389 | if args.free_high_bit and epoch == args.epochs - 1 and (not args.linear_quantization): 390 | # quantize the high bit layers only at last epoch to save time 391 | centroid_label_dict = quantize_model(model, quantizable_idx, strategy, mode='cpu', quantize_bias=False, 392 | centroids_init='k-means++', max_iter=50, free_high_bit=False) 393 | 394 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr_current)) 395 | 396 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda) 397 | test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda) 398 | 399 | # append logger file 400 | logger.append([lr_current, train_loss, test_loss, train_acc, test_acc]) 401 | 402 | # save model 403 | is_best = test_acc > best_acc 404 | best_acc = max(test_acc, best_acc) 405 | save_checkpoint({ 406 | 'epoch': epoch + 1, 407 | 'state_dict': model.state_dict(), 408 | 'acc': test_acc, 409 | 'best_acc': best_acc, 410 | 'optimizer' : optimizer.state_dict(), 411 | }, is_best, checkpoint=args.checkpoint) 412 | 413 | logger.close() 414 | 415 | print('Best acc:') 416 | print(best_acc) 417 | 418 | -------------------------------------------------------------------------------- /haq_lib/lib/env/linear_quantize_env.py: -------------------------------------------------------------------------------- 1 | # Code for "[HAQ: Hardware-Aware Automated Quantization with Mixed Precision" 2 | # Kuan Wang*, Zhijian Liu*, Yujun Lin*, Ji Lin, Song Han 3 | # {kuanwang, zhijian, yujunlin, jilin, songhan}@mit.edu 4 | 5 | import os 6 | import time 7 | import math 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | from copy import deepcopy 12 | import torch.optim as optim 13 | from progress.bar import Bar 14 | 15 | from haq_lib.lib.utils.utils import AverageMeter, accuracy, prGreen, measure_model 16 | from haq_lib.lib.utils.data_utils import get_split_train_dataset 17 | from haq_lib.lib.utils.quantize_utils import QConv2d, QLinear, calibrate 18 | 19 | 20 | class LinearQuantizeEnv: 21 | def __init__(self, model, pretrained_model, data, data_root, compress_ratio, args, n_data_worker=16, 22 | batch_size=256, float_bit=8, is_model_pruned=False): 23 | # default setting 24 | self.quantizable_layer_types = [QConv2d, QLinear] 25 | 26 | # save options 27 | self.model = model 28 | self.model_for_measure = deepcopy(model) 29 | self.model_name = args.arch 30 | self.cur_ind = 0 31 | self.strategy = [] # quantization strategy 32 | 33 | self.finetune_lr = args.finetune_lr 34 | self.optimizer = optim.SGD(model.parameters(), lr=args.finetune_lr, momentum=0.9, weight_decay=1e-5) 35 | self.criterion = nn.CrossEntropyLoss().cuda() 36 | self.pretrained_model = pretrained_model 37 | self.n_data_worker = n_data_worker 38 | self.batch_size = batch_size 39 | self.data_type = data 40 | self.data_root = data_root 41 | self.compress_ratio = compress_ratio 42 | self.is_model_pruned = is_model_pruned 43 | self.val_size = args.val_size 44 | self.train_size = args.train_size 45 | self.finetune_gamma = args.finetune_gamma 46 | self.finetune_lr = args.finetune_lr 47 | self.finetune_flag = args.finetune_flag 48 | self.finetune_epoch = args.finetune_epoch 49 | 50 | # options from args 51 | self.min_bit = args.min_bit 52 | self.max_bit = args.max_bit 53 | self.float_bit = float_bit * 1. 54 | self.last_weight_action = self.max_bit 55 | self.last_activation_action = self.max_bit 56 | self.action_radio_button = True 57 | 58 | self.is_inception = args.arch.startswith('inception') 59 | self.is_imagenet = ('imagenet' in data) 60 | self.use_top5 = args.use_top5 61 | 62 | # init reward 63 | self.best_reward = -math.inf 64 | 65 | # prepare data 66 | self._init_data() 67 | 68 | # build indexs 69 | self._build_index() 70 | self.n_quantizable_layer = len(self.quantizable_idx) 71 | 72 | self.model.load_state_dict(self.pretrained_model, strict=True) 73 | # self.org_acc = self._validate(self.val_loader, self.model) 74 | self.org_acc = self._validate(self.train_loader, self.model) 75 | # build embedding (static part), same as pruning 76 | self._build_state_embedding() 77 | 78 | # mode 79 | self.cost_mode = 'cloud_latency' 80 | self.simulator_batch = 16 81 | self.cost_lookuptable = self._get_lookuptable() 82 | 83 | # sanity check 84 | assert self.compress_ratio > self._min_cost() / self._org_cost(), \ 85 | 'Error! You can make achieve compress_ratio smaller than min_bit!' 86 | 87 | # restore weight 88 | self.reset() 89 | print('=> original acc: {:.3f}% on split dataset(train: %7d, val: %7d )'.format(self.org_acc, 90 | self.train_size, self.val_size)) 91 | print('=> original cost: {:.4f}'.format(self._org_cost())) 92 | 93 | def adjust_learning_rate(self): 94 | for param_group in self.optimizer.param_groups: 95 | param_group['lr'] *= self.finetune_gamma 96 | 97 | def step(self, action): 98 | # Pseudo prune and get the corresponding statistics. The real pruning happens till the end of all pseudo pruning 99 | action = self._action_wall(action) # percentage to preserve 100 | 101 | if self.action_radio_button: 102 | self.last_weight_action = action 103 | else: 104 | self.last_activation_action = action 105 | self.strategy.append([self.last_weight_action, self.last_activation_action]) # save action to strategy 106 | 107 | # all the actions are made 108 | if self._is_final_layer() and (not self.action_radio_button): 109 | self._final_action_wall() 110 | assert len(self.strategy) == len(self.quantizable_idx) 111 | cost = self._cur_cost() 112 | cost_ratio = cost / self._org_cost() 113 | 114 | self._set_mixed_precision(quantizable_idx=self.quantizable_idx, strategy=self.strategy) 115 | self.model = calibrate(self.model, self.train_loader) 116 | if self.finetune_flag: 117 | acc = self._finetune(self.train_loader, self.model, epochs=self.finetune_epoch, verbose=False) 118 | # train_acc = self._finetune(self.train_loader, self.model, epochs=self.finetune_epoch, verbose=False) 119 | # acc = self._validate(self.val_loader, self.model) 120 | else: 121 | acc = self._validate(self.val_loader, self.model) 122 | 123 | # reward = self.reward(acc, w_size_ratio) 124 | reward = self.reward(acc) 125 | 126 | info_set = {'cost_ratio': cost_ratio, 'accuracy': acc, 'cost': cost} 127 | 128 | if reward > self.best_reward: 129 | self.best_reward = reward 130 | prGreen('New best policy: {}, reward: {:.3f}, acc: {:.3f}, cost_ratio: {:.3f}'.format( 131 | self.strategy, self.best_reward, acc, cost_ratio)) 132 | 133 | obs = self.layer_embedding[self.cur_ind, :].copy() # actually the same as the last state 134 | done = True 135 | self.action_radio_button = not self.action_radio_button 136 | return obs, reward, done, info_set 137 | 138 | cost = self._cur_cost() 139 | info_set = {'cost': cost} 140 | reward = 0 141 | done = False 142 | 143 | if self.action_radio_button: 144 | self.layer_embedding[self.cur_ind][-1] = 0.0 145 | else: 146 | self.cur_ind += 1 # the index of next layer 147 | self.layer_embedding[self.cur_ind][-1] = 1.0 148 | self.layer_embedding[self.cur_ind][-2] = float(action) / float(self.max_bit) 149 | self.layer_embedding[self.cur_ind][-1] = float(self.action_radio_button) 150 | # build next state (in-place modify) 151 | obs = self.layer_embedding[self.cur_ind, :].copy() 152 | self.action_radio_button = not self.action_radio_button 153 | return obs, reward, done, info_set 154 | 155 | # for quantization 156 | def reward(self, acc, cost_ratio=None): 157 | if cost_ratio is not None: 158 | return (acc - self.org_acc + 1. / cost_ratio) * 0.1 159 | return (acc - self.org_acc) * 0.1 160 | 161 | def reset(self): 162 | # restore env by loading the pretrained model 163 | self.model.load_state_dict(self.pretrained_model, strict=False) 164 | self.optimizer = optim.SGD(self.model.parameters(), lr=self.finetune_lr, momentum=0.9, weight_decay=4e-5) 165 | self.cur_ind = 0 166 | self.strategy = [] # quantization strategy 167 | obs = self.layer_embedding[0].copy() 168 | return obs 169 | 170 | def _is_final_layer(self): 171 | return self.cur_ind == len(self.quantizable_idx) - 1 172 | 173 | def _final_action_wall(self): 174 | target = self.compress_ratio * self._org_cost() 175 | min_cost = 0 176 | for i, n_bit in enumerate(self.strategy): 177 | min_cost += self.cost_lookuptable[i][int(self.min_bit-1)][int(self.min_bit-1)] 178 | 179 | print('before action_wall: ', self.strategy, min_cost, self._cur_cost()) 180 | while min_cost < self._cur_cost() and target < self._cur_cost(): 181 | # print('current: ', self.strategy, min_cost, self._cur_cost()) 182 | for i, n_bit in enumerate(reversed(self.strategy)): 183 | if n_bit[1] > self.min_bit: 184 | self.strategy[-(i+1)][1] -= 1 185 | self._keep_first_last_layer() 186 | if target >= self._cur_cost(): 187 | break 188 | if n_bit[0] > self.min_bit: 189 | self.strategy[-(i+1)][0] -= 1 190 | self._keep_first_last_layer() 191 | if target >= self._cur_cost(): 192 | break 193 | print('after action_wall: ', self.strategy, min_cost, self._cur_cost()) 194 | 195 | def _keep_first_last_layer(self): 196 | self.strategy[0][0] = 8 197 | # self.strategy[0][1] = 8 198 | # input image is already 8 bit 199 | self.strategy[0][1] = -1 200 | self.strategy[-1][0] = 8 201 | self.strategy[-1][1] = 8 202 | 203 | def _action_wall(self, action): 204 | assert len(self.strategy) == self.cur_ind 205 | # limit the action to certain range 206 | action = float(action) 207 | min_bit, max_bit = self.bound_list[self.cur_ind] 208 | lbound, rbound = min_bit - 0.5, max_bit + 0.5 # same stride length for each bit 209 | action = (rbound - lbound) * action + lbound 210 | action = int(np.round(action, 0)) 211 | return action # not constrained here 212 | 213 | def _set_mixed_precision(self, quantizable_idx, strategy): 214 | assert len(quantizable_idx) == len(strategy), \ 215 | 'You should provide the same number of bit setting as layer list for weight quantization!' 216 | quantize_layer_bit_dict = {n: b for n, b in zip(quantizable_idx, strategy)} 217 | for i, layer in enumerate(self.model.modules()): 218 | if i not in quantizable_idx: 219 | continue 220 | else: 221 | layer.w_bit = quantize_layer_bit_dict[i][0] 222 | layer.a_bit = quantize_layer_bit_dict[i][1] 223 | 224 | def _cur_cost(self): 225 | cur_cost = 0. 226 | # quantized 227 | for i, n_bit in enumerate(self.strategy): 228 | cur_cost += self.cost_lookuptable[i, n_bit[0]-1, n_bit[1]-1] 229 | return cur_cost 230 | 231 | def _org_cost(self): 232 | org_cost = 0 233 | for i in range(self.cost_lookuptable.shape[0]): 234 | org_cost += self.cost_lookuptable[i, int(self.float_bit-1), int(self.float_bit-1)] 235 | return org_cost 236 | 237 | def _min_cost(self): 238 | min_cost = 0 239 | for i in range(self.cost_lookuptable.shape[0]): 240 | if i == 0 or i == (self.cost_lookuptable.shape[0] - 1): 241 | min_cost += self.cost_lookuptable[i, -1, -1] 242 | else: 243 | min_cost += self.cost_lookuptable[i, int(self.min_bit - 1), int(self.min_bit - 1)] 244 | return min_cost 245 | 246 | def _init_data(self): 247 | self.train_loader, self.val_loader, n_class = get_split_train_dataset( 248 | self.data_type, self.batch_size, self.n_data_worker, data_root=self.data_root, 249 | val_size=self.val_size, train_size=self.train_size, for_inception=self.is_inception) 250 | 251 | def _build_index(self): 252 | self.quantizable_idx = [] 253 | self.bound_list = [] 254 | for i, m in enumerate(self.model.modules()): 255 | if type(m) in self.quantizable_layer_types: 256 | self.quantizable_idx.append(i) 257 | self.bound_list.append((self.min_bit, self.max_bit)) 258 | print('=> Final bound list: {}'.format(self.bound_list)) 259 | 260 | def _build_state_embedding(self): 261 | # measure model for cifar 32x32 input 262 | if self.is_imagenet: 263 | measure_model(self.model_for_measure, 224, 224) 264 | else: 265 | measure_model(self.model_for_measure, 32, 32) 266 | # build the static part of the state embedding 267 | layer_embedding = [] 268 | module_list = list(self.model_for_measure.modules()) 269 | for i, ind in enumerate(self.quantizable_idx): 270 | m = module_list[ind] 271 | this_state = [] 272 | if type(m) == nn.Conv2d or type(m) == QConv2d: 273 | this_state.append([int(m.in_channels == m.groups)]) # layer type, 1 for conv_dw 274 | this_state.append([m.in_channels]) # in channels 275 | this_state.append([m.out_channels]) # out channels 276 | this_state.append([m.stride[0]]) # stride 277 | this_state.append([m.kernel_size[0]]) # kernel size 278 | this_state.append([np.prod(m.weight.size())]) # weight size 279 | this_state.append([m.in_w*m.in_h]) # input feature_map_size 280 | elif type(m) == nn.Linear or type(m) == QLinear: 281 | this_state.append([0.]) # layer type, 0 for fc 282 | this_state.append([m.in_features]) # in channels 283 | this_state.append([m.out_features]) # out channels 284 | this_state.append([0.]) # stride 285 | this_state.append([1.]) # kernel size 286 | this_state.append([np.prod(m.weight.size())]) # weight size 287 | this_state.append([m.in_w*m.in_h]) # input feature_map_size 288 | 289 | this_state.append([i]) # index 290 | this_state.append([1.]) # bits, 1 is the max bit 291 | this_state.append([1.]) # action radio button, 1 is the weight action 292 | layer_embedding.append(np.hstack(this_state)) 293 | 294 | # normalize the state 295 | layer_embedding = np.array(layer_embedding, 'float') 296 | print('=> shape of embedding (n_layer * n_dim): {}'.format(layer_embedding.shape)) 297 | assert len(layer_embedding.shape) == 2, layer_embedding.shape 298 | for i in range(layer_embedding.shape[1]): 299 | fmin = min(layer_embedding[:, i]) 300 | fmax = max(layer_embedding[:, i]) 301 | if fmax - fmin > 0: 302 | layer_embedding[:, i] = (layer_embedding[:, i] - fmin) / (fmax - fmin) 303 | 304 | self.layer_embedding = layer_embedding 305 | 306 | def _get_lookuptable(self): 307 | 308 | lookup_table_folder = 'lib/simulator/lookup_tables/' 309 | os.makedirs(lookup_table_folder, exist_ok=True) 310 | if self.cost_mode == 'cloud_latency': 311 | fname = lookup_table_folder + self.model_name + '_' + self.data_type \ 312 | + '_batch' + str(self.simulator_batch) + '_latency_table.npy' 313 | else: 314 | # add your own cost lookuptable here 315 | raise NotImplementedError 316 | 317 | if os.path.isfile(fname): 318 | print('load latency table : ', fname) 319 | latency_list = np.load(fname) 320 | print(latency_list) 321 | else: 322 | # you can put your own simulator/lookuptable here 323 | raise NotImplementedError 324 | return latency_list.copy() 325 | 326 | def _finetune(self, train_loader, model, epochs=1, verbose=True): 327 | batch_time = AverageMeter() 328 | data_time = AverageMeter() 329 | losses = AverageMeter() 330 | top1 = AverageMeter() 331 | top5 = AverageMeter() 332 | best_acc = 0. 333 | 334 | # switch to train mode 335 | model.train() 336 | end = time.time() 337 | t1 = time.time() 338 | bar = Bar('train:', max=len(train_loader)) 339 | for epoch in range(epochs): 340 | for i, (inputs, targets) in enumerate(train_loader): 341 | input_var, target_var = inputs.cuda(), targets.cuda() 342 | 343 | # measure data loading time 344 | data_time.update(time.time() - end) 345 | 346 | # compute output 347 | output = model(input_var) 348 | loss = self.criterion(output, target_var) 349 | 350 | # measure accuracy and record loss 351 | prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5)) 352 | losses.update(loss.item(), inputs.size(0)) 353 | top1.update(prec1.item(), inputs.size(0)) 354 | top5.update(prec5.item(), inputs.size(0)) 355 | 356 | # compute gradient 357 | self.optimizer.zero_grad() 358 | loss.backward() 359 | 360 | # do SGD step 361 | self.optimizer.step() 362 | 363 | # measure elapsed time 364 | batch_time.update(time.time() - end) 365 | end = time.time() 366 | 367 | # plot progress 368 | if i % 1 == 0: 369 | bar.suffix = \ 370 | '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \ 371 | 'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 372 | batch=i + 1, 373 | size=len(train_loader), 374 | data=data_time.val, 375 | bt=batch_time.val, 376 | total=bar.elapsed_td, 377 | eta=bar.eta_td, 378 | loss=losses.avg, 379 | top1=top1.avg, 380 | top5=top5.avg, 381 | ) 382 | bar.next() 383 | bar.finish() 384 | 385 | if self.use_top5: 386 | if top5.avg > best_acc: 387 | best_acc = top5.avg 388 | else: 389 | if top1.avg > best_acc: 390 | best_acc = top1.avg 391 | self.adjust_learning_rate() 392 | t2 = time.time() 393 | if verbose: 394 | print('* Test loss: %.3f top1: %.3f top5: %.3f time: %.3f' % (losses.avg, top1.avg, top5.avg, t2-t1)) 395 | return best_acc 396 | 397 | def _validate(self, val_loader, model, verbose=False): 398 | batch_time = AverageMeter() 399 | data_time = AverageMeter() 400 | losses = AverageMeter() 401 | top1 = AverageMeter() 402 | top5 = AverageMeter() 403 | 404 | t1 = time.time() 405 | with torch.no_grad(): 406 | # switch to evaluate mode 407 | model.eval() 408 | 409 | end = time.time() 410 | bar = Bar('valid:', max=len(val_loader)) 411 | for i, (inputs, targets) in enumerate(val_loader): 412 | # measure data loading time 413 | data_time.update(time.time() - end) 414 | 415 | input_var, target_var = inputs.cuda(), targets.cuda() 416 | 417 | # compute output 418 | output = model(input_var) 419 | loss = self.criterion(output, target_var) 420 | 421 | # measure accuracy and record loss 422 | prec1, prec5 = accuracy(output.data, target_var, topk=(1, 5)) 423 | losses.update(loss.item(), inputs.size(0)) 424 | top1.update(prec1.item(), inputs.size(0)) 425 | top5.update(prec5.item(), inputs.size(0)) 426 | 427 | # measure elapsed time 428 | batch_time.update(time.time() - end) 429 | end = time.time() 430 | # plot progress 431 | if i % 1 == 0: 432 | bar.suffix = \ 433 | '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | ' \ 434 | 'Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 435 | batch=i + 1, 436 | size=len(val_loader), 437 | data=data_time.avg, 438 | bt=batch_time.avg, 439 | total=bar.elapsed_td, 440 | eta=bar.eta_td, 441 | loss=losses.avg, 442 | top1=top1.avg, 443 | top5=top5.avg, 444 | ) 445 | bar.next() 446 | bar.finish() 447 | t2 = time.time() 448 | if verbose: 449 | print('* Test loss: %.3f top1: %.3f top5: %.3f time: %.3f' % (losses.avg, top1.avg, top5.avg, t2-t1)) 450 | if self.use_top5: 451 | return top5.avg 452 | else: 453 | return top1.avg 454 | 455 | --------------------------------------------------------------------------------