├── .gitignore ├── README.md ├── crossentropy2d.py ├── cv_dataset.py ├── focal_loss_wrong_implementation.py ├── focalloss2d.py ├── image_loader.py ├── numpy_loader.py ├── numpy_transforms.py ├── png_loader.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | focal_loss_2d.py 2 | focal_loss_multiclass.py 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_toolbox 2 | pytorchに便利なツールのリポジトリです 3 | 4 | ### focal_loss_multiclass.py 5 | マルチクラス用のfocal lossクラスです。 6 | 引数でγ、weight(元論文のアルファに対応)、size_averageが設定できます。 7 | 8 | ### numpy_loader.py 9 | npyファイルで保存したデータセット用のデータローダーです。 10 | -------------------------------------------------------------------------------- /crossentropy2d.py: -------------------------------------------------------------------------------- 1 | #################################################### 2 | ##### This is focal loss class for multi class ##### 3 | ##### University of Tokyo Doi Kento ##### 4 | #################################################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | class CrossEntropy2d(nn.Module): 12 | 13 | def __init__(self, dim=1, weight=None, size_average=True, ignore_index=-100): 14 | super(CrossEntropy2d, self).__init__() 15 | 16 | """ 17 | dim : dimention along which log_softmax will be computed 18 | weight : class balancing weight 19 | size_average : which size average will be done or not 20 | ignore_index : index that ignored while training 21 | """ 22 | self.dim = dim 23 | self.weight = weight 24 | self.size_average = size_average 25 | self.ignore_index = ignore_index 26 | 27 | def forward(self, input, target): 28 | 29 | criterion = nn.NLLLoss2d(self.weight, self.size_average, self.ignore_index) 30 | return criterion(F.log_softmax(input, dim=self.dim), target) 31 | -------------------------------------------------------------------------------- /cv_dataset.py: -------------------------------------------------------------------------------- 1 | ##### ################################# ##### 2 | ##### This code is written by Doi Kento ##### 3 | ##### ################################# ##### 4 | 5 | import torch 6 | import torch.utils.data as data_utils 7 | import numpy as np 8 | import tiffile 9 | from pathlib import Path 10 | 11 | class CV_Dataset(data_utils.Dataset): 12 | """ 13 | This dataset is for cross validation. 14 | Args: 15 | data_path (str): path of root dir. 16 | cv_num (int): the number of the cross validation. 17 | phase (str): training pahse (train or val). 18 | trains (lst): list of transformation function. 19 | Returns: 20 | img (array): numpy array of image. 21 | label (int) : target label (0 or 1). 22 | """ 23 | def __init__(self, data_path, cv_num=1, phase='train', transform=None): 24 | # set the parameter 25 | self.root = Path(data_path) 26 | self.train_dir = self.root / 'train' 27 | self.imgs= [] 28 | for i in range(1,6): 29 | file = open(self.train_dir / 'cv_{}.txt'.format(i)) 30 | if i != self.phase: 31 | for line in file: 32 | img_name, label = line.split(',') 33 | label = int(label.replace('\n')) 34 | self.imgs.append([img_name, label]) 35 | 36 | def __getitem__(self, idx): 37 | img_name, label = self.imgs[idx] 38 | img = tifffile.imread(self.train_dir / img_name) 39 | 40 | return (img, label) 41 | 42 | def __len__(self): 43 | return len(self.imgs) 44 | 45 | def normalize(img, mean, std): 46 | return (img - mean) / std 47 | -------------------------------------------------------------------------------- /focal_loss_wrong_implementation.py: -------------------------------------------------------------------------------- 1 | #################################################### 2 | ##### This is focal loss class for multi class ##### 3 | ##### University of Tokyo Doi Kento ##### 4 | #################################################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | # I refered https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py 11 | 12 | class MultiClassFocalLoss(nn.Module): 13 | 14 | def __init__(self, gamma=0, weight=None, size_average=True): 15 | super(MultiClassFocalLoss, self).__init__() 16 | 17 | self.gamma = gamma 18 | self.weight = weight 19 | self.size_average = size_average 20 | 21 | def forward(self, input, target): 22 | if input.dim()>2: 23 | input = input.view(input.size(0), input.size(1), -1) 24 | input = input.transpose(1,2) 25 | input = input.contiguous().view(-1, input.size(2)).squeeze() 26 | if target.dim()>2: 27 | target = target.view(target.size(0), target.size(1), -1) 28 | target = target.transpose(1,2) 29 | target = target.contiguous().view(-1, target.size(2)) 30 | else: 31 | target = target.view(-1, 1) 32 | 33 | # compute the negative likelyhood 34 | logpt = F.log_softmax(input) 35 | logpt = logpt.gather(1,target.long()) 36 | logpt = logpt.view(-1).squeeze() 37 | pt = nn.Softmax()(input) 38 | pt = pt.gather(1, target).view(-1) 39 | 40 | # implement the class balancing (Coming soon...) 41 | if self.weight is not None: 42 | if self.weight.type() != input.data.type(): 43 | self.weight = self.weight.type_as(input.data) 44 | b_h_w, c = input.size() 45 | weight_tensor = Variable(self.weight[target.data.long().squeeze()]) 46 | 47 | # compute the loss 48 | if self.weight is not None: 49 | loss = -1 * weight_tensor * (1-pt)**self.gamma * logpt 50 | else: 51 | loss = -1 * (1-pt)**self.gamma * logpt 52 | 53 | # averaging (or not) loss 54 | if self.size_average: 55 | return loss.mean() 56 | else: 57 | return loss.sum() 58 | -------------------------------------------------------------------------------- /focalloss2d.py: -------------------------------------------------------------------------------- 1 | #################################################### 2 | ##### This is focal loss class for multi class ##### 3 | ##### University of Tokyo Doi Kento ##### 4 | #################################################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | # I refered https://github.com/c0nn3r/RetinaNet/blob/master/focal_loss.py 11 | 12 | class FocalLoss2d(nn.Module): 13 | 14 | def __init__(self, gamma=0, weight=None, size_average=True): 15 | super(FocalLoss2d, self).__init__() 16 | 17 | self.gamma = gamma 18 | self.weight = weight 19 | self.size_average = size_average 20 | 21 | def forward(self, input, target): 22 | if input.dim()>2: 23 | input = input.contiguous().view(input.size(0), input.size(1), -1) 24 | input = input.transpose(1,2) 25 | input = input.contiguous().view(-1, input.size(2)).squeeze() 26 | if target.dim()==4: 27 | target = target.contiguous().view(target.size(0), target.size(1), -1) 28 | target = target.transpose(1,2) 29 | target = target.contiguous().view(-1, target.size(2)).squeeze() 30 | elif target.dim()==3: 31 | target = target.view(-1) 32 | else: 33 | target = target.view(-1, 1) 34 | 35 | # compute the negative likelyhood 36 | weight = Variable(self.weight) 37 | logpt = -F.cross_entropy(input, target) 38 | pt = torch.exp(logpt) 39 | 40 | # compute the loss 41 | loss = -((1-pt)**self.gamma) * logpt 42 | 43 | # averaging (or not) loss 44 | if self.size_average: 45 | return loss.mean() 46 | else: 47 | return loss.sum() 48 | -------------------------------------------------------------------------------- /image_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import torch.utils.data as data_utils 5 | from PIL import Image 6 | import glob 7 | 8 | class ImageDataset(data_utils.Dataset): 9 | """ 10 | data loader of any image 11 | Args: 12 | img_dir (str): path of image directory. 13 | """ 14 | def __init__(self, img_dir, transform=None): 15 | self.img_dir = img_dir 16 | self.img_list = glob.glob(img_dir + '/*') 17 | self.img_list.sort() 18 | self.transform = transform 19 | 20 | def __getitem__(self, idx): 21 | img_path = self.img_list[idx] 22 | img = Image.open(img_path) 23 | 24 | if self.transform is not None: 25 | img = self.transform(img) 26 | 27 | return img 28 | 29 | def __len__(self): 30 | return len(self.img_list) 31 | -------------------------------------------------------------------------------- /numpy_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data as data_utils 6 | 7 | class Numpy_SegmentationDataset(data_utils.Dataset): 8 | """ 9 | data loader of numpy array 10 | Args: 11 | img_dir (str): path of image directory. (including 'train' and 'val' directory) 12 | GT_dir (str): path of GT directory. 13 | """ 14 | def __init__(self, img_dir, GT_dir, transform=None): 15 | self.img_dir = img_dir 16 | self.GT_dir = GT_dir 17 | self.img_list = [os.path.join(img_dir, img_path) for img_path in os.listdir(img_dir)] 18 | self.GT_list = [os.path.join(GT_dir, GT_path) for GT_path in os.listdir(GT_dir)] 19 | self.img_list.sort() 20 | self.GT_list.sort() 21 | self.transform = transform 22 | 23 | def __getitem__(self, idx): 24 | img_path = self.img_list[idx] 25 | GT_path = self.GT_list[idx] 26 | image_arr = np.load(img_path) 27 | GT_arr = np.load(GT_path) 28 | 29 | arr_list = [image_arr, GT_arr] 30 | 31 | if self.transform: 32 | arr_list = self.transform(arr_list) 33 | 34 | image_arr, GT_arr = arr_list 35 | return image_arr, GT_arr 36 | 37 | def __len__(self): 38 | return len(self.img_list) 39 | 40 | class RandomCrop_Segmentation(object): 41 | """ 42 | Crop images and labels randomly 43 | Args: 44 | output_size(int): Desired output size. 45 | """ 46 | def __init__(self, output_size): 47 | self.output_size = output_size 48 | 49 | def __call__(self, arr_list): 50 | image_arr, GT_arr = arr_list 51 | 52 | h, w = image_arr.shape[1:] 53 | new_h, new_w = self.output_size, self.output_size 54 | 55 | top = np.random.randint(0, h - new_h) 56 | left = np.random.randint(0, w - new_w) 57 | 58 | image_arr = image_arr[:, top: top + new_h, left: left + new_w] 59 | GT_arr = GT_arr[:, top: top + new_h, left: left + new_w] 60 | 61 | arr_list = [image_arr, GT_arr] 62 | 63 | return arr_list 64 | 65 | class Flip_Segmentation(object): 66 | """ 67 | Flip images and labels randomly 68 | """ 69 | def __init__(self): 70 | pass 71 | 72 | def __call__(self, arr_list): 73 | image_arr, GT_arr = arr_list 74 | 75 | if random.choices([True, False]): 76 | image_arr = np.flip(image_arr, 1).copy() 77 | GT_arr = np.flip(GT_arr, 1).copy() 78 | 79 | if random.choices([True, False]): 80 | image_arr = np.flip(image_arr, 2).copy() 81 | GT_arr = np.flip(GT_arr,2).copy() 82 | 83 | arr_list = [image_arr, GT_arr] 84 | 85 | return arr_list 86 | 87 | class Rotate_Segmentation(object): 88 | """ 89 | Rotate images and labels randomly 90 | """ 91 | def __init__(self): 92 | pass 93 | 94 | def __call__(self, arr_list): 95 | image_arr, GT_arr = arr_list 96 | 97 | n = random.choices([0, 1, 2, 3]) 98 | image_arr = np.rot90(image_arr, n[0], (1,2)).copy() 99 | GT_arr = np.rot90(GT_arr, n[0], (1,2)).copy() 100 | 101 | arr_list = [image_arr, GT_arr] 102 | 103 | return arr_list 104 | -------------------------------------------------------------------------------- /numpy_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | import random 5 | 6 | class Numpy_Flip(object): 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def __call__(self, image_array, p=0.5): 12 | if random.random() < p: 13 | image_array = image_array[:,:,::-1] 14 | 15 | if random.random() < p: 16 | image_array = image_array[:,::-1,:] 17 | 18 | return image_array.copy() 19 | 20 | class Numpy_Rotate(object): 21 | 22 | def __init__(self): 23 | pass 24 | 25 | def __call__(self, image_array): 26 | 27 | angle = random.choices([0, 1, 2, 3]) 28 | image_array = np.rot90(image_array, angle[0], (1,2)) 29 | 30 | return image_array.copy() 31 | -------------------------------------------------------------------------------- /png_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data as data_utils 6 | 7 | class PNG_SegmentationDataset(data_utils.Dataset): 8 | """ 9 | data loader of numpy array 10 | Args: 11 | img_dir (str): path of image directory. (including 'train' and 'val' directory) 12 | GT_dir (str): path of GT directory. 13 | """ 14 | def __init__(self, img_dir, GT_dir, transform=None): 15 | self.img_dir = img_dir 16 | self.GT_dir = GT_dir 17 | self.img_list = [os.path.join(img_dir, img_path) for img_path in os.listdir(img_dir)] 18 | self.GT_list = [os.path.join(GT_dir, GT_path) for GT_path in os.listdir(GT_dir)] 19 | self.img_list.sort() 20 | self.GT_list.sort() 21 | self.transform = transform 22 | 23 | def __getitem__(self, idx): 24 | img_path = self.img_list[idx] 25 | img = np.array(Image.open(img_path)) 26 | GT_path = self.GT_list[idx] 27 | GT = np.array(Image.open(GT_path)) 28 | 29 | arr_list = [img, GT] 30 | 31 | if self.transform: 32 | arr_list = self.transform(arr_list) 33 | 34 | img, GT = arr_list 35 | return img, GT 36 | 37 | def __len__(self): 38 | return len(self.img_list) 39 | 40 | class RandomCrop_Segmentation(object): 41 | """ 42 | Crop images and labels randomly 43 | Args: 44 | output_size(int): Desired output size. 45 | """ 46 | def __init__(self, output_size): 47 | self.output_size = output_size 48 | 49 | def __call__(self, arr_list): 50 | image_arr, GT_arr = arr_list 51 | 52 | h, w = image_arr.shape[1:] 53 | new_h, new_w = self.output_size, self.output_size 54 | 55 | top = np.random.randint(0, h - new_h) 56 | left = np.random.randint(0, w - new_w) 57 | 58 | image_arr = image_arr[:, top: top + new_h, left: left + new_w] 59 | GT_arr = GT_arr[:, top: top + new_h, left: left + new_w] 60 | 61 | arr_list = [image_arr, GT_arr] 62 | 63 | return arr_list 64 | 65 | class Flip_Segmentation(object): 66 | """ 67 | Flip images and labels randomly 68 | """ 69 | def __init__(self): 70 | pass 71 | 72 | def __call__(self, arr_list): 73 | image_arr, GT_arr = arr_list 74 | 75 | if random.choices([True, False]): 76 | image_arr = np.flip(image_arr, 1).copy() 77 | GT_arr = np.flip(GT_arr, 1).copy() 78 | 79 | if random.choices([True, False]): 80 | image_arr = np.flip(image_arr, 2).copy() 81 | GT_arr = np.flip(GT_arr,2).copy() 82 | 83 | arr_list = [image_arr, GT_arr] 84 | 85 | return arr_list 86 | 87 | class Rotate_Segmentation(object): 88 | """ 89 | Rotate images and labels randomly 90 | """ 91 | def __init__(self): 92 | pass 93 | 94 | def __call__(self, arr_list): 95 | image_arr, GT_arr = arr_list 96 | 97 | n = random.choices([0, 1, 2, 3]) 98 | image_arr = np.rot90(image_arr, n[0], (1,2)).copy() 99 | GT_arr = np.rot90(GT_arr, n[0], (1,2)).copy() 100 | 101 | arr_list = [image_arr, GT_arr] 102 | 103 | return arr_list 104 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from pathlib import Path 3 | 4 | def writeout_args(args, out_dir): 5 | fout = open(Path(out_dir).joinpath('arguments.csv'), "wt") 6 | csvout = csv.writer(fout) 7 | print('*' * 20) 8 | print('Write out Arguments...') 9 | print('*' * 20) 10 | for arg in dir(args): 11 | if not arg.startswith('_'): 12 | csvout.writerow([arg, str(getattr(args, arg))]) 13 | print('%-25s %-25s' % (arg , str(getattr(args, arg)))) 14 | --------------------------------------------------------------------------------