├── data ├── __init__.py └── contrast.py ├── results └── pred │ ├── Misc_5.png │ ├── Misc_106.png │ ├── Misc_108.png │ ├── Misc_110.png │ ├── Misc_111.png │ ├── Misc_112.png │ ├── Misc_114.png │ ├── Misc_119.png │ ├── Misc_127.png │ ├── Misc_128.png │ ├── Misc_132.png │ ├── Misc_133.png │ ├── Misc_140.png │ ├── Misc_146.png │ ├── Misc_149.png │ ├── Misc_153.png │ ├── Misc_154.png │ ├── Misc_155.png │ ├── Misc_157.png │ ├── Misc_164.png │ ├── Misc_168.png │ ├── Misc_173.png │ ├── Misc_177.png │ ├── Misc_178.png │ ├── Misc_179.png │ ├── Misc_180.png │ ├── Misc_186.png │ ├── Misc_187.png │ ├── Misc_188.png │ ├── Misc_189.png │ ├── Misc_19.png │ ├── Misc_191.png │ ├── Misc_197.png │ ├── Misc_201.png │ ├── Misc_203.png │ ├── Misc_204.png │ ├── Misc_205.png │ ├── Misc_207.png │ ├── Misc_225.png │ ├── Misc_226.png │ ├── Misc_227.png │ ├── Misc_229.png │ ├── Misc_233.png │ ├── Misc_238.png │ ├── Misc_248.png │ ├── Misc_252.png │ ├── Misc_254.png │ ├── Misc_257.png │ ├── Misc_263.png │ ├── Misc_265.png │ ├── Misc_286.png │ ├── Misc_322.png │ ├── Misc_33.png │ ├── Misc_335.png │ ├── Misc_337.png │ ├── Misc_344.png │ ├── Misc_35.png │ ├── Misc_352.png │ ├── Misc_353.png │ ├── Misc_354.png │ ├── Misc_356.png │ ├── Misc_358.png │ ├── Misc_360.png │ ├── Misc_363.png │ ├── Misc_365.png │ ├── Misc_366.png │ ├── Misc_369.png │ ├── Misc_371.png │ ├── Misc_372.png │ ├── Misc_374.png │ ├── Misc_38.png │ ├── Misc_380.png │ ├── Misc_385.png │ ├── Misc_387.png │ ├── Misc_390.png │ ├── Misc_391.png │ ├── Misc_392.png │ ├── Misc_393.png │ ├── Misc_395.png │ ├── Misc_397.png │ ├── Misc_399.png │ ├── Misc_413.png │ ├── Misc_425.png │ ├── Misc_43.png │ ├── Misc_55.png │ ├── Misc_66.png │ ├── Misc_67.png │ ├── Misc_68.png │ ├── Misc_78.png │ ├── Misc_81.png │ ├── Misc_83.png │ ├── Misc_87.png │ ├── Misc_89.png │ ├── Misc_90.png │ ├── Misc_93.png │ └── Misc_97.png ├── loss ├── __init__.py └── loss.py ├── metric ├── __init__.py ├── sigmoid.py └── samplewise.py ├── model ├── __init__.py ├── segmentation.py └── contrast.py ├── params └── BottomUpLocal_r_1_b_4_0.7614.params ├── .gitignore ├── README.md ├── visualize_local_contrast_networks.py └── train_alcnet.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .contrast import IceContrast 4 | -------------------------------------------------------------------------------- /results/pred/Misc_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_5.png -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | # from .detection import * 2 | from .loss import SoftIoULoss, SamplewiseSoftIoULoss 3 | -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- 1 | from .sigmoid import * 2 | from .samplewise import SamplewiseSigmoidMetric, ROCMetric 3 | -------------------------------------------------------------------------------- /results/pred/Misc_106.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_106.png -------------------------------------------------------------------------------- /results/pred/Misc_108.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_108.png -------------------------------------------------------------------------------- /results/pred/Misc_110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_110.png -------------------------------------------------------------------------------- /results/pred/Misc_111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_111.png -------------------------------------------------------------------------------- /results/pred/Misc_112.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_112.png -------------------------------------------------------------------------------- /results/pred/Misc_114.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_114.png -------------------------------------------------------------------------------- /results/pred/Misc_119.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_119.png -------------------------------------------------------------------------------- /results/pred/Misc_127.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_127.png -------------------------------------------------------------------------------- /results/pred/Misc_128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_128.png -------------------------------------------------------------------------------- /results/pred/Misc_132.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_132.png -------------------------------------------------------------------------------- /results/pred/Misc_133.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_133.png -------------------------------------------------------------------------------- /results/pred/Misc_140.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_140.png -------------------------------------------------------------------------------- /results/pred/Misc_146.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_146.png -------------------------------------------------------------------------------- /results/pred/Misc_149.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_149.png -------------------------------------------------------------------------------- /results/pred/Misc_153.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_153.png -------------------------------------------------------------------------------- /results/pred/Misc_154.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_154.png -------------------------------------------------------------------------------- /results/pred/Misc_155.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_155.png -------------------------------------------------------------------------------- /results/pred/Misc_157.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_157.png -------------------------------------------------------------------------------- /results/pred/Misc_164.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_164.png -------------------------------------------------------------------------------- /results/pred/Misc_168.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_168.png -------------------------------------------------------------------------------- /results/pred/Misc_173.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_173.png -------------------------------------------------------------------------------- /results/pred/Misc_177.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_177.png -------------------------------------------------------------------------------- /results/pred/Misc_178.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_178.png -------------------------------------------------------------------------------- /results/pred/Misc_179.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_179.png -------------------------------------------------------------------------------- /results/pred/Misc_180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_180.png -------------------------------------------------------------------------------- /results/pred/Misc_186.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_186.png -------------------------------------------------------------------------------- /results/pred/Misc_187.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_187.png -------------------------------------------------------------------------------- /results/pred/Misc_188.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_188.png -------------------------------------------------------------------------------- /results/pred/Misc_189.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_189.png -------------------------------------------------------------------------------- /results/pred/Misc_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_19.png -------------------------------------------------------------------------------- /results/pred/Misc_191.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_191.png -------------------------------------------------------------------------------- /results/pred/Misc_197.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_197.png -------------------------------------------------------------------------------- /results/pred/Misc_201.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_201.png -------------------------------------------------------------------------------- /results/pred/Misc_203.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_203.png -------------------------------------------------------------------------------- /results/pred/Misc_204.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_204.png -------------------------------------------------------------------------------- /results/pred/Misc_205.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_205.png -------------------------------------------------------------------------------- /results/pred/Misc_207.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_207.png -------------------------------------------------------------------------------- /results/pred/Misc_225.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_225.png -------------------------------------------------------------------------------- /results/pred/Misc_226.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_226.png -------------------------------------------------------------------------------- /results/pred/Misc_227.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_227.png -------------------------------------------------------------------------------- /results/pred/Misc_229.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_229.png -------------------------------------------------------------------------------- /results/pred/Misc_233.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_233.png -------------------------------------------------------------------------------- /results/pred/Misc_238.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_238.png -------------------------------------------------------------------------------- /results/pred/Misc_248.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_248.png -------------------------------------------------------------------------------- /results/pred/Misc_252.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_252.png -------------------------------------------------------------------------------- /results/pred/Misc_254.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_254.png -------------------------------------------------------------------------------- /results/pred/Misc_257.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_257.png -------------------------------------------------------------------------------- /results/pred/Misc_263.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_263.png -------------------------------------------------------------------------------- /results/pred/Misc_265.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_265.png -------------------------------------------------------------------------------- /results/pred/Misc_286.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_286.png -------------------------------------------------------------------------------- /results/pred/Misc_322.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_322.png -------------------------------------------------------------------------------- /results/pred/Misc_33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_33.png -------------------------------------------------------------------------------- /results/pred/Misc_335.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_335.png -------------------------------------------------------------------------------- /results/pred/Misc_337.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_337.png -------------------------------------------------------------------------------- /results/pred/Misc_344.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_344.png -------------------------------------------------------------------------------- /results/pred/Misc_35.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_35.png -------------------------------------------------------------------------------- /results/pred/Misc_352.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_352.png -------------------------------------------------------------------------------- /results/pred/Misc_353.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_353.png -------------------------------------------------------------------------------- /results/pred/Misc_354.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_354.png -------------------------------------------------------------------------------- /results/pred/Misc_356.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_356.png -------------------------------------------------------------------------------- /results/pred/Misc_358.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_358.png -------------------------------------------------------------------------------- /results/pred/Misc_360.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_360.png -------------------------------------------------------------------------------- /results/pred/Misc_363.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_363.png -------------------------------------------------------------------------------- /results/pred/Misc_365.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_365.png -------------------------------------------------------------------------------- /results/pred/Misc_366.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_366.png -------------------------------------------------------------------------------- /results/pred/Misc_369.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_369.png -------------------------------------------------------------------------------- /results/pred/Misc_371.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_371.png -------------------------------------------------------------------------------- /results/pred/Misc_372.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_372.png -------------------------------------------------------------------------------- /results/pred/Misc_374.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_374.png -------------------------------------------------------------------------------- /results/pred/Misc_38.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_38.png -------------------------------------------------------------------------------- /results/pred/Misc_380.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_380.png -------------------------------------------------------------------------------- /results/pred/Misc_385.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_385.png -------------------------------------------------------------------------------- /results/pred/Misc_387.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_387.png -------------------------------------------------------------------------------- /results/pred/Misc_390.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_390.png -------------------------------------------------------------------------------- /results/pred/Misc_391.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_391.png -------------------------------------------------------------------------------- /results/pred/Misc_392.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_392.png -------------------------------------------------------------------------------- /results/pred/Misc_393.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_393.png -------------------------------------------------------------------------------- /results/pred/Misc_395.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_395.png -------------------------------------------------------------------------------- /results/pred/Misc_397.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_397.png -------------------------------------------------------------------------------- /results/pred/Misc_399.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_399.png -------------------------------------------------------------------------------- /results/pred/Misc_413.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_413.png -------------------------------------------------------------------------------- /results/pred/Misc_425.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_425.png -------------------------------------------------------------------------------- /results/pred/Misc_43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_43.png -------------------------------------------------------------------------------- /results/pred/Misc_55.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_55.png -------------------------------------------------------------------------------- /results/pred/Misc_66.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_66.png -------------------------------------------------------------------------------- /results/pred/Misc_67.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_67.png -------------------------------------------------------------------------------- /results/pred/Misc_68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_68.png -------------------------------------------------------------------------------- /results/pred/Misc_78.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_78.png -------------------------------------------------------------------------------- /results/pred/Misc_81.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_81.png -------------------------------------------------------------------------------- /results/pred/Misc_83.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_83.png -------------------------------------------------------------------------------- /results/pred/Misc_87.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_87.png -------------------------------------------------------------------------------- /results/pred/Misc_89.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_89.png -------------------------------------------------------------------------------- /results/pred/Misc_90.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_90.png -------------------------------------------------------------------------------- /results/pred/Misc_93.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_93.png -------------------------------------------------------------------------------- /results/pred/Misc_97.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/results/pred/Misc_97.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .contrast import PCMNet, PlainNet, MPCMNet, LayerwiseMPCMNet, MPCMResNetFPN, ResNetFCN 2 | 3 | -------------------------------------------------------------------------------- /params/BottomUpLocal_r_1_b_4_0.7614.params: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YimianDai/open-alcnet/HEAD/params/BottomUpLocal_r_1_b_4_0.7614.params -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | 3 | 4 | *.iml 5 | *.xml 6 | *.pyc 7 | *.log 8 | #*.pdf 9 | *.params 10 | *.states 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # open-alcnet 2 | 3 | codes and trained models for the ALCNet 4 | 5 | ## Requirements 6 | 7 | Install [MXNet](https://mxnet.apache.org/) and [Gluon-CV](https://gluon-cv.mxnet.io/): 8 | 9 | ``` 10 | pip install --upgrade mxnet-cu100 gluoncv 11 | ``` 12 | 13 | ## Dataset 14 | 15 | The SIRST dataset: 16 | 17 | ## Experiments 18 | 19 | The trained model params are in `./params` 20 | 21 | ## Citation 22 | 23 | Please cite our paper in your publications if our work helps your research. BibTeX reference is as follows. 24 | 25 | ``` 26 | @inproceedings{dai21acm, 27 | title = {Asymmetric Contextual Modulation for Infrared Small Target Detection}, 28 | author = {Yimian Dai and Yiquan Wu and Fei Zhou and Kobus Barnard}, 29 | booktitle = {{IEEE} Winter Conference on Applications of Computer Vision, {WACV} 2021} 30 | year = {2021} 31 | } 32 | 33 | @article{TGRS21ALCNet, 34 | author = {{Dai}, Yimian and {Wu}, Yiquan and {Zhou}, Fei and {Barnard}, Kobus}, 35 | title = {{Attentional Local Contrast Networks for Infrared Small Target Detection}}, 36 | journal = {IEEE Transactions on Geoscience and Remote Sensing}, 37 | pages = {1--12}, 38 | year = {2021}, 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | from gluoncv.loss import Loss as gcvLoss 2 | import mxnet as mx 3 | from mxnet import nd, gluon 4 | 5 | 6 | class SamplewiseSoftIoULoss(gcvLoss): 7 | def __init__(self, batch_axis=0, weight=None): 8 | super(SamplewiseSoftIoULoss, self).__init__(weight, batch_axis) 9 | 10 | def hybrid_forward(self, F, pred, target): 11 | # Old One 12 | pred = F.sigmoid(pred) 13 | 14 | intersection = (pred * target).sum(axis=(1, 2, 3)) 15 | pred = pred.sum(axis=(1, 2, 3)) 16 | target = target.sum(axis=(1, 2, 3)) 17 | smooth = .1 18 | 19 | loss = (intersection + smooth) / (pred + target - intersection + smooth) 20 | loss = (1 - loss).mean() 21 | 22 | return loss 23 | 24 | 25 | 26 | class SoftIoULoss(gcvLoss): 27 | def __init__(self, batch_axis=0, weight=None): 28 | super(SoftIoULoss, self).__init__(weight, batch_axis) 29 | 30 | def hybrid_forward(self, F, pred, target): 31 | # Old One 32 | pred = F.sigmoid(pred) 33 | smooth = 1 34 | 35 | # print("pred.shape: ", pred.shape) 36 | # print("target.shape: ", target.shape) 37 | 38 | intersection = pred * target 39 | loss = (intersection.sum() + smooth) / (pred.sum() + target.sum() - 40 | intersection.sum() + smooth) 41 | # loss = (intersection.sum(axis=(1, 2, 3)) + smooth) / \ 42 | # (pred.sum(axis=(1, 2, 3)) + target.sum(axis=(1, 2, 3)) 43 | # - intersection.sum(axis=(1, 2, 3)) + smooth) 44 | 45 | loss = 1 - loss.mean() 46 | # loss = (1 - loss).mean() 47 | 48 | return loss 49 | 50 | -------------------------------------------------------------------------------- /metric/sigmoid.py: -------------------------------------------------------------------------------- 1 | """Evaluation Metrics for Semantic Segmentation of Foreground Only""" 2 | import threading 3 | import numpy as np 4 | import mxnet as mx 5 | from mxnet import nd 6 | from mxnet.metric import EvalMetric 7 | 8 | __all__ = ['SigmoidMetric', 'batch_pix_accuracy', 'batch_intersection_union'] 9 | 10 | class SigmoidMetric(EvalMetric): 11 | """Computes pixAcc and mIoU metric scores 12 | """ 13 | def __init__(self, nclass): 14 | super(SigmoidMetric, self).__init__('pixAcc & mIoU') 15 | self.nclass = nclass 16 | self.lock = threading.Lock() 17 | self.reset() 18 | 19 | def update(self, preds, labels): 20 | """Updates the internal evaluation result. 21 | 22 | Parameters 23 | ---------- 24 | labels : 'NDArray' or list of `NDArray` 25 | The labels of the data. 26 | 27 | preds : 'NDArray' or list of `NDArray` 28 | Predicted values. 29 | """ 30 | def evaluate_worker(self, label, pred): 31 | correct, labeled = batch_pix_accuracy( 32 | pred, label) 33 | inter, union = batch_intersection_union( 34 | pred, label, self.nclass) 35 | with self.lock: 36 | self.total_correct += correct 37 | self.total_label += labeled 38 | self.total_inter += inter 39 | self.total_union += union 40 | 41 | if isinstance(preds, mx.nd.NDArray): 42 | evaluate_worker(self, labels, preds) 43 | elif isinstance(preds, (list, tuple)): 44 | threads = [threading.Thread(target=evaluate_worker, 45 | args=(self, label, pred), 46 | ) 47 | for (label, pred) in zip(labels, preds)] 48 | for thread in threads: 49 | thread.start() 50 | for thread in threads: 51 | thread.join() 52 | 53 | def get(self): 54 | """Gets the current evaluation result. 55 | 56 | Returns 57 | ------- 58 | metrics : tuple of float 59 | pixAcc and mIoU 60 | """ 61 | # print("self.total_correct: ", self.total_correct) 62 | # print("self.total_label: ", self.total_label) 63 | # print("self.total_union: ", self.total_union) 64 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 65 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 66 | mIoU = IoU.mean() 67 | return pixAcc, mIoU 68 | 69 | def reset(self): 70 | """Resets the internal evaluation result to initial state.""" 71 | self.total_inter = 0 72 | self.total_union = 0 73 | self.total_correct = 0 74 | self.total_label = 0 75 | 76 | def batch_pix_accuracy(output, target): 77 | """PixAcc""" 78 | # inputs are NDarray, output 4D, target 3D 79 | # the category 0 is ignored class, typically for background / boundary 80 | # predict = np.argmax(output.asnumpy(), 1).astype('int64') 81 | # print("Metric output.shape: ", output.shape) 82 | # print("Metric target.shape: ", target.shape) 83 | # print("output.max(): ", output.max().asscalar()) 84 | # print("target.max(): ", target.max().asscalar()) 85 | if len(target.shape) == 3: 86 | target = nd.expand_dims(target, axis=1).asnumpy().astype('int64') # T 87 | elif len(target.shape) == 4: 88 | target = target.asnumpy().astype('int64') # T 89 | else: 90 | raise ValueError("Unknown target dimension") 91 | # print("output.shape: ", output.shape) 92 | # print("target.shape: ", target.shape) 93 | assert output.shape == target.shape, "Predict and Label Shape Don't Match" 94 | predict = (output.asnumpy() > 0).astype('int64') # P 95 | pixel_labeled = np.sum(target > 0) # T 96 | pixel_correct = np.sum((predict == target)*(target > 0)) # TP 97 | 98 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 99 | return pixel_correct, pixel_labeled 100 | 101 | 102 | def batch_intersection_union(output, target, nclass): 103 | """mIoU""" 104 | # inputs are NDarray, output 4D, target 3D 105 | # the category 0 is ignored class, typically for background / boundary 106 | mini = 1 107 | maxi = 1 # nclass 108 | nbins = 1 # nclass 109 | predict = (output.asnumpy() > 0).astype('int64') # P 110 | if len(target.shape) == 3: 111 | target = nd.expand_dims(target, axis=1).asnumpy().astype('int64') # T 112 | elif len(target.shape) == 4: 113 | target = target.asnumpy().astype('int64') # T 114 | else: 115 | raise ValueError("Unknown target dimension") 116 | intersection = predict * (predict == target) # TP 117 | 118 | # areas of intersection and union 119 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) 120 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 121 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 122 | area_union = area_pred + area_lab - area_inter 123 | assert (area_inter <= area_union).all(), \ 124 | "Intersection area should be smaller than Union area" 125 | return area_inter, area_union 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /data/contrast.py: -------------------------------------------------------------------------------- 1 | from gluoncv.data.segbase import SegmentationDataset 2 | from gluoncv.data.base import VisionDataset 3 | from PIL import Image, ImageOps, ImageFilter 4 | import platform, os 5 | import logging 6 | import mxnet as mx 7 | import random 8 | from mxnet import cpu, nd 9 | import numpy as np 10 | 11 | try: 12 | import xml.etree.cElementTree as ET 13 | except ImportError: 14 | import xml.etree.ElementTree as ET 15 | 16 | 17 | class IceContrast(SegmentationDataset): 18 | 19 | """Iceberg Segmentation dataset.""" 20 | NUM_CLASS = 1 21 | 22 | def __init__(self, base_dir='DENTIST', root=os.path.join('~', 'Nutstore Files', 'Dataset'), 23 | split='train', mode=None, transform=None, include_name=False, **kwargs): 24 | super(IceContrast, self).__init__(root, split, mode, transform, **kwargs) 25 | 26 | # if platform.system() == "Linux": 27 | # root = os.path.join('~', 'datasets') 28 | # print("colab:", colab) 29 | # if colab: 30 | # root = '/content/gdrive/My Drive/Colab Notebooks/datasets' 31 | self.include_name = include_name 32 | self.base_dir = base_dir 33 | self._root = os.path.expanduser(os.path.join(root, base_dir)) 34 | self._transform = transform 35 | self._split = split 36 | self.mode = mode 37 | self._items = self._load_items(split) 38 | if base_dir == 'DENTIST': 39 | self._anno_path = os.path.join('{}', 'masks/', '{}_pixels0.png') 40 | self._image_path = os.path.join('{}', 'images', '{}.png') 41 | elif base_dir == 'Iceberg': 42 | self._anno_path = os.path.join('{}', 'labels/mask/', '{}_pixels0.png') 43 | self._image_path = os.path.join('{}', 'images', '{}.png') 44 | else: 45 | raise ValueError("Unknown base dir") 46 | 47 | def _load_items(self, split): 48 | """Load individual image indices from splits.""" 49 | ids = [] 50 | root = self._root 51 | lf = os.path.join(root, split + '.txt') 52 | with open(lf, 'r') as f: 53 | ids += [(root, line.strip()) for line in f.readlines()] 54 | 55 | random.shuffle(ids) 56 | return ids 57 | 58 | def __getitem__(self, idx): 59 | img_id = self._items[idx] 60 | img_path = self._image_path.format(*img_id) 61 | label_path = self._anno_path.format(*img_id) 62 | 63 | # img = Image.open(img_path).convert('L') 64 | img = Image.open(img_path).convert('RGB') 65 | if self.mode == 'test': 66 | img = img.resize((self.base_size, self.base_size), Image.BILINEAR) 67 | img = self._img_transform(img) 68 | if self.transform is not None: 69 | img = self.transform(img) 70 | return img, img_id[-1] 71 | mask = Image.open(label_path) 72 | # synchronized transform 73 | if self.mode == 'train': 74 | img, mask = self._sync_transform(img, mask) 75 | elif self.mode == 'val': 76 | img, mask = self._val_sync_transform(img, mask) 77 | else: 78 | assert self.mode == 'testval' 79 | if self.base_dir == 'DENTIST': 80 | img, mask = self._testval_sync_transform(img, mask) 81 | else: 82 | img, mask = self._img_transform(img), self._mask_transform(mask) 83 | # general resize, normalize and toTensor 84 | if self.transform is not None: 85 | img = self.transform(img) 86 | mask = nd.expand_dims(mask, axis=0).astype('float32') / 255.0 87 | 88 | if self.include_name: 89 | return img, mask, img_id[-1] 90 | else: 91 | return img, mask 92 | 93 | def __len__(self): 94 | return len(self._items) 95 | 96 | @property 97 | def classes(self): 98 | """Category names.""" 99 | return ('iceberg') 100 | 101 | def _testval_sync_transform(self, img, mask): 102 | base_size = self.base_size 103 | img = img.resize((base_size, base_size), Image.BILINEAR) 104 | mask = mask.resize((base_size, base_size), Image.NEAREST) 105 | # final transform 106 | img, mask = self._img_transform(img), self._mask_transform(mask) 107 | return img, mask 108 | 109 | def _sync_transform(self, img, mask): 110 | # random mirror 111 | if random.random() < 0.5: 112 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 113 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 114 | crop_size = self.crop_size 115 | # random scale (short edge) 116 | long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0)) 117 | w, h = img.size 118 | if h > w: 119 | oh = long_size 120 | ow = int(1.0 * w * long_size / h + 0.5) 121 | short_size = ow 122 | else: 123 | ow = long_size 124 | oh = int(1.0 * h * long_size / w + 0.5) 125 | short_size = oh 126 | img = img.resize((ow, oh), Image.BILINEAR) 127 | mask = mask.resize((ow, oh), Image.NEAREST) 128 | # pad crop 129 | if short_size < crop_size: 130 | padh = crop_size - oh if oh < crop_size else 0 131 | padw = crop_size - ow if ow < crop_size else 0 132 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 133 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 134 | # random crop crop_size 135 | w, h = img.size 136 | x1 = random.randint(0, w - crop_size) 137 | y1 = random.randint(0, h - crop_size) 138 | img = img.crop((x1, y1, x1+crop_size, y1+crop_size)) 139 | mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size)) 140 | # gaussian blur as in PSP 141 | if random.random() < 0.5: 142 | img = img.filter(ImageFilter.GaussianBlur( 143 | radius=random.random())) 144 | # final transform 145 | img, mask = self._img_transform(img), self._mask_transform(mask) 146 | return img, mask 147 | 148 | -------------------------------------------------------------------------------- /visualize_local_contrast_networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.system("taskset -p -c 1-96 %d" % os.getpid()) 3 | import scipy.misc 4 | import platform 5 | import timeit 6 | import sys 7 | import socket 8 | import argparse 9 | import numpy as np 10 | from utils import summary 11 | from tqdm import tqdm 12 | 13 | import mxnet as mx 14 | from mxnet import gluon 15 | from mxnet.gluon.data.vision import transforms 16 | 17 | from data import IceContrast 18 | from model import MPCMResNetFPN 19 | from loss import SoftIoULoss 20 | 21 | import matplotlib.pyplot as plt 22 | 23 | def parse_args(): 24 | """Training Options for Segmentation Experiments""" 25 | parser = argparse.ArgumentParser(description='MXNet Gluon \ 26 | Segmentation') 27 | # model 28 | parser.add_argument('--net-choice', type=str, default='MPCMResNetFPN', 29 | help='model name PCMNet, PlainNet') 30 | parser.add_argument('--pyramid-mode', type=str, default='Dec', 31 | help='Inc, Dec') 32 | parser.add_argument('--scale-mode', type=str, default='Multiple', 33 | help='Single, Multiple, Selective') 34 | parser.add_argument('--pyramid-fuse', type=str, default='bottomuplocal', 35 | help='add, max, sk') 36 | parser.add_argument('--cue', type=str, default='lcm', help='lcm or orig') 37 | # dataset 38 | parser.add_argument('--dataset', type=str, default='DENTIST', 39 | help='dataset name (default: DENTIST, Iceberg)') 40 | parser.add_argument('--workers', type=int, default=48, 41 | metavar='N', help='dataloader threads') 42 | parser.add_argument('--base-size', type=int, default=256, 43 | help='base image size') 44 | parser.add_argument('--blocks', type=int, default=4, 45 | help='[1] * blocks') 46 | parser.add_argument('--channels', type=int, default=16, 47 | help='channels') 48 | parser.add_argument('--shift', type=int, default=13, 49 | help='shift') 50 | parser.add_argument('--iou-thresh', type=float, default=0.5, 51 | help='iou-thresh') 52 | parser.add_argument('--crop-size', type=int, default=240, 53 | help='crop image size') 54 | parser.add_argument('--train-split', type=str, default='trainval', 55 | help='dataset train split (default: train)') 56 | parser.add_argument('--val-split', type=str, default='test', 57 | help='dataset val split (default: val)') 58 | 59 | # training hyper params 60 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 61 | help='number of epochs to train (default: 50)') 62 | parser.add_argument('--start_epoch', type=int, default=0, 63 | metavar='N', help='start epochs (default:0)') 64 | parser.add_argument('--batch-size', type=int, default=1, 65 | metavar='N', help='input batch size for \ 66 | training (default: 16)') 67 | parser.add_argument('--test-batch-size', type=int, default=1, 68 | metavar='N', help='input batch size for \ 69 | testing (default: 32)') 70 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 71 | help='learning rate (default: 1e-3)') 72 | parser.add_argument('--lr-decay', type=float, default=0.1, 73 | help='decay rate of learning rate. default is 0.1.') 74 | parser.add_argument('--lr-decay-epoch', type=str, default='100,200', 75 | help='epochs at which learning rate decays. default is 40,60.') 76 | parser.add_argument('--gamma', type=int, default=2, 77 | help='gamma for Focal Soft IoU Loss') 78 | parser.add_argument('--lambd', type=int, default=1, 79 | help='lambd for TV Soft IoU Loss') 80 | parser.add_argument('--momentum', type=float, default=0.9, 81 | metavar='M', help='momentum (default: 0.9)') 82 | parser.add_argument('--weight-decay', type=float, default=1e-4, 83 | metavar='M', help='w-decay (default: 1e-4)') 84 | parser.add_argument('--no-wd', action='store_true', 85 | help='whether to remove weight decay on bias, \ 86 | and beta/gamma for batchnorm layers.') 87 | parser.add_argument('--sparsity', action='store_true', default= 88 | False, help='') 89 | parser.add_argument('--score-thresh', type=float, default=0.5, 90 | help='score-thresh') 91 | # cuda and logging 92 | parser.add_argument('--no-cuda', action='store_true', default= 93 | False, help='disables CUDA training') 94 | parser.add_argument('--gpus', type=str, default='0', 95 | help='Training with GPUs, you can specify 1,3 for example.') 96 | parser.add_argument('--kvstore', type=str, default='device', 97 | help='kvstore to use for trainer/module.') 98 | parser.add_argument('--dtype', type=str, default='float32', 99 | help='data type for training. default is float32') 100 | parser.add_argument('--wd', type=float, default=0.0001, 101 | help='weight decay rate. default is 0.0001.') 102 | 103 | # checking point 104 | parser.add_argument('--resume', type=str, default=None, 105 | help='put the path to resuming file if needed') 106 | parser.add_argument('--colab', action='store_true', default= 107 | False, help='whether using colab') 108 | 109 | # evaluation only 110 | parser.add_argument('--eval', action='store_true', default= False, 111 | help='evaluation only') 112 | parser.add_argument('--no-val', action='store_true', default= False, 113 | help='skip validation during training') 114 | parser.add_argument('--metric', type=str, default='mAP', 115 | help='F1, IoU, mAP') 116 | 117 | # synchronized Batch Normalization 118 | parser.add_argument('--syncbn', action='store_true', default= False, 119 | help='using Synchronized Cross-GPU BatchNorm') 120 | 121 | # the parser 122 | args = parser.parse_args() 123 | # handle contexts 124 | if args.no_cuda or (len(mx.test_utils.list_gpus()) == 0): 125 | print('Using CPU') 126 | args.kvstore = 'local' 127 | args.ctx = [mx.cpu(0)] 128 | else: 129 | args.ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] 130 | print('Number of GPUs:', len(args.ctx)) 131 | 132 | # Synchronized BatchNorm 133 | args.norm_layer = mx.gluon.contrib.nn.SyncBatchNorm if args.syncbn \ 134 | else mx.gluon.nn.BatchNorm 135 | args.norm_kwargs = {'num_devices': len(args.ctx)} if args.syncbn else {} 136 | print(args) 137 | return args 138 | 139 | 140 | class Trainer(object): 141 | def __init__(self, args): 142 | self.args = args 143 | # image transform 144 | input_transform = transforms.Compose([ 145 | transforms.ToTensor(), 146 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), # Default mean and std 147 | ]) 148 | ################################# dataset and dataloader ################################# 149 | if platform.system() == "Darwin": 150 | data_root = os.path.join('~', 'Nutstore Files', 'Dataset') 151 | elif platform.system() == "Linux": 152 | data_root = os.path.join('~', 'datasets') 153 | if args.colab: 154 | data_root = '/content/datasets' 155 | else: 156 | raise ValueError('Notice Dataset Path') 157 | 158 | data_kwargs = {'base_size': args.base_size, 'transform': input_transform, 159 | 'crop_size': args.crop_size, 'root': data_root, 160 | 'base_dir' : args.dataset} 161 | # data_kwargs = {'base_size': args.base_size, 162 | # 'crop_size': args.crop_size, 'root': data_root, 163 | # 'base_dir' : args.dataset} 164 | valset = IceContrast(split=args.val_split, mode='testval', include_name=True, 165 | **data_kwargs) 166 | self.valset = valset 167 | 168 | net_choice = args.net_choice 169 | print("net_choice: ", net_choice) 170 | 171 | if net_choice == 'MPCMResNetFPN': 172 | layers = [self.args.blocks] * 3 173 | channels = [8, 16, 32, 64] 174 | shift = self.args.shift 175 | pyramid_mode = self.args.pyramid_mode 176 | scale_mode = self.args.scale_mode 177 | pyramid_fuse = self.args.pyramid_fuse 178 | 179 | model = MPCMResNetFPN(layers=layers, channels=channels, shift=shift, 180 | pyramid_mode=pyramid_mode, scale_mode=scale_mode, 181 | pyramid_fuse=pyramid_fuse, classes=valset.NUM_CLASS) 182 | print("net_choice: ", net_choice) 183 | print("scale_mode: ", scale_mode) 184 | print("pyramid_fuse: ", pyramid_fuse) 185 | print("layers: ", layers) 186 | print("channels: ", channels) 187 | print("shift: ", shift) 188 | else: 189 | raise ValueError('Unknow net_choice') 190 | 191 | self.host_name = socket.gethostname() 192 | self.save_prefix = 'MLCPFN' + '_' + args.scale_mode + '_' + args.pyramid_fuse + '_' 193 | 194 | params_path = './params/BottomUpLocal_r_1_b_4_0.7614.params' 195 | model.load_parameters(params_path, ctx=args.ctx) 196 | self.net = model 197 | 198 | # create criterion 199 | kv = mx.kv.create(args.kvstore) 200 | 201 | optimizer_params = { 202 | 'wd': args.weight_decay, 203 | 'learning_rate': args.lr 204 | } 205 | 206 | if args.dtype == 'float16': 207 | optimizer_params['multi_precision'] = True 208 | 209 | if args.no_wd: 210 | for k, v in self.net.collect_params('.*beta|.*gamma|.*bias').items(): 211 | v.wd_mult = 0.0 212 | 213 | ################################# evaluation metrics ################################# 214 | 215 | def validation(self, epoch): 216 | save_path = os.path.expanduser('/Users/grok/Downloads/') 217 | 218 | # summary(self.net, mx.nd.zeros((1, 3, args.crop_size, args.crop_size), ctx=args.ctx[0])) 219 | # sys.exit() 220 | 221 | i = 0 222 | mx.nd.waitall() 223 | start = timeit.default_timer() 224 | for img, mask, img_id in self.valset: 225 | exp_img = img.expand_dims(axis=0) 226 | # pred = self.net(exp_img).squeeze().asnumpy() > 0 227 | pred = self.net(exp_img) 228 | plt.imsave(save_path + img_id + '.png', pred) 229 | # print(pred.shape) 230 | 231 | # save_path = os.path.expanduser('/Users/grok/Downloads/img') 232 | # for img, mask, img_id in self.valset: 233 | # exp_img = img.expand_dims(axis=0) 234 | # img = mx.nd.transpose(img, (1, 2, 0)) 235 | # print(img.shape) 236 | # img = img.squeeze().asnumpy() / 255 237 | # plt.imsave(save_path + img_id + '.png', img) 238 | 239 | 240 | # break 241 | 242 | 243 | 244 | 245 | if __name__ == "__main__": 246 | args = parse_args() 247 | trainer = Trainer(args) 248 | if args.eval: 249 | print('Evaluating model: ', args.resume) 250 | trainer.validation(args.start_epoch) 251 | else: 252 | print('Starting Epoch:', args.start_epoch) 253 | print('Total Epochs:', args.epochs) 254 | trainer.validation(0) 255 | -------------------------------------------------------------------------------- /metric/samplewise.py: -------------------------------------------------------------------------------- 1 | """Evaluation Metrics for Semantic Segmentation of Foreground Only""" 2 | import threading 3 | import numpy as np 4 | import mxnet as mx 5 | from mxnet import nd 6 | from mxnet.metric import EvalMetric 7 | 8 | # __all__ = ['SamplewiseSigmoidMetric', 'batch_pix_accuracy', 'batch_intersection_union'] 9 | __all__ = ['SamplewiseSigmoidMetric', 'batch_intersection_union'] 10 | 11 | class SamplewiseSigmoidMetric(EvalMetric): 12 | """Computes pixAcc and mIoU metric scores 13 | """ 14 | def __init__(self, nclass, score_thresh=0.5): 15 | super(SamplewiseSigmoidMetric, self).__init__('pixAcc & mIoU') 16 | self.nclass = nclass 17 | self.score_thresh = score_thresh 18 | self.lock = threading.Lock() 19 | self.reset() 20 | 21 | def update(self, preds, labels): 22 | """Updates the internal evaluation result. 23 | 24 | Parameters 25 | ---------- 26 | labels : 'NDArray' or list of `NDArray` 27 | The labels of the data. 28 | 29 | preds : 'NDArray' or list of `NDArray` 30 | Predicted values. 31 | """ 32 | def evaluate_worker(self, label, pred): 33 | inter_arr, union_arr = batch_intersection_union( 34 | pred, label, self.nclass, self.score_thresh) 35 | with self.lock: 36 | self.total_inter = np.append(self.total_inter, inter_arr) 37 | self.total_union = np.append(self.total_union, union_arr) 38 | 39 | if isinstance(preds, mx.nd.NDArray): 40 | evaluate_worker(self, labels, preds) 41 | elif isinstance(preds, (list, tuple)): 42 | threads = [threading.Thread(target=evaluate_worker, 43 | args=(self, label, pred), 44 | ) 45 | for (label, pred) in zip(labels, preds)] 46 | for thread in threads: 47 | thread.start() 48 | for thread in threads: 49 | thread.join() 50 | 51 | def get(self): 52 | """Gets the current evaluation result. 53 | 54 | Returns 55 | ------- 56 | metrics : tuple of float 57 | pixAcc and mIoU 58 | """ 59 | # print("self.total_correct: ", self.total_correct) 60 | # print("self.total_label: ", self.total_label) 61 | # print("self.total_union: ", self.total_union) 62 | # pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 63 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 64 | mIoU = IoU.mean() 65 | return mIoU, mIoU 66 | 67 | def reset(self): 68 | """Resets the internal evaluation result to initial state.""" 69 | self.total_inter = np.array([]) 70 | self.total_union = np.array([]) 71 | self.total_correct = np.array([]) 72 | self.total_label = np.array([]) 73 | 74 | 75 | def batch_intersection_union(output, target, nclass, score_thresh): 76 | """mIoU""" 77 | # inputs are NDarray, output 4D, target 3D 78 | # the category 0 is ignored class, typically for background / boundary 79 | mini = 1 80 | maxi = 1 # nclass 81 | nbins = 1 # nclass 82 | 83 | predict = (nd.sigmoid(output).asnumpy() > score_thresh).astype('int64') # P 84 | # predict = (output.asnumpy() > 0).astype('int64') # P 85 | if len(target.shape) == 3: 86 | target = nd.expand_dims(target, axis=1).asnumpy().astype('int64') # T 87 | elif len(target.shape) == 4: 88 | target = target.asnumpy().astype('int64') # T 89 | else: 90 | raise ValueError("Unknown target dimension") 91 | intersection = predict * (predict == target) # TP 92 | 93 | 94 | num_sample = intersection.shape[0] 95 | area_inter_arr = np.zeros(num_sample) 96 | area_pred_arr = np.zeros(num_sample) 97 | area_lab_arr = np.zeros(num_sample) 98 | area_union_arr = np.zeros(num_sample) 99 | for b in range(num_sample): 100 | 101 | # areas of intersection and union 102 | area_inter, _ = np.histogram(intersection[b], bins=nbins, range=(mini, maxi)) 103 | area_inter_arr[b] = area_inter 104 | 105 | area_pred, _ = np.histogram(predict[b], bins=nbins, range=(mini, maxi)) 106 | area_pred_arr[b] = area_pred 107 | 108 | area_lab, _ = np.histogram(target[b], bins=nbins, range=(mini, maxi)) 109 | area_lab_arr[b] = area_lab 110 | 111 | area_union = area_pred + area_lab - area_inter 112 | area_union_arr[b] = area_union 113 | 114 | assert (area_inter <= area_union).all(), \ 115 | "Intersection area should be smaller than Union area" 116 | 117 | return area_inter_arr, area_union_arr 118 | 119 | 120 | 121 | class ROCMetric(EvalMetric): 122 | """Computes pixAcc and mIoU metric scores 123 | """ 124 | def __init__(self, nclass, bins): 125 | super(ROCMetric, self).__init__('ROC') 126 | self.nclass = nclass 127 | self.lock = threading.Lock() 128 | self.bins = bins 129 | self.tp_arr = np.zeros(self.bins+1) 130 | self.pos_arr = np.zeros(self.bins+1) 131 | self.fp_arr = np.zeros(self.bins+1) 132 | self.neg_arr = np.zeros(self.bins+1) 133 | # self.reset() 134 | 135 | def update(self, preds, labels): 136 | """Updates the internal evaluation result. 137 | 138 | Parameters 139 | ---------- 140 | labels : 'NDArray' or list of `NDArray` 141 | The labels of the data. 142 | 143 | preds : 'NDArray' or list of `NDArray` 144 | Predicted values. 145 | """ 146 | def evaluate_worker(self, label, pred): 147 | for iBin in range(self.bins+1): 148 | score_thresh = (iBin + 0.0) / self.bins 149 | # print(iBin, "-th, score_thresh: ", score_thresh) 150 | i_tp, i_pos, i_fp, i_neg = cal_tp_pos_fp_neg(pred, label, self.nclass, 151 | score_thresh) 152 | # print("i_tp: ", i_tp) 153 | # print("i_fp: ", i_fp) 154 | with self.lock: 155 | self.tp_arr[iBin] += i_tp 156 | self.pos_arr[iBin] += i_pos 157 | self.fp_arr[iBin] += i_fp 158 | self.neg_arr[iBin] += i_neg 159 | 160 | if isinstance(preds, mx.nd.NDArray): 161 | evaluate_worker(self, labels, preds) 162 | elif isinstance(preds, (list, tuple)): 163 | threads = [threading.Thread(target=evaluate_worker, 164 | args=(self, label, pred), 165 | ) 166 | for (label, pred) in zip(labels, preds)] 167 | for thread in threads: 168 | thread.start() 169 | for thread in threads: 170 | thread.join() 171 | 172 | def get(self): 173 | """Gets the current evaluation result. 174 | 175 | Returns 176 | ------- 177 | metrics : tuple of float 178 | pixAcc and mIoU 179 | """ 180 | # print("self.total_correct: ", self.total_correct) 181 | # print("self.total_label: ", self.total_label) 182 | # print("self.total_union: ", self.total_union) 183 | # pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 184 | tp_rates = self.tp_arr / (self.pos_arr + 0.001) 185 | fp_rates = self.fp_arr / (self.neg_arr + 0.001) 186 | 187 | return tp_rates, fp_rates 188 | # return self.tp_arr, self.fp_arr 189 | 190 | # def reset(self): 191 | # """Resets the internal evaluation result to initial state.""" 192 | # self.tp_arr = np.ones(self.bins+1) 193 | # self.pos_arr = np.ones(self.bins+1) 194 | # self.fp_arr = np.ones(self.bins+1) 195 | # self.neg_arr = np.ones(self.bins+1) 196 | 197 | 198 | 199 | 200 | def cal_tp_pos_fp_neg(output, target, nclass, score_thresh): 201 | """mIoU""" 202 | # inputs are NDarray, output 4D, target 3D 203 | # the category 0 is ignored class, typically for background / boundary 204 | mini = 1 205 | maxi = 1 # nclass 206 | nbins = 1 # nclass 207 | 208 | predict = (nd.sigmoid(output).asnumpy() > score_thresh).astype('int64') # P 209 | # predict = (output.asnumpy() > 0).astype('int64') # P 210 | if len(target.shape) == 3: 211 | target = nd.expand_dims(target, axis=1).asnumpy().astype('int64') # T 212 | elif len(target.shape) == 4: 213 | target = target.asnumpy().astype('int64') # T 214 | else: 215 | raise ValueError("Unknown target dimension") 216 | intersection = predict * (predict == target) # TP 217 | tp = intersection.sum() 218 | fp = (predict * (predict != target)).sum() # FP 219 | tn = ((1 - predict) * (predict == target)).sum() # TN 220 | fn = ((predict != target) * (1 - predict)).sum() # FN 221 | pos = tp + fn 222 | neg = fp + tn 223 | 224 | return tp, pos, fp, neg 225 | 226 | 227 | def cal_normalized_tp_pos_fp_neg(output, target, nclass, score_thresh): 228 | """mIoU""" 229 | # inputs are NDarray, output 4D, target 3D 230 | # the category 0 is ignored class, typically for background / boundary 231 | mini = 1 232 | maxi = 1 # nclass 233 | nbins = 1 # nclass 234 | 235 | predict = (nd.sigmoid(output).asnumpy() > score_thresh).astype('int64') # P 236 | # predict = (output.asnumpy() > 0).astype('int64') # P 237 | if len(target.shape) == 3: 238 | target = nd.expand_dims(target, axis=1).asnumpy().astype('int64') # T 239 | elif len(target.shape) == 4: 240 | target = target.asnumpy().astype('int64') # T 241 | else: 242 | raise ValueError("Unknown target dimension") 243 | intersection = predict * (predict == target) # TP 244 | tp = intersection.sum() 245 | fp = (predict * (predict != target)).sum() # FP 246 | tn = ((1 - predict) * (predict == target)).sum() # TN 247 | fn = ((predict != target) * (1 - predict)).sum() # FN 248 | pos = tp + fn 249 | neg = fp + tn 250 | 251 | return tp, pos, fp, neg 252 | 253 | 254 | 255 | class nROCMetric(EvalMetric): 256 | """Computes pixAcc and mIoU metric scores 257 | """ 258 | def __init__(self, nclass, bins): 259 | super(nROCMetric, self).__init__('ROC') 260 | self.nclass = nclass 261 | self.lock = threading.Lock() 262 | self.bins = bins 263 | self.tp_arr = np.zeros(self.bins+1) 264 | self.pos_arr = np.zeros(self.bins+1) 265 | self.fp_arr = np.zeros(self.bins+1) 266 | self.neg_arr = np.zeros(self.bins+1) 267 | # self.reset() 268 | 269 | def update(self, preds, labels): 270 | """Updates the internal evaluation result. 271 | 272 | Parameters 273 | ---------- 274 | labels : 'NDArray' or list of `NDArray` 275 | The labels of the data. 276 | 277 | preds : 'NDArray' or list of `NDArray` 278 | Predicted values. 279 | """ 280 | def evaluate_worker(self, label, pred): 281 | for iBin in range(self.bins+1): 282 | score_thresh = (iBin + 0.0) / self.bins 283 | # print(iBin, "-th, score_thresh: ", score_thresh) 284 | i_tp, i_pos, i_fp, i_neg = cal_tp_pos_fp_neg(pred, label, self.nclass, 285 | score_thresh) 286 | # print("i_tp: ", i_tp) 287 | # print("i_fp: ", i_fp) 288 | with self.lock: 289 | self.tp_arr[iBin] += i_tp 290 | self.pos_arr[iBin] += i_pos 291 | self.fp_arr[iBin] += i_fp 292 | self.neg_arr[iBin] += i_neg 293 | 294 | if isinstance(preds, mx.nd.NDArray): 295 | evaluate_worker(self, labels, preds) 296 | elif isinstance(preds, (list, tuple)): 297 | threads = [threading.Thread(target=evaluate_worker, 298 | args=(self, label, pred), 299 | ) 300 | for (label, pred) in zip(labels, preds)] 301 | for thread in threads: 302 | thread.start() 303 | for thread in threads: 304 | thread.join() 305 | 306 | def get(self): 307 | """Gets the current evaluation result. 308 | 309 | Returns 310 | ------- 311 | metrics : tuple of float 312 | pixAcc and mIoU 313 | """ 314 | tp_rates = self.tp_arr / (self.pos_arr + 0.001) 315 | fp_rates = self.fp_arr / (self.neg_arr + 0.001) 316 | 317 | return tp_rates, fp_rates 318 | -------------------------------------------------------------------------------- /train_alcnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.system("taskset -p -c 40-47 %d" % os.getpid()) 3 | 4 | import platform 5 | import sys 6 | import socket 7 | import argparse 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from datetime import datetime 12 | 13 | import mxnet as mx 14 | from mxnet import gluon, autograd, init 15 | from mxnet.gluon.data.vision import transforms 16 | 17 | from gluoncv.utils import LRScheduler 18 | 19 | from data import IceContrast 20 | from model import MPCMResNetFPN 21 | from metric import SigmoidMetric, SamplewiseSigmoidMetric 22 | from loss import SoftIoULoss, SamplewiseSoftIoULoss 23 | 24 | def parse_args(): 25 | """Training Options for Segmentation Experiments""" 26 | parser = argparse.ArgumentParser(description='MXNet Gluon \ 27 | Segmentation') 28 | # model 29 | parser.add_argument('--net-choice', type=str, default='PCMNet', 30 | help='model name PCMNet, PlainNet') 31 | parser.add_argument('--pyramid-mode', type=str, default='Dec', 32 | help='Inc, Dec') 33 | parser.add_argument('--r', type=int, default=2, help='1, 2, 4') 34 | parser.add_argument('--summary', action='store_true', 35 | help='print parameters') 36 | parser.add_argument('--scale-mode', type=str, default='xxx', 37 | help='Single, Multiple, Selective') 38 | parser.add_argument('--pyramid-fuse', type=str, default='sk', 39 | help='add, max, sk') 40 | parser.add_argument('--cue', type=str, default='lcm', help='lcm or orig') 41 | # dataset 42 | parser.add_argument('--dataset', type=str, default='DENTIST', 43 | help='dataset name (default: DENTIST, Iceberg)') 44 | parser.add_argument('--workers', type=int, default=48, 45 | metavar='N', help='dataloader threads') 46 | parser.add_argument('--base-size', type=int, default=512, 47 | help='base image size') 48 | parser.add_argument('--blocks', type=int, default=4, 49 | help='[1] * blocks') 50 | parser.add_argument('--channels', type=int, default=16, 51 | help='channels') 52 | parser.add_argument('--shift', type=int, default=13, 53 | help='shift') 54 | parser.add_argument('--iou-thresh', type=float, default=0.5, 55 | help='iou-thresh') 56 | parser.add_argument('--crop-size', type=int, default=480, 57 | help='crop image size') 58 | parser.add_argument('--train-split', type=str, default='trainval', 59 | help='dataset train split (default: train)') 60 | parser.add_argument('--val-split', type=str, default='test', 61 | help='dataset val split (default: val)') 62 | 63 | # training hyper params 64 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 65 | help='number of epochs to train (default: 50)') 66 | parser.add_argument('--start_epoch', type=int, default=0, 67 | metavar='N', help='start epochs (default:0)') 68 | parser.add_argument('--batch-size', type=int, default=2, 69 | metavar='N', help='input batch size for \ 70 | training (default: 16)') 71 | parser.add_argument('--test-batch-size', type=int, default=32, 72 | metavar='N', help='input batch size for \ 73 | testing (default: 32)') 74 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 75 | help='learning rate (default: 1e-3)') 76 | parser.add_argument('--lr-decay', type=float, default=0.1, 77 | help='decay rate of learning rate. default is 0.1.') 78 | parser.add_argument('--lr-decay-epoch', type=str, default='100,200', 79 | help='epochs at which learning rate decays. default is 40,60.') 80 | parser.add_argument('--gamma', type=int, default=2, 81 | help='gamma for Focal Soft IoU Loss') 82 | parser.add_argument('--lambd', type=int, default=1, 83 | help='lambd for TV Soft IoU Loss') 84 | parser.add_argument('--momentum', type=float, default=0.9, 85 | metavar='M', help='momentum (default: 0.9)') 86 | parser.add_argument('--weight-decay', type=float, default=1e-4, 87 | metavar='M', help='w-decay (default: 1e-4)') 88 | parser.add_argument('--no-wd', action='store_true', 89 | help='whether to remove weight decay on bias, \ 90 | and beta/gamma for batchnorm layers.') 91 | parser.add_argument('--sparsity', action='store_true', default= 92 | False, help='') 93 | parser.add_argument('--score-thresh', type=float, default=0.5, 94 | help='score-thresh') 95 | # cuda and logging 96 | parser.add_argument('--no-cuda', action='store_true', default= 97 | False, help='disables CUDA training') 98 | # parser.add_argument('--no-cuda', action='store_true', default= 99 | # True, help='disables CUDA training') 100 | # parser.add_argument('--ngpus', type=int, 101 | # default=len(mx.test_utils.list_gpus()), 102 | # help='number of GPUs (default: 4)') 103 | parser.add_argument('--gpus', type=str, default='0', 104 | help='Training with GPUs, you can specify 1,3 for example.') 105 | parser.add_argument('--kvstore', type=str, default='device', 106 | help='kvstore to use for trainer/module.') 107 | parser.add_argument('--dtype', type=str, default='float32', 108 | help='data type for training. default is float32') 109 | parser.add_argument('--wd', type=float, default=0.0001, 110 | help='weight decay rate. default is 0.0001.') 111 | 112 | # checking point 113 | parser.add_argument('--resume', type=str, default=None, 114 | help='put the path to resuming file if needed') 115 | parser.add_argument('--colab', action='store_true', default= 116 | False, help='whether using colab') 117 | 118 | # evaluation only 119 | parser.add_argument('--eval', action='store_true', default= False, 120 | help='evaluation only') 121 | parser.add_argument('--no-val', action='store_true', default= False, 122 | help='skip validation during training') 123 | parser.add_argument('--metric', type=str, default='mAP', 124 | help='F1, IoU, mAP') 125 | 126 | # synchronized Batch Normalization 127 | parser.add_argument('--syncbn', action='store_true', default= False, 128 | help='using Synchronized Cross-GPU BatchNorm') 129 | # the parser 130 | args = parser.parse_args() 131 | # handle contexts 132 | if args.no_cuda or (len(mx.test_utils.list_gpus()) == 0): 133 | print('Using CPU') 134 | args.kvstore = 'local' 135 | args.ctx = [mx.cpu(0)] 136 | else: 137 | # print('Number of GPUs:', args.ngpus) 138 | # args.ctx = [mx.gpu(i) for i in range(args.ngpus)] 139 | args.ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] 140 | print('Number of GPUs:', len(args.ctx)) 141 | 142 | # Synchronized BatchNorm 143 | args.norm_layer = mx.gluon.contrib.nn.SyncBatchNorm if args.syncbn \ 144 | else mx.gluon.nn.BatchNorm 145 | # args.norm_kwargs = {'num_devices': args.ngpus} if args.syncbn else {} 146 | args.norm_kwargs = {'num_devices': len(args.ctx)} if args.syncbn else {} 147 | print(args) 148 | return args 149 | 150 | 151 | class Trainer(object): 152 | def __init__(self, args): 153 | self.args = args 154 | # image transform 155 | input_transform = transforms.Compose([ 156 | transforms.ToTensor(), 157 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), # Default mean and std 158 | # transforms.Normalize([.418, .447, .571], [.091, .078, .076]), # Iceberg mean and std 159 | ]) 160 | ################################# dataset and dataloader ################################# 161 | if platform.system() == "Darwin": 162 | data_root = os.path.join('~', 'Nutstore Files', 'Dataset') 163 | elif platform.system() == "Linux": 164 | data_root = os.path.join('~', 'datasets') 165 | if args.colab: 166 | # data_root = '/content/gdrive/My Drive/Colab Notebooks/datasets' 167 | data_root = '/content/datasets' 168 | else: 169 | raise ValueError('Notice Dataset Path') 170 | 171 | data_kwargs = {'base_size': args.base_size, 'transform': input_transform, 172 | 'crop_size': args.crop_size, 'root': data_root, 173 | 'base_dir' : args.dataset} 174 | trainset = IceContrast(split=args.train_split, mode='train', **data_kwargs) 175 | valset = IceContrast(split=args.val_split, mode='testval', **data_kwargs) 176 | 177 | self.train_data = gluon.data.DataLoader(trainset, args.batch_size, shuffle=True, 178 | last_batch='rollover', num_workers=args.workers) 179 | self.eval_data = gluon.data.DataLoader(valset, args.test_batch_size, 180 | last_batch='rollover', num_workers=args.workers) 181 | 182 | # net_choice = 'PCMNet' # ResNetFPN, PCMNet, MPCMNet, LayerwiseMPCMNet 183 | net_choice = self.args.net_choice 184 | print("net_choice: ", net_choice) 185 | 186 | if net_choice == 'MPCMResNetFPN': 187 | r = self.args.r 188 | layers = [self.args.blocks] * 3 189 | channels = [8, 16, 32, 64] 190 | shift = self.args.shift 191 | pyramid_mode = self.args.pyramid_mode 192 | scale_mode = self.args.scale_mode 193 | pyramid_fuse = self.args.pyramid_fuse 194 | 195 | model = MPCMResNetFPN(layers=layers, channels=channels, shift=shift, 196 | pyramid_mode=pyramid_mode, scale_mode=scale_mode, 197 | pyramid_fuse=pyramid_fuse, r=r, classes=trainset.NUM_CLASS) 198 | print("net_choice: ", net_choice) 199 | print("scale_mode: ", scale_mode) 200 | print("pyramid_fuse: ", pyramid_fuse) 201 | print("r: ", r) 202 | print("layers: ", layers) 203 | print("channels: ", channels) 204 | print("shift: ", shift) 205 | 206 | 207 | self.host_name = socket.gethostname() 208 | self.save_prefix = self.host_name + '_' + net_choice + '_scale-mode_' + args.scale_mode + \ 209 | '_pyramid-fuse_' + args.pyramid_fuse + '_b_' + str(args.blocks) 210 | if args.net_choice == 'ResNetFCN': 211 | self.save_prefix = self.host_name + '_' + net_choice + '_b_' + str(args.blocks) 212 | 213 | # resume checkpoint if needed 214 | if args.resume is not None: 215 | if os.path.isfile(args.resume): 216 | model.load_parameters(args.resume, ctx=args.ctx) 217 | else: 218 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) 219 | else: 220 | # model.initialize(init=init.Xavier(), ctx=args.ctx, force_reinit=True) 221 | model.initialize(init=init.MSRAPrelu(), ctx=args.ctx, force_reinit=True) 222 | print("Model Initializing") 223 | print("args.ctx: ", args.ctx) 224 | 225 | 226 | self.net = model 227 | # self.net.summary(mx.nd.zeros((1, 3, 480, 480))) 228 | 229 | if args.summary: 230 | self.net.summary(mx.nd.zeros((1, 3, 480, 480), self.args.ctx[0])) 231 | sys.exit() 232 | 233 | # create criterion 234 | self.criterion = SoftIoULoss() 235 | 236 | # optimizer and lr scheduling 237 | self.lr_scheduler = LRScheduler(mode='poly', base_lr=args.lr, 238 | nepochs=args.epochs, 239 | iters_per_epoch=len(self.train_data), 240 | power=0.9) 241 | kv = mx.kv.create(args.kvstore) 242 | 243 | # For SGD 244 | # optimizer_params = {'lr_scheduler': self.lr_scheduler, 245 | # 'wd': args.weight_decay, 246 | # 'momentum': args.momentum, 247 | # 'learning_rate': args.lr 248 | # } 249 | optimizer_params = { 250 | # 'lr_scheduler': self.lr_scheduler, 251 | 'wd': args.weight_decay, 252 | 'learning_rate': args.lr 253 | } 254 | # For Adam 255 | 256 | if args.dtype == 'float16': 257 | optimizer_params['multi_precision'] = True 258 | 259 | if args.no_wd: 260 | for k, v in self.net.collect_params('.*beta|.*gamma|.*bias').items(): 261 | v.wd_mult = 0.0 262 | 263 | # self.optimizer = gluon.Trainer(self.net.collect_params(), 'sgd', 264 | # optimizer_params, kvstore = kv) 265 | # self.optimizer = gluon.Trainer(self.net.collect_params(), 'adam', 266 | # optimizer_params, kvstore = kv) 267 | self.optimizer = gluon.Trainer(self.net.collect_params(), 'adagrad', 268 | optimizer_params, kvstore=kv) 269 | # self.optimizer = gluon.Trainer(self.net.collect_params(), 'nag', 270 | # optimizer_params, kvstore=kv) 271 | 272 | ################################# evaluation metrics ################################# 273 | 274 | self.iou_metric = SigmoidMetric(1) 275 | self.nIoU_metric = SamplewiseSigmoidMetric(1, score_thresh=self.args.score_thresh) 276 | # self.metric = Seg2DetVOC07MApMetric(iou_thresh=self.args.iou_thresh, 277 | # sparsity=self.args.sparsity, 278 | # score_thresh=self.args.score_thresh) 279 | self.best_metric = 0 280 | self.best_iou = 0 281 | self.best_nIoU = 0 282 | self.is_best = False 283 | 284 | def training(self, epoch): 285 | tbar = tqdm(self.train_data) 286 | train_loss = 0.0 287 | for i, batch in enumerate(tbar): 288 | data = gluon.utils.split_and_load(batch[0], ctx_list=self.args.ctx, batch_axis=0) 289 | labels = gluon.utils.split_and_load(batch[1], ctx_list=self.args.ctx, batch_axis=0) 290 | losses = [] 291 | with autograd.record(True): 292 | for x, y in zip(data, labels): 293 | pred = self.net(x) 294 | loss = self.criterion(pred, y) 295 | losses.append(loss) 296 | mx.nd.waitall() 297 | autograd.backward(losses) 298 | self.optimizer.step(self.args.batch_size) 299 | for loss in losses: 300 | train_loss += np.mean(loss.asnumpy()) / len(losses) 301 | tbar.set_description('Epoch %d, training loss %.4f' % (epoch, train_loss/(i+1))) 302 | 303 | def validation(self, epoch): 304 | self.iou_metric.reset() 305 | self.nIoU_metric.reset() 306 | # self.metric.reset() 307 | tbar = tqdm(self.eval_data) 308 | for i, batch in enumerate(tbar): 309 | data = gluon.utils.split_and_load(batch[0], ctx_list=self.args.ctx, batch_axis=0) 310 | labels = gluon.utils.split_and_load(batch[1], ctx_list=self.args.ctx, batch_axis=0) 311 | preds = [] 312 | for x, y in zip(data, labels): 313 | pred = self.net(x) 314 | preds.append(pred) 315 | # self.metric.update(preds, labels) 316 | self.iou_metric.update(preds, labels) 317 | self.nIoU_metric.update(preds, labels) 318 | 319 | _, IoU = self.iou_metric.get() 320 | _, nIoU = self.nIoU_metric.get() 321 | tbar.set_description('Epoch %d, IoU: %.4f, nIoU: %.4f' % (epoch, IoU, nIoU)) 322 | 323 | if IoU > self.best_iou: 324 | self.best_iou = IoU 325 | self.net.save_parameters('tmp_{:s}_best_{:s}.params'.format( 326 | self.save_prefix, 'IoU')) 327 | with open(self.save_prefix + '_GPU_' + self.args.gpus + 328 | '_best_IoU.log', 'a') as f: 329 | now = datetime.now() 330 | dt_string = now.strftime("%d/%m/%Y %H:%M:%S") 331 | f.write('{} - {:04d}:\t{:.4f}\n'.format(dt_string, epoch, IoU)) 332 | 333 | if nIoU > self.best_nIoU: 334 | self.best_nIoU = nIoU 335 | self.net.save_parameters('tmp_{:s}_best_{:s}.params'.format( 336 | self.save_prefix, 'nIoU')) 337 | with open(self.save_prefix + '_GPU_' + self.args.gpus + 338 | '_best_nIoU.log', 'a') as f: 339 | now = datetime.now() 340 | dt_string = now.strftime("%d/%m/%Y %H:%M:%S") 341 | f.write('{} - {:04d}:\t{:.4f}\n'.format(dt_string, epoch, nIoU)) 342 | 343 | if epoch >= args.epochs - 1: 344 | print("best_iou: ", self.best_iou) 345 | print("best_nIoU: ", self.best_nIoU) 346 | 347 | if epoch >= args.epochs - 1: 348 | with open(self.save_prefix + '_' + '_GPU_' + self.args.gpus + 349 | '_best_IoU.log', 'a') as f: 350 | f.write('Finished\n') 351 | self.net.save_parameters('tmp_{:s}_best_{:s}_{:s}.params'.format( 352 | self.save_prefix, 'IoU', str(self.best_iou))) 353 | self.net.save_parameters('tmp_{:s}_best_{:s}_{:s}.params'.format( 354 | self.save_prefix, 'nIoU', str(self.best_nIoU))) 355 | 356 | 357 | if __name__ == "__main__": 358 | args = parse_args() 359 | trainer = Trainer(args) 360 | if args.eval: 361 | print('Evaluating model: ', args.resume) 362 | trainer.validation(args.start_epoch) 363 | else: 364 | print('Starting Epoch:', args.start_epoch) 365 | print('Total Epochs:', args.epochs) 366 | for epoch in range(args.start_epoch, args.epochs): 367 | trainer.training(epoch) 368 | if not trainer.args.no_val: 369 | trainer.validation(epoch) 370 | -------------------------------------------------------------------------------- /model/segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | from mxnet.gluon.block import HybridBlock 4 | from mxnet.gluon import nn 5 | from mxnet.gluon.nn import BatchNorm 6 | from gluoncv.model_zoo.fcn import _FCNHead 7 | from mxnet import nd 8 | 9 | # from gluoncv.model_zoo.resnetv1b import BasicBlockV1b 10 | from gluoncv.model_zoo.cifarresnet import CIFARBasicBlockV1 11 | 12 | 13 | class ASKCResNetFPN(HybridBlock): 14 | def __init__(self, layers, channels, fuse_mode, act_dilation, classes=1, tinyFlag=False, 15 | norm_layer=BatchNorm, norm_kwargs=None, **kwargs): 16 | super(ASKCResNetFPN, self).__init__(**kwargs) 17 | 18 | self.layer_num = len(layers) 19 | self.tinyFlag = tinyFlag 20 | with self.name_scope(): 21 | 22 | stem_width = int(channels[0]) 23 | self.stem = nn.HybridSequential(prefix='stem') 24 | self.stem.add(norm_layer(scale=False, center=False, 25 | **({} if norm_kwargs is None else norm_kwargs))) 26 | if tinyFlag: 27 | self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 28 | padding=1, use_bias=False)) 29 | self.stem.add(norm_layer(in_channels=stem_width*2)) 30 | self.stem.add(nn.Activation('relu')) 31 | else: 32 | self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2, 33 | padding=1, use_bias=False)) 34 | self.stem.add(norm_layer(in_channels=stem_width)) 35 | self.stem.add(nn.Activation('relu')) 36 | self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1, 37 | padding=1, use_bias=False)) 38 | self.stem.add(norm_layer(in_channels=stem_width)) 39 | self.stem.add(nn.Activation('relu')) 40 | self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 41 | padding=1, use_bias=False)) 42 | self.stem.add(norm_layer(in_channels=stem_width*2)) 43 | self.stem.add(nn.Activation('relu')) 44 | self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 45 | 46 | # self.head1 = _FCNHead(in_channels=channels[1], channels=classes) 47 | # self.head2 = _FCNHead(in_channels=channels[2], channels=classes) 48 | # self.head3 = _FCNHead(in_channels=channels[3], channels=classes) 49 | # self.head4 = _FCNHead(in_channels=channels[4], channels=classes) 50 | 51 | self.head = _FCNHead(in_channels=channels[1], channels=classes) 52 | 53 | self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0], 54 | channels=channels[1], stride=1, stage_index=1, 55 | in_channels=channels[1]) 56 | 57 | self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1], 58 | channels=channels[2], stride=2, stage_index=2, 59 | in_channels=channels[1]) 60 | 61 | self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2], 62 | channels=channels[3], stride=2, stage_index=3, 63 | in_channels=channels[2]) 64 | 65 | if self.layer_num == 4: 66 | self.layer4 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[3], 67 | channels=channels[4], stride=2, stage_index=4, 68 | in_channels=channels[3]) 69 | 70 | if self.layer_num == 4: 71 | self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[3], 72 | act_dilation=act_dilation) # channels[4] 73 | 74 | self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[2], 75 | act_dilation=act_dilation) # 64 76 | self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[1], 77 | act_dilation=act_dilation) # 32 78 | 79 | # if fuse_order == 'reverse': 80 | # self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[2]) # channels[2] 81 | # self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[3]) # channels[3] 82 | # self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] 83 | # elif fuse_order == 'normal': 84 | # self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] 85 | # self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] 86 | # self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[4]) # channels[4] 87 | 88 | def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0, 89 | norm_layer=BatchNorm, norm_kwargs=None): 90 | layer = nn.HybridSequential(prefix='stage%d_'%stage_index) 91 | with layer.name_scope(): 92 | downsample = (channels != in_channels) or (stride != 1) 93 | layer.add(block(channels, stride, downsample, in_channels=in_channels, 94 | prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 95 | for _ in range(layers-1): 96 | layer.add(block(channels, 1, False, in_channels=channels, prefix='', 97 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 98 | return layer 99 | 100 | def _fuse_layer(self, fuse_mode, channels, act_dilation): 101 | if fuse_mode == 'Direct_Add': 102 | fuse_layer = Direct_AddFuse_Reduce(channels=channels) 103 | elif fuse_mode == 'Concat': 104 | fuse_layer = ConcatFuse_Reduce(channels=channels) 105 | elif fuse_mode == 'SK': 106 | fuse_layer = SKFuse_Reduce(channels=channels) 107 | # elif fuse_mode == 'LocalCha': 108 | # fuse_layer = LocalChaFuse(channels=channels) 109 | # elif fuse_mode == 'GlobalCha': 110 | # fuse_layer = GlobalChaFuse(channels=channels) 111 | elif fuse_mode == 'LocalGlobalCha': 112 | fuse_layer = LocalGlobalChaFuse_Reduce(channels=channels) 113 | elif fuse_mode == 'LocalLocalCha': 114 | fuse_layer = LocalLocalChaFuse_Reduce(channels=channels) 115 | elif fuse_mode == 'GlobalGlobalCha': 116 | fuse_layer = GlobalGlobalChaFuse_Reduce(channels=channels) 117 | elif fuse_mode == 'IASKCChaFuse': 118 | fuse_layer = IASKCChaFuse_Reduce(channels=channels) 119 | elif fuse_mode == 'AYforXplusY': 120 | fuse_layer = AYforXplusYChaFuse_Reduce(channels=channels) 121 | elif fuse_mode == 'AXYforXplusY': 122 | fuse_layer = AXYforXplusYChaFuse_Reduce(channels=channels) 123 | elif fuse_mode == 'XplusAYforY': 124 | fuse_layer = XplusAYforYChaFuse_Reduce(channels=channels) 125 | elif fuse_mode == 'GAU': 126 | fuse_layer = GAUChaFuse_Reduce(channels=channels) 127 | elif fuse_mode == 'LocalGAU': 128 | fuse_layer = LocalGAUChaFuse_Reduce(channels=channels) 129 | elif fuse_mode == 'SpaFuse': 130 | fuse_layer = SpaFuse_Reduce(channels=channels, act_dialtion=act_dilation) 131 | elif fuse_mode == 'BiLocalCha': 132 | fuse_layer = BiLocalChaFuse_Reduce(channels=channels) 133 | elif fuse_mode == 'BiGlobalLocalCha': 134 | fuse_layer = BiGlobalLocalChaFuse_Reduce(channels=channels) 135 | elif fuse_mode == 'AsymBiLocalCha': 136 | fuse_layer = AsymBiLocalChaFuse_Reduce(channels=channels) 137 | elif fuse_mode == 'BiGlobalCha': 138 | fuse_layer = BiGlobalChaFuse_Reduce(channels=channels) 139 | elif fuse_mode == 'BiSpaCha': 140 | fuse_layer = BiSpaChaFuse_Reduce(channels=channels) 141 | elif fuse_mode == 'AsymBiSpaCha': 142 | fuse_layer = AsymBiSpaChaFuse_Reduce(channels=channels) 143 | # elif fuse_mode == 'LocalSpa': 144 | # fuse_layer = LocalSpaFuse(channels=channels, act_dilation=act_dilation) 145 | # elif fuse_mode == 'GlobalSpa': 146 | # fuse_layer = GlobalSpaFuse(channels=channels, act_dilation=act_dilation) 147 | # elif fuse_mode == 'SK_MSSpa': 148 | # # fuse_layer.add(SK_MSSpaFuse(channels=channels, act_dilation=act_dilation)) 149 | # fuse_layer = SK_MSSpaFuse(channels=channels, act_dilation=act_dilation) 150 | else: 151 | raise ValueError('Unknown fuse_mode') 152 | 153 | return fuse_layer 154 | 155 | def hybrid_forward(self, F, x): 156 | 157 | _, _, hei, wid = x.shape 158 | 159 | x = self.stem(x) # down 4, 32 160 | c1 = self.layer1(x) # down 4, 32 161 | c2 = self.layer2(c1) # down 8, 64 162 | out = self.layer3(c2) # down 16, 128 163 | if self.layer_num == 4: 164 | c4 = self.layer4(out) # down 32 165 | if self.tinyFlag: 166 | c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4) # down 4 167 | else: 168 | c4 = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16) # down 16 169 | out = self.fuse34(c4, out) 170 | if self.tinyFlag: 171 | out = F.contrib.BilinearResize2D(out, height=hei//2, width=wid//2) # down 2, 128 172 | else: 173 | out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8) # down 8, 128 174 | out = self.fuse23(out, c2) 175 | if self.tinyFlag: 176 | out = F.contrib.BilinearResize2D(out, height=hei, width=wid) # down 1 177 | else: 178 | out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4) # down 8 179 | out = self.fuse12(out, c1) 180 | 181 | pred = self.head(out) 182 | if self.tinyFlag: 183 | out = pred 184 | else: 185 | out = F.contrib.BilinearResize2D(pred, height=hei, width=wid) # down 4 186 | 187 | ######### reverse order ########## 188 | # up_c2 = F.contrib.BilinearResize2D(c2, height=hei//4, width=wid//4) # down 4 189 | # fuse2 = self.fuse12(up_c2, c1) # down 4, channels[2] 190 | # 191 | # up_c3 = F.contrib.BilinearResize2D(c3, height=hei//4, width=wid//4) # down 4 192 | # fuse3 = self.fuse23(up_c3, fuse2) # down 4, channels[3] 193 | # 194 | # up_c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4) # down 4 195 | # fuse4 = self.fuse34(up_c4, fuse3) # down 4, channels[4] 196 | # 197 | 198 | ######### normal order ########## 199 | # out = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16) 200 | # out = self.fuse34(out, c3) 201 | # out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8) 202 | # out = self.fuse23(out, c2) 203 | # out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4) 204 | # out = self.fuse12(out, c1) 205 | # out = self.head(out) 206 | # out = F.contrib.BilinearResize2D(out, height=hei, width=wid) 207 | 208 | 209 | return out 210 | 211 | def evaluate(self, x): 212 | """evaluating network with inputs and targets""" 213 | return self.forward(x) 214 | 215 | 216 | class BasicContextNet(HybridBlock): 217 | def __init__(self, dilations=[1, 1, 2, 4, 8, 16], channels=16, classes=1, 218 | conv_mode='xxx', act_type='relu', skernel=3, act_dilation=16, 219 | useReLU=False, use_act_head=False, check_fullly=False, act_layers=4, 220 | act_order='xxx', asBackbone=False, addstem=False, maxpool=True, **kwargs): 221 | super(BasicContextNet, self).__init__(**kwargs) 222 | assert act_type in ['swish', 'prelu', 'relu', 'xUnit', 'SeqATAC', 'SpaATAC', 'ChaATAC', 223 | 'MSSeqATAC', 'MSSeqATACAdd', 'MSSeqATACConcat'], "Unknown act_type" 224 | assert conv_mode in ['fixed', 'learned', 'ChaDyReF', 'SeqDyReF', 'SK_ChaDyReF', 225 | 'SK_1x1DepthDyReF', 'SK_MSSpaDyReF', 'SK_SpaDyReF', 226 | 'Direct_Add', 'SKCell', 'SK_SeqDyReF', 'Sub_MSSpaDyReF', 227 | 'SK_MSSeqDyReF', 'iAAMSSpaDyReF'], \ 228 | "Unknown conv_mode" 229 | # stem_width = int(channels // 2) 230 | with self.name_scope(): 231 | self.features = nn.HybridSequential(prefix='') 232 | if addstem: 233 | self.features.add(nn.Conv2D(channels=channels, kernel_size=3, strides=2, 234 | padding=1, use_bias=False)) 235 | self.features.add(nn.BatchNorm(in_channels=channels)) 236 | self.features.add(nn.Activation('relu')) 237 | self.features.add(nn.Conv2D(channels=channels, kernel_size=3, strides=1, 238 | padding=1, use_bias=False)) 239 | self.features.add(nn.BatchNorm(in_channels=channels)) 240 | self.features.add(nn.Activation('relu')) 241 | self.features.add(nn.Conv2D(channels=channels*2, kernel_size=3, strides=1, 242 | padding=1, use_bias=False)) 243 | self.features.add(nn.BatchNorm(in_channels=channels*2)) 244 | self.features.add(nn.Activation('relu')) 245 | if maxpool: 246 | self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 247 | 248 | for i, dilation in enumerate(dilations): 249 | self.features.add(self._make_layer( 250 | dilation=dilation, channels=channels, stage_index=i, conv_mode=conv_mode, 251 | act_type=act_type, skernel=skernel, act_dilation=act_dilation, 252 | useReLU=useReLU, check_fullly=check_fullly, act_layers=act_layers, 253 | act_order=act_order, asBackbone=asBackbone)) 254 | if use_act_head: 255 | self.head = ATAC_FCNHead(head_act=act_type, useReLU=useReLU, 256 | in_channels=channels, channels=classes) 257 | else: 258 | self.head = _FCNHead(in_channels=channels, channels=classes) 259 | 260 | def _make_layer(self, dilation, channels, stage_index, conv_mode, act_type, skernel, 261 | act_dilation, useReLU, check_fullly, act_layers, act_order, asBackbone): 262 | layer = nn.HybridSequential(prefix='stage%d_' % stage_index) 263 | with layer.name_scope(): 264 | 265 | if check_fullly: 266 | if act_order == 'bac': 267 | # 后面的层优先用 Attention 268 | if stage_index + act_layers < 5: 269 | act_type = 'relu' 270 | elif act_order == 'pre': 271 | # 前面的层优先用 Attention 272 | if act_layers - stage_index - 1 < 0: 273 | act_type = 'relu' 274 | else: 275 | raise ValueError('Unknown act_order') 276 | 277 | if conv_mode == 'fixed': 278 | 279 | layer.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilation, 280 | padding=dilation)) 281 | layer.add(nn.BatchNorm()) 282 | 283 | if act_type == 'prelu': 284 | layer.add(nn.PReLU()) 285 | elif act_type == 'relu': 286 | layer.add(nn.Activation('relu')) 287 | elif act_type == 'swish': 288 | layer.add(nn.Swish()) 289 | elif act_type == 'xUnit': 290 | layer.add(xUnit(channels=channels, skernel_size=5)) 291 | elif act_type == 'SpaATAC': 292 | layer.add(SpaATAC(skernel=skernel, channels=channels, dilation=act_dilation, 293 | useReLU=useReLU, asBackbone=asBackbone)) 294 | elif act_type == 'ChaATAC': 295 | layer.add(ChaATAC(channels=channels, useReLU=useReLU, useGlobal=False, 296 | asBackbone=asBackbone)) 297 | elif act_type == 'SeqATAC': 298 | layer.add(SeqATAC(skernel=skernel, channels=channels, dilation=act_dilation, 299 | useReLU=useReLU, asBackbone=asBackbone)) 300 | # layer.add(DilatedSeqATACBackbone(channels=channels, dilation=act_dilation)) 301 | elif act_type == 'MSSeqATAC': 302 | layer.add(MSSeqATAC(skernel=skernel, channels=channels, dilation=act_dilation, 303 | useReLU=useReLU, asBackbone=asBackbone)) 304 | # layer.add(DilatedSeqATACBackbone(channels=channels, dilation=act_dilation)) 305 | elif act_type == 'MSSeqATACAdd': 306 | layer.add(MSSeqATACAdd(skernel=skernel, channels=channels, 307 | dilation=act_dilation, useReLU=useReLU, 308 | asBackbone=asBackbone)) 309 | elif act_type == 'MSSeqATACConcat': 310 | layer.add(MSSeqATACConcat(skernel=skernel, channels=channels, 311 | dilation=act_dilation, useReLU=useReLU, 312 | asBackbone=asBackbone)) 313 | else: 314 | raise ValueError('Unknown act_type') 315 | 316 | elif conv_mode == 'learned': 317 | layer.add(LearnedCell(channels=channels, dilations=dilation)) 318 | elif conv_mode == 'ChaDyReF': 319 | layer.add(ChaDyReFCell(channels=channels, dilations=dilation)) 320 | elif conv_mode == 'SK_ChaDyReF': 321 | layer.add(SK_ChaDyReFCell(channels=channels, dilations=dilation)) 322 | elif conv_mode == 'SK_1x1DepthDyReF': 323 | layer.add(SK_1x1DepthDyReFCell(channels=channels, dilations=dilation)) 324 | elif conv_mode == 'SK_MSSpaDyReF': 325 | layer.add(SK_MSSpaDyReFCell(channels=channels, dilations=dilation, 326 | act_dilation=act_dilation, 327 | asBackbone=asBackbone)) 328 | elif conv_mode == 'iAAMSSpaDyReF': 329 | layer.add(iAAMSSpaDyReFCell(channels=channels, dilations=dilation, 330 | asBackbone=asBackbone)) 331 | elif conv_mode == 'SK_MSSeqDyReF': 332 | layer.add(SK_MSSeqDyReFCell(channels=channels, dilations=dilation, 333 | asBackbone=asBackbone)) 334 | elif conv_mode == 'Sub_MSSpaDyReF': 335 | layer.add(Sub_MSSpaDyReFCell(channels=channels, dilations=dilation, 336 | asBackbone=asBackbone)) 337 | 338 | elif conv_mode == 'Direct_Add': 339 | layer.add(Direct_AddCell(channels=channels, dilations=dilation, 340 | asBackbone=asBackbone)) 341 | elif conv_mode == 'SK_SpaDyReF': 342 | layer.add(SK_SpaDyReFCell(channels=channels, dilations=dilation, 343 | act_dilation=act_dilation)) 344 | elif conv_mode == 'SKCell': 345 | layer.add(SKCell(channels=channels, dilations=dilation)) 346 | elif conv_mode == 'SeqDyReF': 347 | layer.add(SeqDyReFCell(channels=channels, dilations=dilation, 348 | act_dilation=act_dilation, useReLU=useReLU, 349 | asBackbone=asBackbone)) 350 | elif conv_mode == 'SK_SeqDyReF': 351 | layer.add(SK_SeqDyReFCell(channels=channels, dilations=dilation, 352 | act_dilation=act_dilation, useReLU=useReLU, 353 | asBackbone=asBackbone)) 354 | elif conv_mode == 'dynamic': 355 | layer.add(DynamicCell(channels=channels, dilations=dilation)) 356 | else: 357 | raise ValueError('Unknown conv_mode') 358 | return layer 359 | 360 | def hybrid_forward(self, F, x): 361 | 362 | _, _, hei, wid = x.shape 363 | x = self.features(x) 364 | x = self.head(x) 365 | 366 | out = F.contrib.BilinearResize2D(x, height=hei, width=wid) 367 | 368 | return out 369 | 370 | def evaluate(self, x): 371 | """evaluating network with inputs and targets""" 372 | return self.forward(x) 373 | 374 | 375 | class ATAC_FCNHead(HybridBlock): 376 | # pylint: disable=redefined-outer-name 377 | def __init__(self, head_act, useReLU, in_channels, channels, norm_layer=nn.BatchNorm, 378 | norm_kwargs=None, **kwargs): 379 | super(ATAC_FCNHead, self).__init__() 380 | with self.name_scope(): 381 | self.block = nn.HybridSequential() 382 | inter_channels = in_channels // 4 383 | with self.block.name_scope(): 384 | self.block.add(nn.Conv2D(in_channels=in_channels, channels=inter_channels, 385 | kernel_size=3, padding=1, use_bias=False)) 386 | self.block.add(norm_layer(in_channels=inter_channels, 387 | **({} if norm_kwargs is None else norm_kwargs))) 388 | # self.block.add(nn.Activation('relu')) 389 | 390 | if head_act == 'prelu': 391 | self.block.add(nn.PReLU()) 392 | elif head_act == 'relu': 393 | self.block.add(nn.Activation('relu')) 394 | elif head_act == 'xUnit': 395 | self.block.add(xUnit(channels=inter_channels)) 396 | elif head_act == 'SpaATAC': 397 | self.block.add(SpaATAC(skernel=3, channels=inter_channels, dilation=1, 398 | useReLU=useReLU)) 399 | elif head_act == 'ChaATAC': 400 | self.block.add(ChaATAC(channels=inter_channels, useReLU=useReLU, 401 | useGlobal=False)) 402 | elif head_act == 'SeqATAC': 403 | self.block.add(SeqATAC(skernel=3, channels=inter_channels, dilation=1, 404 | useReLU=useReLU)) 405 | # layer.add(DilatedSeqATACBackbone(channels=channels, dilation=act_dilation)) 406 | else: 407 | raise ValueError('Unknown act_type') 408 | 409 | self.block.add(nn.Dropout(0.1)) 410 | self.block.add(nn.Conv2D(in_channels=inter_channels, channels=channels, 411 | kernel_size=1)) 412 | 413 | # pylint: disable=arguments-differ 414 | def hybrid_forward(self, F, x): 415 | return self.block(x) 416 | 417 | 418 | class DyRefNet(HybridBlock): 419 | def __init__(self, dilations=[1, 1, 2, 4, 8, 16], channels=16, classes=1, 420 | act_type='relu', skernel=3, act_dilation=16, useReLU=False, 421 | use_act_head=False, check_fullly=False, act_layers=4, act_order='xxx', 422 | asBackbone=False, **kwargs): 423 | super(DyRefNet, self).__init__(**kwargs) 424 | assert act_type in ['prelu', 'relu', 'xUnit', 'SeqATAC', 'SpaATAC', 'ChaATAC', 'MSSeqATAC'], \ 425 | "Unknown act_type" 426 | with self.name_scope(): 427 | self.features = nn.HybridSequential(prefix='') 428 | for i, dilation in enumerate(dilations): 429 | self.features.add(self._make_layer( 430 | dilation=dilation, channels=channels, stage_index=i, act_type=act_type, 431 | skernel=skernel, act_dilation=act_dilation, useReLU=useReLU, 432 | check_fullly=check_fullly, act_layers=act_layers, act_order=act_order, 433 | asBackbone=asBackbone)) 434 | if use_act_head: 435 | self.head = ATAC_FCNHead(head_act=act_type, useReLU=useReLU, 436 | in_channels=channels, channels=classes) 437 | else: 438 | self.head = _FCNHead(in_channels=channels, channels=classes) 439 | 440 | def _make_layer(self, dilation, channels, stage_index, act_type, skernel, 441 | act_dilation, useReLU, check_fullly, act_layers, act_order, asBackbone): 442 | layer = nn.HybridSequential(prefix='stage%d_' % stage_index) 443 | with layer.name_scope(): 444 | layer.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilation, 445 | padding=dilation)) 446 | layer.add(nn.BatchNorm()) 447 | 448 | if check_fullly: 449 | if act_order == 'bac': 450 | # 后面的层优先用 Attention 451 | if stage_index + act_layers < 4: 452 | act_type = 'relu' 453 | elif act_order == 'pre': 454 | # 前面的层优先用 Attention 455 | if act_layers - stage_index - 1 < 0: 456 | act_type = 'relu' 457 | else: 458 | raise ValueError('Unknown act_order') 459 | 460 | if act_type == 'prelu': 461 | layer.add(nn.PReLU()) 462 | elif act_type == 'relu': 463 | layer.add(nn.Activation('relu')) 464 | elif act_type == 'xUnit': 465 | layer.add(xUnit(channels=channels, skernel_size=5)) 466 | elif act_type == 'SpaATAC': 467 | layer.add(SpaATAC(skernel=skernel, channels=channels, dilation=act_dilation, 468 | useReLU=useReLU, asBackbone=asBackbone)) 469 | elif act_type == 'ChaATAC': 470 | layer.add(ChaATAC(channels=channels, useReLU=useReLU, useGlobal=False, 471 | asBackbone=asBackbone)) 472 | elif act_type == 'SeqATAC': 473 | layer.add(SeqATAC(skernel=skernel, channels=channels, dilation=act_dilation, 474 | useReLU=useReLU, asBackbone=asBackbone)) 475 | # layer.add(DilatedSeqATACBackbone(channels=channels, dilation=act_dilation)) 476 | elif act_type == 'MSSeqATAC': 477 | layer.add(MSSeqATAC(skernel=skernel, channels=channels, 478 | dilation=act_dilation, useReLU=useReLU, 479 | asBackbone=asBackbone)) 480 | # layer.add(DilatedSeqATACBackbone(channels=channels, dilation=act_dilation)) 481 | else: 482 | raise ValueError('Unknown act_type') 483 | return layer 484 | 485 | def hybrid_forward(self, F, x): 486 | x = self.features(x) 487 | x = self.head(x) 488 | 489 | return x 490 | 491 | 492 | class VisualBasicContextNet(HybridBlock): 493 | def __init__(self, dilations=[1, 1, 2, 4, 8, 16], channels=16, classes=1, 494 | conv_mode='xxx', act_type='relu', skernel=3, act_dilation=16, 495 | useReLU=False, use_act_head=False, check_fullly=False, act_layers=4, 496 | act_order='xxx', asBackbone=False, addstem=False, **kwargs): 497 | super(VisualBasicContextNet, self).__init__(**kwargs) 498 | assert act_type in ['swish', 'prelu', 'relu', 'xUnit', 'SeqATAC', 'SpaATAC', 'ChaATAC', 499 | 'MSSeqATAC', 'MSSeqATACAdd', 'MSSeqATACConcat'], "Unknown act_type" 500 | assert conv_mode in ['learned', 'fixed', 'dynamic'], "Unknown conv_mode" 501 | self.act_type = act_type 502 | with self.name_scope(): 503 | self.features = nn.HybridSequential(prefix='') 504 | if addstem: 505 | self.features.add(nn.Conv2D(channels=channels, kernel_size=3, strides=2, 506 | padding=1, use_bias=False)) 507 | self.features.add(nn.BatchNorm(in_channels=channels)) 508 | self.features.add(nn.Activation('relu')) 509 | self.features.add(nn.Conv2D(channels=channels, kernel_size=3, strides=1, 510 | padding=1, use_bias=False)) 511 | self.features.add(nn.BatchNorm(in_channels=channels)) 512 | self.features.add(nn.Activation('relu')) 513 | self.features.add(nn.Conv2D(channels=channels*2, kernel_size=3, strides=1, 514 | padding=1, use_bias=False)) 515 | self.features.add(nn.BatchNorm(in_channels=channels*2)) 516 | self.features.add(nn.Activation('relu')) 517 | self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 518 | 519 | for i, dilation in enumerate(dilations[:-1]): 520 | self.features.add(self._make_layer( 521 | dilation=dilation, channels=channels, stage_index=i, conv_mode=conv_mode, 522 | act_type=act_type, skernel=skernel, act_dilation=act_dilation, 523 | useReLU=useReLU, check_fullly=check_fullly, act_layers=act_layers, 524 | act_order=act_order, asBackbone=asBackbone)) 525 | 526 | self.features.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilations[-1], 527 | padding=dilations[-1])) 528 | self.features.add(nn.BatchNorm()) 529 | 530 | self.attention = nn.HybridSequential(prefix='') 531 | if act_type == 'MSSeqATACConcat': 532 | self.attention.add(MSSeqAttentionMap( 533 | skernel=skernel, channels=channels, dilation=act_dilation, 534 | useReLU=useReLU, asBackbone=asBackbone)) 535 | elif act_type == 'xUnit': 536 | self.attention.add(xUnitAttentionMap(channels=channels, skernel_size=5)) 537 | elif act_type == 'relu': 538 | self.attention.add(nn.Activation('relu')) 539 | 540 | if use_act_head: 541 | self.head = ATAC_FCNHead(head_act=act_type, useReLU=useReLU, 542 | in_channels=channels, channels=classes) 543 | else: 544 | self.head = _FCNHead(in_channels=channels, channels=classes) 545 | 546 | def _make_layer(self, dilation, channels, stage_index, conv_mode, act_type, skernel, 547 | act_dilation, useReLU, check_fullly, act_layers, act_order, asBackbone): 548 | layer = nn.HybridSequential(prefix='stage%d_' % stage_index) 549 | with layer.name_scope(): 550 | 551 | if check_fullly: 552 | if act_order == 'bac': 553 | # 后面的层优先用 Attention 554 | if stage_index + act_layers < 5: 555 | act_type = 'relu' 556 | elif act_order == 'pre': 557 | # 前面的层优先用 Attention 558 | if act_layers - stage_index - 1 < 0: 559 | act_type = 'relu' 560 | else: 561 | raise ValueError('Unknown act_order') 562 | 563 | if conv_mode == 'fixed': 564 | 565 | layer.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilation, 566 | padding=dilation)) 567 | layer.add(nn.BatchNorm()) 568 | 569 | if act_type == 'prelu': 570 | layer.add(nn.PReLU()) 571 | elif act_type == 'relu': 572 | layer.add(nn.Activation('relu')) 573 | elif act_type == 'swish': 574 | layer.add(nn.Swish()) 575 | elif act_type == 'xUnit': 576 | layer.add(xUnit(channels=channels, skernel_size=5)) 577 | elif act_type == 'SpaATAC': 578 | layer.add(SpaATAC(skernel=skernel, channels=channels, dilation=act_dilation, 579 | useReLU=useReLU, asBackbone=asBackbone)) 580 | elif act_type == 'ChaATAC': 581 | layer.add(ChaATAC(channels=channels, useReLU=useReLU, useGlobal=False, 582 | asBackbone=asBackbone)) 583 | elif act_type == 'SeqATAC': 584 | layer.add(SeqATAC(skernel=skernel, channels=channels, dilation=act_dilation, 585 | useReLU=useReLU, asBackbone=asBackbone)) 586 | # layer.add(DilatedSeqATACBackbone(channels=channels, dilation=act_dilation)) 587 | elif act_type == 'MSSeqATAC': 588 | layer.add(MSSeqATAC(skernel=skernel, channels=channels, dilation=act_dilation, 589 | useReLU=useReLU, asBackbone=asBackbone)) 590 | # layer.add(DilatedSeqATACBackbone(channels=channels, dilation=act_dilation)) 591 | elif act_type == 'MSSeqATACAdd': 592 | layer.add(MSSeqATACAdd(skernel=skernel, channels=channels, 593 | dilation=act_dilation, useReLU=useReLU, 594 | asBackbone=asBackbone)) 595 | elif act_type == 'MSSeqATACConcat': 596 | layer.add(MSSeqATACConcat(skernel=skernel, channels=channels, 597 | dilation=act_dilation, useReLU=useReLU, 598 | asBackbone=asBackbone)) 599 | else: 600 | raise ValueError('Unknown act_type') 601 | 602 | elif conv_mode == 'learned': 603 | layer.add(LearnedCell(channels=channels, dilations=dilation)) 604 | elif conv_mode == 'dynamic': 605 | layer.add(DynamicCell(channels=channels, dilations=dilation)) 606 | else: 607 | raise ValueError('Unknown conv_mode') 608 | return layer 609 | 610 | def hybrid_forward(self, F, x): 611 | 612 | _, _, hei, wid = x.shape 613 | x = self.features(x) 614 | if self.act_type == 'relu': 615 | x = self.attention(x) 616 | elif self.act_type == 'MSSeqATACConcat' or self.act_type == 'xUnit': 617 | a = self.attention(x) 618 | x = x * a 619 | # elif self.act_type == 'xUnit': 620 | # a = self.attention(x) 621 | # x = x * a 622 | else: 623 | raise ValueError("Unknown self.act_type") 624 | 625 | x = self.head(x) 626 | 627 | out = F.contrib.BilinearResize2D(x, height=hei, width=wid) 628 | 629 | return out 630 | 631 | def evaluate(self, x): 632 | """evaluating network with inputs and targets""" 633 | return self.forward(x) 634 | 635 | 636 | class BasicContextFPN(HybridBlock): 637 | def __init__(self, dilations=[1, 1, 2, 4, 8, 16], channels=16, classes=1, 638 | conv_mode='xxx', fuse_mode='xxx', act_type='relu', skernel=3, act_dilation=16, 639 | useReLU=False, use_act_head=False, check_fullly=False, act_layers=4, 640 | act_order='xxx', asBackbone=False, addstem=False, maxpool=True, **kwargs): 641 | super(BasicContextFPN, self).__init__(**kwargs) 642 | 643 | assert act_type in ['swish', 'prelu', 'relu', 'xUnit', 'SeqATAC', 'SpaATAC', 'ChaATAC', 644 | 'MSSeqATAC', 'MSSeqATACAdd', 'MSSeqATACConcat'], "Unknown act_type" 645 | assert conv_mode in ['fixed', 'learned', 'ChaDyReF', 'SeqDyReF', 'SK_ChaDyReF', 646 | 'SK_1x1DepthDyReF', 'SK_MSSpaDyReF', 'SK_SpaDyReF', 'Direct_Add', 647 | 'SKCell', 'SK_SeqDyReF', 'Sub_MSSpaDyReF', 'SK_MSSeqDyReF'], \ 648 | "Unknown conv_mode" 649 | # assert fuse_mode in ['Direct_Add', 'SK', 'SK_MSSpa', 'LocalCha', 'GlobalCha', 'LocalGlobalCha', 'MSSpaLGCha'], \ 650 | # "Unknown fuse_mode" 651 | stem_width = int(channels // 2) 652 | self.layer_num = len(dilations) 653 | with self.name_scope(): 654 | self.stem = nn.HybridSequential(prefix='stem') 655 | if addstem: 656 | self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2, 657 | padding=1, use_bias=False)) 658 | self.stem.add(nn.BatchNorm(in_channels=stem_width)) 659 | self.stem.add(nn.Activation('relu')) 660 | self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1, 661 | padding=1, use_bias=False)) 662 | self.stem.add(nn.BatchNorm(in_channels=stem_width)) 663 | self.stem.add(nn.Activation('relu')) 664 | self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 665 | padding=1, use_bias=False)) 666 | self.stem.add(nn.BatchNorm(in_channels=stem_width*2)) 667 | self.stem.add(nn.Activation('relu')) 668 | if maxpool: 669 | self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 670 | 671 | self.stage_1 = nn.HybridSequential(prefix='stage_1') 672 | self.stage_1.add(self._make_layer( 673 | dilation=dilations[0], channels=channels, stage_index=0, conv_mode=conv_mode, 674 | act_type=act_type, skernel=skernel, act_dilation=act_dilation, 675 | useReLU=useReLU, check_fullly=check_fullly, act_layers=act_layers, 676 | act_order=act_order, asBackbone=asBackbone)) 677 | if self.layer_num >= 2: 678 | self.stage_1.add(self._make_layer( 679 | dilation=dilations[1], channels=channels, stage_index=1, conv_mode=conv_mode, 680 | act_type=act_type, skernel=skernel, act_dilation=act_dilation, 681 | useReLU=useReLU, check_fullly=check_fullly, act_layers=act_layers, 682 | act_order=act_order, asBackbone=asBackbone)) 683 | 684 | # (1, 1, 2) 685 | if self.layer_num >= 3: 686 | self.stage_2 = self._make_layer( 687 | dilation=dilations[2], channels=channels, stage_index=2, 688 | conv_mode=conv_mode, act_type=act_type, skernel=skernel, 689 | act_dilation=act_dilation, useReLU=useReLU, check_fullly=check_fullly, 690 | act_layers=act_layers, act_order=act_order, asBackbone=asBackbone) 691 | self.fuse12 = self._fuse_layer(fuse_mode=fuse_mode, channels=channels, 692 | act_dilation=act_dilation, useReLU=useReLU, 693 | fuse_index=12) 694 | 695 | # (1, 1, 2, 4) 696 | if self.layer_num >= 4: 697 | self.stage_3 = self._make_layer( 698 | dilation=dilations[3], channels=channels, stage_index=3, 699 | conv_mode=conv_mode, act_type=act_type, skernel=skernel, 700 | act_dilation=act_dilation, useReLU=useReLU, check_fullly=check_fullly, 701 | act_layers=act_layers, act_order=act_order, asBackbone=asBackbone) 702 | self.fuse23 = self._fuse_layer(fuse_mode=fuse_mode, channels=channels, 703 | act_dilation=act_dilation, useReLU=useReLU, 704 | fuse_index=23) 705 | 706 | # (1, 1, 2, 4, 8) 707 | if self.layer_num >= 5: 708 | self.stage_4 = self._make_layer( 709 | dilation=dilations[4], channels=channels, stage_index=4, 710 | conv_mode=conv_mode, act_type=act_type, skernel=skernel, 711 | act_dilation=act_dilation, useReLU=useReLU, check_fullly=check_fullly, 712 | act_layers=act_layers, act_order=act_order, asBackbone=asBackbone) 713 | self.fuse34 = self._fuse_layer(fuse_mode=fuse_mode, channels=channels, 714 | act_dilation=act_dilation, useReLU=useReLU, 715 | fuse_index=34) 716 | 717 | # (1, 1, 2, 4, 8, 16) 718 | if self.layer_num >= 6: 719 | self.stage_5 = self._make_layer( 720 | dilation=dilations[5], channels=channels, stage_index=5, 721 | conv_mode=conv_mode, act_type=act_type, skernel=skernel, 722 | act_dilation=act_dilation, useReLU=useReLU, check_fullly=check_fullly, 723 | act_layers=act_layers, act_order=act_order, asBackbone=asBackbone) 724 | self.fuse45 = self._fuse_layer(fuse_mode=fuse_mode, channels=channels, 725 | act_dilation=act_dilation, useReLU=useReLU, 726 | fuse_index=45) 727 | 728 | self.head = _FCNHead(in_channels=channels, channels=classes) 729 | 730 | def _make_layer(self, dilation, channels, stage_index, conv_mode, act_type, skernel, 731 | act_dilation, useReLU, check_fullly, act_layers, act_order, asBackbone): 732 | layer = nn.HybridSequential(prefix='stage%d_' % stage_index) 733 | with layer.name_scope(): 734 | 735 | if conv_mode == 'fixed': 736 | layer.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilation, 737 | padding=dilation)) 738 | elif conv_mode == 'learned': 739 | layer.add(LearnedConv(channels=channels, dilations=dilation)) 740 | elif conv_mode == 'ChaDyReF': 741 | layer.add(ChaDyReFConv(channels=channels, dilations=dilation)) 742 | elif conv_mode == 'SK_ChaDyReF': 743 | layer.add(SK_ChaDyReFConv(channels=channels, dilations=dilation)) 744 | elif conv_mode == 'SK_1x1DepthDyReF': 745 | layer.add(SK_1x1DepthDyReFConv(channels=channels, dilations=dilation)) 746 | elif conv_mode == 'SK_MSSpaDyReF': 747 | layer.add(SK_MSSpaDyReFConv(channels=channels, dilations=dilation, 748 | asBackbone=asBackbone)) 749 | elif conv_mode == 'Direct_Add': 750 | layer.add(Direct_AddConv(channels=channels, dilations=dilation, 751 | asBackbone=asBackbone)) 752 | elif conv_mode == 'SK_SpaDyReF': 753 | layer.add(SK_SpaDyReFConv(channels=channels, dilations=dilation, 754 | act_dilation=act_dilation)) 755 | elif conv_mode == 'SKCell': 756 | layer.add(SKConv(channels=channels, dilations=dilation)) 757 | elif conv_mode == 'SeqDyReF': 758 | layer.add(SeqDyReFConv(channels=channels, dilations=dilation, 759 | act_dilation=act_dilation, useReLU=useReLU, 760 | asBackbone=asBackbone)) 761 | elif conv_mode == 'SK_SeqDyReF': 762 | layer.add(SK_SeqDyReFConv(channels=channels, dilations=dilation, 763 | act_dilation=act_dilation, useReLU=useReLU, 764 | asBackbone=asBackbone)) 765 | else: 766 | raise ValueError('Unknown conv_mode') 767 | 768 | layer.add(nn.BatchNorm()) 769 | layer.add(nn.Activation('relu')) 770 | 771 | return layer 772 | 773 | def _fuse_layer(self, fuse_mode, channels, act_dilation, useReLU, fuse_index): 774 | # fuse_layer = nn.HybridSequential(prefix='fuse%d_' % fuse_index) 775 | 776 | if fuse_mode == 'Direct_Add': 777 | # fuse_layer.add(Direct_AddFuse(channels=channels)) 778 | fuse_layer = Direct_AddFuse(channels=channels) 779 | elif fuse_mode == 'SK': 780 | fuse_layer = SKFuse(channels=channels) 781 | elif fuse_mode == 'LocalCha': 782 | fuse_layer = LocalChaFuse(channels=channels) 783 | elif fuse_mode == 'GlobalCha': 784 | fuse_layer = GlobalChaFuse(channels=channels) 785 | elif fuse_mode == 'LocalGlobalCha': 786 | fuse_layer = LocalGlobalChaFuse(channels=channels) 787 | elif fuse_mode == 'LocalSpa': 788 | fuse_layer = LocalSpaFuse(channels=channels, act_dilation=act_dilation) 789 | elif fuse_mode == 'GlobalSpa': 790 | fuse_layer = GlobalSpaFuse(channels=channels, act_dilation=act_dilation) 791 | elif fuse_mode == 'SK_MSSpa': 792 | # fuse_layer.add(SK_MSSpaFuse(channels=channels, act_dilation=act_dilation)) 793 | fuse_layer = SK_MSSpaFuse(channels=channels, act_dilation=act_dilation) 794 | else: 795 | raise ValueError('Unknown fuse_mode') 796 | 797 | return fuse_layer 798 | 799 | def hybrid_forward(self, F, x): 800 | 801 | _, _, hei, wid = x.shape 802 | 803 | xs = self.stem(x) # Subsampling 4 804 | x1 = self.stage_1(xs) # Subsampling 4, dilation 1 805 | 806 | if self.layer_num <= 2: 807 | xf = x1 808 | elif self.layer_num == 3: 809 | x2 = self.stage_2(x1) # Subsampling 4, dilation 2 810 | xf = self.fuse12(x2, x1) 811 | # xf = x2 + x1 812 | elif self.layer_num == 4: 813 | x2 = self.stage_2(x1) # Subsampling 4, dilation 2 814 | x3 = self.stage_3(x2) # Subsampling 4, dilation 4 815 | xf = self.fuse23(x3, x2) 816 | xf = self.fuse12(xf, x1) 817 | # xf = x3 + x2 818 | # xf = xf + x1 819 | elif self.layer_num == 5: 820 | x2 = self.stage_2(x1) # Subsampling 4, dilation 2 821 | x3 = self.stage_3(x2) # Subsampling 4, dilation 4 822 | x4 = self.stage_4(x3) # Subsampling 4, dilation 8 823 | xf = self.fuse34(x4, x3) 824 | xf = self.fuse23(xf, x2) 825 | xf = self.fuse12(xf, x1) 826 | # xf = x4 + x3 827 | # xf = xf + x2 828 | # xf = xf + x1 829 | elif self.layer_num == 6: 830 | x2 = self.stage_2(x1) # Subsampling 4, dilation 2 831 | x3 = self.stage_3(x2) # Subsampling 4, dilation 4 832 | x4 = self.stage_4(x3) # Subsampling 4, dilation 8 833 | x5 = self.stage_5(x4) # Subsampling 4, dilation 16 834 | xf = self.fuse45(x5, x4) 835 | xf = self.fuse34(xf, x3) 836 | xf = self.fuse23(xf, x2) 837 | xf = self.fuse12(xf, x1) 838 | # xf = x5 + x4 839 | # xf = xf + x3 840 | # xf = xf + x2 841 | # xf = xf + x1 842 | 843 | xo = self.head(xf) 844 | 845 | out = F.contrib.BilinearResize2D(xo, height=hei, width=wid) 846 | 847 | return out 848 | 849 | def evaluate(self, x): 850 | """evaluating network with inputs and targets""" 851 | return self.forward(x) 852 | 853 | 854 | 855 | # class MPCMResNetFPN(HybridBlock): 856 | # def __init__(self, layers, channels, shift=3, classes=1, 857 | # norm_layer=BatchNorm, norm_kwargs=None, **kwargs): 858 | # super(MPCMResNetFPN, self).__init__(**kwargs) 859 | # 860 | # self.layer_num = len(layers) 861 | # with self.name_scope(): 862 | # 863 | # self.shift = shift 864 | # 865 | # stem_width = int(channels[0]) 866 | # self.stem = nn.HybridSequential(prefix='stem') 867 | # self.stem.add(norm_layer(scale=False, center=False, 868 | # **({} if norm_kwargs is None else norm_kwargs))) 869 | # 870 | # self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2, 871 | # padding=1, use_bias=False)) 872 | # self.stem.add(norm_layer(in_channels=stem_width)) 873 | # self.stem.add(nn.Activation('relu')) 874 | # self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1, 875 | # padding=1, use_bias=False)) 876 | # self.stem.add(norm_layer(in_channels=stem_width)) 877 | # self.stem.add(nn.Activation('relu')) 878 | # self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 879 | # padding=1, use_bias=False)) 880 | # self.stem.add(norm_layer(in_channels=stem_width*2)) 881 | # self.stem.add(nn.Activation('relu')) 882 | # 883 | # self.head = _FCNHead(in_channels=channels[-1], channels=classes) 884 | # 885 | # self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0], 886 | # channels=channels[1], stride=1, stage_index=1, 887 | # in_channels=channels[1]) 888 | # 889 | # self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1], 890 | # channels=channels[2], stride=2, stage_index=2, 891 | # in_channels=channels[1]) 892 | # 893 | # self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2], 894 | # channels=channels[3], stride=2, stage_index=3, 895 | # in_channels=channels[2]) 896 | # 897 | # self.inc_c2 = nn.HybridSequential(prefix='inc_c2') 898 | # self.inc_c2.add(nn.Conv2D(channels=channels[3], kernel_size=1, strides=1, 899 | # padding=0, use_bias=False)) 900 | # self.inc_c2.add(norm_layer(in_channels=channels[-1])) 901 | # self.inc_c2.add(nn.Activation('relu')) 902 | # 903 | # self.inc_c1 = nn.HybridSequential(prefix='inc_c1') 904 | # self.inc_c1.add(nn.Conv2D(channels=channels[3], kernel_size=1, strides=1, 905 | # padding=0, use_bias=False)) 906 | # self.inc_c1.add(norm_layer(in_channels=channels[-1])) 907 | # self.inc_c1.add(nn.Activation('relu')) 908 | # 909 | # 910 | # def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0, 911 | # norm_layer=BatchNorm, norm_kwargs=None): 912 | # layer = nn.HybridSequential(prefix='stage%d_'%stage_index) 913 | # with layer.name_scope(): 914 | # downsample = (channels != in_channels) or (stride != 1) 915 | # layer.add(block(channels, stride, downsample, in_channels=in_channels, 916 | # prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 917 | # for _ in range(layers-1): 918 | # layer.add(block(channels, 1, False, in_channels=channels, prefix='', 919 | # norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 920 | # return layer 921 | # 922 | # def hybrid_forward(self, F, x): 923 | # 924 | # _, _, orig_hei, orig_wid = x.shape 925 | # x = self.stem(x) # sub 2 926 | # c1 = self.layer1(x) # sub 2 927 | # _, _, c1_hei, c1_wid = c1.shape 928 | # c2 = self.layer2(c1) # sub 4 929 | # _, _, c2_hei, c2_wid = c2.shape 930 | # c3 = self.layer3(c2) # sub 8 931 | # _, _, c3_hei, c3_wid = c3.shape 932 | # 933 | # # 1. upsampling(c3) -> c3PCM # size: sub 4 934 | # 935 | # # c3 -> c3PCM 936 | # # 2. pwconv(c2) -> c2PCM # size: sub 4 937 | # # 3. upsampling(c3PCM + c2PCM) # size: sub 2 938 | # # 4. pwconv(c1) -> c1PCM # size: sub 2 939 | # # 5. upsampling(upsampling(c3PCM + c2PCM)) + c1PCM 940 | # # 6. upsampling(upsampling(c3PCM + c2PCM)) + c1PCM 941 | # 942 | # c3pcm = self.cal_pcm(c3, shift=self.shift) 943 | # up_c3pcm = F.contrib.BilinearResize2D(c3pcm, height=c2_hei, width=c2_wid) # sub 4, 64 944 | # 945 | # inc_c2 = self.inc_c2(c2) # sub 4, 64 946 | # c2pcm = self.cal_pcm(inc_c2, shift=self.shift) 947 | # 948 | # c23pcm = up_c3pcm + c2pcm # sub 4, 64 949 | # 950 | # up_c23pcm = F.contrib.BilinearResize2D(c23pcm, height=c1_hei, width=c1_wid) # sub 2, 64 951 | # inc_c1 = self.inc_c1(c1) # sub 2, 64 952 | # c1pcm = self.cal_pcm(inc_c1, shift=self.shift) 953 | # 954 | # out = up_c23pcm + c1pcm # sub 2, 64 955 | # pred = self.head(out) 956 | # out = F.contrib.BilinearResize2D(pred, height=orig_hei, width=orig_wid) 957 | # 958 | # return out 959 | # 960 | # def evaluate(self, x): 961 | # """evaluating network with inputs and targets""" 962 | # return self.forward(x) 963 | # 964 | # def circ_shift(self, cen, shift): 965 | # 966 | # _, _, hei, wid = cen.shape 967 | # 968 | # ######## B1 ######### 969 | # # old: AD => new: CB 970 | # # BC => DA 971 | # B1_NW = cen[:, :, shift:, shift:] # B1_NW is cen's SE 972 | # B1_NE = cen[:, :, shift:, :shift] # B1_NE is cen's SW 973 | # B1_SW = cen[:, :, :shift, shift:] # B1_SW is cen's NE 974 | # B1_SE = cen[:, :, :shift, :shift] # B1_SE is cen's NW 975 | # B1_N = nd.concat(B1_NW, B1_NE, dim=3) 976 | # B1_S = nd.concat(B1_SW, B1_SE, dim=3) 977 | # B1 = nd.concat(B1_N, B1_S, dim=2) 978 | # 979 | # ######## B2 ######### 980 | # # old: A => new: B 981 | # # B => A 982 | # B2_N = cen[:, :, shift:, :] # B2_N is cen's S 983 | # B2_S = cen[:, :, :shift, :] # B2_S is cen's N 984 | # B2 = nd.concat(B2_N, B2_S, dim=2) 985 | # 986 | # ######## B3 ######### 987 | # # old: AD => new: CB 988 | # # BC => DA 989 | # B3_NW = cen[:, :, shift:, wid-shift:] # B3_NW is cen's SE 990 | # B3_NE = cen[:, :, shift:, :wid-shift] # B3_NE is cen's SW 991 | # B3_SW = cen[:, :, :shift, wid-shift:] # B3_SW is cen's NE 992 | # B3_SE = cen[:, :, :shift, :wid-shift] # B1_SE is cen's NW 993 | # B3_N = nd.concat(B3_NW, B3_NE, dim=3) 994 | # B3_S = nd.concat(B3_SW, B3_SE, dim=3) 995 | # B3 = nd.concat(B3_N, B3_S, dim=2) 996 | # 997 | # ######## B4 ######### 998 | # # old: AB => new: BA 999 | # B4_W = cen[:, :, :, wid-shift:] # B2_W is cen's E 1000 | # B4_E = cen[:, :, :, :wid-shift] # B2_E is cen's S 1001 | # B4 = nd.concat(B4_W, B4_E, dim=3) 1002 | # 1003 | # ######## B5 ######### 1004 | # # old: AD => new: CB 1005 | # # BC => DA 1006 | # B5_NW = cen[:, :, hei-shift:, wid-shift:] # B5_NW is cen's SE 1007 | # B5_NE = cen[:, :, hei-shift:, :wid-shift] # B5_NE is cen's SW 1008 | # B5_SW = cen[:, :, :hei-shift, wid-shift:] # B5_SW is cen's NE 1009 | # B5_SE = cen[:, :, :hei-shift, :wid-shift] # B5_SE is cen's NW 1010 | # B5_N = nd.concat(B5_NW, B5_NE, dim=3) 1011 | # B5_S = nd.concat(B5_SW, B5_SE, dim=3) 1012 | # B5 = nd.concat(B5_N, B5_S, dim=2) 1013 | # 1014 | # ######## B6 ######### 1015 | # # old: A => new: B 1016 | # # B => A 1017 | # B6_N = cen[:, :, hei-shift:, :] # B6_N is cen's S 1018 | # B6_S = cen[:, :, :hei-shift, :] # B6_S is cen's N 1019 | # B6 = nd.concat(B6_N, B6_S, dim=2) 1020 | # 1021 | # ######## B7 ######### 1022 | # # old: AD => new: CB 1023 | # # BC => DA 1024 | # B7_NW = cen[:, :, hei-shift:, shift:] # B7_NW is cen's SE 1025 | # B7_NE = cen[:, :, hei-shift:, :shift] # B7_NE is cen's SW 1026 | # B7_SW = cen[:, :, :hei-shift, shift:] # B7_SW is cen's NE 1027 | # B7_SE = cen[:, :, :hei-shift, :shift] # B7_SE is cen's NW 1028 | # B7_N = nd.concat(B7_NW, B7_NE, dim=3) 1029 | # B7_S = nd.concat(B7_SW, B7_SE, dim=3) 1030 | # B7 = nd.concat(B7_N, B7_S, dim=2) 1031 | # 1032 | # ######## B8 ######### 1033 | # # old: AB => new: BA 1034 | # B8_W = cen[:, :, :, shift:] # B8_W is cen's E 1035 | # B8_E = cen[:, :, :, :shift] # B8_E is cen's S 1036 | # B8 = nd.concat(B8_W, B8_E, dim=3) 1037 | # 1038 | # return B1, B2, B3, B4, B5, B6, B7, B8 1039 | # 1040 | # def cal_pcm(self, cen, shift): 1041 | # 1042 | # B1, B2, B3, B4, B5, B6, B7, B8 = self.circ_shift(cen, shift=shift) 1043 | # s1 = (B1 - cen) * (B5 - cen) 1044 | # s2 = (B2 - cen) * (B6 - cen) 1045 | # s3 = (B3 - cen) * (B7 - cen) 1046 | # s4 = (B4 - cen) * (B8 - cen) 1047 | # 1048 | # c12 = nd.minimum(s1, s2) 1049 | # c123 = nd.minimum(c12, s3) 1050 | # c1234 = nd.minimum(c123, s4) 1051 | # 1052 | # return c1234 1053 | 1054 | 1055 | -------------------------------------------------------------------------------- /model/contrast.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from mxnet.gluon.block import HybridBlock 3 | from mxnet.gluon import nn 4 | from mxnet.gluon.nn import BatchNorm 5 | from gluoncv.model_zoo.fcn import _FCNHead 6 | from mxnet import nd 7 | from gluoncv.model_zoo.cifarresnet import CIFARBasicBlockV1 8 | 9 | 10 | def circ_shift(cen, shift): 11 | 12 | _, _, hei, wid = cen.shape 13 | 14 | ######## B1 ######### 15 | # old: AD => new: CB 16 | # BC => DA 17 | B1_NW = cen[:, :, shift:, shift:] # B1_NW is cen's SE 18 | B1_NE = cen[:, :, shift:, :shift] # B1_NE is cen's SW 19 | B1_SW = cen[:, :, :shift, shift:] # B1_SW is cen's NE 20 | B1_SE = cen[:, :, :shift, :shift] # B1_SE is cen's NW 21 | B1_N = nd.concat(B1_NW, B1_NE, dim=3) 22 | B1_S = nd.concat(B1_SW, B1_SE, dim=3) 23 | B1 = nd.concat(B1_N, B1_S, dim=2) 24 | 25 | ######## B2 ######### 26 | # old: A => new: B 27 | # B => A 28 | B2_N = cen[:, :, shift:, :] # B2_N is cen's S 29 | B2_S = cen[:, :, :shift, :] # B2_S is cen's N 30 | B2 = nd.concat(B2_N, B2_S, dim=2) 31 | 32 | ######## B3 ######### 33 | # old: AD => new: CB 34 | # BC => DA 35 | B3_NW = cen[:, :, shift:, wid-shift:] # B3_NW is cen's SE 36 | B3_NE = cen[:, :, shift:, :wid-shift] # B3_NE is cen's SW 37 | B3_SW = cen[:, :, :shift, wid-shift:] # B3_SW is cen's NE 38 | B3_SE = cen[:, :, :shift, :wid-shift] # B1_SE is cen's NW 39 | B3_N = nd.concat(B3_NW, B3_NE, dim=3) 40 | B3_S = nd.concat(B3_SW, B3_SE, dim=3) 41 | B3 = nd.concat(B3_N, B3_S, dim=2) 42 | 43 | ######## B4 ######### 44 | # old: AB => new: BA 45 | B4_W = cen[:, :, :, wid-shift:] # B2_W is cen's E 46 | B4_E = cen[:, :, :, :wid-shift] # B2_E is cen's S 47 | B4 = nd.concat(B4_W, B4_E, dim=3) 48 | 49 | ######## B5 ######### 50 | # old: AD => new: CB 51 | # BC => DA 52 | B5_NW = cen[:, :, hei-shift:, wid-shift:] # B5_NW is cen's SE 53 | B5_NE = cen[:, :, hei-shift:, :wid-shift] # B5_NE is cen's SW 54 | B5_SW = cen[:, :, :hei-shift, wid-shift:] # B5_SW is cen's NE 55 | B5_SE = cen[:, :, :hei-shift, :wid-shift] # B5_SE is cen's NW 56 | B5_N = nd.concat(B5_NW, B5_NE, dim=3) 57 | B5_S = nd.concat(B5_SW, B5_SE, dim=3) 58 | B5 = nd.concat(B5_N, B5_S, dim=2) 59 | 60 | ######## B6 ######### 61 | # old: A => new: B 62 | # B => A 63 | B6_N = cen[:, :, hei-shift:, :] # B6_N is cen's S 64 | B6_S = cen[:, :, :hei-shift, :] # B6_S is cen's N 65 | B6 = nd.concat(B6_N, B6_S, dim=2) 66 | 67 | ######## B7 ######### 68 | # old: AD => new: CB 69 | # BC => DA 70 | B7_NW = cen[:, :, hei-shift:, shift:] # B7_NW is cen's SE 71 | B7_NE = cen[:, :, hei-shift:, :shift] # B7_NE is cen's SW 72 | B7_SW = cen[:, :, :hei-shift, shift:] # B7_SW is cen's NE 73 | B7_SE = cen[:, :, :hei-shift, :shift] # B7_SE is cen's NW 74 | B7_N = nd.concat(B7_NW, B7_NE, dim=3) 75 | B7_S = nd.concat(B7_SW, B7_SE, dim=3) 76 | B7 = nd.concat(B7_N, B7_S, dim=2) 77 | 78 | ######## B8 ######### 79 | # old: AB => new: BA 80 | B8_W = cen[:, :, :, shift:] # B8_W is cen's E 81 | B8_E = cen[:, :, :, :shift] # B8_E is cen's S 82 | B8 = nd.concat(B8_W, B8_E, dim=3) 83 | 84 | return B1, B2, B3, B4, B5, B6, B7, B8 85 | 86 | 87 | def cal_pcm(cen, shift): 88 | 89 | B1, B2, B3, B4, B5, B6, B7, B8 = circ_shift(cen, shift=shift) 90 | s1 = (B1 - cen) * (B5 - cen) 91 | s2 = (B2 - cen) * (B6 - cen) 92 | s3 = (B3 - cen) * (B7 - cen) 93 | s4 = (B4 - cen) * (B8 - cen) 94 | 95 | c12 = nd.minimum(s1, s2) 96 | c123 = nd.minimum(c12, s3) 97 | c1234 = nd.minimum(c123, s4) 98 | 99 | return c1234 100 | 101 | 102 | class PCMNet(HybridBlock): 103 | def __init__(self, dilations=[1, 1, 2, 4, 8, 16], channels=16, classes=1, addstem=True, 104 | maxpool=False, shift='xxx', **kwargs): 105 | super(PCMNet, self).__init__(**kwargs) 106 | 107 | self.shift = shift 108 | stem_width = int(channels // 2) 109 | with self.name_scope(): 110 | self.features = nn.HybridSequential(prefix='') 111 | if addstem: 112 | self.features.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2, 113 | padding=1, use_bias=False)) 114 | self.features.add(nn.BatchNorm(in_channels=stem_width)) 115 | self.features.add(nn.Activation('relu')) 116 | self.features.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1, 117 | padding=1, use_bias=False)) 118 | self.features.add(nn.BatchNorm(in_channels=stem_width)) 119 | self.features.add(nn.Activation('relu')) 120 | self.features.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 121 | padding=1, use_bias=False)) 122 | self.features.add(nn.BatchNorm(in_channels=stem_width*2)) 123 | self.features.add(nn.Activation('relu')) 124 | if maxpool: 125 | self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 126 | 127 | for i, dilation in enumerate(dilations): 128 | self.features.add(self._make_layer( 129 | dilation=dilation, channels=channels, stage_index=i)) 130 | 131 | self.head = _FCNHead(in_channels=channels, channels=classes) 132 | 133 | def _make_layer(self, dilation, channels, stage_index): 134 | layer = nn.HybridSequential(prefix='stage%d_' % stage_index) 135 | with layer.name_scope(): 136 | 137 | layer.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilation, 138 | padding=dilation)) 139 | layer.add(nn.BatchNorm()) 140 | layer.add(nn.Activation('relu')) 141 | 142 | return layer 143 | 144 | def hybrid_forward(self, F, x): 145 | 146 | _, _, hei, wid = x.shape 147 | 148 | cen = self.features(x) 149 | c1234 = cal_pcm(cen, self.shift) 150 | x = self.head(c1234) 151 | 152 | out = F.contrib.BilinearResize2D(x, height=hei, width=wid) 153 | 154 | return out 155 | 156 | def evaluate(self, x): 157 | """evaluating network with inputs and targets""" 158 | return self.forward(x) 159 | 160 | 161 | class MPCMNet(HybridBlock): 162 | def __init__(self, dilations=[1, 1, 2, 4, 8, 16], channels=16, classes=1, addstem=True, 163 | maxpool=False, **kwargs): 164 | super(MPCMNet, self).__init__(**kwargs) 165 | 166 | stem_width = int(channels // 2) 167 | with self.name_scope(): 168 | self.features = nn.HybridSequential(prefix='') 169 | if addstem: 170 | self.features.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2, 171 | padding=1, use_bias=False)) 172 | self.features.add(nn.BatchNorm(in_channels=stem_width)) 173 | self.features.add(nn.Activation('relu')) 174 | self.features.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1, 175 | padding=1, use_bias=False)) 176 | self.features.add(nn.BatchNorm(in_channels=stem_width)) 177 | self.features.add(nn.Activation('relu')) 178 | self.features.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 179 | padding=1, use_bias=False)) 180 | self.features.add(nn.BatchNorm(in_channels=stem_width*2)) 181 | self.features.add(nn.Activation('relu')) 182 | if maxpool: 183 | self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 184 | 185 | for i, dilation in enumerate(dilations): 186 | self.features.add(self._make_layer( 187 | dilation=dilation, channels=channels, stage_index=i)) 188 | 189 | self.head = _FCNHead(in_channels=channels, channels=classes) 190 | 191 | def _make_layer(self, dilation, channels, stage_index): 192 | layer = nn.HybridSequential(prefix='stage%d_' % stage_index) 193 | with layer.name_scope(): 194 | 195 | layer.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilation, 196 | padding=dilation)) 197 | layer.add(nn.BatchNorm()) 198 | layer.add(nn.Activation('relu')) 199 | 200 | return layer 201 | 202 | def hybrid_forward(self, F, x): 203 | 204 | _, _, hei, wid = x.shape 205 | 206 | cen = self.features(x) 207 | 208 | # pcm9 = self.cal_pcm(cen, shift=9) 209 | # pcm17 = self.cal_pcm(cen, shift=17) 210 | # pcm25 = self.cal_pcm(cen, shift=25) 211 | # pcm33 = self.cal_pcm(cen, shift=33) 212 | # mpcm = nd.maximum(nd.maximum(nd.maximum(pcm9, pcm17), pcm25), pcm33) 213 | 214 | pcm9 = cal_pcm(cen, shift=9) 215 | pcm13 = cal_pcm(cen, shift=13) 216 | pcm17 = cal_pcm(cen, shift=17) 217 | # pcm21 = self.cal_pcm(cen, shift=21) 218 | # mpcm = nd.maximum(nd.maximum(nd.maximum(pcm9, pcm13), pcm17), pcm21) 219 | mpcm = nd.maximum(nd.maximum(pcm9, pcm13), pcm17) 220 | 221 | x = self.head(mpcm) 222 | 223 | out = F.contrib.BilinearResize2D(x, height=hei, width=wid) 224 | 225 | return out 226 | 227 | def evaluate(self, x): 228 | """evaluating network with inputs and targets""" 229 | return self.forward(x) 230 | 231 | 232 | class LayerwiseMPCMNet(HybridBlock): 233 | def __init__(self, dilations=[1, 1, 2, 4, 8, 16], channels=16, classes=1, addstem=True, 234 | maxpool=False, **kwargs): 235 | super(LayerwiseMPCMNet, self).__init__(**kwargs) 236 | 237 | stem_width = int(channels // 2) 238 | with self.name_scope(): 239 | self.features = nn.HybridSequential(prefix='') 240 | if addstem: 241 | self.features.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2, 242 | padding=1, use_bias=False)) 243 | self.features.add(nn.BatchNorm(in_channels=stem_width)) 244 | self.features.add(nn.Activation('relu')) 245 | self.features.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1, 246 | padding=1, use_bias=False)) 247 | self.features.add(nn.BatchNorm(in_channels=stem_width)) 248 | self.features.add(nn.Activation('relu')) 249 | self.features.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 250 | padding=1, use_bias=False)) 251 | self.features.add(nn.BatchNorm(in_channels=stem_width*2)) 252 | self.features.add(nn.Activation('relu')) 253 | if maxpool: 254 | self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 255 | self.features.add(CalMPCM()) 256 | 257 | for i, dilation in enumerate(dilations): 258 | self.features.add(self._make_layer( 259 | dilation=dilation, channels=channels, stage_index=i)) 260 | 261 | self.head = _FCNHead(in_channels=channels, channels=classes) 262 | 263 | def _make_layer(self, dilation, channels, stage_index): 264 | layer = nn.HybridSequential(prefix='stage%d_' % stage_index) 265 | with layer.name_scope(): 266 | 267 | layer.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilation, 268 | padding=dilation)) 269 | layer.add(nn.BatchNorm()) 270 | layer.add(nn.Activation('relu')) 271 | layer.add(CalMPCM()) 272 | 273 | return layer 274 | 275 | def hybrid_forward(self, F, x): 276 | 277 | _, _, hei, wid = x.shape 278 | 279 | x = self.features(x) 280 | x = self.head(x) 281 | 282 | out = F.contrib.BilinearResize2D(x, height=hei, width=wid) 283 | 284 | return out 285 | 286 | def evaluate(self, x): 287 | """evaluating network with inputs and targets""" 288 | return self.forward(x) 289 | 290 | 291 | class PlainNet(HybridBlock): 292 | def __init__(self, dilations=[1, 1, 2, 4, 8, 16], channels=16, classes=1, addstem=True, 293 | maxpool=False, **kwargs): 294 | super(PlainNet, self).__init__(**kwargs) 295 | 296 | stem_width = int(channels // 2) 297 | with self.name_scope(): 298 | self.features = nn.HybridSequential(prefix='') 299 | if addstem: 300 | self.features.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2, 301 | padding=1, use_bias=False)) 302 | self.features.add(nn.BatchNorm(in_channels=stem_width)) 303 | self.features.add(nn.Activation('relu')) 304 | self.features.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1, 305 | padding=1, use_bias=False)) 306 | self.features.add(nn.BatchNorm(in_channels=stem_width)) 307 | self.features.add(nn.Activation('relu')) 308 | self.features.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 309 | padding=1, use_bias=False)) 310 | self.features.add(nn.BatchNorm(in_channels=stem_width*2)) 311 | self.features.add(nn.Activation('relu')) 312 | if maxpool: 313 | self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 314 | 315 | for i, dilation in enumerate(dilations): 316 | self.features.add(self._make_layer( 317 | dilation=dilation, channels=channels, stage_index=i)) 318 | 319 | self.head = _FCNHead(in_channels=channels, channels=classes) 320 | 321 | def _make_layer(self, dilation, channels, stage_index): 322 | layer = nn.HybridSequential(prefix='stage%d_' % stage_index) 323 | with layer.name_scope(): 324 | 325 | layer.add(nn.Conv2D(channels=channels, kernel_size=3, dilation=dilation, 326 | padding=dilation)) 327 | layer.add(nn.BatchNorm()) 328 | layer.add(nn.Activation('relu')) 329 | 330 | return layer 331 | 332 | def hybrid_forward(self, F, x): 333 | 334 | _, _, hei, wid = x.shape 335 | 336 | cen = self.features(x) 337 | 338 | x = self.head(cen) 339 | 340 | out = F.contrib.BilinearResize2D(x, height=hei, width=wid) 341 | 342 | return out 343 | 344 | def evaluate(self, x): 345 | """evaluating network with inputs and targets""" 346 | return self.forward(x) 347 | 348 | 349 | class CalMPCM(HybridBlock): 350 | def __init__(self, **kwargs): 351 | super(CalMPCM, self).__init__(**kwargs) 352 | 353 | def hybrid_forward(self, F, x): 354 | 355 | pcm9 = cal_pcm(x, shift=9) 356 | pcm13 = cal_pcm(x, shift=13) 357 | pcm17 = cal_pcm(x, shift=17) 358 | mpcm = nd.maximum(nd.maximum(pcm9, pcm13), pcm17) 359 | 360 | return mpcm 361 | 362 | def evaluate(self, x): 363 | """evaluating network with inputs and targets""" 364 | return self.forward(x) 365 | 366 | 367 | class MPCMResNetFPN(HybridBlock): 368 | def __init__(self, layers, channels, shift=3, pyramid_mode='xxx', scale_mode='xxx', 369 | pyramid_fuse='xxx', r=2, classes=1, norm_layer=BatchNorm, norm_kwargs=None, 370 | **kwargs): 371 | super(MPCMResNetFPN, self).__init__(**kwargs) 372 | 373 | self.layer_num = len(layers) 374 | with self.name_scope(): 375 | 376 | self.r = r 377 | self.shift = shift 378 | self.pyramid_mode = pyramid_mode 379 | self.scale_mode = scale_mode 380 | self.pyramid_fuse = pyramid_fuse 381 | 382 | stem_width = int(channels[0]) 383 | self.stem = nn.HybridSequential(prefix='stem') 384 | self.stem.add(norm_layer(scale=False, center=False, 385 | **({} if norm_kwargs is None else norm_kwargs))) 386 | self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 387 | padding=1, use_bias=False)) 388 | self.stem.add(norm_layer(in_channels=stem_width*2)) 389 | self.stem.add(nn.Activation('relu')) 390 | 391 | self.head = _FCNHead(in_channels=channels[1], channels=classes) 392 | 393 | self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0], 394 | channels=channels[1], stride=1, stage_index=1, 395 | in_channels=channels[1]) 396 | 397 | self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1], 398 | channels=channels[2], stride=2, stage_index=2, 399 | in_channels=channels[1]) 400 | 401 | self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2], 402 | channels=channels[3], stride=2, stage_index=3, 403 | in_channels=channels[2]) 404 | 405 | if pyramid_mode == 'Dec': 406 | 407 | self.dec_c2 = nn.HybridSequential(prefix='dec_c2') 408 | self.dec_c2.add(nn.Conv2D(channels=channels[1], kernel_size=1, strides=1, 409 | padding=0, use_bias=False)) 410 | self.dec_c2.add(norm_layer(in_channels=channels[1])) 411 | self.dec_c2.add(nn.Activation('relu')) 412 | 413 | self.dec_c3 = nn.HybridSequential(prefix='dec_c3') 414 | self.dec_c3.add(nn.Conv2D(channels=channels[1], kernel_size=1, strides=1, 415 | padding=0, use_bias=False)) 416 | self.dec_c3.add(norm_layer(in_channels=channels[1])) 417 | self.dec_c3.add(nn.Activation('relu')) 418 | 419 | # if self.scale_mode == 'Selective': 420 | # # self.fuse_mpcm_c1 = GlobMPCMFuse(channels=channels[1]) 421 | # # self.fuse_mpcm_c2 = GlobMPCMFuse(channels=channels[2]) 422 | # # self.fuse_mpcm_c3 = GlobMPCMFuse(channels=channels[3]) 423 | # 424 | # self.fuse_mpcm_c1 = LocalMPCMFuse(channels=channels[1]) 425 | # self.fuse_mpcm_c2 = LocalMPCMFuse(channels=channels[2]) 426 | # self.fuse_mpcm_c3 = LocalMPCMFuse(channels=channels[3]) 427 | 428 | if self.scale_mode == 'biglobal': 429 | self.fuse_mpcm_c1 = BiGlobal_MPCMFuse(channels=channels[1]) 430 | self.fuse_mpcm_c2 = BiGlobal_MPCMFuse(channels=channels[2]) 431 | self.fuse_mpcm_c3 = BiGlobal_MPCMFuse(channels=channels[3]) 432 | elif self.scale_mode == 'bilocal': 433 | self.fuse_mpcm_c1 = BiLocal_MPCMFuse(channels=channels[1]) 434 | self.fuse_mpcm_c2 = BiLocal_MPCMFuse(channels=channels[2]) 435 | self.fuse_mpcm_c3 = BiLocal_MPCMFuse(channels=channels[3]) 436 | elif self.scale_mode == 'add': 437 | self.fuse_mpcm_c1 = Add_MPCMFuse(channels=channels[1]) 438 | self.fuse_mpcm_c2 = Add_MPCMFuse(channels=channels[2]) 439 | self.fuse_mpcm_c3 = Add_MPCMFuse(channels=channels[3]) 440 | elif self.scale_mode == 'globalsk': 441 | self.fuse_mpcm_c1 = GlobalSK_MPCMFuse(channels=channels[1]) 442 | self.fuse_mpcm_c2 = GlobalSK_MPCMFuse(channels=channels[2]) 443 | self.fuse_mpcm_c3 = GlobalSK_MPCMFuse(channels=channels[3]) 444 | elif self.scale_mode == 'localsk': 445 | self.fuse_mpcm_c1 = LocalSK_MPCMFuse(channels=channels[1]) 446 | self.fuse_mpcm_c2 = LocalSK_MPCMFuse(channels=channels[2]) 447 | self.fuse_mpcm_c3 = LocalSK_MPCMFuse(channels=channels[3]) 448 | 449 | if self.pyramid_fuse == 'globalsk': 450 | self.globalsk_fpn_2 = GlobalSK_FPNFuse(channels=channels[1], r=self.r) 451 | self.globalsk_fpn_1 = GlobalSK_FPNFuse(channels=channels[1], r=self.r) 452 | elif self.pyramid_fuse == 'localsk': 453 | self.localsk_fpn_2 = LocalSK_FPNFuse(channels=channels[1], r=self.r) 454 | self.localsk_fpn_1 = LocalSK_FPNFuse(channels=channels[1], r=self.r) 455 | elif self.pyramid_fuse == 'mutualglobalsk': 456 | self.mutualglobalsk_fpn_2 = MutualSKGlobal_FPNFuse(channels=channels[1]) 457 | self.mutualglobalsk_fpn_1 = MutualSKGlobal_FPNFuse(channels=channels[1]) 458 | elif self.pyramid_fuse == 'mutuallocalsk': 459 | self.mutuallocalsk_fpn_2 = MutualSKLocal_FPNFuse(channels=channels[1]) 460 | self.mutuallocalsk_fpn_1 = MutualSKLocal_FPNFuse(channels=channels[1]) 461 | elif self.pyramid_fuse == 'bilocal': 462 | self.bilocal_fpn_2 = BiLocal_FPNFuse(channels=channels[1]) 463 | self.bilocal_fpn_1 = BiLocal_FPNFuse(channels=channels[1]) 464 | elif self.pyramid_fuse == 'biglobal': 465 | self.biglobal_fpn_2 = BiGlobal_FPNFuse(channels=channels[1]) 466 | self.biglobal_fpn_1 = BiGlobal_FPNFuse(channels=channels[1]) 467 | elif self.pyramid_fuse == 'remo': 468 | self.remo_fpn_2 = ReMo_FPNFuse(channels=channels[1]) 469 | self.remo_fpn_1 = ReMo_FPNFuse(channels=channels[1]) 470 | elif self.pyramid_fuse == 'localremo': 471 | self.localremo_fpn_2 = ReMo_FPNFuse(channels=channels[1]) 472 | self.localremo_fpn_1 = ReMo_FPNFuse(channels=channels[1]) 473 | elif self.pyramid_fuse == 'asymbi': 474 | self.asymbi_fpn_2 = AsymBi_FPNFuse(channels=channels[1]) 475 | self.asymbi_fpn_1 = AsymBi_FPNFuse(channels=channels[1]) 476 | elif self.pyramid_fuse == 'topdownlocal': 477 | self.topdownlocal_fpn_2 = TopDownLocal_FPNFuse(channels=channels[1]) 478 | self.topdownlocal_fpn_1 = TopDownLocal_FPNFuse(channels=channels[1]) 479 | elif self.pyramid_fuse == 'bottomuplocal': 480 | self.bottomuplocal_fpn_2 = BottomUpLocal_FPNFuse(channels=channels[1]) 481 | self.bottomuplocal_fpn_1 = BottomUpLocal_FPNFuse(channels=channels[1]) 482 | elif self.pyramid_fuse == 'bottomupglobal': 483 | self.bottomupglobal_fpn_2 = BottomUpGlobal_FPNFuse(channels=channels[1]) 484 | self.bottomupglobal_fpn_1 = BottomUpGlobal_FPNFuse(channels=channels[1]) 485 | 486 | 487 | # elif pyramid_mode == 'Inc': 488 | # 489 | # self.inc_c2 = nn.HybridSequential(prefix='inc_c2') 490 | # self.inc_c2.add(nn.Conv2D(channels=channels[3], kernel_size=1, strides=1, 491 | # padding=0, use_bias=False)) 492 | # self.inc_c2.add(norm_layer(in_channels=channels[-1])) 493 | # self.inc_c2.add(nn.Activation('relu')) 494 | # 495 | # self.inc_c1 = nn.HybridSequential(prefix='inc_c1') 496 | # self.inc_c1.add(nn.Conv2D(channels=channels[3], kernel_size=1, strides=1, 497 | # padding=0, use_bias=False)) 498 | # self.inc_c1.add(norm_layer(in_channels=channels[-1])) 499 | # self.inc_c1.add(nn.Activation('relu')) 500 | # else: 501 | # raise ValueError("unknown pyramid_mode") 502 | 503 | def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0, 504 | norm_layer=BatchNorm, norm_kwargs=None): 505 | layer = nn.HybridSequential(prefix='stage%d_'%stage_index) 506 | with layer.name_scope(): 507 | downsample = (channels != in_channels) or (stride != 1) 508 | layer.add(block(channels, stride, downsample, in_channels=in_channels, 509 | prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 510 | for _ in range(layers-1): 511 | layer.add(block(channels, 1, False, in_channels=channels, prefix='', 512 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 513 | return layer 514 | 515 | def hybrid_forward(self, F, x): 516 | 517 | _, _, orig_hei, orig_wid = x.shape 518 | x = self.stem(x) # sub 2 519 | c1 = self.layer1(x) # sub 2 520 | _, _, c1_hei, c1_wid = c1.shape 521 | c2 = self.layer2(c1) # sub 4 522 | _, _, c2_hei, c2_wid = c2.shape 523 | c3 = self.layer3(c2) # sub 8 524 | _, _, c3_hei, c3_wid = c3.shape 525 | 526 | # 1. upsampling(c3) -> c3PCM # size: sub 4 527 | 528 | # c3 -> c3PCM 529 | # 2. pwconv(c2) -> c2PCM # size: sub 4 530 | # 3. upsampling(c3PCM + c2PCM) # size: sub 2 531 | # 4. pwconv(c1) -> c1PCM # size: sub 2 532 | # 5. upsampling(upsampling(c3PCM + c2PCM)) + c1PCM 533 | # 6. upsampling(upsampling(c3PCM + c2PCM)) + c1PCM 534 | 535 | if self.pyramid_mode == 'Dec': 536 | if self.scale_mode == 'Single': 537 | c3pcm = cal_pcm(c3, shift=self.shift) # sub 8, 64 538 | elif self.scale_mode == 'Multiple': 539 | c3pcm = self.cal_mpcm(c3) # sub 8, 64 540 | elif self.scale_mode == 'biglobal': 541 | c3pcm = self.fuse_mpcm_c3(c3) # sub 8, 64 542 | elif self.scale_mode == 'bilocal': 543 | c3pcm = self.fuse_mpcm_c3(c3) # sub 8, 64 544 | elif self.scale_mode == 'add': 545 | c3pcm = self.fuse_mpcm_c3(c3) # sub 8, 64 546 | elif self.scale_mode == 'globalsk': 547 | c3pcm = self.fuse_mpcm_c3(c3) # sub 8, 64 548 | elif self.scale_mode == 'localsk': 549 | c3pcm = self.fuse_mpcm_c3(c3) # sub 8, 64 550 | else: 551 | raise ValueError("unknow self.scale_mode") 552 | c3pcm = self.dec_c3(c3pcm) # sub 8, 16 553 | up_c3pcm = F.contrib.BilinearResize2D(c3pcm, height=c2_hei, width=c2_wid) # sub 4, 16 554 | 555 | if self.scale_mode == 'Single': 556 | c2pcm = cal_pcm(c2, shift=self.shift) # sub 4, 32 557 | elif self.scale_mode == 'Multiple': 558 | c2pcm = self.cal_mpcm(c2) # sub 4, 32 559 | elif self.scale_mode == 'biglobal': 560 | c2pcm = self.fuse_mpcm_c2(c2) # sub 4, 32 561 | elif self.scale_mode == 'bilocal': 562 | c2pcm = self.fuse_mpcm_c2(c2) # sub 4, 32 563 | elif self.scale_mode == 'add': 564 | c2pcm = self.fuse_mpcm_c2(c2) # sub 4, 32 565 | elif self.scale_mode == 'globalsk': 566 | c2pcm = self.fuse_mpcm_c2(c2) # sub 4, 32 567 | elif self.scale_mode == 'localsk': 568 | c2pcm = self.fuse_mpcm_c2(c2) # sub 4, 32 569 | else: 570 | raise ValueError("unknow self.scale_mode") 571 | c2pcm = self.dec_c2(c2pcm) # sub 4, 16 572 | 573 | if self.pyramid_fuse == 'add': 574 | c23pcm = up_c3pcm + c2pcm # sub 4, 16 575 | elif self.pyramid_fuse == 'max': 576 | c23pcm = nd.maximum(up_c3pcm, c2pcm) # sub 4, 16 577 | elif self.pyramid_fuse == 'bilocal': 578 | c23pcm = self.bilocal_fpn_2(up_c3pcm, c2pcm) 579 | elif self.pyramid_fuse == 'biglobal': 580 | c23pcm = self.biglobal_fpn_2(up_c3pcm, c2pcm) 581 | elif self.pyramid_fuse == 'globalsk': 582 | c23pcm = self.globalsk_fpn_2(up_c3pcm, c2pcm) 583 | elif self.pyramid_fuse == 'localsk': 584 | c23pcm = self.localsk_fpn_2(up_c3pcm, c2pcm) 585 | elif self.pyramid_fuse == 'mutualglobalsk': 586 | c23pcm = self.mutualglobalsk_fpn_2(up_c3pcm, c2pcm) 587 | elif self.pyramid_fuse == 'mutuallocalsk': 588 | c23pcm = self.mutuallocalsk_fpn_2(up_c3pcm, c2pcm) 589 | elif self.pyramid_fuse == 'remo': 590 | c23pcm = self.remo_fpn_2(up_c3pcm, c2pcm) 591 | elif self.pyramid_fuse == 'localremo': 592 | c23pcm = self.localremo_fpn_2(up_c3pcm, c2pcm) 593 | elif self.pyramid_fuse == 'asymbi': 594 | c23pcm = self.asymbi_fpn_2(up_c3pcm, c2pcm) 595 | elif self.pyramid_fuse == 'topdownlocal': 596 | c23pcm = self.topdownlocal_fpn_2(up_c3pcm, c2pcm) 597 | elif self.pyramid_fuse == 'bottomuplocal': 598 | c23pcm = self.bottomuplocal_fpn_2(up_c3pcm, c2pcm) 599 | elif self.pyramid_fuse == 'bottomupglobal': 600 | c23pcm = self.bottomupglobal_fpn_2(up_c3pcm, c2pcm) 601 | else: 602 | raise ValueError("unknow self.scale_mode") 603 | 604 | up_c23pcm = F.contrib.BilinearResize2D(c23pcm, height=c1_hei, width=c1_wid) # sub 2, 16 605 | 606 | if self.scale_mode == 'Single': 607 | c1pcm = cal_pcm(c1, shift=self.shift) # sub 2, 16 608 | elif self.scale_mode == 'Multiple': 609 | c1pcm = self.cal_mpcm(c1) # sub 2, 16 610 | elif self.scale_mode == 'biglobal': 611 | c1pcm = self.fuse_mpcm_c1(c1) # sub 2, 16 612 | elif self.scale_mode == 'bilocal': 613 | c1pcm = self.fuse_mpcm_c1(c1) # sub 2, 16 614 | elif self.scale_mode == 'add': 615 | c1pcm = self.fuse_mpcm_c1(c1) # sub 2, 16 616 | elif self.scale_mode == 'globalsk': 617 | c1pcm = self.fuse_mpcm_c1(c1) # sub 2, 16 618 | elif self.scale_mode == 'localsk': 619 | c1pcm = self.fuse_mpcm_c1(c1) # sub 2, 16 620 | else: 621 | raise ValueError("unknow self.scale_mode") 622 | 623 | if self.pyramid_fuse == 'add': 624 | out = up_c23pcm + c1pcm 625 | elif self.pyramid_fuse == 'max': 626 | out = nd.maximum(up_c23pcm, c1pcm) 627 | elif self.pyramid_fuse == 'bilocal': 628 | out = self.bilocal_fpn_1(up_c23pcm, c1pcm) 629 | elif self.pyramid_fuse == 'biglobal': 630 | out = self.biglobal_fpn_1(up_c23pcm, c1pcm) 631 | elif self.pyramid_fuse == 'globalsk': 632 | out = self.globalsk_fpn_1(up_c23pcm, c1pcm) 633 | elif self.pyramid_fuse == 'localsk': 634 | out = self.localsk_fpn_1(up_c23pcm, c1pcm) 635 | elif self.pyramid_fuse == 'mutualglobalsk': 636 | out = self.mutualglobalsk_fpn_1(up_c23pcm, c1pcm) 637 | elif self.pyramid_fuse == 'mutuallocalsk': 638 | out = self.mutuallocalsk_fpn_1(up_c23pcm, c1pcm) 639 | elif self.pyramid_fuse == 'remo': 640 | out = self.remo_fpn_1(up_c23pcm, c1pcm) 641 | elif self.pyramid_fuse == 'localremo': 642 | out = self.localremo_fpn_1(up_c23pcm, c1pcm) 643 | elif self.pyramid_fuse == 'asymbi': 644 | out = self.asymbi_fpn_1(up_c23pcm, c1pcm) 645 | elif self.pyramid_fuse == 'topdownlocal': 646 | out = self.topdownlocal_fpn_1(up_c23pcm, c1pcm) 647 | elif self.pyramid_fuse == 'bottomuplocal': 648 | out = self.bottomuplocal_fpn_1(up_c23pcm, c1pcm) 649 | elif self.pyramid_fuse == 'bottomupglobal': 650 | out = self.bottomupglobal_fpn_1(up_c23pcm, c1pcm) 651 | else: 652 | raise ValueError("unknown self.pyramid_fuse") 653 | 654 | elif self.pyramid_mode == 'Inc': 655 | 656 | c3pcm = cal_pcm(c3, shift=self.shift) 657 | up_c3pcm = F.contrib.BilinearResize2D(c3pcm, height=c2_hei, width=c2_wid) # sub 4, 64 658 | 659 | inc_c2 = self.inc_c2(c2) # sub 4, 64 660 | c2pcm = cal_pcm(inc_c2, shift=self.shift) 661 | 662 | c23pcm = up_c3pcm + c2pcm # sub 4, 64 663 | 664 | up_c23pcm = F.contrib.BilinearResize2D(c23pcm, height=c1_hei, width=c1_wid) # sub 2, 64 665 | inc_c1 = self.inc_c1(c1) # sub 2, 64 666 | c1pcm = cal_pcm(inc_c1, shift=self.shift) 667 | 668 | out = up_c23pcm + c1pcm # sub 2, 64 669 | 670 | pred = self.head(out) 671 | out = F.contrib.BilinearResize2D(pred, height=orig_hei, width=orig_wid) 672 | 673 | return out 674 | 675 | def evaluate(self, x): 676 | """evaluating network with inputs and targets""" 677 | return self.forward(x) 678 | 679 | def cal_mpcm(self, cen): 680 | # pcm11 = cal_pcm(cen, shift=11) 681 | pcm13 = cal_pcm(cen, shift=13) 682 | pcm17 = cal_pcm(cen, shift=17) 683 | mpcm = nd.maximum(pcm13, pcm17) 684 | # mpcm = nd.maximum(pcm11, nd.maximum(pcm13, pcm17)) 685 | 686 | return mpcm 687 | 688 | 689 | #### MPCM Fuse 690 | 691 | 692 | class Add_MPCMFuse(HybridBlock): 693 | def __init__(self, channels=64): 694 | super(Add_MPCMFuse, self).__init__() 695 | 696 | with self.name_scope(): 697 | 698 | self.bn1 = nn.BatchNorm() 699 | self.bn2 = nn.BatchNorm() 700 | 701 | def hybrid_forward(self, F, cen): 702 | 703 | pcm13 = cal_pcm(cen, shift=13) 704 | pcm17 = cal_pcm(cen, shift=17) 705 | 706 | pcm13 = self.bn1(pcm13) 707 | pcm17 = self.bn2(pcm17) 708 | 709 | xo = pcm13 + pcm17 710 | 711 | return xo 712 | 713 | 714 | class LocalSK_MPCMFuse(HybridBlock): 715 | def __init__(self, channels=64): 716 | super(LocalSK_MPCMFuse, self).__init__() 717 | inter_channels = int(channels // 2) 718 | 719 | with self.name_scope(): 720 | 721 | self.local_att = nn.HybridSequential(prefix='local_att') 722 | self.local_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 723 | self.local_att.add(nn.BatchNorm()) 724 | self.local_att.add(nn.Activation('relu')) 725 | self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 726 | self.local_att.add(nn.BatchNorm()) 727 | self.local_att.add(nn.Activation('sigmoid')) 728 | 729 | def hybrid_forward(self, F, cen): 730 | 731 | pcm13 = cal_pcm(cen, shift=13) 732 | pcm17 = cal_pcm(cen, shift=17) 733 | 734 | xa = pcm13 + pcm17 735 | # xa = cen 736 | wei = self.local_att(xa) 737 | 738 | xo = 2 * F.broadcast_mul(pcm13, wei) + 2 * F.broadcast_mul(pcm17, 1-wei) 739 | 740 | return xo 741 | 742 | 743 | class GlobalSK_MPCMFuse(HybridBlock): 744 | def __init__(self, channels=64): 745 | super(GlobalSK_MPCMFuse, self).__init__() 746 | inter_channels = int(channels // 2) 747 | 748 | with self.name_scope(): 749 | 750 | self.global_att = nn.HybridSequential(prefix='global_att') 751 | self.global_att.add(nn.GlobalAvgPool2D()) 752 | self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 753 | self.global_att.add(nn.BatchNorm()) 754 | self.global_att.add(nn.Activation('relu')) 755 | self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 756 | self.global_att.add(nn.BatchNorm()) 757 | self.global_att.add(nn.Activation('sigmoid')) 758 | 759 | def hybrid_forward(self, F, cen): 760 | 761 | pcm13 = cal_pcm(cen, shift=13) 762 | pcm17 = cal_pcm(cen, shift=17) 763 | 764 | xa = pcm13 + pcm17 765 | wei = self.global_att(xa) 766 | 767 | xo = 2 * F.broadcast_mul(pcm13, wei) + 2 * F.broadcast_mul(pcm17, 1-wei) 768 | 769 | return xo 770 | 771 | 772 | class BiLocal_MPCMFuse(HybridBlock): 773 | def __init__(self, channels=64): 774 | super(BiLocal_MPCMFuse, self).__init__() 775 | inter_channels = int(channels // 2) 776 | 777 | with self.name_scope(): 778 | 779 | self.bn1 = nn.BatchNorm() 780 | self.bn2 = nn.BatchNorm() 781 | 782 | self.topdown_att = nn.HybridSequential(prefix='topdown_att') 783 | self.topdown_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 784 | self.topdown_att.add(nn.BatchNorm()) 785 | self.topdown_att.add(nn.Activation('relu')) 786 | self.topdown_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 787 | self.topdown_att.add(nn.BatchNorm()) 788 | self.topdown_att.add(nn.Activation('sigmoid')) 789 | 790 | self.bottomup_att = nn.HybridSequential(prefix='bottomup_att') 791 | self.bottomup_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 792 | self.bottomup_att.add(nn.BatchNorm()) 793 | self.bottomup_att.add(nn.Activation('relu')) 794 | self.bottomup_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 795 | self.bottomup_att.add(nn.BatchNorm()) 796 | self.bottomup_att.add(nn.Activation('sigmoid')) 797 | 798 | def hybrid_forward(self, F, cen): 799 | 800 | pcm13 = cal_pcm(cen, shift=13) 801 | pcm17 = cal_pcm(cen, shift=17) 802 | 803 | pcm13 = self.bn1(pcm13) 804 | pcm17 = self.bn2(pcm17) 805 | 806 | topdown_wei = self.topdown_att(pcm17) 807 | bottomup_wei = self.bottomup_att(pcm13) 808 | 809 | xo = F.broadcast_mul(topdown_wei, pcm13) + F.broadcast_mul(bottomup_wei, pcm17) 810 | 811 | return xo 812 | 813 | 814 | class BiGlobal_MPCMFuse(HybridBlock): 815 | def __init__(self, channels=64): 816 | super(BiGlobal_MPCMFuse, self).__init__() 817 | inter_channels = int(channels // 2) 818 | 819 | with self.name_scope(): 820 | 821 | self.bn1 = nn.BatchNorm() 822 | self.bn2 = nn.BatchNorm() 823 | 824 | self.topdown_att = nn.HybridSequential(prefix='topdown_att') 825 | self.topdown_att.add(nn.GlobalAvgPool2D()) 826 | self.topdown_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 827 | self.topdown_att.add(nn.BatchNorm()) 828 | self.topdown_att.add(nn.Activation('relu')) 829 | self.topdown_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 830 | self.topdown_att.add(nn.BatchNorm()) 831 | self.topdown_att.add(nn.Activation('sigmoid')) 832 | 833 | self.bottomup_att = nn.HybridSequential(prefix='bottomup_att') 834 | self.bottomup_att.add(nn.GlobalAvgPool2D()) 835 | self.bottomup_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 836 | self.bottomup_att.add(nn.BatchNorm()) 837 | self.bottomup_att.add(nn.Activation('relu')) 838 | self.bottomup_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 839 | self.bottomup_att.add(nn.BatchNorm()) 840 | self.bottomup_att.add(nn.Activation('sigmoid')) 841 | 842 | def hybrid_forward(self, F, cen): 843 | 844 | pcm13 = cal_pcm(cen, shift=13) 845 | pcm17 = cal_pcm(cen, shift=17) 846 | 847 | pcm13 = self.bn1(pcm13) 848 | pcm17 = self.bn2(pcm17) 849 | 850 | topdown_wei = self.topdown_att(pcm17) 851 | bottomup_wei = self.bottomup_att(pcm13) 852 | 853 | xo = F.broadcast_mul(topdown_wei, pcm13) + F.broadcast_mul(bottomup_wei, pcm17) 854 | 855 | return xo 856 | 857 | 858 | ### FPN Fuse 859 | 860 | class BiGlobal_FPNFuse(HybridBlock): 861 | def __init__(self, channels=64): 862 | super(BiGlobal_FPNFuse, self).__init__() 863 | inter_channels = int(channels // 2) 864 | 865 | with self.name_scope(): 866 | 867 | self.bn1 = nn.BatchNorm() 868 | self.bn2 = nn.BatchNorm() 869 | 870 | self.topdown_att = nn.HybridSequential(prefix='topdown_att') 871 | self.topdown_att.add(nn.GlobalAvgPool2D()) 872 | self.topdown_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 873 | self.topdown_att.add(nn.BatchNorm()) 874 | self.topdown_att.add(nn.Activation('relu')) 875 | self.topdown_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 876 | self.topdown_att.add(nn.BatchNorm()) 877 | self.topdown_att.add(nn.Activation('sigmoid')) 878 | 879 | self.bottomup_att = nn.HybridSequential(prefix='bottomup_att') 880 | self.bottomup_att.add(nn.GlobalAvgPool2D()) 881 | self.bottomup_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 882 | self.bottomup_att.add(nn.BatchNorm()) 883 | self.bottomup_att.add(nn.Activation('relu')) 884 | self.bottomup_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 885 | self.bottomup_att.add(nn.BatchNorm()) 886 | self.bottomup_att.add(nn.Activation('sigmoid')) 887 | 888 | def hybrid_forward(self, F, x, residual): 889 | 890 | x = self.bn1(x) 891 | residual = self.bn2(residual) 892 | 893 | topdown_wei = self.topdown_att(x) 894 | bottomup_wei = self.bottomup_att(residual) 895 | 896 | xo = F.broadcast_mul(topdown_wei, residual) + F.broadcast_mul(bottomup_wei, x) 897 | 898 | return xo 899 | 900 | 901 | class BiLocal_FPNFuse(HybridBlock): 902 | def __init__(self, channels=64): 903 | super(BiLocal_FPNFuse, self).__init__() 904 | inter_channels = int(channels // 2) 905 | 906 | with self.name_scope(): 907 | 908 | self.bn1 = nn.BatchNorm() 909 | self.bn2 = nn.BatchNorm() 910 | 911 | self.topdown_att = nn.HybridSequential(prefix='topdown_att') 912 | self.topdown_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 913 | self.topdown_att.add(nn.BatchNorm()) 914 | self.topdown_att.add(nn.Activation('relu')) 915 | self.topdown_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 916 | self.topdown_att.add(nn.BatchNorm()) 917 | self.topdown_att.add(nn.Activation('sigmoid')) 918 | 919 | self.bottomup_att = nn.HybridSequential(prefix='bottomup_att') 920 | self.bottomup_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 921 | self.bottomup_att.add(nn.BatchNorm()) 922 | self.bottomup_att.add(nn.Activation('relu')) 923 | self.bottomup_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 924 | self.bottomup_att.add(nn.BatchNorm()) 925 | self.bottomup_att.add(nn.Activation('sigmoid')) 926 | 927 | def hybrid_forward(self, F, x, residual): 928 | 929 | x = self.bn1(x) 930 | residual = self.bn2(residual) 931 | 932 | topdown_wei = self.topdown_att(x) 933 | bottomup_wei = self.bottomup_att(residual) 934 | 935 | xo = F.broadcast_mul(topdown_wei, residual) + F.broadcast_mul(bottomup_wei, x) 936 | 937 | return xo 938 | 939 | 940 | class AsymBi_FPNFuse(HybridBlock): 941 | def __init__(self, channels=64): 942 | super(AsymBi_FPNFuse, self).__init__() 943 | inter_channels = int(channels // 2) 944 | 945 | with self.name_scope(): 946 | 947 | self.bn1 = nn.BatchNorm() 948 | self.bn2 = nn.BatchNorm() 949 | 950 | self.topdown_att = nn.HybridSequential(prefix='topdown_att') 951 | self.topdown_att.add(nn.GlobalAvgPool2D()) 952 | self.topdown_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 953 | self.topdown_att.add(nn.BatchNorm()) 954 | self.topdown_att.add(nn.Activation('relu')) 955 | self.topdown_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 956 | self.topdown_att.add(nn.BatchNorm()) 957 | self.topdown_att.add(nn.Activation('sigmoid')) 958 | 959 | self.bottomup_att = nn.HybridSequential(prefix='bottomup_att') 960 | self.bottomup_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 961 | self.bottomup_att.add(nn.BatchNorm()) 962 | self.bottomup_att.add(nn.Activation('relu')) 963 | self.bottomup_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 964 | self.bottomup_att.add(nn.BatchNorm()) 965 | self.bottomup_att.add(nn.Activation('sigmoid')) 966 | 967 | def hybrid_forward(self, F, x, residual): 968 | 969 | x = self.bn1(x) 970 | residual = self.bn2(residual) 971 | 972 | topdown_wei = self.topdown_att(x) 973 | bottomup_wei = self.bottomup_att(residual) 974 | 975 | xo = F.broadcast_mul(topdown_wei, residual) + F.broadcast_mul(bottomup_wei, x) 976 | 977 | return xo 978 | 979 | 980 | class BottomUpLocal_FPNFuse(HybridBlock): 981 | def __init__(self, channels=64): 982 | super(BottomUpLocal_FPNFuse, self).__init__() 983 | inter_channels = int(channels // 1) 984 | 985 | with self.name_scope(): 986 | 987 | self.bn1 = nn.BatchNorm() 988 | self.bn2 = nn.BatchNorm() 989 | 990 | self.bottomup_att = nn.HybridSequential(prefix='bottomup_att') 991 | self.bottomup_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 992 | self.bottomup_att.add(nn.BatchNorm()) 993 | self.bottomup_att.add(nn.Activation('relu')) 994 | self.bottomup_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 995 | self.bottomup_att.add(nn.BatchNorm()) 996 | self.bottomup_att.add(nn.Activation('sigmoid')) 997 | 998 | def hybrid_forward(self, F, x, residual): 999 | 1000 | x = self.bn1(x) 1001 | residual = self.bn2(residual) 1002 | 1003 | bottomup_wei = self.bottomup_att(residual) 1004 | 1005 | xo = F.broadcast_mul(bottomup_wei, x) + residual 1006 | 1007 | return xo 1008 | 1009 | 1010 | class BottomUpGlobal_FPNFuse(HybridBlock): 1011 | def __init__(self, channels=64): 1012 | super(BottomUpGlobal_FPNFuse, self).__init__() 1013 | inter_channels = int(channels // 1) 1014 | 1015 | with self.name_scope(): 1016 | 1017 | self.bn1 = nn.BatchNorm() 1018 | self.bn2 = nn.BatchNorm() 1019 | 1020 | self.bottomup_att = nn.HybridSequential(prefix='bottomup_att') 1021 | self.bottomup_att.add(nn.GlobalAvgPool2D()) 1022 | self.bottomup_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1023 | self.bottomup_att.add(nn.BatchNorm()) 1024 | self.bottomup_att.add(nn.Activation('relu')) 1025 | self.bottomup_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1026 | self.bottomup_att.add(nn.BatchNorm()) 1027 | self.bottomup_att.add(nn.Activation('sigmoid')) 1028 | 1029 | def hybrid_forward(self, F, x, residual): 1030 | 1031 | x = self.bn1(x) 1032 | residual = self.bn2(residual) 1033 | 1034 | bottomup_wei = self.bottomup_att(residual) 1035 | 1036 | xo = F.broadcast_mul(bottomup_wei, x) + residual 1037 | 1038 | return xo 1039 | 1040 | 1041 | 1042 | class TopDownLocal_FPNFuse(HybridBlock): 1043 | def __init__(self, channels=64): 1044 | super(TopDownLocal_FPNFuse, self).__init__() 1045 | inter_channels = int(channels // 1) 1046 | 1047 | with self.name_scope(): 1048 | 1049 | self.bn1 = nn.BatchNorm() 1050 | self.bn2 = nn.BatchNorm() 1051 | 1052 | self.topdown_att = nn.HybridSequential(prefix='topdown_att') 1053 | self.topdown_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1054 | self.topdown_att.add(nn.BatchNorm()) 1055 | self.topdown_att.add(nn.Activation('relu')) 1056 | self.topdown_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1057 | self.topdown_att.add(nn.BatchNorm()) 1058 | self.topdown_att.add(nn.Activation('sigmoid')) 1059 | 1060 | def hybrid_forward(self, F, x, residual): 1061 | 1062 | x = self.bn1(x) 1063 | residual = self.bn2(residual) 1064 | topdown_wei = self.topdown_att(x) 1065 | 1066 | xo = x + F.broadcast_mul(topdown_wei, residual) 1067 | 1068 | return xo 1069 | 1070 | 1071 | class GlobalSK_FPNFuse(HybridBlock): 1072 | def __init__(self, channels=64, r=2): 1073 | super(GlobalSK_FPNFuse, self).__init__() 1074 | inter_channels = int(channels // r) 1075 | 1076 | with self.name_scope(): 1077 | 1078 | self.bn1 = nn.BatchNorm() 1079 | self.bn2 = nn.BatchNorm() 1080 | 1081 | self.global_att = nn.HybridSequential(prefix='global_att') 1082 | self.global_att.add(nn.GlobalAvgPool2D()) 1083 | self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1084 | self.global_att.add(nn.BatchNorm()) 1085 | self.global_att.add(nn.Activation('relu')) 1086 | self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1087 | self.global_att.add(nn.BatchNorm()) 1088 | self.global_att.add(nn.Activation('sigmoid')) 1089 | 1090 | def hybrid_forward(self, F, x, residual): 1091 | 1092 | x = self.bn1(x) 1093 | residual = self.bn2(residual) 1094 | xa = x + residual 1095 | 1096 | wei = self.global_att(xa) 1097 | 1098 | xo = F.broadcast_mul(x, wei) + F.broadcast_mul(residual, 1-wei) 1099 | 1100 | return xo 1101 | 1102 | 1103 | class MutualSKLocal_FPNFuse(HybridBlock): 1104 | def __init__(self, channels=64): 1105 | super(MutualSKLocal_FPNFuse, self).__init__() 1106 | inter_channels = int(channels // 2) 1107 | 1108 | with self.name_scope(): 1109 | 1110 | self.bn1 = nn.BatchNorm() 1111 | self.bn2 = nn.BatchNorm() 1112 | 1113 | self.topdown = nn.HybridSequential(prefix='topdown') 1114 | self.topdown.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1115 | self.topdown.add(nn.BatchNorm()) 1116 | self.topdown.add(nn.Activation('relu')) 1117 | self.topdown.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1118 | self.topdown.add(nn.BatchNorm()) 1119 | self.topdown.add(nn.Activation('sigmoid')) 1120 | 1121 | self.bottomup = nn.HybridSequential(prefix='bottomup') 1122 | self.bottomup.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1123 | self.bottomup.add(nn.BatchNorm()) 1124 | self.bottomup.add(nn.Activation('relu')) 1125 | self.bottomup.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1126 | self.bottomup.add(nn.BatchNorm()) 1127 | self.bottomup.add(nn.Activation('sigmoid')) 1128 | 1129 | def hybrid_forward(self, F, x, residual): 1130 | 1131 | x = self.bn1(x) 1132 | residual = self.bn2(residual) 1133 | xa = x + residual 1134 | 1135 | topdown_wei = self.topdown(xa) 1136 | bottomup_wei = self.bottomup(xa) 1137 | 1138 | xo = F.broadcast_mul(x, topdown_wei) + F.broadcast_mul(residual, bottomup_wei) 1139 | 1140 | return xo 1141 | 1142 | 1143 | class MutualSKGlobal_FPNFuse(HybridBlock): 1144 | def __init__(self, channels=64): 1145 | super(MutualSKGlobal_FPNFuse, self).__init__() 1146 | inter_channels = int(channels // 2) 1147 | 1148 | with self.name_scope(): 1149 | 1150 | self.bn1 = nn.BatchNorm() 1151 | self.bn2 = nn.BatchNorm() 1152 | 1153 | self.topdown = nn.HybridSequential(prefix='topdown') 1154 | self.topdown.add(nn.GlobalAvgPool2D()) 1155 | self.topdown.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1156 | self.topdown.add(nn.BatchNorm()) 1157 | self.topdown.add(nn.Activation('relu')) 1158 | self.topdown.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1159 | self.topdown.add(nn.BatchNorm()) 1160 | self.topdown.add(nn.Activation('sigmoid')) 1161 | 1162 | self.bottomup = nn.HybridSequential(prefix='bottomup') 1163 | self.bottomup.add(nn.GlobalAvgPool2D()) 1164 | self.bottomup.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1165 | self.bottomup.add(nn.BatchNorm()) 1166 | self.bottomup.add(nn.Activation('relu')) 1167 | self.bottomup.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1168 | self.bottomup.add(nn.BatchNorm()) 1169 | self.bottomup.add(nn.Activation('sigmoid')) 1170 | 1171 | def hybrid_forward(self, F, x, residual): 1172 | 1173 | x = self.bn1(x) 1174 | residual = self.bn2(residual) 1175 | xa = x + residual 1176 | 1177 | topdown_wei = self.topdown(xa) 1178 | bottomup_wei = self.bottomup(xa) 1179 | 1180 | xo = F.broadcast_mul(x, topdown_wei) + F.broadcast_mul(residual, bottomup_wei) 1181 | 1182 | return xo 1183 | 1184 | 1185 | 1186 | class LocalSK_FPNFuse(HybridBlock): 1187 | def __init__(self, channels=64, r=2): 1188 | super(LocalSK_FPNFuse, self).__init__() 1189 | inter_channels = int(channels // r) 1190 | 1191 | with self.name_scope(): 1192 | 1193 | self.bn1 = nn.BatchNorm() 1194 | self.bn2 = nn.BatchNorm() 1195 | 1196 | self.global_att = nn.HybridSequential(prefix='global_att') 1197 | self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1198 | self.global_att.add(nn.BatchNorm()) 1199 | self.global_att.add(nn.Activation('relu')) 1200 | self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1201 | self.global_att.add(nn.BatchNorm()) 1202 | self.global_att.add(nn.Activation('sigmoid')) 1203 | 1204 | def hybrid_forward(self, F, x, residual): 1205 | 1206 | x = self.bn1(x) 1207 | residual = self.bn2(residual) 1208 | xa = x + residual 1209 | 1210 | wei = self.global_att(xa) 1211 | 1212 | xo = F.broadcast_mul(x, wei) + F.broadcast_mul(residual, 1-wei) 1213 | 1214 | return xo 1215 | 1216 | 1217 | 1218 | class ReMo_FPNFuse(HybridBlock): 1219 | def __init__(self, channels=64, r=2): 1220 | super(ReMo_FPNFuse, self).__init__() 1221 | inter_channels = int(channels // r) 1222 | 1223 | with self.name_scope(): 1224 | 1225 | self.global_att = nn.HybridSequential(prefix='global_att') 1226 | self.global_att.add(nn.GlobalAvgPool2D()) 1227 | self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1228 | self.global_att.add(nn.BatchNorm()) 1229 | self.global_att.add(nn.Activation('relu')) 1230 | self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1231 | self.global_att.add(nn.BatchNorm()) 1232 | self.global_att.add(nn.Activation('sigmoid')) 1233 | 1234 | self.local_att = nn.HybridSequential(prefix='local_att') 1235 | self.local_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1236 | self.local_att.add(nn.BatchNorm()) 1237 | self.local_att.add(nn.Activation('relu')) 1238 | self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1239 | self.local_att.add(nn.BatchNorm()) 1240 | self.local_att.add(nn.Activation('sigmoid')) 1241 | 1242 | def hybrid_forward(self, F, x, residual): 1243 | 1244 | global_wei = self.global_att(x) 1245 | local_wei = self.local_att(residual) 1246 | 1247 | wei = F.broadcast_add(global_wei, local_wei) 1248 | xo = F.broadcast_mul(x + residual, wei) 1249 | 1250 | return xo 1251 | 1252 | 1253 | 1254 | class LocalReMo_FPNFuse(HybridBlock): 1255 | def __init__(self, channels=64, r=2): 1256 | super(LocalReMo_FPNFuse, self).__init__() 1257 | inter_channels = int(channels // r) 1258 | 1259 | with self.name_scope(): 1260 | 1261 | self.global_att = nn.HybridSequential(prefix='global_att') 1262 | self.global_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1263 | self.global_att.add(nn.BatchNorm()) 1264 | self.global_att.add(nn.Activation('relu')) 1265 | self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1266 | self.global_att.add(nn.BatchNorm()) 1267 | self.global_att.add(nn.Activation('sigmoid')) 1268 | 1269 | self.local_att = nn.HybridSequential(prefix='local_att') 1270 | self.local_att.add(nn.Conv2D(inter_channels, kernel_size=1, strides=1, padding=0)) 1271 | self.local_att.add(nn.BatchNorm()) 1272 | self.local_att.add(nn.Activation('relu')) 1273 | self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) 1274 | self.local_att.add(nn.BatchNorm()) 1275 | self.local_att.add(nn.Activation('sigmoid')) 1276 | 1277 | def hybrid_forward(self, F, x, residual): 1278 | 1279 | global_wei = self.global_att(x) 1280 | local_wei = self.local_att(residual) 1281 | 1282 | wei = F.broadcast_add(global_wei, local_wei) 1283 | xo = F.broadcast_mul(x + residual, wei) 1284 | 1285 | return xo 1286 | 1287 | 1288 | 1289 | class ResNetFCN(HybridBlock): 1290 | def __init__(self, layers, channels, classes=1, norm_layer=BatchNorm, norm_kwargs=None, 1291 | **kwargs): 1292 | super(ResNetFCN, self).__init__(**kwargs) 1293 | 1294 | self.layer_num = len(layers) 1295 | with self.name_scope(): 1296 | 1297 | stem_width = int(channels[0]) 1298 | self.stem = nn.HybridSequential(prefix='stem') 1299 | self.stem.add(norm_layer(scale=False, center=False, 1300 | **({} if norm_kwargs is None else norm_kwargs))) 1301 | self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 1302 | padding=1, use_bias=False)) 1303 | self.stem.add(norm_layer(in_channels=stem_width*2)) 1304 | self.stem.add(nn.Activation('relu')) 1305 | 1306 | self.head = _FCNHead(in_channels=channels[-1], channels=classes) 1307 | 1308 | self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0], 1309 | channels=channels[1], stride=1, stage_index=1, 1310 | in_channels=channels[1]) 1311 | 1312 | self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1], 1313 | channels=channels[2], stride=2, stage_index=2, 1314 | in_channels=channels[1]) 1315 | 1316 | self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2], 1317 | channels=channels[3], stride=2, stage_index=3, 1318 | in_channels=channels[2]) 1319 | 1320 | def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0, 1321 | norm_layer=BatchNorm, norm_kwargs=None): 1322 | layer = nn.HybridSequential(prefix='stage%d_'%stage_index) 1323 | with layer.name_scope(): 1324 | downsample = (channels != in_channels) or (stride != 1) 1325 | layer.add(block(channels, stride, downsample, in_channels=in_channels, 1326 | prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 1327 | for _ in range(layers-1): 1328 | layer.add(block(channels, 1, False, in_channels=channels, prefix='', 1329 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 1330 | return layer 1331 | 1332 | def hybrid_forward(self, F, x): 1333 | 1334 | _, _, orig_hei, orig_wid = x.shape 1335 | x = self.stem(x) # sub 2 1336 | c1 = self.layer1(x) # sub 2 1337 | c2 = self.layer2(c1) # sub 4 1338 | c3 = self.layer3(c2) # sub 8 1339 | 1340 | pred = self.head(c3) 1341 | out = F.contrib.BilinearResize2D(pred, height=orig_hei, width=orig_wid) 1342 | 1343 | return out 1344 | 1345 | def evaluate(self, x): 1346 | """evaluating network with inputs and targets""" 1347 | return self.forward(x) 1348 | 1349 | 1350 | class ResNetFPN(HybridBlock): 1351 | def __init__(self, layers, channels, classes=1, norm_layer=BatchNorm, norm_kwargs=None, 1352 | **kwargs): 1353 | super(ResNetFCN, self).__init__(**kwargs) 1354 | 1355 | self.layer_num = len(layers) 1356 | with self.name_scope(): 1357 | 1358 | stem_width = int(channels[0]) 1359 | self.stem = nn.HybridSequential(prefix='stem') 1360 | self.stem.add(norm_layer(scale=False, center=False, 1361 | **({} if norm_kwargs is None else norm_kwargs))) 1362 | self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1, 1363 | padding=1, use_bias=False)) 1364 | self.stem.add(norm_layer(in_channels=stem_width*2)) 1365 | self.stem.add(nn.Activation('relu')) 1366 | 1367 | self.head = _FCNHead(in_channels=channels[-1], channels=classes) 1368 | 1369 | self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0], 1370 | channels=channels[1], stride=1, stage_index=1, 1371 | in_channels=channels[1]) 1372 | 1373 | self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1], 1374 | channels=channels[2], stride=2, stage_index=2, 1375 | in_channels=channels[1]) 1376 | 1377 | self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2], 1378 | channels=channels[3], stride=2, stage_index=3, 1379 | in_channels=channels[2]) 1380 | 1381 | def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0, 1382 | norm_layer=BatchNorm, norm_kwargs=None): 1383 | layer = nn.HybridSequential(prefix='stage%d_'%stage_index) 1384 | with layer.name_scope(): 1385 | downsample = (channels != in_channels) or (stride != 1) 1386 | layer.add(block(channels, stride, downsample, in_channels=in_channels, 1387 | prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 1388 | for _ in range(layers-1): 1389 | layer.add(block(channels, 1, False, in_channels=channels, prefix='', 1390 | norm_layer=norm_layer, norm_kwargs=norm_kwargs)) 1391 | return layer 1392 | 1393 | def hybrid_forward(self, F, x): 1394 | 1395 | _, _, orig_hei, orig_wid = x.shape 1396 | x = self.stem(x) # sub 2 1397 | c1 = self.layer1(x) # sub 2 1398 | c2 = self.layer2(c1) # sub 4 1399 | c3 = self.layer3(c2) # sub 8 1400 | 1401 | pred = self.head(c3) 1402 | out = F.contrib.BilinearResize2D(pred, height=orig_hei, width=orig_wid) 1403 | 1404 | return out 1405 | 1406 | def evaluate(self, x): 1407 | """evaluating network with inputs and targets""" 1408 | return self.forward(x) 1409 | --------------------------------------------------------------------------------