├── Data ├── Datasets │ ├── BERLIN │ │ └── .gitkeep │ └── EUSAR │ │ └── .gitkeep ├── Test │ ├── BERLIN │ │ └── .gitkeep │ └── EUSAR │ │ └── .gitkeep ├── Train │ ├── BERLIN │ │ └── .gitkeep │ └── EUSAR │ │ └── .gitkeep └── Val │ ├── BERLIN │ └── .gitkeep │ └── EUSAR │ └── .gitkeep ├── Docker └── dockerfile ├── Docs ├── arch │ ├── disc.tex │ ├── gen.tex │ └── prova.tex ├── final_presentation │ └── unitn1.pdf └── manuscript │ ├── img │ ├── compare.png │ └── map_img.png │ ├── manuscript.pdf │ └── presentation.pdf ├── LICENSE ├── Lib ├── Datasets │ ├── EUSAR │ │ ├── ConcatDataset.py │ │ ├── EUSARDataset.py │ │ └── __pycache__ │ │ │ ├── ConcatDataset.cpython-38.pyc │ │ │ ├── EUSARDataset.cpython-37.pyc │ │ │ └── EUSARDataset.cpython-38.pyc │ ├── processing │ │ ├── EUSAR_data_processing.py │ │ ├── __pycache__ │ │ │ ├── EUSAR_data_processing.cpython-38.pyc │ │ │ ├── utility.cpython-37.pyc │ │ │ └── utility.cpython-38.pyc │ │ └── utility.py │ └── runner │ │ ├── EUSAR_tile_processing.py │ │ └── cut_in_patches.py ├── Nets │ ├── BL │ │ ├── BL.py │ │ └── __pycache__ │ │ │ └── BL.cpython-38.pyc │ ├── CAT │ │ └── CAT.py │ ├── Cycle_AT │ │ ├── Cycle_AT.py │ │ └── __pycache__ │ │ │ ├── CGAN_arch.cpython-38.pyc │ │ │ ├── DecayLR.cpython-38.pyc │ │ │ ├── ReplyBuffer.cpython-38.pyc │ │ │ ├── SegNet_arch.cpython-38.pyc │ │ │ ├── TrainerCGAN.cpython-38.pyc │ │ │ └── config_routineCGAN.cpython-38.pyc │ ├── RT │ │ └── RT.py │ ├── SN │ │ ├── SN.py │ │ └── __pycache__ │ │ │ ├── SN.cpython-37.pyc │ │ │ └── SN.cpython-38.pyc │ └── utils │ │ ├── arch │ │ └── arch.py │ │ ├── config │ │ ├── __pycache__ │ │ │ ├── config_routine.cpython-37.pyc │ │ │ ├── config_routine.cpython-38.pyc │ │ │ ├── general_parser.cpython-37.pyc │ │ │ ├── general_parser.cpython-38.pyc │ │ │ ├── specific_parser.cpython-37.pyc │ │ │ └── specific_parser.cpython-38.pyc │ │ ├── config_routine.py │ │ ├── general_parser.py │ │ └── specific_parser.py │ │ └── generic │ │ ├── DecayLR.py │ │ ├── ReplyBuffer.py │ │ ├── __pycache__ │ │ ├── DecayLR.cpython-37.pyc │ │ ├── DecayLR.cpython-38.pyc │ │ ├── ReplyBuffer.cpython-37.pyc │ │ ├── ReplyBuffer.cpython-38.pyc │ │ ├── generic_training.cpython-37.pyc │ │ ├── generic_training.cpython-38.pyc │ │ ├── image2tensorboard.cpython-37.pyc │ │ ├── image2tensorboard.cpython-38.pyc │ │ ├── init_weights.cpython-37.pyc │ │ ├── init_weights.cpython-38.pyc │ │ ├── trainSN.cpython-37.pyc │ │ └── trainSN.cpython-38.pyc │ │ ├── generic_training.py │ │ ├── image2tensorboard.py │ │ ├── init_weights.py │ │ ├── tile_creator.py │ │ └── trainSN.py └── utils │ ├── Logger │ ├── Logger.py │ ├── Logger_cmp.py │ └── __pycache__ │ │ ├── Logger.cpython-37.pyc │ │ ├── Logger.cpython-38.pyc │ │ └── Logger_cmp.cpython-38.pyc │ ├── __pycache__ │ ├── generic_utils.cpython-38.pyc │ ├── image2tensorboard.cpython-38.pyc │ ├── init_weights.cpython-38.pyc │ └── net_util.cpython-38.pyc │ ├── generic │ ├── __pycache__ │ │ ├── generic_utils.cpython-37.pyc │ │ └── generic_utils.cpython-38.pyc │ └── generic_utils.py │ └── metrics │ ├── Accuracy.py │ └── __pycache__ │ ├── Accuracy.cpython-37.pyc │ └── Accuracy.cpython-38.pyc ├── eval.py ├── for_server.sh ├── mainBL.py ├── mainCAT.py ├── mainCycle_AT.py ├── mainRT.py ├── mainSN.py ├── readme.md └── runs ├── agan └── .gitkeep ├── bl └── .gitkeep ├── cgan └── .gitkeep ├── rgan └── .gitkeep ├── runs └── .gitkeep └── sn └── .gitkeep /Data/Datasets/BERLIN/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Data/Datasets/BERLIN/.gitkeep -------------------------------------------------------------------------------- /Data/Datasets/EUSAR/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Data/Datasets/EUSAR/.gitkeep -------------------------------------------------------------------------------- /Data/Test/BERLIN/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Data/Test/BERLIN/.gitkeep -------------------------------------------------------------------------------- /Data/Test/EUSAR/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Data/Test/EUSAR/.gitkeep -------------------------------------------------------------------------------- /Data/Train/BERLIN/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Data/Train/BERLIN/.gitkeep -------------------------------------------------------------------------------- /Data/Train/EUSAR/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Data/Train/EUSAR/.gitkeep -------------------------------------------------------------------------------- /Data/Val/BERLIN/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Data/Val/BERLIN/.gitkeep -------------------------------------------------------------------------------- /Data/Val/EUSAR/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Data/Val/EUSAR/.gitkeep -------------------------------------------------------------------------------- /Docker/dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime 2 | LABEL maintainer=ac.alessandrocattoi@gmail.com 3 | LABEL version="1.0" 4 | LABEL description="This is a custom image to run GAN transcoding with pytorch, basically standard pytorch env plus some py lib" 5 | WORKDIR / 6 | RUN apt-get update 7 | RUN apt install tree 8 | RUN apt install nano 9 | RUN apt install screen -y 10 | RUN apt-get update 11 | RUN pip install -U scikit-learn 12 | RUN python -m pip install --user numpy scipy 13 | RUN pip install colour 14 | RUN python -m pip install -U scikit-image 15 | RUN pip install plotly 16 | RUN pip install tensorboard 17 | RUN pip install pandas 18 | RUN mkdir -p /home/ale/Documents/Python/13_Tesi_2/ 19 | RUN echo 'export PYTHONPATH="/home/ale/Documents/Python/13_Tesi_2/"' >>~/.bashrc 20 | RUN echo 'h=/home/ale/Documents/Python/13_Tesi_2' >>~/.bashrc 21 | -------------------------------------------------------------------------------- /Docs/arch/disc.tex: -------------------------------------------------------------------------------- 1 | \documentclass[border=15pt, multi, tikz]{standalone} 2 | \usepackage{import} 3 | \subimport{../../layers/}{init} 4 | \usetikzlibrary{positioning} 5 | \usetikzlibrary{3d} %for including external image 6 | 7 | \def\resnet{rgb:red,5} 8 | \def\tanh{rgb:blue,5} 9 | \def\dropout{rgb:black,5} 10 | \def\relu{rgb:blue,5} 11 | \def\DeconvColor{rgb:pink,2} 12 | \def\instnormcolor{rgb:green,5} 13 | \def\ConvColor{rgb:yellow,5;red,2;white,5} 14 | \def\SumColor{rgb:blue,5;green,15} 15 | \def\padcolor{rgb:magenta,5;black,7} 16 | 17 | 18 | \def\ConvReluColor{rgb:yellow,5} 19 | \def\PoolColor{rgb:red,1;black,0.6} 20 | \def\SoftmaxColor{rgb:magenta,5;black,7} 21 | 22 | \begin{document} 23 | \begin{tikzpicture} 24 | \tikzstyle{connection}=[ultra thick,every node/.style={sloped,allow upside down},draw=\edgecolor,opacity=0.7] 25 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 26 | %% Draw Layer Blocks 27 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 28 | \node[canvas is zy plane at x=0] (temp) at (-2,0,0) {\includegraphics[width=7cm,height=8cm]{rgb.png}}; 29 | 30 | \node[draw,align=center] at (9.5,7) {\Huge Discriminator}; 31 | 32 | \pic[shift={(0,4.50)}] at (5,0,0) {Box={name=leg2, caption=Conv2d,% 33 | fill=\ConvColor,opacity=0.5,height=6,width=6,depth=6}}; 34 | \pic[shift={(0,4.50)}] at (9,0,0) {Box={name=leg3, caption=InstanceNorm2d,% 35 | fill=\instnormcolor,opacity=0.5,height=6,width=6,depth=6}}; 36 | \pic[shift={(0,4.50)}] at (13,0,0) {Box={name=leg4, caption=LeakyRelu,% 37 | fill=\relu,opacity=0.5,height=6,width=6,depth=6}}; 38 | 39 | 40 | %CONV FIRST 41 | \pic[shift={(1,0,0)}] at (0,0,0) {Box={name=conv1,% 42 | fill=\ConvColor,opacity=0.5,height=20,width=3,depth=20}}; 43 | \pic[shift={(0,0,0)}] at (conv1-east) {Box={name=relu1,% 44 | fill=\relu,opacity=0.5,height=20,width=2,depth=20}}; 45 | %DESCRIPTION 46 | \node[draw,align=center] at (1.5,-3.6) {Input\\Conv2D\\(4X4@64 stride=2)}; 47 | 48 | %CONV1 49 | \pic[shift={(3,0,0)}] at (relu1-east) {Box={name=conv2,% 50 | fill=\ConvColor,opacity=0.5,height=10,width=4,depth=10}}; 51 | \pic[shift={(0,0,0)}] at (conv2-east) {Box={name=instnorm1,% 52 | fill=\instnormcolor,opacity=0.5,height=10,width=2,depth=10}}; 53 | \pic[shift={(0,0,0)}] at (instnorm1-east) {Box={name=relu2,% 54 | fill=\relu,opacity=0.5,height=10,width=2,depth=10}}; 55 | %DESCRIPTION 56 | \node[draw,align=center] at (5.8,-2.2) {Convolutional\\Conv2D\\(4X4@128 stride=2)}; 57 | 58 | %CONV2 59 | \pic[shift={(2.5,0,0)}] at (relu2-east) {Box={name=conv3,% 60 | fill=\ConvColor,opacity=0.5,height=5,width=8,depth=5}}; 61 | \pic[shift={(0,0,0)}] at (conv3-east) {Box={name=instnorm2,% 62 | fill=\instnormcolor,opacity=0.5,height=5,width=2,depth=5}}; 63 | \pic[shift={(0,0,0)}] at (instnorm2-east) {Box={name=relu3,% 64 | fill=\relu,opacity=0.5,height=5,width=2,depth=5}}; 65 | %DESCRIPTION 66 | \node[draw,align=center] at (10.3,-1.5) {Convolutional\\Conv2D\\(3x3@256 stride=2)}; 67 | 68 | %CONV3 69 | \pic[shift={(2,0,0)}] at (relu3-east) {Box={name=conv4,% 70 | fill=\ConvColor,opacity=0.5,height=4.5,width=16,depth=4.5}}; 71 | \pic[shift={(0,0,0)}] at (conv4-east) {Box={name=instnorm3,% 72 | fill=\instnormcolor,opacity=0.5,height=4.5,width=2,depth=4.5}}; 73 | \pic[shift={(0,0,0)}] at (instnorm3-east) {Box={name=relu4,% 74 | fill=\relu,opacity=0.5,height=4.5,width=2,depth=4.5}}; 75 | %DESCRIPTION 76 | \node[draw,align=center] at (15.5,-1.5) {Convolutional\\Conv2D\\(4X4@512)}; 77 | 78 | %CONV LAST 79 | \pic[shift={(1.5,0,0)}] at (relu4-east) {Box={name=conv5,% 80 | fill=\ConvColor,opacity=0.5,height=4,width=2,depth=4}}; 81 | %DESCRIPTION 82 | \node[draw,align=center] at (19.1,-1.5) {Convolutional\\Conv2D\\(4X4@1)}; 83 | 84 | \pic[shift={(1,0,0)}] at (conv5-east) {Box={name=ph,% 85 | fill=\tanh,opacity=0.5,height=5,width=0,depth=5}}; 86 | \node[canvas is zy plane at x=0] (1,0,0) at (ph-east) {\includegraphics[width=1cm,height=1cm]{res.jpg}}; 87 | 88 | 89 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 90 | %% Draw connections 91 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 92 | \draw [connection] (-2,0,0) -- node {\midarrow} (conv1-west); 93 | \draw [connection] (relu1-east) -- node {\midarrow} (conv2-west); 94 | \draw [connection] (relu2-east) -- node {\midarrow} (conv3-west); 95 | \draw [connection] (relu3-east) -- node {\midarrow} (conv4-west); 96 | \draw [connection] (relu4-east) -- node {\midarrow} (conv5-west); 97 | \draw [connection] (conv5-east) -- node {\midarrow} (ph-west); 98 | 99 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 100 | 101 | \end{tikzpicture} 102 | \end{document}\grid 103 | -------------------------------------------------------------------------------- /Docs/arch/gen.tex: -------------------------------------------------------------------------------- 1 | \documentclass[border=15pt, multi, tikz]{standalone} 2 | \usepackage{import} 3 | \subimport{../../layers/}{init} 4 | \usetikzlibrary{positioning} 5 | \usetikzlibrary{3d} %for including external image 6 | 7 | \def\resnet{rgb:red,5} 8 | \def\tanh{rgb:blue,5} 9 | \def\dropout{rgb:black,5} 10 | \def\relu{rgb:orange,5} 11 | \def\DeconvColor{rgb:pink,2} 12 | \def\instnormcolor{rgb:green,5} 13 | \def\ConvColor{rgb:yellow,5;red,2;white,5} 14 | \def\SumColor{rgb:blue,5;green,15} 15 | \def\padcolor{rgb:magenta,5;black,7} 16 | 17 | 18 | \def\ConvReluColor{rgb:yellow,5} 19 | \def\PoolColor{rgb:red,1;black,0.6} 20 | \def\SoftmaxColor{rgb:magenta,5;black,7} 21 | 22 | \begin{document} 23 | \begin{tikzpicture} 24 | \tikzstyle{connection}=[ultra thick,every node/.style={sloped,allow upside down},draw=\edgecolor,opacity=0.7] 25 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 26 | %% Draw Layer Blocks 27 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 28 | \node[canvas is zy plane at x=0] (temp) at (-2,0,0) {\includegraphics[width=7cm,height=8cm]{train_sar.png}}; 29 | 30 | \pic[shift={(0,4.50)}] at (6,0,0) {Box={name=leg1, caption=ReflectionPad2d,% 31 | fill=\padcolor,opacity=0.5,height=6,width=6,depth=6}}; 32 | \pic[shift={(0,4.50)}] at (9.5,0,0) {Box={name=leg2, caption=Conv2d,% 33 | fill=\ConvColor,opacity=0.5,height=6,width=6,depth=6}}; 34 | \pic[shift={(0,4.50)}] at (13,0,0) {Box={name=leg3, caption=InstanceNorm2d,% 35 | fill=\instnormcolor,opacity=0.5,height=6,width=6,depth=6}}; 36 | \pic[shift={(0,4.50)}] at (16.5,0,0) {Box={name=leg4, caption=Relu,% 37 | fill=\relu,opacity=0.5,height=6,width=6,depth=6}}; 38 | \pic[shift={(0,4.50)}] at (20,0,0) {Box={name=leg5, caption=Dropout,% 39 | fill=\dropout,opacity=0.5,height=6,width=6,depth=6}}; 40 | \pic[shift={(0,4.50)}] at (23.5,0,0) {Box={name=leg6, caption=ResBlock,% 41 | fill=\resnet,opacity=0.5,height=6,width=6,depth=6}}; 42 | \pic[shift={(0,4.50)}] at (27,0,0) {Box={name=leg7, caption=ConvTranspose2d,% 43 | fill=\DeconvColor,opacity=0.5,height=6,width=6,depth=6}}; 44 | 45 | %Input 46 | \pic[shift={(0,0,0)}] at (0,0,0) {Box={name=pad1,% 47 | fill=\padcolor,opacity=0.5,height=45,width=1,depth=45}}; 48 | \pic[shift={(0,0,0)}] at (pad1-east) {Box={name=conv1,% 49 | fill=\ConvColor,opacity=0.5,height=40,width=2,depth=40}}; 50 | \pic[shift={(0,0,0)}] at (conv1-east) {Box={name=instnorm1,% 51 | fill=\instnormcolor,opacity=0.5,height=40,width=1,depth=40}}; 52 | \pic[shift={(0,0,0)}] at (instnorm1-east) {Box={name=relu,% 53 | fill=\relu,opacity=0.5,height=40,width=1,depth=40}}; 54 | \node[draw,align=center] at (-0.7,-7.1) {Input\\ReflectionPad2d(3)\\Conv2D(7x7@64)}; 55 | 56 | %conv1 57 | \pic[shift={(1,0,0)}] at (relu-east) {Box={name=conv2,% 58 | fill=\ConvColor,opacity=0.5,height=20,width=3,depth=20}}; 59 | \pic[shift={(0,0,0)}] at (conv2-east) {Box={name=instnorm2,% 60 | fill=\instnormcolor,opacity=0.5,height=20,width=1,depth=20}}; 61 | \pic[shift={(0,0,0)}] at (instnorm2-east) {Box={name=relu2,% 62 | fill=\relu,opacity=0.5,height=20,width=1,depth=20}}; 63 | \node[draw,align=center] at (2,-3.8) {Downsampling\\Conv2D\\(3x3@128 stride=2)}; 64 | 65 | %conv2 66 | \pic[shift={(1,0,0)}] at (relu2-east) {Box={name=conv3,% 67 | fill=\ConvColor,opacity=0.5,height=10,width=5,depth=10}}; 68 | \pic[shift={(0,0,0)}] at (conv3-east) {Box={name=instnorm3,% 69 | fill=\instnormcolor,opacity=0.5,height=10,width=1,depth=10}}; 70 | \pic[shift={(0,0,0)}] at (instnorm3-east) {Box={name=relu3,% 71 | fill=\relu,opacity=0.5,height=10,width=1,depth=10}}; 72 | \node[draw,align=center] at (4.5,-2.3) {Downsampling\\Conv2D\\(3x3@256 stride=2)}; 73 | 74 | %RES1 75 | \pic[shift={(1,0,0)}] at (relu3-east) {Box={name=conv4,% 76 | fill=\ConvColor,opacity=0.5,height=10,width=5,depth=10}}; 77 | \pic[shift={(0,0,0)}] at (conv4-east) {Box={name=instnorm4,% 78 | fill=\instnormcolor,opacity=0.5,height=10,width=1,depth=10}}; 79 | \pic[shift={(0,0,0)}] at (instnorm4-east) {Box={name=relu4,% 80 | fill=\relu,opacity=0.5,height=10,width=1,depth=10}}; 81 | \pic[shift={(0,0,0)}] at (relu4-east) {Box={name=dropout1,% 82 | fill=\dropout,opacity=0.5,height=10,width=1,depth=10}}; 83 | \pic[shift={(0,0,0)}] at (dropout1-east) {Box={name=pad2,% 84 | fill=\padcolor,opacity=0.5,height=10,width=1,depth=10}}; 85 | \pic[shift={(0,0,0)}] at (pad2-east) {Box={name=conv5,% 86 | fill=\ConvColor,opacity=0.5,height=10,width=5,depth=10}}; 87 | \pic[shift={(0,0,0)}] at (conv5-east) {Box={name=instnorm5,% 88 | fill=\instnormcolor,opacity=0.5,height=10,width=1,depth=10}}; 89 | \node[draw,align=center] at (7.8,-2.51) {Residual Block\\ Conv2D(3x3@256)\\ Dropoout(0.5)\\ReflectionPad2d(1)}; 90 | 91 | %PLUS 92 | \pic[shift={(1,0,0)}] at (instnorm5-east) {Ball={name=sum1,% 93 | fill=\SumColor,opacity=0.6,% 94 | radius=1.5,logo=$+$}}; 95 | \pic[shift={(0,-4,0)}] at (dropout1-west) {Box={name=dummy1,% 96 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 97 | \pic[shift={(0,-3.70,0)}] at (sum1-south) {Box={name=dummy21,% 98 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 99 | 100 | %RES2 101 | \pic[shift={(0.8,0,0)}] at (sum1-east) {Box={name=resnet2,% 102 | fill=\resnet,opacity=0.5,height=10,width=1,depth=10}}; 103 | \node[draw,align=center] at (11.65,-1.8) {ResBlock}; 104 | 105 | %PLUS 106 | \pic[shift={(1,0,0)}] at (resnet2-east) {Ball={name=sum2,% 107 | fill=\SumColor,opacity=0.6,% 108 | radius=1.5,logo=$+$}}; 109 | \pic[shift={(0,-4,0)}] at (resnet2-east) {Box={name=dummy2,% 110 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 111 | \pic[shift={(0,-3.70,0)}] at (sum2-south) {Box={name=dummy22,% 112 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 113 | 114 | %RES3 115 | \pic[shift={(0.8,0,0)}] at (sum2-east) {Box={name=resnet3,% 116 | fill=\resnet,opacity=0.5,height=10,width=1,depth=10}}; 117 | \node[draw,align=center] at (13.95,-1.8) {ResBlock}; 118 | 119 | %PLUS 120 | \pic[shift={(1,0,0)}] at (resnet3-east) {Ball={name=sum3,% 121 | fill=\SumColor,opacity=0.6,% 122 | radius=1.5,logo=$+$}}; 123 | \pic[shift={(0,-4,0)}] at (resnet3-east) {Box={name=dummy3,% 124 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 125 | \pic[shift={(0,-3.70,0)}] at (sum3-south) {Box={name=dummy23,% 126 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 127 | 128 | %res4 129 | \pic[shift={(0.8,0,0)}] at (sum3-east) {Box={name=resnet4,% 130 | fill=\resnet,opacity=0.5,height=10,width=1,depth=10}}; 131 | \node[draw,align=center] at (16.25,-1.8) {ResBlock}; 132 | 133 | %PLUS 134 | \pic[shift={(1,0,0)}] at (resnet4-east) {Ball={name=sum4,% 135 | fill=\SumColor,opacity=0.6,% 136 | radius=1.5,logo=$+$}}; 137 | \pic[shift={(0,-4,0)}] at (resnet4-east) {Box={name=dummy4,% 138 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 139 | \pic[shift={(0,-3.70,0)}] at (sum4-south) {Box={name=dummy24,% 140 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 141 | 142 | %res5 143 | \pic[shift={(0.8,0,0)}] at (sum4-east) {Box={name=resnet5,% 144 | fill=\resnet,opacity=0.5,height=10,width=1,depth=10}}; 145 | \node[draw,align=center] at (18.55,-1.8) {ResBlock}; 146 | 147 | %PLUS 148 | \pic[shift={(1,0,0)}] at (resnet5-east) {Ball={name=sum5,% 149 | fill=\SumColor,opacity=0.6,% 150 | radius=1.5,logo=$+$}}; 151 | \pic[shift={(0,-4,0)}] at (resnet5-east) {Box={name=dummy5,% 152 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 153 | \pic[shift={(0,-3.70,0)}] at (sum5-south) {Box={name=dummy25,% 154 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 155 | 156 | %res6 157 | \pic[shift={(0.8,0,0)}] at (sum5-east) {Box={name=resnet6,% 158 | fill=\resnet,opacity=0.5,height=10,width=1,depth=10}}; 159 | \node[draw,align=center] at (20.85,-1.8) {ResBlock}; 160 | 161 | %PLUS 162 | \pic[shift={(1,0,0)}] at (resnet6-east) {Ball={name=sum6,% 163 | fill=\SumColor,opacity=0.6,% 164 | radius=1.5,logo=$+$}}; 165 | \pic[shift={(0,-4,0)}] at (resnet6-east) {Box={name=dummy6,% 166 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 167 | \pic[shift={(0,-3.70,0)}] at (sum6-south) {Box={name=dummy26,% 168 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 169 | 170 | %res7 171 | \pic[shift={(0.8,0,0)}] at (sum6-east) {Box={name=resnet7,% 172 | fill=\resnet,opacity=0.5,height=10,width=1,depth=10}}; 173 | \node[draw,align=center] at (23.2,-1.8) {ResBlock}; 174 | 175 | %PLUS 176 | \pic[shift={(1,0,0)}] at (resnet7-east) {Ball={name=sum7,% 177 | fill=\SumColor,opacity=0.6,% 178 | radius=1.5,logo=$+$}}; 179 | \pic[shift={(0,-4,0)}] at (resnet7-east) {Box={name=dummy7,% 180 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 181 | \pic[shift={(0,-3.70,0)}] at (sum7-south) {Box={name=dummy27,% 182 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 183 | 184 | %res8 185 | \pic[shift={(0.8,0,0)}] at (sum7-east) {Box={name=resnet8,% 186 | fill=\resnet,opacity=0.5,height=10,width=1,depth=10}}; 187 | \node[draw,align=center] at (25.5,-1.8) {ResBlock}; 188 | 189 | %PLUS 190 | \pic[shift={(1,0,0)}] at (resnet8-east) {Ball={name=sum8,% 191 | fill=\SumColor,opacity=0.6,% 192 | radius=1.5,logo=$+$}}; 193 | \pic[shift={(0,-4,0)}] at (resnet8-east) {Box={name=dummy8,% 194 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 195 | \pic[shift={(0,-3.70,0)}] at (sum8-south) {Box={name=dummy28,% 196 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 197 | 198 | 199 | %res9 200 | \pic[shift={(0.8,0,0)}] at (sum8-east) {Box={name=resnet9,% 201 | fill=\resnet,opacity=0.5,height=10,width=1,depth=10}}; 202 | \node[draw,align=center] at (27.75,-1.8) {ResBlock}; 203 | 204 | %PLUS 205 | \pic[shift={(1,0,0)}] at (resnet9-east) {Ball={name=sum9,% 206 | fill=\SumColor,opacity=0.6,% 207 | radius=1.5,logo=$+$}}; 208 | \pic[shift={(0,-4,0)}] at (resnet9-east) {Box={name=dummy9,% 209 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 210 | \pic[shift={(0,-3.70,0)}] at (sum9-south) {Box={name=dummy29,% 211 | fill=\instnormcolor,opacity=0.5,height=0,width=0,depth=0}}; 212 | 213 | %UPSAMPLE1 214 | \pic[shift={(1.5,0,0)}] at (sum9-east) {Box={name=conv6,% 215 | fill=\DeconvColor,opacity=0.5,height=20,width=3,depth=20}}; 216 | \pic[shift={(0,0,0)}] at (conv6-east) {Box={name=instnorm6,% 217 | fill=\instnormcolor,opacity=0.5,height=20,width=1,depth=20}}; 218 | \pic[shift={(0,0,0)}] at (instnorm6-east) {Box={name=relu5,% 219 | fill=\relu,opacity=0.5,height=20,width=1,depth=20}}; 220 | \node[draw,align=center] at (29.95,-3.75) {Upsampling\\ConvTranspose2d\\(3x3@128 stride=2)}; 221 | 222 | %UPSAMPLE2 223 | \pic[shift={(1.5,0,0)}] at (relu5-east) {Box={name=conv7,% 224 | fill=\DeconvColor,opacity=0.5,height=40,width=2,depth=40}}; 225 | \pic[shift={(0,0,0)}] at (conv7-east) {Box={name=instnorm7,% 226 | fill=\instnormcolor,opacity=0.5,height=40,width=1,depth=40}}; 227 | \pic[shift={(0,0,0)}] at (instnorm7-east) {Box={name=relu6,% 228 | fill=\relu,opacity=0.5,height=40,width=1,depth=40}}; 229 | \node[draw,align=center] at (31.7,-6.4) {Upsampling\\ConvTranspose2d\\(3x3@64 stride=2)}; 230 | 231 | %OUTPUT 232 | \pic[shift={(1.5,0,0)}] at (relu6-east) {Box={name=pad3,% 233 | fill=\padcolor,opacity=0.5,height=45,width=1,depth=45}}; 234 | \pic[shift={(0,0,0)}] at (pad3-east) {Box={name=conv8,% 235 | fill=\ConvColor,opacity=0.5,height=40,width=1,depth=40}}; 236 | \pic[shift={(0,0,0)}] at (conv8-east) {Box={name=tanh,% 237 | fill=\ConvColor,opacity=0.5,height=40,width=1,depth=40}}; 238 | \node[draw,align=center] at (36.4,-7.3) {Output\\ReflectionPad2d(3)\\Conv2D(7x7@N of Optical channels)\\Conv2D(1x1@N of Optical channels)}; 239 | 240 | 241 | \pic[shift={(2,0,0)}] at (conv8-east) {Box={name=ph,% 242 | fill=\tanh,opacity=0.5,height=5,width=0,depth=5}}; 243 | \node[canvas is zy plane at x=0] (1,0,0) at (ph-east) {\includegraphics[width=7.5cm,height=8.5cm]{train.png}}; 244 | 245 | 246 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 247 | %% Draw connections 248 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 249 | 250 | \draw [connection] (relu-east) -- node {\midarrow} (conv2-west); 251 | \draw [connection] (relu2-east) -- node {\midarrow} (conv3-west); 252 | \draw [connection] (relu3-east) -- node {\midarrow} (conv4-west); 253 | \draw [connection] (instnorm5-east) -- node {\midarrow} (sum1-west); 254 | 255 | \draw [connection] (instnorm5-east) -- node {\midarrow} (sum1-west); 256 | \draw [connection] (resnet2-east) -- node {\midarrow} (sum2-west); 257 | \draw [connection] (resnet3-east) -- node {\midarrow} (sum3-west); 258 | \draw [connection] (resnet4-east) -- node {\midarrow} (sum4-west); 259 | \draw [connection] (resnet5-east) -- node {\midarrow} (sum5-west); 260 | \draw [connection] (resnet6-east) -- node {\midarrow} (sum6-west); 261 | \draw [connection] (resnet7-east) -- node {\midarrow} (sum7-west); 262 | \draw [connection] (resnet8-east) -- node {\midarrow} (sum8-west); 263 | \draw [connection] (resnet9-east) -- node {\midarrow} (sum9-west); 264 | 265 | \draw [connection] (sum1-east) -- (resnet2-west); 266 | \draw [connection] (sum2-east) -- (resnet3-west); 267 | \draw [connection] (sum3-east) -- (resnet4-west); 268 | \draw [connection] (sum4-east) -- (resnet5-west); 269 | \draw [connection] (sum5-east) -- (resnet6-west); 270 | \draw [connection] (sum6-east) -- (resnet7-west); 271 | \draw [connection] (sum7-east) -- (resnet8-west); 272 | \draw [connection] (sum8-east) -- (resnet9-west); 273 | \draw [connection] (sum9-east) -- node {\midarrow} (conv6-west); 274 | 275 | 276 | \path (relu3-east) -- (conv4-west) coordinate[pos=0.25] (between1) ; 277 | \path (resnet2-east) -- (sum1-west) coordinate[pos=0.6] (between2) ; 278 | \path (resnet3-east) -- (sum2-west) coordinate[pos=0.6] (between3) ; 279 | \path (resnet4-east) -- (sum3-west) coordinate[pos=0.6] (between4) ; 280 | \path (resnet5-east) -- (sum4-west) coordinate[pos=0.6] (between5) ; 281 | \path (resnet6-east) -- (sum5-west) coordinate[pos=0.6] (between6) ; 282 | \path (resnet7-east) -- (sum6-west) coordinate[pos=0.6] (between7) ; 283 | \path (resnet8-east) -- (sum7-west) coordinate[pos=0.6] (between8) ; 284 | \path (resnet9-east) -- (sum8-west) coordinate[pos=0.6] (between9) ; 285 | 286 | \draw [connection] (between1) -- (dummy1-west-|between1) -- (dummy1-west); 287 | \draw [connection] (between2) -- (dummy2-west-|between2) -- (dummy2-west); 288 | \draw [connection] (between3) -- (dummy3-west-|between3) -- (dummy3-west); 289 | \draw [connection] (between4) -- (dummy4-west-|between4) -- (dummy4-west); 290 | \draw [connection] (between5) -- (dummy5-west-|between5) -- (dummy5-west); 291 | \draw [connection] (between6) -- (dummy6-west-|between6) -- (dummy6-west); 292 | \draw [connection] (between7) -- (dummy7-west-|between7) -- (dummy7-west); 293 | \draw [connection] (between8) -- (dummy8-west-|between8) -- (dummy8-west); 294 | \draw [connection] (between9) -- (dummy9-west-|between9) -- (dummy9-west); 295 | 296 | 297 | \draw [connection] (dummy1-east) -- (dummy21-west); 298 | \draw [connection] (dummy2-east) -- (dummy22-west); 299 | \draw [connection] (dummy3-east) -- (dummy23-west); 300 | \draw [connection] (dummy4-east) -- (dummy24-west); 301 | \draw [connection] (dummy5-east) -- (dummy25-west); 302 | \draw [connection] (dummy6-east) -- (dummy26-west); 303 | \draw [connection] (dummy7-east) -- (dummy27-west); 304 | \draw [connection] (dummy8-east) -- (dummy28-west); 305 | \draw [connection] (dummy9-east) -- (dummy29-west); 306 | 307 | \draw [connection] (dummy21-east) -- node {\midarrow} (sum1-south); 308 | \draw [connection] (dummy22-east) -- node {\midarrow} (sum2-south); 309 | \draw [connection] (dummy23-east) -- node {\midarrow} (sum3-south); 310 | \draw [connection] (dummy24-east) -- node {\midarrow} (sum4-south); 311 | \draw [connection] (dummy25-east) -- node {\midarrow} (sum5-south); 312 | \draw [connection] (dummy26-east) -- node {\midarrow} (sum6-south); 313 | \draw [connection] (dummy27-east) -- node {\midarrow} (sum7-south); 314 | \draw [connection] (dummy28-east) -- node {\midarrow} (sum8-south); 315 | \draw [connection] (dummy29-east) -- node {\midarrow} (sum9-south); 316 | 317 | \draw [connection] (relu5-east) -- node {\midarrow} (conv7-west); 318 | \draw [connection] (relu6-east) -- node {\midarrow} (pad3-west); 319 | 320 | 321 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 322 | 323 | \end{tikzpicture} 324 | \end{document}\grid 325 | -------------------------------------------------------------------------------- /Docs/arch/prova.tex: -------------------------------------------------------------------------------- 1 | \documentclass[border=15pt, multi, tikz]{standalone} 2 | \usepackage{import} 3 | \subimport{../../layers/}{init} 4 | \usetikzlibrary{positioning} 5 | \usetikzlibrary{3d} %for including external image 6 | 7 | \def\resnet{rgb:red,5} 8 | \def\tanh{rgb:blue,5} 9 | \def\dropout{rgb:black,5} 10 | \def\relu{rgb:orange,5} 11 | \def\DeconvColor{rgb:pink,2} 12 | \def\instnormcolor{rgb:green,5} 13 | \def\ConvColor{rgb:yellow,5;red,2;white,5} 14 | \def\SumColor{rgb:blue,5;green,15} 15 | \def\padcolor{rgb:magenta,5;black,7} 16 | 17 | 18 | \def\ConvReluColor{rgb:yellow,5} 19 | \def\PoolColor{rgb:red,1;black,0.6} 20 | \def\SoftmaxColor{rgb:magenta,5;black,7} 21 | \def\black{rgb:black,1} 22 | 23 | 24 | \begin{document} 25 | \begin{tikzpicture} 26 | \tikzstyle{connection}=[ultra thick,every node/.style={sloped,allow upside down},draw=\edgecolor,opacity=0.7] 27 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 28 | %% Draw Layer Blocks 29 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 30 | \node[canvas is zy plane at x=0] (temp) at (-2,0,0) {\includegraphics[width=7cm,height=8cm]{train_sar.png}}; 31 | 32 | 33 | %Input 34 | \pic[shift={(0,0,0)}] at (0,0,0) {Box={name=input,% 35 | fill=\dropout,opacity=0.5,height=40,width=5,depth=40}}; 36 | \node[draw,align=center] at (-1.1,-6) {Input}; 37 | 38 | 39 | %DOWN1 40 | \pic[shift={(1.5,0,0)}] at (input-east) {Box={name=down1,% 41 | fill=\dropout,opacity=0.5,height=20,width=6,depth=20}}; 42 | \node[draw,align=center] at (2.8,-3.2) {Downsampling1}; 43 | 44 | 45 | %DOWN2 46 | \pic[shift={(1.5,0,0)}] at (down1-east) {Box={name=down2,% 47 | fill=\dropout,opacity=0.5,height=10,width=8,depth=10}}; 48 | \node[draw,align=center] at (5.7,-1.8) {Downsampling2}; 49 | 50 | 51 | %Input 52 | \pic[shift={(1,0,0)}] at (down2-east) {Box={name=res,% 53 | fill=\dropout,opacity=0.5,height=10,width=20,depth=10}}; 54 | \node[draw,align=center] at (9.5,-1.8) {Residual Blocks}; 55 | 56 | 57 | %UPSAMPLE1 58 | \pic[shift={(1.5,0,0)}] at (res-east) {Box={name=conv6,% 59 | fill=\DeconvColor,opacity=0.5,height=20,width=3,depth=20}}; 60 | \pic[shift={(0,0,0)}] at (conv6-east) {Box={name=instnorm6,% 61 | fill=\instnormcolor,opacity=0.5,height=20,width=1,depth=20}}; 62 | \pic[shift={(0,0,0)}] at (instnorm6-east) {Box={name=relu5,% 63 | fill=\relu,opacity=0.5,height=20,width=1,depth=20}}; 64 | \node[draw,align=center] at (12.5,-3.75) {Upsampling\\ConvTranspose2d\\(3x3@128 stride=2)}; 65 | 66 | %UPSAMPLE2 67 | \pic[shift={(1.5,0,0)}] at (relu5-east) {Box={name=conv7,% 68 | fill=\DeconvColor,opacity=0.5,height=40,width=2,depth=40}}; 69 | \pic[shift={(0,0,0)}] at (conv7-east) {Box={name=instnorm7,% 70 | fill=\instnormcolor,opacity=0.5,height=40,width=1,depth=40}}; 71 | \pic[shift={(0,0,0)}] at (instnorm7-east) {Box={name=relu6,% 72 | fill=\relu,opacity=0.5,height=40,width=1,depth=40}}; 73 | \node[draw,align=center] at (14.4,-6.4) {Upsampling\\ConvTranspose2d\\(3x3@64 stride=2)}; 74 | 75 | %OUTPUT 76 | \pic[shift={(1.5,0,0)}] at (relu6-east) {Box={name=pad3,% 77 | fill=\padcolor,opacity=0.5,height=45,width=1,depth=45}}; 78 | \pic[shift={(0,0,0)}] at (pad3-east) {Box={name=conv8,% 79 | fill=\ConvColor,opacity=0.5,height=40,width=1,depth=40}}; 80 | \node[draw,align=center] at (17.7,-7.1) {Output\\ReflectionPad2d(3)\\Conv2D(7x7@64)}; 81 | 82 | 83 | \pic[shift={(2,0,0)}] at (conv8-east) {Box={name=ph,% 84 | fill=\black,opacity=1,height=42.8,width=0.01,depth=38}}; 85 | \node[canvas is zy plane at x=0] (0,0,0) at (ph-east) {\includegraphics[width=7.5cm,height=8.5cm]{label.png}}; 86 | 87 | 88 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 89 | %% Draw connections 90 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 91 | \draw [connection] (input-east) -- node {\midarrow} (down1-west); 92 | \draw [connection] (down1-east) -- node {\midarrow} (down2-west); 93 | \draw [connection] (down2-east) -- node {\midarrow} (res-west); 94 | \draw [connection] (res-east) -- node {\midarrow} (conv6-west); 95 | \draw [connection] (relu5-east) -- node {\midarrow} (conv7-west); 96 | \draw [connection] (relu6-east) -- node {\midarrow} (pad3-west); 97 | 98 | 99 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 100 | 101 | \end{tikzpicture} 102 | \end{document}\grid 103 | -------------------------------------------------------------------------------- /Docs/final_presentation/unitn1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Docs/final_presentation/unitn1.pdf -------------------------------------------------------------------------------- /Docs/manuscript/img/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Docs/manuscript/img/compare.png -------------------------------------------------------------------------------- /Docs/manuscript/img/map_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Docs/manuscript/img/map_img.png -------------------------------------------------------------------------------- /Docs/manuscript/manuscript.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Docs/manuscript/manuscript.pdf -------------------------------------------------------------------------------- /Docs/manuscript/presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Docs/manuscript/presentation.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alessandro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Lib/Datasets/EUSAR/ConcatDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 5 | """ 6 | Author: Alessandro Cattoi 7 | Description: Concat class implementation 8 | """ 9 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 10 | 11 | 12 | class ConcatDataset(Dataset): 13 | """ 14 | This class can be employed to concat pytorch Datasets instances, in this project it is not employed 15 | """ 16 | def __init__(self, *datasets): 17 | self.datasets = datasets 18 | 19 | def __getitem__(self, i): 20 | return tuple(d[i] for d in self.datasets) 21 | 22 | def __len__(self): 23 | return min(len(d) for d in self.datasets) -------------------------------------------------------------------------------- /Lib/Datasets/EUSAR/EUSARDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from Lib.Datasets.processing.utility import one_hot_2_label_value 5 | import random 6 | 7 | 8 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 9 | """ 10 | Author: Alessandro Cattoi 11 | Description: This file overload the Dataset class of Pytorch 12 | """ 13 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 14 | 15 | 16 | class EUSARDataset(Dataset): 17 | """ 18 | Implementation of the torch Dataset class for EUSAR dataset 19 | The Dataset must be structured as follow: 20 | The naming convetion for patches is: XXXX_XXXXXXX_data.npy 21 | Where: 22 | - XXXX is the tile ID (from 0 to N) 23 | - XXXXXXX is the patch ID (from zero to M) 24 | - data can be label rgb or radar 25 | Dataset 26 | --/ radar 27 | --/ radar patches in np format 28 | --/ rgb 29 | --/ rgb patches in np format 30 | --/ label 31 | --/ label patches in np format 32 | """ 33 | def __init__(self, global_path, b_label=False, b_rgb=False, sar_c=5, opt_c=5, randomized=False): 34 | """ 35 | Init function 36 | :param global_path: path to the directory of data (which should be composed of: rgb, radar, label) 37 | :param b_rgb: use label? 38 | :param b_rgb: use rgb data? 39 | :param b_rgb: use rgb data? 40 | :param sar_c: number of channel of sar data 41 | :param opt_c: number of channel of optical data 42 | :param randomized: 43 | """ 44 | # global path to dataset folder 45 | self.data_path = os.path.join(global_path) 46 | # The dataset folder should always contain folders named as follow 47 | self.radar = "radar" 48 | self.label = "label" 49 | self.rgb = "rgb" 50 | self.b_label = b_label 51 | self.b_rgb = b_rgb 52 | self.sar_c = sar_c 53 | self.opt_c = opt_c 54 | self.randomized = randomized 55 | # create a list of the file in radar folder 56 | self.radar_list = sorted(os.listdir(os.path.join(self.data_path, self.radar))) 57 | # randomly sample file names to be able to access data randomly 58 | self.rand_list = random.sample(list(range(0, len(self.radar_list))), k=len(self.radar_list)) 59 | # label and rgb are generated only if data are requested 60 | if self.b_label: 61 | self.label_list = sorted(os.listdir(os.path.join(self.data_path, self.label))) 62 | if self.b_rgb: 63 | self.rgb_list = sorted(os.listdir(os.path.join(self.data_path, self.rgb))) 64 | 65 | def __len__(self): 66 | """ 67 | If b_label b_rgb are True we want both transcode and train for segmentation so the length is radar_list 68 | If b_label is False and b_rgb is True we want only to transcode so the length is radar_list 69 | If b_label is True and b_rgb is False we want only to segment so the length is label list 70 | :return: number of samples in dataset 71 | """ 72 | if self.b_label and not self.b_rgb: 73 | return len(self.label_list) 74 | else: 75 | return len(self.radar_list) 76 | 77 | def __getitem__(self, idx): 78 | """ 79 | this function return a dictionary of the data requested 80 | :param idx: index 81 | :return: dictionary with {"radar": radar, "label": label, "rgb": rgb, "name": name} 82 | """ 83 | if self.randomized: 84 | idx = self.rand_list[idx] 85 | 86 | radar_name = self.radar_list[idx] 87 | radar = np.load(os.path.join(self.data_path, self.radar, radar_name), allow_pickle=True)[:self.sar_c] 88 | if self.b_label and not self.b_rgb: 89 | label_name = self.label_list[idx] 90 | label = np.load(os.path.join(self.data_path, self.label, label_name), allow_pickle=True) 91 | label = one_hot_2_label_value(label) 92 | return {"radar": radar, "label": label, "name": label_name} 93 | elif self.b_rgb and not self.b_label: 94 | rgb_name = self.rgb_list[idx] 95 | rgb = np.load(os.path.join(self.data_path, self.rgb, rgb_name), allow_pickle=True)[:self.opt_c] 96 | return {"radar": radar, "rgb": rgb, "name": radar_name} 97 | elif self.b_label and self.b_rgb: 98 | label_name = self.label_list[idx] 99 | label = np.load(os.path.join(self.data_path, self.label, label_name), allow_pickle=True) 100 | label = one_hot_2_label_value(label) 101 | rgb_name = self.rgb_list[idx] 102 | rgb = np.load(os.path.join(self.data_path, self.rgb, rgb_name), allow_pickle=True)[:self.opt_c] 103 | return {"radar": radar, "rgb": rgb, "label": label, "name": radar_name} 104 | -------------------------------------------------------------------------------- /Lib/Datasets/EUSAR/__pycache__/ConcatDataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Datasets/EUSAR/__pycache__/ConcatDataset.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Datasets/EUSAR/__pycache__/EUSARDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Datasets/EUSAR/__pycache__/EUSARDataset.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Datasets/EUSAR/__pycache__/EUSARDataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Datasets/EUSAR/__pycache__/EUSARDataset.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Datasets/processing/EUSAR_data_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | from PIL import Image 5 | from skimage import io 6 | from scipy import ndimage 7 | from Lib.Datasets.processing.utility import remove_duplicates, simple_normalizer 8 | 9 | 10 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 11 | """ 12 | Author: Alessandro Cattoi 13 | Description: This file implements function to process EUSAR data preprocessing such as radar feature extraction 14 | """ 15 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 16 | 17 | 18 | def png_to_numpy(data_path, dest_path, name): 19 | """ 20 | This function transform many black and white png images of label in a unique npy array with [Classes, W,H] 21 | The resulting image one hot encoded 22 | :param data_path: where data are now stored 23 | :param dest_path: where new data wil be stored 24 | :param name: name is the output name of the label variable global_path/(name + _label.npy) 25 | :return: NA 26 | """ 27 | t = time.time() 28 | image_list = os.listdir(data_path) 29 | f = open(os.path.join(dest_path, "label", name + "_label.txt"), "a") 30 | f.write("Images to be merged {}\n".format(image_list)) 31 | w, h = Image.open(os.path.join(data_path, image_list[0])).size 32 | f.write("Sample size w {}, h {}\n".format(h, w)) 33 | image = np.zeros((len(image_list), h, w), dtype=np.uint8) 34 | f.write("Created numpy array of format {} and type {}\n".format(image.shape, image.dtype)) 35 | 36 | for i, img_name in enumerate(image_list): 37 | f.write("Band {} name {}\n".format(i, img_name)) 38 | img = Image.open(os.path.join(data_path, img_name)).convert("L") 39 | img = np.asarray(img) 40 | image[i] = img 41 | 42 | image = (image / 255).astype(np.float32) 43 | image = remove_duplicates(image) 44 | f.write("Shape {} type {}\n".format(image.shape, image.dtype)) 45 | u, c = np.unique(image, return_counts=True) 46 | f.write("Unique {} count {}\n".format(u, c)) 47 | image = image.astype(np.uint8) 48 | np.save(os.path.join(dest_path, "label", name + "_label.npy"), image, allow_pickle=True) 49 | 50 | f.write("Execution time = {:.2f} s".format(time.time() - t)) 51 | f.close() 52 | 53 | 54 | def process_tile(tile_path, tile_name, dest_path, action, name, filter_type='box', filter_kernel=3, center=False): 55 | """ 56 | Open tif or numpy tile and apply an action which is a function which process the data and store them in the new 57 | location output dest 58 | :param tile_path: position of the desired path 59 | :param tile_name: name of the tile 60 | :param dest_path: where to store the new tile 61 | :param action: what kind of processing apply {"SAR_feature", "tif2npy"} 62 | :param name: name of the tile 63 | :param filter_type: ['box', 'gaus'] 64 | :param filter_kernel: ['3', '0.4'] filter size 3x3 for more detail see specific fucntions 65 | :param center: center values or not (now if center it re-normalize between -1 and 1) 66 | :return: NA 67 | """ 68 | 69 | t = time.time() 70 | f = open(os.path.join(dest_path, name + "_" + tile_name.split(".")[0] + ".txt"), "a") 71 | formato = tile_name.split(".")[-1] 72 | if formato == "tif": 73 | f.write("Image format {}\n".format(formato)) 74 | tile = io.imread(os.path.join(tile_path, tile_name)) 75 | f.write("Input tile shape {} type {}\n".format(tile.shape, tile.dtype)) 76 | new_tile = np.asarray(tile, order="F") 77 | new_tile = np.rollaxis(new_tile, 2, 0) 78 | new_tile = new_tile.astype('float32') 79 | f.write("Output tile shape {} type {}\n".format(new_tile.shape, new_tile.dtype)) 80 | elif formato == "npy": 81 | f.write("Image format {}\n".format(formato)) 82 | new_tile = np.load(os.path.join(tile_path, tile_name)) 83 | f.write("Input tile shape {} type {}\n".format(new_tile.shape, new_tile.dtype)) 84 | new_tile = new_tile.astype('float32') 85 | f.write("Output tile shape {} type {}\n".format(new_tile.shape, new_tile.dtype)) 86 | else: 87 | f.write("Image format TILE FORMAT INCORRECT\n") 88 | raise NotImplementedError("TILE FORMAT INCORRECT") 89 | 90 | if action == "tif2npy": 91 | f.write("Run tif2npy\n") 92 | new_tile[[0, 1, 2, 3, 4]] = new_tile[[2, 1, 0, 3, 4]] 93 | u, c = np.unique(new_tile, return_counts=True) 94 | f.write("Input tile shape {} type {} unique {} count {}\n".format(new_tile.shape, new_tile.dtype, u, c)) 95 | # new_tile, mean, std, center, mx = normalizer(new_tile, f, center, center) 96 | vect = np.load(os.path.join(tile_path, 'norm.npy')) 97 | mean = vect[0] 98 | std = vect[1] 99 | new_tile = simple_normalizer(new_tile, f, mean, std) 100 | f.write("Norm param mean {} std {}\n".format(mean, std)) 101 | f.write("Saved norm tile shape {} type {} min {} max {}\n".format(new_tile.shape, new_tile.dtype, 102 | np.min(new_tile), np.max(new_tile))) 103 | np.save(os.path.join(dest_path, name + "_" + tile_name.split(".")[0]), new_tile, allow_pickle=True) 104 | '''np.save(os.path.join(dest_path, name + "_mean_" + tile_name.split(".")[0]), mean, allow_pickle=True) 105 | np.save(os.path.join(dest_path, name + "_std_" + tile_name.split(".")[0]), std, allow_pickle=True) 106 | np.save(os.path.join(dest_path, name + "_center_" + tile_name.split(".")[0]), center, allow_pickle=True) 107 | np.save(os.path.join(dest_path, name + "_max_" + tile_name.split(".")[0]), mx, allow_pickle=True)''' 108 | elif action == "SAR_feature": 109 | f.write("Run real_SAR_feature_extractor\n") 110 | u, c = np.unique(new_tile, return_counts=True) 111 | f.write("Input tile shape {} type {} unique {} count {}\n".format(new_tile.shape, new_tile.dtype, u, c)) 112 | new_tile = real_SAR_feature_extractor(new_tile, f, filter_type, filter_kernel) 113 | # new_tile, mean, std, center, mx = normalizer(new_tile, f, center, center) 114 | vect = np.load(os.path.join(tile_path, 'norm.npy')) 115 | mean = vect[0] 116 | std = vect[1] 117 | new_tile = simple_normalizer(new_tile, f, mean, std) 118 | f.write("Norm param mean {} std {}\n".format(mean, std)) 119 | f.write("Saved norm tile shape {} type {} min {} max {}\n".format(new_tile.shape, new_tile.dtype, 120 | np.min(new_tile), np.max(new_tile))) 121 | np.save(os.path.join(dest_path, name + "_" + tile_name.split(".")[0]), new_tile, allow_pickle=True) 122 | '''np.save(os.path.join(dest_path, name + "_mean_" + tile_name.split(".")[0]), mean, allow_pickle=True) 123 | np.save(os.path.join(dest_path, name + "_std_" + tile_name.split(".")[0]), std, allow_pickle=True) 124 | np.save(os.path.join(dest_path, name + "_center_" + tile_name.split(".")[0]), center, allow_pickle=True) 125 | np.save(os.path.join(dest_path, name + "_max_" + tile_name.split(".")[0]), mx, allow_pickle=True)''' 126 | 127 | f.write("Execution time = {:.2f} s".format(time.time() - t)) 128 | f.close() 129 | 130 | 131 | def real_SAR_feature_extractor(raw_data, f, filer_type="box", filter_kernel=3): 132 | """ 133 | This function got an np array of 4 band [{R_hh,I_hh,R_hv,I_hv},w,h] and return an an array of shape [5,w,h] 134 | where each pixel is composed of 5 real value calculated as in EUSAR paper 135 | To process [4,14000,9000] of type float32 it requires around 30 minutes 136 | Gaussian filter param 137 | - sigma < 0.5 -> 3x3 138 | - 0.5 <= sigma < 0.8334 -> 5x5 139 | - 0.8334 <= sigma < 1.17 -> 7x7 140 | - 1.17 <= sigma < 1.45 -> 9x9 141 | - sigma >= 1.45 -> 11x11 142 | Box filter only takes ikernel size 7 = 7x7 ecc 143 | :param raw_data: input data to be processed 144 | :param f: 145 | :param filer_type: {"gauss", "box"} 146 | :param filter_kernel: 147 | :return: 148 | """ 149 | f.write("Data shape {} type {}\n".format(raw_data.shape, raw_data.dtype)) 150 | 151 | # image dim 152 | w = raw_data.shape[1] 153 | h = raw_data.shape[2] 154 | 155 | # create equivalent complex vector 156 | complex_data = np.zeros((2, w, h), dtype=np.complex64) 157 | # fulfil complex vector 158 | complex_data[0].real = raw_data[0] 159 | complex_data[0].imag = raw_data[1] 160 | complex_data[1].real = raw_data[2] 161 | complex_data[1].imag = raw_data[3] 162 | 163 | f.write("Complex_data shape {} type {}\n".format(complex_data.shape, complex_data.dtype)) 164 | 165 | # init vector for single pixel cov matrix 166 | s = np.zeros((1, 2), dtype=np.complex64) 167 | 168 | # initialize output feature vector 169 | data = np.zeros((5, w, h), dtype=np.float32) 170 | 171 | # initialize vector for all pixell covariance 172 | covariance = np.zeros((4, w, h), dtype=np.float32) 173 | 174 | # pass each pixel 175 | for i in range(0, w): 176 | for j in range(0, h): 177 | s[0, :] = complex_data[:, i, j] 178 | s_conj = np.conjugate(s) 179 | temp = s * s_conj.T 180 | c11_real = temp[0, 0].real 181 | c22_real = temp[1, 1].real 182 | c12 = temp[0, 1] 183 | covariance[:, i, j] = np.array((c11_real, c22_real, c12.real, c12.imag)) 184 | 185 | 186 | # filter 187 | if filer_type == "box": 188 | for i in range(covariance.shape[0]): 189 | covariance[i] = ndimage.uniform_filter(covariance[i], filter_kernel, mode="wrap") 190 | else: 191 | covariance = ndimage.gaussian_filter(covariance, filter_kernel, truncate=3) 192 | 193 | f.write("Filter type applied {} kernell dim {}\n".format(filer_type, filter_kernel)) 194 | # Calculate C12 abs 195 | # Init complex vector 196 | complex_c12 = np.zeros((w, h), dtype=np.complex64) 197 | # Fulfil complex vector 198 | complex_c12.real = covariance[2] 199 | complex_c12.imag = covariance[3] 200 | 201 | c12_abs = np.abs(complex_c12) 202 | 203 | # removes zeros 204 | covariance = np.where(covariance <= 0.0, 1e-06, covariance) 205 | c12_abs = np.where(c12_abs <= 0.0, 1e-06, c12_abs) 206 | 207 | # create output feature vector 208 | data[0] = np.log10(covariance[0]) 209 | data[1] = np.log10(covariance[1]) 210 | data[2] = np.log10(c12_abs) 211 | data[3] = covariance[2] / c12_abs 212 | data[4] = covariance[3] / c12_abs 213 | 214 | f.write("Return data array with 5 polsar real feature as EUSAR paper\n") 215 | 216 | return data 217 | 218 | 219 | def print_values(data, comp=None): 220 | 221 | for i in range(2): 222 | for k in range(data.shape[0]): 223 | if comp is not None: 224 | print("{}-{}-{}-{}".format(k, i,i ,data[k, i, i])) 225 | if k<2: 226 | print("{}-{}-{}-{}".format(k, i,i ,comp[k, i, i])) 227 | else: 228 | print("{}-{}-{}-{}".format(k, i,i ,data[k, i, i])) 229 | print() 230 | 231 | 232 | 233 | -------------------------------------------------------------------------------- /Lib/Datasets/processing/__pycache__/EUSAR_data_processing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Datasets/processing/__pycache__/EUSAR_data_processing.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Datasets/processing/__pycache__/utility.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Datasets/processing/__pycache__/utility.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Datasets/processing/__pycache__/utility.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Datasets/processing/__pycache__/utility.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Datasets/runner/EUSAR_tile_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from Lib.Datasets.processing.EUSAR_data_processing import png_to_numpy, process_tile 3 | 4 | 5 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 6 | """ 7 | Author: Alessandro Cattoi 8 | Description: This file is a runner to preprocess raw tiles 9 | """ 10 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 11 | 12 | 13 | orig = ['orig', 'orig_test'] 14 | dest = ['Train_corr', 'Test_corr'] 15 | for i in range(2): 16 | print(i) 17 | global_path = '/home/ale/Documents/Python/13_Tesi_2/' 18 | orig_path = 'Data/Datasets/EUSAR/' + orig[i] + '/' 19 | print(orig_path) 20 | dest_path = os.path.join(global_path, 'Data/Datasets/EUSAR/', dest[i]) 21 | print(dest_path) 22 | try: 23 | os.mkdir(dest_path) 24 | os.mkdir(os.path.join(dest_path, 'radar')) 25 | os.mkdir(os.path.join(dest_path, 'label')) 26 | os.mkdir(os.path.join(dest_path, 'rgb')) 27 | except: 28 | print('Already existing folder') 29 | 30 | label_path = os.path.join(global_path, orig_path, 'label/') 31 | # create unique numpy array of labels 32 | png_to_numpy(label_path, dest_path, '1') 33 | 34 | tile_path = os.path.join(global_path, orig_path, 'radar') 35 | tile_name = 'radar.tif' 36 | dest_folder_radar = os.path.join(global_path, dest_path, 'radar/') 37 | # process SAR tile 38 | process_tile(tile_path, tile_name, dest_folder_radar, 'SAR_feature', '1', 'box', 3) 39 | 40 | tile_path = os.path.join(global_path, orig_path, 'rgb') 41 | tile_name = 'rgb.tif' 42 | dest_folder_rgb = os.path.join(global_path, dest_path, 'rgb/') 43 | # process rgb tile 44 | process_tile(tile_path, tile_name, dest_folder_rgb, 'tif2npy', '1') 45 | -------------------------------------------------------------------------------- /Lib/Datasets/runner/cut_in_patches.py: -------------------------------------------------------------------------------- 1 | import os 2 | from Lib.Datasets.processing.utility import cut_tiles, cut_tiles_radar 3 | 4 | 5 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 6 | """ 7 | Author: Alessandro Cattoi 8 | Description: This file is a runner to cut processed tiles in patches 9 | """ 10 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 11 | 12 | 13 | dataset_name = 'EUSAR' 14 | ps = [128] # [32, 128, 192] 15 | dataset_type = ['Test_corr', 'Train_corr'] 16 | for typ in dataset_type: 17 | for patch_size in ps: 18 | global_path = '/home/ale/Documents/Python/13_Tesi_2/' 19 | data_orig = os.path.join(global_path, 'Data/Datasets/', dataset_name, typ) 20 | dest = os.path.join(global_path, 'Data/', typ.split('_')[0], dataset_name) 21 | dest_path = os.path.join(dest, str(patch_size) + '_sn_corr/') 22 | dest_path_trans = os.path.join(dest, str(patch_size) + '_trans_corr/') 23 | print('Creo:') 24 | print(' - Global path {}'.format(global_path)) 25 | print(' - Data path {}'.format(data_orig)) 26 | print(' - Dest path {}\n{}'.format(dest_path, dest_path_trans)) 27 | try: 28 | os.mkdir(dest_path_trans) 29 | os.mkdir(dest_path) 30 | os.mkdir(os.path.join(dest_path_trans, 'radar')) 31 | os.mkdir(os.path.join(dest_path_trans, 'rgb')) 32 | os.mkdir(os.path.join(dest_path, 'radar')) 33 | os.mkdir(os.path.join(dest_path, 'label')) 34 | os.mkdir(os.path.join(dest_path, 'rgb')) 35 | except: 36 | print('Already existing folder') 37 | 38 | max_n_bad_pix = 1 39 | overlapping = 0.5 40 | padding = True 41 | 42 | cut_tiles(data_orig, dest_path, '1', patch_size, max_n_bad_pix, overlapping, padding) 43 | cut_tiles_radar(data_orig, dest_path_trans, '1', patch_size, max_n_bad_pix, overlapping, padding) 44 | print(typ) 45 | print(dest_path) 46 | for i in os.listdir(dest_path): 47 | if '.' not in i: 48 | print(i) 49 | print(len(os.listdir(os.path.join(dest_path, i)))) 50 | 51 | print(dest_path_trans) 52 | for i in os.listdir(dest_path_trans): 53 | if '.' not in i: 54 | print(i) 55 | print(len(os.listdir(os.path.join(dest_path_trans, i)))) 56 | -------------------------------------------------------------------------------- /Lib/Nets/BL/BL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import os 4 | from Lib.Nets.utils.arch.arch import Generator, newSN 5 | from Lib.Nets.utils.generic.image2tensorboard import display_label_single_c, display_input, display_predictions 6 | from tqdm import tqdm 7 | from Lib.utils.generic.generic_utils import start, stop 8 | from Lib.utils.metrics.Accuracy import Accuracy 9 | from Lib.utils.Logger.Logger import Logger 10 | from Lib.Nets.utils.generic.generic_training import set_requires_grad, calculate_accuracy, breaker 11 | import itertools 12 | import pickle as pkl 13 | 14 | 15 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 16 | """ 17 | Author: Alessandro Cattoi 18 | Description: This file implements the Baseline 19 | """ 20 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 21 | 22 | 23 | class BL: 24 | """ 25 | This class implements all the variables, functions, methods, required to deploy a shallow U-Net which is used a 26 | classifier on top of a pretrained model 27 | """ 28 | def __init__(self, opt, device): 29 | """ 30 | define network 31 | :param opt: opt contain all the variables that are configurable when launching the script, check the 32 | folder: Lib/Nets/utils/config/ there are three scripts which are used to configure networks 33 | :param device: cuda device 34 | :return: 35 | """ 36 | # images placeholders 37 | self.real_S = None 38 | self.label = None 39 | self.seg_map = None 40 | # cost function values 41 | self.loss = None 42 | # general 43 | self.device = device 44 | self.opt = opt 45 | self.Accuracy_test = Accuracy() 46 | self.Accuracy_train = Accuracy() 47 | self.Logger_test = Logger(self.opt.mode) 48 | self.Logger_train = Logger(self.opt.mode) 49 | self.sar_c_vis = min(self.opt.sar_c, 3) 50 | self.posx_train = pkl.load(open(os.path.join(opt.data_dir_train, 'posx.pkl'), "rb")) 51 | self.posy_train = pkl.load(open(os.path.join(opt.data_dir_train, 'posy.pkl'), "rb")) 52 | self.posx_test = pkl.load(open(os.path.join(opt.data_dir_test, 'posx.pkl'), "rb")) 53 | self.posy_test = pkl.load(open(os.path.join(opt.data_dir_test, 'posy.pkl'), "rb")) 54 | # net 55 | self.netG_S2O = Generator(self.opt.sar_c, self.opt.optical_c, self.opt.dropout, self.opt.bias) 56 | # self.SN = SN(self.opt.sar_c, self.opt.N_classes).to(self.device) 57 | self.SegNet = newSN(self.opt.N_classes, self.opt.bias, self.opt.dropout).to(self.device) 58 | 59 | if self.opt.mode == "train": 60 | print('Mode -> train') 61 | 62 | set_requires_grad(self.SegNet, True) 63 | set_requires_grad(self.netG_S2O, True) 64 | 65 | self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean').to(self.device) 66 | self.G_net_chain = itertools.chain(self.netG_S2O.parameters(), self.SegNet.parameters()) 67 | self.optimizer = torch.optim.Adam(self.G_net_chain, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 68 | 69 | elif self.opt.mode == "eval": 70 | print('Mode -> train') 71 | set_requires_grad(self.netG_S2O, False) 72 | set_requires_grad(self.SegNet, False) 73 | file_GAN = os.path.join(self.opt.global_path, self.opt.pretrained_GAN, 74 | 'checkpoint_epoch_' + str(self.opt.GAN_epoch) + '.pt') 75 | file_SN = os.path.join(self.opt.global_path, self.opt.restoring_rep_path, 76 | 'checkpoint_epoch_' + str(self.opt.start_from_epoch) + '.pt') 77 | self.load(file_GAN, file_SN) 78 | temp = 22 - (9 - self.opt.res_block_N) 79 | self.netG_S2O = self.netG_S2O.model[0:temp] 80 | self.netG_S2O.to(self.device) 81 | 82 | def set_input(self, data): 83 | """ 84 | Unpack input data from the dataloader 85 | :param data: 86 | :return: 87 | """ 88 | self.real_S = data['radar'].to(self.device) 89 | self.label = data['label'].to(self.device).type(torch.long) 90 | 91 | def forward(self): 92 | """ 93 | Run forward pass 94 | :return: 95 | """ 96 | # use the output of the first up sampling layer (second last relu) 97 | # detach should detach from Generator. 98 | feature_map = self.netG_S2O(self.real_S) 99 | self.seg_map = self.SegNet(feature_map) 100 | 101 | def backward(self): 102 | """ 103 | Calculates loss and calculate gradients running loss.backward() 104 | :return: NA 105 | """ 106 | self.loss = self.criterion(self.seg_map, self.label) 107 | self.loss.backward() 108 | 109 | def optimize(self): 110 | """ 111 | Calculate losses, gradients, and update network weights; called in every training iteration 112 | :return: 113 | """ 114 | # compute fake images and reconstruction images. 115 | self.forward() 116 | # set gradients to zero 117 | self.optimizer.zero_grad() 118 | # calculate gradients 119 | self.backward() 120 | # update weights 121 | self.optimizer.step() 122 | 123 | def save_model(self, epoch): 124 | """ 125 | Save model parameters 126 | :param epoch: actual epoch 127 | :return: 128 | """ 129 | out_file = os.path.join(self.opt.checkpoint_dir, 'checkpoint_epoch_' + str(epoch) + ".pt") 130 | # save model 131 | data = {"G_S2O": self.netG_S2O.state_dict(), 132 | "SN": self.SegNet.state_dict(), 133 | "opt": self.optimizer.state_dict(), 134 | } 135 | torch.save(data, out_file) 136 | 137 | def load(self, file_GAN, file_SN): 138 | """ 139 | Restore generator and classifier parameters 140 | :param file_GAN: file from where load GAN parameters 141 | :param file_SN: file from where load SN parameters 142 | :return: 143 | """ 144 | data_GAN = torch.load(file_GAN) 145 | data_SN = torch.load(file_SN) 146 | if self.opt.mode == 'train': 147 | self.netG_S2O.load_state_dict(data_GAN['G_S2O']) 148 | self.SegNet.load_state_dict(data_SN['SN']) 149 | self.optimizer.load_state_dict(data_SN['opt']) 150 | elif self.opt.mode == 'eval': 151 | self.netG_S2O.load_state_dict(data_GAN['G_S2O']) 152 | self.SegNet.load_state_dict(data_SN['SN']) 153 | 154 | def tb_add_step_loss(self, writer, global_step): 155 | """ 156 | Saves segnet loss to tensorboard and segnet acc 157 | :param writer: pointer to tb 158 | :param global_step: step for tb 159 | :return: 160 | """ 161 | # log loss 162 | temp_loss = self.loss.item() 163 | writer.add_scalar("Train/Loss", temp_loss, global_step=global_step) 164 | loss_SN = {"loss_SN": temp_loss} 165 | self.Logger_train.append_SN_loss(loss_SN) 166 | 167 | def tb_add_step_images(self, writer=None, global_step=None): 168 | """ 169 | Saves all net images to tensorboard 170 | - real_S 171 | - label 172 | - prediction 173 | :param writer: pointer to tb writer 174 | :param global_step: step for tb 175 | :return: 176 | """ 177 | label_norm, mask = display_label_single_c(self.label) 178 | seg_map_norm = display_predictions(self.seg_map, True, mask) 179 | real_S_norm = display_input(self.real_S[:, 0:self.sar_c_vis, :, :], False) 180 | # if the writer is not passed to the function instead of updating tb it returns the images, 181 | # this variant is useful to create tile 182 | if writer is None: 183 | return label_norm[0], seg_map_norm[0] 184 | else: 185 | writer.add_images("1 - Labes", label_norm, global_step=global_step) 186 | writer.add_images("2 - Map", seg_map_norm, global_step=global_step) 187 | writer.add_images("3 - Radar Input", real_S_norm, global_step=global_step) 188 | 189 | def train(self, train_dataset, eval_dataset, writer): 190 | """ 191 | Run the training for the required epochs 192 | :param train_dataset: dataset used to train the network 193 | :param eval_dataset: 194 | :param writer: a tensorboard instance to track info 195 | :return: 196 | """ 197 | global_step = 0 198 | if self.opt.acc_log_freq == 0: 199 | calculate_accuracy(self, eval_dataset, writer, global_step, "Test", 0, self.posx_test, self.posy_test, self.opt.test_size, True) 200 | calculate_accuracy(self, train_dataset, writer, global_step, "Train", 0, self.posx_train, self.posy_train, self.opt.train_size, True) 201 | self.Logger_train.append_acc_step({"step": global_step}) 202 | self.Logger_test.append_acc_step({"step": global_step}) 203 | for epoch in range(self.opt.tot_epochs): 204 | t = start() 205 | text_line = "=" * 27 + "BL EPOCH " + str(epoch) + "/" + str(self.opt.tot_epochs) + "=" * 27 206 | print(text_line) 207 | progress_bar = tqdm(enumerate(train_dataset), total=len(train_dataset)) 208 | # Train for each patch in the 209 | for i, data in progress_bar: 210 | self.set_input(data) 211 | self.optimize() 212 | self.tb_add_step_loss(writer, global_step) 213 | self.Logger_train.append_loss_step({"step": global_step}) 214 | global_step = global_step + 1 215 | if epoch >= 0 and epoch % self.opt.acc_log_freq == 0: 216 | calculate_accuracy(self, eval_dataset, writer, global_step, "Test", epoch, self.posx_test, self.posy_test, self.opt.test_size, True) 217 | calculate_accuracy(self, train_dataset, writer, global_step, "Train", epoch, self.posx_train, self.posy_train, self.opt.train_size, True) 218 | self.Logger_train.append_acc_step({"step": global_step}) 219 | self.Logger_test.append_acc_step({"step": global_step}) 220 | if epoch >= 0 and epoch % self.opt.save_model_freq == 0: 221 | self.save_model(epoch) 222 | self.Logger_test.save_logger(self.opt.checkpoint_dir, name='test') 223 | self.Logger_train.save_logger(self.opt.checkpoint_dir, name='train') 224 | # Epoch duration 225 | s = 'SN Epoch {} '.format(epoch) 226 | stop(t, s) 227 | if breaker(self.opt, epoch): 228 | print('EXECUTION FORCED TO STOP AT {} EPOCHS'.format(epoch)) 229 | break 230 | -------------------------------------------------------------------------------- /Lib/Nets/BL/__pycache__/BL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/BL/__pycache__/BL.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/Cycle_AT/__pycache__/CGAN_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/Cycle_AT/__pycache__/CGAN_arch.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/Cycle_AT/__pycache__/DecayLR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/Cycle_AT/__pycache__/DecayLR.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/Cycle_AT/__pycache__/ReplyBuffer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/Cycle_AT/__pycache__/ReplyBuffer.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/Cycle_AT/__pycache__/SegNet_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/Cycle_AT/__pycache__/SegNet_arch.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/Cycle_AT/__pycache__/TrainerCGAN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/Cycle_AT/__pycache__/TrainerCGAN.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/Cycle_AT/__pycache__/config_routineCGAN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/Cycle_AT/__pycache__/config_routineCGAN.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/RT/RT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from Lib.Nets.utils.arch.arch import Generator 5 | from Lib.Nets.utils.generic.init_weights import init_weights 6 | from Lib.Nets.utils.generic.DecayLR import DecayLR 7 | from Lib.Nets.utils.generic.image2tensorboard import display_input, reconstruct_tile 8 | from tqdm import tqdm 9 | from Lib.utils.Logger.Logger import Logger 10 | from Lib.Nets.utils.generic.generic_training import set_requires_grad, breaker 11 | from Lib.Nets.utils.generic.trainSN import trainSN 12 | from Lib.utils.generic.generic_utils import start, stop 13 | import pickle as pkl 14 | 15 | 16 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 17 | """ 18 | Author: Alessandro Cattoi 19 | Description: This file implements the regressive transcoder 20 | """ 21 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 22 | 23 | 24 | class RT: 25 | """ 26 | This class implements all the variables, functions, methods, required to deploy a very simple conditional GAN 27 | trained only based on a regression loss function 28 | """ 29 | 30 | def __init__(self, opt, device): 31 | """ 32 | :param opt: code options 33 | :param device: cuda device 34 | :return: 35 | """ 36 | # images placeholders 37 | self.real_S = None 38 | self.fake_O = None 39 | self.label = None # real_O (rgb) 40 | self.name = None 41 | 42 | # cost function values 43 | self.loss = None 44 | 45 | # general 46 | self.device = device 47 | self.opt = opt 48 | self.Logger = Logger(self.opt.mode) 49 | self.trans = None 50 | self.trans_eval = None 51 | self.sar_c_vis = min(self.opt.sar_c, 3) 52 | self.opt_c_vis = min(self.opt.optical_c, 3) 53 | self.posx_train = pkl.load(open(os.path.join(opt.data_dir_train, 'posx.pkl'), "rb")) 54 | self.posy_train = pkl.load(open(os.path.join(opt.data_dir_train, 'posy.pkl'), "rb")) 55 | self.posx_test = pkl.load(open(os.path.join(opt.data_dir_test, 'posx.pkl'), "rb")) 56 | self.posy_test = pkl.load(open(os.path.join(opt.data_dir_test, 'posy.pkl'), "rb")) 57 | self.flag = False 58 | # Define generator 59 | self.netG_S2O = Generator(self.opt.sar_c, self.opt.optical_c, self.opt.dropout, self.opt.bias).to(self.device) 60 | 61 | if self.opt.mode == "train": 62 | print('Mode -> train') 63 | set_requires_grad(self.netG_S2O, True) 64 | # init weights 65 | init_weights(self.netG_S2O, self.opt.init_type, self.opt.init_gain) 66 | # define loss functions 67 | self.criterion = torch.nn.L1Loss().to(self.device) 68 | # [SETUP] TODO: Use cross entropy or BCELoss? 69 | # initialize optimizers 70 | # [SETUP] TODO: which optimizaer or SDG? 71 | self.optimizer = torch.optim.Adam(self.netG_S2O.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 72 | # instantiate the method step of the class decaylr 73 | self.lr_lambda = DecayLR(self.opt.tot_epochs, self.opt.start_from_epoch, self.opt.decay_epochs).step 74 | # initialise networks scheduler passing the function above 75 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lr_lambda) 76 | if self.opt.restoring_rep_path is not None: 77 | file = os.path.join(self.opt.restoring_rep_path, 78 | 'checkpoint_epoch_' + str(self.opt.start_from_epoch) + '.pt') 79 | self.load(file) 80 | elif self.opt.mode == "eval": 81 | self.sar_c_vis = self.opt.sar_c 82 | self.opt_c_vis = self.opt.optical_c 83 | print('Mode -> eval') 84 | set_requires_grad(self.netG_S2O, False) 85 | file = os.path.join(self.opt.restoring_rep_path, 86 | 'checkpoint_epoch_' + str(self.opt.start_from_epoch) + '.pt') 87 | self.load(file) 88 | 89 | def set_input(self, data): 90 | """ 91 | Unpack input data from the dataloader 92 | :param data: 93 | :return: 94 | """ 95 | self.real_S = data['radar'].to(self.device) 96 | self.label = data['rgb'].to(self.device) 97 | self.name = data['name'] 98 | 99 | def forward(self, var_name): 100 | """ 101 | Run forward pass 102 | :return: 103 | """ 104 | self.fake_O = self.netG_S2O(self.real_S) # G_S(S) 105 | if self.flag: 106 | name = self.name 107 | img = self.fake_O.cpu().detach() 108 | for i, n in enumerate(name): 109 | temp = int(n.split('_')[1]) - 1 110 | getattr(self, var_name)[temp] = img[i, 0:self.opt_c_vis] 111 | 112 | def backward(self): 113 | """ 114 | Calculate the loss for generator 115 | :return: 116 | """ 117 | self.loss = self.criterion(self.fake_O, self.label) 118 | self.loss.backward() 119 | 120 | def update_learning_rate(self): 121 | """ 122 | This function is called to request a step to the scheduler so that to update the learning rate 123 | :return: 124 | """ 125 | old_lr = self.optimizer.param_groups[0]['lr'] 126 | self.lr_scheduler.step() 127 | lr = self.optimizer.param_groups[0]['lr'] 128 | print('learning rate %.7f -> %.7f' % (old_lr, lr)) 129 | 130 | def optimize(self): 131 | """ 132 | Calculate losses, gradients, and update network weights; called in every training iteration 133 | :return: 134 | """ 135 | # compute fake images and reconstruction images. 136 | self.forward('trans') 137 | # set G_S and G_O's gradients to zero 138 | self.optimizer.zero_grad() 139 | # calculate gradients for G_S and G_O 140 | self.backward() 141 | # update G's weights 142 | self.optimizer.step() 143 | 144 | def save_model(self, epoch): 145 | """ 146 | Save model parameters 147 | :param epoch: actual epoch 148 | :return: 149 | """ 150 | out_file = os.path.join(self.opt.checkpoint_dir, 'checkpoint_epoch_' + str(epoch) + ".pt") 151 | # save model 152 | data = {"G_S2O": self.netG_S2O.state_dict(), 153 | "opt": self.optimizer.state_dict(), 154 | } 155 | torch.save(data, out_file) 156 | 157 | def load(self, file): 158 | """ 159 | restore model parameters 160 | :param file: file from where load parameters 161 | :return: 162 | """ 163 | data = torch.load(file) 164 | if self.opt.mode == 'train': 165 | self.netG_S2O.load_state_dict(data['G_S2O']) 166 | self.optimizer.load_state_dict(data['opt']) 167 | elif self.opt.mode == 'eval': 168 | self.netG_S2O.load_state_dict(data['G_S2O']) 169 | 170 | def tb_add_step_loss_g(self, writer, global_step): 171 | """ 172 | This function add G losses to tensorboard and store the value in the logger 173 | 174 | - loss 175 | :return: all losses and output of network 176 | """ 177 | step_loss = { 178 | 'loss': self.loss.item(), 179 | } 180 | writer.add_scalars("Train/Generator", step_loss, global_step=global_step) 181 | self.Logger.append_G(step_loss) 182 | 183 | def tb_add_step_images(self, writer=None, global_step=None): 184 | """ 185 | Saves all net images to tensorboard 186 | :param writer: pointer to tb 187 | :param global_step: step for tb 188 | :return: 189 | """ 190 | real_S_norm = display_input(self.real_S[:, 0:self.sar_c_vis, :, :], False) 191 | real_O_norm = display_input(self.label[:, 0:3, :, :], False) 192 | fake_O_norm = display_input(self.fake_O[:, 0:3, :, :], False) 193 | real = np.concatenate([real_S_norm, real_O_norm, fake_O_norm]) 194 | # if the writer is not passed to the function instead of updating tb it returns the images, 195 | # this variant is useful to create tile 196 | if writer is None: 197 | return real_S_norm[0], real_O_norm[0], fake_O_norm[0] 198 | else: 199 | writer.add_images("Real Radar - Real Optical - Fake Optical", real, global_step=global_step) 200 | 201 | def train(self, train_dataset, eval_dataset, writer=None): 202 | """ 203 | Run the training for the required epochs 204 | :param train_dataset: dataset used to train the network 205 | :param eval_dataset: dataset used to eval the network 206 | :param writer: a tensorboard instance to track info 207 | :return: 208 | """ 209 | self.trans = torch.zeros((len(train_dataset.dataset), self.opt_c_vis, self.opt.patch_size, self.opt.patch_size), 210 | dtype=torch.float32) 211 | epoch = self.opt.start_from_epoch 212 | global_step = epoch * len(train_dataset) 213 | for epoch in range(epoch, self.opt.tot_epochs): 214 | t = start() 215 | text_line = "=" * 30 + "EPOCH " + str(epoch) + "/" + str(self.opt.tot_epochs) + "=" * 30 216 | print(text_line) 217 | 218 | progress_bar = tqdm(enumerate(train_dataset), total=len(train_dataset)) 219 | # Train for each patch in the 220 | for i, data in progress_bar: 221 | self.set_input(data) 222 | self.optimize() 223 | # write generator loss to tensorboard 224 | if global_step > 0 and global_step % self.opt.loss_log_freq == 0: 225 | self.tb_add_step_loss_g(writer, global_step) 226 | self.Logger.append_loss_step({"step": global_step}) 227 | global_step = global_step + 1 228 | 229 | self.update_learning_rate() 230 | if epoch >= 0 and epoch % self.opt.save_model_freq == 0: 231 | self.save_model(epoch) 232 | # reconstruct_tile('train', self.opt.patch_size, self.posx_train, self.posy_train, self.opt.tb_dir, 233 | # self.opt.train_size, epoch, self.trans) # , parameter_path=par_path) 234 | if epoch >= 0 and epoch % self.opt.images_log_freq == 0: 235 | self.eval(train_dataset, epoch, self.posx_train, self.posy_train, 'train') 236 | self.eval(eval_dataset, epoch, self.posx_test, self.posy_test, 'test') 237 | 238 | trainSN(self.opt, epoch, self.device) 239 | self.Logger.save_logger(self.opt.checkpoint_dir, '') 240 | # Epoch duration 241 | s = 'Epoch {} took'.format(epoch) 242 | stop(t, s) 243 | if breaker(self.opt, epoch): 244 | print('EXECUTION FORCED TO STOP AT {} EPOCHS'.format(epoch)) 245 | break 246 | 247 | def eval(self, dataset=None, epoch=0, posx=None, posy=None, name=''): 248 | set_requires_grad(self.netG_S2O, False) 249 | self.flag = True 250 | self.trans_eval = torch.zeros((len(dataset.dataset), self.opt_c_vis, self.opt.patch_size, self.opt.patch_size), 251 | dtype=torch.float32) 252 | progress_bar = tqdm(enumerate(dataset), total=len(dataset)) 253 | # Train for each patch in the 254 | for i, data in progress_bar: 255 | self.set_input(data) 256 | self.forward('trans_eval') 257 | reconstruct_tile(name, self.opt.patch_size, posx, posy, self.opt.tb_dir, self.opt.test_size, epoch, self.trans_eval) 258 | # , parameter_path=par_path) 259 | self.flag = False 260 | set_requires_grad(self.netG_S2O, True) 261 | -------------------------------------------------------------------------------- /Lib/Nets/SN/SN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import os 4 | from Lib.Nets.utils.arch.arch import Generator, newSN 5 | from Lib.Nets.utils.generic.image2tensorboard import display_label_single_c, display_input, display_predictions 6 | from tqdm import tqdm 7 | from Lib.utils.generic.generic_utils import start, stop 8 | from Lib.utils.metrics.Accuracy import Accuracy 9 | from Lib.utils.Logger.Logger import Logger 10 | from Lib.Nets.utils.generic.generic_training import set_requires_grad, calculate_accuracy, breaker 11 | import pickle as pkl 12 | 13 | 14 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 15 | """ 16 | Author: Alessandro Cattoi 17 | Description: This file implements the classifier network to put on top of the feature extractors 18 | """ 19 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 20 | 21 | 22 | class SN: 23 | """ 24 | This class implements all the variables, functions, methods, required to deploy a shallow U-Net which is used a classifier 25 | on top of a pretrained model 26 | """ 27 | def __init__(self, opt, device): 28 | """ 29 | define network 30 | :param opt: opt contain all the variables that are configurable when launching the script, check the 31 | folder: Lib/Nets/utils/config/ there are three scripts which are used to configure networks 32 | :param device: cuda device 33 | :return: 34 | """ 35 | # images placeholders 36 | self.real_S = None 37 | self.label = None 38 | self.seg_map = None 39 | # cost function values 40 | self.loss = None 41 | # general 42 | self.device = device 43 | self.opt = opt 44 | self.Accuracy_test = Accuracy() 45 | self.Accuracy_train = Accuracy() 46 | self.Logger_test = Logger(self.opt.mode) 47 | self.Logger_train = Logger(self.opt.mode) 48 | self.sar_c_vis = min(self.opt.sar_c, 3) 49 | self.posx_train = pkl.load(open(os.path.join(opt.data_dir_train, 'posx.pkl'), "rb")) 50 | self.posy_train = pkl.load(open(os.path.join(opt.data_dir_train, 'posy.pkl'), "rb")) 51 | self.posx_test = pkl.load(open(os.path.join(opt.data_dir_test, 'posx.pkl'), "rb")) 52 | self.posy_test = pkl.load(open(os.path.join(opt.data_dir_test, 'posy.pkl'), "rb")) 53 | # net 54 | self.netG_S2O = Generator(self.opt.sar_c, self.opt.optical_c, self.opt.dropout, self.opt.bias) 55 | #self.SN = SN(self.opt.sar_c, self.opt.N_classes).to(self.device) 56 | self.SegNet = newSN(self.opt.N_classes, self.opt.bias, self.opt.dropout).to(self.device) 57 | 58 | if self.opt.mode == "train": 59 | print('Mode -> train') 60 | if self.opt.GAN_epoch is not None: 61 | file = os.path.join(self.opt.global_path, self.opt.pretrained_GAN, 62 | 'checkpoint_epoch_' + str(self.opt.GAN_epoch) + '.pt') 63 | self.load_GAN(file) 64 | print('Loaded model {}'.format(file)) 65 | #self.load_SN(file) 66 | #self.SN.to(self.device) 67 | else: 68 | # [SETUP] TODO: init weight of segnet 69 | # init weights 70 | pass 71 | set_requires_grad(self.SegNet, True) 72 | set_requires_grad(self.netG_S2O, False) 73 | 74 | self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean').to(self.device) 75 | # self.optimizer = torch.optim.RMSprop(self.SN.parameters(), lr=self.opt.lr_SN, 76 | # weight_decay=self.opt.weight_decay_SN) 77 | self.optimizer = torch.optim.Adam(self.SegNet.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 78 | elif self.opt.mode == "eval": 79 | print('Mode -> train') 80 | set_requires_grad(self.netG_S2O, False) 81 | set_requires_grad(self.SegNet, False) 82 | file_GAN = os.path.join(self.opt.global_path, self.opt.pretrained_GAN, 83 | 'checkpoint_epoch_' + str(self.opt.GAN_epoch) + '.pt') 84 | file_SN = os.path.join(self.opt.global_path, self.opt.restoring_rep_path, 85 | 'checkpoint_epoch_' + str(self.opt.start_from_epoch) + '.pt') 86 | self.load_all(file_GAN, file_SN) 87 | temp = 22 - (9 - self.opt.res_block_N) 88 | self.netG_S2O = self.netG_S2O.model[0:temp] 89 | self.netG_S2O.to(self.device) 90 | 91 | def set_input(self, data): 92 | """ 93 | Unpack input data from the dataloader 94 | :param data: 95 | :return: 96 | """ 97 | self.real_S = data['radar'].to(self.device) 98 | self.label = data['label'].to(self.device).type(torch.long) 99 | 100 | def forward(self): 101 | """ 102 | Run forward pass 103 | :return: 104 | """ 105 | # use the output of the first up sampling layer (second last relu) 106 | # detach should detach from Generator. 107 | feature_map = self.netG_S2O(self.real_S).detach() 108 | self.seg_map = self.SegNet(feature_map) 109 | 110 | def backward(self): 111 | """ 112 | Calculates loss and calculate gradients running loss.backward() 113 | :return: NA 114 | """ 115 | self.loss = self.criterion(self.seg_map, self.label) 116 | 117 | self.loss.backward() 118 | 119 | def optimize(self): 120 | """ 121 | Calculate losses, gradients, and update network weights; called in every training iteration 122 | :return: 123 | """ 124 | # compute fake images and reconstruction images. 125 | self.forward() 126 | set_requires_grad(self.netG_S2O, False) 127 | # set gradients to zero 128 | self.optimizer.zero_grad() 129 | # calculate gradients 130 | self.backward() 131 | # update only segnet weights weights 132 | self.optimizer.step() 133 | 134 | def save_model(self, epoch): 135 | """ 136 | Save model parameters 137 | :param epoch: actual epoch 138 | :return: 139 | """ 140 | out_file = os.path.join(self.opt.checkpoint_dir, 'checkpoint_epoch_' + str(epoch) + ".pt") 141 | # save model 142 | data = {"G_S2O": self.netG_S2O.state_dict(), 143 | "SN": self.SegNet.state_dict(), 144 | "opt_SegNet": self.optimizer.state_dict(), 145 | } 146 | 147 | torch.save(data, out_file) 148 | 149 | def load_GAN(self, file): 150 | """ 151 | Restore generator parameters 152 | :param file: file from where load parameters 153 | :return: 154 | """ 155 | data = torch.load(file) 156 | self.netG_S2O.load_state_dict(data['G_S2O']) 157 | 158 | def load_SN(self, file): 159 | """ 160 | Restore generator parameters 161 | :param file: file from where load parameters 162 | :return: 163 | """ 164 | pt_dict = torch.load(file) 165 | pretrained_dict = {k: v for k, v in pt_dict['G_S2O'].items() if '19' in k or '22' in k} 166 | 167 | SN_dict = self.SegNet.state_dict() 168 | 169 | SN_dict_names = [] 170 | for k, v in SN_dict.items(): 171 | SN_dict_names.append(k) 172 | 173 | i = 0 174 | for k, v in pretrained_dict.items(): 175 | SN_dict[SN_dict_names[i]] = v 176 | i = i + 1 177 | 178 | self.SegNet.load_state_dict(SN_dict) 179 | 180 | def load_all(self, file_GAN, file_SN): 181 | """ 182 | Restore generator and classifier parameters 183 | :param file_GAN: file from where load GAN parameters 184 | :param file_SN: file from where load SN parameters 185 | :return: 186 | """ 187 | data_GAN = torch.load(file_GAN) 188 | data_SN = torch.load(file_SN) 189 | if self.opt.mode == 'train': 190 | self.netG_S2O.load_state_dict(data_GAN['G_S2O']) 191 | self.SegNet.load_state_dict(data_SN['SN']) 192 | self.optimizer.load_state_dict(data_SN['opt_SegNet']) 193 | elif self.opt.mode == 'eval': 194 | self.netG_S2O.load_state_dict(data_GAN['G_S2O']) 195 | self.SegNet.load_state_dict(data_SN['SN']) 196 | 197 | def tb_add_step_loss(self, writer, global_step): 198 | """ 199 | Saves segnet loss to tensorboard and segnet acc 200 | :param writer: pointer to tb 201 | :param global_step: step for tb 202 | :return: 203 | """ 204 | # log loss 205 | temp_loss = self.loss.item() 206 | writer.add_scalar("Train/Loss", temp_loss, global_step=global_step) 207 | loss_SN = {"loss_SN": temp_loss} 208 | self.Logger_train.append_SN_loss(loss_SN) 209 | 210 | def tb_add_step_images(self, writer=None, global_step=None): 211 | """ 212 | Saves all net images to tensorboard 213 | - real_S 214 | - label 215 | - prediction 216 | :param writer: pointer to tb writer 217 | :param global_step: step for tb 218 | :return: 219 | """ 220 | label_norm, mask = display_label_single_c(self.label) 221 | seg_map_norm = display_predictions(self.seg_map, True, mask) 222 | real_S_norm = display_input(self.real_S[:, 0:self.sar_c_vis, :, :], False) 223 | # if the writer is not passed to the function instead of updating tb it returns the images, 224 | # this variant is useful to create tile 225 | if writer is None: 226 | return label_norm[0], seg_map_norm[0] 227 | else: 228 | writer.add_images("Train/1 - Labes", label_norm, global_step=global_step) 229 | writer.add_images("Train/2 - Map", seg_map_norm, global_step=global_step) 230 | writer.add_images("Train/3 - Radar Input", real_S_norm, global_step=global_step) 231 | 232 | def train(self, train_dataset, eval_dataset, writer): 233 | """ 234 | Run the training for the required epochs 235 | :param train_dataset: dataset used to train the network 236 | :param eval_dataset: 237 | :param writer: a tensorboard instance to track info 238 | :return: 239 | """ 240 | global_step = 0 241 | if self.opt.acc_log_freq == 117: 242 | calculate_accuracy(self, eval_dataset, writer, global_step, "Test", 0, self.posx_test, self.posy_test, self.opt.test_size) 243 | calculate_accuracy(self, train_dataset, writer, global_step, "Train", 0, self.posx_train, self.posy_train, self.opt.train_size) 244 | self.Logger_train.append_acc_step({"step": global_step}) 245 | self.Logger_test.append_acc_step({"step": global_step}) 246 | for epoch in range(self.opt.tot_epochs): 247 | t = start() 248 | text_line = "=" * 27 + "SN EPOCH " + str(epoch) + "/" + str(self.opt.tot_epochs) + "=" * 27 249 | print(text_line) 250 | progress_bar = tqdm(enumerate(train_dataset), total=len(train_dataset)) 251 | # Train for each patch in the 252 | for i, data in progress_bar: 253 | self.set_input(data) 254 | self.optimize() 255 | self.tb_add_step_loss(writer, global_step) 256 | #self.Logger_train.append_loss_step({"step": global_step}) 257 | global_step = global_step + 1 258 | 259 | if (epoch > 0 and epoch % self.opt.acc_log_freq == 0) or self.opt.acc_log_freq == 117: 260 | calculate_accuracy(self, eval_dataset, writer, global_step, "Test", epoch, self.posx_test, self.posy_test, self.opt.test_size) 261 | #calculate_accuracy(self, train_dataset, writer, global_step, "Train", epoch, self.posx_train, self.posy_train, self.opt.train_size) 262 | #self.Logger_train.append_acc_step({"step": global_step}) 263 | self.Logger_test.append_acc_step({"step": global_step}) 264 | 265 | if epoch > 0 and epoch % self.opt.save_model_freq == 0: 266 | self.save_model(epoch) 267 | self.Logger_test.save_logger(self.opt.checkpoint_dir, name='test') 268 | #self.Logger_train.save_logger(self.opt.checkpoint_dir, name='train') 269 | # Epoch duration 270 | s = 'SN Epoch {} '.format(epoch) 271 | stop(t, s) 272 | if breaker(self.opt, epoch): 273 | print('EXECUTION FORCED TO STOP AT {} EPOCHS'.format(epoch)) 274 | break 275 | -------------------------------------------------------------------------------- /Lib/Nets/SN/__pycache__/SN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/SN/__pycache__/SN.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/SN/__pycache__/SN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/SN/__pycache__/SN.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/arch/arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 6 | """ 7 | Author: Alessandro Cattoi 8 | Description: This file defines all the architecture employed in this work 9 | """ 10 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 11 | 12 | 13 | class Discriminator(nn.Module): 14 | """ 15 | Define the discriminator architecture 16 | """ 17 | def __init__(self, input_c, use_bias): 18 | super(Discriminator, self).__init__() 19 | self.register_buffer('real_label', torch.tensor(1.0)) 20 | self.register_buffer('fake_label', torch.tensor(0.0)) 21 | 22 | self.model = nn.Sequential( 23 | # fixed first 24 | nn.Conv2d(input_c, 64, 4, stride=2, padding=1), 25 | nn.LeakyReLU(0.2, inplace=True), 26 | 27 | # n_layers_D = 3 means 2 layer for discriminator + 1 fixed 28 | # 1 29 | nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=use_bias), 30 | nn.InstanceNorm2d(128), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | 33 | # 2 34 | nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=use_bias), 35 | nn.InstanceNorm2d(256), 36 | nn.LeakyReLU(0.2, inplace=True), 37 | 38 | # 3 fixed 39 | nn.Conv2d(256, 512, 4, padding=1, bias=use_bias), 40 | nn.InstanceNorm2d(512), 41 | nn.LeakyReLU(0.2, inplace=True), 42 | 43 | # fixed last 44 | nn.Conv2d(512, 1, 4, padding=1), 45 | ) 46 | 47 | def forward(self, x): 48 | x = self.model(x) 49 | return x 50 | 51 | 52 | class Generator(nn.Module): 53 | """ 54 | Define the generator architecture 55 | """ 56 | def __init__(self, input_c, output_c, use_dropout, use_bias): 57 | super(Generator, self).__init__() 58 | self.model = nn.Sequential( 59 | # First fixed layer 60 | nn.ReflectionPad2d(3), 61 | nn.Conv2d(input_c, 64, 7, bias=use_bias), 62 | nn.InstanceNorm2d(64), 63 | nn.ReLU(inplace=True), 64 | 65 | # Downsampling fixed at 2 66 | nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=use_bias), 67 | nn.InstanceNorm2d(128), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=use_bias), 70 | nn.InstanceNorm2d(256), 71 | nn.ReLU(inplace=True), 72 | #nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=use_bias), 73 | #nn.InstanceNorm2d(512), 74 | #nn.ReLU(inplace=True), 75 | 76 | # Residual blocks 77 | ResidualBlock(256, use_dropout, use_bias), 78 | ResidualBlock(256, use_dropout, use_bias), 79 | ResidualBlock(256, use_dropout, use_bias), 80 | 81 | ResidualBlock(256, use_dropout, use_bias), 82 | ResidualBlock(256, use_dropout, use_bias), 83 | ResidualBlock(256, use_dropout, use_bias), 84 | 85 | ResidualBlock(256, use_dropout, use_bias), 86 | ResidualBlock(256, use_dropout, use_bias), 87 | ResidualBlock(256, use_dropout, use_bias), 88 | 89 | # Upsampling fixed at 2 90 | #nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=use_bias), 91 | #nn.InstanceNorm2d(256), 92 | #nn.ReLU(inplace=True), 93 | nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=use_bias), 94 | nn.InstanceNorm2d(128), 95 | nn.ReLU(inplace=True), 96 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=use_bias), 97 | nn.InstanceNorm2d(64), 98 | nn.ReLU(inplace=True), 99 | 100 | # Last layer fixed 101 | nn.ReflectionPad2d(3), 102 | nn.Conv2d(64, output_c, 7), 103 | #nn.Tanh(), 104 | nn.Conv2d(output_c, output_c, 1), 105 | #nn.Conv2d(output_c, output_c, 1), 106 | ) 107 | 108 | 109 | def forward(self, x): 110 | return self.model(x) 111 | 112 | 113 | class ResidualBlock(nn.Module): 114 | """ 115 | Residual block definition 116 | """ 117 | def __init__(self, in_channels, use_dropout, use_bias): 118 | super(ResidualBlock, self).__init__() 119 | 120 | if use_dropout: 121 | self.res = nn.Sequential(nn.ReflectionPad2d(1), 122 | nn.Conv2d(in_channels, in_channels, 3, bias=use_bias), 123 | nn.InstanceNorm2d(in_channels), 124 | nn.ReLU(inplace=True), 125 | nn.Dropout(0.5), 126 | nn.ReflectionPad2d(1), 127 | nn.Conv2d(in_channels, in_channels, 3, bias=use_bias), 128 | nn.InstanceNorm2d(in_channels)) 129 | else: 130 | self.res = nn.Sequential(nn.ReflectionPad2d(1), 131 | nn.Conv2d(in_channels, in_channels, 3, bias=use_bias), 132 | nn.InstanceNorm2d(in_channels), 133 | nn.ReLU(inplace=True), 134 | nn.ReflectionPad2d(1), 135 | nn.Conv2d(in_channels, in_channels, 3, bias=use_bias), 136 | nn.InstanceNorm2d(in_channels)) 137 | 138 | def forward(self, x): 139 | return x + self.res(x) 140 | 141 | 142 | class PrintLayer(nn.Module): 143 | """ 144 | This layer can be inserted to debug networks 145 | """ 146 | def __init__(self): 147 | super(PrintLayer, self).__init__() 148 | 149 | def forward(self, x): 150 | # Do your print / debug stuff here 151 | print(x.shape) 152 | return x 153 | 154 | 155 | class newSN(nn.Module): 156 | """ 157 | This class implemetes the architecture to perform the classification task 158 | """ 159 | def __init__(self, output_c, use_bias, use_dropout): 160 | super(newSN, self).__init__() 161 | self.model = nn.Sequential( 162 | #ResidualBlock(256, use_dropout, use_bias), 163 | #ResidualBlock(256, use_dropout, use_bias), 164 | 165 | # Upsampling fixed at 2 166 | # nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=use_bias), 167 | # nn.InstanceNorm2d(256), 168 | # nn.ReLU(inplace=True), 169 | nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=use_bias), 170 | nn.InstanceNorm2d(128), 171 | nn.ReLU(inplace=True), 172 | nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=use_bias), 173 | nn.InstanceNorm2d(64), 174 | nn.ReLU(inplace=True), 175 | 176 | # Last layer fixed 177 | nn.ReflectionPad2d(3), 178 | nn.Conv2d(64, output_c, 7) 179 | ) 180 | 181 | def forward(self, x): 182 | return self.model(x) 183 | -------------------------------------------------------------------------------- /Lib/Nets/utils/config/__pycache__/config_routine.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/config/__pycache__/config_routine.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/config/__pycache__/config_routine.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/config/__pycache__/config_routine.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/config/__pycache__/general_parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/config/__pycache__/general_parser.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/config/__pycache__/general_parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/config/__pycache__/general_parser.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/config/__pycache__/specific_parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/config/__pycache__/specific_parser.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/config/__pycache__/specific_parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/config/__pycache__/specific_parser.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/config/config_routine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import torch 4 | import random 5 | from Lib.utils.generic.generic_utils import set_rand_seed 6 | 7 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 8 | """ 9 | Author: Alessandro Cattoi 10 | Description: This function set up parameters based on the arguments passed 11 | """ 12 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 13 | 14 | def config_routine(args): 15 | """ 16 | This routine is used to set up folders, save some status.py data and ask for some info when setting up a network 17 | both for training or for evaluation 18 | :param args: arguments passed when launching script 19 | :return: 20 | """ 21 | if args.restore_training: 22 | temp_epoch = args.start_from_epoch 23 | temp_path = args.restoring_rep_path 24 | args = pkl.load(open(os.path.join(args.global_path, args.restoring_rep_path, 'args.pkl'), "rb")) 25 | args.start_from_epoch = temp_epoch 26 | args.restoring_rep_path = temp_path 27 | random.seed(args.seed) 28 | torch.manual_seed(args.seed) 29 | print('Model path = {}'.format(os.path.join(args.global_path, args.restoring_rep_path))) 30 | print('Resuming epoch = {}'.format(args.start_from_epoch)) 31 | print('Continuing exp {}'.format(args.experiment_name)) 32 | temp = input('Correct?(y/n) ') 33 | if 'y' in temp: 34 | print("Let's go...") 35 | else: 36 | raise NotImplementedError("INCORRECT INIT") 37 | else: 38 | if args.experiment_name == "": 39 | name = input('Experiment name? (DESCRIPTION)') 40 | args.experiment_name = name 41 | # create required folder 42 | try: 43 | list_dir = os.listdir(os.path.join(args.global_path, args.log_dir)) 44 | list_dir = list(filter(lambda x: '.' not in x, list_dir)) 45 | id_list = [] 46 | if list_dir: 47 | for i in list_dir: 48 | id_list.append(int(i.split('_')[0])) 49 | unique_id = max(id_list) + 10 50 | else: 51 | unique_id = 10 52 | 53 | args.data_dir_train = os.path.join(args.global_path, args.data_dir_train) 54 | args.data_dir_test = os.path.join(args.global_path, args.data_dir_test) 55 | args.data_dir_val = os.path.join(args.global_path, args.data_dir_val) 56 | args.log_dir = os.path.join(args.global_path, args.log_dir) 57 | args.pretrained_GAN = os.path.join(args.global_path, args.pretrained_GAN) 58 | args.log_dir = os.path.join(args.log_dir, str(unique_id) + '_' + args.experiment_name) 59 | print('Created ' + str(unique_id) + '_' + args.experiment_name) 60 | except: 61 | args.log_dir = os.path.join(args.global_path, args.log_dir) 62 | args.log_dir = os.path.join(args.log_dir, args.experiment_name) 63 | 64 | if args.sar_c != args.optical_c: 65 | args.lambda_identity = 0.0 66 | 67 | file = open(os.path.join(args.data_dir_train, '1_log.txt'), 'r') 68 | lines = file.readlines() 69 | values = lines[9].split(',') 70 | dim1 = values[1].split(')')[0] 71 | dim2 = values[2].split(')')[0] 72 | dim1 = int(dim1) 73 | dim2 = int(dim2) 74 | args.train_size = [dim1, dim2] 75 | 76 | file = open(os.path.join(args.data_dir_test, '1_log.txt'), 'r') 77 | lines = file.readlines() 78 | values = lines[9].split(',') 79 | dim1 = values[1].split(')')[0] 80 | dim2 = values[2].split(')')[0] 81 | dim1 = int(dim1) 82 | dim2 = int(dim2) 83 | args.test_size = [dim1, dim2] 84 | 85 | if 'BERLIN' in args.data_dir_train: 86 | args.sar_c = 2 87 | args.N_classes = 10 88 | 89 | if '32' in args.data_dir_train: 90 | args.patch_size = 32 91 | elif '128' in args.data_dir_train: 92 | args.patch_size = 128 93 | elif '256' in args.data_dir_train: 94 | args.patch_size = 256 95 | else: 96 | args.patch_size = 192 97 | 98 | if args.restoring_rep_path is not None: 99 | args.restoring_rep_path = os.path.join(args.global_path, args.restoring_rep_path, "checkpoints") 100 | 101 | os.mkdir(args.log_dir) 102 | # add new argument checkpoint_dir 103 | args.checkpoint_dir = os.path.join(args.log_dir, "checkpoints") 104 | os.mkdir(args.checkpoint_dir) 105 | # add new argument tb_dir for tensorboard 106 | args.tb_dir = os.path.join(args.log_dir, "tb") 107 | os.mkdir(args.tb_dir) 108 | 109 | # if seed not available 110 | if args.seed is None: 111 | # generate seed 112 | args.seed = set_rand_seed() 113 | print("Random Seed: ", args.seed) 114 | # Set seed og random generators 115 | random.seed(args.seed) 116 | torch.manual_seed(args.seed) 117 | 118 | # create a log file with all options specified 119 | f = open(os.path.join(args.log_dir, "param.txt"), "a") 120 | text_line = "=" * 20 + "CONFIG" + "=" * 20 + '\n' 121 | f.write(text_line) 122 | for arg in vars(args): 123 | text_line = '{0:20} {1}\n'.format(arg, getattr(args, arg)) 124 | f.write(text_line) 125 | f.close() 126 | # save config variable 127 | pkl.dump(args, open(os.path.join(args.checkpoint_dir, "args.pkl"), "wb")) 128 | 129 | # check gpu 130 | if torch.cuda.is_available(): 131 | print("GPU devices found: {}".format(torch.cuda.device_count())) 132 | else: 133 | raise NotImplementedError("GPU PROBLEM") 134 | 135 | # Create a file in the number of total epoch is stored 136 | # This file can be used to correctly stop the execution at a different number of epoch 137 | f = open(os.path.join(args.log_dir, "q.txt"), "w") 138 | val = 'epoch=' + str(args.tot_epochs) 139 | f.write(val) 140 | f.close() 141 | 142 | return args 143 | -------------------------------------------------------------------------------- /Lib/Nets/utils/config/general_parser.py: -------------------------------------------------------------------------------- 1 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 2 | """ 3 | Author: Alessandro Cattoi 4 | Description: This file defines a parser to describe all the parameters emplyed in this work. This function allows to 5 | pass parameter when launching the script. 6 | """ 7 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 8 | 9 | def general_parser(parser): 10 | """ 11 | This is the parser which collects all the parameters passed through the command to launch the script 12 | - experiment_name 13 | - run_description 14 | - val_freq 15 | - acc_log_freq 16 | - loss_log_freq 17 | - images_log_freq 18 | - SN_log_freq 19 | - save_model_freq 20 | - restore_training 21 | - start_from_epoch 22 | - restoring_rep_path 23 | - D_training_ratio 24 | - buff_dim 25 | - th_low 26 | - th_high 27 | - th_b_h_ratio 28 | - th_b_l_ratio 29 | - th_b_h_pool 30 | - th_b_l_pool 31 | - res_block_N 32 | - pool_prc_O 33 | - pool_prc_S 34 | - drop_prc 35 | - seed 36 | - tot_epochs 37 | - decay_epochs 38 | - patch_size 39 | - batch_size 40 | - sar_c 41 | - optical_c 42 | - N_classes 43 | - loss_type 44 | - lr 45 | - beta1 46 | - workers 47 | - dropout 48 | - lambda_S 49 | - lambda_O 50 | - lambda_identity 51 | - pool_size 52 | - pool 53 | - conditioned 54 | - pool 55 | - init_type 56 | - init_gain 57 | - global_path 58 | - data_dir_train 59 | - data_dir_train2 60 | - data_dir_test 61 | - data_dir_test2 62 | - data_dir_val 63 | - log_dir 64 | - train_set_prc 65 | - mode 66 | - lr_SN 67 | - weight_decay_SN 68 | - batch_size_SN 69 | """ 70 | # general 71 | parser.add_argument('--experiment_name', type=str, default="", 72 | help="experiment name. will be used in the path names for log- and save files") 73 | parser.add_argument('--run_description', type=str, default="", 74 | help="describes the running") 75 | parser.add_argument('--val_freq', type=int, default=1000, 76 | help='validation will be run every val_freq batches/optimization steps during training') 77 | parser.add_argument('--acc_log_freq', type=int, default=500, 78 | help='model will be saved every acc_log_freq epochs during training') 79 | parser.add_argument('--loss_log_freq', type=int, default=10, 80 | help='tensorboard logs will be written every loss_log_freq number of batches/optimization steps') 81 | parser.add_argument('--images_log_freq', type=int, default=1000, 82 | help='tensorboard logs will be written every image_log_freq of batches/optimization steps') 83 | parser.add_argument('--SN_log_freq', type=int, default=5, 84 | help='tensorboard logs will be written every image_log_freq of batches/optimization steps') 85 | parser.add_argument('--save_model_freq', type=int, default=5, 86 | help='tensorboard logs will be written every image_log_freq of batches/optimization steps') 87 | parser.add_argument('--restore_training', type=bool, default=False, 88 | help='restore a previous training') 89 | parser.add_argument('--start_from_epoch', type=int, default=0, 90 | help='epoch from where start training') 91 | parser.add_argument('--restoring_rep_path', type=str, default=None, 92 | help='restore a previous training') 93 | parser.add_argument('--D_training_ratio', type=int, default=5, 94 | help='every D_training_ratio opt step of the generator 1 Discriminator opt is performed') 95 | parser.add_argument('--buff_dim', type=int, default=10000, 96 | help='mean loss value on which decide if change Dration') 97 | parser.add_argument('--th_low', type=int, default=0.45, 98 | help='D ratio low threshold') 99 | parser.add_argument('--th_high', type=int, default=0.55, 100 | help='D ration high threshold') 101 | parser.add_argument('--th_b_h_ratio', type=int, default=100, 102 | help='D ration high threshold') 103 | parser.add_argument('--th_b_l_ratio', type=int, default=2, 104 | help='D ration high threshold') 105 | parser.add_argument('--th_b_h_pool', type=int, default=0.9, 106 | help='D ration high threshold') 107 | parser.add_argument('--th_b_l_pool', type=int, default=0.4, 108 | help='D ration high threshold') 109 | 110 | 111 | parser.add_argument('--res_block_N', type=int, default=9, 112 | help='number of resblock ps=128 -> 6 ps= 256-> 9') 113 | parser.add_argument('--pool_prc_S', type=int, default=0.5, 114 | help='prc to choose old generated patches or new one') 115 | parser.add_argument('--pool_prc_O', type=int, default=0.5, 116 | help='prc to choose old generated patches or new one') 117 | parser.add_argument('--drop_prc', type=int, default=0.5, 118 | help='prc to choose old generated patches or new one') 119 | 120 | # training hyperparameters 121 | parser.add_argument('--seed', type=int, default=1, 122 | help='torch random seed') 123 | parser.add_argument('--tot_epochs', type=int, default=200, 124 | help='number of training epochs (default: 100)') 125 | parser.add_argument("--decay_epochs", type=int, default=100, 126 | help="when starts linearly decaying the learning rate to 0. (default:100)") 127 | parser.add_argument('--patch_size', type=int, default=32, 128 | help='patch size for training and validation (default: 1)') 129 | parser.add_argument('--batch_size', type=int, default=1, 130 | help='batch size for training and validation (default: 1)') 131 | parser.add_argument('--sar_c', type=int, default=5, 132 | help='n of channel in sar images') 133 | parser.add_argument('--optical_c', type=int, default=4, 134 | help='n of channel in optical images') 135 | parser.add_argument('--N_classes', type=int, default=5, 136 | help='n of channel in optical images') 137 | parser.add_argument('--loss_type', type=str, default="lsgan", 138 | help='[lsgan | wgan]') 139 | parser.add_argument("--lr", type=float, default=0.0002, 140 | help="learning rate. (default:0.0002)") 141 | parser.add_argument('--beta1', type=float, default=0.5, 142 | help='beta1 term for adam') 143 | parser.add_argument('--workers', type=int, default=4, 144 | help='number of workers for dataloading (default: 4)') 145 | parser.add_argument('--dropout', type=bool, default=True, 146 | help='dropout for the generators: True for training, False for testing') 147 | parser.add_argument('--bias', type=bool, default=True, 148 | help='bias for G and D: True for InstanceNorm2D as normalization func') 149 | parser.add_argument("--lambda_S", type=float, default=10, 150 | help="weight for cycle loss (S -> O -> S)") 151 | parser.add_argument("--lambda_O", type=float, default=10, 152 | help="weight for cycle loss (O -> S -> O)") 153 | parser.add_argument("--lambda_identity", type=float, default=0.5, 154 | help="Scales the weight of the identity mapping loss.") 155 | parser.add_argument("--lambda_A", type=float, default=100, 156 | help="Scales the weight of the regression loss.") 157 | parser.add_argument("--lambda_gp", type=float, default=10, 158 | help="Scales the weight of the gradient penalty only for wgan loss.") 159 | parser.add_argument("--pool_size", type=float, default=50, 160 | help="dim of the pool of images") 161 | parser.add_argument("--pool", type=bool, default=True, 162 | help="if use pool or not") 163 | parser.add_argument("--conditioned", type=bool, default=False, 164 | help="if use conditioned or not") 165 | parser.add_argument("--dropping", type=bool, default=False, 166 | help="if use dropping or not") 167 | parser.add_argument('--init_type', type=str, default='normal', 168 | help='network initialization [normal | xavier | kaiming | orthogonal]') 169 | parser.add_argument('--init_gain', type=float, default=0.02, 170 | help='scaling factor for normal, xavier and orthogonal.') 171 | 172 | # data 173 | parser.add_argument('--global_path', type=str, default="/home/ale/Documents/Python/13_Tesi_2/", 174 | help='path to training dataset') 175 | parser.add_argument('--data_dir_train', type=str, default="Data/Train/EUSAR/32_box_double_norm", 176 | help='path to training dataset') 177 | parser.add_argument('--data_dir_train2', type=str, default="Data/Train/EUSAR/32_box_double_norm", 178 | help='path to training dataset') 179 | parser.add_argument('--data_dir_test', type=str, default="", 180 | help='path to test dataset') 181 | parser.add_argument('--data_dir_test2', type=str, default="", 182 | help='path to test dataset') 183 | parser.add_argument('--data_dir_val', type=str, default="Data/Train/EUSAR/32_box_double_norm", 184 | help='path to validation dataset') 185 | parser.add_argument('--log_dir', type=str, default="Runs/Runs_CGAN/", 186 | help='path to dir for code logs') 187 | parser.add_argument('--prc_train', type=int, default=1, 188 | help='% of the train dataset') 189 | parser.add_argument('--prc_test', type=int, default=1, 190 | help='% of the test dataset') 191 | parser.add_argument('--prc_val', type=int, default=1, 192 | help='% of the val dataset') 193 | 194 | # SN param 195 | parser.add_argument('--mode', type=str, default='trainSN', 196 | help='set up model mode [train | eval]') 197 | parser.add_argument('--lr_SN', type=float, default=0.01, 198 | help='learning rate (default: 1e-2)') 199 | parser.add_argument('--weight_decay_SN', type=float, default=5e-4, 200 | help='weight-decay (default: 5e-4)') 201 | parser.add_argument('--batch_size_SN', type=int, default=32, 202 | help='batch size for training and validation \ 203 | (default: 32)') 204 | parser.add_argument('--pretrained_GAN', type=str, 205 | default='Runs/Runs_CGAN/4_2020-11-12_11-09-03_norm_data_first/norm_data_first_checkpoints', 206 | help='restore a pretrained model') 207 | parser.add_argument('--GAN_epoch', type=int, default=None, 208 | help='which epoch restore?') 209 | return parser 210 | -------------------------------------------------------------------------------- /Lib/Nets/utils/config/specific_parser.py: -------------------------------------------------------------------------------- 1 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 2 | """ 3 | Author: Alessandro Cattoi 4 | Description: This file defines a function to overwrite parsed parameters. If specified is possible to overwrite passed 5 | parameters. As a result is possible to define certain parameters in the initial par of scripts as done in the mains 6 | """ 7 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 8 | 9 | 10 | def specific_parser(parser, log=False, run_folder=None, mode=None, tot_epochs=None, restoring_rep_path=None, 11 | start_from_epoch=None, pretrained_GAN=None, GAN_epoch=None, data_dir_train=None, data_dir_train2=None, 12 | data_dir_test=None, data_dir_test2=None, images_log_freq=None, batch_size=None, batch_size_SN=None, 13 | acc_log_freq=None, loss_log_freq=None, experiment_name=None, run_description=None, prc_train=None, 14 | prc_test=None, prc_val=None, sar_c=None, optical_c=None, N_classes=None, patch_size=None, SN_log_freq=None, 15 | save_model_freq=None, lambda_identity=None, D_training_ratio=None, lambda_A=None, loss_type=None, 16 | lambda_gp=None, res_block_N=None, pool_prc_O=None, pool_prc_S=None, buff_dim=None, th_low=None, th_high=None, 17 | pool=None, conditioned=None, dropping=None, th_b_h_ratio=None, th_b_l_ratio=None, th_b_h_pool=None, 18 | th_b_l_pool=None, drop_prc=None, seed=None): 19 | """ 20 | This is an intermediate layer between the general parser and the config routine to allow who use this code to easily 21 | access parameters and change them when building his experiment 22 | :param parser: 23 | :param log: decide if print or not 24 | :param run_folder: new value for run folder 25 | :param mode: train mode 26 | :param tot_epochs: 27 | :param restoring_rep_path: 28 | :param start_from_epoch: 29 | :param pretrained_GAN: 30 | :param GAN_epoch: 31 | :param data_dir_train: 32 | :param data_dir_train2: 33 | :param data_dir_test: 34 | :param data_dir_test2: 35 | :param images_log_freq: 36 | :param batch_size: 37 | :param batch_size_SN: 38 | :param acc_log_freq: 39 | :param loss_log_freq: 40 | :param experiment_name: 41 | :param run_description: 42 | :param prc_train: 43 | :param prc_test: 44 | :param prc_val: 45 | :param sar_c: 46 | :param optical_c: 47 | :param N_classes: 48 | :param patch_size: 49 | :param SN_log_freq: 50 | :param save_model_freq: 51 | :param lambda_identity: 52 | :param D_training_ratio: 53 | :param lambda_A: 54 | :param loss_type: 55 | :param lambda_gp: 56 | :param res_block_N: 57 | :param pool_prc_O: 58 | :param pool_prc_S: 59 | :param buff_dim: 60 | :param th_low: 61 | :param th_high: 62 | :param pool: 63 | :param conditioned: 64 | :param dropping: 65 | :param th_b_h_ratio: 66 | :param th_b_l_ratio: 67 | :param th_b_h_pool: 68 | :param th_b_l_pool: 69 | :param drop_prc: 70 | :return: args 71 | """ 72 | args = parser.parse_args() 73 | print('SPECIFIC CONFIG') 74 | args.log_dir = update_arg(args.log_dir, run_folder, 'log_dir', log) 75 | args.tot_epochs = update_arg(args.tot_epochs, tot_epochs, 'tot_epochs', log) 76 | args.mode = update_arg(args.mode, mode, 'mode', log) 77 | args.restoring_rep_path = update_arg(args.restoring_rep_path, restoring_rep_path, 'restoring_rep_path', log) 78 | args.start_from_epoch = update_arg(args.start_from_epoch, start_from_epoch, 'start_from_epoch', log) 79 | args.pretrained_GAN = update_arg(args.pretrained_GAN, pretrained_GAN, 'pretrained_GAN', log) 80 | args.GAN_epoch = update_arg(args.GAN_epoch, GAN_epoch, 'GAN_epoch', log) 81 | args.data_dir_train = update_arg(args.data_dir_train, data_dir_train, 'data_dir_train', log) 82 | args.data_dir_train2 = update_arg(args.data_dir_train2, data_dir_train2, 'data_dir_train2', log) 83 | args.data_dir_test = update_arg(args.data_dir_test, data_dir_test, 'data_dir_test', log) 84 | args.data_dir_test2 = update_arg(args.data_dir_test2, data_dir_test2, 'data_dir_test2', log) 85 | args.images_log_freq = update_arg(args.images_log_freq, images_log_freq, 'images_log_freq', log) 86 | args.batch_size = update_arg(args.batch_size, batch_size, 'batch_size', log) 87 | args.batch_size_SN = update_arg(args.batch_size_SN, batch_size_SN, 'batch_size_SN', log) 88 | args.acc_log_freq = update_arg(args.acc_log_freq, acc_log_freq, 'acc_log_freq', log) 89 | args.loss_log_freq = update_arg(args.loss_log_freq, loss_log_freq, 'loss_log_freq', log) 90 | args.experiment_name = update_arg(args.experiment_name, experiment_name, 'experiment_name', log) 91 | args.run_description = update_arg(args.run_description, run_description, 'run_description', log) 92 | args.prc_train = update_arg(args.prc_train, prc_train, 'prc_train', log) 93 | args.prc_test = update_arg(args.prc_test, prc_test, 'prc_test', log) 94 | args.prc_val = update_arg(args.prc_val, prc_val, 'prc_val', log) 95 | args.sar_c = update_arg(args.sar_c, sar_c, 'sar_c', log) 96 | args.optical_c = update_arg(args.optical_c, optical_c, 'optical_c', log) 97 | args.N_classes = update_arg(args.N_classes, N_classes, 'N_classes', log) 98 | args.patch_size = update_arg(args.patch_size, patch_size, 'patch_size', log) 99 | args.SN_log_freq = update_arg(args.SN_log_freq, SN_log_freq, 'SN_log_freq', log) 100 | args.save_model_freq = update_arg(args.save_model_freq, save_model_freq, 'save_model_freq', log) 101 | args.lambda_identity = update_arg(args.lambda_identity, lambda_identity, 'lambda_identity', log) 102 | args.D_training_ratio = update_arg(args.D_training_ratio, D_training_ratio, 'D_training_ratio', log) 103 | args.lambda_A = update_arg(args.lambda_A, lambda_A, 'lambda_A', log) 104 | args.loss_type = update_arg(args.loss_type, loss_type, 'loss_type', log) 105 | args.lambda_gp = update_arg(args.lambda_gp, lambda_gp, 'lambda_gp', log) 106 | args.res_block_N = update_arg(args.res_block_N, res_block_N, 'res_block_N', log) 107 | args.pool_prc_O = update_arg(args.pool_prc_O, pool_prc_O, 'pool_prc_O', log) 108 | args.pool_prc_S = update_arg(args.pool_prc_S, pool_prc_S, 'pool_prc_S', log) 109 | args.buff_dim = update_arg(args.buff_dim, buff_dim, 'buff_dim', log) 110 | args.th_low = update_arg(args.th_low, th_low, 'th_low', log) 111 | args.th_high = update_arg(args.th_high, th_high, 'th_high', log) 112 | args.pool = update_arg(args.pool, pool, 'pool', log) 113 | args.conditioned = update_arg(args.conditioned, conditioned, 'conditioned', log) 114 | args.dropping = update_arg(args.dropping, dropping, 'dropping', log) 115 | args.th_b_h_ratio = update_arg(args.th_b_h_ratio, th_b_h_ratio, 'th_b_h_ratio', log) 116 | args.th_b_l_ratio = update_arg(args.th_b_l_ratio, th_b_l_ratio, 'th_b_l_ratio', log) 117 | args.th_b_h_pool = update_arg(args.th_b_h_pool, th_b_h_pool, 'th_b_h_pool', log) 118 | args.th_b_l_pool = update_arg(args.th_b_l_pool, th_b_l_pool, 'th_b_l_pool', log) 119 | args.drop_prc = update_arg(args.drop_prc, drop_prc, 'drop_prc', log) 120 | args.seed = update_arg(args.seed, seed, 'seed', log) 121 | return args 122 | 123 | 124 | def update_arg(original, new_val, name, log=False): 125 | """ 126 | Decide if update value or keep the original 127 | :param original: 128 | :param new_val: 129 | :param name: name of the variable 130 | :param log: decide if print or not 131 | :return: 132 | """ 133 | if new_val is None: 134 | out_val = original 135 | else: 136 | out_val = new_val 137 | if log: 138 | print(' - ' + name + ' = {}'.format(out_val)) 139 | 140 | return out_val 141 | -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/DecayLR.py: -------------------------------------------------------------------------------- 1 | class DecayLR: 2 | """ 3 | The learning scheduler requires a function that it is used to update the lr value 4 | This is the function used in the project models 5 | """ 6 | def __init__(self, epochs, checkpoint_epoch, decay_epochs): 7 | self.epochs = epochs 8 | self.checkpoint_epoch = checkpoint_epoch 9 | self.decay_epochs = decay_epochs 10 | 11 | def step(self, epoch): 12 | return 1.0 - max(0, epoch + self.checkpoint_epoch - self.decay_epochs) / (self.epochs - self.decay_epochs) 13 | -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/ReplyBuffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ReplayBuffer: 6 | """ 7 | This class implements a buffer which stores the last 50 generated images and return with a probability of 1/2 8 | the new one or one of the last 50 9 | """ 10 | def __init__(self, max_size=50): 11 | self.max_size = max_size 12 | # it is list of batch! 13 | self.data = [] 14 | 15 | def push_and_pop(self, data, prc): 16 | """ 17 | 18 | :param data: batch of images 19 | :return: batch of images randomly drawn from last max_size generated 20 | """ 21 | to_return = [] 22 | # element is each single image in the batch 23 | for element in data.data: 24 | # insert batch dim in the first position 25 | element = torch.unsqueeze(element, 0) 26 | if len(self.data) < self.max_size: 27 | # if the buffer is not full store one or more new elemnt 28 | self.data.append(element) 29 | # return the latest stored element 30 | to_return.append(element) 31 | else: 32 | # when the buffer is full with 50% probability 33 | if random.uniform(0, 1) > prc: 34 | # return a random element in the buffer 35 | # the latest added element is stored in place of the returned element 36 | i = random.randint(0, self.max_size - 1) 37 | to_return.append(self.data[i].clone()) 38 | self.data[i] = element 39 | else: 40 | # with 50 % probability return the latest element 41 | to_return.append(element) 42 | # convert 43 | return torch.cat(to_return) 44 | -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/DecayLR.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/DecayLR.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/DecayLR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/DecayLR.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/ReplyBuffer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/ReplyBuffer.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/ReplyBuffer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/ReplyBuffer.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/generic_training.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/generic_training.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/generic_training.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/generic_training.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/image2tensorboard.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/image2tensorboard.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/image2tensorboard.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/image2tensorboard.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/init_weights.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/init_weights.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/init_weights.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/init_weights.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/trainSN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/trainSN.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/__pycache__/trainSN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/Nets/utils/generic/__pycache__/trainSN.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/generic_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Subset 4 | from Lib.Nets.utils.generic.image2tensorboard import display_label_single_c, display_predictions 5 | from random import randint, uniform 6 | from PIL import Image 7 | import numpy as np 8 | 9 | 10 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 11 | """ 12 | Author: Alessandro Cattoi 13 | Description: This file defines some function which are common to all the network implemeted. 14 | """ 15 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 16 | 17 | 18 | def breaker(args, epoch): 19 | """ 20 | When execute the config routine a q.txt (q = quit) file is created where required number of epoch is written. 21 | By manually changing this parameter is possible to stop training at a different number of epochs. 22 | It is better to stop the network in this way so that each epoch is completely executed. 23 | This specific function read the value on the q file and return a bool which states if or not is equal to the actual 24 | number of epochs 25 | :param args: running parameters 26 | :param epoch: actual epoch 27 | :return: 28 | """ 29 | f = open(os.path.join(args.global_path, args.log_dir, 'q.txt'), "r") 30 | epoch_end = int(f.readline().split('=')[-1]) 31 | f.close() 32 | return epoch_end == epoch 33 | 34 | 35 | def set_requires_grad(model, requires_grad=False, N=None): 36 | """ 37 | Set requies_grad=Fasle for all the networks to avoid unnecessary computations 38 | :param model: model to who turn off gradient 39 | :param requires_grad: True to request gradient 40 | :return: 41 | """ 42 | if N is None: 43 | for param in model.parameters(): 44 | param.requires_grad = requires_grad 45 | if requires_grad: 46 | model.train() 47 | else: 48 | model.eval() 49 | else: 50 | for i, param in enumerate(model.parameters()): 51 | if i < N: 52 | param.requires_grad = requires_grad 53 | 54 | 55 | def print_param(model): 56 | """ 57 | Take a model ad input and print all parameter data, it is a debugging function 58 | :param model: 59 | :return: 60 | """ 61 | i = 0 62 | for name, param in model.named_parameters(): 63 | print(name, param.data, param.requires_grad) 64 | i = i + 1 65 | print(i) 66 | 67 | 68 | def get_subset(dataset, prc, end=False): 69 | """ 70 | Return a loader definition with the right prc of patches 71 | Random disable so that all net use exactly same validation set 72 | :param dataset: dataset from where create the loader 73 | :param prc: net options 74 | :param rand: 75 | :return: 76 | """ 77 | if prc >= 1: 78 | return dataset 79 | length = len(dataset) 80 | n_sample = int(length * prc) 81 | if end: 82 | subset = Subset(dataset, range(length - n_sample, length)) 83 | print('Range of prc data {} is [{}, {}]'.format(prc, length - n_sample, length)) 84 | else: 85 | start = 0 86 | subset = Subset(dataset, range(start, start + n_sample)) 87 | print('Range of prc data {} is [{}, {}]'.format(prc, start, start + n_sample)) 88 | return subset 89 | 90 | 91 | def calculate_accuracy(Net, dataset, writer, global_step, name, epoch, posx, posy, size, BL=False): 92 | """ 93 | Classify dataset and calculate accuracy 94 | :param Net: 95 | :param dataset: dataset to use to calculate accuracy 96 | :param writer: pointer to tb 97 | :param global_step: 98 | :param name: 99 | :param epoch: 100 | :return: 101 | """ 102 | # iterate all or partially the dataset and for each sample calculate the accuracy 103 | # bar = tqdm(enumerate(dataset), total=len(dataset)) 104 | map_list = np.zeros((len(dataset.dataset), 3, Net.opt.patch_size, Net.opt.patch_size), 105 | dtype=np.float32) 106 | 107 | set_requires_grad(Net.netG_S2O, False) 108 | set_requires_grad(Net.SegNet, False) 109 | for i, data in enumerate(dataset): 110 | label = data['label'].to(Net.device) 111 | radar = data['radar'].to(Net.device) 112 | patch_names = data['name'] 113 | feature_map = Net.netG_S2O(radar) 114 | seg_map = Net.SegNet(feature_map) 115 | if name == 'Test': 116 | Net.Accuracy_test.update_acc(label, seg_map) 117 | else: 118 | Net.Accuracy_train.update_acc(label, seg_map) 119 | 120 | label_norm, mask = display_label_single_c(label) 121 | seg_map_norm = display_predictions(seg_map, True, mask) 122 | #writer.add_images(name + '/' + str(i) + "/Map", seg_map_norm, global_step=i) 123 | #writer.add_images(name + "/Labels", label_norm, global_step=i) 124 | 125 | img = seg_map_norm 126 | for k, n in enumerate(patch_names): 127 | temp = int(n.split('_')[1]) - 1 128 | map_list[temp] = img[k] 129 | 130 | set_requires_grad(Net.netG_S2O, BL) 131 | set_requires_grad(Net.SegNet, True) 132 | 133 | q_ps = int(Net.opt.patch_size / 4) 134 | ps_d = Net.opt.patch_size - q_ps 135 | tile = np.zeros((3, size[0], size[1])) 136 | for i in range(map_list.shape[0]): 137 | x = int(posx[i]) 138 | y = int(posy[i]) 139 | patch_o = map_list[i] 140 | tile[:, x + q_ps:x + ps_d, y + q_ps:y + ps_d] = patch_o[:, q_ps:ps_d, q_ps:ps_d] 141 | temp = tile[:, 800:7800, 1500:12000] 142 | temp = np.moveaxis(temp, 0, 2) 143 | temp = temp * 255 144 | temp = temp.astype(np.uint8) 145 | temp = Image.fromarray(temp, 'RGB') 146 | temp.save(os.path.join(Net.opt.tb_dir, str(epoch) + name + '_map.png')) 147 | 148 | # get the dictionary of the mean accuracies 149 | if name == 'Test': 150 | acc = Net.Accuracy_test.get_mean_dict() 151 | Net.Accuracy_test.reinit() 152 | # log the accuracy values 153 | Net.Logger_test.append_SN_acc(acc) 154 | # add accuracies to tensorboard 155 | else: 156 | acc = Net.Accuracy_train.get_mean_dict() 157 | Net.Accuracy_train.reinit() 158 | # log the accuracy values 159 | Net.Logger_train.append_SN_acc(acc) 160 | # add accuracies to tensorboard 161 | print('Accuracy = {}'.format(acc)) 162 | writer.add_scalars(name + "/Accuracy", acc, global_step=global_step) 163 | 164 | 165 | def cal_gp(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 166 | """ 167 | Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 168 | 169 | Arguments: 170 | netD (network) -- discriminator network 171 | real_data (tensor array) -- real images 172 | fake_data (tensor array) -- generated images from the generator 173 | device (str) -- GPU torch.device 174 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 175 | constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 176 | lambda_gp (float) -- weight for this loss 177 | 178 | Returns the gradient penalty loss 179 | """ 180 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 181 | interpolatesv = real_data 182 | elif type == 'fake': 183 | interpolatesv = fake_data 184 | elif type == 'mixed': 185 | alpha = torch.rand(real_data.shape[0], 1, device=device) 186 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 187 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 188 | else: 189 | raise NotImplementedError('{} not implemented'.format(type)) 190 | interpolatesv.requires_grad_(True) 191 | disc_interpolates = netD(interpolatesv) 192 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 193 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 194 | create_graph=True, retain_graph=True, only_inputs=True) 195 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 196 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 197 | return gradient_penalty, gradients 198 | 199 | 200 | class List: 201 | """ 202 | This class implements a buffer 203 | """ 204 | def __init__(self, size=50): 205 | self.size = size 206 | # it is list of batch! 207 | self.data = [] 208 | self.data.append(0.5) 209 | 210 | def push_and_pop(self, value): 211 | if len(self.data) < self.size: 212 | self.data.append(value) 213 | else: 214 | self.data.pop(0) 215 | self.data.append(value) 216 | 217 | def mean(self): 218 | return sum(self.data)/len(self.data) 219 | 220 | 221 | def drop_channel(data, min_damping_coeff=0, p_th=0.7): 222 | n_channel = data.shape[1] 223 | channels = list(range(n_channel)) 224 | N_of_drop_ch = randint(0, n_channel-1) 225 | prc = uniform(0, 1) 226 | if prc > p_th: 227 | for i in range(N_of_drop_ch): 228 | ch = randint(0, len(channels)-1) 229 | target_ch = channels.pop(ch) 230 | damping_coeff = round(uniform(0, min_damping_coeff), 1) 231 | data[:, target_ch] = data[:, target_ch] * damping_coeff 232 | return data 233 | -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/image2tensorboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import ImageColor, Image 4 | import numpy as np 5 | from skimage import exposure 6 | from scipy.ndimage import zoom 7 | 8 | 9 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 10 | """ 11 | Author: Alessandro Cattoi 12 | Description: This file defines some function employed to manage images 13 | """ 14 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 15 | 16 | 17 | def display_label_multi_c(tensor): 18 | """ 19 | Display one hot ancoded labels 20 | :param tensor: input tensor of label [NCHW] 21 | :return: a tensor of label ready for tensorboard 22 | """ 23 | 24 | # calculate mask 25 | mask = get_mask_multi_c(tensor) 26 | tensor = torch.argmax(tensor, 1) 27 | # apply the mask, subtracting 5 where there are non classified pixel 28 | tensor = tensor - mask 29 | cmap = mycmap() 30 | # transform image values to color map 31 | tensor = cmap[tensor] 32 | 33 | # bring the last axis to the second position 34 | tensor = np.rollaxis(tensor, 3, 1) 35 | return tensor, mask 36 | 37 | 38 | def display_label_single_c(tensor): 39 | """ 40 | Displays labels 41 | :param tensor: input tensor of label [NHW] 42 | :return: a tensor of label ready for tensorboard 43 | """ 44 | tensor = tensor.cpu().numpy() 45 | # calculate mask 46 | mask = get_mask_single_c(tensor) 47 | # apply the mask, subtracting 5 where there are non classified pixel 48 | tensor = tensor - mask 49 | cmap = mycmap() 50 | # transform image values to color map 51 | tensor = tensor.astype(int) 52 | tensor = cmap[tensor] 53 | # bring the last axis to the second position 54 | tensor = np.rollaxis(tensor, 3, 1) 55 | return tensor, mask 56 | 57 | 58 | def display_predictions(tensor, mask_image=True, mask=0): 59 | """ 60 | Display segmentation map 61 | :param tensor: input tensor of label [NCHW] 62 | :param mask_image: boolean to decide if to mask or not 63 | :param mask: mask array 64 | :return: a tensor of label ready for tensorboard 65 | """ 66 | tensor = torch.argmax(tensor, 1) 67 | tensor = tensor.cpu().numpy() 68 | # apply the mask, subtracting 5 where there are non classified pixel 69 | if mask_image: 70 | tensor = np.where(mask == 256, -1, tensor) 71 | cmap = mycmap() 72 | # transform image values to color map 73 | tensor = tensor.astype(int) 74 | tensor = cmap[tensor] 75 | # bring the last axis to the second position 76 | tensor = np.rollaxis(tensor, 3, 1) 77 | 78 | return tensor 79 | 80 | 81 | def display_input(tensor, mask_image=False, mask=0): 82 | """ 83 | Display any kind of input 84 | :param tensor: tensor of rgb or radar [NCHW] 85 | :param mask_image: mask the image or not 86 | :param mask: mask array 87 | :return: return a tensor ready for tensorboard 88 | """ 89 | if tensor.shape[1] == 2: 90 | temp = torch.zeros((tensor.shape[0], 3, tensor.shape[2], tensor.shape[2])) 91 | temp[:, 0] = tensor[:, 0] 92 | temp[:, 1] = tensor[:, 1] 93 | tensor = temp 94 | else: 95 | min_val = torch.min(tensor) 96 | # makes all value positive 97 | if min_val < 0: 98 | tensor = tensor - min_val 99 | max_val = torch.max(tensor) 100 | # check that value are meaningful and not all equal to zero 101 | if max_val != 0: 102 | tensor = torch.true_divide(tensor, max_val) 103 | tensor = tensor.cpu().detach().numpy() 104 | if mask_image: 105 | # [5] TODO: mascherare correttamente 106 | tensor = tensor * mask 107 | tensor = histogram_equalize(tensor) 108 | return tensor 109 | 110 | 111 | def mycmap(norm_val=255, color_code="RGB"): 112 | """ 113 | Define a colormap to be used when colouring maps 114 | :param norm_val: divide rgb vector with norm_val (1: no action, 255: normalize to 1) 115 | :param color_code: colour coding, RGBA o RGB 116 | :return: return an array of rgb or rgba codes 117 | """ 118 | colours = np.array(["#0000ff", # rosso FF0000# water verde militare -> Forests 119 | "#ff0000", # lime 00FF00# grey -> strets 120 | "#00ff00", # blu 0000FF# lime -> fields 121 | "#565656", # azzurro 0000ff# red -> urban 122 | "#145a32", # verde militare 145A32# blu -> water 123 | "#FFFF00", # giallo 124 | "#FF7000", # arancio 125 | "#FFFFFF", # bianco 126 | "#FF00FF", # fucsia 127 | "#767676", # grigio 128 | "#00CF84", # verde acqua 129 | "#ffffff"]) # nero #000000 # giallo 130 | # init array of colours 131 | cmap = np.zeros((len(colours), len(color_code))) 132 | # convert hex colour to rgb or rgba 133 | for i in range(len(colours)): 134 | cmap[i] = (np.array(ImageColor.getcolor(colours[i], color_code))) 135 | # normalize values between 0 and 1 136 | cmap = cmap/norm_val 137 | return cmap 138 | 139 | 140 | def histogram_equalize(img): 141 | """ 142 | Strech the images histogram between all channel to improve image quality 143 | :param img: image with any shape 144 | :return: same image with histogram streched 145 | """ 146 | img_cdf, bin_centers = exposure.cumulative_distribution(img) 147 | return np.array(np.interp(img, bin_centers, img_cdf)) 148 | 149 | 150 | def get_mask_multi_c(tensor): 151 | """ 152 | 153 | :param tensor: input tensor of label 154 | :return: return a mask for that label 155 | """ 156 | # create a mask with 1 where there are no classified pixel 157 | mask = torch.where(tensor == 0, torch.as_tensor(1), torch.as_tensor(0)) 158 | 159 | mask = (mask[:, 0, :, :] * mask[:, 1, :, :] * mask[:, 2, :, :] * mask[:, 3, :, :] * mask[:, 4, :, :])*5 160 | 161 | 162 | return mask 163 | 164 | 165 | def get_mask_single_c(tensor): 166 | """ 167 | 168 | :param tensor: input tensor of label 169 | :return: return a mask for that label 170 | """ 171 | # create a mask with 1 where there are no classified pixel 172 | mask = np.where(tensor == 255, 256, 0) 173 | return mask 174 | 175 | 176 | def negate_mask(mask): 177 | """ 178 | 179 | :param mask: mask of image 180 | :return: covert 0 to 1 and other value to 0 181 | """ 182 | mask = torch.where(mask == 0, torch.as_tensor(1), torch.as_tensor(0)) 183 | 184 | return mask 185 | 186 | 187 | def norm(data, scale_factor=1, scale=False): 188 | """ 189 | Brings data in the normal form for PIL saving 190 | :param data: image to be transformed 191 | :param scale_factor: 192 | :param scale: 193 | :return: 194 | """ 195 | if data.shape[0] < 4: 196 | data = np.moveaxis(data, 0, -1) 197 | else: 198 | data = np.moveaxis(data, 0, 1) 199 | data = data * 255 200 | if scale: 201 | data = zoom(data, (scale_factor, scale_factor, 1), mode='nearest', prefilter=False) 202 | return data 203 | 204 | 205 | def denorm(data, parameters_path, typ): 206 | """ 207 | Denorm an image rescuing the original coding 208 | :param data: image to be decode 209 | :param parameters_path: path to coding parameters 210 | :param typ: if radar or rgb 211 | :return: 212 | """ 213 | try: 214 | mx = np.load(os.path.join(parameters_path, '1_max_radar.npy')) 215 | for i in range(3): 216 | data[i, :, :] = data[i, :, :] * mx[i] 217 | except: 218 | pass 219 | center = np.load(os.path.join(parameters_path, '1_center_' + typ + '.npy')) 220 | std = np.load(os.path.join(parameters_path, '1_std_' + typ + '.npy')) 221 | mean = np.load(os.path.join(parameters_path, '1_mean_' + typ + '.npy')) 222 | for i in range(3): 223 | try: 224 | _ = len(center) 225 | data[i, :, :] = data[i, :, :] + center[i] 226 | except: 227 | pass 228 | data[i, :, :] = data[i, :, :] * std[i] 229 | data[i, :, :] = data[i, :, :] + mean[i] 230 | return data 231 | 232 | 233 | def reconstruct_tile(name, ps, posx, posy, save_dir, size, epoch, data, rgb=True, parameter_path=None, data_s=None): 234 | media = 0.7678273 235 | q_ps = int(ps / 4) 236 | ps_d = ps - q_ps 237 | ch = data.shape[1] 238 | tile = torch.zeros((ch, size[0], size[1])) 239 | for i in range(data.shape[0]): 240 | x = int(posx[i]) 241 | y = int(posy[i]) 242 | patch_o = data[i] 243 | tile[:, x + q_ps:x + ps_d, y + q_ps:y + ps_d] = patch_o[:, q_ps:ps_d, q_ps:ps_d] 244 | #temp = np.array(tile[:, 800:7800, 1500:12000]) 245 | #np.save(os.path.join(save_dir, str(epoch) + name + '.png'), temp) 246 | #tile_s = denorm(tile_s, os.path.join(parameter_path, 'radar'), 'radar') 247 | torch.save(tile, os.path.join(save_dir, str(epoch) + name + '.pt')) 248 | temp = tile.cpu().detach().numpy() 249 | temp = np.moveaxis(temp[0:3, 800:7800, 1500:12000], 0, -1) 250 | temp = temp - np.min(temp) 251 | if rgb: 252 | temp = np.log(temp + 1) 253 | temp = np.where(temp > 2.5 * media, 2.5 * media, temp) 254 | temp = temp / (2.5 * media) 255 | else: 256 | temp = temp / np.max(temp) 257 | temp = temp * 255 258 | temp = temp.astype(np.uint8) 259 | temp = Image.fromarray(temp, 'RGB') 260 | temp.save(os.path.join(save_dir, str(epoch) + name + '.png')) 261 | -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/init_weights.py: -------------------------------------------------------------------------------- 1 | from torch.nn import init 2 | 3 | 4 | def init_weights(net, init_type='normal', init_gain=0.02): 5 | """ 6 | Init network weights 7 | :param net: net to which init weights 8 | :param init_type: type of weight initialisation 9 | :param init_gain: gain value for weight initialisation func 10 | :return: 11 | """ 12 | 13 | def init_func(m): 14 | """ 15 | :param m: is like self so is the class Generator for example 16 | :return: 17 | """ 18 | class_name = m.__class__.__name__ 19 | if hasattr(m, 'weight') and (class_name.find('Conv') != -1 or class_name.find('Linear') != -1): 20 | if init_type == 'normal': 21 | init.normal_(m.weight.data, 0.0, init_gain) 22 | elif init_type == 'xavier': 23 | init.xavier_normal_(m.weight.data, gain=init_gain) 24 | elif init_type == 'kaiming': 25 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 26 | elif init_type == 'orthogonal': 27 | init.orthogonal_(m.weight.data, gain=init_gain) 28 | else: 29 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 30 | if hasattr(m, 'bias') and m.bias is not None: 31 | init.constant_(m.bias.data, 0.0) 32 | # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 33 | elif class_name.find('BatchNorm2d') != -1: 34 | init.normal_(m.weight.data, 1.0, init_gain) 35 | init.constant_(m.bias.data, 0.0) 36 | 37 | # print('initialize network with %s' % init_type) 38 | # apply the initialization function 39 | net.apply(init_func) 40 | -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/tile_creator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from Lib.Nets.utils.generic.image2tensorboard import reconstruct_tile 4 | import pickle as pkl 5 | 6 | path = '/home/ale/Documents/Python/13_Tesi_2/runs/agan/10_32_idt/checkpoints/args.pkl' 7 | opt = pkl.load(open(path, "rb")) 8 | posx = pkl.load(open(os.path.join(opt.data_dir_train, 'posx.pkl'), "rb")) 9 | posy = pkl.load(open(os.path.join(opt.data_dir_train, 'posy.pkl'), "rb")) 10 | 11 | file_list = os.listdir(opt.tb_dir) 12 | tile_list = list(filter(lambda x: '.pt' in x, file_list)) 13 | name = 'RT' 14 | 15 | par_path = '/home/ale/Documents/Python/13_Tesi_2/Data/Datasets/EUSAR/Train/' 16 | for i in tile_list: 17 | epoch = i.split('.')[0] 18 | trans = torch.load(os.path.join(opt.tb_dir, epoch + '.pt')) 19 | reconstruct_tile(name, opt.patch_size, posx, posy, opt.tb_dir, [8736, 13984], epoch, trans)#, parameter_path=par_path) 20 | -------------------------------------------------------------------------------- /Lib/Nets/utils/generic/trainSN.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.tensorboard import SummaryWriter 3 | from Lib.Nets.SN.SN import SN 4 | from Lib.Nets.utils.config.config_routine import config_routine 5 | from Lib.Nets.utils.config.general_parser import general_parser 6 | from Lib.Nets.utils.config.specific_parser import specific_parser 7 | from Lib.Datasets.EUSAR.EUSARDataset import EUSARDataset 8 | import argparse 9 | from torch.utils.data import DataLoader 10 | from Lib.Nets.utils.generic.generic_training import get_subset 11 | 12 | 13 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 14 | """ 15 | Author: Alessandro Cattoi 16 | Description: This function is employed to test feature extraction capability. In fact can be called by network 17 | implementations to train a classifier on the top of the capacity just leaned. 18 | """ 19 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 20 | 21 | 22 | def trainSN(options, epoch, device): 23 | """ 24 | Run a quick SN training 25 | :param options: pretrained model options 26 | :param epoch: epoch to be loaded 27 | :param device: 28 | :return: 29 | """ 30 | """-------------------------------CONFIG----------------------------------""" 31 | parser = argparse.ArgumentParser(description="PyTorch Regression GAN") 32 | parser = general_parser(parser) 33 | opt = specific_parser( 34 | parser=parser, run_folder=options.log_dir, mode='train', tot_epochs=30, pretrained_GAN=options.checkpoint_dir, 35 | GAN_epoch=epoch, acc_log_freq=options.acc_log_freq, loss_log_freq=options.loss_log_freq, 36 | batch_size_SN=options.batch_size_SN, images_log_freq=options.images_log_freq, 37 | data_dir_train=options.data_dir_train2, data_dir_test=options.data_dir_test2, 38 | experiment_name='SN'+str(epoch), sar_c=options.sar_c, optical_c=options.optical_c, 39 | save_model_freq=1000, res_block_N=options.res_block_N) 40 | 41 | opt = config_routine(opt) 42 | 43 | """-----------------------------DATA LOADER--------------------------------""" 44 | train_dataset = EUSARDataset(os.path.join(options.data_dir_train2), True, False, options.sar_c, options.optical_c) 45 | train_dataset = get_subset(train_dataset, options.prc_test) 46 | train_dataset = DataLoader(train_dataset, batch_size=options.batch_size_SN, shuffle=True, 47 | num_workers=options.workers, pin_memory=True, drop_last=False) 48 | 49 | test_dataset = EUSARDataset(os.path.join(options.data_dir_test2), True, False, options.sar_c, options.optical_c) 50 | test_dataset = get_subset(test_dataset, options.prc_test, True) 51 | test_dataset = DataLoader(test_dataset, batch_size=options.batch_size_SN, shuffle=False, 52 | num_workers=options.workers, pin_memory=True, drop_last=False) 53 | 54 | """--------------------------------TRAIN-----------------------------------""" 55 | # Init model 56 | model = SN(opt, device) 57 | 58 | # set up tensorboard logging 59 | writer = SummaryWriter(log_dir=os.path.join(opt.tb_dir)) 60 | # Model Training 61 | model.train(train_dataset, test_dataset, writer) 62 | -------------------------------------------------------------------------------- /Lib/utils/Logger/Logger.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pickle as pkl 3 | from Lib.utils.generic.generic_utils import color_hex_list_generator 4 | import plotly.graph_objects as go 5 | from plotly.subplots import make_subplots 6 | from Lib.utils.generic.generic_utils import moving_average 7 | import os 8 | 9 | 10 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 11 | """ 12 | Author: Alessandro Cattoi 13 | Description: This class is thought to log data during training, then reload the results and plot them 14 | """ 15 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 16 | 17 | 18 | class Logger: 19 | """ 20 | This class is thought to store and plot data. 21 | The storing is made during training execution, the plotting is performed after 22 | """ 23 | def __init__(self, mode='', shades=5): 24 | self.mode = mode 25 | self.epoch = None 26 | self.shades = shades 27 | self.shade_name = [] 28 | self.color_shade = self.get_shades() 29 | self.base_color = self.get_base_color() 30 | self.fig_G_loss = None 31 | self.fig_D_S_loss = None 32 | self.fig_D_O_loss = None 33 | self.fig_GAN = None 34 | self.fig_SN_loss = None 35 | self.fig_SN_acc = None 36 | self.fig_SN = None 37 | self.loss_step_df = pd.DataFrame() 38 | self.acc_step_df = pd.DataFrame() 39 | self.SN_loss_df = pd.DataFrame() 40 | self.SN_acc_df = pd.DataFrame() 41 | if self.mode == "train": 42 | self.G_loss_df = pd.DataFrame() 43 | self.D_S_loss_df = pd.DataFrame() 44 | self.D_O_loss_df = pd.DataFrame() 45 | 46 | # All these function append a dictionary which is a sample of each of these acquired variable 47 | def append_G(self, new_sample): 48 | self.G_loss_df = self.G_loss_df.append(new_sample, ignore_index=True) 49 | 50 | def append_D_S(self, new_sample): 51 | self.D_S_loss_df = self.D_S_loss_df.append(new_sample, ignore_index=True) 52 | 53 | def append_D_O(self, new_sample): 54 | self.D_O_loss_df = self.D_O_loss_df.append(new_sample, ignore_index=True) 55 | 56 | def append_SN_loss(self, new_sample): 57 | self.SN_loss_df = self.SN_loss_df.append(new_sample, ignore_index=True) 58 | 59 | def append_loss_step(self, new_sample): 60 | self.loss_step_df = self.loss_step_df.append(new_sample, ignore_index=True) 61 | 62 | def append_SN_acc(self, new_sample): 63 | self.SN_acc_df = self.SN_acc_df.append(new_sample, ignore_index=True) 64 | 65 | def append_acc_step(self, new_sample): 66 | self.acc_step_df = self.acc_step_df.append(new_sample, ignore_index=True) 67 | 68 | def save_logger(self, path, name, epoch=None): 69 | self.epoch = epoch 70 | if epoch is not None: 71 | file = os.path.join(path, str(self.epoch) + "_" + name + "_logger.pkl") 72 | else: 73 | file = os.path.join(path, name + "logger.pkl") 74 | pkl.dump(self, open(file, "wb")) 75 | 76 | def get_shades(self): 77 | """ 78 | Create a colour palette for each base color. Base colour are: 79 | KEYS: 80 | - red 81 | - violet 82 | - blue 83 | - green 84 | - yellow 85 | - grey 86 | For each of them a list of 87 | :return: dict of list of self.shades shaded colour is created 88 | """ 89 | self.shade_name = ['red', 'blue', 'yellow', 'green', 'violet', 'grey'] 90 | red_list = color_hex_list_generator("#FF0000", "#F5B7B1", self.shades) 91 | viol_list = color_hex_list_generator("#4A235A", "#EBDEF0", self.shades) 92 | blue_list = color_hex_list_generator("#0B3AEA", "#8FF0FC", self.shades) 93 | green_list = color_hex_list_generator("#145A32", "#D1F2EB", self.shades) 94 | yellow_list = color_hex_list_generator("#FF6C00", "#F7DC6F", self.shades) 95 | grey_list = color_hex_list_generator("#17202A", "#F2F3F4", self.shades) 96 | color = { 97 | "red": red_list, 98 | "violet": viol_list, 99 | "blue": blue_list, 100 | "green": green_list, 101 | "yellow": yellow_list, 102 | "grey": grey_list, 103 | } 104 | return color 105 | 106 | def get_base_color(self): 107 | """ 108 | Return a list of complementary colur 109 | :return: list of complementary color 110 | """ 111 | color = [ 112 | '#2971B0', # "blue": 113 | '#3A9E1F', # "green": 114 | '#E39112', # "Orange": 115 | '#EC4040', # "red": 116 | '#00FFFF', # "light_blue" 117 | '#FFFF00', # "yellow": 118 | '#FF00FF', # "fuchsia": 119 | '#00FF00', # "light_green": 120 | '#000000', # "black": 121 | '#757575', # "grey": 122 | '#7800FF', # "violet": 123 | ] 124 | return color 125 | 126 | def create_figure(self): 127 | """ 128 | Create a figure and add two plot to it 129 | :return: 130 | """ 131 | self.loss_step_df = self.norm_step(self.loss_step_df) 132 | #self.acc_step_df = self.norm_step(self.acc_step_df) 133 | if self.mode == "train": 134 | self.fig_G_loss = self.generate_plot(self.G_loss_df, self.loss_step_df) 135 | self.fig_G_loss = self.plot_layout(self.fig_G_loss, "Generator Losses") 136 | self.fig_D_S_loss = self.generate_plot(self.D_S_loss_df, self.loss_step_df) 137 | self.fig_D_S_loss = self.plot_layout(self.fig_D_S_loss, "SAR Discriminator Losses") 138 | self.fig_D_O_loss = self.generate_plot(self.D_O_loss_df, self.loss_step_df) 139 | self.fig_D_O_loss = self.plot_layout(self.fig_D_O_loss, "Optical Discriminator Losses") 140 | 141 | '''self.fig_SN_loss = self.generate_plot(self.SN_loss_df, self.loss_step_df) 142 | self.fig_SN_loss = self.plot_layout(self.fig_SN_loss, "Segmentation Network Losses") 143 | self.fig_SN_acc = self.generate_plot(self.SN_acc_df, self.acc_step_df) 144 | self.fig_SN_acc = self.plot_layout(self.fig_SN_acc, "Segmentation Accuracy", "Epochs", "Accuracy [%]")''' 145 | 146 | def generate_plot(self, df, df_x): 147 | """ 148 | Generates a plot where each column is added as a singal 149 | :param df: is the y values 150 | :param df_x: is the x value 151 | :return: 152 | """ 153 | fig = go.Figure() 154 | for i, col in enumerate(df.columns): 155 | color = dict(color=self.base_color[i]) 156 | fig.add_trace(go.Scatter(x=df_x['step'], y=df[col], mode='lines+markers', line=color, name=col)) 157 | return fig 158 | 159 | def create_subplot(self): 160 | if self.mode == "train": 161 | self.fig_GAN = self.generate_subplot(self.loss_step_df, self.loss_step_df, self.G_loss_df, self.D_S_loss_df, self.D_O_loss_df) 162 | self.fig_GAN = self.plot_layout(self.fig_GAN, "GAN Losses", "", "") 163 | self.fig_SN = self.generate_subplot(self.loss_step_df, self.acc_step_df, self.SN_loss_df, self.SN_acc_df) 164 | self.fig_SN = self.plot_layout(self.fig_SN, "Segmentation Network Performances", "", "") 165 | 166 | def generate_subplot(self, df_x1, df_x2, df1, df2, df3=None): 167 | if df3 is not None: 168 | row = 3 169 | else: 170 | row = 2 171 | 172 | fig = make_subplots(rows=row, cols=1, 173 | shared_xaxes=True, 174 | vertical_spacing=0.02) 175 | 176 | for i, col in enumerate(df1.columns): 177 | color = dict(color=self.base_color[i]) 178 | fig.add_trace(go.Scatter(x=df_x1['step'], y=df1[col], mode='lines+markers', line=color, name=col), row=1, col=1) 179 | 180 | for i, col in enumerate(df2.columns): 181 | color = dict(color=self.base_color[i]) 182 | fig.add_trace(go.Scatter(x=df_x2['step'], y=df2[col], mode='lines+markers', line=color, name=col), row=2, col=1) 183 | 184 | if df3 is not None: 185 | for i, col in enumerate(df3.columns): 186 | color = dict(color=self.base_color[i]) 187 | fig.add_trace(go.Scatter(x=df_x1['step'], y=df3[col], mode='lines+markers', line=color, name=col), row=3, col=1) 188 | return fig 189 | 190 | @staticmethod 191 | def plot_layout(fig, title, x_title="Epochs", y_title="Loss Value"): 192 | """ 193 | 194 | :param fig: fig to which apply the layout 195 | :param title: title 196 | :param x_title: title position x 197 | :param y_title: title position y 198 | :return: 199 | """ 200 | fig.update_layout( 201 | #showlegend=False, 202 | title={'text': title, 'y':0.95, 'x':0.5, 'xanchor': 'center', 'yanchor': 'top'}, 203 | xaxis_title=x_title, 204 | yaxis_title=y_title, 205 | legend_x=0.84, 206 | legend_y=0.01, 207 | font=dict(family="Times New Roman, monospace", size=29, color="Black"), 208 | legend=dict(title="Legend", bgcolor="White", bordercolor="Black", borderwidth=2) 209 | ) 210 | fig.update_xaxes(showline=True, linewidth=0.5, linecolor='Black', mirror=True, range=[-0.3, 30.3]) 211 | #TODO: range 212 | fig.update_yaxes(showline=True, linewidth=0.5, linecolor='Black', mirror=True, range=[0, 100], tickmode = 'linear', dtick = 10) 213 | fig.update_xaxes(ticks="outside") 214 | fig.update_yaxes(ticks="outside") 215 | return fig 216 | 217 | def norm_step(self, df): 218 | """ 219 | Normilize the step value by the number of epoch so that to have as x the number of epoch not the bumber of step 220 | :param df: input data 221 | :return: 222 | """ 223 | mx = df['step'].iloc[-1] 224 | if self.epoch is not None: 225 | norm = mx/self.epoch 226 | for i, val in enumerate(df['step']): 227 | df['step'][i] = val/norm 228 | return df 229 | 230 | def save_fig(self, path=""): 231 | """ 232 | Save figures 233 | :param path: 234 | :return: 235 | """ 236 | self.fig_G_loss.write_html("/home/ale/Desktop/fig_G_loss.html") 237 | self.fig_D_S_loss.write_html("/home/ale/Desktop/fig_D_S_loss.html") 238 | self.fig_D_O_loss.write_html("/home/ale/Desktop/fig_D_O_loss.html") 239 | self.fig_SN_loss.write_html("/home/ale/Desktop/fig_SN_loss.html") 240 | self.fig_SN_acc.write_html("/home/ale/Desktop/fig_SN_acc.html") 241 | 242 | def filter_all(self, win, pad): 243 | """ 244 | filter all loss function with mov mean 245 | :param win: mov mean filter win 246 | :param pad: 247 | :return: 248 | """ 249 | if self.mode == "train": 250 | self.G_loss_df = self.filter_df(self.G_loss_df, win, pad) 251 | self.D_S_loss_df = self.filter_df(self.D_S_loss_df, win, pad) 252 | self.D_O_loss_df = self.filter_df(self.D_O_loss_df, win, pad) 253 | self.SN_loss_df = self.filter_df(self.SN_loss_df, win, pad) 254 | # self.SN_acc_df = self.filter_df(self.SN_acc_df, win, pad) 255 | 256 | @staticmethod 257 | def filter_df(df, win, pad): 258 | """ 259 | Aplly movmean 260 | :param df: data to be filtered 261 | :param win: win size 262 | :param pad: 263 | :return: 264 | """ 265 | for i in df.columns: 266 | df[i] = moving_average(df[i], win, pad) 267 | return df 268 | -------------------------------------------------------------------------------- /Lib/utils/Logger/Logger_cmp.py: -------------------------------------------------------------------------------- 1 | import plotly.graph_objects as go 2 | from Lib.utils.Logger.Logger import Logger 3 | 4 | 5 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 6 | """ 7 | Author: Alessandro Cattoi 8 | Description: This class is thought to compare different Logger instance 9 | """ 10 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 11 | 12 | 13 | class Logger_cmp(Logger): 14 | """ 15 | This class is an extention of Logger, it is used to compare logger. Basically is composed of a list of Logger and has 16 | some methods to plot each Logger against another 17 | """ 18 | def __init__(self, mode, shades, logger_list, name_list, title): 19 | """ 20 | 21 | :param mode: not implemented yet 22 | :param shades: number of shades so basically number of signal for each Logger 23 | :param logger_list: logger list 24 | :param name_list: name of each logger 25 | :param title: general title 26 | The title of the graph is composed: title + name_list[0] + VS + name_list[1] + VS + ... + name_list[N] 27 | """ 28 | self.mode = mode 29 | self.shades = shades 30 | self.shade_name = [] 31 | self.color_shade = self.get_shades() 32 | self.base_color = self.get_base_color() 33 | self.logger_list = logger_list 34 | self.name_list = name_list 35 | self.title = title 36 | self.len = len(logger_list) 37 | 38 | def create_figure(self): 39 | """ 40 | This method creates the figure of the loss function and the accuracy 41 | :return: 42 | """ 43 | if self.mode == "trainall": 44 | self.fig_G_loss = self.generate_plot(self.G_loss_df) 45 | self.fig_G_loss = self.plot_layout(self.fig_G_loss, "Generator Losses") 46 | self.fig_D_S_loss = self.generate_plot(self.D_S_loss_df) 47 | self.fig_D_S_loss = self.plot_layout(self.fig_D_S_loss, "SAR Discriminator Losses") 48 | self.fig_D_O_loss = self.generate_plot(self.D_O_loss_df) 49 | self.fig_D_O_loss = self.plot_layout(self.fig_D_O_loss, "Optical Discriminator Losses") 50 | 51 | self.fig_SN_loss = self.generate_plot('SN_loss_df') 52 | title = self.get_title(self.title) 53 | self.fig_SN_loss = self.plot_layout(self.fig_SN_loss, title) 54 | 55 | self.fig_SN_acc = self.generate_plot('SN_acc_df') 56 | title = self.get_title(self.title) 57 | self.fig_SN_acc = self.plot_layout(self.fig_SN_acc, title, "Epochs", "OA [%]") 58 | 59 | def generate_plot(self, var): 60 | """ 61 | This function iterate over each signal of each logger and add them to the plot assigning always different colour 62 | :param var: name of the variable in logger which contains the x axis 63 | :return: 64 | """ 65 | fig = go.Figure() 66 | for j, logger in enumerate(self.logger_list): 67 | df = getattr(logger, var) 68 | if 'acc' in var: 69 | df_x = getattr(logger, 'acc_step_df') 70 | else: 71 | df_x = getattr(logger, 'loss_step_df') 72 | self.epoch = logger.epoch 73 | df_x = self.norm_step(df_x) 74 | #shades = self.color_shade[self.shade_name[j]] 75 | shades = self.base_color 76 | for i, col in enumerate(df.columns): 77 | color = color=shades[j] 78 | if i==0: 79 | #TODO: which metrics plot, *100 80 | fig.add_trace(go.Scatter(x=df_x['step'], y=df[col]*100, mode='lines+markers', line=dict(color=color, width=4), 81 | marker=dict(color=color, size=10), name=col, )) 82 | return fig 83 | 84 | def get_title(self, title): 85 | """ 86 | Create the title from the name list 87 | :param title: 88 | :return: 89 | """ 90 | print(title) 91 | ''' 92 | if title != "": 93 | for i in range(self.len - 1): 94 | title = title + self.name_list[i] + ' VS ' 95 | title = title + self.name_list[-1]''' 96 | return title 97 | -------------------------------------------------------------------------------- /Lib/utils/Logger/__pycache__/Logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/Logger/__pycache__/Logger.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/utils/Logger/__pycache__/Logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/Logger/__pycache__/Logger.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/utils/Logger/__pycache__/Logger_cmp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/Logger/__pycache__/Logger_cmp.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/utils/__pycache__/generic_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/__pycache__/generic_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/utils/__pycache__/image2tensorboard.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/__pycache__/image2tensorboard.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/utils/__pycache__/init_weights.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/__pycache__/init_weights.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/utils/__pycache__/net_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/__pycache__/net_util.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/utils/generic/__pycache__/generic_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/generic/__pycache__/generic_utils.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/utils/generic/__pycache__/generic_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/generic/__pycache__/generic_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Lib/utils/generic/generic_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | import numpy as np 4 | from colour import Color 5 | import os 6 | 7 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 8 | """ 9 | Author: Alessandro Cattoi 10 | Description: Here there are some general purpose functions 11 | """ 12 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 13 | 14 | 15 | def start(): 16 | """ 17 | Return the current time, works in couple with stop 18 | :return: 19 | """ 20 | tic = time.time() 21 | return tic 22 | 23 | 24 | def stop(toc, process_name=''): 25 | """ 26 | Works is couple with start and take the returned actual time captured by start to calculate and print the difference 27 | :param toc: 28 | :param process_name: optional string appended before printing the time to give some detail 29 | :return: 30 | """ 31 | print("{} execution time = {:.5f} s".format(process_name, time.time() - toc)) 32 | 33 | 34 | def set_rand_seed(): 35 | """ 36 | return a random value between 1 and 10000 37 | :return: 38 | """ 39 | return random.randint(1, 10000) 40 | 41 | 42 | def mean(arr, support=None): 43 | """ 44 | Calculate the mean of an array 45 | :param arr: input list of data 46 | :return: 47 | """ 48 | if support is not None: 49 | return sum(arr)/sum(support) 50 | else: 51 | return sum(arr)/len(arr) 52 | 53 | 54 | def moving_average(data, win_dim=3, correct_init=False): 55 | """ 56 | Calculate a moving mean. In moving mean first samples are not filterd typically because there are not enough sample. 57 | If use correct init, a tenth of the filtering window is copied in the first samples. 58 | :param data: input array of any type 59 | :param win_dim: dimension of filtering window 60 | :param correct_init: decide if to filter even firsts samples 61 | :return: averaged array 62 | """ 63 | k = 10 64 | data = np.array(data).astype(float) 65 | win = [] 66 | mov_data = [] 67 | for value in data: 68 | if len(win) >= win_dim: 69 | win.pop(0) 70 | win.append(value) 71 | else: 72 | win.append(value) 73 | mov_data.append(np.mean(win)) 74 | if correct_init: 75 | back_ward_mean = mov_data[int(win_dim/k):2*int(win_dim/k)] 76 | mov_data[0:int(win_dim/k)] = back_ward_mean 77 | return mov_data 78 | 79 | 80 | def color_hex_list_generator(start_col, end_col, n_shades=5): 81 | """ 82 | Create list of shaded colour 83 | :param start_col: starting color 84 | :param end_col: destination colour 85 | :param n_shades: number of colours 86 | :return: return list of colour 87 | """ 88 | start_col = Color(start_col) 89 | colors = list(start_col.range_to(Color(end_col), n_shades)) 90 | color_list = [] 91 | for i in colors: 92 | color_list.append(i.hex) 93 | return color_list 94 | 95 | 96 | def get_norm_param(data_path): 97 | """ 98 | :param data_path: path to rdar or rgb folder 99 | :return: ['mean', 'std', 'center', 'max'] 100 | """ 101 | dir_list = sorted(os.listdir(data_path)) 102 | param_name = [] 103 | for x in dir_list: 104 | param_name.append(x.split('_')[1]) 105 | param_name = list(filter(lambda x: '.' not in x, param_name)) 106 | file_name = list(filter(lambda x: x.split('_')[1] in param_name, dir_list)) 107 | temp = [] 108 | for x in file_name: 109 | temp.append(np.load(os.path.join(data_path, x))) 110 | return temp[0], temp[1], temp[2], temp[3] -------------------------------------------------------------------------------- /Lib/utils/metrics/Accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics 3 | import torch 4 | from Lib.utils.generic.generic_utils import mean 5 | 6 | 7 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 8 | """ 9 | Author: Alessandro Cattoi 10 | Description: This file implements a class who calculates the accuracy of the classification results 11 | """ 12 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 13 | 14 | 15 | class Accuracy: 16 | """ 17 | NB always use at least batch of 16 because otherwise accuracy_w is very inaccurate. 18 | If it calculates accuracy and a class is not present any false positive of that class is not counted so the accuracy is 19 | strongly overestimated 20 | This class get prediction and label tensor of any shape it flat them removes pixel that has not a label and calculate an 21 | overall accuracy of the tensor. Each value calculated is then stored in a list. 22 | At any time using the get mean dict is possible to return the mean value of the lists of scores store in a dictionary. 23 | Labels are composed as follow: 24 | CLASS NAME VALUE 25 | - forest --> 0 26 | - street --> 1 27 | - field --> 2 28 | - urban --> 3 29 | - water --> 4 30 | - Not classified --> 255 31 | """ 32 | def __init__(self, labels=[0, 1, 2, 3, 4]): 33 | super(Accuracy, self).__init__() 34 | self.f1 = [] 35 | self.AA = [] 36 | self.weighted_acc = [] 37 | self.support = [] 38 | self.labels = labels 39 | 40 | def update_acc(self, y_true, y_pred): 41 | """ 42 | Calculates accuracy scores for classification 43 | AA is all good pixel divided by all pixel 44 | weighted_acc is the weighted mean between each class correct pixel over total pixel classified per class 45 | :param y_true: labels 46 | :param y_pred: prediction 47 | :return: accuracy dict with keys: f1, AA, AA_w 48 | """ 49 | y_true = y_true.cpu().numpy() 50 | y_pred = torch.argmax(y_pred, 1) 51 | y_pred = y_pred.cpu().numpy() 52 | 53 | y_true_flat = y_true.flatten() 54 | y_pred_flat = y_pred.flatten() 55 | 56 | index_f = np.argwhere(y_true_flat == 255) 57 | y_true_flat = np.delete(y_true_flat, index_f) 58 | y_pred_flat = np.delete(y_pred_flat, index_f) 59 | # print(metrics.classification_report(y_true_flat, y_pred_flat, zero_division=0)) 60 | # print(metrics.confusion_matrix(y_true_flat, y_pred_flat)) 61 | 62 | self.f1.append(metrics.f1_score(y_true_flat, y_pred_flat, labels=self.labels, average='macro', zero_division=0)*len(y_true_flat)) 63 | self.AA.append(metrics.accuracy_score(y_true_flat, y_pred_flat)*len(y_true_flat)) 64 | self.support.append(len(y_true_flat)) 65 | self.weighted_acc.append(metrics.precision_score(y_true_flat, y_pred_flat, average='macro', zero_division=0)*len(y_true_flat)) 66 | 67 | def get_mean_dict(self): 68 | """ 69 | Calculates mean stored in the calss list attributes and return them in a dictionary 70 | (that can easily stored in a Logger or loaded in tensorboard) 71 | :return: 72 | """ 73 | accuracy = { 74 | 'f1': mean(self.f1, self.support), 75 | 'AA': mean(self.AA, self.support), 76 | 'AA_w': mean(self.weighted_acc, self.support), 77 | } 78 | return accuracy 79 | 80 | def reinit(self): 81 | self.f1 = [] 82 | self.AA = [] 83 | self.weighted_acc = [] 84 | self.support = [] 85 | -------------------------------------------------------------------------------- /Lib/utils/metrics/__pycache__/Accuracy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/metrics/__pycache__/Accuracy.cpython-37.pyc -------------------------------------------------------------------------------- /Lib/utils/metrics/__pycache__/Accuracy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/Lib/utils/metrics/__pycache__/Accuracy.cpython-38.pyc -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import torch.backends.cudnn as cudnn 5 | from Lib.Datasets.EUSAR.EUSARDataset import EUSARDataset 6 | from Lib.Nets.SN.SN import SN 7 | from Lib.Nets.utils.config.config_routine import config_routine 8 | from Lib.Nets.utils.config.general_parser import general_parser 9 | from Lib.Nets.utils.config.specific_parser import specific_parser 10 | import argparse 11 | 12 | 13 | """-------------------------------CONFIG----------------------------------""" 14 | parser = argparse.ArgumentParser(description="PyTorch Regressor GAN") 15 | parser = general_parser(parser) 16 | opt = specific_parser( 17 | parser=parser, run_folder='runs/sn/', mode='eval', tot_epochs=30, res_block_N=6, 18 | restoring_rep_path=None, start_from_epoch=None, 19 | pretrained_GAN=None, GAN_epoch=None, seed=None, 20 | batch_size_SN=16, prc_train=1, prc_test=1, prc_val=None, sar_c=5, optical_c=4, 21 | data_dir_train='Data/Train/EUSAR/128_sn_corr', data_dir_test='Data/Test/EUSAR/128_sn_corr', 22 | acc_log_freq=29, loss_log_freq=1, save_model_freq=100, images_log_freq=None, 23 | experiment_name='sn_', 24 | run_description='Classifico con nuova metrica accuracy e tutto pt') 25 | opt = config_routine(opt) 26 | 27 | """-------------------------------LOAD DATA----------------------------------""" 28 | train_dataset = EUSARDataset(os.path.join(opt.data_dir_train), True, False, opt.sar_c, opt.optical_c) 29 | test_dataset = EUSARDataset(os.path.join(opt.data_dir_test), False, True, opt.sar_c, opt.optical_c) 30 | train_dataset = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, 31 | num_workers=opt.workers, pin_memory=True, drop_last=False) 32 | test_dataset = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, 33 | num_workers=opt.workers, pin_memory=True, drop_last=False) 34 | 35 | """---------------------------------TRAIN------------------------------------""" 36 | # Set cuda 37 | device = torch.device("cuda:0") 38 | cudnn.benchmark = True 39 | 40 | # Init model 41 | model = SN(opt, device) 42 | # Model Training 43 | 44 | for i in train_dataset: 45 | model.set_input(i) 46 | model.forward() 47 | 48 | -------------------------------------------------------------------------------- /for_server.sh: -------------------------------------------------------------------------------- 1 | # Quick commit command to commit last changes to be run on GPU server 2 | commit_msg=${1:-"for server"} 3 | cd /home/ale/Documents/Python/13_Tesi_2/ || exit 4 | git add --all 5 | git commit -m "$commit_msg" 6 | git push -u origin master 7 | git status -------------------------------------------------------------------------------- /mainBL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | import torch.backends.cudnn as cudnn 5 | from Lib.Datasets.EUSAR.EUSARDataset import EUSARDataset 6 | from Lib.Nets.BL.BL import BL 7 | from Lib.Nets.utils.config.config_routine import config_routine 8 | from Lib.Nets.utils.config.general_parser import general_parser 9 | from Lib.Nets.utils.config.specific_parser import specific_parser 10 | from Lib.Nets.utils.generic.generic_training import get_subset 11 | from torch.utils.data.dataloader import DataLoader 12 | import argparse 13 | 14 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 15 | """ 16 | Author: Alessandro Cattoi 17 | Description: This main can be employed to fine tune any network to perform Semantic Segmentation. 18 | "CONFIG" section give quick access to some parameter setting. 19 | """ 20 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 21 | 22 | """-------------------------------CONFIG----------------------------------""" 23 | parser = argparse.ArgumentParser(description="PyTorch Regression GAN") 24 | parser = general_parser(parser) 25 | opt = specific_parser( 26 | parser=parser, run_folder='runs/bl/', mode='train', tot_epochs=30, res_block_N=6, 27 | restoring_rep_path=None, start_from_epoch=None, 28 | batch_size_SN=16, prc_train=1, prc_test=1, prc_val=None, sar_c=5, optical_c=4, 29 | data_dir_train='Data/Train/EUSAR/128_sn_corr', data_dir_test='Data/Test/EUSAR/128_sn_corr', 30 | images_log_freq=50, acc_log_freq=1, loss_log_freq=10, save_model_freq=1, 31 | experiment_name='prova', 32 | run_description='faccio il train con 128 alla BL con i vecchi dati') 33 | opt = config_routine(opt) 34 | 35 | """-------------------------------LOAD DATA----------------------------------""" 36 | train_dataset = EUSARDataset(os.path.join(opt.data_dir_train), True, False, opt.sar_c, opt.optical_c) 37 | train_dataset = get_subset(train_dataset, opt.prc_train) 38 | test_dataset = EUSARDataset(os.path.join(opt.data_dir_test), True, False, opt.sar_c, opt.optical_c) 39 | test_dataset = get_subset(test_dataset, opt.prc_test, True) 40 | train_dataset = DataLoader(train_dataset, batch_size=opt.batch_size_SN, shuffle=True, 41 | num_workers=opt.workers, pin_memory=True, drop_last=False) 42 | test_dataset = DataLoader(test_dataset, batch_size=opt.batch_size_SN, shuffle=False, 43 | num_workers=opt.workers, pin_memory=True, drop_last=False) 44 | 45 | """---------------------------------TRAIN------------------------------------""" 46 | # Set cuda 47 | device = torch.device("cuda:0") 48 | cudnn.benchmark = True 49 | 50 | # Init model 51 | model = BL(opt, device) 52 | 53 | # set up tensorboard logging 54 | writer = SummaryWriter(log_dir=os.path.join(opt.tb_dir)) 55 | 56 | # Model Training 57 | model.train(train_dataset, test_dataset, writer) 58 | -------------------------------------------------------------------------------- /mainCAT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | import torch.backends.cudnn as cudnn 5 | from Lib.Datasets.EUSAR.EUSARDataset import EUSARDataset 6 | from Lib.Nets.CAT.CAT import CAT 7 | from Lib.Nets.utils.config.config_routine import config_routine 8 | from Lib.Nets.utils.config.general_parser import general_parser 9 | from Lib.Nets.utils.config.specific_parser import specific_parser 10 | from Lib.Nets.utils.generic.generic_training import get_subset 11 | from torch.utils.data.dataloader import DataLoader 12 | import argparse 13 | 14 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 15 | """ 16 | Author: Alessandro Cattoi 17 | Description: This main is employed to train the Conditional Adversarial Transcoder. 18 | "CONFIG" section give quick access to some parameter setting. 19 | """ 20 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 21 | """-------------------------------CONFIG----------------------------------""" 22 | parser = argparse.ArgumentParser(description="CONDITIONAL ADVERSARIAL TRANSCODER") 23 | parser = general_parser(parser) 24 | opt = specific_parser( 25 | parser=parser, log=False, run_folder='runs/agan/', mode='train', tot_epochs=211, loss_type='lsgan', 26 | restoring_rep_path=None, start_from_epoch=None, D_training_ratio=5, lambda_A=10, lambda_gp=None, 27 | buff_dim=1, th_low=None, th_high=None, pool=False, dropping=False, pool_prc_O=0.5, 28 | batch_size=1, prc_train=1, prc_test=1, prc_val=None, sar_c=5, optical_c=4, batch_size_SN=16, res_block_N=6, 29 | data_dir_train='Data/Train/EUSAR/128_trans_corr', data_dir_train2='Data/Train/EUSAR/128_sn_corr', 30 | data_dir_test='Data/Test/EUSAR/128_trans_corr', data_dir_test2='Data/Test/EUSAR/128_sn_corr', 31 | acc_log_freq=50, loss_log_freq=10, save_model_freq=20, images_log_freq=20, 32 | experiment_name='128_corr_data_2_disc', 33 | run_description='provo cgan std con 2 disc') 34 | opt = config_routine(opt) 35 | 36 | """-------------------------------LOAD DATA----------------------------------""" 37 | train_dataset = EUSARDataset(os.path.join(opt.data_dir_train), False, True, opt.sar_c, opt.optical_c) 38 | test_dataset = EUSARDataset(os.path.join(opt.data_dir_test), False, True, opt.sar_c, opt.optical_c) 39 | train_dataset = get_subset(train_dataset, opt.prc_train) 40 | test_dataset = get_subset(test_dataset, opt.prc_train) 41 | train_dataset = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, 42 | num_workers=opt.workers, pin_memory=True, drop_last=False) 43 | test_dataset = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, 44 | num_workers=opt.workers, pin_memory=True, drop_last=False) 45 | 46 | """---------------------------------TRAIN------------------------------------""" 47 | # Set cuda 48 | device = torch.device("cuda:0") 49 | cudnn.benchmark = True 50 | 51 | # Init model 52 | model = CAT(opt, device) 53 | 54 | # set up tensorboard logging 55 | writer = SummaryWriter(log_dir=os.path.join(opt.global_path, opt.tb_dir)) 56 | 57 | # Model Training 58 | model.train(train_dataset=train_dataset, eval_dataset=test_dataset, writer=writer) 59 | -------------------------------------------------------------------------------- /mainCycle_AT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torch.utils.tensorboard import SummaryWriter 5 | import torch.backends.cudnn as cudnn 6 | from Lib.Datasets.EUSAR.EUSARDataset import EUSARDataset 7 | from Lib.Nets.Cycle_AT.Cycle_AT import Cycle_AT 8 | from Lib.Nets.utils.config.config_routine import config_routine 9 | from Lib.Nets.utils.config.general_parser import general_parser 10 | from Lib.Nets.utils.config.specific_parser import specific_parser 11 | from Lib.Nets.utils.generic.generic_training import get_subset 12 | import argparse 13 | 14 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 15 | """ 16 | Author: Alessandro Cattoi 17 | Description: This main can be employed to train the Cycle Consistent Adversarial Transcoder. 18 | "CONFIG" section give quick access to some parameter setting. 19 | """ 20 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 21 | 22 | """-------------------------------CONFIG----------------------------------""" 23 | parser = argparse.ArgumentParser(description="PyTorch Cycle GAN") 24 | parser = general_parser(parser) 25 | opt = specific_parser( 26 | parser=parser, log=False, run_folder='runs/cgan/', mode='train', tot_epochs=200, loss_type='lsgan', 27 | restoring_rep_path=None, start_from_epoch=None, 28 | D_training_ratio=1, res_block_N=6, pool_prc_O=0.5, pool_prc_S=0.5, 29 | buff_dim=1, th_low=0.45, th_high=0.55, pool=False, conditioned=False, dropping=False, 30 | batch_size=1, prc_train=1, prc_test=1, prc_val=None, sar_c=5, optical_c=4, batch_size_SN=16, 31 | th_b_h_ratio=200, th_b_l_ratio=2, th_b_h_pool=0.2, th_b_l_pool=0.8, drop_prc=0, 32 | data_dir_train='Data/Train/EUSAR/128_trans_corr', data_dir_train2='Data/Train/EUSAR/128_sn_corr', 33 | data_dir_test='Data/Test/EUSAR/128_trans_corr', data_dir_test2='Data/Test/EUSAR/128_sn_corr', 34 | acc_log_freq=50, loss_log_freq=10, save_model_freq=20, images_log_freq=1, 35 | experiment_name='std_conv_corr_2_disc', 36 | run_description='provo la cgfan standard che ha fatto 80 con due disc') 37 | opt = config_routine(opt) 38 | """-------------------------------LOAD DATA----------------------------------""" 39 | train_dataset = EUSARDataset(os.path.join(opt.data_dir_train), False, True, opt.sar_c, opt.optical_c) 40 | test_dataset = EUSARDataset(os.path.join(opt.data_dir_test), False, True, opt.sar_c, opt.optical_c) 41 | train_dataset = get_subset(train_dataset, opt.prc_train) 42 | test_dataset = get_subset(test_dataset, opt.prc_train) 43 | train_dataset = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, 44 | num_workers=opt.workers, pin_memory=True, drop_last=False) 45 | test_dataset = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, 46 | num_workers=opt.workers, pin_memory=True, drop_last=False) 47 | 48 | """---------------------------------TRAIN------------------------------------""" 49 | # Set cuda 50 | device = torch.device("cuda:0") 51 | cudnn.benchmark = True 52 | 53 | # Init model 54 | model = Cycle_AT(opt, device) 55 | 56 | # set up tensorboard logging 57 | writer = SummaryWriter(log_dir=os.path.join(opt.global_path, opt.tb_dir)) 58 | 59 | # Model Training 60 | model.train(train_dataset=train_dataset, eval_dataset=test_dataset, writer=writer) 61 | -------------------------------------------------------------------------------- /mainRT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torch.utils.tensorboard import SummaryWriter 5 | import torch.backends.cudnn as cudnn 6 | from Lib.Datasets.EUSAR.EUSARDataset import EUSARDataset 7 | from Lib.Nets.RT.RT import RT 8 | from Lib.Nets.utils.config.config_routine import config_routine 9 | from Lib.Nets.utils.config.general_parser import general_parser 10 | from Lib.Nets.utils.config.specific_parser import specific_parser 11 | from Lib.Nets.utils.generic.generic_training import get_subset 12 | import argparse 13 | 14 | 15 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 16 | """ 17 | Author: Alessandro Cattoi 18 | Description: This main can be employed to train the Regressive Transcoder. 19 | "CONFIG" section give quick access to some parameter setting. 20 | """ 21 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 22 | 23 | """-------------------------------CONFIG----------------------------------""" 24 | parser = argparse.ArgumentParser(description="PyTorch Regressor GAN") 25 | parser = general_parser(parser) 26 | opt = specific_parser( 27 | parser=parser, log=False, run_folder='runs/rgan/', mode='train', tot_epochs=200, 28 | restoring_rep_path=None, start_from_epoch=None, res_block_N=6, 29 | batch_size=1, prc_train=1, prc_test=1, prc_val=None, sar_c=5, optical_c=4, batch_size_SN=16, 30 | data_dir_train='Data/Train/EUSAR/128_trans_corr', data_dir_train2='Data/Train/EUSAR/128_sn_corr', 31 | data_dir_test='Data/Test/EUSAR/128_trans_corr', data_dir_test2='Data/Test/EUSAR/128_sn_corr', 32 | acc_log_freq=50, loss_log_freq=10, save_model_freq=1, images_log_freq=1, 33 | experiment_name='128_conv_std_dati_corr_long', 34 | run_description='Test di rgan con conv std e nuovi dati più lunga') 35 | opt = config_routine(opt) 36 | 37 | """-------------------------------LOAD DATA----------------------------------""" 38 | train_dataset = EUSARDataset(os.path.join(opt.data_dir_train), False, True, opt.sar_c, opt.optical_c) 39 | test_dataset = EUSARDataset(os.path.join(opt.data_dir_test), False, True, opt.sar_c, opt.optical_c) 40 | train_dataset = get_subset(train_dataset, opt.prc_train) 41 | test_dataset = get_subset(test_dataset, opt.prc_train) 42 | train_dataset = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, 43 | num_workers=opt.workers, pin_memory=True, drop_last=False) 44 | test_dataset = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, 45 | num_workers=opt.workers, pin_memory=True, drop_last=False) 46 | 47 | """---------------------------------TRAIN------------------------------------""" 48 | # Set cuda 49 | device = torch.device("cuda:0") 50 | cudnn.benchmark = True 51 | 52 | # Init model 53 | model = RT(opt, device) 54 | 55 | # set up tensorboard logging 56 | writer = SummaryWriter(log_dir=os.path.join(opt.global_path, opt.tb_dir)) 57 | 58 | # Model Training 59 | model.train(train_dataset=train_dataset, eval_dataset=test_dataset, writer=writer) 60 | -------------------------------------------------------------------------------- /mainSN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | import torch.backends.cudnn as cudnn 5 | from Lib.Datasets.EUSAR.EUSARDataset import EUSARDataset 6 | from Lib.Nets.SN.SN import SN 7 | from Lib.Nets.utils.config.config_routine import config_routine 8 | from Lib.Nets.utils.config.general_parser import general_parser 9 | from Lib.Nets.utils.config.specific_parser import specific_parser 10 | from Lib.Nets.utils.generic.generic_training import get_subset 11 | from torch.utils.data import DataLoader 12 | import argparse 13 | 14 | """-------------------------------CONFIG----------------------------------""" 15 | parser = argparse.ArgumentParser(description="PyTorch Regression GAN") 16 | parser = general_parser(parser) 17 | opt = specific_parser( 18 | parser=parser, run_folder='runs/sn/', mode='train', tot_epochs=30, res_block_N=6, 19 | restoring_rep_path=None, start_from_epoch=None, 20 | pretrained_GAN=None, GAN_epoch=None, seed=1, 21 | batch_size_SN=16, prc_train=1, prc_test=1, prc_val=None, sar_c=5, optical_c=4, 22 | data_dir_train='Data/Train/EUSAR/128_sn_corr', data_dir_test='Data/Test/EUSAR/128_sn_corr', 23 | acc_log_freq=29, loss_log_freq=1, save_model_freq=100, images_log_freq=None, 24 | experiment_name='prova', 25 | run_description='Classifico con nuova metrica accuracy e tutto pt') 26 | 27 | opt = config_routine(opt) 28 | """-------------------------------LOAD DATA----------------------------------""" 29 | train_dataset = EUSARDataset(os.path.join(opt.data_dir_train), True, False, opt.sar_c, opt.optical_c, True) 30 | train_dataset = get_subset(train_dataset, opt.prc_train) 31 | train_dataset = DataLoader(train_dataset, batch_size=opt.batch_size_SN, shuffle=True, 32 | num_workers=opt.workers, pin_memory=True, drop_last=False) 33 | 34 | test_dataset = EUSARDataset(os.path.join(opt.data_dir_test), True, False, opt.sar_c, opt.optical_c) 35 | test_dataset = get_subset(test_dataset, opt.prc_test, True) 36 | test_dataset = DataLoader(test_dataset, batch_size=opt.batch_size_SN, shuffle=False, 37 | num_workers=opt.workers, pin_memory=True, drop_last=False) 38 | 39 | """--------------------------------TRAIN-----------------------------------""" 40 | # Set cuda 41 | device = torch.device("cuda:0") 42 | cudnn.benchmark = True 43 | 44 | # Init model 45 | model = SN(opt, device) 46 | 47 | # set up tensorboard logging 48 | writer = SummaryWriter(log_dir=os.path.join(opt.tb_dir)) 49 | 50 | # Model Training 51 | model.train(train_dataset, test_dataset, writer) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Learning for Semantic Segmentation of Pol-SAR Images via SAR-to-Optical Transcoding Using Pytorch 2 | 3 | * This framework has been developed during my master thesis for the master degree in Information and Communication 4 | Engineering [@UniTN](https://www.unitn.it/ "UniTN website") - 5 | [Course Description](https://offertaformativa.unitn.it/en/lm/information-and-communications-engineering "Course website") 6 | * If you want to discover every detail about the choices motivations and experiments you can check my 7 | [Final Dissertation](Docs/manuscript/manuscript.pdf "Final Dissertation") 8 | out. 9 | * Instead, for a brief overview you can check out my [Final Presentation](Docs/manuscript/manuscript.pdf "Final Presentation"). 10 | * Here below I will explain how you reuse the code or repeat the experiments performed. 11 | * Please you find this work helpful or interesting just drop a comment and let me now, if you have problem or 12 | curiosities do not hesitate to contact me. 13 | 14 | ## Visual Results 15 | ### Transcoding Results 16 | Transcoding comparison between three randomly sampled areas and the three type of transcoders implemented 17 | 18 | 19 | 20 | ### Classification Results (% states for amount of labelled data employed) 21 | Classification results comparison using different pretrained models and different amount of labelled data 22 | 23 | 24 | 25 | ## How to use this code 26 | The repo is structured as follows: 27 | ``` 28 | . 29 | ├── Data 30 | │ ├── Datasets 31 | │ ├── Test ⮕ Here store the patches prepared accordingly to Lib/Dataset/EUSAR/ 32 | │ ├── Train ⮕ For train and test sets 33 | ├── Docker ⮕ Here is store the docker file configuration 34 | ├── Docs 35 | │ ├── arch ⮕ Here there are some architecture images 36 | │ └── manuscript ⮕ Here there is the final manuscript 37 | ├── Lib ⮕ Here there are all the code resources 38 | │ ├── Datasets 39 | │ │ ├── EUSAR ⮕ Dataset Pytorch class overload 40 | │ │ ├── processing ⮕ Dataset preprocessing 41 | │ │ └── runner ⮕ Some runners to perform dataset processing 42 | │ ├── Nets ⮕ Here there is the implementation for each network deployed in this framework 43 | │ │ ├── BL ⮕ Fully supervised framework used as benchmark 44 | │ │ ├── CAT ⮕ Supervised Conditional Adversarial Transcoder 45 | │ │ ├── Cycle_AT ⮕ Unsupervised Cycle-Consistent Adversarial Transcoder 46 | │ │ ├── RT ⮕ Supervised Regressive Transcoder 47 | │ │ ├── SN ⮕ Segmentation Network to perform semantic segmentation using the features 48 | learning during the transcoding phase. 49 | ├── eval.py 50 | ├── mainBL.py ⮕ Main file to train the Baseline 51 | ├── mainCAT.py ⮕ Main file to train the Conditional Adversarial Transcoder 52 | ├── mainCycle_AT.py ⮕ Main file to train the Cycle-Consistent Adversarial Transcoder 53 | ├── mainRT.py ⮕ Main file to train the Regressive Transcoder 54 | ├── mainSN.py ⮕ Main file to train the Segmentation Network 55 | ├── readme.md 56 | ``` 57 | 58 | ## Getting started 59 | * Create the dataset. This file [EUSARDataset.py](Lib/Datasets/EUSAR/EUSARDataset.py "EUSAR Dataset Class") implements a Pytorch Dataset. To work 60 | with it the data has to be stored as specified in the file in the folders of Data. 61 | * The docker folder stores the file employed to built my docker image which is public 62 | [myDocker](https://hub.docker.com/repository/docker/cattale/pytorch "Docker Hub"). It includes the pytorch docker 63 | image with some additional library and setting. 64 | * Docs store the final report of this work, so for any doubt refer to it, you can find almost everything. 65 | * In Lib there are all the libraries used to perform the network's operations. 66 | * Once you have prepared the repo, the dataset, and the docker image you can run the main files. 67 | 68 | ### Prepare the Dataset 69 | The data employed in my work was coupled radar and optical images: 70 | * Radar images was dual polarized C-Band Sentinel-1 products 71 | * Optical images was RGB+NIR Sentinel-2 images 72 | * The labelled set was composed of 73 | * Forests. 74 | * Streets. 75 | * Fields. 76 | * Urban. 77 | * Water. 78 | 79 | All the images employed were 10x10 meters resolution. 80 | You can follow the instruction [EUSARDataset.py](Lib/Datasets/EUSAR/EUSARDataset.py "EUSAR Dataset Class") and create a dataset compliant with my EUSARDataset class or recreate your own, 81 | In the former you have 100% compatibility, in the latter you could encounter some problems. 82 | 83 | ### Prerequisites 84 | All the training of the networks implemented have been performed on an Nvidia GeForce RTX 2070 SUPER with 8GB 85 | of dedicated memory. The code requires at least 8 GB of free GPU and 8 GB of free RAM. In the report you can find 86 | approximately the running times. 87 | 88 | ### Prepare the Machine 89 | To prepare the docker you can run this command: 90 | ```shell 91 | docker create --name [yourPreferredName] --gpus all --shm-size 8G -it -v [folder/Path/To/Project]:[Folder/In/The/Docker] cattale/pytorch:latest 92 | ``` 93 | Between square brackets are parameters you can change: 94 | * [yourPreferredName] choose a name for your container (here you should clone the project) 95 | * [folder/Path/To/Project] the folder in which you store your project 96 | * [Folder/In/The/Docker] folder in the docker container where you will run your code 97 | 98 | ### Configure the Test 99 | Now you need to configure the scripts to run the test you want to perform. 100 | The parameters are configured as follows when a script is launched: 101 | * The [general_parser.py](Lib/Nets/utils/config/general_parser.py "General Parser") defines all the configurable parameters. So refer to 102 | it for a detailed list 103 | * After the parsing of the argument passed to the script is possible to modify them in a mask in the mains using the 104 | [specific_parser.py](Lib/Nets/utils/config/specific_parser.py "Specific Parser"). This script basically overwrite the argument which are 105 | specified, it is useful when a lot of parameters change over tests. 106 | * Lastly the [config_routine.py](Lib/Nets/utils/config/config_routine.py "Configure Routine") is run. This script configures the 107 | environment based on the parameters defined. 108 | 109 | ### Run the Code 110 | To run the script follow these instructions. 111 | * Start your docker container using the command ```docker container start [container_name]``` 112 | * Then enter in your container using the command ```docker container attach [container_name]``` 113 | * Now navigate in the container up to the project folder and run one of the provided mains. 114 | * [mainBL.py](mainBL.py "Train the Conditional Adversarial Transcoder") 115 | * [mainRT.py](mainRT.py "Train the Regressive Transcoder") 116 | * [mainCAT.py](mainCAT.py "Train the Conditional Adversarial Transcoder") 117 | * [mainCycle_AT.py](mainCycle_AT.py "Train the Conditional Adversarial Transcoder") 118 | * [mainSN.py](mainSN.py "Train the Segmentation Network") 119 | * To run the script above run the command ```python main*.py``` 120 | 121 | ## Acknowledgments 122 | Last but not least this implementation is based on the work of [Zhu et al. 2017a](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix "pytorch-CycleGAN-and-pix2pix"). 123 | -------------------------------------------------------------------------------- /runs/agan/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/runs/agan/.gitkeep -------------------------------------------------------------------------------- /runs/bl/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/runs/bl/.gitkeep -------------------------------------------------------------------------------- /runs/cgan/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/runs/cgan/.gitkeep -------------------------------------------------------------------------------- /runs/rgan/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/runs/rgan/.gitkeep -------------------------------------------------------------------------------- /runs/runs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/runs/runs/.gitkeep -------------------------------------------------------------------------------- /runs/sn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cattale93/pytorch_self_supervised_learning/162c26446837a7c0ffd2679b3fb54ba01f204123/runs/sn/.gitkeep --------------------------------------------------------------------------------