├── .gitignore ├── README.md ├── conditioned_main.py ├── datasets.py ├── evaluations.py ├── learn_transform.py ├── loss.py ├── main.py ├── models.py ├── nima ├── __init__.py ├── cli.py ├── common.py ├── inference │ ├── __init__.py │ ├── app.py │ ├── inference_model.py │ └── utils.py ├── mobile_net_v2.py ├── model.py └── train │ ├── __init__.py │ ├── clean_dataset.py │ ├── datasets.py │ ├── emd_loss.py │ ├── main.py │ └── utils.py ├── requirements.txt ├── scrape_fivek.py ├── semseg ├── LICENSE ├── __init__.py ├── lib │ ├── nn │ │ ├── __init__.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── batchnorm.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ ├── tests │ │ │ │ ├── test_numeric_batchnorm.py │ │ │ │ └── test_sync_batchnorm.py │ │ │ └── unittest.py │ │ └── parallel │ │ │ ├── __init__.py │ │ │ └── data_parallel.py │ └── utils │ │ ├── __init__.py │ │ ├── data │ │ ├── __init__.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── distributed.py │ │ └── sampler.py │ │ └── th.py └── models │ ├── __init__.py │ ├── mobilenet.py │ ├── models.py │ ├── resnet.py │ └── resnext.py ├── ssim.py ├── torch_utils.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | *.pth 4 | *.pt 5 | log/ 6 | log_histeq/ 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Neural Enhance 2 | ![ ](https://proceduralia.github.io/assets/can.png) 3 | 4 | This repository contains the code for the experiments presented in the technical report [An empirical evaluation of convolutional neural networks for image 5 | enhancement](https://proceduralia.github.io/assets/IACV_Project.pdf). 6 | 7 | The repository contains the code for evaluating CAN32 and UNet models, together with a number of different loss functions, on the task of supervised image enhancement, imitating experts from the MIT-Adobe FiveK dataset. 8 | 9 | The models can also be conditioned, by *conditional batch normalization*, on the categorical features contained in the dataset. 10 | 11 | ![ ](https://proceduralia.github.io/assets/flower.png) 12 | 13 | In addition, the script *learn_transform.py* performs training for learning Contrast Limited Adaptive Histogram Equalization (CLAHE) on the CIFAR10 dataset, using different architectures. 14 | 15 | --- 16 | 17 | Written in collaboration with [ennnas](https://github.com/ennnas) for the Computer Vision MSc course at Politecnico di Milano. 18 | 19 | ## Installation 20 | 21 | ``` 22 | git clone https://github.com/proceduralia/neural_enhance 23 | cd neural_enhance 24 | conda create --name myenv --file requirements.txt 25 | source activate myenv 26 | ``` 27 | 28 | To download the MIT-Adobe FiveK dataset run: 29 | 30 | ``` 31 | python scrape_fivek.py --base_dir path/to/data 32 | ``` 33 | 34 | ## Training 35 | To train a model without using categorical features as additional input run: 36 | 37 | ``` 38 | python main.py --model_type unet --loss l1nima --data_path path/to/data 39 | ``` 40 | 41 | To train a model using categorical features as additional input run: 42 | 43 | ``` 44 | python conditioned_main.py --model_type unet --loss l1nima --data_path path/to/data 45 | ``` 46 | 47 | ## Evaluation 48 | To evaluate a model (without conditions) run: 49 | ``` 50 | python evaluations.py --model_type unet --image_path path/to/image --final_dir path/to/model_folder 51 | ``` 52 | 53 | ## Citation 54 | If you found this repository or the report useful for your research work, you can cite them: 55 | ``` 56 | @misc{pytorchenhance, 57 | author = {Nasca, Ennio and D'Oro, Pierluca}, 58 | title = {An empirical evaluation of convolutional neural networks for image enhancement}, 59 | year = {2019}, 60 | publisher = {GitHub}, 61 | journal = {GitHub repository}, 62 | howpublished = {\url{https://github.com/proceduralia/pytorch-neural-enhance}}, 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /conditioned_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.utils.data import Dataset, DataLoader, random_split 5 | from torchvision import transforms 6 | from torchvision.utils import make_grid 7 | from tensorboardX import SummaryWriter 8 | import argparse 9 | import datetime 10 | import os 11 | import random 12 | from datasets import FivekDataset 13 | from models import ConditionalCAN, ConditionalUNet 14 | from torch_utils import JoinedDataLoader 15 | from loss import ColorSSIM, NimaLoss 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 19 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for') 20 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate') 21 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 22 | parser.add_argument('--cuda_idx', type=int, default=1, help='cuda device id') 23 | parser.add_argument('--manual_seed', type=int, help='manual seed') 24 | parser.add_argument('--logdir', default='log', help='logdir for tensorboard') 25 | parser.add_argument('--run_tag', default='', help='tags for the current run') 26 | parser.add_argument('--checkpoint_every', default=10, help='number of epochs after which saving checkpoints') 27 | parser.add_argument('--checkpoint_dir', default="checkpoints", help='directory for the checkpoints') 28 | parser.add_argument('--final_dir', default="final_models", help='directory for the final models') 29 | parser.add_argument('--model_type', default='can32', choices=['can32','unet'], help='type of model to use') 30 | parser.add_argument('--loss', default='mse', choices=['mse','mae','l1nima','l2nima','l1ssim','colorssim'], help='loss to be used') 31 | parser.add_argument('--gamma', default=0.001, type=float, help='gamma to be used only in case of Nima Loss') 32 | parser.add_argument('--data_path', default='/home/iacv3_1/fivek', help='path of the base directory of the dataset') 33 | opt = parser.parse_args() 34 | 35 | #Create writer for tensorboard 36 | date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M") 37 | run_name = "{}_{}".format(opt.run_tag,date) if opt.run_tag != '' else date 38 | log_dir_name = os.path.join(opt.logdir, run_name) 39 | writer = SummaryWriter(log_dir_name) 40 | writer.add_text('Options', str(opt), 0) 41 | print(opt) 42 | 43 | if opt.manual_seed is None: 44 | opt.manual_seed = random.randint(1, 10000) 45 | print("Random Seed: ", opt.manual_seed) 46 | random.seed(opt.manual_seed) 47 | torch.manual_seed(opt.manual_seed) 48 | 49 | os.makedirs(opt.checkpoint_dir, exist_ok=True) 50 | 51 | if torch.cuda.is_available() and not opt.cuda: 52 | print("You should run with CUDA.") 53 | device = torch.device("cuda:"+str(opt.cuda_idx) if opt.cuda else "cpu") 54 | 55 | landscape_transform = transforms.Compose([ 56 | transforms.Resize((332, 500)), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #normalize in [-1,1] 59 | ]) 60 | portrait_transform = transforms.Compose([ 61 | transforms.Resize((500, 332)), 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #normalize in [-1,1] 64 | ]) 65 | landscape_dataset = FivekDataset(opt.data_path, expert_idx=2, transform=landscape_transform, filter_ratio="landscape", use_features=True) 66 | portrait_dataset = FivekDataset(opt.data_path, expert_idx=2, transform=portrait_transform, filter_ratio="portrait", use_features=True) 67 | 68 | 69 | train_size = int(0.8 * len(landscape_dataset)) 70 | test_size = len(landscape_dataset) - train_size 71 | train_landscape_dataset, test_landscape_dataset = random_split(landscape_dataset, [train_size, test_size]) 72 | 73 | train_size = int(0.8 * len(portrait_dataset)) 74 | test_size = len(portrait_dataset) - train_size 75 | train_portrait_dataset, test_portrait_dataset = random_split(portrait_dataset, [train_size, test_size]) 76 | 77 | train_landscape_loader = DataLoader(train_landscape_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=2) 78 | train_portrait_loader = DataLoader(train_portrait_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=2) 79 | train_loader = JoinedDataLoader(train_landscape_loader, train_portrait_loader) 80 | 81 | test_landscape_loader = DataLoader(test_landscape_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=2) 82 | test_portrait_loader = DataLoader(test_portrait_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=2) 83 | test_loader = JoinedDataLoader(test_landscape_loader, test_portrait_loader) 84 | 85 | 86 | if opt.model_type == 'can32': 87 | model = ConditionalCAN() 88 | if opt.model_type == 'unet': 89 | model = ConditionalUNet() 90 | assert model 91 | model = model.to(device) 92 | 93 | if opt.loss == "mse": 94 | criterion = nn.MSELoss() 95 | if opt.loss == "mae": 96 | criterion = nn.L1Loss() 97 | if opt.loss == "l1nima": 98 | criterion = NimaLoss(device,opt.gamma,nn.L1Loss()) 99 | if opt.loss == "l2nima": 100 | criterion = NimaLoss(device,opt.gamma,nn.MSELoss()) 101 | if opt.loss == "l1ssim": 102 | criterion = ColorSSIM(device,'l1') 103 | if opt.loss == "colorssim": 104 | criterion = ColorSSIM(device) 105 | assert criterion 106 | criterion = criterion.to(device) 107 | 108 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 109 | 110 | #Select random idxs for displaying 111 | test_idxs = random.sample(range(len(test_landscape_dataset)), 3) 112 | for epoch in range(opt.epochs): 113 | model.train() 114 | cumulative_loss = 0.0 115 | for i, (im_o, im_t, feats) in enumerate(train_loader): 116 | im_o, im_t = im_o.to(device), im_t.to(device) 117 | feats = [f.to(device) for f in feats] 118 | optimizer.zero_grad() 119 | 120 | output = model(im_o, feats) 121 | loss = criterion(output, im_t) 122 | loss.backward() 123 | optimizer.step() 124 | cumulative_loss += loss.item() 125 | print('[Epoch %d, Batch %2d] loss: %.3f' % 126 | (epoch + 1, i + 1, cumulative_loss / (i+1)), end="\r") 127 | #Evaluate 128 | writer.add_scalar('Train Error', cumulative_loss / len(train_loader), epoch) 129 | #Checkpointing 130 | if (epoch+1) % opt.checkpoint_every == 0: 131 | torch.save(model.state_dict(), os.path.join(opt.checkpoint_dir, "{}_epoch{}.pt".format(opt.run_tag, epoch+1))) 132 | 133 | #Model evaluation 134 | model.eval() 135 | test_loss = [] 136 | for i, (im_o, im_t, feats) in enumerate(test_loader): 137 | im_o, im_t = im_o.to(device), im_t.to(device) 138 | feats = [f.to(device) for f in feats] 139 | with torch.no_grad(): 140 | output = model(im_o, feats) 141 | test_loss.append(criterion(output, im_t).item()) 142 | avg_loss = sum(test_loss)/len(test_loss) 143 | writer.add_scalar('Test Error', avg_loss, epoch) 144 | 145 | for idx in test_idxs: 146 | original, actual, feats = test_landscape_dataset[idx] 147 | original, actual = original.unsqueeze(0).to(device), actual.unsqueeze(0).to(device) 148 | feats = [f.unsqueeze(0).to(device) for f in feats] 149 | estimated = model(original, feats) 150 | images = torch.cat((original, estimated, actual)) 151 | grid = make_grid(images, nrow=1, normalize=True, range=(-1,1)) 152 | writer.add_image('{}:Original|Estimated|Actual'.format(idx), grid, epoch) 153 | 154 | print("Training Finished") 155 | torch.save(model.state_dict(), os.path.join(opt.final_dir, "{}_final.pt".format(opt.run_tag))) 156 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from torch.utils.data import Dataset 3 | import torch 4 | import numpy as np 5 | from sklearn.preprocessing import LabelEncoder 6 | import os 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | from skimage.color import rgb2gray 10 | from skimage.transform import resize 11 | from skimage.exposure import equalize_hist, equalize_adapthist 12 | from PIL import Image 13 | import pandas as pd 14 | 15 | class TransformedCifarDataset(Dataset): 16 | """ 17 | Dataset with (x, transformed_x) couples, given CIFAR10 and a skimage-style transformation 18 | """ 19 | def __init__(self, transformation, root='./data', tag='adhist', train=True, normalize=True): 20 | """Args: 21 | transformation (callable): the skimage-style transformation to be applied 22 | root (str): the root of the original cifar data 23 | train (bool): true for training set, false for test set 24 | normalize (bool): if true, normalize the data in [-1,1] 25 | """ 26 | original_train_path = os.path.join(root, "original_train"+tag+'.pt') 27 | transformed_train_path = os.path.join(root, "transformed_train"+tag+'.pt') 28 | original_test_path = os.path.join(root, "original_test"+tag+'.pt') 29 | transformed_test_path = os.path.join(root, "transformed_test"+tag+'.pt') 30 | 31 | if train: 32 | if os.path.exists(original_train_path): 33 | self.original_data = torch.load(original_train_path) 34 | self.transformed_data = torch.load(transformed_train_path) 35 | else: 36 | data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True).train_data 37 | gray_data = np.array([rgb2gray(im) for im in data]) 38 | self.original_data = torch.FloatTensor(gray_data) 39 | self.transformed_data = torch.FloatTensor(np.array([transformation(im) for im in gray_data])) 40 | torch.save(self.original_data, original_train_path) 41 | torch.save(self.transformed_data, transformed_train_path) 42 | 43 | if not train: 44 | if os.path.exists(original_test_path): 45 | self.original_data = torch.load(original_test_path) 46 | self.transformed_data = torch.load(transformed_test_path) 47 | else: 48 | data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True).test_data 49 | gray_data = np.array([rgb2gray(im) for im in data]) 50 | self.original_data = torch.FloatTensor(gray_data) 51 | self.transformed_data = torch.FloatTensor(np.array([transformation(im) for im in gray_data])) 52 | torch.save(self.original_data, original_test_path) 53 | torch.save(self.transformed_data, transformed_test_path) 54 | 55 | if normalize: 56 | normalization = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 57 | self.original_data = normalization(self.original_data) 58 | self.transformed_data = normalization(self.transformed_data) 59 | self.original_data = self.original_data.unsqueeze(1) 60 | self.transformed_data = self.transformed_data.unsqueeze(1) 61 | 62 | def __getitem__(self, idx): 63 | return self.original_data[idx], self.transformed_data[idx] 64 | 65 | def __len__(self): 66 | return len(self.original_data) 67 | 68 | class FivekDataset(Dataset): 69 | def __init__(self, base_path, expert_idx=2, transform=None, filter_ratio=None, use_features=False): 70 | """Fivek dataset class. 71 | Args: 72 | - base_path (str): base path with the directories 73 | - expert_idx (int): index of the ground truth expert 74 | - transform (torchvision transform): to be applied to both original and improved images 75 | - filter_ratio (str): "landscape" or "portrait" filter 76 | - use_features (bool): whether to use the (subject, light, location, time) features or not 77 | """ 78 | self.base_path = base_path 79 | self.expert_idx = expert_idx 80 | self.use_features = use_features 81 | if use_features: 82 | self.info_df = pd.read_csv(os.path.join(base_path, 'mitdatainfo.csv')) 83 | self.features = ["subject", "light", "location", "time"] 84 | self.encoders = {} 85 | for feature_name in self.features: 86 | self.encoders[feature_name] = LabelEncoder().fit(self.info_df[feature_name]) 87 | self.encoded_features = torch.LongTensor(np.vstack([self.encoders[feat].transform(self.info_df[feat]) for feat in self.features]).T) 88 | 89 | self.transform = transform 90 | if filter_ratio: 91 | assert filter_ratio in ["landscape", "portrait"] 92 | self.filter_ratio = filter_ratio 93 | self.original_path = os.path.join(base_path, 'original') 94 | self.expert_path = os.path.join(base_path, 'expert'+str(expert_idx)) 95 | 96 | self.len = len(os.listdir(self.original_path)) 97 | #TODO inefficient... Just save this data in the csv 98 | original_shapes = [] 99 | for i in range(self.len): 100 | original_shapes.append(Image.open(os.path.join(self.original_path, "{}.png".format(i))).size) 101 | self.landscape_idxs = [i for i in range(len(original_shapes)) if original_shapes[i][0] > original_shapes[i][1]] 102 | self.portrait_idxs = [i for i in range(len(original_shapes)) if original_shapes[i][0] < original_shapes[i][1]] 103 | 104 | def __getitem__(self, idx): 105 | #Alter index if poltrait or landscape filter is selected 106 | idx = int(idx) 107 | if self.filter_ratio == "landscape": 108 | idx = self.landscape_idxs[idx] 109 | if self.filter_ratio == "portrait": 110 | idx = self.portrait_idxs[idx] 111 | original_im = Image.open(os.path.join(self.original_path, str(idx)+'.png')) 112 | expert_im = Image.open(os.path.join(self.expert_path, str(idx)+'.png')) 113 | if self.transform: 114 | original_im = self.transform(original_im) 115 | expert_im = self.transform(expert_im) 116 | if self.use_features: 117 | #Retrieve features from dataframe and transform them 118 | feats = self.encoded_features[idx] 119 | #Create tuple of tensors 120 | return original_im, expert_im, tuple([tens for tens in feats]) 121 | else: 122 | return original_im, expert_im 123 | 124 | def __len__(self): 125 | if self.filter_ratio == "landscape": 126 | return len(self.landscape_idxs) 127 | if self.filter_ratio == "portrait": 128 | return len(self.portrait_idxs) 129 | else: 130 | return self.len 131 | 132 | if __name__ == "__main__": 133 | dataset = FivekDataset(base_path="/home/iacv3_1/fivek", use_features=True) 134 | original_im, expert_im, feats = dataset[0] 135 | print(original_im.size, expert_im.size, feats) 136 | -------------------------------------------------------------------------------- /evaluations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.utils.data import Dataset, DataLoader, random_split 5 | from torchvision import transforms 6 | from torchvision.utils import make_grid, save_image 7 | from tensorboardX import SummaryWriter 8 | import argparse 9 | import datetime 10 | import os 11 | import random 12 | from datasets import FivekDataset 13 | from models import CAN, SandOCAN, UNet 14 | from torch_utils import JoinedDataLoader, load_model 15 | from PIL import Image 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 20 | parser.add_argument('--cuda_idx', type=int, default=1, help='cuda device id') 21 | parser.add_argument('--logdir', default='log', help='logdir for tensorboard') 22 | parser.add_argument('--image_path', default=None, help='path to the image to enhance') 23 | parser.add_argument('--run_tag', default='evaluation', help='tags for the current run') 24 | parser.add_argument('--final_dir', default="final_models", help='directory for the final_models') 25 | parser.add_argument('--model_type', default='can32', choices=['can32', 'sandocan32','unet'], help='type of model to use') 26 | parser.add_argument('--data_path', default='/home/iacv3_1/fivek', help='path of the base directory of the dataset') 27 | opt = parser.parse_args() 28 | 29 | #Create writer for tensorboard 30 | date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M") 31 | run_name = "{}_{}".format(opt.run_tag,date) if opt.run_tag != '' else date 32 | log_dir_name = os.path.join(opt.logdir, run_name) 33 | writer = SummaryWriter(log_dir_name) 34 | writer.add_text('Options', str(opt), 0) 35 | print(opt) 36 | 37 | if torch.cuda.is_available() and not opt.cuda: 38 | print("You should run with CUDA.") 39 | device = torch.device("cuda:"+str(opt.cuda_idx) if opt.cuda else "cpu") 40 | 41 | landscape_transform = transforms.Compose([ 42 | transforms.Resize((332, 500)), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #normalize in [-1,1] 45 | ]) 46 | portrait_transform = transforms.Compose([ 47 | transforms.Resize((500, 332)), 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #normalize in [-1,1] 50 | ]) 51 | """ 52 | landscape_dataset = FivekDataset(opt.data_path, expert_idx=2, transform=landscape_transform, filter_ratio="landscape") 53 | portrait_dataset = FivekDataset(opt.data_path, expert_idx=2, transform=portrait_transform, filter_ratio="portrait") 54 | 55 | 56 | landscape_loader = DataLoader(train_landscape_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=2) 57 | portrait_loader = DataLoader(train_portrait_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=2) 58 | loader = JoinedDataLoader(train_landscape_loader, train_portrait_loader) 59 | """ 60 | 61 | im = Image.open(opt.image_path) 62 | if im.size()[2] > im.size()[1]: 63 | im = landscape_transform(im) 64 | else: 65 | im = portrait_transform(im) 66 | 67 | im = im.to(device) 68 | im = im.unsqueeze(0) 69 | 70 | if opt.model_type == 'can32': 71 | model = CAN(n_channels=32) 72 | if opt.model_type == 'sandocan32': 73 | model = SandOCAN() 74 | if opt.model_type == 'unet': 75 | model = UNet() 76 | assert model 77 | 78 | models_path = [f for f in os.listdir(opt.final_dir) if f.startswith(opt.model_type)] 79 | #print(models_path) 80 | images = im 81 | for model_name in models_path: 82 | print('Loading model' + model_name) 83 | model.load_state_dict(torch.load(os.path.join(opt.final_dir,model_name), map_location=lambda storage, loc: storage)) 84 | model.to(device) 85 | images = torch.cat((images,model(im))) 86 | names = [' '.join(m.split('_')[1:-1]) for m in models_path] 87 | filename = opt.model_type+'actual_'+'_'.join(names)+'.png' 88 | save_image(images,filename,normalize=True,range=(-1,1)) 89 | grid = make_grid(images, nrow=1, normalize=True, range=(-1,1)) 90 | writer.add_image(filename, grid) 91 | -------------------------------------------------------------------------------- /learn_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision.utils import make_grid 6 | from tensorboardX import SummaryWriter 7 | import numpy as np 8 | import argparse 9 | from functools import partial 10 | import datetime 11 | import os 12 | import random 13 | from datasets import TransformedCifarDataset 14 | from skimage.exposure import equalize_hist, equalize_adapthist 15 | from transforms import unsharp_mask 16 | from scipy.stats import wasserstein_distance 17 | from models import MLP, NaiveCNN, LittleUnet 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 21 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for') 22 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 23 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 24 | parser.add_argument('--cuda_idx', type=int, default=1, help='cuda device id') 25 | parser.add_argument('--outf', default='.', help='folder for model checkpoints') 26 | parser.add_argument('--manual_seed', type=int, help='manual seed') 27 | parser.add_argument('--logdir', default='log_histeq', help='logdir for tensorboard') 28 | parser.add_argument('--run_tag', default='', help='tags for the current run') 29 | parser.add_argument('--checkpoint_every', default=10, help='number of epochs after which saving checkpoints') 30 | parser.add_argument('--model_type', default='unet', choices=['unet', 'cnn', 'mlp'], help='type of model to use') 31 | parser.add_argument('--loss', default='mse', choices=['mse','mae'], help='type of loss to use') 32 | parser.add_argument('--initial_1by1', action="store_true", help='whether to use the initial 1 by 1 convs in unet') 33 | parser.add_argument('--transform', default='ad_hist_eq', choices=['hist_eq','ad_hist_eq','unsharp'], help='transformation to be learned') 34 | opt = parser.parse_args() 35 | 36 | #Create writer for tensorboard 37 | date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M") 38 | run_name = "{}_{}".format(opt.run_tag,date) if opt.run_tag != '' else date 39 | log_dir_name = os.path.join(opt.logdir, run_name) 40 | writer = SummaryWriter(log_dir_name) 41 | writer.add_text('Options', str(opt), 0) 42 | print(opt) 43 | 44 | if opt.manual_seed is None: 45 | opt.manual_seed = random.randint(1, 10000) 46 | print("Random Seed: ", opt.manual_seed) 47 | random.seed(opt.manual_seed) 48 | torch.manual_seed(opt.manual_seed) 49 | 50 | if torch.cuda.is_available() and not opt.cuda: 51 | print("You should run with CUDA.") 52 | device = torch.device("cuda:"+str(opt.cuda_idx) if opt.cuda else "cpu") 53 | 54 | transforms = { 55 | 'hist_eq': equalize_hist, 56 | 'ad_hist_eq': partial(equalize_adapthist, kernel_size=32//4), 57 | 'unsharp': partial(unsharp_mask, amount=1.0) 58 | } 59 | 60 | dataset = TransformedCifarDataset(transforms[opt.transform]) 61 | loader = DataLoader(dataset, batch_size=opt.batch_size, 62 | shuffle=True, num_workers=2) 63 | test_dataset = TransformedCifarDataset(transforms[opt.transform], train=False) 64 | test_loader = DataLoader(test_dataset, batch_size=opt.batch_size, 65 | shuffle=False, num_workers=2) 66 | 67 | if opt.model_type == 'unet': 68 | model = LittleUnet(initial_1by1=opt.initial_1by1) 69 | if opt.model_type == 'cnn': 70 | model = NaiveCNN() 71 | if opt.model_type == 'mlp': 72 | model = MLP() 73 | assert model 74 | 75 | model = model.to(device) 76 | if opt.loss == "mse": 77 | criterion = nn.MSELoss() 78 | if opt.loss == "mae": 79 | criterion = nn.L1Loss() 80 | 81 | criterion = criterion.to(device) 82 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 83 | 84 | for epoch in range(opt.epochs): 85 | model.train() 86 | cumulative_loss = 0.0 87 | for i, (im_o, im_t) in enumerate(loader): 88 | im_o, im_t = im_o.to(device), im_t.to(device) 89 | optimizer.zero_grad() 90 | 91 | output = model(im_o) 92 | loss = criterion(output, im_t) 93 | loss.backward() 94 | optimizer.step() 95 | 96 | cumulative_loss += loss.item() 97 | print('[Epoch %d, Batch %2d] loss: %.3f' % 98 | (epoch + 1, i + 1, cumulative_loss / (i+1)), end="\r") 99 | #Evaluate 100 | writer.add_scalar('MSE Train', cumulative_loss / len(loader), epoch) 101 | model.eval() 102 | 103 | test_loss = [] 104 | wass_dist = [] 105 | for i, (im_o, im_t) in enumerate(test_loader): 106 | im_o, im_t = im_o.to(device), im_t.to(device) 107 | with torch.no_grad(): 108 | output = model(im_o) 109 | test_loss.append(criterion(output, im_t).item()) 110 | actual_hists = np.array([np.histogram(im, bins=255, density=True)[0] for im in im_t.cpu().numpy()]) 111 | pred_hists = np.array([np.histogram(pred, bins=255, density=True)[0] for pred in output.cpu().numpy()]) 112 | wass_dist.append(np.mean([wasserstein_distance(i, j) for i,j in zip(actual_hists, pred_hists)])) 113 | writer.add_scalar('MSE Test', sum(test_loss)/len(test_loss), epoch) 114 | writer.add_scalar('Avg Wasserstein distance', sum(wass_dist)/len(wass_dist), epoch) 115 | 116 | #Make list of type [original1,estimated1,actual1,original2,estimated2,actual2] 117 | original, actual = test_dataset[:5] 118 | original, actual = original.to(device), actual.to(device) 119 | estimated = model(original) 120 | #Original, tran and estimated are (5, 1, 32, 32) 121 | images = [[o,e,a] for o,e,a in zip(original,estimated,actual)] 122 | images = torch.cat([i for k in images for i in k]).unsqueeze(1) 123 | #Make a grid, in each row, original|estimated|actual 124 | grid = make_grid(images, nrow=len(images)//5, normalize=True) 125 | writer.add_image('Original|Estimated|Actual', grid, epoch) 126 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from nima.inference.inference_model import InferenceModel 5 | from ssim import SSIM 6 | import math 7 | import numbers 8 | 9 | class ColorSSIM(nn.Module): 10 | def __init__(self,device,fidelity=None): 11 | super().__init__() 12 | self.smoothing = GaussianSmoothing(3,10,5) 13 | self.smoothing = self.smoothing.to(device) 14 | self.ssim = SSIM() 15 | self.w1 = 1 16 | if fidelity=='l1': 17 | self.fidelity = nn.L1Loss().to(device) 18 | else: 19 | self.w1 = 0.00001 20 | self.fidelity = self.color_loss 21 | 22 | 23 | def forward(self,original_img,target_img): 24 | return self.w1*self.fidelity(original_img,target_img) + (1-self.ssim(original_img,target_img)) 25 | 26 | def color_loss(self,original_img,target_img): 27 | batch_size = original_img.size()[0] 28 | original_blur = self.smoothing(original_img) 29 | target_blur = self.smoothing(target_img) 30 | color_loss = torch.sum(torch.pow(target_blur - original_blur,2))/(2*batch_size) 31 | return color_loss 32 | 33 | class NimaLoss(nn.Module): 34 | def __init__(self,device,gamma,fidelity): 35 | super().__init__() 36 | self.model = InferenceModel(device) 37 | self.fidelity = fidelity 38 | self.fidelity = self.fidelity.to(device) 39 | self.gamma = gamma 40 | 41 | def forward(self,x,y): 42 | score = self.model.predict(x) 43 | return self.fidelity(x,y) + self.gamma*(10 - score) 44 | 45 | 46 | class GaussianSmoothing(nn.Module): 47 | """ 48 | Apply gaussian smoothing on a 49 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 50 | in the input using a depthwise convolution. 51 | Arguments: 52 | channels (int, sequence): Number of channels of the input tensors. Output will 53 | have this number of channels as well. 54 | kernel_size (int, sequence): Size of the gaussian kernel. 55 | sigma (float, sequence): Standard deviation of the gaussian kernel. 56 | dim (int, optional): The number of dimensions of the data. 57 | Default value is 2 (spatial). 58 | Example: 59 | smoothing = GaussianSmoothing(3, 5, 1) 60 | input = torch.rand(1, 3, 100, 100) 61 | input = F.pad(input, (2, 2, 2, 2), mode='reflect') 62 | output = smoothing(input) 63 | """ 64 | def __init__(self, channels, kernel_size, sigma, dim=2): 65 | super(GaussianSmoothing, self).__init__() 66 | if isinstance(kernel_size, numbers.Number): 67 | kernel_size = [kernel_size] * dim 68 | if isinstance(sigma, numbers.Number): 69 | sigma = [sigma] * dim 70 | 71 | # The gaussian kernel is the product of the 72 | # gaussian function of each dimension. 73 | kernel = 1 74 | meshgrids = torch.meshgrid( 75 | [ 76 | torch.arange(size, dtype=torch.float32) 77 | for size in kernel_size 78 | ] 79 | ) 80 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 81 | mean = (size - 1) / 2 82 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 83 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 84 | 85 | # Make sure sum of values in gaussian kernel equals 1. 86 | kernel = kernel / torch.sum(kernel) 87 | 88 | # Reshape to depthwise convolutional weight 89 | kernel = kernel.view(1, 1, *kernel.size()) 90 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 91 | 92 | self.register_buffer('weight', kernel) 93 | self.groups = channels 94 | 95 | if dim == 1: 96 | self.conv = F.conv1d 97 | elif dim == 2: 98 | self.conv = F.conv2d 99 | elif dim == 3: 100 | self.conv = F.conv3d 101 | else: 102 | raise RuntimeError( 103 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 104 | ) 105 | 106 | def forward(self, x): 107 | """ 108 | Apply gaussian filter to input. 109 | Arguments: 110 | input (torch.Tensor): Input to apply gaussian filter on. 111 | Returns: 112 | filtered (torch.Tensor): Filtered output. 113 | """ 114 | x = F.pad(x, (2, 2, 2, 2), mode='reflect') 115 | return self.conv(x, weight=self.weight, groups=self.groups) 116 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.utils.data import Dataset, DataLoader, random_split 5 | from torchvision import transforms 6 | from torchvision.utils import make_grid 7 | from tensorboardX import SummaryWriter 8 | import argparse 9 | import datetime 10 | import os 11 | import random 12 | from datasets import FivekDataset 13 | from models import CAN, SandOCAN, UNet 14 | from torch_utils import JoinedDataLoader, load_model 15 | from loss import ColorSSIM, NimaLoss 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 19 | parser.add_argument('--epochs', type=int, default=52, help='number of epochs to train for') 20 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate') 21 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 22 | parser.add_argument('--cuda_idx', type=int, default=1, help='cuda device id') 23 | parser.add_argument('--manual_seed', type=int, help='manual seed') 24 | parser.add_argument('--logdir', default='log', help='logdir for tensorboard') 25 | parser.add_argument('--run_tag', default='', help='tags for the current run') 26 | parser.add_argument('--checkpoint_every', default=10, help='number of epochs after which saving checkpoints') 27 | parser.add_argument('--checkpoint_dir', default="checkpoints", help='directory for the checkpoints') 28 | parser.add_argument('--final_dir', default="final_models", help='directory for the final_models') 29 | parser.add_argument('--model_type', default='can32', choices=['can32', 'sandocan32','unet'], help='type of model to use') 30 | parser.add_argument('--load_model', action='store_true', help='enables load from latest checkpoint') 31 | parser.add_argument('--loss', default='mse', choices=['mse','mae','l1nima','l2nima','l1ssim','colorssim'], help='loss to be used') 32 | parser.add_argument('--gamma', default=0.001, type=float, help='gamma to be used only in case of Nima Loss') 33 | parser.add_argument('--data_path', default='/home/iacv3_1/fivek', help='path of the base directory of the dataset') 34 | opt = parser.parse_args() 35 | 36 | #Create writer for tensorboard 37 | date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M") 38 | run_name = "{}_{}".format(opt.run_tag,date) if opt.run_tag != '' else date 39 | log_dir_name = os.path.join(opt.logdir, run_name) 40 | writer = SummaryWriter(log_dir_name) 41 | writer.add_text('Options', str(opt), 0) 42 | print(opt) 43 | 44 | if opt.manual_seed is None: 45 | opt.manual_seed = random.randint(1, 10000) 46 | print("Random Seed: ", opt.manual_seed) 47 | random.seed(opt.manual_seed) 48 | torch.manual_seed(opt.manual_seed) 49 | start_epoch = 0 50 | 51 | os.makedirs(opt.checkpoint_dir, exist_ok=True) 52 | 53 | if torch.cuda.is_available() and not opt.cuda: 54 | print("You should run with CUDA.") 55 | device = torch.device("cuda:"+str(opt.cuda_idx) if opt.cuda else "cpu") 56 | 57 | landscape_transform = transforms.Compose([ 58 | transforms.Resize((332, 500)), 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #normalize in [-1,1] 61 | ]) 62 | portrait_transform = transforms.Compose([ 63 | transforms.Resize((500, 332)), 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #normalize in [-1,1] 66 | ]) 67 | landscape_dataset = FivekDataset(opt.data_path, expert_idx=2, transform=landscape_transform, filter_ratio="landscape") 68 | portrait_dataset = FivekDataset(opt.data_path, expert_idx=2, transform=portrait_transform, filter_ratio="portrait") 69 | 70 | 71 | train_size = int(0.8 * len(landscape_dataset)) 72 | test_size = len(landscape_dataset) - train_size 73 | train_landscape_dataset, test_landscape_dataset = random_split(landscape_dataset, [train_size, test_size]) 74 | 75 | train_size = int(0.8 * len(portrait_dataset)) 76 | test_size = len(portrait_dataset) - train_size 77 | train_portrait_dataset, test_portrait_dataset = random_split(portrait_dataset, [train_size, test_size]) 78 | 79 | train_landscape_loader = DataLoader(train_landscape_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=2) 80 | train_portrait_loader = DataLoader(train_portrait_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=2) 81 | train_loader = JoinedDataLoader(train_landscape_loader, train_portrait_loader) 82 | 83 | test_landscape_loader = DataLoader(test_landscape_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=2) 84 | test_portrait_loader = DataLoader(test_portrait_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=2) 85 | test_loader = JoinedDataLoader(test_landscape_loader, test_portrait_loader) 86 | 87 | 88 | if opt.model_type == 'can32': 89 | model = CAN(n_channels=32) 90 | if opt.model_type == 'sandocan32': 91 | model = SandOCAN() 92 | if opt.model_type == 'unet': 93 | model = UNet() 94 | assert model 95 | 96 | if opt.load_model: 97 | model, start_epoch = load_model(model, opt.checkpoint_dir, opt.run_tag) 98 | model = model.to(device) 99 | 100 | if opt.loss == "mse": 101 | criterion = nn.MSELoss() 102 | if opt.loss == "mae": 103 | criterion = nn.L1Loss() 104 | if opt.loss == "l1nima": 105 | criterion = NimaLoss(device,opt.gamma,nn.L1Loss()) 106 | if opt.loss == "l2nima": 107 | criterion = NimaLoss(device,opt.gamma,nn.MSELoss()) 108 | if opt.loss == "l1ssim": 109 | criterion = ColorSSIM(device,'l1') 110 | if opt.loss == "colorssim": 111 | criterion = ColorSSIM(device) 112 | assert criterion 113 | criterion = criterion.to(device) 114 | 115 | optimizer = optim.Adam(model.parameters(), lr=opt.lr) 116 | 117 | #Select random idxs for displaying 118 | test_idxs = random.sample(range(len(test_landscape_dataset)), 3) 119 | for epoch in range(start_epoch, opt.epochs): 120 | model.train() 121 | cumulative_loss = 0.0 122 | for i, (im_o, im_t) in enumerate(train_loader): 123 | im_o, im_t = im_o.to(device), im_t.to(device) 124 | optimizer.zero_grad() 125 | 126 | output = model(im_o) 127 | loss = criterion(output, im_t) 128 | loss.backward() 129 | optimizer.step() 130 | cumulative_loss += loss.item() 131 | print('[Epoch %d, Batch %2d] loss: %.3f' % 132 | (epoch + 1, i + 1, cumulative_loss / (i+1)), end="\r") 133 | #Evaluate 134 | writer.add_scalar('Train Error', cumulative_loss / len(train_loader), epoch) 135 | #Checkpointing 136 | if (epoch+1) % opt.checkpoint_every == 0: 137 | torch.save(model.state_dict(), os.path.join(opt.checkpoint_dir, "{}_epoch{}.pt".format(opt.run_tag, epoch+1))) 138 | 139 | #Model evaluation 140 | model.eval() 141 | test_loss = [] 142 | for i, (im_o, im_t) in enumerate(test_loader): 143 | im_o, im_t = im_o.to(device), im_t.to(device) 144 | with torch.no_grad(): 145 | output = model(im_o) 146 | test_loss.append(criterion(output, im_t).item()) 147 | avg_loss = sum(test_loss)/len(test_loss) 148 | writer.add_scalar('Test Error', avg_loss, epoch) 149 | 150 | for idx in test_idxs: 151 | original, actual = test_landscape_dataset[idx] 152 | original, actual = original.unsqueeze(0).to(device), actual.unsqueeze(0).to(device) 153 | estimated = model(original) 154 | images = torch.cat((original, estimated, actual)) 155 | grid = make_grid(images, nrow=1, normalize=True, range=(-1,1)) 156 | writer.add_image('{}:Original|Estimated|Actual'.format(idx), grid, epoch) 157 | 158 | print("Training completed succesfully") 159 | 160 | torch.save(model.state_dict(), os.path.join(opt.final_dir, "{}_final.pt".format(opt.run_tag))) 161 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_utils import conv_out_shape, same_padding 5 | from torchvision.models import vgg16_bn 6 | from semseg.models.models import SemSegNet 7 | 8 | class ResidualBlock(nn.Module): 9 | def __init__(self, in_channels, out_channels, stride=1): 10 | super(ResidualBlock, self).__init__() 11 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(out_channels) 13 | self.relu = nn.ReLU(inplace=True) 14 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(out_channels) 16 | 17 | def forward(self, x): 18 | residual = x 19 | out = self.conv1(x) 20 | out = self.bn1(out) 21 | out = self.relu(out) 22 | out = self.conv2(out) 23 | out = self.bn2(out) 24 | out += residual 25 | out = self.relu(out) 26 | return out 27 | 28 | class ContinuousConditionalBatchNorm2d(nn.Module): 29 | def __init__(self, n_channels, input_dim): 30 | super().__init__() 31 | self.n_channels = n_channels 32 | #BatchNorm with affine=False is just normalization without params 33 | self.bn = nn.BatchNorm2d(n_channels, affine=False) 34 | #Map continuous condition to required size 35 | self.linear = nn.Linear(input_dim, n_channels*2) 36 | 37 | def forward(self, x, y): 38 | out = self.bn(x) 39 | gamma, beta = self.linear(y).chunk(2, 1) 40 | out = gamma.view(-1, self.n_channels, 1, 1) * out + beta.view(-1, self.n_channels, 1, 1) 41 | return out 42 | 43 | 44 | class ConditionalBatchNorm2d(nn.Module): 45 | def __init__(self, n_channels, num_classes): 46 | super().__init__() 47 | self.n_channels = n_channels 48 | #BatchNorm with affine=False is just normalization without params 49 | self.bn = nn.BatchNorm2d(n_channels, affine=False) 50 | self.embed = nn.Embedding(num_classes, n_channels * 2) 51 | #First half of embedding is for gamma (scale parameter) 52 | self.embed.weight.data[:, :n_channels].normal_(1, 0.02) 53 | #Second half of the embedding is for beta (bias parameter) 54 | self.embed.weight.data[:, n_channels:].zero_() 55 | 56 | def forward(self, x, y): 57 | out = self.bn(x) 58 | gamma, beta = self.embed(y).chunk(2, 1) 59 | out = gamma.view(-1, self.n_channels, 1, 1) * out + beta.view(-1, self.n_channels, 1, 1) 60 | return out 61 | 62 | class AdaptiveBatchNorm2d(nn.Module): 63 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 64 | super(AdaptiveBatchNorm2d,self).__init__() 65 | self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine) 66 | self.a = nn.Parameter(torch.FloatTensor(1, 1, 1, 1)) 67 | self.b = nn.Parameter(torch.FloatTensor(1, 1, 1, 1)) 68 | 69 | def forward(self, x): 70 | return self.a * x + self.b * self.bn(x) 71 | 72 | class MultipleConditionalBatchNorm2d(nn.Module): 73 | """Conditional BatchNorm to be used in the case of multiple categorical input features. 74 | Args: 75 | - n_channels (int): number of feature maps of the convolutional layer to be conditioned 76 | - num_classes (iterable of ints): list of the number of classes for the different feature embeddings 77 | - adaptive (bool): use adaptive batchnorm instead of standadrd batchnorm 78 | """ 79 | def __init__(self, n_channels, nums_classes, adaptive=False): 80 | super().__init__() 81 | self.n_channels = n_channels 82 | self.nums_classes = nums_classes 83 | #use embedding dim such that the total size is n_channels 84 | embedding_dims = [n_channels//len(nums_classes) for i in range(len(nums_classes)-1)] 85 | embedding_dims.append((n_channels//len(nums_classes)) + (n_channels % len(nums_classes))) 86 | embedding_dims = [dim*2 for dim in embedding_dims] 87 | 88 | #BatchNorm with affine=False is just normalization without params 89 | self.bn = AdaptiveBatchNorm2d(n_channels, affine=False) if adaptive else nn.BatchNorm2d(n_channels, affine=False) 90 | #An embedding for each different categorical feature 91 | self.embeddings = nn.ModuleList([nn.Embedding(num_classes, dim) for num_classes, dim in zip(nums_classes, embedding_dims)]) 92 | 93 | #Initialize embeddings 94 | for emb in self.embeddings: 95 | #First half of embedding is for gamma (scale parameter) 96 | emb.weight.data[:, :n_channels].normal_(1, 0.02) 97 | #Second half of the embedding is for beta (bias parameter) 98 | emb.weight.data[:, n_channels:].zero_() 99 | 100 | def forward(self, x, class_idxs): 101 | out = self.bn(x) 102 | concatenated_embeddings = [emb(idx) for emb, idx in zip(self.embeddings, class_idxs)] 103 | concatenated_embeddings = torch.cat(concatenated_embeddings, dim=1) 104 | #gamma, beta = concatenated_embeddings.chunk(2, 1) 105 | gamma, beta = concatenated_embeddings[:,::2], concatenated_embeddings[:,1::2] 106 | #print(gamma, beta) 107 | out = gamma.view(-1, self.n_channels, 1, 1) * out + beta.view(-1, self.n_channels, 1, 1) 108 | return out 109 | 110 | 111 | class MLP(nn.Module): 112 | """ 113 | A one-hidden-layer MLP. 114 | Takes (batch_size, 1, imsize, imsize) tensors as input, and outputs tensor of same shape. 115 | Outputs in range [-1,1] 116 | """ 117 | def __init__(self, imsize=32, n_channels=1, hidden_dim=512, dropout_rate=0.1): 118 | super().__init__() 119 | self.imsize = imsize 120 | self.net = nn.Sequential( 121 | nn.Linear(imsize*imsize*n_channels, hidden_dim), 122 | nn.ReLU(), 123 | nn.Dropout(dropout_rate), 124 | nn.Linear(hidden_dim, imsize*imsize*n_channels), 125 | nn.Tanh() 126 | ) 127 | 128 | def forward(self, x): 129 | batch_size, n_channels = x.size(0), x.size(1) 130 | x = x.view(batch_size, -1) 131 | x = self.net(x) 132 | x = x.view(batch_size, n_channels, self.imsize, self.imsize) 133 | return x 134 | 135 | class CAN(nn.Module): 136 | """Context Aggregation Network based on Table 3 of 137 | "Fast Image Processing with Fully-Convolutional Nets". 138 | In the original paper: n_channels=32, n_middle_blocks=7 139 | """ 140 | def __init__(self, n_channels=32, n_middle_blocks=5, adaptive=False): 141 | super().__init__() 142 | self.bn = AdaptiveBatchNorm2d if adaptive else nn.BatchNorm2d 143 | self.first_block = nn.Sequential( 144 | nn.Conv2d(3, n_channels, kernel_size=3, padding=same_padding(3, 1)), 145 | self.bn(n_channels), 146 | nn.LeakyReLU(0.2), 147 | ) 148 | 149 | #Layers from 2 to 8 150 | blocks = [] 151 | for i in range(1, n_middle_blocks+1): 152 | d = 2**i 153 | blocks.append(nn.Sequential( 154 | nn.Conv2d(n_channels, n_channels, kernel_size=3, dilation=d, padding=same_padding(3, d)), 155 | self.bn(n_channels), 156 | nn.LeakyReLU(0.2) 157 | )) 158 | self.middle_blocks = nn.Sequential(*blocks) 159 | 160 | self.last_blocks = nn.Sequential( 161 | nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=same_padding(3, 1)), 162 | self.bn(n_channels), 163 | nn.LeakyReLU(0.2), 164 | nn.Conv2d(n_channels, 3, kernel_size=1) 165 | ) 166 | 167 | def forward(self, x): 168 | x = self.first_block(x) 169 | x = self.middle_blocks(x) 170 | x = self.last_blocks(x) 171 | return x 172 | 173 | class ConditionalCAN(nn.Module): 174 | """Context Aggregation Network that can be conditioned by multiple categorical classes. 175 | Conditioning is done by conditional batch normalization based on 176 | 177 | Expected input: (image, (class_idx1,class_idx2,...)). 178 | """ 179 | def __init__(self, nums_classes=(6,3,3,4), n_channels=32, n_middle_blocks=5, adaptive=False): 180 | super().__init__() 181 | self.bn = AdaptiveBatchNorm2d if adaptive else nn.BatchNorm2d 182 | self.first_block = nn.Sequential( 183 | nn.Conv2d(3, n_channels, kernel_size=3, padding=same_padding(3, 1)), 184 | self.bn(n_channels), 185 | nn.LeakyReLU(0.2), 186 | ) 187 | 188 | #Layers from 2 to 8 189 | self.middle_convs = nn.ModuleList([ 190 | nn.Conv2d(n_channels, n_channels, kernel_size=3, dilation=2**i, padding=same_padding(3, 2**i)) 191 | for i in range(1, n_middle_blocks+1) 192 | ]) 193 | self.middle_cbns = nn.ModuleList([ 194 | MultipleConditionalBatchNorm2d(n_channels, nums_classes, adaptive=adaptive) 195 | for i in range(1, n_middle_blocks+1) 196 | ]) 197 | self.last_blocks = nn.Sequential( 198 | nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=same_padding(3, 1)), 199 | self.bn(n_channels), 200 | nn.LeakyReLU(0.2), 201 | nn.Conv2d(n_channels, 3, kernel_size=1) 202 | ) 203 | 204 | def forward(self, x, class_idxs): 205 | x = self.first_block(x) 206 | for conv, cbn in zip(self.middle_convs, self.middle_cbns): 207 | x = conv(x) 208 | x = cbn(x, class_idxs) 209 | x = nn.functional.leaky_relu(x, 0.2) 210 | x = self.last_blocks(x) 211 | return x 212 | 213 | class SandOCAN(nn.Module): 214 | """Context Aggregation Network that can be conditioned by semantic segmentation information. 215 | Conditioning with conditional batch normalization 216 | 217 | Expected input: (image, semantic_map). 218 | 219 | Args: 220 | - n_channels (int): number of channels for the internal convolutions 221 | - n_classes (int): number of semantic classes 222 | - emb_dim (int): dimensionality of the semantic embedding 223 | - n_middle blocks (int): number of middle blocks at the center of the can network 224 | - adaptive (bool): whether to use adaptive batch normalization or not 225 | """ 226 | def __init__(self, n_channels=32, n_classes=150, emb_dim=64, n_middle_blocks=5, adaptive=False): 227 | super().__init__() 228 | self.sem_net = SemSegNet() 229 | #Set no gradients for the pretrained model 230 | for param in self.sem_net.parameters(): 231 | param.requires_grad=True 232 | self.bn = AdaptiveBatchNorm2d if adaptive else nn.BatchNorm2d 233 | self.first_block = nn.Sequential( 234 | nn.Conv2d(3, n_channels, kernel_size=3, padding=same_padding(3, 1)), 235 | self.bn(n_channels), 236 | nn.LeakyReLU(0.2), 237 | ) 238 | 239 | #Layers from 2 to 8 240 | self.middle_convs = nn.ModuleList([ 241 | nn.Conv2d(n_channels, n_channels, kernel_size=3, dilation=2**i, padding=same_padding(3, 2**i)) 242 | for i in range(1, n_middle_blocks+1) 243 | ]) 244 | self.middle_cbns = nn.ModuleList([ 245 | ContinuousConditionalBatchNorm2d(n_channels, emb_dim) 246 | for i in range(1, n_middle_blocks+1) 247 | ]) 248 | 249 | self.downsampling_net = nn.Sequential(*( 250 | [nn.Conv2d(n_classes, emb_dim, 1), nn.LeakyReLU(0.2)] + \ 251 | [ResidualBlock(emb_dim, emb_dim) for i in range(3)] + \ 252 | [nn.AdaptiveAvgPool2d((1, 1))] 253 | )) 254 | 255 | self.last_blocks = nn.Sequential( 256 | nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=same_padding(3, 1)), 257 | self.bn(n_channels), 258 | nn.LeakyReLU(0.2), 259 | nn.Conv2d(n_channels, 3, kernel_size=1) 260 | ) 261 | 262 | def forward(self, x): 263 | maps = self.sem_net(x) 264 | x = self.first_block(x) 265 | #Semantic embedding is computed only once for all layers 266 | emb = self.downsampling_net(maps) 267 | emb = emb.view(-1, emb.size(1)) 268 | for conv, cbn in zip(self.middle_convs, self.middle_cbns): 269 | x = conv(x) 270 | x = cbn(x, emb) 271 | x = nn.functional.leaky_relu(x, 0.2) 272 | x = self.last_blocks(x) 273 | return x 274 | 275 | class UNet(nn.Module): 276 | """ 277 | Standard Unet 278 | """ 279 | def __init__(self): 280 | super().__init__() 281 | self.inc = unet_block(3,64,False) 282 | self.down1 = unet_block(64,128) 283 | self.down2 = unet_block(128,256) 284 | self.down3 = unet_block(256,512) 285 | self.down4 = unet_block(512,512) 286 | self.up1 = unet_up(1024,256) 287 | self.up2 = unet_up(512,128) 288 | self.up3 = unet_up(256,64) 289 | self.up4 = unet_up(128,64) 290 | self.outc = nn.Conv2d(64,3,1) 291 | 292 | def forward(self, x): 293 | x1 = self.inc(x) 294 | x2 = self.down1(x1) 295 | x3 = self.down2(x2) 296 | x4 = self.down3(x3) 297 | x5 = self.down4(x4) 298 | x = self.up1(x5, x4) 299 | x = self.up2(x, x3) 300 | x = self.up3(x, x2) 301 | x = self.up4(x, x1) 302 | x = self.outc(x) 303 | return x 304 | 305 | class unet_block(nn.Module): 306 | '''(conv => BN => ReLU) * 2''' 307 | def __init__(self,in_ch,out_ch,down=True): 308 | super(unet_block,self).__init__() 309 | self.down = down 310 | self.pool = nn.MaxPool2d(2) 311 | self.block = nn.Sequential( 312 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 313 | nn.BatchNorm2d(out_ch), 314 | nn.ReLU(inplace=True), 315 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 316 | nn.BatchNorm2d(out_ch), 317 | nn.ReLU(inplace=True) ) 318 | 319 | def forward(self,x): 320 | if self.down: 321 | x = self.pool(x) 322 | x = self.block(x) 323 | return x 324 | 325 | class unet_up(nn.Module): 326 | def __init__(self,in_ch,out_ch,bilinear=True): 327 | super(unet_up, self).__init__() 328 | if bilinear: 329 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 330 | else: 331 | self.up = nn.Conv2dTranspose(in_ch//2,out_ch//2,2,stride=2) 332 | 333 | self.conv = unet_block(in_ch,out_ch,False) 334 | 335 | def forward(self,x1,x2): 336 | x1 = self.up(x1) 337 | 338 | diffY = x2.size()[2] - x1.size()[2] 339 | diffX = x2.size()[3] - x1.size()[3] 340 | 341 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 342 | diffY // 2, diffY - diffY//2)) 343 | 344 | x = torch.cat([x2, x1], dim=1) 345 | x = self.conv(x) 346 | return x 347 | 348 | 349 | class ConditionalUNet(nn.Module): 350 | """ 351 | Conditional Unet 352 | """ 353 | def __init__(self, nums_classes=(6,3,3,4)): 354 | super().__init__() 355 | self.inc = unet_block(3,64,False) 356 | self.down1 = unet_block(64,128) 357 | self.down2 = unet_block(128,256) 358 | self.down3 = unet_block(256,512) 359 | self.down4 = unet_block(512,512) 360 | self.up1 = cond_unet_up(1024,256, nums_classes=nums_classes) 361 | self.up2 = cond_unet_up(512,128, nums_classes=nums_classes) 362 | self.up3 = cond_unet_up(256,64, nums_classes=nums_classes) 363 | self.up4 = cond_unet_up(128,64, nums_classes=nums_classes) 364 | self.outc = nn.Conv2d(64,3,1) 365 | 366 | def forward(self, x, feat): 367 | x1 = self.inc(x) 368 | x2 = self.down1(x1) 369 | x3 = self.down2(x2) 370 | x4 = self.down3(x3) 371 | x5 = self.down4(x4) 372 | x = self.up1(x5, x4, feat) 373 | x = self.up2(x, x3, feat) 374 | x = self.up3(x, x2, feat) 375 | x = self.up4(x, x1, feat) 376 | x = self.outc(x) 377 | return x 378 | 379 | class cond_unet_block(nn.Module): 380 | '''(conv => BN => ReLU) * 2''' 381 | def __init__(self,in_ch,out_ch,down=True, nums_classes=(6,3,3,4)): 382 | super().__init__() 383 | self.down = down 384 | self.pool = nn.MaxPool2d(2) 385 | self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) 386 | self.bn1 = MultipleConditionalBatchNorm2d(out_ch, nums_classes) 387 | self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) 388 | self.bn2 = MultipleConditionalBatchNorm2d(out_ch, nums_classes) 389 | 390 | def forward(self,x,feat): 391 | if self.down: 392 | x = self.pool(x) 393 | #x = self.block(x) 394 | x = self.conv1(x) 395 | x = F.relu(self.bn1(x,feat)) 396 | x = self.conv2(x) 397 | x = F.relu(self.bn2(x,feat)) 398 | return x 399 | 400 | class cond_unet_up(nn.Module): 401 | def __init__(self,in_ch,out_ch,bilinear=True,nums_classes=(6,3,3,4)): 402 | super().__init__() 403 | if bilinear: 404 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 405 | else: 406 | self.up = nn.Conv2dTranspose(in_ch//2,out_ch//2,2,stride=2) 407 | 408 | self.conv = cond_unet_block(in_ch,out_ch,False,nums_classes) 409 | 410 | def forward(self,x1,x2,feat): 411 | x1 = self.up(x1) 412 | 413 | diffY = x2.size()[2] - x1.size()[2] 414 | diffX = x2.size()[3] - x1.size()[3] 415 | 416 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 417 | diffY // 2, diffY - diffY//2)) 418 | 419 | x = torch.cat([x2, x1], dim=1) 420 | x = self.conv(x,feat) 421 | return x 422 | 423 | 424 | class NaiveCNN(nn.Module): 425 | """ 426 | Naive CNN with a bunch of convolutions and a fully connected at the end 427 | Output in [-1,1] 428 | """ 429 | def __init__(self, imsize=32, n_channels=1, dropout_rate=0.2): 430 | super().__init__() 431 | self.imsize = imsize 432 | 433 | self.conv1 = nn.Conv2d(n_channels, 4, kernel_size=5) 434 | current_shape = conv_out_shape((imsize, imsize), self.conv1) 435 | self.conv2 = nn.Conv2d(4, 8, kernel_size=5) 436 | current_shape = conv_out_shape(current_shape, self.conv2) 437 | self.conv3 = nn.Conv2d(8, 16, kernel_size=5) 438 | self.shape_before_dense = conv_out_shape(current_shape, self.conv3) 439 | self.dropout = nn.Dropout2d(p=dropout_rate) 440 | self.linear = nn.Linear(16*self.shape_before_dense[0]*self.shape_before_dense[1], imsize*imsize) 441 | 442 | def forward(self, x): 443 | batch_size, n_channels = x.size(0), x.size(1) 444 | x = torch.relu(self.conv1(x)) 445 | x = torch.relu(self.conv2(x)) 446 | x = torch.relu(self.conv3(x)) 447 | x = self.dropout(x) 448 | #Reshape before fully connected 449 | x = x.view(batch_size, -1) 450 | x = torch.tanh(self.linear(x)) 451 | x = x.view(batch_size, n_channels, self.imsize, self.imsize) 452 | return x 453 | 454 | class LittleUnet(nn.Module): 455 | """A little U-net style CNN based on concatenations and transposed convolution 456 | Output in [-1,1] 457 | """ 458 | def __init__(self, imsize=32, n_channels=1, initial_1by1=False): 459 | super().__init__() 460 | if initial_1by1: 461 | self.conv1 = nn.Sequential( 462 | nn.Conv2d(n_channels, 255, kernel_size=1), 463 | nn.ReLU(), 464 | nn.Conv2d(255, 255, kernel_size=1), 465 | nn.ReLU(), 466 | nn.Conv2d(255, 32, kernel_size=3, stride=2), 467 | nn.BatchNorm2d(32), 468 | nn.ReLU() 469 | ) 470 | else: 471 | self.conv1 = nn.Sequential( 472 | nn.Conv2d(n_channels, 32, kernel_size=3, stride=2), 473 | nn.BatchNorm2d(32), 474 | nn.ReLU() 475 | ) 476 | self.conv2 = nn.Sequential( 477 | nn.Conv2d(32, 64, kernel_size=3, stride=2), 478 | nn.BatchNorm2d(64), 479 | nn.ReLU(), 480 | ) 481 | self.conv3 = nn.Sequential( 482 | nn.Conv2d(64, 128, kernel_size=3), 483 | nn.BatchNorm2d(128), 484 | nn.ReLU(), 485 | ) 486 | 487 | self.conv_tran1 = nn.Sequential( 488 | nn.ConvTranspose2d(128, 64, kernel_size=3), 489 | nn.BatchNorm2d(64), 490 | nn.ReLU() 491 | ) 492 | self.conv_tran2 = nn.Sequential( 493 | nn.ConvTranspose2d(64*2, 32, kernel_size=3, stride=2), 494 | nn.BatchNorm2d(32), 495 | nn.ReLU() 496 | ) 497 | self.conv_tran3 = nn.Sequential( 498 | nn.ConvTranspose2d(32*2, 1, kernel_size=4, stride=2), 499 | nn.Tanh() 500 | ) 501 | 502 | def forward(self, x): 503 | out1 = self.conv1(x) 504 | out2 = self.conv2(out1) 505 | out3 = self.conv3(out2) 506 | x = self.conv_tran1(out3) 507 | x = self.conv_tran2(torch.cat((out2, x), dim=1)) 508 | x = self.conv_tran3(torch.cat((out1, x), dim=1)) 509 | return x 510 | 511 | 512 | class VGG(nn.Module): 513 | def __init__(self,device): 514 | super(VGG,self).__init__() 515 | self.model = vgg16_bn(True).features 516 | self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(1,3,1,1) 517 | self.mean = self.mean.to(device) 518 | for param in self.model.parameters(): 519 | param.requires_grad = False 520 | 521 | def forward(self, x): 522 | # Normalize the images since we have [-1,1] and vgg wants [0,1] 523 | x = (x*0.5)+0.5 524 | x = x*255 - self.mean 525 | x = self.model(x) 526 | return x 527 | 528 | if __name__ == "__main__": 529 | im = torch.randn(8, 1, 32, 32) 530 | 531 | mlp = MLP() 532 | #Test mlp forward 533 | assert mlp(im).size() == im.size() 534 | 535 | naive_cnn = NaiveCNN() 536 | #Test naive cnn forward 537 | assert naive_cnn(im).size() == im.size() 538 | 539 | unet = LittleUnet(initial_1by1=True) 540 | #Test little unet forward 541 | assert unet(im).size() == im.size() 542 | 543 | im = torch.randn(8, 3, 60, 80) 544 | can32 = CAN() 545 | #Test can32 forward 546 | assert can32(im).size() == im.size() 547 | 548 | feat = torch.randn(1, 32, 300, 500) 549 | class_idxs = (torch.LongTensor([1]), torch.LongTensor([0]), torch.LongTensor([0]), torch.LongTensor([1])) 550 | nums_classes = (6, 3, 3, 4) 551 | cbn = MultipleConditionalBatchNorm2d(n_channels=32, nums_classes=nums_classes) 552 | assert cbn(feat, class_idxs).size() == feat.size() 553 | 554 | c_can32 = ConditionalCAN(nums_classes) 555 | assert c_can32(im, class_idxs).size() == im.size() 556 | 557 | sandocan32 = SandOCAN() 558 | assert sandocan32(im).size() == im.size() 559 | 560 | im = torch.randn(1, 3, 300, 500) 561 | unet = UNet() 562 | assert unet(im).size() == im.size() 563 | print("Tests run correctly!") 564 | -------------------------------------------------------------------------------- /nima/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proceduralia/pytorch-neural-enhance/6c9fbc378eae3d4ad1317dd01a3cb1a157909d84/nima/__init__.py -------------------------------------------------------------------------------- /nima/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from nima.train.clean_dataset import clean_and_split 4 | from nima.train.utils import TrainParams, ValidateParams 5 | from nima.train.main import start_train, start_check_model 6 | from nima.inference.inference_model import InferenceModel 7 | 8 | 9 | @click.group() 10 | def cli(): 11 | pass 12 | 13 | 14 | @click.command() 15 | @click.option('--path_to_ava_txt', help='origin AVA.txt file', required=True) 16 | @click.option('--path_to_save_csv', help='where save train.csv|val.csv|test.csv', required=True) 17 | @click.option('--path_to_images', help='images directory', required=True) 18 | def prepare_dataset(path_to_ava_txt, path_to_save_csv, path_to_images): 19 | click.echo('Clean and split dataset to train|val|test') 20 | clean_and_split(path_to_ava_txt=path_to_ava_txt, path_to_save_csv=path_to_save_csv, path_to_images=path_to_images) 21 | click.echo('Done') 22 | 23 | 24 | @click.command() 25 | @click.option('--path_to_save_csv', help='where save train.csv|val.csv|test.csv', required=True) 26 | @click.option('--path_to_images', help='images directory', required=True) 27 | @click.option('--experiment_dir_name', help='unique experiment name and directory to save all logs and weight', 28 | required=True) 29 | @click.option('--batch_size', help='batch size', required=True, type=int) 30 | @click.option('--num_workers', help='number of reading workers', required=True, type=int) 31 | @click.option('--num_epoch', help='number of epoch', required=True, type=int) 32 | @click.option('--init_lr', help='initial learning rate', required=True, type=float) 33 | def train_model(path_to_save_csv, path_to_images, experiment_dir_name, batch_size, num_workers, num_epoch, init_lr): 34 | click.echo('Train and Validate model save all logs too tensorboard and params to params.json') 35 | params = TrainParams(path_to_save_csv=path_to_save_csv, path_to_images=path_to_images, 36 | experiment_dir_name=experiment_dir_name, batch_size=batch_size, num_workers=num_workers, 37 | num_epoch=num_epoch, init_lr=init_lr) 38 | start_train(params) 39 | 40 | 41 | @click.command() 42 | @click.option('--path_to_model_weight', help='path to model weight .pth file', required=True) 43 | @click.option('--path_to_image', help='image ', required=True) 44 | def get_image_score(path_to_model_weight, path_to_image): 45 | model = InferenceModel(path_to_model=path_to_model_weight) 46 | result = model.predict_from_file(path_to_image) 47 | click.echo(result) 48 | 49 | 50 | @click.command() 51 | @click.option('--path_to_model_weight', help='path to model weight .pth file', required=True) 52 | @click.option('--path_to_save_csv', help='where save train.csv|val.csv|test.csv', required=True) 53 | @click.option('--path_to_images', help='images directory', required=True) 54 | @click.option('--batch_size', help='batch size', required=True, type=int) 55 | @click.option('--num_workers', help='number of reading workers', required=True, type=int) 56 | def validate_model(path_to_model_weight, path_to_save_csv, path_to_images, batch_size, num_workers): 57 | params = ValidateParams(path_to_save_csv=path_to_save_csv, path_to_model_weight=path_to_model_weight, 58 | path_to_images=path_to_images, num_workers=num_workers, batch_size=batch_size) 59 | 60 | val_loss, test_loss = start_check_model(params) 61 | click.echo(f"val_loss = {val_loss}; test_loss = {test_loss}") 62 | 63 | 64 | cli.add_command(prepare_dataset) 65 | cli.add_command(train_model) 66 | cli.add_command(validate_model) 67 | cli.add_command(get_image_score) 68 | 69 | 70 | if __name__ == '__main__': 71 | cli() 72 | -------------------------------------------------------------------------------- /nima/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import numpy as np 4 | from torchvision import transforms 5 | 6 | 7 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406] 8 | IMAGE_NET_STD = [0.229, 0.224, 0.225] 9 | 10 | 11 | class Transform: 12 | def __init__(self): 13 | normalize = transforms.Normalize( 14 | mean=IMAGE_NET_MEAN, 15 | std=IMAGE_NET_STD) 16 | 17 | self._train_transform = transforms.Compose([ 18 | transforms.Resize((256, 256)), 19 | transforms.RandomHorizontalFlip(), 20 | transforms.RandomCrop((224, 224)), 21 | transforms.ToTensor(), 22 | normalize]) 23 | 24 | self._val_transform = transforms.Compose([ 25 | transforms.Resize((224, 224)), 26 | transforms.ToTensor(), 27 | normalize]) 28 | 29 | self._eval_transform = transforms.Compose([ 30 | #transforms.Resize((224, 224)), 31 | normalize]) 32 | 33 | @property 34 | def train_transform(self): 35 | return self._train_transform 36 | 37 | @property 38 | def val_transform(self): 39 | return self._val_transform 40 | 41 | @property 42 | def eval_transform(self): 43 | return self._eval_transform 44 | 45 | 46 | def get_mean_score(score): 47 | buckets = np.arange(1, 11) 48 | mu = (buckets * score).sum() 49 | return mu 50 | 51 | 52 | def get_std_score(scores): 53 | si = np.arange(1, 11) 54 | mean = get_mean_score(scores) 55 | std = np.sqrt(np.sum(((si - mean) ** 2) * scores)) 56 | return std 57 | 58 | 59 | def download_file(url, local_filename, chunk_size=1024): 60 | if os.path.exists(local_filename): 61 | return local_filename 62 | r = requests.get(url, stream=True) 63 | with open(local_filename, 'wb') as f: 64 | for chunk in r.iter_content(chunk_size=chunk_size): 65 | if chunk: 66 | f.write(chunk) 67 | return local_filename 68 | -------------------------------------------------------------------------------- /nima/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proceduralia/pytorch-neural-enhance/6c9fbc378eae3d4ad1317dd01a3cb1a157909d84/nima/inference/__init__.py -------------------------------------------------------------------------------- /nima/inference/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, redirect, url_for, request, jsonify 2 | from PIL import Image 3 | 4 | from flasgger import Swagger 5 | 6 | from nima.inference.inference_model import InferenceModel 7 | 8 | app = Flask(__name__) 9 | Swagger(app=app) 10 | app.model = InferenceModel.create_model() 11 | 12 | 13 | @app.route('/') 14 | def index(): 15 | return redirect(url_for('health_check')) 16 | 17 | 18 | @app.route('/api/health_check') 19 | def health_check(): 20 | return "ok" 21 | 22 | 23 | @app.route('/api/get_scores', methods=['POST']) 24 | def get_scores(): 25 | """ 26 | NIMA Pytorch 27 | 28 | --- 29 | tags: 30 | - Get Scores 31 | consumes: 32 | - multipart/form-data 33 | parameters: 34 | - in: formData 35 | type: file 36 | name: file 37 | required: true 38 | description: Upload your file. 39 | responses: 40 | 200: 41 | description: Scores for image 42 | schema: 43 | id: Palette 44 | type: object 45 | properties: 46 | mean_score: 47 | type: float 48 | std_score: 49 | type: float 50 | scores: 51 | type: array 52 | items: 53 | type: float 54 | examples: 55 | { 56 | "mean_score": 5.385255615692586, 57 | "scores": [ 58 | 0.0049467734061181545, 59 | 0.018246186897158623, 60 | 0.05434520170092583, 61 | 0.16275958716869354, 62 | 0.3268744945526123, 63 | 0.24433879554271698, 64 | 0.11257114261388779, 65 | 0.05015537887811661, 66 | 0.017528045922517776, 67 | 0.00823438260704279 68 | ], 69 | "std_score": 1.451693009595486 70 | } 71 | """ 72 | 73 | img = Image.open(request.files['file']) 74 | result = app.model.predict_from_pil_image(img) 75 | return jsonify(result) 76 | 77 | 78 | if __name__ == '__main__': 79 | app.run(host='0.0.0.0', port=5000, debug=True) 80 | -------------------------------------------------------------------------------- /nima/inference/inference_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets.folder import default_loader 3 | import torch.nn.functional as F 4 | from decouple import config 5 | 6 | from nima.model import NIMA 7 | from nima.common import Transform, get_mean_score, get_std_score 8 | from nima.common import download_file 9 | from nima.inference.utils import format_output 10 | 11 | class InferenceModel: 12 | def __init__(self,device): 13 | self.transform = Transform().eval_transform 14 | self.model = NIMA(pretrained_base_model=True) 15 | self.model = self.model.to(device) 16 | self.model.eval() 17 | 18 | def predict_from_file(self, image_path): 19 | image = default_loader(image_path) 20 | return self.predict(image) 21 | 22 | def predict_from_pil_image(self, image): 23 | image = image.convert('RGB') 24 | return self.predict(image) 25 | 26 | def predict(self, image): 27 | image = image*0.5 + 0.5 #rescale from [-1,1]-->[0,1] 28 | image = F.interpolate(image,size=(224,224),mode='bilinear') 29 | with torch.no_grad(): 30 | prob = self.model(image).data.cpu().numpy()[0] 31 | 32 | mean_score = get_mean_score(prob) 33 | std_score = get_std_score(prob) 34 | return mean_score+std_score 35 | -------------------------------------------------------------------------------- /nima/inference/utils.py: -------------------------------------------------------------------------------- 1 | def format_output(mean_score, std_score, prob): 2 | return { 3 | 'mean_score': float(mean_score), 4 | 'std_score': float(std_score), 5 | 'scores': [float(x) for x in prob] 6 | } 7 | -------------------------------------------------------------------------------- /nima/mobile_net_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from nima.common import download_file 8 | 9 | MOBILE_NET_V2_UTR = 'https://s3-us-west-1.amazonaws.com/models-nima/mobilenetv2.pth.tar' 10 | 11 | 12 | def conv_bn(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | 20 | def conv_1x1_bn(inp, oup): 21 | return nn.Sequential( 22 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 23 | nn.BatchNorm2d(oup), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | 28 | class InvertedResidual(nn.Module): 29 | def __init__(self, inp, oup, stride, expand_ratio): 30 | super(InvertedResidual, self).__init__() 31 | self.stride = stride 32 | assert stride in [1, 2] 33 | 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | self.conv = nn.Sequential( 37 | # pw 38 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(inp * expand_ratio), 40 | nn.ReLU6(inplace=True), 41 | # dw 42 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 43 | nn.BatchNorm2d(inp * expand_ratio), 44 | nn.ReLU6(inplace=True), 45 | # pw-linear 46 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(oup), 48 | ) 49 | 50 | def forward(self, x): 51 | if self.use_res_connect: 52 | return x + self.conv(x) 53 | else: 54 | return self.conv(x) 55 | 56 | 57 | class MobileNetV2(nn.Module): 58 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 59 | super(MobileNetV2, self).__init__() 60 | # setting of inverted residual blocks 61 | self.interverted_residual_setting = [ 62 | # t, c, n, s 63 | [1, 16, 1, 1], 64 | [6, 24, 2, 2], 65 | [6, 32, 3, 2], 66 | [6, 64, 4, 2], 67 | [6, 96, 3, 1], 68 | [6, 160, 3, 2], 69 | [6, 320, 1, 1], 70 | ] 71 | 72 | # building first layer 73 | assert input_size % 32 == 0 74 | input_channel = int(32 * width_mult) 75 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 76 | self.features = [conv_bn(3, input_channel, 2)] 77 | # building inverted residual blocks 78 | for t, c, n, s in self.interverted_residual_setting: 79 | output_channel = int(c * width_mult) 80 | for i in range(n): 81 | if i == 0: 82 | self.features.append(InvertedResidual(input_channel, output_channel, s, t)) 83 | else: 84 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t)) 85 | input_channel = output_channel 86 | # building last several layers 87 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 88 | self.features.append(nn.AvgPool2d(input_size // 32)) 89 | # make it nn.Sequential 90 | self.features = nn.Sequential(*self.features) 91 | 92 | # building classifier 93 | self.classifier = nn.Sequential( 94 | nn.Dropout(), 95 | nn.Linear(self.last_channel, n_class), 96 | ) 97 | 98 | self._initialize_weights() 99 | 100 | def forward(self, x): 101 | x = self.features(x) 102 | x = x.view(-1, self.last_channel) 103 | x = self.classifier(x) 104 | return x 105 | 106 | def _initialize_weights(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | n = m.weight.size(1) 118 | m.weight.data.normal_(0, 0.01) 119 | m.bias.data.zero_() 120 | 121 | 122 | def mobile_net_v2(pretrained=True): 123 | model = MobileNetV2() 124 | if pretrained: 125 | path_to_model = '/tmp/mobilenetv2.pth.tar' 126 | if not os.path.exists(path_to_model): 127 | path_to_model = download_file(MOBILE_NET_V2_UTR, path_to_model) 128 | state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage) 129 | model.load_state_dict(state_dict) 130 | return model 131 | -------------------------------------------------------------------------------- /nima/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from nima.mobile_net_v2 import mobile_net_v2 4 | 5 | 6 | class NIMA(nn.Module): 7 | def __init__(self, pretrained_base_model=True): 8 | super(NIMA, self).__init__() 9 | base_model = mobile_net_v2(pretrained=pretrained_base_model) 10 | base_model = nn.Sequential(*list(base_model.children())[:-1]) 11 | 12 | self.base_model = base_model 13 | 14 | self.head = nn.Sequential( 15 | nn.ReLU(inplace=True), 16 | nn.Dropout(p=0.75), 17 | nn.Linear(1280, 10), 18 | nn.Softmax(dim=1) 19 | ) 20 | 21 | def forward(self, x): 22 | x = self.base_model(x) 23 | x = x.view(x.size(0), -1) 24 | x = self.head(x) 25 | return x 26 | -------------------------------------------------------------------------------- /nima/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proceduralia/pytorch-neural-enhance/6c9fbc378eae3d4ad1317dd01a3cb1a157909d84/nima/train/__init__.py -------------------------------------------------------------------------------- /nima/train/clean_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor, as_completed 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torchvision.datasets.folder import default_loader 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | from nima.train.utils import SCORE_NAMES, TAG_NAMES 12 | 13 | 14 | def _remove_all_not_found_image(df: pd.DataFrame, path_to_images: str) -> pd.DataFrame: 15 | clean_rows = [] 16 | for _, row in df.iterrows(): 17 | image_id = row['image_id'] 18 | try: 19 | _ = default_loader(os.path.join(path_to_images, f"{image_id}.jpg")) 20 | except (FileNotFoundError, OSError): 21 | pass 22 | else: 23 | clean_rows.append(row) 24 | df_clean = pd.DataFrame(clean_rows) 25 | return df_clean 26 | 27 | 28 | def remove_all_not_found_image(df: pd.DataFrame, path_to_images: str, num_workers: int = 64) -> pd.DataFrame: 29 | futures = [] 30 | results = [] 31 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 32 | for df_batch in np.array_split(df, num_workers): 33 | future = executor.submit(_remove_all_not_found_image, df=df_batch, path_to_images=path_to_images) 34 | futures.append(future) 35 | for future in tqdm(as_completed(futures)): 36 | results.append(future.result()) 37 | new_df = pd.concat(results) 38 | return new_df 39 | 40 | 41 | def _read_ava_txt(path_to_ava: str) -> pd.DataFrame: 42 | df = pd.read_csv(path_to_ava, header=None, sep=' ') 43 | del df[0] 44 | scores_names = SCORE_NAMES 45 | tag_names = TAG_NAMES 46 | df.columns = ['image_id'] + scores_names + tag_names 47 | return df 48 | 49 | 50 | def clean_and_split(path_to_ava_txt: str, path_to_save_csv: str, path_to_images: str): 51 | df = _read_ava_txt(path_to_ava_txt) 52 | df = remove_all_not_found_image(df, path_to_images) 53 | 54 | df_train, df_val_test = train_test_split(df, train_size=0.9) 55 | df_val, df_test = train_test_split(df_val_test, train_size=0.5) 56 | 57 | df_train.to_csv(os.path.join(path_to_save_csv, 'train.csv'), index=False) 58 | df_val.to_csv(os.path.join(path_to_save_csv, 'val.csv'), index=False) 59 | df_test.to_csv(os.path.join(path_to_save_csv, 'test.csv'), index=False) 60 | -------------------------------------------------------------------------------- /nima/train/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import numpy as np 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision.datasets.folder import default_loader 8 | 9 | 10 | from nima.train.utils import SCORE_NAMES 11 | 12 | 13 | class AVADataset(Dataset): 14 | def __init__(self, path_to_csv: str, images_path: str, transform): 15 | self.df = pd.read_csv(path_to_csv) 16 | self.images_path = images_path 17 | self.transform = transform 18 | 19 | def __len__(self): 20 | return self.df.shape[0] 21 | 22 | def __getitem__(self, item): 23 | row = self.df.iloc[item] 24 | y = np.array([row[k] for k in SCORE_NAMES]) 25 | p = y / y.sum() 26 | 27 | image_id = row['image_id'] 28 | image_path = os.path.join(self.images_path, f'{image_id}.jpg') 29 | image = default_loader(image_path) 30 | x = self.transform(image) 31 | return x, p.astype('float32') 32 | -------------------------------------------------------------------------------- /nima/train/emd_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class EDMLoss(nn.Module): 7 | def __init__(self): 8 | super(EDMLoss, self).__init__() 9 | 10 | def forward(self, p_target: Variable, p_estimate: Variable): 11 | assert p_target.shape == p_estimate.shape 12 | # cdf for values [1, 2, ..., 10] 13 | cdf_target = torch.cumsum(p_target, dim=1) 14 | # cdf for values [1, 2, ..., 10] 15 | cdf_estimate = torch.cumsum(p_estimate, dim=1) 16 | cdf_diff = cdf_estimate - cdf_target 17 | samplewise_emd = torch.sqrt(torch.mean(torch.pow(torch.abs(cdf_diff), 2))) 18 | return samplewise_emd.mean() 19 | -------------------------------------------------------------------------------- /nima/train/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | from nima.model import NIMA 11 | from nima.train.datasets import AVADataset 12 | from nima.train.emd_loss import EDMLoss 13 | from nima.common import Transform 14 | from nima.train.utils import TrainParams, ValidateParams, AverageMeter 15 | 16 | use_gpu = torch.cuda.is_available() 17 | device = torch.device("cuda" if use_gpu else "cpu") 18 | 19 | 20 | def train(model, loader, optimizer, criterion, writer=None, global_step=None, name=None): 21 | model.train() 22 | train_losses = AverageMeter() 23 | for idx, (x, y) in enumerate(tqdm(loader)): 24 | x = x.to(device) 25 | y = y.to(device) 26 | y_pred = model(x) 27 | loss = criterion(p_target=y, p_estimate=y_pred) 28 | optimizer.zero_grad() 29 | loss.backward() 30 | optimizer.step() 31 | train_losses.update(loss.item(), x.size(0)) 32 | 33 | if writer is not None: 34 | writer.add_scalar(f"{name}/train_loss.avg", train_losses.avg, global_step=global_step + idx) 35 | return train_losses.avg 36 | 37 | 38 | def validate(model, loader, criterion, writer=None, global_step=None, name=None): 39 | model.eval() 40 | validate_losses = AverageMeter() 41 | for idx, (x, y) in enumerate(tqdm(loader)): 42 | x = x.to(device) 43 | y = y.to(device) 44 | y_pred = model(x) 45 | loss = criterion(p_target=y, p_estimate=y_pred) 46 | validate_losses.update(loss.item(), x.size(0)) 47 | 48 | if writer is not None: 49 | writer.add_scalar(f"{name}/val_loss.avg", validate_losses.avg, global_step=global_step + idx) 50 | return validate_losses.avg 51 | 52 | 53 | def _create_train_data_part(params: TrainParams): 54 | train_csv_path = os.path.join(params.path_to_save_csv, 'train.csv') 55 | val_csv_path = os.path.join(params.path_to_save_csv, 'val.csv') 56 | 57 | transform = Transform() 58 | train_ds = AVADataset(train_csv_path, params.path_to_images, transform.train_transform) 59 | val_ds = AVADataset(val_csv_path, params.path_to_images, transform.val_transform) 60 | 61 | train_loader = DataLoader(train_ds, batch_size=params.batch_size, num_workers=params.num_workers, shuffle=True) 62 | val_loader = DataLoader(val_ds, batch_size=params.batch_size, num_workers=params.num_workers, shuffle=False) 63 | 64 | return train_loader, val_loader 65 | 66 | 67 | def _create_val_data_part(params: TrainParams): 68 | val_csv_path = os.path.join(params.path_to_save_csv, 'val.csv') 69 | test_csv_path = os.path.join(params.path_to_save_csv, 'test.csv') 70 | 71 | transform = Transform() 72 | val_ds = AVADataset(val_csv_path, params.path_to_images, transform.val_transform) 73 | test_ds = AVADataset(test_csv_path, params.path_to_images, transform.val_transform) 74 | 75 | val_loader = DataLoader(val_ds, batch_size=params.batch_size, num_workers=params.num_workers, shuffle=False) 76 | test_loader = DataLoader(test_ds, batch_size=params.batch_size, num_workers=params.num_workers, shuffle=False) 77 | 78 | return val_loader, test_loader 79 | 80 | 81 | def start_train(params: TrainParams): 82 | train_loader, val_loader = _create_train_data_part(params=params) 83 | model = NIMA() 84 | optimizer = torch.optim.Adam(model.parameters(), lr=params.init_lr) 85 | criterion = EDMLoss() 86 | model = model.to(device) 87 | criterion.to(device) 88 | 89 | writer = SummaryWriter(log_dir=os.path.join(params.experiment_dir_name, 'logs')) 90 | os.makedirs(params.experiment_dir_name, exist_ok=True) 91 | params.save_params(os.path.join(params.experiment_dir_name, 'params.json')) 92 | 93 | for e in range(params.num_epoch): 94 | train_loss = train(model=model, loader=train_loader, optimizer=optimizer, criterion=criterion, 95 | writer=writer, global_step=len(train_loader.dataset) * e, 96 | name=f"{params.experiment_dir_name}_by_batch") 97 | val_loss = validate(model=model, loader=val_loader, criterion=criterion, 98 | writer=writer, global_step=len(train_loader.dataset) * e, 99 | name=f"{params.experiment_dir_name}_by_batch") 100 | 101 | model_name = f"emd_loss_epoch_{e}_train_{train_loss}_{val_loss}.pth" 102 | torch.save(model.module.state_dict(), os.path.join(params.experiment_dir_name, model_name)) 103 | writer.add_scalar(f"{params.experiment_dir_name}_by_epoch/train_loss", train_loss, global_step=e) 104 | writer.add_scalar(f"{params.experiment_dir_name}_by_epoch/val_loss", val_loss, global_step=e) 105 | 106 | writer.export_scalars_to_json(os.path.join(params.experiment_dir_name, 'all_scalars.json')) 107 | writer.close() 108 | 109 | 110 | def start_check_model(params: ValidateParams): 111 | val_loader, test_loader = _create_val_data_part(params) 112 | model = NIMA() 113 | model.load_state_dict(torch.load(params.path_to_model_weight)) 114 | criterion = EDMLoss() 115 | 116 | model = model.to(device) 117 | criterion.to(device) 118 | 119 | val_loss = validate(model=model, loader=val_loader, criterion=criterion) 120 | test_loss = validate(model=model, loader=test_loader, criterion=criterion) 121 | return val_loss, test_loss 122 | -------------------------------------------------------------------------------- /nima/train/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import namedtuple 3 | 4 | _SCORE_FIRST_COLUMN = 2 5 | _SCORE_LAST_COLUMN = 12 6 | _TAG_FIRST_COLUMN = 1 7 | _TAG_LAST_COLUMN = 4 8 | 9 | SCORE_NAMES = [f'score{i}' for i in range(_SCORE_FIRST_COLUMN, _SCORE_LAST_COLUMN)] 10 | TAG_NAMES = [f'tag{i}' for i in range(_TAG_FIRST_COLUMN, _TAG_LAST_COLUMN)] 11 | 12 | 13 | class TrainParams(namedtuple('TrainParams', ['path_to_save_csv', 'path_to_images', 14 | 'experiment_dir_name', 'batch_size', 15 | 'num_workers', 'num_epoch', 'init_lr'])): 16 | def save_params(self, file_path: str): 17 | with open(file_path, 'w') as f: 18 | json.dump(self._asdict(), f) 19 | 20 | 21 | class ValidateParams(namedtuple('TrainParams', ['path_to_save_csv', 'path_to_model_weight', 22 | 'path_to_images', 'batch_size', 23 | 'num_workers'])): 24 | pass 25 | 26 | 27 | class AverageMeter(object): 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | 37 | def update(self, val, n=1): 38 | self.val = val 39 | self.sum += val * n 40 | self.count += n 41 | self.avg = self.sum / self.count 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.tar.bz2 6 | https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2018.03.07-0.tar.bz2 7 | https://repo.anaconda.com/pkgs/main/linux-64/conda-env-2.6.0-1.tar.bz2 8 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2019.1-144.tar.bz2 9 | https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-8.2.0-hdf63c60_1.tar.bz2 10 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-7.3.0-hdf63c60_0.tar.bz2 11 | https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-8.2.0-hdf63c60_1.tar.bz2 12 | https://repo.anaconda.com/pkgs/main/linux-64/bzip2-1.0.6-h14c3975_5.tar.bz2 13 | https://repo.anaconda.com/pkgs/main/linux-64/expat-2.2.6-he6710b0_0.tar.bz2 14 | https://repo.anaconda.com/pkgs/main/linux-64/fribidi-1.0.5-h7b6447c_0.tar.bz2 15 | https://repo.anaconda.com/pkgs/main/linux-64/gmp-6.1.2-h6c8ec71_1.tar.bz2 16 | https://repo.anaconda.com/pkgs/main/linux-64/graphite2-1.3.12-h23475e2_2.tar.bz2 17 | https://repo.anaconda.com/pkgs/main/linux-64/icu-58.2-h9c2bf20_1.tar.bz2 18 | https://repo.anaconda.com/pkgs/main/linux-64/jbig-2.1-hdba287a_0.tar.bz2 19 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.tar.bz2 20 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.2.1-hd88cf55_4.tar.bz2 21 | https://repo.anaconda.com/pkgs/main/linux-64/liblief-0.9.0-h7725739_1.tar.bz2 22 | https://repo.anaconda.com/pkgs/main/linux-64/libsodium-1.0.16-h1bed415_0.tar.bz2 23 | https://repo.anaconda.com/pkgs/main/linux-64/libtool-2.4.6-h7b6447c_5.tar.bz2 24 | https://repo.anaconda.com/pkgs/main/linux-64/libuuid-1.0.3-h1bed415_2.tar.bz2 25 | https://repo.anaconda.com/pkgs/main/linux-64/libxcb-1.13-h1bed415_1.tar.bz2 26 | https://repo.anaconda.com/pkgs/main/linux-64/lz4-c-1.8.1.2-h14c3975_0.tar.bz2 27 | https://repo.anaconda.com/pkgs/main/linux-64/lzo-2.10-h49e0be7_2.tar.bz2 28 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2019.1-144.tar.bz2 29 | https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.1-he6710b0_1.tar.bz2 30 | https://repo.anaconda.com/pkgs/main/linux-64/openssl-1.1.1a-h7b6447c_0.tar.bz2 31 | https://repo.anaconda.com/pkgs/main/linux-64/patchelf-0.9-he6710b0_3.tar.bz2 32 | https://repo.anaconda.com/pkgs/main/linux-64/pcre-8.42-h439df22_0.tar.bz2 33 | https://repo.anaconda.com/pkgs/main/linux-64/pixman-0.34.0-hceecf20_3.tar.bz2 34 | https://repo.anaconda.com/pkgs/main/linux-64/snappy-1.1.7-hbae5bb6_3.tar.bz2 35 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.2.4-h14c3975_4.tar.bz2 36 | https://repo.anaconda.com/pkgs/main/linux-64/yaml-0.1.7-had09818_2.tar.bz2 37 | https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.11-h7b6447c_3.tar.bz2 38 | https://repo.anaconda.com/pkgs/main/linux-64/blosc-1.14.4-hdbcaa40_0.tar.bz2 39 | https://repo.anaconda.com/pkgs/main/linux-64/glib-2.56.2-hd408876_0.tar.bz2 40 | https://repo.anaconda.com/pkgs/main/linux-64/hdf5-1.10.2-hba1933b_1.tar.bz2 41 | https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20170329-h6b74fdf_2.tar.bz2 42 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.35-hbc83047_0.tar.bz2 43 | https://repo.anaconda.com/pkgs/main/linux-64/libssh2-1.8.0-h1ba5d50_4.tar.bz2 44 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.0.9-he85c1e1_2.tar.bz2 45 | https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.9.8-h26e45fe_1.tar.bz2 46 | https://repo.anaconda.com/pkgs/main/linux-64/mpfr-4.0.1-hdf1c602_3.tar.bz2 47 | https://repo.anaconda.com/pkgs/main/linux-64/pandoc-1.19.2.1-hea2e7c5_1.tar.bz2 48 | https://repo.anaconda.com/pkgs/main/linux-64/readline-7.0-h7b6447c_5.tar.bz2 49 | https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.8-hbc83047_0.tar.bz2 50 | https://repo.anaconda.com/pkgs/main/linux-64/zeromq-4.2.5-hf484d3e_1.tar.bz2 51 | https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.3.7-h0b5b093_0.tar.bz2 52 | https://repo.anaconda.com/pkgs/main/linux-64/dbus-1.13.2-h714fa37_1.tar.bz2 53 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.9.1-h8a8886c_1.tar.bz2 54 | https://repo.anaconda.com/pkgs/main/linux-64/gstreamer-1.14.0-hb453b48_1.tar.bz2 55 | https://repo.anaconda.com/pkgs/main/linux-64/krb5-1.16.1-h173b8e3_7.tar.bz2 56 | https://repo.anaconda.com/pkgs/main/linux-64/libarchive-3.3.3-h5d8350f_5.tar.bz2 57 | https://repo.anaconda.com/pkgs/main/linux-64/libxslt-1.1.32-h1312cb7_0.tar.bz2 58 | https://repo.anaconda.com/pkgs/main/linux-64/mpc-1.1.0-h10f8cd9_1.tar.bz2 59 | https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.26.0-h7b6447c_0.tar.bz2 60 | https://repo.anaconda.com/pkgs/main/linux-64/unixodbc-2.3.7-h14c3975_0.tar.bz2 61 | https://repo.anaconda.com/pkgs/main/linux-64/fontconfig-2.13.0-h9420a91_0.tar.bz2 62 | https://repo.anaconda.com/pkgs/main/linux-64/gst-plugins-base-1.14.0-hbbd80ab_1.tar.bz2 63 | https://repo.anaconda.com/pkgs/main/linux-64/libcurl-7.63.0-h20c2e04_1000.tar.bz2 64 | https://repo.anaconda.com/pkgs/main/linux-64/python-3.7.1-h0371630_7.tar.bz2 65 | https://repo.anaconda.com/pkgs/main/linux-64/alabaster-0.7.12-py37_0.tar.bz2 66 | https://repo.anaconda.com/pkgs/main/linux-64/asn1crypto-0.24.0-py37_0.tar.bz2 67 | https://repo.anaconda.com/pkgs/main/linux-64/atomicwrites-1.2.1-py37_0.tar.bz2 68 | https://repo.anaconda.com/pkgs/main/linux-64/attrs-18.2.0-py37h28b3542_0.tar.bz2 69 | https://repo.anaconda.com/pkgs/main/linux-64/backcall-0.1.0-py37_0.tar.bz2 70 | https://repo.anaconda.com/pkgs/main/linux-64/backports-1.0-py37_1.tar.bz2 71 | https://repo.anaconda.com/pkgs/main/linux-64/beautifulsoup4-4.6.3-py37_0.tar.bz2 72 | https://repo.anaconda.com/pkgs/main/linux-64/bitarray-0.8.3-py37h14c3975_0.tar.bz2 73 | https://repo.anaconda.com/pkgs/main/linux-64/boto-2.49.0-py37_0.tar.bz2 74 | https://repo.anaconda.com/pkgs/main/linux-64/cairo-1.14.12-h8948797_3.tar.bz2 75 | https://repo.anaconda.com/pkgs/main/linux-64/certifi-2018.11.29-py37_0.tar.bz2 76 | https://repo.anaconda.com/pkgs/main/linux-64/chardet-3.0.4-py37_1.tar.bz2 77 | https://repo.anaconda.com/pkgs/main/linux-64/click-7.0-py37_0.tar.bz2 78 | https://repo.anaconda.com/pkgs/main/linux-64/cloudpickle-0.6.1-py37_0.tar.bz2 79 | https://repo.anaconda.com/pkgs/main/linux-64/colorama-0.4.1-py37_0.tar.bz2 80 | https://repo.anaconda.com/pkgs/main/linux-64/contextlib2-0.5.5-py37_0.tar.bz2 81 | https://repo.anaconda.com/pkgs/main/linux-64/curl-7.63.0-hbc83047_1000.tar.bz2 82 | https://repo.anaconda.com/pkgs/main/linux-64/dask-core-1.0.0-py37_0.tar.bz2 83 | https://repo.anaconda.com/pkgs/main/linux-64/decorator-4.3.0-py37_0.tar.bz2 84 | https://repo.anaconda.com/pkgs/main/linux-64/defusedxml-0.5.0-py37_1.tar.bz2 85 | https://repo.anaconda.com/pkgs/main/linux-64/docutils-0.14-py37_0.tar.bz2 86 | https://repo.anaconda.com/pkgs/main/linux-64/entrypoints-0.2.3-py37_2.tar.bz2 87 | https://repo.anaconda.com/pkgs/main/linux-64/et_xmlfile-1.0.1-py37_0.tar.bz2 88 | https://repo.anaconda.com/pkgs/main/linux-64/fastcache-1.0.2-py37h14c3975_2.tar.bz2 89 | https://repo.anaconda.com/pkgs/main/linux-64/filelock-3.0.10-py37_0.tar.bz2 90 | https://repo.anaconda.com/pkgs/main/linux-64/future-0.17.1-py37_0.tar.bz2 91 | https://repo.anaconda.com/pkgs/main/linux-64/glob2-0.6-py37_1.tar.bz2 92 | https://repo.anaconda.com/pkgs/main/linux-64/gmpy2-2.0.8-py37h10f8cd9_2.tar.bz2 93 | https://repo.anaconda.com/pkgs/main/linux-64/greenlet-0.4.15-py37h7b6447c_0.tar.bz2 94 | https://repo.anaconda.com/pkgs/main/linux-64/heapdict-1.0.0-py37_2.tar.bz2 95 | https://repo.anaconda.com/pkgs/main/linux-64/idna-2.8-py37_0.tar.bz2 96 | https://repo.anaconda.com/pkgs/main/linux-64/imagesize-1.1.0-py37_0.tar.bz2 97 | https://repo.anaconda.com/pkgs/main/linux-64/importlib_metadata-0.6-py37_0.tar.bz2 98 | https://repo.anaconda.com/pkgs/main/linux-64/ipython_genutils-0.2.0-py37_0.tar.bz2 99 | https://repo.anaconda.com/pkgs/main/linux-64/itsdangerous-1.1.0-py37_0.tar.bz2 100 | https://repo.anaconda.com/pkgs/main/linux-64/jdcal-1.4-py37_0.tar.bz2 101 | https://repo.anaconda.com/pkgs/main/linux-64/jeepney-0.4-py37_0.tar.bz2 102 | https://repo.anaconda.com/pkgs/main/linux-64/kiwisolver-1.0.1-py37hf484d3e_0.tar.bz2 103 | https://repo.anaconda.com/pkgs/main/linux-64/lazy-object-proxy-1.3.1-py37h14c3975_2.tar.bz2 104 | https://repo.anaconda.com/pkgs/main/linux-64/llvmlite-0.26.0-py37hd408876_0.tar.bz2 105 | https://repo.anaconda.com/pkgs/main/linux-64/locket-0.2.0-py37_1.tar.bz2 106 | https://repo.anaconda.com/pkgs/main/linux-64/lxml-4.2.5-py37hefd8a0e_0.tar.bz2 107 | https://repo.anaconda.com/pkgs/main/linux-64/markupsafe-1.1.0-py37h7b6447c_0.tar.bz2 108 | https://repo.anaconda.com/pkgs/main/linux-64/mccabe-0.6.1-py37_1.tar.bz2 109 | https://repo.anaconda.com/pkgs/main/linux-64/mistune-0.8.4-py37h7b6447c_0.tar.bz2 110 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-1.1.2-py37he904b0f_5.tar.bz2 111 | https://repo.anaconda.com/pkgs/main/linux-64/mpmath-1.1.0-py37_0.tar.bz2 112 | https://repo.anaconda.com/pkgs/main/linux-64/msgpack-python-0.5.6-py37h6bb024c_1.tar.bz2 113 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.15.4-py37hde5b4d6_0.tar.bz2 114 | https://repo.anaconda.com/pkgs/main/linux-64/olefile-0.46-py37_0.tar.bz2 115 | https://repo.anaconda.com/pkgs/main/linux-64/pandocfilters-1.4.2-py37_1.tar.bz2 116 | https://repo.anaconda.com/pkgs/main/linux-64/parso-0.3.1-py37_0.tar.bz2 117 | https://repo.anaconda.com/pkgs/main/linux-64/pep8-1.7.1-py37_0.tar.bz2 118 | https://repo.anaconda.com/pkgs/main/linux-64/pickleshare-0.7.5-py37_0.tar.bz2 119 | https://repo.anaconda.com/pkgs/main/linux-64/pkginfo-1.4.2-py37_1.tar.bz2 120 | https://repo.anaconda.com/pkgs/main/linux-64/pluggy-0.8.0-py37_0.tar.bz2 121 | https://repo.anaconda.com/pkgs/main/linux-64/ply-3.11-py37_0.tar.bz2 122 | https://repo.anaconda.com/pkgs/main/linux-64/prometheus_client-0.5.0-py37_0.tar.bz2 123 | https://repo.anaconda.com/pkgs/main/linux-64/psutil-5.4.8-py37h7b6447c_0.tar.bz2 124 | https://repo.anaconda.com/pkgs/main/linux-64/ptyprocess-0.6.0-py37_0.tar.bz2 125 | https://repo.anaconda.com/pkgs/main/linux-64/py-1.7.0-py37_0.tar.bz2 126 | https://repo.anaconda.com/pkgs/main/linux-64/py-lief-0.9.0-py37h7725739_1.tar.bz2 127 | https://repo.anaconda.com/pkgs/main/linux-64/pycodestyle-2.4.0-py37_0.tar.bz2 128 | https://repo.anaconda.com/pkgs/main/linux-64/pycosat-0.6.3-py37h14c3975_0.tar.bz2 129 | https://repo.anaconda.com/pkgs/main/linux-64/pycparser-2.19-py37_0.tar.bz2 130 | https://repo.anaconda.com/pkgs/main/linux-64/pycrypto-2.6.1-py37h14c3975_9.tar.bz2 131 | https://repo.anaconda.com/pkgs/main/linux-64/pycurl-7.43.0.2-py37h1ba5d50_0.tar.bz2 132 | https://repo.anaconda.com/pkgs/main/linux-64/pyflakes-2.0.0-py37_0.tar.bz2 133 | https://repo.anaconda.com/pkgs/main/linux-64/pyodbc-4.0.25-py37he6710b0_0.tar.bz2 134 | https://repo.anaconda.com/pkgs/main/linux-64/pyparsing-2.3.0-py37_0.tar.bz2 135 | https://repo.anaconda.com/pkgs/main/linux-64/pysocks-1.6.8-py37_0.tar.bz2 136 | https://repo.anaconda.com/pkgs/main/linux-64/python-libarchive-c-2.8-py37_6.tar.bz2 137 | https://repo.anaconda.com/pkgs/main/linux-64/pytz-2018.7-py37_0.tar.bz2 138 | https://repo.anaconda.com/pkgs/main/linux-64/pyyaml-3.13-py37h14c3975_0.tar.bz2 139 | https://repo.anaconda.com/pkgs/main/linux-64/pyzmq-17.1.2-py37h14c3975_0.tar.bz2 140 | https://repo.anaconda.com/pkgs/main/linux-64/qt-5.9.7-h5867ecd_1.tar.bz2 141 | https://repo.anaconda.com/pkgs/main/linux-64/qtpy-1.5.2-py37_0.tar.bz2 142 | https://repo.anaconda.com/pkgs/main/linux-64/rope-0.11.0-py37_0.tar.bz2 143 | https://repo.anaconda.com/pkgs/main/linux-64/ruamel_yaml-0.15.46-py37h14c3975_0.tar.bz2 144 | https://repo.anaconda.com/pkgs/main/linux-64/send2trash-1.5.0-py37_0.tar.bz2 145 | https://repo.anaconda.com/pkgs/main/linux-64/simplegeneric-0.8.1-py37_2.tar.bz2 146 | https://repo.anaconda.com/pkgs/main/linux-64/sip-4.19.8-py37hf484d3e_0.tar.bz2 147 | https://repo.anaconda.com/pkgs/main/linux-64/six-1.12.0-py37_0.tar.bz2 148 | https://repo.anaconda.com/pkgs/main/linux-64/snowballstemmer-1.2.1-py37_0.tar.bz2 149 | https://repo.anaconda.com/pkgs/main/linux-64/sortedcontainers-2.1.0-py37_0.tar.bz2 150 | https://repo.anaconda.com/pkgs/main/linux-64/sphinxcontrib-1.0-py37_1.tar.bz2 151 | https://repo.anaconda.com/pkgs/main/linux-64/sqlalchemy-1.2.15-py37h7b6447c_0.tar.bz2 152 | https://repo.anaconda.com/pkgs/main/linux-64/tblib-1.3.2-py37_0.tar.bz2 153 | https://repo.anaconda.com/pkgs/main/linux-64/testpath-0.4.2-py37_0.tar.bz2 154 | https://repo.anaconda.com/pkgs/main/linux-64/toolz-0.9.0-py37_0.tar.bz2 155 | https://repo.anaconda.com/pkgs/main/linux-64/tornado-5.1.1-py37h7b6447c_0.tar.bz2 156 | https://repo.anaconda.com/pkgs/main/linux-64/tqdm-4.28.1-py37h28b3542_0.tar.bz2 157 | https://repo.anaconda.com/pkgs/main/linux-64/unicodecsv-0.14.1-py37_0.tar.bz2 158 | https://repo.anaconda.com/pkgs/main/linux-64/wcwidth-0.1.7-py37_0.tar.bz2 159 | https://repo.anaconda.com/pkgs/main/linux-64/webencodings-0.5.1-py37_1.tar.bz2 160 | https://repo.anaconda.com/pkgs/main/linux-64/werkzeug-0.14.1-py37_0.tar.bz2 161 | https://repo.anaconda.com/pkgs/main/linux-64/wrapt-1.10.11-py37h14c3975_2.tar.bz2 162 | https://repo.anaconda.com/pkgs/main/linux-64/wurlitzer-1.0.2-py37_0.tar.bz2 163 | https://repo.anaconda.com/pkgs/main/linux-64/xlrd-1.2.0-py37_0.tar.bz2 164 | https://repo.anaconda.com/pkgs/main/linux-64/xlsxwriter-1.1.2-py37_0.tar.bz2 165 | https://repo.anaconda.com/pkgs/main/linux-64/xlwt-1.3.0-py37_0.tar.bz2 166 | https://repo.anaconda.com/pkgs/main/linux-64/astroid-2.1.0-py37_0.tar.bz2 167 | https://repo.anaconda.com/pkgs/main/linux-64/babel-2.6.0-py37_0.tar.bz2 168 | https://repo.anaconda.com/pkgs/main/linux-64/backports.os-0.1.1-py37_0.tar.bz2 169 | https://repo.anaconda.com/pkgs/main/linux-64/backports.shutil_get_terminal_size-1.0.0-py37_2.tar.bz2 170 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.11.5-py37he75722e_1.tar.bz2 171 | https://repo.anaconda.com/pkgs/main/linux-64/cycler-0.10.0-py37_0.tar.bz2 172 | https://repo.anaconda.com/pkgs/main/linux-64/cytoolz-0.9.0.1-py37h14c3975_1.tar.bz2 173 | https://repo.anaconda.com/pkgs/main/linux-64/harfbuzz-1.8.8-hffaf4a1_0.tar.bz2 174 | https://repo.anaconda.com/pkgs/main/linux-64/html5lib-1.0.1-py37_0.tar.bz2 175 | https://repo.anaconda.com/pkgs/main/linux-64/jedi-0.13.2-py37_0.tar.bz2 176 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.0.6-py37hd81dba3_0.tar.bz2 177 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.0.2-py37hd81dba3_0.tar.bz2 178 | https://repo.anaconda.com/pkgs/main/linux-64/more-itertools-4.3.0-py37_0.tar.bz2 179 | https://repo.anaconda.com/pkgs/main/linux-64/multipledispatch-0.6.0-py37_0.tar.bz2 180 | https://repo.anaconda.com/pkgs/main/linux-64/nltk-3.4-py37_1.tar.bz2 181 | https://repo.anaconda.com/pkgs/main/linux-64/openpyxl-2.5.12-py37_0.tar.bz2 182 | https://repo.anaconda.com/pkgs/main/linux-64/packaging-18.0-py37_0.tar.bz2 183 | https://repo.anaconda.com/pkgs/main/linux-64/partd-0.3.9-py37_0.tar.bz2 184 | https://repo.anaconda.com/pkgs/main/linux-64/pathlib2-2.3.3-py37_0.tar.bz2 185 | https://repo.anaconda.com/pkgs/main/linux-64/pexpect-4.6.0-py37_0.tar.bz2 186 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-5.3.0-py37h34e0f95_0.tar.bz2 187 | https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.9.2-py37h05f1152_2.tar.bz2 188 | https://repo.anaconda.com/pkgs/main/linux-64/python-dateutil-2.7.5-py37_0.tar.bz2 189 | https://repo.anaconda.com/pkgs/main/linux-64/qtawesome-0.5.3-py37_0.tar.bz2 190 | https://repo.anaconda.com/pkgs/main/linux-64/setuptools-40.6.3-py37_0.tar.bz2 191 | https://repo.anaconda.com/pkgs/main/linux-64/singledispatch-3.4.0.3-py37_0.tar.bz2 192 | https://repo.anaconda.com/pkgs/main/linux-64/sortedcollections-1.0.1-py37_0.tar.bz2 193 | https://repo.anaconda.com/pkgs/main/linux-64/sphinxcontrib-websupport-1.1.0-py37_1.tar.bz2 194 | https://repo.anaconda.com/pkgs/main/linux-64/sympy-1.3-py37_0.tar.bz2 195 | https://repo.anaconda.com/pkgs/main/linux-64/terminado-0.8.1-py37_1.tar.bz2 196 | https://repo.anaconda.com/pkgs/main/linux-64/traitlets-4.3.2-py37_0.tar.bz2 197 | https://repo.anaconda.com/pkgs/main/linux-64/zict-0.1.3-py37_0.tar.bz2 198 | https://repo.anaconda.com/pkgs/main/linux-64/bleach-3.0.2-py37_0.tar.bz2 199 | https://repo.anaconda.com/pkgs/main/linux-64/clyent-1.2.2-py37_1.tar.bz2 200 | https://repo.anaconda.com/pkgs/main/linux-64/cryptography-2.4.2-py37h1ba5d50_0.tar.bz2 201 | https://repo.anaconda.com/pkgs/main/linux-64/cython-0.29.2-py37he6710b0_0.tar.bz2 202 | https://repo.anaconda.com/pkgs/main/linux-64/distributed-1.25.1-py37_0.tar.bz2 203 | https://repo.anaconda.com/pkgs/main/linux-64/get_terminal_size-1.0.0-haa9412d_0.tar.bz2 204 | https://repo.anaconda.com/pkgs/main/linux-64/gevent-1.3.7-py37h7b6447c_1.tar.bz2 205 | https://repo.anaconda.com/pkgs/main/linux-64/isort-4.3.4-py37_0.tar.bz2 206 | https://repo.anaconda.com/pkgs/main/linux-64/jinja2-2.10-py37_0.tar.bz2 207 | https://repo.anaconda.com/pkgs/main/linux-64/jsonschema-2.6.0-py37_0.tar.bz2 208 | https://repo.anaconda.com/pkgs/main/linux-64/jupyter_core-4.4.0-py37_0.tar.bz2 209 | https://repo.anaconda.com/pkgs/main/linux-64/navigator-updater-0.2.1-py37_0.tar.bz2 210 | https://repo.anaconda.com/pkgs/main/linux-64/networkx-2.2-py37_1.tar.bz2 211 | https://repo.anaconda.com/pkgs/main/linux-64/nose-1.3.7-py37_2.tar.bz2 212 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.15.4-py37h7e9f1db_0.tar.bz2 213 | https://repo.anaconda.com/pkgs/main/linux-64/pango-1.42.4-h049681c_0.tar.bz2 214 | https://repo.anaconda.com/pkgs/main/linux-64/path.py-11.5.0-py37_0.tar.bz2 215 | https://repo.anaconda.com/pkgs/main/linux-64/pygments-2.3.1-py37_0.tar.bz2 216 | https://repo.anaconda.com/pkgs/main/linux-64/pytest-4.0.2-py37_0.tar.bz2 217 | https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.32.3-py37_0.tar.bz2 218 | https://repo.anaconda.com/pkgs/main/linux-64/bokeh-1.0.2-py37_0.tar.bz2 219 | https://repo.anaconda.com/pkgs/main/linux-64/bottleneck-1.2.1-py37h035aef0_1.tar.bz2 220 | https://repo.anaconda.com/pkgs/main/linux-64/conda-verify-3.1.1-py37_0.tar.bz2 221 | https://repo.anaconda.com/pkgs/main/linux-64/datashape-0.5.4-py37_1.tar.bz2 222 | https://repo.anaconda.com/pkgs/main/linux-64/flask-1.0.2-py37_1.tar.bz2 223 | https://repo.anaconda.com/pkgs/main/linux-64/h5py-2.8.0-py37h989c5e5_3.tar.bz2 224 | https://repo.anaconda.com/pkgs/main/linux-64/imageio-2.4.1-py37_0.tar.bz2 225 | https://repo.anaconda.com/pkgs/main/linux-64/jupyter_client-5.2.4-py37_0.tar.bz2 226 | https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.0.2-py37h5429711_0.tar.bz2 227 | https://repo.anaconda.com/pkgs/main/linux-64/nbformat-4.4.0-py37_0.tar.bz2 228 | https://repo.anaconda.com/pkgs/main/linux-64/numba-0.41.0-py37h962f231_0.tar.bz2 229 | https://repo.anaconda.com/pkgs/main/linux-64/numexpr-2.6.8-py37h9e4a6bb_0.tar.bz2 230 | https://repo.anaconda.com/pkgs/main/linux-64/pandas-0.23.4-py37h04863e7_0.tar.bz2 231 | https://repo.anaconda.com/pkgs/main/linux-64/pip-18.1-py37_0.tar.bz2 232 | https://repo.anaconda.com/pkgs/main/linux-64/prompt_toolkit-2.0.7-py37_0.tar.bz2 233 | https://repo.anaconda.com/pkgs/main/linux-64/pylint-2.2.2-py37_0.tar.bz2 234 | https://repo.anaconda.com/pkgs/main/linux-64/pyopenssl-18.0.0-py37_0.tar.bz2 235 | https://repo.anaconda.com/pkgs/main/linux-64/pytest-arraydiff-0.3-py37h39e3cac_0.tar.bz2 236 | https://repo.anaconda.com/pkgs/main/linux-64/pytest-doctestplus-0.2.0-py37_0.tar.bz2 237 | https://repo.anaconda.com/pkgs/main/linux-64/pytest-openfiles-0.3.1-py37_0.tar.bz2 238 | https://repo.anaconda.com/pkgs/main/linux-64/pytest-remotedata-0.3.1-py37_0.tar.bz2 239 | https://repo.anaconda.com/pkgs/main/linux-64/pywavelets-1.0.1-py37hdd07704_0.tar.bz2 240 | https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.1.0-py37h7c811a0_2.tar.bz2 241 | https://repo.anaconda.com/pkgs/main/linux-64/secretstorage-3.1.0-py37_0.tar.bz2 242 | https://repo.anaconda.com/pkgs/main/linux-64/bkcharts-0.2-py37_0.tar.bz2 243 | https://repo.anaconda.com/pkgs/main/linux-64/dask-1.0.0-py37_0.tar.bz2 244 | https://repo.anaconda.com/pkgs/main/linux-64/flask-cors-3.0.7-py37_0.tar.bz2 245 | https://repo.anaconda.com/pkgs/main/linux-64/ipython-7.2.0-py37h39e3cac_0.tar.bz2 246 | https://repo.anaconda.com/pkgs/main/linux-64/keyring-17.0.0-py37_0.tar.bz2 247 | https://repo.anaconda.com/pkgs/main/linux-64/nbconvert-5.4.0-py37_1.tar.bz2 248 | https://repo.anaconda.com/pkgs/main/linux-64/patsy-0.5.1-py37_0.tar.bz2 249 | https://repo.anaconda.com/pkgs/main/linux-64/pytables-3.4.4-py37ha205bf6_0.tar.bz2 250 | https://repo.anaconda.com/pkgs/main/linux-64/pytest-astropy-0.5.0-py37_0.tar.bz2 251 | https://repo.anaconda.com/pkgs/main/linux-64/scikit-image-0.14.1-py37he6710b0_0.tar.bz2 252 | https://repo.anaconda.com/pkgs/main/linux-64/scikit-learn-0.20.1-py37hd81dba3_0.tar.bz2 253 | https://repo.anaconda.com/pkgs/main/linux-64/urllib3-1.24.1-py37_0.tar.bz2 254 | https://repo.anaconda.com/pkgs/main/linux-64/astropy-3.1-py37h7b6447c_0.tar.bz2 255 | https://repo.anaconda.com/pkgs/main/linux-64/ipykernel-5.1.0-py37h39e3cac_0.tar.bz2 256 | https://repo.anaconda.com/pkgs/main/linux-64/odo-0.5.1-py37_0.tar.bz2 257 | https://repo.anaconda.com/pkgs/main/linux-64/requests-2.21.0-py37_0.tar.bz2 258 | https://repo.anaconda.com/pkgs/main/linux-64/statsmodels-0.9.0-py37h035aef0_0.tar.bz2 259 | https://repo.anaconda.com/pkgs/main/linux-64/anaconda-client-1.7.2-py37_0.tar.bz2 260 | https://repo.anaconda.com/pkgs/main/linux-64/blaze-0.11.3-py37_0.tar.bz2 261 | https://repo.anaconda.com/pkgs/main/linux-64/jupyter_console-6.0.0-py37_0.tar.bz2 262 | https://repo.anaconda.com/pkgs/main/linux-64/notebook-5.7.4-py37_0.tar.bz2 263 | https://repo.anaconda.com/pkgs/main/linux-64/qtconsole-4.4.3-py37_0.tar.bz2 264 | https://repo.anaconda.com/pkgs/main/linux-64/seaborn-0.9.0-py37_0.tar.bz2 265 | https://repo.anaconda.com/pkgs/main/linux-64/sphinx-1.8.2-py37_0.tar.bz2 266 | https://repo.anaconda.com/pkgs/main/linux-64/spyder-kernels-0.3.0-py37_0.tar.bz2 267 | https://repo.anaconda.com/pkgs/main/linux-64/anaconda-navigator-1.9.6-py37_0.tar.bz2 268 | https://repo.anaconda.com/pkgs/main/linux-64/anaconda-project-0.8.2-py37_0.tar.bz2 269 | https://repo.anaconda.com/pkgs/main/linux-64/jupyterlab_server-0.2.0-py37_0.tar.bz2 270 | https://repo.anaconda.com/pkgs/main/linux-64/numpydoc-0.8.0-py37_0.tar.bz2 271 | https://repo.anaconda.com/pkgs/main/linux-64/widgetsnbextension-3.4.2-py37_0.tar.bz2 272 | https://repo.anaconda.com/pkgs/main/linux-64/ipywidgets-7.4.2-py37_0.tar.bz2 273 | https://repo.anaconda.com/pkgs/main/linux-64/jupyterlab-0.35.3-py37_0.tar.bz2 274 | https://repo.anaconda.com/pkgs/main/linux-64/spyder-3.3.2-py37_0.tar.bz2 275 | https://repo.anaconda.com/pkgs/main/linux-64/_ipyw_jlab_nb_ext_conf-0.1.0-py37_0.tar.bz2 276 | https://repo.anaconda.com/pkgs/main/linux-64/jupyter-1.0.0-py37_7.tar.bz2 277 | https://repo.anaconda.com/pkgs/main/linux-64/anaconda-2018.12-py37_0.tar.bz2 278 | https://repo.anaconda.com/pkgs/main/linux-64/conda-4.5.12-py37_0.tar.bz2 279 | https://repo.anaconda.com/pkgs/main/linux-64/conda-build-3.17.6-py37_0.tar.bz2 280 | -------------------------------------------------------------------------------- /scrape_fivek.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scrape the MIT/Adobe dataset website for downloading the images. 3 | Save semantic info and paths in csv file 'mitdatainfo.csv'. 4 | 5 | Create a directory hierarchy of the type: 6 | base_dir/ 7 | or_dir/ 8 | 0.png 9 | ... 10 | expert_dir[0]/ 11 | 0.png 12 | ... 13 | expert_dir[1]/ 14 | 0.png 15 | ... 16 | ... 17 | """ 18 | import requests 19 | from lxml import html 20 | import os 21 | from PIL import Image 22 | import pandas as pd 23 | import rawpy 24 | import shutil 25 | import argparse 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--base_link', default='https://data.csail.mit.edu/graphics/fivek', help="Base link of website") 29 | parser.add_argument('--base_dir', default='fivek', help="Path of the base directory for the dataset") 30 | parser.add_argument('--size', default=500, help="Size for the greatest dimension of images") 31 | parser.add_argument('--limit_number', default=0, type=int, help="Limit the number of images to scrape. 0 -> no limit") 32 | parser.add_argument('--reverse', action='store_true', help="Use reverse order for downloading the images") 33 | args = vars(parser.parse_args()) 34 | 35 | #Parameters 36 | base_link = args['base_link'] 37 | base_dir = args['base_dir'] 38 | or_dir = os.path.join(base_dir, 'original') 39 | expert_dirs = [os.path.join(base_dir, 'expert{}'.format(i)) for i in range(5)] 40 | size = args['size'] #maximum size for a dimension 41 | limit_number = args['limit_number'] 42 | reverse = args['reverse'] 43 | 44 | page = requests.get(base_link) 45 | tree = html.fromstring(page.content) 46 | numbers = tree.xpath("//table[@class='data']//tr/td[1]/text()") 47 | original_links = tree.xpath("//table[@class='data']//tr/td[3]/a/@href") 48 | #Create list of lists (number_of_ims, number_of_experts) 49 | expert_links = list(map(list, zip(*[tree.xpath("//table[@class='data']//tr/td[{}]/a/@href".format(i)) 50 | for i in range(4, 9)]))) 51 | subjects = [el.tail for el in tree.xpath("//table[@class='data']//tr/td[9]/br")] 52 | light = [el.tail for el in tree.xpath("//table[@class='data']//tr/td[10]/br")] 53 | location = [el.tail for el in tree.xpath("//table[@class='data']//tr/td[11]/br")] 54 | time = [el.tail for el in tree.xpath("//table[@class='data']//tr/td[12]/br")] 55 | exif = [el.tail for el in tree.xpath("//table[@class='data']//tr/td[13]/br")] 56 | 57 | info_dataframe = pd.DataFrame({ 58 | "number": numbers, 59 | "subject": subjects, 60 | "light": light, 61 | "location": location, 62 | "time": time, 63 | "exif": exif, 64 | "original_path": '', 65 | "expert0_path": '', 66 | "expert1_path": '', 67 | "expert2_path": '', 68 | "expert3_path": '', 69 | "expert4_path": '', 70 | }) 71 | 72 | os.makedirs(or_dir, exist_ok=True) 73 | for expert_dir in expert_dirs: 74 | os.makedirs(expert_dir, exist_ok=True) 75 | 76 | idxs = range(len(original_links)) 77 | if reverse: 78 | print("Downloading in reverse order") 79 | idxs = reversed(idxs) 80 | original_links = reversed(original_links) 81 | expert_links = reversed(expert_links) 82 | 83 | for im_count, (original_link, expert_link) in zip(idxs, zip(original_links, expert_links)): 84 | print("Processing original image {}...".format(im_count), end="\r") 85 | filename = os.path.join(or_dir, '{}.png'.format(im_count)) 86 | 87 | if not os.path.exists(filename): 88 | #Workaround for dng: first save dng, then convert to array, then to png 89 | image_link = os.path.join(base_link, original_link) 90 | response = requests.get(image_link, stream=True) 91 | with open('temp.dng', 'wb') as out_file: 92 | shutil.copyfileobj(response.raw, out_file) 93 | with rawpy.imread('temp.dng') as raw: 94 | rgb = raw.postprocess() 95 | del response 96 | #Clean temp dng 97 | os.remove('temp.dng') 98 | image = Image.fromarray(rgb) 99 | image.thumbnail((size, size), Image.ANTIALIAS) 100 | image.save(filename) 101 | 102 | info_dataframe.at[im_count, 'original_path'] = filename 103 | 104 | for expert_count, link in enumerate(expert_link): 105 | print("Processing image {} of expert {}...".format(im_count, expert_count), end="\r") 106 | filename = os.path.join(expert_dirs[expert_count], '{}.png'.format(im_count)) 107 | 108 | if not os.path.exists(filename): 109 | #Download the image and resize to desired max size 110 | image_link = os.path.join(base_link, link) 111 | response = requests.get(image_link, stream=True) 112 | image = Image.open(response.raw) 113 | image.thumbnail((size, size), Image.ANTIALIAS) 114 | image.save(filename) 115 | info_dataframe.at[im_count, 'expert{}_path'.format(expert_count)] = filename 116 | 117 | if limit_number != 0 and (im_count+1) == limit_number: 118 | break 119 | 120 | info_dataframe.to_csv("mitdatainfo.csv", index=False) 121 | -------------------------------------------------------------------------------- /semseg/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, MIT CSAIL Computer Vision 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /semseg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proceduralia/pytorch-neural-enhance/6c9fbc378eae3d4ad1317dd01a3cb1a157909d84/semseg/__init__.py -------------------------------------------------------------------------------- /semseg/lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /semseg/lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /semseg/lib/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | # customed batch norm statistics 49 | self._moving_average_fraction = 1. - momentum 50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) 51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) 52 | self.register_buffer('_running_iter', torch.ones(1)) 53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter 54 | self._tmp_running_var = self.running_var.clone() * self._running_iter 55 | 56 | def forward(self, input): 57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 58 | if not (self._is_parallel and self.training): 59 | return F.batch_norm( 60 | input, self.running_mean, self.running_var, self.weight, self.bias, 61 | self.training, self.momentum, self.eps) 62 | 63 | # Resize the input to (B, C, -1). 64 | input_shape = input.size() 65 | input = input.view(input.size(0), self.num_features, -1) 66 | 67 | # Compute the sum and square-sum. 68 | sum_size = input.size(0) * input.size(2) 69 | input_sum = _sum_ft(input) 70 | input_ssum = _sum_ft(input ** 2) 71 | 72 | # Reduce-and-broadcast the statistics. 73 | if self._parallel_id == 0: 74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | else: 76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 77 | 78 | # Compute the output. 79 | if self.affine: 80 | # MJY:: Fuse the multiplication for speed. 81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 82 | else: 83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 84 | 85 | # Reshape it. 86 | return output.view(input_shape) 87 | 88 | def __data_parallel_replicate__(self, ctx, copy_id): 89 | self._is_parallel = True 90 | self._parallel_id = copy_id 91 | 92 | # parallel_id == 0 means master device. 93 | if self._parallel_id == 0: 94 | ctx.sync_master = self._sync_master 95 | else: 96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 97 | 98 | def _data_parallel_master(self, intermediates): 99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 101 | 102 | to_reduce = [i[1][:2] for i in intermediates] 103 | to_reduce = [j for i in to_reduce for j in i] # flatten 104 | target_gpus = [i[1].sum.get_device() for i in intermediates] 105 | 106 | sum_size = sum([i[1].sum_size for i in intermediates]) 107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 108 | 109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 110 | 111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 112 | 113 | outputs = [] 114 | for i, rec in enumerate(intermediates): 115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 116 | 117 | return outputs 118 | 119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): 120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`""" 121 | return dest * alpha + delta * beta + bias 122 | 123 | def _compute_mean_std(self, sum_, ssum, size): 124 | """Compute the mean and standard-deviation with sum and square-sum. This method 125 | also maintains the moving average on the master device.""" 126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 127 | mean = sum_ / size 128 | sumvar = ssum - sum_ * mean 129 | unbias_var = sumvar / (size - 1) 130 | bias_var = sumvar / size 131 | 132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) 133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) 134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) 135 | 136 | self.running_mean = self._tmp_running_mean / self._running_iter 137 | self.running_var = self._tmp_running_var / self._running_iter 138 | 139 | return mean, bias_var.clamp(self.eps) ** -0.5 140 | 141 | 142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 144 | mini-batch. 145 | 146 | .. math:: 147 | 148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 149 | 150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 151 | standard-deviation are reduced across all devices during training. 152 | 153 | For example, when one uses `nn.DataParallel` to wrap the network during 154 | training, PyTorch's implementation normalize the tensor on each device using 155 | the statistics only on that device, which accelerated the computation and 156 | is also easy to implement, but the statistics might be inaccurate. 157 | Instead, in this synchronized version, the statistics will be computed 158 | over all training samples distributed on multiple devices. 159 | 160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 161 | as the built-in PyTorch implementation. 162 | 163 | The mean and standard-deviation are calculated per-dimension over 164 | the mini-batches and gamma and beta are learnable parameter vectors 165 | of size C (where C is the input size). 166 | 167 | During training, this layer keeps a running estimate of its computed mean 168 | and variance. The running sum is kept with a default momentum of 0.1. 169 | 170 | During evaluation, this running mean/variance is used for normalization. 171 | 172 | Because the BatchNorm is done over the `C` dimension, computing statistics 173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 174 | 175 | Args: 176 | num_features: num_features from an expected input of size 177 | `batch_size x num_features [x width]` 178 | eps: a value added to the denominator for numerical stability. 179 | Default: 1e-5 180 | momentum: the value used for the running_mean and running_var 181 | computation. Default: 0.1 182 | affine: a boolean value that when set to ``True``, gives the layer learnable 183 | affine parameters. Default: ``True`` 184 | 185 | Shape: 186 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 188 | 189 | Examples: 190 | >>> # With Learnable Parameters 191 | >>> m = SynchronizedBatchNorm1d(100) 192 | >>> # Without Learnable Parameters 193 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 194 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 195 | >>> output = m(input) 196 | """ 197 | 198 | def _check_input_dim(self, input): 199 | if input.dim() != 2 and input.dim() != 3: 200 | raise ValueError('expected 2D or 3D input (got {}D input)' 201 | .format(input.dim())) 202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 203 | 204 | 205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 207 | of 3d inputs 208 | 209 | .. math:: 210 | 211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 212 | 213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 214 | standard-deviation are reduced across all devices during training. 215 | 216 | For example, when one uses `nn.DataParallel` to wrap the network during 217 | training, PyTorch's implementation normalize the tensor on each device using 218 | the statistics only on that device, which accelerated the computation and 219 | is also easy to implement, but the statistics might be inaccurate. 220 | Instead, in this synchronized version, the statistics will be computed 221 | over all training samples distributed on multiple devices. 222 | 223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 224 | as the built-in PyTorch implementation. 225 | 226 | The mean and standard-deviation are calculated per-dimension over 227 | the mini-batches and gamma and beta are learnable parameter vectors 228 | of size C (where C is the input size). 229 | 230 | During training, this layer keeps a running estimate of its computed mean 231 | and variance. The running sum is kept with a default momentum of 0.1. 232 | 233 | During evaluation, this running mean/variance is used for normalization. 234 | 235 | Because the BatchNorm is done over the `C` dimension, computing statistics 236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 237 | 238 | Args: 239 | num_features: num_features from an expected input of 240 | size batch_size x num_features x height x width 241 | eps: a value added to the denominator for numerical stability. 242 | Default: 1e-5 243 | momentum: the value used for the running_mean and running_var 244 | computation. Default: 0.1 245 | affine: a boolean value that when set to ``True``, gives the layer learnable 246 | affine parameters. Default: ``True`` 247 | 248 | Shape: 249 | - Input: :math:`(N, C, H, W)` 250 | - Output: :math:`(N, C, H, W)` (same shape as input) 251 | 252 | Examples: 253 | >>> # With Learnable Parameters 254 | >>> m = SynchronizedBatchNorm2d(100) 255 | >>> # Without Learnable Parameters 256 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 258 | >>> output = m(input) 259 | """ 260 | 261 | def _check_input_dim(self, input): 262 | if input.dim() != 4: 263 | raise ValueError('expected 4D input (got {}D input)' 264 | .format(input.dim())) 265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 266 | 267 | 268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 270 | of 4d inputs 271 | 272 | .. math:: 273 | 274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 275 | 276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 277 | standard-deviation are reduced across all devices during training. 278 | 279 | For example, when one uses `nn.DataParallel` to wrap the network during 280 | training, PyTorch's implementation normalize the tensor on each device using 281 | the statistics only on that device, which accelerated the computation and 282 | is also easy to implement, but the statistics might be inaccurate. 283 | Instead, in this synchronized version, the statistics will be computed 284 | over all training samples distributed on multiple devices. 285 | 286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 287 | as the built-in PyTorch implementation. 288 | 289 | The mean and standard-deviation are calculated per-dimension over 290 | the mini-batches and gamma and beta are learnable parameter vectors 291 | of size C (where C is the input size). 292 | 293 | During training, this layer keeps a running estimate of its computed mean 294 | and variance. The running sum is kept with a default momentum of 0.1. 295 | 296 | During evaluation, this running mean/variance is used for normalization. 297 | 298 | Because the BatchNorm is done over the `C` dimension, computing statistics 299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 300 | or Spatio-temporal BatchNorm 301 | 302 | Args: 303 | num_features: num_features from an expected input of 304 | size batch_size x num_features x depth x height x width 305 | eps: a value added to the denominator for numerical stability. 306 | Default: 1e-5 307 | momentum: the value used for the running_mean and running_var 308 | computation. Default: 0.1 309 | affine: a boolean value that when set to ``True``, gives the layer learnable 310 | affine parameters. Default: ``True`` 311 | 312 | Shape: 313 | - Input: :math:`(N, C, D, H, W)` 314 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 315 | 316 | Examples: 317 | >>> # With Learnable Parameters 318 | >>> m = SynchronizedBatchNorm3d(100) 319 | >>> # Without Learnable Parameters 320 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 322 | >>> output = m(input) 323 | """ 324 | 325 | def _check_input_dim(self, input): 326 | if input.dim() != 5: 327 | raise ValueError('expected 5D input (got {}D input)' 328 | .format(input.dim())) 329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 330 | -------------------------------------------------------------------------------- /semseg/lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /semseg/lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /semseg/lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /semseg/lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /semseg/lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /semseg/lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /semseg/lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /semseg/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /semseg/lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /semseg/lib/utils/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ 4 | _remove_worker_pids, _error_if_any_worker_fails 5 | from .sampler import SequentialSampler, RandomSampler, BatchSampler 6 | import signal 7 | import functools 8 | import collections 9 | import re 10 | import sys 11 | import threading 12 | import traceback 13 | from torch._six import string_classes, int_classes 14 | import numpy as np 15 | 16 | if sys.version_info[0] == 2: 17 | import Queue as queue 18 | else: 19 | import queue 20 | 21 | 22 | class ExceptionWrapper(object): 23 | r"Wraps an exception plus traceback to communicate across threads" 24 | 25 | def __init__(self, exc_info): 26 | self.exc_type = exc_info[0] 27 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 28 | 29 | 30 | _use_shared_memory = False 31 | """Whether to use shared memory in default_collate""" 32 | 33 | 34 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): 35 | global _use_shared_memory 36 | _use_shared_memory = True 37 | 38 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal 39 | # module's handlers are executed after Python returns from C low-level 40 | # handlers, likely when the same fatal signal happened again already. 41 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 42 | _set_worker_signal_handlers() 43 | 44 | torch.set_num_threads(1) 45 | torch.manual_seed(seed) 46 | np.random.seed(seed) 47 | 48 | if init_fn is not None: 49 | init_fn(worker_id) 50 | 51 | while True: 52 | r = index_queue.get() 53 | if r is None: 54 | break 55 | idx, batch_indices = r 56 | try: 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | except Exception: 59 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 60 | else: 61 | data_queue.put((idx, samples)) 62 | 63 | 64 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): 65 | if pin_memory: 66 | torch.cuda.set_device(device_id) 67 | 68 | while True: 69 | try: 70 | r = in_queue.get() 71 | except Exception: 72 | if done_event.is_set(): 73 | return 74 | raise 75 | if r is None: 76 | break 77 | if isinstance(r[1], ExceptionWrapper): 78 | out_queue.put(r) 79 | continue 80 | idx, batch = r 81 | try: 82 | if pin_memory: 83 | batch = pin_memory_batch(batch) 84 | except Exception: 85 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 86 | else: 87 | out_queue.put((idx, batch)) 88 | 89 | numpy_type_map = { 90 | 'float64': torch.DoubleTensor, 91 | 'float32': torch.FloatTensor, 92 | 'float16': torch.HalfTensor, 93 | 'int64': torch.LongTensor, 94 | 'int32': torch.IntTensor, 95 | 'int16': torch.ShortTensor, 96 | 'int8': torch.CharTensor, 97 | 'uint8': torch.ByteTensor, 98 | } 99 | 100 | 101 | def default_collate(batch): 102 | "Puts each data field into a tensor with outer dimension batch size" 103 | 104 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 105 | elem_type = type(batch[0]) 106 | if torch.is_tensor(batch[0]): 107 | out = None 108 | if _use_shared_memory: 109 | # If we're in a background process, concatenate directly into a 110 | # shared memory tensor to avoid an extra copy 111 | numel = sum([x.numel() for x in batch]) 112 | storage = batch[0].storage()._new_shared(numel) 113 | out = batch[0].new(storage) 114 | return torch.stack(batch, 0, out=out) 115 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 116 | and elem_type.__name__ != 'string_': 117 | elem = batch[0] 118 | if elem_type.__name__ == 'ndarray': 119 | # array of string classes and object 120 | if re.search('[SaUO]', elem.dtype.str) is not None: 121 | raise TypeError(error_msg.format(elem.dtype)) 122 | 123 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 124 | if elem.shape == (): # scalars 125 | py_type = float if elem.dtype.name.startswith('float') else int 126 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 127 | elif isinstance(batch[0], int_classes): 128 | return torch.LongTensor(batch) 129 | elif isinstance(batch[0], float): 130 | return torch.DoubleTensor(batch) 131 | elif isinstance(batch[0], string_classes): 132 | return batch 133 | elif isinstance(batch[0], collections.Mapping): 134 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 135 | elif isinstance(batch[0], collections.Sequence): 136 | transposed = zip(*batch) 137 | return [default_collate(samples) for samples in transposed] 138 | 139 | raise TypeError((error_msg.format(type(batch[0])))) 140 | 141 | 142 | def pin_memory_batch(batch): 143 | if torch.is_tensor(batch): 144 | return batch.pin_memory() 145 | elif isinstance(batch, string_classes): 146 | return batch 147 | elif isinstance(batch, collections.Mapping): 148 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 149 | elif isinstance(batch, collections.Sequence): 150 | return [pin_memory_batch(sample) for sample in batch] 151 | else: 152 | return batch 153 | 154 | 155 | _SIGCHLD_handler_set = False 156 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one 157 | handler needs to be set for all DataLoaders in a process.""" 158 | 159 | 160 | def _set_SIGCHLD_handler(): 161 | # Windows doesn't support SIGCHLD handler 162 | if sys.platform == 'win32': 163 | return 164 | # can't set signal in child threads 165 | if not isinstance(threading.current_thread(), threading._MainThread): 166 | return 167 | global _SIGCHLD_handler_set 168 | if _SIGCHLD_handler_set: 169 | return 170 | previous_handler = signal.getsignal(signal.SIGCHLD) 171 | if not callable(previous_handler): 172 | previous_handler = None 173 | 174 | def handler(signum, frame): 175 | # This following call uses `waitid` with WNOHANG from C side. Therefore, 176 | # Python can still get and update the process status successfully. 177 | _error_if_any_worker_fails() 178 | if previous_handler is not None: 179 | previous_handler(signum, frame) 180 | 181 | signal.signal(signal.SIGCHLD, handler) 182 | _SIGCHLD_handler_set = True 183 | 184 | 185 | class DataLoaderIter(object): 186 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 187 | 188 | def __init__(self, loader): 189 | self.dataset = loader.dataset 190 | self.collate_fn = loader.collate_fn 191 | self.batch_sampler = loader.batch_sampler 192 | self.num_workers = loader.num_workers 193 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 194 | self.timeout = loader.timeout 195 | self.done_event = threading.Event() 196 | 197 | self.sample_iter = iter(self.batch_sampler) 198 | 199 | if self.num_workers > 0: 200 | self.worker_init_fn = loader.worker_init_fn 201 | self.index_queue = multiprocessing.SimpleQueue() 202 | self.worker_result_queue = multiprocessing.SimpleQueue() 203 | self.batches_outstanding = 0 204 | self.worker_pids_set = False 205 | self.shutdown = False 206 | self.send_idx = 0 207 | self.rcvd_idx = 0 208 | self.reorder_dict = {} 209 | 210 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] 211 | self.workers = [ 212 | multiprocessing.Process( 213 | target=_worker_loop, 214 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, 215 | base_seed + i, self.worker_init_fn, i)) 216 | for i in range(self.num_workers)] 217 | 218 | if self.pin_memory or self.timeout > 0: 219 | self.data_queue = queue.Queue() 220 | if self.pin_memory: 221 | maybe_device_id = torch.cuda.current_device() 222 | else: 223 | # do not initialize cuda context if not necessary 224 | maybe_device_id = None 225 | self.worker_manager_thread = threading.Thread( 226 | target=_worker_manager_loop, 227 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, 228 | maybe_device_id)) 229 | self.worker_manager_thread.daemon = True 230 | self.worker_manager_thread.start() 231 | else: 232 | self.data_queue = self.worker_result_queue 233 | 234 | for w in self.workers: 235 | w.daemon = True # ensure that the worker exits on process exit 236 | w.start() 237 | 238 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) 239 | _set_SIGCHLD_handler() 240 | self.worker_pids_set = True 241 | 242 | # prime the prefetch loop 243 | for _ in range(2 * self.num_workers): 244 | self._put_indices() 245 | 246 | def __len__(self): 247 | return len(self.batch_sampler) 248 | 249 | def _get_batch(self): 250 | if self.timeout > 0: 251 | try: 252 | return self.data_queue.get(timeout=self.timeout) 253 | except queue.Empty: 254 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) 255 | else: 256 | return self.data_queue.get() 257 | 258 | def __next__(self): 259 | if self.num_workers == 0: # same-process loading 260 | indices = next(self.sample_iter) # may raise StopIteration 261 | batch = self.collate_fn([self.dataset[i] for i in indices]) 262 | if self.pin_memory: 263 | batch = pin_memory_batch(batch) 264 | return batch 265 | 266 | # check if the next sample has already been generated 267 | if self.rcvd_idx in self.reorder_dict: 268 | batch = self.reorder_dict.pop(self.rcvd_idx) 269 | return self._process_next_batch(batch) 270 | 271 | if self.batches_outstanding == 0: 272 | self._shutdown_workers() 273 | raise StopIteration 274 | 275 | while True: 276 | assert (not self.shutdown and self.batches_outstanding > 0) 277 | idx, batch = self._get_batch() 278 | self.batches_outstanding -= 1 279 | if idx != self.rcvd_idx: 280 | # store out-of-order samples 281 | self.reorder_dict[idx] = batch 282 | continue 283 | return self._process_next_batch(batch) 284 | 285 | next = __next__ # Python 2 compatibility 286 | 287 | def __iter__(self): 288 | return self 289 | 290 | def _put_indices(self): 291 | assert self.batches_outstanding < 2 * self.num_workers 292 | indices = next(self.sample_iter, None) 293 | if indices is None: 294 | return 295 | self.index_queue.put((self.send_idx, indices)) 296 | self.batches_outstanding += 1 297 | self.send_idx += 1 298 | 299 | def _process_next_batch(self, batch): 300 | self.rcvd_idx += 1 301 | self._put_indices() 302 | if isinstance(batch, ExceptionWrapper): 303 | raise batch.exc_type(batch.exc_msg) 304 | return batch 305 | 306 | def __getstate__(self): 307 | # TODO: add limited pickling support for sharing an iterator 308 | # across multiple threads for HOGWILD. 309 | # Probably the best way to do this is by moving the sample pushing 310 | # to a separate thread and then just sharing the data queue 311 | # but signalling the end is tricky without a non-blocking API 312 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 313 | 314 | def _shutdown_workers(self): 315 | try: 316 | if not self.shutdown: 317 | self.shutdown = True 318 | self.done_event.set() 319 | # if worker_manager_thread is waiting to put 320 | while not self.data_queue.empty(): 321 | self.data_queue.get() 322 | for _ in self.workers: 323 | self.index_queue.put(None) 324 | # done_event should be sufficient to exit worker_manager_thread, 325 | # but be safe here and put another None 326 | self.worker_result_queue.put(None) 327 | finally: 328 | # removes pids no matter what 329 | if self.worker_pids_set: 330 | _remove_worker_pids(id(self)) 331 | self.worker_pids_set = False 332 | 333 | def __del__(self): 334 | if self.num_workers > 0: 335 | self._shutdown_workers() 336 | 337 | 338 | class DataLoader(object): 339 | """ 340 | Data loader. Combines a dataset and a sampler, and provides 341 | single- or multi-process iterators over the dataset. 342 | 343 | Arguments: 344 | dataset (Dataset): dataset from which to load the data. 345 | batch_size (int, optional): how many samples per batch to load 346 | (default: 1). 347 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 348 | at every epoch (default: False). 349 | sampler (Sampler, optional): defines the strategy to draw samples from 350 | the dataset. If specified, ``shuffle`` must be False. 351 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 352 | indices at a time. Mutually exclusive with batch_size, shuffle, 353 | sampler, and drop_last. 354 | num_workers (int, optional): how many subprocesses to use for data 355 | loading. 0 means that the data will be loaded in the main process. 356 | (default: 0) 357 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 358 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 359 | into CUDA pinned memory before returning them. 360 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 361 | if the dataset size is not divisible by the batch size. If ``False`` and 362 | the size of dataset is not divisible by the batch size, then the last batch 363 | will be smaller. (default: False) 364 | timeout (numeric, optional): if positive, the timeout value for collecting a batch 365 | from workers. Should always be non-negative. (default: 0) 366 | worker_init_fn (callable, optional): If not None, this will be called on each 367 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as 368 | input, after seeding and before data loading. (default: None) 369 | 370 | .. note:: By default, each worker will have its PyTorch seed set to 371 | ``base_seed + worker_id``, where ``base_seed`` is a long generated 372 | by main process using its RNG. You may use ``torch.initial_seed()`` to access 373 | this value in :attr:`worker_init_fn`, which can be used to set other seeds 374 | (e.g. NumPy) before data loading. 375 | 376 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an 377 | unpicklable object, e.g., a lambda function. 378 | """ 379 | 380 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 381 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, 382 | timeout=0, worker_init_fn=None): 383 | self.dataset = dataset 384 | self.batch_size = batch_size 385 | self.num_workers = num_workers 386 | self.collate_fn = collate_fn 387 | self.pin_memory = pin_memory 388 | self.drop_last = drop_last 389 | self.timeout = timeout 390 | self.worker_init_fn = worker_init_fn 391 | 392 | if timeout < 0: 393 | raise ValueError('timeout option should be non-negative') 394 | 395 | if batch_sampler is not None: 396 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 397 | raise ValueError('batch_sampler is mutually exclusive with ' 398 | 'batch_size, shuffle, sampler, and drop_last') 399 | 400 | if sampler is not None and shuffle: 401 | raise ValueError('sampler is mutually exclusive with shuffle') 402 | 403 | if self.num_workers < 0: 404 | raise ValueError('num_workers cannot be negative; ' 405 | 'use num_workers=0 to disable multiprocessing.') 406 | 407 | if batch_sampler is None: 408 | if sampler is None: 409 | if shuffle: 410 | sampler = RandomSampler(dataset) 411 | else: 412 | sampler = SequentialSampler(dataset) 413 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 414 | 415 | self.sampler = sampler 416 | self.batch_sampler = batch_sampler 417 | 418 | def __iter__(self): 419 | return DataLoaderIter(self) 420 | 421 | def __len__(self): 422 | return len(self.batch_sampler) 423 | -------------------------------------------------------------------------------- /semseg/lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /semseg/lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /semseg/lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /semseg/lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /semseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule 2 | -------------------------------------------------------------------------------- /semseg/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MobileNetV2 implementation is modified from the following repository: 3 | https://github.com/tonylins/pytorch-mobilenet-v2 4 | """ 5 | 6 | import os 7 | import sys 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | from ..lib.nn import SynchronizedBatchNorm2d 12 | 13 | try: 14 | from urllib import urlretrieve 15 | except ImportError: 16 | from urllib.request import urlretrieve 17 | 18 | 19 | __all__ = ['mobilenetv2'] 20 | 21 | 22 | model_urls = { 23 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', 24 | } 25 | 26 | 27 | def conv_bn(inp, oup, stride): 28 | return nn.Sequential( 29 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 30 | SynchronizedBatchNorm2d(oup), 31 | nn.ReLU6(inplace=True) 32 | ) 33 | 34 | 35 | def conv_1x1_bn(inp, oup): 36 | return nn.Sequential( 37 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 38 | SynchronizedBatchNorm2d(oup), 39 | nn.ReLU6(inplace=True) 40 | ) 41 | 42 | 43 | class InvertedResidual(nn.Module): 44 | def __init__(self, inp, oup, stride, expand_ratio): 45 | super(InvertedResidual, self).__init__() 46 | self.stride = stride 47 | assert stride in [1, 2] 48 | 49 | hidden_dim = round(inp * expand_ratio) 50 | self.use_res_connect = self.stride == 1 and inp == oup 51 | 52 | if expand_ratio == 1: 53 | self.conv = nn.Sequential( 54 | # dw 55 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 56 | SynchronizedBatchNorm2d(hidden_dim), 57 | nn.ReLU6(inplace=True), 58 | # pw-linear 59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 60 | SynchronizedBatchNorm2d(oup), 61 | ) 62 | else: 63 | self.conv = nn.Sequential( 64 | # pw 65 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 66 | SynchronizedBatchNorm2d(hidden_dim), 67 | nn.ReLU6(inplace=True), 68 | # dw 69 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 70 | SynchronizedBatchNorm2d(hidden_dim), 71 | nn.ReLU6(inplace=True), 72 | # pw-linear 73 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 74 | SynchronizedBatchNorm2d(oup), 75 | ) 76 | 77 | def forward(self, x): 78 | if self.use_res_connect: 79 | return x + self.conv(x) 80 | else: 81 | return self.conv(x) 82 | 83 | 84 | class MobileNetV2(nn.Module): 85 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 86 | super(MobileNetV2, self).__init__() 87 | block = InvertedResidual 88 | input_channel = 32 89 | last_channel = 1280 90 | interverted_residual_setting = [ 91 | # t, c, n, s 92 | [1, 16, 1, 1], 93 | [6, 24, 2, 2], 94 | [6, 32, 3, 2], 95 | [6, 64, 4, 2], 96 | [6, 96, 3, 1], 97 | [6, 160, 3, 2], 98 | [6, 320, 1, 1], 99 | ] 100 | 101 | # building first layer 102 | assert input_size % 32 == 0 103 | input_channel = int(input_channel * width_mult) 104 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 105 | self.features = [conv_bn(3, input_channel, 2)] 106 | # building inverted residual blocks 107 | for t, c, n, s in interverted_residual_setting: 108 | output_channel = int(c * width_mult) 109 | for i in range(n): 110 | if i == 0: 111 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 112 | else: 113 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 114 | input_channel = output_channel 115 | # building last several layers 116 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 117 | # make it nn.Sequential 118 | self.features = nn.Sequential(*self.features) 119 | 120 | # building classifier 121 | self.classifier = nn.Sequential( 122 | nn.Dropout(0.2), 123 | nn.Linear(self.last_channel, n_class), 124 | ) 125 | 126 | self._initialize_weights() 127 | 128 | def forward(self, x): 129 | x = self.features(x) 130 | x = x.mean(3).mean(2) 131 | x = self.classifier(x) 132 | return x 133 | 134 | def _initialize_weights(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | if m.bias is not None: 140 | m.bias.data.zero_() 141 | elif isinstance(m, SynchronizedBatchNorm2d): 142 | m.weight.data.fill_(1) 143 | m.bias.data.zero_() 144 | elif isinstance(m, nn.Linear): 145 | n = m.weight.size(1) 146 | m.weight.data.normal_(0, 0.01) 147 | m.bias.data.zero_() 148 | 149 | 150 | def mobilenetv2(pretrained=False, **kwargs): 151 | """Constructs a MobileNet_V2 model. 152 | 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | """ 156 | model = MobileNetV2(n_class=1000, **kwargs) 157 | if pretrained: 158 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) 159 | return model 160 | 161 | 162 | def load_url(url, model_dir='./pretrained', map_location=None): 163 | if not os.path.exists(model_dir): 164 | os.makedirs(model_dir) 165 | filename = url.split('/')[-1] 166 | cached_file = os.path.join(model_dir, filename) 167 | if not os.path.exists(cached_file): 168 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 169 | urlretrieve(url, cached_file) 170 | return torch.load(cached_file, map_location=map_location) 171 | 172 | -------------------------------------------------------------------------------- /semseg/models/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from ..lib.nn import SynchronizedBatchNorm2d 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | 14 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 19 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 20 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | "3x3 convolution with padding" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = SynchronizedBatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = SynchronizedBatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = SynchronizedBatchNorm2d(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = SynchronizedBatchNorm2d(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = SynchronizedBatchNorm2d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 128 105 | super(ResNet, self).__init__() 106 | self.conv1 = conv3x3(3, 64, stride=2) 107 | self.bn1 = SynchronizedBatchNorm2d(64) 108 | self.relu1 = nn.ReLU(inplace=True) 109 | self.conv2 = conv3x3(64, 64) 110 | self.bn2 = SynchronizedBatchNorm2d(64) 111 | self.relu2 = nn.ReLU(inplace=True) 112 | self.conv3 = conv3x3(64, 128) 113 | self.bn3 = SynchronizedBatchNorm2d(128) 114 | self.relu3 = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | 117 | self.layer1 = self._make_layer(block, 64, layers[0]) 118 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 120 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 121 | self.avgpool = nn.AvgPool2d(7, stride=1) 122 | self.fc = nn.Linear(512 * block.expansion, num_classes) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 127 | m.weight.data.normal_(0, math.sqrt(2. / n)) 128 | elif isinstance(m, SynchronizedBatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | SynchronizedBatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for i in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def forward(self, x): 150 | x = self.relu1(self.bn1(self.conv1(x))) 151 | x = self.relu2(self.bn2(self.conv2(x))) 152 | x = self.relu3(self.bn3(self.conv3(x))) 153 | x = self.maxpool(x) 154 | 155 | x = self.layer1(x) 156 | x = self.layer2(x) 157 | x = self.layer3(x) 158 | x = self.layer4(x) 159 | 160 | x = self.avgpool(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.fc(x) 163 | 164 | return x 165 | 166 | def resnet18(pretrained=False, **kwargs): 167 | """Constructs a ResNet-18 model. 168 | 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(load_url(model_urls['resnet18'])) 175 | return model 176 | 177 | ''' 178 | def resnet34(pretrained=False, **kwargs): 179 | """Constructs a ResNet-34 model. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(load_url(model_urls['resnet34'])) 187 | return model 188 | ''' 189 | 190 | def resnet50(pretrained=False, **kwargs): 191 | """Constructs a ResNet-50 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 199 | return model 200 | 201 | 202 | def resnet101(pretrained=False, **kwargs): 203 | """Constructs a ResNet-101 model. 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 211 | return model 212 | 213 | # def resnet152(pretrained=False, **kwargs): 214 | # """Constructs a ResNet-152 model. 215 | # 216 | # Args: 217 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | # """ 219 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 220 | # if pretrained: 221 | # model.load_state_dict(load_url(model_urls['resnet152'])) 222 | # return model 223 | 224 | def load_url(url, model_dir='./pretrained', map_location=None): 225 | if not os.path.exists(model_dir): 226 | os.makedirs(model_dir) 227 | filename = url.split('/')[-1] 228 | cached_file = os.path.join(model_dir, filename) 229 | if not os.path.exists(cached_file): 230 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 231 | urlretrieve(url, cached_file) 232 | return torch.load(cached_file, map_location=map_location) 233 | -------------------------------------------------------------------------------- /semseg/models/resnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from ..lib.nn import SynchronizedBatchNorm2d 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | 14 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 15 | 16 | 17 | model_urls = { 18 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 19 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | "3x3 convolution with padding" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class GroupBottleneck(nn.Module): 30 | expansion = 2 31 | 32 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 33 | super(GroupBottleneck, self).__init__() 34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 35 | self.bn1 = SynchronizedBatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 37 | padding=1, groups=groups, bias=False) 38 | self.bn2 = SynchronizedBatchNorm2d(planes) 39 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 40 | self.bn3 = SynchronizedBatchNorm2d(planes * 2) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class ResNeXt(nn.Module): 69 | 70 | def __init__(self, block, layers, groups=32, num_classes=1000): 71 | self.inplanes = 128 72 | super(ResNeXt, self).__init__() 73 | self.conv1 = conv3x3(3, 64, stride=2) 74 | self.bn1 = SynchronizedBatchNorm2d(64) 75 | self.relu1 = nn.ReLU(inplace=True) 76 | self.conv2 = conv3x3(64, 64) 77 | self.bn2 = SynchronizedBatchNorm2d(64) 78 | self.relu2 = nn.ReLU(inplace=True) 79 | self.conv3 = conv3x3(64, 128) 80 | self.bn3 = SynchronizedBatchNorm2d(128) 81 | self.relu3 = nn.ReLU(inplace=True) 82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 83 | 84 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 85 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 86 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 87 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 88 | self.avgpool = nn.AvgPool2d(7, stride=1) 89 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | elif isinstance(m, SynchronizedBatchNorm2d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | 99 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | downsample = nn.Sequential( 103 | nn.Conv2d(self.inplanes, planes * block.expansion, 104 | kernel_size=1, stride=stride, bias=False), 105 | SynchronizedBatchNorm2d(planes * block.expansion), 106 | ) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, blocks): 112 | layers.append(block(self.inplanes, planes, groups=groups)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | x = self.relu1(self.bn1(self.conv1(x))) 118 | x = self.relu2(self.bn2(self.conv2(x))) 119 | x = self.relu3(self.bn3(self.conv3(x))) 120 | x = self.maxpool(x) 121 | 122 | x = self.layer1(x) 123 | x = self.layer2(x) 124 | x = self.layer3(x) 125 | x = self.layer4(x) 126 | 127 | x = self.avgpool(x) 128 | x = x.view(x.size(0), -1) 129 | x = self.fc(x) 130 | 131 | return x 132 | 133 | 134 | ''' 135 | def resnext50(pretrained=False, **kwargs): 136 | """Constructs a ResNet-50 model. 137 | 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on Places 140 | """ 141 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 142 | if pretrained: 143 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 144 | return model 145 | ''' 146 | 147 | 148 | def resnext101(pretrained=False, **kwargs): 149 | """Constructs a ResNet-101 model. 150 | 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on Places 153 | """ 154 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 155 | if pretrained: 156 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 157 | return model 158 | 159 | 160 | # def resnext152(pretrained=False, **kwargs): 161 | # """Constructs a ResNeXt-152 model. 162 | # 163 | # Args: 164 | # pretrained (bool): If True, returns a model pre-trained on Places 165 | # """ 166 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 167 | # if pretrained: 168 | # model.load_state_dict(load_url(model_urls['resnext152'])) 169 | # return model 170 | 171 | 172 | def load_url(url, model_dir='./pretrained', map_location=None): 173 | if not os.path.exists(model_dir): 174 | os.makedirs(model_dir) 175 | filename = url.split('/')[-1] 176 | cached_file = os.path.join(model_dir, filename) 177 | if not os.path.exists(cached_file): 178 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 179 | urlretrieve(url, cached_file) 180 | return torch.load(cached_file, map_location=map_location) 181 | -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | 75 | -------------------------------------------------------------------------------- /torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import os 4 | import re 5 | 6 | class JoinedDataLoader: 7 | """Loader for sampling from multiple loaders with probability proportional to their length. Stops when all loaders are exausthed. 8 | Useful in case you can't join samples of different datasets in a single batch. 9 | """ 10 | def __init__(self, loaderA, loaderB): 11 | self.loaderA = loaderA 12 | self.loaderB = loaderB 13 | self.probA = len(loaderA)/(len(loaderA)+len(loaderB)) 14 | self.loaderAiter, self.loaderBiter = iter(loaderA), iter(loaderB) 15 | 16 | def __iter__(self): 17 | return self 18 | 19 | def __next__(self): 20 | loader_choice = torch.rand(1).item() 21 | if loader_choice < self.probA: 22 | try: 23 | n = next(self.loaderAiter) 24 | except StopIteration: 25 | try: 26 | n = next(self.loaderBiter) 27 | except StopIteration: 28 | self._reset_iterators() 29 | raise StopIteration 30 | else: 31 | try: 32 | n = next(self.loaderBiter) 33 | except StopIteration: 34 | try: 35 | n = next(self.loaderAiter) 36 | except StopIteration: 37 | self._reset_iterators() 38 | raise StopIteration 39 | return n 40 | 41 | def _reset_iterators(self): 42 | self.loaderAiter, self.loaderBiter = iter(self.loaderA), iter(self.loaderB) 43 | 44 | def __len__(self): 45 | return len(self.loaderAiter) + len(self.loaderBiter) 46 | 47 | def conv_out_shape(dims, conv): 48 | """Computes the output shape for given convolution module 49 | Args: 50 | dims (tuples): a tuple of kind (w, h) 51 | conv (module): a pytorch convolutional module 52 | """ 53 | kernel_size, stride, pad, dilation = conv.kernel_size, conv.stride, conv.padding, conv.dilation 54 | return tuple(int(((dims[i] + (2 * pad[i]) - (dilation[i]*(kernel_size[i]-1))-1)/stride[i])+1) for i in range(len(dims))) 55 | 56 | def general_same_padding(i, k, d=1, s=1, dims=2): 57 | """Compute the padding to obtain the same output shape when using convolution 58 | Args: 59 | - input_size, kernel_size, dilation, stride (tuple or ints) 60 | - dims (int): number of dimensions for the padding 61 | """ 62 | #Convert i, k and d to tuples if they are int 63 | i = tuple([i for j in range(dims)]) if type(i) == int else i 64 | k = tuple([k for j in range(dims)]) if type(k) == int else k 65 | d = tuple([d for j in range(dims)]) if type(d) == int else d 66 | s = tuple([s for j in range(dims)]) if type(s) == int else s 67 | 68 | return tuple([int(0.5*(d[j]*(k[j]-1)-(1-i[j])*(s[j]-1))) for j in range(dims)]) 69 | 70 | def same_padding(k, d=1, dims=2): 71 | """Compute the padding to obtain the same output shape when using convolution, 72 | considering the case when the stride is unitary 73 | Args: 74 | - input_size, kernel_size, dilation, stride (tuple or ints) 75 | - dims (int): number of dimensions for the padding 76 | """ 77 | #Convert i, k and d to tuples if they are int 78 | k = tuple([k for j in range(dims)]) if type(k) == int else k 79 | d = tuple([d for j in range(dims)]) if type(d) == int else d 80 | 81 | return tuple([int(0.5*(d[j]*(k[j]-1))) for j in range(dims)]) 82 | 83 | def load_model(model,model_dir,run_tag): 84 | epoch = 0 85 | check = [f for f in os.listdir(model_dir) if f.startswith(run_tag)] 86 | if len(check)>0: 87 | epoch = max([int(re.findall('epoch\d+',c)[0][5:]) for c in check]) 88 | model.load_state_dict(torch.load(os.path.join(model_dir,run_tag+'_epoch'+str(epoch)+'.pt'))) 89 | print("Resuming trainig from epoch %d" % epoch) 90 | return model, epoch 91 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | from scipy.ndimage.filters import gaussian_filter 4 | from skimage import img_as_float 5 | 6 | def original(image): 7 | return image 8 | 9 | def _unsharp_mask_single_channel(image, radius, amount, vrange): 10 | """Single channel implementation of the unsharp masking filter.""" 11 | 12 | blurred = gaussian_filter(image, 13 | sigma=radius, 14 | mode='reflect') 15 | 16 | result = image + (image - blurred) * amount 17 | if vrange is not None: 18 | return np.clip(result, vrange[0], vrange[1], out=result) 19 | return result 20 | 21 | 22 | def unsharp_mask(image, radius=1.0, amount=1.0, multichannel=False, 23 | preserve_range=False): 24 | """Unsharp masking filter. 25 | The sharp details are identified as the difference between the original 26 | image and its blurred version. These details are then scaled, and added 27 | back to the original image. 28 | Parameters 29 | ---------- 30 | image : [P, ..., ]M[, N][, C] ndarray 31 | Input image. 32 | radius : scalar or sequence of scalars, optional 33 | If a scalar is given, then its value is used for all dimensions. 34 | If sequence is given, then there must be exactly one radius 35 | for each dimension except the last dimension for multichannel images. 36 | Note that 0 radius means no blurring, and negative values are 37 | not allowed. 38 | amount : scalar, optional 39 | The details will be amplified with this factor. The factor could be 0 40 | or negative. Typically, it is a small positive number, e.g. 1.0. 41 | multichannel : bool, optional 42 | If True, the last `image` dimension is considered as a color channel, 43 | otherwise as spatial. Color channels are processed individually. 44 | preserve_range: bool, optional 45 | Whether to keep the original range of values. Otherwise, the input 46 | image is converted according to the conventions of `img_as_float`. 47 | Also see https://scikit-image.org/docs/dev/user_guide/data_types.html 48 | Returns 49 | ------- 50 | output : [P, ..., ]M[, N][, C] ndarray of float 51 | Image with unsharp mask applied. 52 | Notes 53 | ----- 54 | Unsharp masking is an image sharpening technique. It is a linear image 55 | operation, and numerically stable, unlike deconvolution which is an 56 | ill-posed problem. Because of this stability, it is often 57 | preferred over deconvolution. 58 | The main idea is as follows: sharp details are identified as the 59 | difference between the original image and its blurred version. 60 | These details are added back to the original image after a scaling step: 61 | enhanced image = original + amount * (original - blurred) 62 | When applying this filter to several color layers independently, 63 | color bleeding may occur. More visually pleasing result can be 64 | achieved by processing only the brightness/lightness/intensity 65 | channel in a suitable color space such as HSV, HSL, YUV, or YCbCr. 66 | Unsharp masking is described in most introductory digital image 67 | processing books. This implementation is based on [1]_. 68 | Examples 69 | -------- 70 | >>> array = np.ones(shape=(5,5), dtype=np.uint8)*100 71 | >>> array[2,2] = 120 72 | >>> array 73 | array([[100, 100, 100, 100, 100], 74 | [100, 100, 100, 100, 100], 75 | [100, 100, 120, 100, 100], 76 | [100, 100, 100, 100, 100], 77 | [100, 100, 100, 100, 100]], dtype=uint8) 78 | >>> np.around(unsharp_mask(array, radius=0.5, amount=2),2) 79 | array([[ 0.39, 0.39, 0.39, 0.39, 0.39], 80 | [ 0.39, 0.39, 0.38, 0.39, 0.39], 81 | [ 0.39, 0.38, 0.53, 0.38, 0.39], 82 | [ 0.39, 0.39, 0.38, 0.39, 0.39], 83 | [ 0.39, 0.39, 0.39, 0.39, 0.39]]) 84 | >>> array = np.ones(shape=(5,5), dtype=np.int8)*100 85 | >>> array[2,2] = 127 86 | >>> np.around(unsharp_mask(array, radius=0.5, amount=2),2) 87 | array([[ 0.79, 0.79, 0.79, 0.79, 0.79], 88 | [ 0.79, 0.78, 0.75, 0.78, 0.79], 89 | [ 0.79, 0.75, 1. , 0.75, 0.79], 90 | [ 0.79, 0.78, 0.75, 0.78, 0.79], 91 | [ 0.79, 0.79, 0.79, 0.79, 0.79]]) 92 | >>> np.around(unsharp_mask(array, radius=0.5, amount=2, preserve_range=True), 2) 93 | array([[ 100. , 100. , 99.99, 100. , 100. ], 94 | [ 100. , 99.39, 95.48, 99.39, 100. ], 95 | [ 99.99, 95.48, 147.59, 95.48, 99.99], 96 | [ 100. , 99.39, 95.48, 99.39, 100. ], 97 | [ 100. , 100. , 99.99, 100. , 100. ]]) 98 | References 99 | ---------- 100 | .. [1] Maria Petrou, Costas Petrou 101 | "Image Processing: The Fundamentals", (2010), ed ii., page 357, 102 | ISBN 13: 9781119994398 :DOI:`10.1002/9781119994398` 103 | .. [2] Wikipedia. Unsharp masking 104 | https://en.wikipedia.org/wiki/Unsharp_masking 105 | """ 106 | vrange = None # Range for valid values; used for clipping. 107 | if preserve_range: 108 | fimg = image.astype(np.float) 109 | else: 110 | fimg = img_as_float(image) 111 | negative = np.any(fimg < 0) 112 | if negative: 113 | vrange = [-1., 1.] 114 | else: 115 | vrange = [0., 1.] 116 | 117 | if multichannel: 118 | result = np.empty_like(fimg, dtype=np.float) 119 | for channel in range(image.shape[-1]): 120 | result[..., channel] = _unsharp_mask_single_channel( 121 | fimg[..., channel], radius, amount, vrange) 122 | return result 123 | else: 124 | return _unsharp_mask_single_channel(fimg, radius, amount, vrange) 125 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def display_transforms(im, ops, figsize=(12, 18)): 5 | """ 6 | Apply all the transforms in ops to the image and display the result 7 | 8 | Args: 9 | - im (array): image to be processed 10 | - ops (dict): dictionary of operations to be applied 11 | - figsize (tuple): tuple with size of the figure 12 | """ 13 | n_rows = (len(ops)+1)#//2 + (0 if (len(ops)+1)%2 == 0 else 1) 14 | n_cols = 2#n_rows 15 | plt.figure(figsize=figsize) 16 | 17 | for i, op_name in enumerate(ops): 18 | plt.subplot(n_rows, n_cols, i*2+1) 19 | plt.imshow(ops[op_name](im)) 20 | plt.axis('off') 21 | plt.title(op_name) 22 | # plot the histogram 23 | plt.subplot(n_rows, n_cols, i*2+2) 24 | plt.hist(ops[op_name](im).ravel(),256) 25 | --------------------------------------------------------------------------------