├── _config.yml ├── images ├── gen.jpg ├── conc_data.png └── framework.jpg ├── model ├── models.py ├── pretrained_resnet.py ├── fusion.py ├── base_model.py ├── networks.py └── trecg_model.py ├── util ├── average_meter.py ├── confusion_matrix.py ├── splitImages.py ├── conc_modalities.py └── utils.py ├── train.sh ├── data ├── __init__.py └── aligned_conc_dataset.py ├── config ├── evalute_resnet18_config.py ├── resnet18_nyud2_config.py ├── default_config.py └── resnet_sunrgbd_config.py ├── evaluate.py ├── README.md ├── train_fusion.py └── train.py /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /images/gen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ownstyledu/Translate-to-Recognize-Networks/HEAD/images/gen.jpg -------------------------------------------------------------------------------- /images/conc_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ownstyledu/Translate-to-Recognize-Networks/HEAD/images/conc_data.png -------------------------------------------------------------------------------- /images/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ownstyledu/Translate-to-Recognize-Networks/HEAD/images/framework.jpg -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | def create_model(cfg, writer=None): 2 | 3 | model_name = cfg.MODEL 4 | print(model_name) 5 | if model_name == 'trecg': 6 | from .trecg_model import TRecgNet 7 | model = TRecgNet(cfg, writer) 8 | elif model_name == 'fusion': 9 | from .fusion import Fusion 10 | model = Fusion(cfg, writer) 11 | else: 12 | raise ValueError('Model {0} not recognized'.format(model_name)) 13 | return model -------------------------------------------------------------------------------- /util/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self): 4 | self.reset() 5 | 6 | def reset(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def update(self, val, n=1): 13 | self.val = val 14 | self.sum += val * n 15 | self.count += n 16 | self.avg = round(self.sum / self.count, 3) -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | mkdir -p logs 2 | 3 | echo Using GPU Device "$1" 4 | export PYTHONUNBUFFERED="True" 5 | 6 | LOG="logs/traininginfo.txt.`date +'%Y-%m-%d_%H-%M-%S'`" 7 | exec &> >(tee -a "$LOG") 8 | echo Logging to "$LOG" 9 | 10 | 11 | starttime=`date +'%Y-%m-%d %H:%M:%S'` 12 | 13 | cat train.sh 14 | for i in $(seq 1 2) 15 | do 16 | 17 | ######## TrecgNet ############ 18 | python train.py 19 | 20 | done 21 | 22 | echo "------------" 23 | endtime=`date +'%Y-%m-%d %H:%M:%S'` 24 | start_seconds=$(date --date="$starttime" +%s); 25 | end_seconds=$(date --date="$endtime" +%s); 26 | echo "start time: "$((end_seconds-start_seconds))"s" 27 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | 3 | class DataProvider(): 4 | 5 | def __init__(self, cfg, dataset, batch_size=None, shuffle=True): 6 | super().__init__() 7 | self.dataset = dataset 8 | if batch_size is None: 9 | batch_size = cfg.BATCH_SIZE 10 | self.dataloader = torch.utils.data.DataLoader( 11 | self.dataset, 12 | batch_size=batch_size, 13 | shuffle=shuffle, 14 | num_workers=int(cfg.WORKERS), 15 | drop_last=False) 16 | 17 | def __len__(self): 18 | return len(self.dataset) 19 | 20 | def __iter__(self): 21 | for i, data in enumerate(self.dataloader): 22 | yield data -------------------------------------------------------------------------------- /config/evalute_resnet18_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | from datetime import datetime 4 | 5 | class EVALUATE_RESNET18_CONFIG: 6 | 7 | def args(self): 8 | args = {'ROOT_DIR': '/home/dudapeng/workspace/trecgnet/summary/resnet18'} 9 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 10 | modality = 'rgb' 11 | task_name = 'inference_' + \ 12 | 'trecgnet_' + modality 13 | log_path = os.path.join(args['ROOT_DIR'], 'sunrgbd', modality, ''.join([task_name, ]), current_time) 14 | 15 | return { 16 | 'LOG_DIR': log_path, 17 | # MODEL 18 | 'ARCH': 'resnet18', 19 | 'PRETRAINED': 'imagenet', 20 | 'NO_UPSAMPLE': False, 21 | 22 | 'NO_FC': False, 23 | 'INFERENCE': True, 24 | 25 | # 'WHICH_DIRECTION': 'AtoB', 26 | 'WHICH_DIRECTION': 'BtoA', # AtoB: RGB->depth, BtoA: depth->RGB 27 | 'GPU_IDS': '6, 7', 28 | 'NUM_CLASSES': 19, 29 | 'BATCH_SIZE': 180, 30 | 'RESUME_PATH': '/home/dudapeng/workspace/trecgnet/resnet18/sample_model/' + 31 | 'trecg_BtoA_best.pth', 32 | 33 | 'EVALUATE': True, 34 | } 35 | -------------------------------------------------------------------------------- /model/pretrained_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | 6 | class ResNet(nn.Module): 7 | 8 | def __init__(self, resnet=None, cfg=None): 9 | super(ResNet, self).__init__() 10 | 11 | if resnet == 'resnet18': 12 | 13 | if cfg.CONTENT_PRETRAINED == 'place': 14 | resnet_model = models.__dict__['resnet18'](num_classes=365) 15 | # places model downloaded from http://places2.csail.mit.edu/ 16 | checkpoint = torch.load(cfg.CONTENT_MODEL_PATH, map_location=lambda storage, loc: storage) 17 | state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items()} 18 | resnet_model.load_state_dict(state_dict) 19 | print('content model pretrained using place') 20 | else: 21 | resnet_model = models.resnet18(True) 22 | print('content model pretrained using imagenet') 23 | 24 | self.conv1 = resnet_model.conv1 25 | self.bn1 = resnet_model.bn1 26 | self.relu = resnet_model.relu 27 | self.maxpool = resnet_model.maxpool 28 | self.layer1 = resnet_model.layer1 29 | self.layer2 = resnet_model.layer2 30 | self.layer3 = resnet_model.layer3 31 | self.layer4 = resnet_model.layer4 32 | 33 | def forward(self, x, out_keys, in_channel=3): 34 | 35 | out = {} 36 | out['0'] = self.relu(self.bn1(self.conv1(x))) 37 | out['1'] = self.layer1(self.maxpool(out['0'])) 38 | out['2'] = self.layer2(out['1']) 39 | out['3'] = self.layer3(out['2']) 40 | out['4'] = self.layer4(out['3']) 41 | return [out[key] for key in out_keys] -------------------------------------------------------------------------------- /util/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | 5 | 6 | def _plot_confusion_matrix(cm,labels, cmap=plt.cm.Blues): 7 | 8 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 9 | xlocations = np.array(range(len(labels))) 10 | plt.xticks([]) 11 | plt.yticks(xlocations, labels) 12 | 13 | def plot_confusion_matrix(y_true, y_pred, save_dir, labels): 14 | 15 | tick_marks = np.array(range(len(labels))) + 0.5 16 | cm = confusion_matrix(y_true, y_pred) 17 | np.set_printoptions(precision=2) 18 | cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 19 | # print (cm_normalized ) 20 | if len(labels) > 10: 21 | plt.figure(figsize=(12, 8), dpi=120) 22 | else: 23 | plt.figure(figsize=(6, 4), dpi=120) 24 | 25 | ind_array = np.arange(len(labels)) 26 | x, y = np.meshgrid(ind_array, ind_array) 27 | for x_val, y_val in zip(x.flatten(), y.flatten()): 28 | c = cm_normalized[y_val][x_val] 29 | if c > 0.01 and c<0.4 : 30 | plt.text(x_val, y_val, "%0.2f" % (c,), color='black', fontsize=6, va='center', ha='center') 31 | if c>=0.4 and c<=1 : 32 | plt.text(x_val, y_val, "%0.2f" % (c,), color='white', fontsize=6, va='center', ha='center') 33 | 34 | plt.gca().set_xticks(tick_marks, minor=False) 35 | plt.gca().set_yticks(tick_marks, minor=False) 36 | plt.gca().xaxis.set_ticks_position('bottom') 37 | plt.gca().yaxis.set_ticks_position('left') 38 | plt.grid(True, which='minor', linestyle='-') 39 | plt.gcf().subplots_adjust(bottom=0.15) 40 | _plot_confusion_matrix(cm_normalized, labels) 41 | plt.savefig(save_dir, format='png') 42 | # plt.show() 43 | -------------------------------------------------------------------------------- /util/splitImages.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | from PIL import Image 3 | 4 | ## split images from train/val/..txt -> train/class1,class2 format 5 | 6 | scene_dir = '/dataset/sun_rgbd/scene/' 7 | data_dir = '/dataset/sun_rgbd/data/' 8 | data_types = ['rgb/', 'hha/'] 9 | target_path = '/dataset/sun_rgbd/data_in_class/' 10 | ext = '.png' 11 | 12 | 13 | def sun_rgbd(): 14 | for phase in [('train/', '19scenes_train.txt'), ('val/', '19scenes_val.txt'), ('test/', '19scenes_test.txt')]: 15 | for type in data_types: 16 | for image_name, label in read_annotation_sunrgbd(scene_dir, phase[1]): 17 | # class_dir = ''.join([target_path, label]) 18 | source_image_path = os.path.join(data_dir, type, image_name, ext) 19 | target_folder = os.path.join(target_path, type, phase[0], label) 20 | target_image_path = os.path.join(target_folder, '/', image_name, ext) 21 | if not os.path.exists(target_folder): 22 | os.makedirs(target_folder) 23 | shutil.copyfile(source_image_path, target_image_path) 24 | print('copying {0} to {1}'.format(source_image_path, target_image_path)) 25 | 26 | def read_annotation_sunrgbd(data_dir, file_path): 27 | 28 | with open(os.path.join(data_dir, file_path)) as f: 29 | for line in f.readlines(): 30 | [image_path, scene_label] = line.strip().split('\t') 31 | yield image_path, scene_label 32 | 33 | def read_annotation_mit67(data_dir, file_path): 34 | 35 | with open(os.path.join(data_dir, file_path)) as f: 36 | for line in f.readlines(): 37 | [scene_label, file_name] = line.strip().split('/') 38 | yield file_name, scene_label 39 | 40 | def jpg2png(): 41 | image_dir = os.path.join(data_dir, 'images') 42 | image_names = os.listdir(image_dir) 43 | for image_name in image_names: 44 | image_path = os.path.join(image_dir, image_name) 45 | if '.jpg' in image_name: 46 | print('processing {0} ...'.format(image_name)) 47 | im = Image.open(image_path) 48 | image_path_new = image_path.replace('.jpg', '.png') 49 | im.save(image_path_new) 50 | 51 | 52 | if __name__ == '__main__': 53 | sun_rgbd() 54 | print('finishied!') 55 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from functools import reduce 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | from tensorboardX import SummaryWriter 9 | 10 | import data.aligned_conc_dataset as dataset 11 | import util.utils as util 12 | from config.default_config import DefaultConfig 13 | from config.evalute_resnet18_config import EVALUATE_RESNET18_CONFIG 14 | from data import DataProvider 15 | from model.networks import TRecgNet_Upsample_Resiual 16 | 17 | cfg = DefaultConfig() 18 | args = { 19 | 'resnet18': EVALUATE_RESNET18_CONFIG().args(), 20 | } 21 | 22 | # args for different backbones 23 | cfg.parse(args['resnet18']) 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.GPU_IDS 26 | device_ids = torch.cuda.device_count() 27 | print('device_ids:', device_ids) 28 | project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1]) 29 | util.mkdir('logs') 30 | 31 | val_dataset = dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_VAL, transform=transforms.Compose([ 32 | dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)), 33 | dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)), 34 | dataset.ToTensor(), 35 | dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 36 | 37 | ])) 38 | batch_size_val = cfg.BATCH_SIZE 39 | 40 | val_loader = DataProvider(cfg, dataset=val_dataset, batch_size=batch_size_val, shuffle=False) 41 | writer = SummaryWriter(log_dir=cfg.LOG_PATH) # tensorboard 42 | model = TRecgNet_Upsample_Resiual(cfg, writer) 43 | model.set_data_loader(None, val_loader, None) 44 | model.net = nn.DataParallel(model.net).to(model.device) 45 | model.set_log_data(cfg) 46 | 47 | def evaluate(): 48 | 49 | checkpoint_path = os.path.join(cfg.CHECKPOINTS_DIR, cfg.RESUME_PATH) 50 | checkpoint = torch.load(checkpoint_path) 51 | load_epoch = checkpoint['epoch'] 52 | model.load_checkpoint(model.net, checkpoint_path, checkpoint, data_para=False) 53 | cfg.START_EPOCH = load_epoch 54 | 55 | print('>>> task path is {0}'.format(project_name)) 56 | 57 | model.evaluate(cfg) 58 | 59 | if writer is not None: 60 | writer.close() 61 | 62 | if __name__ == '__main__': 63 | start_time = time.time() 64 | evaluate() 65 | print('time consumption: {0} secs', time.time() - start_time) 66 | -------------------------------------------------------------------------------- /util/conc_modalities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | 6 | parser = argparse.ArgumentParser('create image pairs') 7 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='/dataset/sun_rgbd/data/images') 8 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='/dataset/sun_rgbd/data/hha') 9 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='/dataset/sun_rgbd/data_in_class/conc_data') 10 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) 11 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') 12 | args = parser.parse_args() 13 | 14 | for arg in vars(args): 15 | print('[%s] = ' % arg, getattr(args, arg)) 16 | 17 | splits = os.listdir(args.fold_A) 18 | 19 | for sp in splits: 20 | print('processing phase {0} ...'.format(sp)) 21 | fold_A = os.path.join(args.fold_A, sp) 22 | fold_B = os.path.join(args.fold_B, sp) 23 | catogaries = os.listdir(fold_A) 24 | for cato in catogaries: 25 | img_fold_A = os.path.join(fold_A, cato) 26 | img_fold_B = os.path.join(fold_B, cato) 27 | img_list_A = os.listdir(img_fold_A) 28 | img_list_B = os.listdir(img_fold_B) 29 | if len(img_list_A) != len(img_list_B): 30 | raise ValueError('number of images in A is not equal to B\'s, A\'s path is {0}'.format(img_fold_A)) 31 | 32 | img_fold_AB = os.path.join(args.fold_AB, sp, cato) 33 | if not os.path.isdir(img_fold_AB): 34 | os.makedirs(img_fold_AB) 35 | 36 | # print('split = %s, number of images = %d' % (sp, num_imgs)) 37 | for n in range(len(img_list_A)): 38 | name_A = img_list_A[n] 39 | path_A = os.path.join(img_fold_A, name_A) 40 | if args.use_AB: 41 | name_B = name_A.replace('_A.', '_B.') 42 | else: 43 | name_B = name_A 44 | path_B = os.path.join(img_fold_B, name_B) 45 | if os.path.isfile(path_A) and os.path.isfile(path_B): 46 | name_AB = name_A 47 | if args.use_AB: 48 | name_AB = name_AB.replace('_A.', '.') # remove _A 49 | path_AB = os.path.join(img_fold_AB, name_AB) 50 | im_A = cv2.imread(path_A, cv2.IMREAD_COLOR) 51 | im_B = cv2.imread(path_B, cv2.IMREAD_COLOR) 52 | im_AB = np.concatenate([im_A, im_B], 1) 53 | cv2.imwrite(path_AB, im_AB) 54 | 55 | print('finished!') 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Translate-to-Recognize Networks 2 | 3 | Pytorch implementations of Translate-to-Recognize Networks for RGB-D Scene Recognition (CVPR 2019). 4 | 5 | ![blockchain](images/framework.jpg) 6 | ![blockchain](images/gen.jpg) 7 | ## Usage 8 | 1. Download Reset18 pre-trained on [Places dataset](https://github.com/CSAILVision/places365) if necessary. 9 | 2. Data processing. 10 | * We use ImageFolder format, i.e., [class1/images.., class2/images..], to store the data, 11 | use ***util.splitimages.py*** to help change the format if neccessary. 12 | * Use ***util.conc_modalities.py*** to concatenate each paired RGB and depth images to one image for more efficient data loading. An example is shown below(depth data is encoded using HHA format). 13 | ![blockchain](images/conc_data.png) 14 | * We provide links to download SUN RGB-D data in ImageFolder format and depth data has been encoded using HHA format. 15 | 1. RGB and depth data is concatenated, [link](http://mcg.nju.edu.cn/dataset/sun-rgbd_conc.tar) 16 | 2. RGB and depth data is stored separately, [link](http://mcg.nju.edu.cn/dataset/sun-rgbd_split.tar) 17 | 3. ***[Updated 2019.11.5]*** If the links above are not accessible, try (https://pan.baidu.com/s/1LZIF1hlT3k0oX76Ttp660w) The extraction code is: g5vp 18 | 19 | 20 | 3. Configuration. 21 | Almost all the settings of experiments are configurable by the files in the ***config*** package. 22 | 4. Train. 23 | `python train.py` or `bash train.sh` 24 | 5. [~~\!\!New\!\!~~] New branch 'multi-gpu' has been uploaded, making losses calculated on each gpu for better balanced usage of multi gpus. 25 | You could use this version using this command: \ 26 | `git clone -b multi-gpu https://github.com/ownstyledu/Translate-to-Recognize-Networks.git TrecgNet` 27 | 6. [~~\!\!New\!\!~~] In multi-gpu brach, we add more loss types in the training, e.g., GAN, pixel2pixel intensity. You could easily add these losses by modifying the config file. 28 | 7. ***[\!\!New 2019.7.31\!\!]*** We added the fusion model, and reconstructed the code, including reusing the base model and making some modifications on decoder structure and config file. 29 | ***Due to the time limitation, we don't simultaneously update the multi-gpu branch. But you could still refer to it if you want.*** 30 | ## Development Environment 31 | * NVIDIA TITAN XP 32 | * cuda 9.0 33 | * python 3.6.5 34 | * pytorch 0.4.1 35 | * torchvision 0.2.1 36 | * tensorboardX 37 | 38 | ## Citation 39 | Please cite the following paper if you feel this repository useful. 40 | ``` 41 | @inproceedings{du2019translate, 42 | title={Translate-to-Recognize Networks for RGB-D Scene Recognition}, 43 | author={Du, Dapeng and Wang, Limin and Wang, Huiling and Zhao, Kai and Wu, Gangshan}, 44 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 45 | pages={11836--11845}, 46 | year={2019} 47 | } 48 | 49 | ``` 50 | -------------------------------------------------------------------------------- /config/resnet18_nyud2_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | 5 | class RESNET18_SUNRGBD_CONFIG: 6 | 7 | def args(self): 8 | args = {'ROOT_DIR': '/home/dudapeng/workspace/trecgnet/summary/resnet18'} 9 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 10 | 11 | ########### Quick Setup ############ 12 | 13 | modality = 'rgb' 14 | task_name = 'specific_task_name' 15 | lr_schedule = 'lambda' # lambda|step|plateau 16 | pretrained = 'place' 17 | content_pretrained = 'place' 18 | gpus = '7' # gpu no. you can add more gpus with comma, e.g., '0,1,2' 19 | batch_size = 46 20 | direction = 'AtoB' # AtoB: RGB->Depth 21 | # direction = 'BtoA' 22 | loss = ['CLS','SEMANTIC'] # remove 'SEMANTIC' if trained with unlabeled data 23 | no_upsample = False # True for removing Decoder network 24 | unlabeled = False # True for training with unlabeled data 25 | content_layers = '0,1,2,3,4' # layer-wise semantic layers 26 | 27 | len_gpu = str(len(gpus.split(','))) 28 | 29 | # use generated data while training 30 | use_fake = False 31 | sample_path = os.path.join('/home/dudapeng/workspace/trecgnet/resnet18/sample_model/', content_pretrained, 32 | 'trecg_AtoB_best.pth') 33 | resume = False 34 | resume_path = os.path.join('/home/dudapeng/workspace/trecgnet/resnet18/sample_model/', content_pretrained, 35 | '10k_place_AtoB.pth') 36 | 37 | log_path = os.path.join(args['ROOT_DIR'], 'nyud2', modality, content_pretrained, 38 | ''.join([task_name, '_', lr_schedule, '_', 'gpu('+len_gpu+')' 39 | ]), current_time) 40 | return { 41 | 42 | 'GPU_IDS': gpus, 43 | 'WHICH_DIRECTION': direction, 44 | 'BATCH_SIZE': batch_size, 45 | 'LOSS_TYPES': loss, 46 | 'PRETRAINED': pretrained, 47 | 48 | 'LOG_PATH': log_path, 49 | 50 | # MODEL 51 | 'ARCH': 'resnet18', 52 | 'SAVE_BEST': True, 53 | 'NO_UPSAMPLE': no_upsample, 54 | 55 | #### DATA 56 | 'DATA_DIR_TRAIN': '/home/dudapeng/workspace/datasets/nyud2/conc_data/train', 57 | 'DATA_DIR_VAL': '/home/dudapeng/workspace/datasets/nyud2/conc_data/test', 58 | 'NUM_CLASSES': 10, 59 | 'UNLABELED': unlabeled, 60 | 'USE_FAKE_DATA': use_fake, 61 | 'SAMPLE_MODEL_PATH': sample_path, 62 | 63 | # TRAINING / TEST 64 | 'RESUME': resume, 65 | 'INIT_EPOCH': True, 66 | 'RESUME_PATH': resume_path, 67 | 'LR_POLICY': lr_schedule, 68 | 69 | 'NITER': 20, 70 | 'NITER_DECAY': 80, 71 | 'NITER_TOTAL': 100, 72 | 'EVALUATE': False, 73 | 74 | # translation task 75 | 'WHICH_CONTENT_NET': 'resnet18', 76 | 'CONTENT_LAYERS': content_layers, 77 | 'CONTENT_PRETRAINED': content_pretrained, 78 | 'ALPHA_CONTENT': 10 79 | } 80 | -------------------------------------------------------------------------------- /config/default_config.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | # base configuration, might be considered as the abstract class 4 | class DefaultConfig: 5 | # GPU / CPU 6 | GPU_IDS = None # slipt different gpus with comma 7 | nTHREADS = 8 8 | WORKERS = 8 9 | 10 | # MODEL 11 | MODEL = 'trecg' 12 | ARCH = 'vgg11_bn' 13 | PRETRAINED = 'imagenet' 14 | CONTENT_PRETRAINED = 'imagenet' 15 | NO_UPSAMPLE = False # set True when evaluating baseline 16 | FIX_GRAD = False 17 | IN_CONC = False # if True, change input_nc from 3 to specific ones 18 | 19 | # PATH 20 | DATA_DIR_TRAIN = '/home/dudapeng/workspace/datasets/sun_rgbd/data_in_class_mix/conc_data/train' 21 | DATA_DIR_VAL = '/home/dudapeng/workspace/datasets/sun_rgbd/data_in_class_mix/conc_data/test' 22 | DATA_DIR_UNLABELED = '/home/dudapeng/workspace/datasets/nyud2/mix/conc_data/10k_conc_bak' 23 | SAMPLE_MODEL_PATH = None 24 | CHECKPOINTS_DIR = './checkpoints' 25 | ROOT_DIR = '/home/dudapeng/workspace/trecgnet/' 26 | SUMMARY_DIR_ROOT = '/home/dudapeng/workspace/trecgnet/summary/' 27 | LOG_PATH = None 28 | CONTENT_MODEL_PATH = '' 29 | 30 | # DATA 31 | DATA_TYPE = 'pair' # pair | single 32 | WHICH_DIRECTION = None 33 | NUM_CLASSES = 19 34 | BATCH_SIZE = 48 35 | LOAD_SIZE = 256 36 | FINE_SIZE = 224 37 | FLIP = True 38 | UNLABELED = False 39 | FIVE_CROP = False 40 | FAKE_DATA_RATE = 0.3 41 | 42 | # OPTIMIZATION 43 | LR = 2e-4 44 | WEIGHT_DECAY = 1e-4 45 | MOMENTUM = 0.9 46 | LR_POLICY = 'plateau' # lambda|step|plateau 47 | 48 | # TRAINING / TEST 49 | PHASE = 'train' 50 | RESUME = False 51 | RESUME_PATH = None 52 | RESUME_PATH_A = None 53 | RESUME_PATH_B = None 54 | NO_FC = True 55 | INIT_EPOCH = True # True for load pretrained parameters, False for resume the last training 56 | START_EPOCH = 1 57 | ROUND = 1 58 | MANUAL_SEED = None 59 | NITER = 10 60 | NITER_DECAY = 40 61 | NITER_TOTAL = 50 62 | LOSS_TYPES = [] # SEMANTIC_CONTENT, PIX2PIX, GAN 63 | EVALUATE = True 64 | USE_FAKE_DATA = False 65 | CLASS_WEIGHTS_TRAIN = None 66 | PRINT_FREQ = 100 67 | NO_VIS = False 68 | CAL_LOSS = True 69 | SAVE_BEST = False 70 | INFERENCE = False 71 | 72 | # classfication task 73 | ALPHA_CLS = 1 74 | 75 | # translation task 76 | WHICH_CONTENT_NET = 'vgg11_bn' 77 | CONTENT_LAYERS = ['l0', 'l1', 'l2'] 78 | NITER_START_CONTENT = 1 79 | NITER_END_CONTENT = 200 80 | ALPHA_CONTENT = 10 81 | 82 | # GAN task 83 | NO_LSGAN = True # False: least square gan loss, True: BCE loss 84 | NITER_START_GAN = 1 85 | NITER_END_GAN = 200 86 | ALPHA_GAN = 1 87 | 88 | # Pix2Pix 89 | NITER_START_PIX2PIX = 1 90 | NITER_END_PIX2PIX = 200 91 | ALPHA_PIX2PIX = 5 92 | 93 | def parse(self, kwargs): 94 | for k, v in kwargs.items(): 95 | if not hasattr(self, k): 96 | warnings.warn("Warning: opt has not attribut {0}".format(k)) 97 | setattr(self, k, v) 98 | 99 | print('user config:') 100 | for k, v in self.__class__.__dict__.items(): 101 | if not k.startswith('__'): 102 | print(k, ':', getattr(self, k)) 103 | -------------------------------------------------------------------------------- /train_fusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import Counter 4 | from functools import reduce 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from tensorboardX import SummaryWriter 9 | 10 | import data.aligned_conc_dataset as conc_dataset 11 | import util.utils as util 12 | from config.default_config import DefaultConfig 13 | from config.resnet_sunrgbd_config import RESNET_SUNRGBD_CONFIG 14 | from data import DataProvider 15 | from model.models import create_model 16 | 17 | cfg = DefaultConfig() 18 | args = { 19 | # model should be defined as 'fusion' and set paths for RESUME_PATH_A and RESUME_PATH_B 20 | 'resnet_sunrgbd': RESNET_SUNRGBD_CONFIG().args(), 21 | } 22 | 23 | # Setting random seed 24 | if cfg.MANUAL_SEED is None: 25 | cfg.MANUAL_SEED = random.randint(1, 10000) 26 | random.seed(cfg.MANUAL_SEED) 27 | torch.manual_seed(cfg.MANUAL_SEED) 28 | 29 | # args for different backbones 30 | cfg.parse(args['resnet_sunrgbd']) 31 | 32 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.GPU_IDS 33 | device_ids = torch.cuda.device_count() 34 | print('device_ids:', device_ids) 35 | project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1]) 36 | util.mkdir('logs') 37 | 38 | train_dataset = None 39 | val_dataset = None 40 | unlabeled_dataset = None 41 | train_loader = None 42 | val_loader = None 43 | unlabeled_loader = None 44 | 45 | train_transforms = list() 46 | train_transforms.append(conc_dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE))) 47 | train_transforms.append(conc_dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE))) 48 | train_transforms.append(conc_dataset.RandomHorizontalFlip()) 49 | train_transforms.append(conc_dataset.ToTensor()) 50 | train_transforms.append(conc_dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 51 | 52 | val_transforms = list() 53 | val_transforms.append(conc_dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE))) 54 | val_transforms.append(conc_dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE))) 55 | val_transforms.append(conc_dataset.ToTensor()) 56 | val_transforms.append(conc_dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 57 | 58 | if cfg.DATA_TYPE == 'pair': 59 | 60 | train_dataset = conc_dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_TRAIN, 61 | transform=transforms.Compose(train_transforms)) 62 | val_dataset = conc_dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_VAL, 63 | transform=transforms.Compose(val_transforms)) 64 | 65 | train_loader = DataProvider(cfg, dataset=train_dataset) 66 | val_loader = DataProvider(cfg, dataset=val_dataset, shuffle=False) 67 | # class weights 68 | num_classes_train = list(Counter([i[1] for i in train_loader.dataset.imgs]).values()) 69 | cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train) 70 | 71 | writer = SummaryWriter(log_dir=cfg.LOG_PATH) # tensorboard 72 | model = create_model(cfg, writer) 73 | model.set_data_loader(train_loader, val_loader, unlabeled_loader) 74 | 75 | 76 | def train(): 77 | 78 | print('>>> task path is {0}'.format(project_name)) 79 | 80 | # train 81 | model.train_parameters(cfg) 82 | 83 | print('save model ...') 84 | model_filename = '{0}_{1}_finish.pth'.format(cfg.MODEL, cfg.WHICH_DIRECTION) 85 | model.save_checkpoint(cfg.NITER_TOTAL, model_filename) 86 | 87 | if writer is not None: 88 | writer.close() 89 | 90 | 91 | if __name__ == '__main__': 92 | train() 93 | -------------------------------------------------------------------------------- /config/resnet_sunrgbd_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class RESNET_SUNRGBD_CONFIG: 4 | 5 | def args(self): 6 | 7 | ########### Quick Setup ############ 8 | model = 'fusion' # | fusion 9 | arch = 'resnet18' # | resnet50 10 | content_model = 'resnet18' # | resnet50 11 | pretrained = 'place' # places model downloaded from http://places2.csail.mit.edu/ | imagenet 12 | content_pretrained = 'place' # | imagenet 13 | 14 | gpus = '0' # gpu no. you can add more gpus with comma, e.g., '0,1,2' 15 | batch_size = 40 16 | 17 | log_path = 'summary' # path for tensorboardX log file 18 | lr_schedule = 'lambda' # lambda|step|plateau 19 | lr = 2e-4 20 | 21 | direction = 'AtoB' # AtoB: RGB->Depth 22 | loss = ['CLS', 'SEMANTIC'] # remove 'CLS' if trained with unlabeled data 23 | no_upsample = False # True for removing Decoder network 24 | unlabeled = False # True for training with unlabeled data 25 | content_model_path = None # places model downloaded from http://places2.csail.mit.edu/ 26 | content_layers = '0,1,2,3,4' # layer-wise semantic layers, you can change it to better adapt your task 27 | alpha_content = 10 # coefficient for content loss 28 | fix_grad = False 29 | 30 | # use generated data while training 31 | use_fake = False 32 | sample_path = None # path of saved TrecgNet model for generating fake images 33 | resume = False 34 | resume_path = None # path of loading TrecgNet model 35 | 36 | # if we do fusion, we need two tregnets 37 | resume_path_A = None # the path for RGB TrecgNet 38 | resume_path_B = None # the path for Depth TrecgNet 39 | 40 | # current_time = datetime.now().strftime('%b%d_%H-%M-%S') 41 | # summary_dir_root = '/home/dudapeng/workspace/trecgnet/summary/resnet18' # dir for saving files from tensorboardX 42 | # modality = 'rgb' 43 | # task_name = 'your_task_name' # name for tensorboardX log file 44 | # len_gpu = str(len(gpus.split(','))) 45 | # log_path = os.path.join(summary_dir_root, 'sunrgbd', modality, content_pretrained, 46 | # ''.join([task_name, '_', lr_schedule, '_', 'gpu('+len_gpu+')' 47 | # ]), current_time) 48 | return { 49 | 50 | 'GPU_IDS': gpus, 51 | 'WHICH_DIRECTION': direction, 52 | 'BATCH_SIZE': batch_size, 53 | 'LOSS_TYPES': loss, 54 | 'PRETRAINED': pretrained, 55 | 56 | 'LOG_PATH': log_path, 57 | 58 | # MODEL 59 | 'MODEL': model, 60 | 'ARCH': arch, 61 | 'SAVE_BEST': True, 62 | 'NO_UPSAMPLE': no_upsample, 63 | 'FIX_GRAD': fix_grad, 64 | 65 | # DATA 66 | 'NUM_CLASSES': 19, 67 | 'UNLABELED': unlabeled, 68 | 'USE_FAKE_DATA': use_fake, 69 | 'SAMPLE_MODEL_PATH': sample_path, 70 | 'CONTENT_MODEL_PATH': content_model_path, 71 | 72 | # TRAINING / TEST 73 | 'RESUME': resume, 74 | 'INIT_EPOCH': True, 75 | 'RESUME_PATH': resume_path, 76 | 'RESUME_PATH_A': resume_path_A, 77 | 'RESUME_PATH_B': resume_path_B, 78 | 'LR_POLICY': lr_schedule, 79 | 'LR': lr, 80 | 81 | 'NITER': 20, 82 | 'NITER_DECAY': 80, 83 | 'NITER_TOTAL': 100, 84 | 'EVALUATE': True, # True if you want to check the test result after each epoch 85 | 86 | # TRANSLATION TASK 87 | 'WHICH_CONTENT_NET': content_model, 88 | 'CONTENT_LAYERS': content_layers, 89 | 'CONTENT_PRETRAINED': content_pretrained, 90 | 'ALPHA_CONTENT': alpha_content 91 | } 92 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import Counter 4 | from functools import reduce 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from tensorboardX import SummaryWriter 9 | 10 | import data.aligned_conc_dataset as dataset 11 | import util.utils as util 12 | from config.default_config import DefaultConfig 13 | from config.resnet_sunrgbd_config import RESNET_SUNRGBD_CONFIG 14 | from data import DataProvider 15 | from model.models import create_model 16 | 17 | cfg = DefaultConfig() 18 | args = { 19 | 'resnet_sunrgbd': RESNET_SUNRGBD_CONFIG().args(), 20 | } 21 | 22 | # Setting random seed 23 | if cfg.MANUAL_SEED is None: 24 | cfg.MANUAL_SEED = random.randint(1, 10000) 25 | random.seed(cfg.MANUAL_SEED) 26 | torch.manual_seed(cfg.MANUAL_SEED) 27 | 28 | # args for different backbones 29 | cfg.parse(args['resnet_sunrgbd']) 30 | 31 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.GPU_IDS 32 | device_ids = torch.cuda.device_count() 33 | print('device_ids:', device_ids) 34 | project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1]) 35 | util.mkdir('logs') 36 | 37 | # data 38 | train_dataset = dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_TRAIN, transform=transforms.Compose([ 39 | dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)), 40 | dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)), 41 | dataset.RandomHorizontalFlip(), 42 | dataset.ToTensor(), 43 | dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 44 | 45 | ])) 46 | 47 | val_dataset = dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_VAL, transform=transforms.Compose([ 48 | dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)), 49 | dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)), 50 | dataset.ToTensor(), 51 | dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 52 | 53 | ])) 54 | batch_size_val = cfg.BATCH_SIZE 55 | 56 | unlabeled_loader = None 57 | if cfg.UNLABELED: 58 | unlabeled_dataset = dataset.AlignedConcDataset(cfg, data_dir=cfg.DATA_DIR_UNLABELED, transform=transforms.Compose([ 59 | dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)), 60 | dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)), 61 | dataset.RandomHorizontalFlip(), 62 | dataset.ToTensor(), 63 | dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] 64 | 65 | )]), labeled=False) 66 | 67 | unlabeled_loader = DataProvider(cfg, dataset=unlabeled_dataset) 68 | 69 | train_loader = DataProvider(cfg, dataset=train_dataset) 70 | val_loader = DataProvider(cfg, dataset=val_dataset, batch_size=batch_size_val, shuffle=False) 71 | 72 | # class weights 73 | num_classes_train = list(Counter([i[1] for i in train_loader.dataset.imgs]).values()) 74 | cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train) 75 | 76 | writer = SummaryWriter(log_dir=cfg.LOG_PATH) # tensorboard 77 | model = create_model(cfg, writer) 78 | model.set_data_loader(train_loader, val_loader, unlabeled_loader) 79 | 80 | 81 | def train(): 82 | 83 | if cfg.RESUME: 84 | checkpoint_path = os.path.join(cfg.CHECKPOINTS_DIR, cfg.RESUME_PATH) 85 | checkpoint = model.load_checkpoint(model.net, checkpoint_path, keep_kw_module=False, keep_fc=True) 86 | load_epoch = checkpoint['epoch'] 87 | cfg.START_EPOCH = load_epoch 88 | 89 | if cfg.INIT_EPOCH: 90 | # just load pretrained parameters 91 | print('load checkpoint from another source') 92 | cfg.START_EPOCH = 1 93 | 94 | print('>>> task path is {0}'.format(project_name)) 95 | 96 | # train 97 | model.train_parameters(cfg) 98 | 99 | print('save model ...') 100 | model_filename = '{0}_{1}_{2}.pth'.format(cfg.MODEL, cfg.WHICH_DIRECTION, cfg.NITER_TOTAL) 101 | model.save_checkpoint(cfg.NITER_TOTAL, model_filename) 102 | 103 | if writer is not None: 104 | writer.close() 105 | 106 | 107 | if __name__ == '__main__': 108 | train() 109 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from sklearn.metrics import accuracy_score 6 | 7 | def mkdirs(paths): 8 | if isinstance(paths, list) and not isinstance(paths, str): 9 | for path in paths: 10 | mkdir(path) 11 | else: 12 | mkdir(paths) 13 | 14 | 15 | def mkdir(path): 16 | if not os.path.exists(path): 17 | os.makedirs(path) 18 | 19 | 20 | def get_images(dir, extensions): 21 | images = [] 22 | dir = os.path.expanduser(dir) 23 | image_names = [d for d in os.listdir(dir)] 24 | for image_name in image_names: 25 | if has_file_allowed_extension(image_name, extensions): 26 | file = os.path.join(dir, image_name) 27 | images.append(file) 28 | return images 29 | 30 | 31 | #Checks if a file is an allowed extension. 32 | def has_file_allowed_extension(filename, extensions): 33 | filename_lower = filename.lower() 34 | return any(filename_lower.endswith(ext) for ext in extensions) 35 | 36 | 37 | def normalize_batch(batch): 38 | # normalize using imagenet mean and std 39 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 40 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 41 | batch = batch.div_(255.0) 42 | return (batch - mean) / std 43 | 44 | 45 | def plot_confusion_matrix(cm, classes, 46 | normalize=False, 47 | title='Confusion matrix', 48 | cmap=plt.cm.Blues): 49 | if normalize: 50 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 51 | print("Normalized confusion matrix") 52 | else: 53 | print('Confusion matrix, without normalization') 54 | 55 | print(cm) 56 | 57 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 58 | plt.title(title) 59 | plt.colorbar() 60 | tick_marks = np.arange(len(classes)) 61 | plt.xticks(tick_marks, classes, rotation=45) 62 | plt.yticks(tick_marks, classes) 63 | 64 | fmt = '.2f' if normalize else 'd' 65 | thresh = cm.max() / 2. 66 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 67 | plt.text(j, i, format(cm[i, j], fmt), 68 | horizontalalignment="center", 69 | color="white" if cm[i, j] > thresh else "black") 70 | 71 | plt.ylabel('True label') 72 | plt.xlabel('Predicted label') 73 | plt.tight_layout() 74 | 75 | 76 | def accuracy(output, target, topk=(1,)): 77 | """Computes the precision@k for the specified values of k""" 78 | maxk = max(topk) 79 | batch_size = target.size(0) 80 | _, pred = output.topk(maxk, 1, True, True) 81 | pred = pred.t() 82 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 83 | 84 | res = [] 85 | for k in topk: 86 | correct_k = correct[:k].view(-1).float().sum(0) 87 | res.append(correct_k.mul_(100.0 / batch_size)) 88 | return res 89 | 90 | def mean_acc(target_indice, pred_indice, num_classes, classes=None): 91 | assert(num_classes == len(classes)) 92 | acc = 0. 93 | print('{0} Class Acc Report {1}'.format('#' * 10, '#' * 10)) 94 | for i in range(num_classes): 95 | idx = np.where(target_indice == i)[0] 96 | # acc = acc + accuracy_score(target_indice[idx], pred_indice[idx]) 97 | class_correct = accuracy_score(target_indice[idx], pred_indice[idx]) 98 | acc += class_correct 99 | print('acc {0}: {1:.3f}'.format(classes[i], class_correct * 100)) 100 | 101 | # class report 102 | # y_tpye, y_true, y_pred = _check_targets(target_indice[idx], pred_indice[idx]) 103 | # score = y_true == y_pred 104 | # wrong_index = np.where(score == False)[0] 105 | # for j in idx[wrong_index]: 106 | # print("Wrong for class [%s]: predicted as: <%s>, image_id--<%s>" % 107 | # (int_to_class[i], int_to_class[pred[j]], image_paths[j])) 108 | # 109 | # print("[class] %s accuracy is %.3f" % (int_to_class[i], class_correct)) 110 | print('#' * 30) 111 | return (acc / num_classes) * 100 112 | 113 | def process_output(output): 114 | # Computes the result and argmax index 115 | pred, index = output.topk(1, 1, largest=True) 116 | 117 | return pred.cpu().float().numpy().flatten(), index.cpu().numpy().flatten() -------------------------------------------------------------------------------- /data/aligned_conc_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | from PIL import ImageFile 7 | from torchvision.datasets.folder import find_classes 8 | from torchvision.datasets.folder import make_dataset 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | import torch 12 | 13 | import util.utils as utils 14 | from torchvision.transforms import functional as F 15 | import copy 16 | 17 | 18 | class AlignedConcDataset: 19 | 20 | def __init__(self, cfg, data_dir=None, transform=None, labeled=True): 21 | self.cfg = cfg 22 | self.transform = transform 23 | self.data_dir = data_dir 24 | self.labeled = labeled 25 | 26 | if labeled: 27 | self.classes, self.class_to_idx = find_classes(self.data_dir) 28 | self.int_to_class = dict(zip(range(len(self.classes)), self.classes)) 29 | self.imgs = make_dataset(self.data_dir, self.class_to_idx, ['jpg','png']) 30 | else: 31 | self.imgs = utils.get_images(self.data_dir, ['jpg', 'png']) 32 | 33 | def __len__(self): 34 | return len(self.imgs) 35 | 36 | def __getitem__(self, index): 37 | if self.labeled: 38 | img_path, label = self.imgs[index] 39 | else: 40 | img_path = self.imgs[index] 41 | 42 | img_name = os.path.basename(img_path) 43 | AB_conc = Image.open(img_path).convert('RGB') 44 | 45 | # split RGB and Depth as A and B 46 | w, h = AB_conc.size 47 | w2 = int(w / 2) 48 | if w2 > self.cfg.FINE_SIZE: 49 | A = AB_conc.crop((0, 0, w2, h)).resize((self.cfg.LOAD_SIZE, self.cfg.LOAD_SIZE), Image.BICUBIC) 50 | B = AB_conc.crop((w2, 0, w, h)).resize((self.cfg.LOAD_SIZE, self.cfg.LOAD_SIZE), Image.BICUBIC) 51 | else: 52 | A = AB_conc.crop((0, 0, w2, h)) 53 | B = AB_conc.crop((w2, 0, w, h)) 54 | 55 | if self.labeled: 56 | sample = {'A': A, 'B': B, 'img_name': img_name, 'label': label} 57 | else: 58 | sample = {'A': A, 'B': B, 'img_name': img_name} 59 | 60 | if self.transform: 61 | sample = self.transform(sample) 62 | 63 | return sample 64 | 65 | 66 | class RandomCrop(transforms.RandomCrop): 67 | 68 | def __call__(self, sample): 69 | A, B = sample['A'], sample['B'] 70 | 71 | if self.padding > 0: 72 | A = F.pad(A, self.padding) 73 | B = F.pad(B, self.padding) 74 | 75 | # pad the width if needed 76 | if self.pad_if_needed and A.size[0] < self.size[1]: 77 | A = F.pad(A, (int((1 + self.size[1] - A.size[0]) / 2), 0)) 78 | B = F.pad(B, (int((1 + self.size[1] - B.size[0]) / 2), 0)) 79 | # pad the height if needed 80 | if self.pad_if_needed and A.size[1] < self.size[0]: 81 | A = F.pad(A, (0, int((1 + self.size[0] - A.size[1]) / 2))) 82 | B = F.pad(B, (0, int((1 + self.size[0] - B.size[1]) / 2))) 83 | 84 | i, j, h, w = self.get_params(A, self.size) 85 | sample['A'] = F.crop(A, i, j, h, w) 86 | sample['B'] = F.crop(B, i, j, h, w) 87 | 88 | # _i, _j, _h, _w = self.get_params(A, self.size) 89 | # sample['A'] = F.crop(A, i, j, h, w) 90 | # sample['B'] = F.crop(B, _i, _j, _h, _w) 91 | 92 | return sample 93 | 94 | 95 | class CenterCrop(transforms.CenterCrop): 96 | 97 | def __call__(self, sample): 98 | A, B = sample['A'], sample['B'] 99 | sample['A'] = F.center_crop(A, self.size) 100 | sample['B'] = F.center_crop(B, self.size) 101 | return sample 102 | 103 | 104 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip): 105 | 106 | def __call__(self, sample): 107 | A, B = sample['A'], sample['B'] 108 | if random.random() > 0.5: 109 | A = F.hflip(A) 110 | B = F.hflip(B) 111 | 112 | sample['A'] = A 113 | sample['B'] = B 114 | 115 | return sample 116 | 117 | 118 | class Resize(transforms.Resize): 119 | 120 | def __call__(self, sample): 121 | A, B = sample['A'], sample['B'] 122 | h = self.size[0] 123 | w = self.size[1] 124 | 125 | sample['A'] = F.resize(A, (h, w)) 126 | sample['B'] = F.resize(B, (h, w)) 127 | 128 | return sample 129 | 130 | 131 | class ToTensor(object): 132 | def __call__(self, sample): 133 | 134 | A, B = sample['A'], sample['B'] 135 | 136 | # if isinstance(sample, dict): 137 | # for key, value in sample: 138 | # _list = sample[key] 139 | # sample[key] = [F.to_tensor(item) for item in _list] 140 | 141 | sample['A'] = F.to_tensor(A) 142 | sample['B'] = F.to_tensor(B) 143 | 144 | return sample 145 | 146 | class Normalize(transforms.Normalize): 147 | 148 | def __call__(self, sample): 149 | A, B = sample['A'], sample['B'] 150 | sample['A'] = F.normalize(A, self.mean, self.std) 151 | sample['B'] = F.normalize(B, self.mean, self.std) 152 | 153 | return sample 154 | 155 | class Lambda(transforms.Lambda): 156 | 157 | def __call__(self, sample): 158 | return self.lambd(sample) -------------------------------------------------------------------------------- /model/fusion.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import random 5 | import time 6 | from collections import defaultdict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import util.utils as util 12 | from util.average_meter import AverageMeter 13 | from . import networks 14 | from .base_model import BaseModel 15 | 16 | 17 | class Fusion(BaseModel): 18 | 19 | def __init__(self, cfg, writer): 20 | 21 | super(Fusion, self).__init__(cfg) 22 | cfg_tmp = copy.deepcopy(cfg) 23 | cfg_tmp.model = 'trecg' 24 | self.net_rgb = networks.define_TrecgNet(cfg_tmp, upsample=False, device=self.device) 25 | self.net_depth = networks.define_TrecgNet(cfg_tmp, upsample=False, device=self.device) 26 | 27 | # load parameters 28 | checkpoint_path_A = os.path.join(self.cfg.CHECKPOINTS_DIR, self.cfg.RESUME_PATH_A) 29 | checkpoint_path_B = os.path.join(self.cfg.CHECKPOINTS_DIR, self.cfg.RESUME_PATH_B) 30 | super().load_checkpoint(net=self.net_rgb, checkpoint_path=checkpoint_path_A, keep_kw_module=False) 31 | super().load_checkpoint(net=self.net_depth, checkpoint_path=checkpoint_path_B, keep_kw_module=False) 32 | 33 | # fusion model 34 | self.net = networks.Fusion(cfg, self.net_rgb, self.net_depth) 35 | networks.print_network(self.net) 36 | self.criterion_cls = torch.nn.CrossEntropyLoss() 37 | self.net = nn.DataParallel(self.net).to(self.device) 38 | 39 | self.set_optimizer(cfg) 40 | self.set_schedulers(cfg) 41 | self.set_log_data(cfg) 42 | self.writer = writer 43 | 44 | if cfg.USE_FAKE_DATA: 45 | print('Use fake data: sample model is {0}'.format(cfg.SAMPLE_MODEL_PATH)) 46 | print('fake ratio:', cfg.FAKE_DATA_RATE) 47 | cfg_tmp.USE_FAKE_DATA = False 48 | self.sample_model_AtoB = networks.define_TrecgNet(cfg_tmp, upsample=True, device=self.device) 49 | self.sample_model_BtoA = networks.define_TrecgNet(cfg_tmp, upsample=True, device=self.device) 50 | self.sample_model_AtoB.eval() 51 | self.sample_model_BtoA.eval() 52 | self.sample_model_AtoB = nn.DataParallel(self.sample_model_AtoB).to(self.device) 53 | self.sample_model_BtoA = nn.DataParallel(self.sample_model_BtoA).to(self.device) 54 | super().load_checkpoint(net=self.sample_model_AtoB, checkpoint_path=checkpoint_path_A) 55 | super().load_checkpoint(net=self.sample_model_BtoA, checkpoint_path=checkpoint_path_B) 56 | 57 | def set_input(self, data, d_type='pair'): 58 | 59 | input_A = data['A'] 60 | input_B = data['B'] 61 | self.img_names = data['img_name'] 62 | self.imgs_all.extend(data['img_name']) 63 | self.input_rgb = input_A.to(self.device) 64 | self.input_depth = input_B.to(self.device) 65 | 66 | self.batch_size = input_A.size(0) 67 | 68 | if 'label' in data.keys(): 69 | self._label = data['label'] 70 | self.label = torch.LongTensor(self._label).to(self.device) 71 | 72 | def _forward(self): 73 | 74 | # # use fake data to train 75 | if self.cfg.USE_FAKE_DATA: 76 | with torch.no_grad(): 77 | out_keys = self.build_output_keys(gen_img=True, cls=False) 78 | [self.fake_depth] = self.sample_model_AtoB(source=self.input_rgb, out_keys=out_keys) 79 | [self.fake_rgb] = self.sample_model_BtoA(source=self.input_depth, out_keys=out_keys) 80 | input_num = len(self.fake_depth) 81 | indexes = [i for i in range(input_num)] 82 | rgb_random_index = random.sample(indexes, int(len(self.fake_rgb) * self.cfg.FAKE_DATA_RATE)) 83 | depth_random_index = random.sample(indexes, int(len(self.fake_depth) * self.cfg.FAKE_DATA_RATE)) 84 | 85 | for i in rgb_random_index: 86 | self.input_rgb[i, :] = self.fake_rgb.data[i, :] 87 | for j in depth_random_index: 88 | self.input_depth[j, :] = self.fake_depth.data[j, :] 89 | 90 | out_keys = self.build_output_keys(gen_img=False, cls=True) 91 | [self.cls] = self.net(self.input_rgb, self.input_depth, label=self.label, out_keys=out_keys) 92 | 93 | def train_parameters(self, cfg): 94 | 95 | train_total_steps = 0 96 | train_total_iter = 0 97 | best_prec = 0 98 | 99 | for epoch in range(cfg.START_EPOCH, cfg.NITER_TOTAL + 1): 100 | 101 | self.phase = 'train' 102 | self.net.train() 103 | 104 | start_time = time.time() 105 | 106 | if cfg.LR_POLICY != 'plateau': 107 | self.update_learning_rate(epoch=train_total_iter) 108 | else: 109 | self.update_learning_rate(val=self.loss_meters['VAL_CLS_MEAN_ACC'].avg) 110 | 111 | for key in self.loss_meters: 112 | self.loss_meters[key].reset() 113 | 114 | iters = 0 115 | for i, data in enumerate(self.train_loader): 116 | 117 | self.set_input(data, self.cfg.DATA_TYPE) 118 | iter_start_time = time.time() 119 | train_total_steps += self.batch_size 120 | train_total_iter += 1 121 | iters += 1 122 | 123 | self._forward() 124 | loss = self._cal_loss(epoch) 125 | self._optimize(loss) 126 | 127 | if train_total_steps % cfg.PRINT_FREQ == 0: 128 | errors = self.get_current_errors() 129 | t = (time.time() - iter_start_time) 130 | self.print_current_errors(errors, epoch, i, t) 131 | 132 | print('iters in one epoch:', iters) 133 | print('gpu_ids:', cfg.GPU_IDS) 134 | 135 | self._write_loss(phase=self.phase, global_step=train_total_iter) 136 | 137 | train_errors = self.get_current_errors(current=False) 138 | print('#' * 10) 139 | self.print_current_errors(train_errors, epoch) 140 | 141 | print('Training Time: {0} sec'.format(time.time() - start_time)) 142 | 143 | # if self.cfg.USE_FAKE_DATA: 144 | # print('Fake data usage: {0} / {1}'.format(self.fake_image_num, self.train_image_num)) 145 | 146 | # Validate cls 147 | if cfg.EVALUATE: 148 | 149 | self.imgs_all = [] 150 | self.pred_index_all = [] 151 | self.target_index_all = [] 152 | self.fake_image_num = 0 153 | 154 | mean_acc = self.evaluate(cfg=self.cfg) 155 | 156 | print('Mean Acc Epoch <{epoch}> * Prec@1 <{mean_acc:.3f}> ' 157 | .format(epoch=epoch, mean_acc=mean_acc)) 158 | 159 | if not cfg.INFERENCE: 160 | self.loss_meters['VAL_CLS_MEAN_ACC'].update(mean_acc) 161 | self._write_loss(phase=self.phase, global_step=train_total_iter) 162 | 163 | assert (len(self.pred_index_all) == len(self.val_loader)) 164 | 165 | if cfg.SAVE_BEST and epoch >= total_epoch - 10: 166 | # save model 167 | is_best = mean_acc > best_prec 168 | best_prec = max(mean_acc, best_prec) 169 | 170 | if is_best: 171 | # confusion matrix 172 | # save_dir = os.path.join(self.save_dir, 'confusion_matrix' + '.png') 173 | # plot_confusion_matrix(self.target_index_all, self.pred_index_all, save_dir, 174 | # self.val_loader.dataset.classes) 175 | 176 | model_filename = '{0}_{1}_best.pth'.format(cfg.MODEL, cfg.WHICH_DIRECTION) 177 | self.save_checkpoint(epoch, model_filename) 178 | print('best mean acc is {0}, epoch is {1}'.format(best_prec, epoch)) 179 | 180 | print('End of iter {0} / {1} \t ' 181 | 'Time Taken: {2} sec'.format(train_total_iter, cfg.NITER_TOTAL, time.time() - start_time)) 182 | print('-' * 80) 183 | 184 | def _cal_loss(self, epoch=None): 185 | 186 | loss_total = torch.zeros(1) 187 | if self.use_gpu: 188 | loss_total = loss_total.cuda() 189 | 190 | cls_loss = self.criterion_cls(self.cls, self.label) * self.cfg.ALPHA_CLS 191 | loss_total = loss_total + cls_loss 192 | 193 | cls_loss = round(cls_loss.item(), 4) 194 | self.loss_meters['TRAIN_CLS_LOSS'].update(cls_loss, self.batch_size) 195 | 196 | prec1 = util.accuracy(self.cls.data, self.label, topk=(1,)) 197 | self.loss_meters['TRAIN_CLS_ACC'].update(prec1[0].item(), self.batch_size) 198 | 199 | # total loss 200 | return loss_total 201 | 202 | def set_log_data(self, cfg): 203 | 204 | self.loss_meters = defaultdict() 205 | self.log_keys = [ 206 | 'TRAIN_CLS_ACC', 207 | 'VAL_CLS_ACC', # classification 208 | 'TRAIN_CLS_LOSS', 209 | 'VAL_CLS_LOSS', 210 | 'TRAIN_CLS_MEAN_ACC', 211 | 'VAL_CLS_MEAN_ACC' 212 | ] 213 | for item in self.log_keys: 214 | self.loss_meters[item] = AverageMeter() 215 | 216 | def _write_loss(self, phase, global_step): 217 | 218 | if phase == 'train': 219 | 220 | self.writer.add_scalar('LR', self.optimizer.param_groups[0]['lr'], global_step=global_step) 221 | 222 | self.writer.add_scalar('TRAIN_CLS_LOSS', self.loss_meters['TRAIN_CLS_LOSS'].avg, 223 | global_step=global_step) 224 | self.writer.add_scalar('TRAIN_CLS_MEAN_ACC', self.loss_meters['TRAIN_CLS_MEAN_ACC'].avg, 225 | global_step=global_step) 226 | 227 | if phase == 'test': 228 | 229 | if self.cfg.EVALUATE: 230 | self.writer.add_scalar('VAL_CLS_LOSS', self.loss_meters['VAL_CLS_LOSS'].avg, 231 | global_step=global_step) 232 | self.writer.add_scalar('VAL_CLS_ACC', self.loss_meters['VAL_CLS_ACC'].avg, 233 | global_step=global_step) 234 | 235 | self.writer.add_scalar('VAL_CLS_MEAN_ACC_FUSION', self.loss_meters['VAL_CLS_MEAN_ACC'].avg, 236 | global_step=global_step) 237 | -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim import lr_scheduler 9 | import numpy as np 10 | import util.utils as util 11 | 12 | 13 | # some common actions are abstracted here 14 | # customized ones could be implemented in the corresponding model 15 | class BaseModel(nn.Module): 16 | 17 | def __init__(self, cfg): 18 | super(BaseModel, self).__init__() 19 | self.cfg = cfg 20 | self.gpu_ids = cfg.GPU_IDS 21 | num_gpu = len(cfg.GPU_IDS.split(',')) 22 | self.use_gpu = num_gpu > 0 23 | self.multi_gpu = num_gpu > 1 24 | self.model = None 25 | self.device = torch.device('cuda' if self.gpu_ids else 'cpu') 26 | self.save_dir = os.path.join(self.cfg.CHECKPOINTS_DIR, self.cfg.MODEL, 27 | str(time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime(time.time())))) 28 | if os.path.exists(self.save_dir): 29 | shutil.rmtree(self.save_dir) 30 | os.mkdir(self.save_dir) 31 | self.imgs_all = [] 32 | 33 | # schedule for modifying learning rate 34 | def set_schedulers(self, cfg): 35 | 36 | if cfg.PHASE == 'train': 37 | self.schedulers = [self._get_scheduler(optimizer, cfg, cfg.LR_POLICY) for optimizer in self.optimizers] 38 | 39 | def _get_scheduler(self, optimizer, cfg, lr_policy, decay_start=None, decay_epochs=None): 40 | if lr_policy == 'lambda': 41 | print('use lambda lr') 42 | if decay_start is None: 43 | decay_start = cfg.NITER 44 | decay_epochs = cfg.NITER_DECAY 45 | 46 | def lambda_rule(epoch): 47 | lr_l = 1 - max(0, epoch - decay_start - 1) / float(decay_epochs) 48 | # if lr_l < 1: 49 | # lr_l = 0.5 * lr_l 50 | # if epoch < decay_epochs + decay_start: 51 | # lr_l = 1 - max(0, epoch - decay_start) / float(decay_epochs) 52 | # else: 53 | # lr_l = 0.01 54 | return lr_l 55 | 56 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 57 | elif lr_policy == 'step': 58 | print('use step lr') 59 | scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg.LR_DECAY_ITERS, gamma=0.1) 60 | elif lr_policy == 'plateau': 61 | print('use plateau lr') 62 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', verbose=True, 63 | threshold=0.001, factor=0.5, patience=5, eps=1e-7) 64 | else: 65 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) 66 | return scheduler 67 | 68 | def set_input(self, input): 69 | 70 | pass 71 | 72 | # choose labelled or unlabelled dataset for training 73 | def get_train_loader(self, cfg): 74 | if cfg.UNLABELED: 75 | print('Training with no labeled data..., training image: {0}'.format( 76 | len(self.unlabeled_loader.dataset.imgs))) 77 | dataset = self.unlabeled_loader 78 | self.unlabeled_flag = True 79 | else: 80 | dataset = self.train_loader 81 | self.unlabeled_flag = False 82 | print('# Training images num = {0}'.format(self.train_image_num)) 83 | # classes = zip(self.train_loader.dataset.classes, (self.cfg.CLASS_WEIGHTS_TRAIN.cpu().numpy() / self.train_image_num) * 100) 84 | # print('Class weight:') 85 | # print(['{0}:{1}'.format(cla[0], round(cla[1])) for cla in classes]) 86 | 87 | return dataset 88 | 89 | def set_data_loader(self, train_loader=None, val_loader=None, unlabeled_loader=None): 90 | 91 | if train_loader is not None: 92 | self.train_loader = train_loader 93 | self.train_image_num = len(train_loader.dataset.imgs) 94 | if val_loader is not None: 95 | self.val_loader = val_loader 96 | self.val_image_num = len(val_loader.dataset.imgs) 97 | if unlabeled_loader is not None: 98 | self.unlabeled_loader = unlabeled_loader 99 | self.unlabled_train_image_num = len(unlabeled_loader.dataset.imgs) 100 | 101 | def get_current_errors(self, current=True): 102 | 103 | loss_dict = OrderedDict() 104 | for key, value in sorted(self.loss_meters.items(), reverse=True): 105 | 106 | if 'TEST' in key or 'VAL' in key or 'ACC' in key or value.val == 0 or 'LAYER' in key: 107 | continue 108 | if current: 109 | loss_dict[key] = value.val 110 | else: 111 | loss_dict[key] = value.avg 112 | return loss_dict 113 | 114 | def save_checkpoint(self, state, filename='checkpoint.pth.tar'): 115 | pass 116 | 117 | def update_learning_rate(self, val=None, epoch=None): 118 | for scheduler in self.schedulers: 119 | if val is not None: 120 | scheduler.step(val) 121 | else: 122 | scheduler.step(epoch) 123 | 124 | for optimizer in self.optimizers: 125 | print('default lr', optimizer.defaults['lr']) 126 | for param_group in optimizer.param_groups: 127 | lr = param_group['lr'] 128 | print('/////////learning rate = %.7f' % lr) 129 | 130 | def set_log_data(self, cfg): 131 | pass 132 | 133 | def print_current_errors(self, errors, epoch, i=None, t=None): 134 | if i is None: 135 | message = '(Training Loss_avg [Epoch:{0}]) '.format(epoch) 136 | else: 137 | message = '(epoch: {epoch}, iters: {iter}, time: {time:.3f}) '.format(epoch=epoch, iter=i, time=t) 138 | 139 | for k, v in errors.items(): 140 | if 'CLS' in k and i is None: 141 | message += '{key}: [{value:.3f}] '.format(key=k, value=v) 142 | else: 143 | message += '{key}: {value:.3f} '.format(key=k, value=v) 144 | print(message) 145 | 146 | def set_optimizer(self, cfg): 147 | 148 | self.optimizers = [] 149 | # self.optimizer = torch.optim.Adam([{'params': self.net.fc.parameters(), 'lr': cfg.LR}], lr=cfg.LR / 10, betas=(0.5, 0.999)) 150 | 151 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=cfg.LR, betas=(0.5, 0.999)) 152 | print('optimizer: ', self.optimizer) 153 | self.optimizers.append(self.optimizer) 154 | 155 | def _optimize(self, loss): 156 | 157 | self.optimizer.zero_grad() 158 | loss.backward() 159 | self.optimizer.step() 160 | 161 | def _cal_loss(self, epoch): 162 | 163 | pass 164 | 165 | def build_output_keys(self, gen_img=True, cls=True): 166 | 167 | out_keys = [] 168 | 169 | if gen_img: 170 | out_keys.append('gen_img') 171 | 172 | if cls: 173 | out_keys.append('cls') 174 | 175 | return out_keys 176 | 177 | def load_checkpoint(self, net=None, checkpoint_path=None, keep_kw_module=True, keep_fc=None): 178 | 179 | keep_fc = keep_fc if keep_fc is not None else not self.cfg.NO_FC 180 | 181 | if os.path.isfile(checkpoint_path): 182 | checkpoint = torch.load(checkpoint_path) 183 | state_model = net.state_dict() 184 | state_checkpoint = checkpoint['state_dict'] 185 | 186 | # the weights of ckpt are stored when data-paralleled, remove 'module' if you 187 | # update the raw model with such ckpt 188 | if not keep_kw_module: 189 | new_state_dict = OrderedDict() 190 | for k, v in state_checkpoint.items(): 191 | name = k[7:] 192 | new_state_dict[name] = v 193 | state_checkpoint = new_state_dict 194 | 195 | if keep_fc: 196 | states_ckp = {k: v for k, v in state_checkpoint.items() if k in state_model} 197 | else: 198 | states_ckp = {k: v for k, v in state_checkpoint.items() if k in state_model and 'fc' not in k} 199 | 200 | # if successfully load weights 201 | assert (len(states_ckp) > 0) 202 | 203 | state_model.update(states_ckp) 204 | net.load_state_dict(state_model) 205 | print('load ckpt {0}'.format(checkpoint_path)) 206 | return checkpoint 207 | 208 | else: 209 | print("=> !!! No checkpoint found at '{}'".format(checkpoint_path)) 210 | return 211 | 212 | def evaluate(self, cfg): 213 | 214 | self.phase = 'test' 215 | 216 | # switch to evaluate mode 217 | self.net.eval() 218 | 219 | self.imgs_all = [] 220 | self.pred_index_all = [] 221 | self.target_index_all = [] 222 | self.fake_image_num = 0 223 | 224 | with torch.no_grad(): 225 | 226 | print('# Cls val images num = {0}'.format(self.val_image_num)) 227 | # batch_index = int(self.val_image_num / cfg.BATCH_SIZE) 228 | # random_id = random.randint(0, batch_index) 229 | 230 | for i, data in enumerate(self.val_loader): 231 | self.set_input(data, self.cfg.DATA_TYPE) 232 | 233 | self._forward() 234 | self._process_fc() 235 | 236 | # accuracy 237 | prec1 = util.accuracy(self.cls.data, self.label, topk=(1,)) 238 | self.loss_meters['VAL_CLS_ACC'].update(prec1[0].item(), self.batch_size) 239 | 240 | # Mean ACC 241 | mean_acc = self._cal_mean_acc(cfg=cfg, data_loader=self.val_loader) 242 | print('mean_acc: [{0}]'.format(mean_acc)) 243 | return mean_acc 244 | 245 | def _process_fc(self): 246 | 247 | pred, self.pred_index = util.process_output(self.cls.data) 248 | 249 | self.pred_index_all.extend(list(self.pred_index)) 250 | self.target_index_all.extend(list(self._label.numpy())) 251 | 252 | def _cal_mean_acc(self, cfg, data_loader): 253 | 254 | mean_acc = util.mean_acc(np.array(self.target_index_all), np.array(self.pred_index_all), 255 | cfg.NUM_CLASSES, 256 | data_loader.dataset.classes) 257 | return mean_acc 258 | 259 | def _forward(self): 260 | 261 | pass 262 | 263 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | from torch.nn import init 6 | from torchvision.models.resnet import resnet18 7 | 8 | 9 | def init_weights(net, init_type='normal', gain=0.02): 10 | def init_func(m): 11 | classname = m.__class__.__name__ 12 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 13 | if init_type == 'normal': 14 | init.normal_(m.weight.data, 0.0, gain) 15 | elif init_type == 'xavier': 16 | init.xavier_normal(m.weight.data, gain=gain) 17 | elif init_type == 'kaiming': 18 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 19 | elif init_type == 'orthogonal': 20 | init.orthogonal(m.weight.data, gain=gain) 21 | else: 22 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 23 | if hasattr(m, 'bias') and m.bias is not None: 24 | init.constant_(m.bias.data, 0.0) 25 | elif classname.find('BatchNorm2d') != -1: 26 | init.normal_(m.weight.data, 1.0, gain) 27 | init.constant_(m.bias.data, 0.0) 28 | 29 | print('initialize network with %s' % init_type) 30 | net.apply(init_func) 31 | 32 | 33 | def fix_grad(net): 34 | def fix_func(m): 35 | classname = m.__class__.__name__ 36 | if classname.find('Conv') != -1 or classname.find('BatchNorm2d') != -1: 37 | m.weight.requires_grad = False 38 | if m.bias is not None: 39 | m.bias.requires_grad = False 40 | 41 | net.apply(fix_func) 42 | 43 | 44 | def unfix_grad(net): 45 | def fix_func(m): 46 | classname = m.__class__.__name__ 47 | if classname.find('Conv') != -1 or classname.find('BatchNorm2d') != -1 or classname.find('Linear') != -1: 48 | m.weight.requires_grad = True 49 | if m.bias is not None: 50 | m.bias.requires_grad = True 51 | 52 | net.apply(fix_func) 53 | 54 | 55 | def define_TrecgNet(cfg, upsample=None, device=None): 56 | 57 | if upsample is None: 58 | upsample = not cfg.NO_UPSAMPLE 59 | 60 | model = TRecgNet_Upsample_Resiual(cfg, encoder=cfg.ARCH, upsample=upsample, device=device) 61 | 62 | return model 63 | 64 | def print_network(net): 65 | num_params = 0 66 | for param in net.parameters(): 67 | num_params += param.numel() 68 | print(net) 69 | print('Total number of parameters: %d' % num_params) 70 | 71 | def conv3x3(in_planes, out_planes, stride=1): 72 | 73 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 74 | padding=1, bias=False) 75 | 76 | def conv_norm_relu(dim_in, dim_out, kernel_size=3, norm=nn.BatchNorm2d, stride=1, padding=1, 77 | use_leakyRelu=False, use_bias=False, is_Sequential=True): 78 | if use_leakyRelu: 79 | act = nn.LeakyReLU(0.2, True) 80 | else: 81 | act = nn.ReLU(True) 82 | 83 | if is_Sequential: 84 | result = nn.Sequential( 85 | nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, 86 | padding=padding, bias=use_bias), 87 | norm(dim_out, affine=True), 88 | act 89 | ) 90 | return result 91 | return [nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, 92 | padding=padding, bias=False), 93 | norm(dim_out, affine=True), 94 | act] 95 | 96 | 97 | ############################################################################## 98 | # Moduels 99 | ############################################################################## 100 | class Upsample_Interpolate(nn.Module): 101 | 102 | def __init__(self, dim_in, dim_out, kernel_size=1, padding=0, norm=nn.BatchNorm2d, scale=2, mode='bilinear', activate=True): 103 | super(Upsample_Interpolate, self).__init__() 104 | self.scale = scale 105 | self.mode = mode 106 | self.conv = nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=1, padding=padding, bias=False) 107 | self.bn = norm(dim_out) 108 | 109 | def forward(self, x, activate=True): 110 | x = nn.functional.interpolate(x, scale_factor=self.scale, mode=self.mode, align_corners=True) 111 | conv_out = self.conv(x) 112 | conv_out = self.bn(conv_out) 113 | if activate: 114 | conv_out = nn.ReLU(True)(conv_out) 115 | return x, conv_out 116 | 117 | 118 | class UpsampleBasicBlock(nn.Module): 119 | 120 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, norm=nn.BatchNorm2d, scale=2, mode='bilinear', upsample=True): 121 | super(UpsampleBasicBlock, self).__init__() 122 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 123 | self.bn1 = norm(planes) 124 | self.relu = nn.ReLU(inplace=True) 125 | self.conv2 = conv3x3(planes, planes) 126 | self.bn2 = norm(planes) 127 | 128 | if upsample: 129 | if inplanes != planes: 130 | kernel_size, padding = 1, 0 131 | else: 132 | kernel_size, padding = 3, 1 133 | 134 | self.upsample = nn.Sequential( 135 | nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, bias=False), 136 | norm(planes)) 137 | else: 138 | self.upsample = None 139 | 140 | self.scale = scale 141 | self.mode = mode 142 | 143 | def forward(self, x): 144 | 145 | if self.upsample is not None: 146 | x = nn.functional.interpolate(x, scale_factor=self.scale, mode=self.mode, align_corners=True) 147 | residual = self.upsample(x) 148 | else: 149 | residual = x 150 | 151 | out = self.conv1(x) 152 | out = self.bn1(out) 153 | out = self.relu(out) 154 | 155 | out = self.conv2(out) 156 | out = self.bn2(out) 157 | 158 | out += residual 159 | out = self.relu(out) 160 | 161 | return out 162 | 163 | ############################################################################## 164 | # Translate to recognize 165 | ############################################################################## 166 | class Content_Model(nn.Module): 167 | 168 | def __init__(self, cfg, criterion=None): 169 | super(Content_Model, self).__init__() 170 | self.cfg = cfg 171 | self.criterion = criterion 172 | self.net = cfg.WHICH_CONTENT_NET 173 | 174 | if 'resnet' in self.net: 175 | from .pretrained_resnet import ResNet 176 | self.model = ResNet(self.net, cfg) 177 | 178 | fix_grad(self.model) 179 | # print_network(self) 180 | 181 | def forward(self, x, in_channel=3, layers=None): 182 | 183 | self.model.eval() 184 | 185 | if layers is None: 186 | layers = self.cfg.CONTENT_LAYERS.split(',') 187 | 188 | layer_wise_features = self.model(x, layers) 189 | return layer_wise_features 190 | 191 | 192 | class TRecgNet_Upsample_Resiual(nn.Module): 193 | 194 | def __init__(self, cfg, encoder='resnet18', upsample=True, device=None): 195 | super(TRecgNet_Upsample_Resiual, self).__init__() 196 | 197 | self.encoder = encoder 198 | self.cfg = cfg 199 | self.upsample = upsample 200 | self.dim_noise = 128 201 | self.device = device 202 | self.avg_pool_size = 14 203 | 204 | dims = [32, 64, 128, 256, 512, 1024, 2048] 205 | 206 | if cfg.PRETRAINED == 'imagenet' or cfg.PRETRAINED == 'place': 207 | pretrained = True 208 | else: 209 | pretrained = False 210 | 211 | if cfg.PRETRAINED == 'place': 212 | resnet = models.__dict__['resnet18'](num_classes=365) 213 | # places model downloaded from http://places2.csail.mit.edu/ 214 | checkpoint = torch.load(self.cfg.CONTENT_MODEL_PATH, map_location=lambda storage, loc: storage) 215 | state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items()} 216 | resnet.load_state_dict(state_dict) 217 | print('place resnet18 loaded....') 218 | else: 219 | resnet = resnet18(pretrained=pretrained) 220 | print('{0} pretrained:{1}'.format(encoder, str(pretrained))) 221 | 222 | self.conv1 = resnet.conv1 223 | self.bn1 = resnet.bn1 224 | self.relu = resnet.relu 225 | self.maxpool = resnet.maxpool # 1/4 226 | self.layer1 = resnet.layer1 # 1/4 227 | self.layer2 = resnet.layer2 # 1/8 228 | self.layer3 = resnet.layer3 # 1/16 229 | self.layer4 = resnet.layer4 # 1/32 230 | 231 | self.build_upsample_layers(dims) 232 | 233 | self.avgpool = nn.AvgPool2d(self.avg_pool_size, 1) 234 | self.fc = nn.Linear(dims[4], cfg.NUM_CLASSES) 235 | 236 | if pretrained and upsample: 237 | 238 | init_weights(self.up1, 'normal') 239 | init_weights(self.up2, 'normal') 240 | init_weights(self.up3, 'normal') 241 | init_weights(self.up4, 'normal') 242 | init_weights(self.skip_3, 'normal') 243 | init_weights(self.skip_2, 'normal') 244 | init_weights(self.skip_1, 'normal') 245 | init_weights(self.up_image, 'normal') 246 | 247 | elif not pretrained: 248 | 249 | init_weights(self, 'normal') 250 | 251 | def _make_upsample(self, block, planes, norm=nn.BatchNorm2d, is_upsample=False): 252 | 253 | upsample = None 254 | if self.inplanes != planes or is_upsample: 255 | upsample = Upsample_Interpolate(self.inplanes, planes, kernel_size=1, padding=0, norm=norm, activate=False) 256 | 257 | layer = block(self.inplanes, planes, norm=norm, upsample=upsample) 258 | self.inplanes = planes 259 | 260 | return layer 261 | 262 | def build_upsample_layers(self, dims): 263 | 264 | # norm = nn.BatchNorm2d 265 | norm = nn.InstanceNorm2d 266 | 267 | # self.inplanes = dims[4] + self.dim_noise if self.use_noise else dims[4] 268 | self.up1 = UpsampleBasicBlock(dims[4], dims[3], kernel_size=1, padding=0, norm=norm) 269 | self.up2 = UpsampleBasicBlock(dims[3], dims[2], kernel_size=1, padding=0, norm=norm) 270 | self.up3 = UpsampleBasicBlock(dims[2], dims[1], kernel_size=1, padding=0, norm=norm) 271 | self.up4 = UpsampleBasicBlock(dims[1], dims[1], kernel_size=3, padding=1, norm=norm) 272 | 273 | self.skip_3 = conv_norm_relu(dims[3], dims[3], kernel_size=1, padding=0, norm=norm) 274 | self.skip_2 = conv_norm_relu(dims[2], dims[2], kernel_size=1, padding=0, norm=norm) 275 | self.skip_1 = conv_norm_relu(dims[1], dims[1], kernel_size=1, padding=0, norm=norm) 276 | 277 | self.up_image = nn.Sequential( 278 | nn.Conv2d(64, 3, 7, 1, 3, bias=False), 279 | nn.Tanh() 280 | ) 281 | 282 | def forward(self, source=None, out_keys=None, phase='train', content_layers=None, return_losses=True): 283 | out = {} 284 | 285 | out['0'] = self.relu(self.bn1(self.conv1(source))) 286 | out['1'] = self.layer1(out['0']) 287 | out['2'] = self.layer2(out['1']) 288 | out['3'] = self.layer3(out['2']) 289 | out['4'] = self.layer4(out['3']) 290 | 291 | if self.upsample and 'gen_img' in out_keys: 292 | skip1 = self.skip_1(out['1']) # 64 / 128 293 | skip2 = self.skip_2(out['2']) # 128 / 256 294 | skip3 = self.skip_3(out['3']) # 256 / 512 295 | 296 | upconv4 = self.up1(out['4']) 297 | upconv3 = self.up2(upconv4 + skip3) 298 | upconv2 = self.up3(upconv3 + skip2) 299 | upconv1 = self.up4(upconv2 + skip1) 300 | 301 | out['gen_img'] = self.up_image(upconv1) 302 | 303 | out['avgpool'] = self.avgpool(out['4']) 304 | avgpool = out['avgpool'].view(source.size(0), -1) 305 | out['cls'] = self.fc(avgpool) 306 | 307 | result = [] 308 | for key in out_keys: 309 | if isinstance(key, list): 310 | item = [out[subkey] for subkey in key] 311 | else: 312 | item = out[key] 313 | result.append(item) 314 | 315 | return result 316 | 317 | 318 | class Fusion(nn.Module): 319 | 320 | def __init__(self, cfg, rgb_model=None, depth_model=None, device='cuda'): 321 | super(Fusion, self).__init__() 322 | self.cfg = cfg 323 | self.device = device 324 | self.rgb_model = rgb_model 325 | self.depth_model = depth_model 326 | self.net_RGB = self.construct_single_modal_net(rgb_model) 327 | self.net_depth = self.construct_single_modal_net(depth_model) 328 | 329 | if cfg.FIX_GRAD: 330 | fix_grad(self.net_RGB) 331 | fix_grad(self.net_depth) 332 | 333 | self.avgpool = nn.AvgPool2d(14, 1) 334 | self.fc = nn.Sequential( 335 | nn.Linear(1024, 1024), 336 | nn.ReLU(), 337 | nn.Dropout(0.2), 338 | nn.Linear(1024, 1024), 339 | nn.ReLU(), 340 | nn.Dropout(0.2), 341 | nn.Linear(1024, cfg.NUM_CLASSES) 342 | ) 343 | 344 | init_weights(self.fc, 'normal') 345 | 346 | # only keep the classification branch 347 | def construct_single_modal_net(self, model): 348 | if isinstance(model, nn.DataParallel): 349 | model = model.module 350 | 351 | ops = [model.conv1, model.bn1, model.relu, model.layer1, model.layer2, 352 | model.layer3, model.layer4] 353 | return nn.Sequential(*ops) 354 | 355 | def set_cls_criterion(self, criterion): 356 | self.cls_criterion = criterion.to(self.device) 357 | 358 | def forward(self, input_rgb, input_depth, label, out_keys=None): 359 | 360 | out = {} 361 | 362 | rgb_specific = self.net_RGB(input_rgb) 363 | depth_specific = self.net_depth(input_depth) 364 | 365 | concat = torch.cat((rgb_specific, depth_specific), 1).to(self.device) 366 | x = self.avgpool(concat) 367 | x = x.view(x.size(0), -1) 368 | out['cls'] = self.fc(x) 369 | 370 | result = [] 371 | for key in out_keys: 372 | if isinstance(key, list): 373 | item = [out[subkey] for subkey in key] 374 | else: 375 | item = out[key] 376 | result.append(item) 377 | 378 | return result 379 | 380 | -------------------------------------------------------------------------------- /model/trecg_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import copy 4 | from collections import OrderedDict 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torchvision 11 | 12 | import util.utils as util 13 | from util.average_meter import AverageMeter 14 | from util.confusion_matrix import plot_confusion_matrix 15 | from . import networks 16 | from .base_model import BaseModel 17 | 18 | 19 | class TRecgNet(BaseModel): 20 | 21 | def __init__(self, cfg, writer=None): 22 | super(TRecgNet, self).__init__(cfg) 23 | 24 | util.mkdir(self.save_dir) 25 | assert (self.cfg.WHICH_DIRECTION is not None) 26 | self.AtoB = self.cfg.WHICH_DIRECTION == 'AtoB' 27 | self.modality = 'rgb' if self.AtoB else 'depth' 28 | self.sample_model = None 29 | self.phase = cfg.PHASE 30 | self.upsample = not cfg.NO_UPSAMPLE 31 | self.content_model = None 32 | self.content_layers = [] 33 | 34 | self.writer = writer 35 | 36 | # networks 37 | self.use_noise = cfg.WHICH_DIRECTION == 'BtoA' 38 | self.net = networks.define_TrecgNet(cfg, device=self.device) 39 | networks.print_network(self.net) 40 | 41 | 42 | def set_input(self, data, type='pair'): 43 | 44 | if type == 'pair': 45 | 46 | input_A = data['A'] 47 | input_B = data['B'] 48 | self.img_names = data['img_name'] 49 | self.real_A = input_A.to(self.device) 50 | self.real_B = input_B.to(self.device) 51 | 52 | AtoB = self.AtoB 53 | self.source_modal = self.real_A if AtoB else self.real_B 54 | self.target_modal = self.real_B if AtoB else self.real_A 55 | self.source_modal_original = self.source_modal 56 | 57 | self.batch_size = input_A.size(0) 58 | 59 | if 'label' in data.keys(): 60 | self._label = data['label'] 61 | self.label = torch.LongTensor(self._label).to(self.device) 62 | 63 | def build_output_keys(self, gen_img=True, cls=True): 64 | 65 | out_keys = [] 66 | 67 | if gen_img: 68 | out_keys.append('gen_img') 69 | 70 | if cls: 71 | out_keys.append('cls') 72 | 73 | return out_keys 74 | 75 | def train_parameters(self, cfg): 76 | 77 | assert (self.cfg.LOSS_TYPES) 78 | 79 | if 'CLS' in self.cfg.LOSS_TYPES or self.cfg.EVALUATE: 80 | self.criterion_cls = torch.nn.CrossEntropyLoss(self.cfg.CLASS_WEIGHTS_TRAIN.to(self.device)) 81 | 82 | if 'SEMANTIC' in self.cfg.LOSS_TYPES: 83 | self.criterion_content = torch.nn.L1Loss() 84 | self.content_model = networks.Content_Model(cfg, self.criterion_content).to(self.device) 85 | assert(self.cfg.CONTENT_LAYERS) 86 | self.content_layers = self.cfg.CONTENT_LAYERS.split(',') 87 | 88 | self.set_optimizer(cfg) 89 | self.set_log_data(cfg) 90 | self.set_schedulers(cfg) 91 | self.net = nn.DataParallel(self.net).to(self.device) 92 | 93 | if cfg.USE_FAKE_DATA: 94 | print('Use fake data: sample model is {0}'.format(cfg.SAMPLE_MODEL_PATH)) 95 | sample_model_path = os.path.join(cfg.CHECKPOINTS_DIR, cfg.SAMPLE_MODEL_PATH) 96 | cfg_sample = copy.deepcopy(cfg) 97 | cfg_sample.USE_FAKE_DATA = False # since it passes the same forward used by the main model 98 | sample_model = networks.define_TrecgNet(cfg_sample, use_noise=not self.use_noise, upsample=True, device=self.device) 99 | self.load_checkpoint(sample_model, sample_model_path, keep_kw_module=True, keep_fc=False) 100 | self.sample_model = sample_model.to(self.device) 101 | self.sample_model.eval() 102 | 103 | train_total_steps = 0 104 | train_total_iter = 0 105 | best_prec = 0 106 | 107 | for epoch in range(cfg.START_EPOCH, cfg.NITER_TOTAL + 1): 108 | 109 | self.imgs_all = [] 110 | self.pred_index_all = [] 111 | self.target_index_all = [] 112 | 113 | start_time = time.time() 114 | data_loader = self.get_train_loader(cfg) 115 | 116 | if cfg.LR_POLICY != 'plateau': 117 | self.update_learning_rate(epoch=epoch) 118 | else: 119 | self.update_learning_rate(val=self.loss_meters['VAL_CLS_MEAN_ACC'].avg) 120 | 121 | self.phase = 'train' 122 | self.net.train() 123 | 124 | for key in self.loss_meters: 125 | self.loss_meters[key].reset() 126 | 127 | if self.sample_model is not None: 128 | self.fake_image_num = 0 129 | 130 | iters = 0 131 | for i, data in enumerate(data_loader): 132 | 133 | self.set_input(data, self.cfg.DATA_TYPE) 134 | iter_start_time = time.time() 135 | train_total_steps += self.batch_size 136 | train_total_iter += 1 137 | iters += 1 138 | 139 | self._forward() 140 | loss = self._cal_loss(epoch) 141 | self._optimize(loss) 142 | 143 | if train_total_steps % cfg.PRINT_FREQ == 0: 144 | errors = self.get_current_errors() 145 | t = (time.time() - iter_start_time) 146 | self.print_current_errors(errors, epoch, i, t) 147 | 148 | model_filename = '{0}_{1}_{2}.pth'.format(cfg.MODEL, cfg.WHICH_DIRECTION, cfg.NITER_TOTAL) 149 | self.save_checkpoint(cfg.NITER_TOTAL, model_filename) 150 | 151 | print('iters in one epoch:', iters) 152 | 153 | self._write_loss(phase=self.phase, global_step=epoch) 154 | 155 | train_errors = self.get_current_errors(current=False) 156 | print('#' * 10) 157 | self.print_current_errors(train_errors, epoch) 158 | 159 | if self.sample_model is not None: 160 | print('Fake data usage: {0} / {1}'.format(self.fake_image_num, self.train_image_num)) 161 | print('Training Time: {0} sec'.format(time.time() - start_time)) 162 | 163 | # Validate cls 164 | if cfg.EVALUATE: 165 | 166 | mean_acc = self.evaluate(cfg=self.cfg) 167 | 168 | print('Mean Acc Epoch <{epoch}> * Prec@1 <{mean_acc:.3f}> ' 169 | .format(epoch=epoch, mean_acc=mean_acc)) 170 | 171 | if not cfg.INFERENCE: 172 | self.loss_meters['VAL_CLS_MEAN_ACC'].update(mean_acc) 173 | self._write_loss(phase=self.phase, global_step=epoch) 174 | 175 | assert (len(self.pred_index_all) == len(self.val_loader)) 176 | 177 | if cfg.SAVE_BEST and epoch >= self.cfg.NITER_TOTAL - 10: 178 | # save model 179 | is_best = mean_acc > best_prec 180 | best_prec = max(mean_acc, best_prec) 181 | 182 | if is_best: 183 | # confusion matrix 184 | save_dir = os.path.join(self.save_dir, 'confusion_matrix' + '.png') 185 | plot_confusion_matrix(self.target_index_all, self.pred_index_all, save_dir, 186 | self.val_loader.dataset.classes) 187 | 188 | model_filename = '{0}_{1}_best.pth'.format(cfg.MODEL, cfg.WHICH_DIRECTION) 189 | self.save_checkpoint(epoch, model_filename) 190 | print('best mean acc is {0}, epoch is {1}'.format(best_prec, epoch)) 191 | 192 | print('End of Epoch {0} / {1} \t ' 193 | 'Time Taken: {2} sec'.format(epoch, cfg.NITER_TOTAL, time.time() - start_time)) 194 | print('-' * 80) 195 | 196 | # encoder-decoder branch 197 | def _forward(self): 198 | 199 | self.gen = None 200 | self.source_modal_show = None 201 | self.target_modal_show = None 202 | self.cls_loss = None 203 | 204 | if self.phase == 'train': 205 | 206 | # # use fake data to train 207 | if self.sample_model is not None: 208 | with torch.no_grad(): 209 | out_keys = self.build_output_keys(gen_img=True, cls=False) 210 | [fake_source] = self.sample_model(source=self.target_modal, 211 | out_keys=out_keys, return_losses=False) 212 | input_num = len(fake_source) 213 | index = [i for i in range(0, input_num) if np.random.uniform() > 1 - self.cfg.FAKE_DATA_RATE] 214 | for j in index: 215 | self.source_modal[j, :] = fake_source.data[j, :] 216 | self.fake_image_num += len(index) 217 | 218 | if 'CLS' not in self.cfg.LOSS_TYPES or self.cfg.UNLABELED: 219 | 220 | out_keys = self.build_output_keys(gen_img=True, cls=False) 221 | [self.gen] = self.net(source=self.source_modal, out_keys=out_keys, content_layers=self.content_layers) 222 | 223 | elif self.upsample: 224 | out_keys = self.build_output_keys(gen_img=True, cls=True) 225 | [self.gen, self.cls] = self.net(source=self.source_modal, out_keys=out_keys, content_layers=self.content_layers) 226 | else: 227 | out_keys = self.build_output_keys(gen_img=False, cls=True) 228 | [self.cls] = self.net(source=self.source_modal, out_keys=out_keys) 229 | 230 | self.source_modal_show = self.source_modal 231 | self.target_modal_show = self.target_modal 232 | 233 | else: 234 | 235 | if self.upsample: 236 | 237 | out_keys = self.build_output_keys(gen_img=True, cls=True) 238 | [self.gen, self.cls]= self.net(self.source_modal, out_keys=out_keys) 239 | self.source_modal_show = self.source_modal 240 | self.target_modal_show = self.target_modal 241 | 242 | else: 243 | out_keys = self.build_output_keys(gen_img=False, cls=True) 244 | [self.cls] = self.net(self.source_modal, label=self.label, out_keys=out_keys) 245 | 246 | def _cal_loss(self, epoch=None): 247 | 248 | loss_total = torch.zeros(1) 249 | if self.use_gpu: 250 | loss_total = loss_total.cuda() 251 | 252 | if self.gen is not None: 253 | assert (self.gen.size(-1) == self.cfg.FINE_SIZE) 254 | 255 | if 'CLS' in self.cfg.LOSS_TYPES: 256 | cls_loss = self.criterion_cls(self.cls, self.label) * self.cfg.ALPHA_CLS 257 | loss_total = loss_total + cls_loss 258 | 259 | cls_loss = round(cls_loss.item(), 4) 260 | self.loss_meters['TRAIN_CLS_LOSS'].update(cls_loss, self.batch_size) 261 | 262 | prec1 = util.accuracy(self.cls.data, self.label, topk=(1,)) 263 | self.loss_meters['TRAIN_CLS_ACC'].update(prec1[0].item(), self.batch_size) 264 | 265 | # ) content supervised 266 | if self.cfg.NITER_START_CONTENT <= epoch <= self.cfg.NITER_END_CONTENT: 267 | 268 | if 'SEMANTIC' in self.cfg.LOSS_TYPES: 269 | source_features = self.content_model((self.gen + 1) / 2, layers=self.content_layers) 270 | target_features = self.content_model((self.target_modal + 1) / 2, layers=self.content_layers) 271 | len_layers = len(self.content_layers) 272 | loss_fns = [self.criterion_content] * len_layers 273 | alpha = [1] * len_layers 274 | 275 | layer_wise_losses = [alpha[i] * loss_fns[i](source_feature, target_features[i]) 276 | for i, source_feature in enumerate(source_features)] * self.cfg.ALPHA_CONTENT 277 | 278 | content_loss = sum(layer_wise_losses) 279 | loss_total = loss_total + content_loss 280 | 281 | self.loss_meters['TRAIN_SEMANTIC_LOSS'].update(content_loss.item(), self.batch_size) 282 | 283 | # total loss 284 | return loss_total 285 | 286 | def set_log_data(self, cfg): 287 | 288 | self.loss_meters = defaultdict() 289 | self.log_keys = [ 290 | 'TRAIN_SEMANTIC_LOSS', # semantic 291 | 'TRAIN_CLS_ACC', 292 | 'VAL_CLS_ACC', # classification 293 | 'TRAIN_CLS_LOSS', 294 | 'VAL_CLS_LOSS', 295 | 'TRAIN_CLS_MEAN_ACC', 296 | 'VAL_CLS_MEAN_ACC' 297 | ] 298 | for item in self.log_keys: 299 | self.loss_meters[item] = AverageMeter() 300 | 301 | def save_checkpoint(self, epoch, filename=None): 302 | 303 | if filename is None: 304 | filename = 'TRecg2Net_{0}_{1}.pth'.format(self.cfg.WHICH_DIRECTION, epoch) 305 | 306 | net_state_dict = self.net.state_dict() 307 | save_state_dict = {} 308 | for k, v in net_state_dict.items(): 309 | if 'content_model' in k: 310 | continue 311 | save_state_dict[k] = v 312 | 313 | state = { 314 | 'epoch': epoch, 315 | 'state_dict': save_state_dict, 316 | 'optimizer': self.optimizer.state_dict(), 317 | } 318 | 319 | filepath = os.path.join(self.save_dir, filename) 320 | torch.save(state, filepath) 321 | 322 | def _load_checkpoint(self, net, checkpoint_path, optimizer=None, keep_kw_module=True, keep_fc=None): 323 | 324 | checkpoint = super().load_checkpoint(net, checkpoint_path=checkpoint_path, keep_kw_module=keep_kw_module, 325 | keep_fc=keep_fc) 326 | 327 | if self.phase == 'train' and not self.cfg.INIT_EPOCH: 328 | optimizer.load_state_dict(checkpoint['optimizer']) 329 | 330 | print("=> loaded checkpoint '{}' (iter {})" 331 | .format(checkpoint_path, checkpoint['iter'])) 332 | 333 | def _write_loss(self, phase, global_step): 334 | 335 | loss_types = self.cfg.LOSS_TYPES 336 | 337 | if phase == 'train': 338 | 339 | self.writer.add_scalar('LR', self.optimizer.param_groups[0]['lr'], global_step=global_step) 340 | 341 | if 'CLS' in loss_types: 342 | self.writer.add_scalar('TRAIN_CLS_LOSS', self.loss_meters['TRAIN_CLS_LOSS'].avg, 343 | global_step=global_step) 344 | self.writer.add_scalar('TRAIN_CLS_MEAN_ACC', self.loss_meters['TRAIN_CLS_MEAN_ACC'].avg, 345 | global_step=global_step) 346 | 347 | if 'SEMANTIC' in loss_types: 348 | self.writer.add_scalar('TRAIN_SEMANTIC_LOSS', self.loss_meters['TRAIN_SEMANTIC_LOSS'].avg, 349 | global_step=global_step) 350 | 351 | if self.upsample and self.gen is not None and not self.cfg.NO_VIS: 352 | self.writer.add_image('Train_Source', 353 | torchvision.utils.make_grid(self.source_modal_show[:6].clone().cpu().data, 3, 354 | normalize=True), global_step=global_step) 355 | self.writer.add_image('Train_Gen', torchvision.utils.make_grid(self.gen[:6].clone().cpu().data, 3, 356 | normalize=True), 357 | global_step=global_step) 358 | self.writer.add_image('Train_Target', 359 | torchvision.utils.make_grid(self.target_modal_show[:6].clone().cpu().data, 3, 360 | normalize=True), global_step=global_step) 361 | 362 | if phase == 'test': 363 | 364 | if self.cfg.EVALUATE and self.cfg.CAL_LOSS: 365 | self.writer.add_scalar('VAL_CLS_LOSS', self.loss_meters['VAL_CLS_LOSS'].avg, 366 | global_step=global_step) 367 | self.writer.add_scalar('VAL_CLS_ACC', self.loss_meters['VAL_CLS_ACC'].avg, 368 | global_step=global_step) 369 | self.writer.add_scalar('VAL_CLS_MEAN_ACC', self.loss_meters['VAL_CLS_MEAN_ACC'].avg, 370 | global_step=global_step) 371 | 372 | if self.upsample and self.gen is not None and not self.cfg.NO_VIS: 373 | self.writer.add_image('Val_Source', 374 | torchvision.utils.make_grid(self.source_modal_show[:6].clone().cpu().data, 3, 375 | normalize=True), global_step=global_step) 376 | self.writer.add_image('Val_Gen', torchvision.utils.make_grid(self.gen[:6].clone().cpu().data, 377 | 3, normalize=True), global_step=global_step) 378 | self.writer.add_image('Val_Target', torchvision.utils.make_grid(self.target_modal_show[:6].clone().cpu() 379 | .data, 3, normalize=True), global_step=global_step) 380 | --------------------------------------------------------------------------------