├── .idea ├── .gitignore ├── STAR_Stochastic_Classifiers_for_UDA.iml ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md └── digit_signal_classification ├── README.md ├── datasets ├── __init__.pyc ├── base_data_loader.py ├── dataset_read.py ├── datasets.py ├── gtsrb.py ├── mnist.py ├── svhn.py ├── synth_traffic.py ├── unaligned_data_loader.py └── usps.py ├── doc ├── architecture.jpg ├── architecture_small.jpg └── architecture_small.png ├── main.py ├── model ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── build_gen.cpython-36.pyc │ └── build_gen.cpython-37.pyc ├── build_gen.py ├── build_gen.pyc ├── svhn2mnist.py ├── svhn2mnist.pyc ├── syn2gtrsb.py ├── syn2gtrsb.pyc ├── usps.py └── usps.pyc ├── record └── mnist_usps │ ├── test_2020-07-22-17-50-42.out │ └── train_2020-07-22-17-50-42.out ├── solver.py └── utils ├── avgmeter.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml 3 | -------------------------------------------------------------------------------- /.idea/STAR_Stochastic_Classifiers_for_UDA.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stochastic Classifiers for Unsupervised Domain Adaptation (CVPR2020) 2 | 3 | ## [[Paper Link]](https://openaccess.thecvf.com/content_CVPR_2020/papers/Lu_Stochastic_Classifiers_for_Unsupervised_Domain_Adaptation_CVPR_2020_paper.pdf) 4 | 5 | ## Short introduction 6 | 7 | This is the implementation for STAR (STochastic clAssifieRs). The main idea for that is to build a distribution over the weights of the classifiers. With that, infinite number of classifiers can be sampled without extra parameters. 8 | 9 | ## Architecture 10 | 11 | 12 | ## Citation 13 | 14 | If you find this helpful, please cite it. 15 | 16 | ``` 17 | @InProceedings{Lu_2020_CVPR, 18 | author = {Lu, Zhihe and Yang, Yongxin and Zhu, Xiatian and Liu, Cong and Song, Yi-Zhe and Xiang, Tao}, 19 | title = {Stochastic Classifiers for Unsupervised Domain Adaptation}, 20 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 21 | month = {June}, 22 | year = {2020} 23 | } 24 | ``` 25 | 26 | ## Further discussion 27 | 28 | If you have any problems, please contact zhihe.lu@surrey.ac.uk or simply write it in the issue session. -------------------------------------------------------------------------------- /digit_signal_classification/README.md: -------------------------------------------------------------------------------- 1 | ## Digit and Sign classification 2 | 3 | ### Getting Started 4 | #### Installation 5 | - Install PyTorch (Works on Version 0.2.0_3) and dependencies from http://pytorch.org. 6 | - Install Python 2.7. 7 | - Install tensorboardX. 8 | 9 | ### Download Dataset 10 | Download MNIST Dataset [here](https://drive.google.com/file/d/1cZ4vSIS-IKoyKWPfcgxFMugw0LtMiqPf/view?usp=sharing). Resized image dataset is contained in the file. 11 | Place it in the directory ./data. 12 | All other datasets should be placed in the directory too. 13 | 14 | ### Train 15 | Here is an example for task: MNIST to USPS, 16 | 17 | ``` 18 | python main.py \ 19 | --source mnist \ 20 | --target usps \ 21 | --num_classifiers_train 2 \ 22 | --lr 0.0002 \ 23 | --max_epoch 300 \ 24 | --all_use yes \ 25 | --optimizer adam 26 | ``` 27 | -------------------------------------------------------------------------------- /digit_signal_classification/datasets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/datasets/__init__.pyc -------------------------------------------------------------------------------- /digit_signal_classification/datasets/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self,batch_size): 6 | self.batch_size = batch_size 7 | self.serial_batches = 0 8 | self.nThreads = 2 9 | self.max_dataset_size=float("inf") 10 | pass 11 | 12 | def load_data(): 13 | return None 14 | -------------------------------------------------------------------------------- /digit_signal_classification/datasets/dataset_read.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unaligned_data_loader import UnalignedDataLoader 3 | from svhn import load_svhn 4 | from mnist import load_mnist 5 | from usps import load_usps 6 | from gtsrb import load_gtsrb 7 | from synth_traffic import load_syntraffic 8 | sys.path.append('../loader') 9 | 10 | 11 | def return_dataset(data, scale=False, usps=False, all_use='no'): 12 | if data == 'svhn': 13 | train_image, train_label, test_image, test_label = load_svhn() 14 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape, 15 | test_image.shape)) 16 | 17 | if data == 'mnist': 18 | train_image, train_label, test_image, test_label = load_mnist(scale=scale, usps=usps, all_use=all_use) 19 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape, 20 | test_image.shape)) 21 | 22 | if data == 'usps': 23 | train_image, train_label, test_image, test_label = load_usps(all_use=all_use) 24 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape, 25 | test_image.shape)) 26 | 27 | if data == 'synth': 28 | train_image, train_label, test_image, test_label = load_syntraffic() 29 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape, 30 | test_image.shape)) 31 | 32 | if data == 'gtsrb': 33 | train_image, train_label, test_image, test_label = load_gtsrb() 34 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape, 35 | test_image.shape)) 36 | 37 | return train_image, train_label, test_image, test_label 38 | 39 | 40 | def dataset_read(source, target, batch_size, scale=False, all_use='no'): 41 | S = {} 42 | S_test = {} 43 | T = {} 44 | T_test = {} 45 | usps = False 46 | if source == 'usps' or target == 'usps': 47 | usps = True 48 | 49 | train_source, s_label_train, test_source, s_label_test = return_dataset( 50 | source, 51 | scale=scale, 52 | usps=usps, 53 | all_use=all_use 54 | ) 55 | 56 | train_target, t_label_train, test_target, t_label_test = return_dataset( 57 | target, 58 | scale=scale, 59 | usps=usps, 60 | all_use=all_use 61 | ) 62 | 63 | S['imgs'] = train_source 64 | S['labels'] = s_label_train 65 | T['imgs'] = train_target 66 | T['labels'] = t_label_train 67 | 68 | # Input target samples for both as test target is not used 69 | S_test['imgs'] = test_target 70 | S_test['labels'] = t_label_test 71 | T_test['imgs'] = test_target 72 | T_test['labels'] = t_label_test 73 | 74 | scale = 40 if source == 'synth' else 28 if source == 'usps' or target == 'usps' else 32 75 | 76 | train_loader = UnalignedDataLoader() 77 | train_loader.initialize(S, T, batch_size, batch_size, scale=scale, shuffle_=True) 78 | dataset = train_loader.load_data() 79 | 80 | test_loader = UnalignedDataLoader() 81 | test_loader.initialize(S_test, T_test, batch_size, batch_size, scale=scale, shuffle_=True) 82 | dataset_test = test_loader.load_data() 83 | 84 | return dataset, dataset_test 85 | -------------------------------------------------------------------------------- /digit_signal_classification/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | class Dataset(data.Dataset): 8 | """Args: 9 | transform (callable, optional): A function/transform that takes in an PIL image 10 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 11 | target_transform (callable, optional): A function/transform that takes in the 12 | target and transforms it. 13 | download (bool, optional): If true, downloads the dataset from the internet and 14 | puts it in root directory. If dataset is already downloaded, it is not 15 | downloaded again. 16 | """ 17 | def __init__(self, data, label, transform=None,target_transform=None): 18 | 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | self.data = data 22 | self.labels = label 23 | 24 | def __getitem__(self, index): 25 | """ 26 | Args: 27 | index (int): Index 28 | Returns: 29 | tuple: (image, target) where target is index of the target class. 30 | """ 31 | 32 | img, target = self.data[index], self.labels[index] 33 | if img.shape[0] != 1: 34 | img = Image.fromarray(np.uint8(np.asarray(img.transpose((1, 2, 0))))) 35 | # 36 | elif img.shape[0] == 1: 37 | im = np.uint8(np.asarray(img)) 38 | im = np.vstack([im, im, im]).transpose((1, 2, 0)) 39 | img = Image.fromarray(im) 40 | 41 | if self.target_transform is not None: 42 | target = self.target_transform(target) 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | # return img, target 46 | return img, target 47 | 48 | def __len__(self): 49 | return len(self.data) 50 | -------------------------------------------------------------------------------- /digit_signal_classification/datasets/gtsrb.py: -------------------------------------------------------------------------------- 1 | mport numpy as np 2 | import cPickle as pkl 3 | 4 | 5 | def load_gtsrb(): 6 | data_target = pkl.load(open('/data/digit_sign/data_gtsrb')) 7 | target_train = np.random.permutation(len(data_target['image'])) 8 | data_t_im = data_target['image'][target_train[:31367], :, :, :] 9 | data_t_im_test = data_target['image'][target_train[31367:], :, :, :] 10 | data_t_label = data_target['label'][target_train[:31367]] + 1 11 | data_t_label_test = data_target['label'][target_train[31367:]] + 1 12 | data_t_im = data_t_im.transpose(0, 3, 1, 2).astype(np.float32) 13 | data_t_im_test = data_t_im_test.transpose(0, 3, 1, 2).astype(np.float32) 14 | return data_t_im, data_t_label, data_t_im_test, data_t_label_test 15 | -------------------------------------------------------------------------------- /digit_signal_classification/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | 4 | 5 | def load_mnist(scale=True, usps=False, all_use=False): 6 | mnist_data = loadmat('/data/digit_sign/mnist_data.mat') 7 | if scale: 8 | mnist_train = np.reshape(mnist_data['train_32'], (55000, 32, 32, 1)) 9 | mnist_test = np.reshape(mnist_data['test_32'], (10000, 32, 32, 1)) 10 | mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3) 11 | mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3) 12 | mnist_train = mnist_train.transpose(0, 3, 1, 2).astype(np.float32) 13 | mnist_test = mnist_test.transpose(0, 3, 1, 2).astype(np.float32) 14 | mnist_labels_train = mnist_data['label_train'] 15 | mnist_labels_test = mnist_data['label_test'] 16 | else: 17 | mnist_train = mnist_data['train_28'] 18 | mnist_test = mnist_data['test_28'] 19 | mnist_labels_train = mnist_data['label_train'] 20 | mnist_labels_test = mnist_data['label_test'] 21 | mnist_train = mnist_train.astype(np.float32) 22 | mnist_test = mnist_test.astype(np.float32) 23 | mnist_train = mnist_train.transpose((0, 3, 1, 2)) 24 | mnist_test = mnist_test.transpose((0, 3, 1, 2)) 25 | train_label = np.argmax(mnist_labels_train, axis=1) 26 | inds = np.random.permutation(mnist_train.shape[0]) 27 | mnist_train = mnist_train[inds] 28 | train_label = train_label[inds] 29 | test_label = np.argmax(mnist_labels_test, axis=1) 30 | if usps and all_use != 'yes': 31 | mnist_train = mnist_train[:2000] 32 | train_label = train_label[:2000] 33 | 34 | return mnist_train, train_label, mnist_test, test_label 35 | -------------------------------------------------------------------------------- /digit_signal_classification/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | rom scipy.io import loadmat 2 | import numpy as np 3 | import sys 4 | sys.path.append('../utils/') 5 | from utils.utils import dense_to_one_hot 6 | 7 | 8 | def load_svhn(): 9 | svhn_train = loadmat('/data/digit_sign/train_32x32.mat') 10 | svhn_test = loadmat('/data/digit_sign/test_32x32.mat') 11 | svhn_train_im = svhn_train['X'] 12 | svhn_train_im = svhn_train_im.transpose(3, 2, 0, 1).astype(np.float32) 13 | svhn_label = dense_to_one_hot(svhn_train['y']) 14 | svhn_test_im = svhn_test['X'] 15 | svhn_test_im = svhn_test_im.transpose(3, 2, 0, 1).astype(np.float32) 16 | svhn_label_test = dense_to_one_hot(svhn_test['y']) 17 | 18 | return svhn_train_im, svhn_label, svhn_test_im, svhn_label_test 19 | -------------------------------------------------------------------------------- /digit_signal_classification/datasets/synth_traffic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cPickle as pkl 3 | 4 | 5 | def load_syntraffic(): 6 | data_source = pkl.load(open('/data/digit_sign/data_synthetic')) 7 | source_train = np.random.permutation(len(data_source['image'])) 8 | data_s_im = data_source['image'][source_train[:len(data_source['image'])], :, :, :] 9 | data_s_im_test = data_source['image'][source_train[len(data_source['image']) - 2000:], :, :, :] 10 | data_s_label = data_source['label'][source_train[:len(data_source['image'])]] 11 | data_s_label_test = data_source['label'][source_train[len(data_source['image']) - 2000:]] 12 | data_s_im = data_s_im.transpose(0, 3, 1, 2).astype(np.float32) 13 | data_s_im_test = data_s_im_test.transpose(0, 3, 1, 2).astype(np.float32) 14 | return data_s_im, data_s_label, data_s_im_test, data_s_label_test -------------------------------------------------------------------------------- /digit_signal_classification/datasets/unaligned_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from builtins import object 3 | import torchvision.transforms as transforms 4 | from datasets import Dataset 5 | 6 | 7 | class PairedData(object): 8 | def __init__(self, data_loader_A, data_loader_B, max_dataset_size): 9 | self.data_loader_A = data_loader_A 10 | self.data_loader_B = data_loader_B 11 | self.stop_A = False 12 | self.stop_B = False 13 | self.max_dataset_size = max_dataset_size 14 | 15 | def __iter__(self): 16 | self.stop_A = False 17 | self.stop_B = False 18 | self.data_loader_A_iter = iter(self.data_loader_A) 19 | self.data_loader_B_iter = iter(self.data_loader_B) 20 | self.iter = 0 21 | return self 22 | 23 | def __next__(self): 24 | A, A_paths = None, None 25 | B, B_paths = None, None 26 | try: 27 | A, A_paths = next(self.data_loader_A_iter) 28 | except StopIteration: 29 | if A is None or A_paths is None: 30 | self.stop_A = True 31 | self.data_loader_A_iter = iter(self.data_loader_A) 32 | A, A_paths = next(self.data_loader_A_iter) 33 | 34 | try: 35 | B, B_paths = next(self.data_loader_B_iter) 36 | except StopIteration: 37 | if B is None or B_paths is None: 38 | self.stop_B = True 39 | self.data_loader_B_iter = iter(self.data_loader_B) 40 | B, B_paths = next(self.data_loader_B_iter) 41 | 42 | if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size: 43 | self.stop_A = False 44 | self.stop_B = False 45 | raise StopIteration() 46 | else: 47 | self.iter += 1 48 | return {'S': A, 'S_label': A_paths, 49 | 'T': B, 'T_label': B_paths} 50 | 51 | 52 | class UnalignedDataLoader(): 53 | def initialize(self, source, target, batch_size1, batch_size2, scale=32, shuffle_=False): 54 | transform = transforms.Compose([ 55 | transforms.Resize(scale), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 58 | ]) 59 | 60 | dataset_source = Dataset(source['imgs'], source['labels'], transform=transform) 61 | dataset_target = Dataset(target['imgs'], target['labels'], transform=transform) 62 | 63 | data_loader_s = torch.utils.data.DataLoader( 64 | dataset_source, 65 | batch_size=batch_size1, 66 | shuffle=shuffle_, 67 | num_workers=4 68 | ) 69 | 70 | data_loader_t = torch.utils.data.DataLoader( 71 | dataset_target, 72 | batch_size=batch_size2, 73 | shuffle=shuffle_, 74 | num_workers=4 75 | ) 76 | 77 | self.dataset_s = dataset_source 78 | self.dataset_t = dataset_target 79 | self.paired_data = PairedData(data_loader_s, data_loader_t, float("inf")) 80 | 81 | @staticmethod 82 | def name(): 83 | return 'UnalignedDataLoader' 84 | 85 | def load_data(self): 86 | return self.paired_data 87 | 88 | def __len__(self): 89 | return min(max(len(self.dataset_s), len(self.dataset_t)), float("inf")) 90 | -------------------------------------------------------------------------------- /digit_signal_classification/datasets/usps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gzip 3 | import cPickle 4 | 5 | 6 | def load_usps(all_use=False): 7 | f = gzip.open('/data/digit_sign/usps_28x28.pkl', 'rb') 8 | data_set = cPickle.load(f) 9 | f.close() 10 | img_train = data_set[0][0] 11 | label_train = data_set[0][1] 12 | img_test = data_set[1][0] 13 | label_test = data_set[1][1] 14 | inds = np.random.permutation(img_train.shape[0]) 15 | if all_use == 'yes': 16 | img_train = img_train[inds][:6562] 17 | label_train = label_train[inds][:6562] 18 | else: 19 | img_train = img_train[inds][:1800] 20 | label_train = label_train[inds][:1800] 21 | img_train = img_train * 255 22 | img_test = img_test * 255 23 | img_train = img_train.reshape((img_train.shape[0], 1, 28, 28)) 24 | img_test = img_test.reshape((img_test.shape[0], 1, 28, 28)) 25 | return img_train, label_train, img_test, label_test 26 | -------------------------------------------------------------------------------- /digit_signal_classification/doc/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/doc/architecture.jpg -------------------------------------------------------------------------------- /digit_signal_classification/doc/architecture_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/doc/architecture_small.jpg -------------------------------------------------------------------------------- /digit_signal_classification/doc/architecture_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/doc/architecture_small.png -------------------------------------------------------------------------------- /digit_signal_classification/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | from solver import Solver 5 | import os 6 | import time 7 | 8 | # Training settings 9 | parser = argparse.ArgumentParser(description='PyTorch Stochastic Classifiers Implementation') 10 | parser.add_argument('--all_use', type=str, default='no', metavar='N', 11 | help='use all training data? in usps adaptation') 12 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 13 | help='input batch size for training (default: 64)') 14 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', metavar='N', 15 | help='source only or not') 16 | parser.add_argument('--eval_only', action='store_true', default=False, 17 | help='evaluation only option') 18 | parser.add_argument('--lr', type=float, default=0.0002, metavar='LR', 19 | help='learning rate (default: 0.02)') 20 | parser.add_argument('--max_epoch', type=int, default=100, metavar='N', 21 | help='how many epochs') 22 | parser.add_argument('--no-cuda', action='store_true', default=False, 23 | help='disables CUDA training') 24 | parser.add_argument('--num_k', type=int, default=4, metavar='N', 25 | help='hyper parameter for generator update') 26 | parser.add_argument('--one_step', action='store_true', default=False, 27 | help='one step training with gradient reversal layer') 28 | parser.add_argument('--optimizer', type=str, default='momentum', metavar='N', 29 | help='which optimizer') 30 | parser.add_argument('--resume_epoch', type=int, default=100, metavar='N', 31 | help='epoch to resume') 32 | parser.add_argument('--save_epoch', type=int, default=10, metavar='N', 33 | help='when to restore the model') 34 | parser.add_argument('--save_model', action='store_true', default=False, 35 | help='save_model or not') 36 | parser.add_argument('--seed', type=int, default=1, metavar='S', 37 | help='random seed (default: 1)') 38 | parser.add_argument('--source', type=str, default='svhn', metavar='N', 39 | help='source dataset') 40 | parser.add_argument('--target', type=str, default='mnist', metavar='N', 41 | help='target dataset') 42 | parser.add_argument('--use_abs_diff', action='store_true', default=False, 43 | help='use absolute difference value as a measurement') 44 | parser.add_argument('--num_classifiers_train', type=int, default=2, metavar='N', 45 | help='the number of classifiers used in training') 46 | parser.add_argument('--num_classifiers_test', type=int, default=20, metavar='N', 47 | help='the number of classifiers used in testing') 48 | parser.add_argument('--gpu_devices', type=str, default='0', help='the device you use') 49 | parser.add_argument('--loss_process', type=str, default='sum', 50 | help='mean or sum of the loss') 51 | parser.add_argument('--log_dir', type=str, default='record', metavar='N', 52 | help='the place to store the logs') 53 | parser.add_argument('--init', type=str, default='kaiming_u', metavar='N', 54 | help='the initialization method') 55 | parser.add_argument('--use_init', action='store_true', default=False, 56 | help='whether use initialization') 57 | 58 | args = parser.parse_args() 59 | args.cuda = not args.no_cuda and torch.cuda.is_available() 60 | torch.manual_seed(args.seed) 61 | 62 | if args.cuda and args.gpu_devices: 63 | torch.cuda.manual_seed(args.seed) 64 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 65 | 66 | print(args) 67 | 68 | 69 | def main(): 70 | solver = Solver( 71 | args, source=args.source, 72 | target=args.target, 73 | learning_rate=args.lr, 74 | batch_size=args.batch_size, 75 | optimizer=args.optimizer, 76 | num_k=args.num_k, 77 | all_use=args.all_use, 78 | checkpoint_dir=args.checkpoint_dir, 79 | save_epoch=args.save_epoch, 80 | num_classifiers_train=args.num_classifiers_train, 81 | num_classifiers_test=args.num_classifiers_test, 82 | init=args.init, 83 | use_init=args.use_init 84 | ) 85 | 86 | record_time = time.strftime('%Y-%m-%d-%H-%M-%S') 87 | 88 | if args.use_init: 89 | record_dir = '{}/{}_{}'.format( 90 | args.log_dir, 91 | args.source, 92 | args.target 93 | ) 94 | else: 95 | record_dir = '{}/{}_{}'.format( 96 | args.log_dir, 97 | args.source, 98 | args.target, 99 | ) 100 | 101 | if not os.path.exists(args.checkpoint_dir): 102 | os.mkdir(args.checkpoint_dir) 103 | if not os.path.exists(args.log_dir): 104 | os.mkdir(args.log_dir) 105 | if not os.path.exists(record_dir): 106 | os.mkdir(record_dir) 107 | 108 | record_train = '{}/train_{}.out'.format(record_dir, record_time) 109 | record_test = '{}/test_{}.out'.format(record_dir, record_time) 110 | 111 | # Log the configures into log file 112 | record = open(record_test, 'a') 113 | record.write('Configures: {} \n'.format(args)) 114 | record.close() 115 | 116 | if args.eval_only: 117 | solver.test(0) 118 | else: 119 | for epoch in range(args.max_epoch): 120 | solver.train( 121 | epoch, 122 | record_file=record_train, 123 | loss_process=args.loss_process 124 | ) 125 | 126 | if epoch % 1 == 0: 127 | solver.test( 128 | epoch, 129 | record_file=record_test, 130 | save_model=args.save_model 131 | ) 132 | 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /digit_signal_classification/model/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__init__.pyc -------------------------------------------------------------------------------- /digit_signal_classification/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /digit_signal_classification/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /digit_signal_classification/model/__pycache__/build_gen.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__pycache__/build_gen.cpython-36.pyc -------------------------------------------------------------------------------- /digit_signal_classification/model/__pycache__/build_gen.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__pycache__/build_gen.cpython-37.pyc -------------------------------------------------------------------------------- /digit_signal_classification/model/build_gen.py: -------------------------------------------------------------------------------- 1 | import svhn2mnist 2 | import usps 3 | import syn2gtrsb 4 | 5 | 6 | def Generator(source, target): 7 | if source == 'usps' or target == 'usps': 8 | return usps.Feature() 9 | elif source == 'svhn': 10 | return svhn2mnist.Feature() 11 | elif source == 'synth': 12 | return syn2gtrsb.Feature() 13 | 14 | 15 | def Classifier( 16 | source, 17 | target, 18 | num_classifiers_train=2, 19 | num_classifiers_test=1, 20 | init='kaiming_u', 21 | use_init=False 22 | ): 23 | 24 | if source == 'usps' or target == 'usps': 25 | return usps.Predictor( 26 | num_classifiers_train, 27 | num_classifiers_test, 28 | init, 29 | use_init 30 | ) 31 | 32 | if source == 'svhn': 33 | return svhn2mnist.Predictor( 34 | num_classifiers_train, 35 | num_classifiers_test, 36 | init, 37 | use_init 38 | ) 39 | 40 | if source == 'synth': 41 | return syn2gtrsb.Predictor( 42 | num_classifiers_train, 43 | num_classifiers_test, 44 | init, 45 | use_init 46 | ) 47 | 48 | -------------------------------------------------------------------------------- /digit_signal_classification/model/build_gen.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/build_gen.pyc -------------------------------------------------------------------------------- /digit_signal_classification/model/svhn2mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import torch.distributions.normal as normal 6 | 7 | 8 | class Feature(nn.Module): 9 | def __init__(self): 10 | super(Feature, self).__init__() 11 | self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2) 12 | self.bn1 = nn.BatchNorm2d(64) 13 | self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2) 14 | self.bn2 = nn.BatchNorm2d(64) 15 | self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2) 16 | self.bn3 = nn.BatchNorm2d(128) 17 | self.fc1 = nn.Linear(8192, 3072) 18 | self.bn1_fc = nn.BatchNorm1d(3072) 19 | 20 | def forward(self, x): 21 | x = F.max_pool2d( 22 | F.relu(self.bn1(self.conv1(x))), 23 | stride=2, 24 | kernel_size=3, 25 | padding=1 26 | ) 27 | x = F.max_pool2d( 28 | F.relu(self.bn2(self.conv2(x))), 29 | stride=2, 30 | kernel_size=3, 31 | padding=1 32 | ) 33 | x = F.relu(self.bn3(self.conv3(x))) 34 | x = x.view(x.size(0), 8192) 35 | x = F.relu(self.bn1_fc(self.fc1(x))) 36 | x = F.dropout(x, training=self.training) 37 | return x 38 | 39 | 40 | class Predictor(nn.Module): 41 | def __init__( 42 | self, num_classifiers_train=2, 43 | num_classifiers_test=20, 44 | init='kaiming_u', 45 | use_init=False 46 | ): 47 | super(Predictor, self).__init__() 48 | self.num_classifiers_train = num_classifiers_train 49 | self.num_classifiers_test = num_classifiers_test 50 | self.init = init 51 | 52 | function_init = { 53 | 'kaiming_u': nn.init.kaiming_uniform_, 54 | 'kaiming_n': nn.init.kaiming_normal_, 55 | 'xavier': nn.init.xavier_normal_ 56 | } 57 | 58 | self.fc1 = nn.Linear(3072, 2048) 59 | 60 | self.mu2 = Parameter(torch.randn(10, 2048)) 61 | self.sigma2 = Parameter(torch.zeros(10, 2048)) 62 | 63 | if use_init: 64 | all_parameters = [self.mu2, self.sigma2] 65 | for item in all_parameters: 66 | function_init[self.init](item) 67 | 68 | self.b2 = Parameter(torch.zeros(10)) 69 | self.bn1_fc = nn.BatchNorm1d(2048) 70 | 71 | def forward(self, x, only_mu=True): 72 | 73 | x = self.fc1(x) 74 | x = F.relu(self.bn1_fc(x)) 75 | 76 | sigma2_pos = torch.sigmoid(self.sigma2) 77 | fc2_distribution = normal.Normal(self.mu2, sigma2_pos) 78 | 79 | if self.training: 80 | classifiers = [] 81 | for index in range(self.num_classifiers_train): 82 | fc2_w = fc2_distribution.rsample() 83 | one_classifier = [fc2_w, self.b2] 84 | 85 | classifiers.append(one_classifier) 86 | 87 | outputs = [] 88 | for index in range(self.num_classifiers_train): 89 | out = F.linear(x, classifiers[index][0], classifiers[index][1]) 90 | outputs.append(out) 91 | 92 | return outputs 93 | else: 94 | if only_mu: 95 | # Only use mu for classification 96 | out = F.linear(x, self.mu2, self.b2) 97 | return [out] 98 | else: 99 | classifiers = [] 100 | for index in range(self.num_classifiers_test): 101 | fc2_w = fc2_distribution.rsample() 102 | one_classifier = [fc2_w, self.b2] 103 | 104 | classifiers.append(one_classifier) 105 | 106 | outputs = [] 107 | for index in range(self.num_classifiers_test): 108 | out = F.linear(x, classifiers[index][0], classifiers[index][1]) 109 | outputs.append(out) 110 | 111 | return outputs 112 | -------------------------------------------------------------------------------- /digit_signal_classification/model/svhn2mnist.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/svhn2mnist.pyc -------------------------------------------------------------------------------- /digit_signal_classification/model/syn2gtrsb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import torch.distributions.normal as normal 6 | 7 | 8 | class Feature(nn.Module): 9 | def __init__(self): 10 | super(Feature, self).__init__() 11 | self.conv1 = nn.Conv2d(3, 96, kernel_size=5, stride=1, padding=2) 12 | self.bn1 = nn.BatchNorm2d(96) 13 | self.conv2 = nn.Conv2d(96, 144, kernel_size=3, stride=1, padding=1) 14 | self.bn2 = nn.BatchNorm2d(144) 15 | self.conv3 = nn.Conv2d(144, 256, kernel_size=5, stride=1, padding=2) 16 | self.bn3 = nn.BatchNorm2d(256) 17 | 18 | def forward(self, x): 19 | x = F.max_pool2d( 20 | F.relu(self.bn1(self.conv1(x))), 21 | stride=2, 22 | kernel_size=2, 23 | padding=0 24 | ) 25 | x = F.max_pool2d( 26 | F.relu(self.bn2(self.conv2(x))), 27 | stride=2, 28 | kernel_size=2, 29 | padding=0 30 | ) 31 | x = F.max_pool2d( 32 | F.relu(self.bn3(self.conv3(x))), 33 | stride=2, 34 | kernel_size=2, 35 | padding=0 36 | ) 37 | x = x.view(x.size(0), 6400) 38 | return x 39 | 40 | 41 | class Predictor(nn.Module): 42 | def __init__( 43 | self, 44 | num_classifiers_train=2, 45 | num_classifiers_test=1, 46 | use_init=False 47 | ): 48 | super(Predictor, self).__init__() 49 | self.num_classifiers_train = num_classifiers_train 50 | self.num_classifiers_test = num_classifiers_test 51 | self.init = init 52 | function_init = { 53 | 'kaiming_u': nn.init.kaiming_uniform_, 54 | 'kaiming_n': nn.init.kaiming_normal_, 55 | 'xavier': nn.init.xavier_normal_ 56 | } 57 | 58 | self.fc1 = nn.Linear(6400, 512) 59 | self.bn1_fc = nn.BatchNorm1d(512) 60 | 61 | self.fc2_mu = Parameter(torch.randn(43, 512)) 62 | self.fc2_sigma = Parameter(torch.zeros(43, 512)) 63 | self.fc2_bias = Parameter(torch.zeros(43)) 64 | 65 | if use_init: 66 | function_init[self.init](self.fc2_mu) 67 | 68 | def forward(self, x, only_mu=True): 69 | 70 | x = F.relu(self.bn1_fc(self.fc1(x))) 71 | x = F.dropout(x, training=self.training) 72 | 73 | fc2_sigma_pos = F.softplus(self.fc2_sigma - 2) 74 | fc2_distribution = normal.Normal(self.fc2_mu, fc2_sigma_pos) 75 | 76 | if self.training: 77 | classifiers = [] 78 | for index in range(self.num_classifiers_train): 79 | fc2_w = fc2_distribution.rsample() 80 | classifiers.append(fc2_w) 81 | 82 | outputs = [] 83 | for index in range(self.num_classifiers_train): 84 | out = F.linear(x, classifiers[index], self.fc2_bias) 85 | outputs.append(out) 86 | return outputs 87 | else: 88 | if only_mu: 89 | # Only use mu for classification 90 | out = F.linear(x, self.fc2_mu, self.fc2_bias) 91 | return [out] 92 | else: 93 | classifiers = [] 94 | for index in range(self.num_classifiers_test): 95 | fc2_w = fc2_distribution.rsample() 96 | classifiers.append(fc2_w) 97 | 98 | outputs = [] 99 | for index in range(self.num_classifiers_test): 100 | out = F.linear(x, classifiers[index], self.fc2_bias) 101 | outputs.append(out) 102 | return outputs 103 | -------------------------------------------------------------------------------- /digit_signal_classification/model/syn2gtrsb.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/syn2gtrsb.pyc -------------------------------------------------------------------------------- /digit_signal_classification/model/usps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import torch.distributions.normal as normal 6 | 7 | 8 | class Feature(nn.Module): 9 | def __init__(self): 10 | super(Feature, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1) 12 | self.bn1 = nn.BatchNorm2d(32) 13 | self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1) 14 | self.bn2 = nn.BatchNorm2d(48) 15 | 16 | def forward(self, x): 17 | x = torch.mean(x, 1).view(x.size()[0], 1, x.size()[2], x.size()[3]) 18 | 19 | x = F.max_pool2d( 20 | F.relu(self.bn1(self.conv1(x))), 21 | stride=2, 22 | kernel_size=2, 23 | dilation=(1, 1) 24 | ) 25 | 26 | x = F.max_pool2d( 27 | F.relu(self.bn2(self.conv2(x))), 28 | stride=2, 29 | kernel_size=2, 30 | dilation=(1, 1) 31 | ) 32 | 33 | x = x.view(x.size(0), 48*4*4) 34 | return x 35 | 36 | 37 | class Predictor(nn.Module): 38 | def __init__( 39 | self, 40 | num_classifiers_train=2, 41 | num_classifiers_test=20, 42 | init='kaiming_u', 43 | use_init=False, 44 | prob=0.5 45 | ): 46 | super(Predictor, self).__init__() 47 | self.num_classifiers_train = num_classifiers_train 48 | self.num_classifiers_test = num_classifiers_test 49 | self.prob = prob 50 | self.init = init 51 | 52 | function_init = { 53 | 'kaiming_u': nn.init.kaiming_uniform_, 54 | 'kaiming_n': nn.init.kaiming_normal_, 55 | 'xavier': nn.init.xavier_normal_ 56 | } 57 | 58 | self.fc1 = nn.Linear(48*4*4, 100) 59 | self.bn1_fc = nn.BatchNorm1d(100) 60 | 61 | self.fc2 = nn.Linear(100, 100) 62 | self.bn2_fc = nn.BatchNorm1d(100) 63 | 64 | # Use distribution in the last layer 65 | self.fc3_mu = Parameter(torch.randn(10, 100)) 66 | self.fc3_sigma = Parameter(torch.zeros(10, 100)) 67 | self.fc3_bias = Parameter(torch.zeros(10)) 68 | 69 | if use_init: 70 | function_init[init](self.fc3_mu) 71 | 72 | def forward(self, x, only_mu=True): 73 | x = F.dropout(x, training=self.training, p=self.prob) 74 | x = F.relu(self.bn1_fc(self.fc1(x))) 75 | x = F.dropout(x, training=self.training, p=self.prob) 76 | x = F.relu(self.bn2_fc(self.fc2(x))) 77 | 78 | # Distribution sample for the fc layer 79 | fc3_sigma_pos = F.softplus(self.fc3_sigma - 2) 80 | fc3_distribution = normal.Normal(self.fc3_mu, fc3_sigma_pos) 81 | 82 | if self.training: 83 | classifiers = [] 84 | for index in range(self.num_classifiers_train): 85 | fc3_w = fc3_distribution.rsample() 86 | classifiers.append(fc3_w) 87 | 88 | outputs = [] 89 | for index in range(self.num_classifiers_train): 90 | out = F.linear(x, classifiers[index], self.fc3_bias) 91 | outputs.append(out) 92 | 93 | return outputs 94 | 95 | else: 96 | if only_mu: 97 | # Only use mu for classification 98 | out = F.linear(x, self.fc3_mu, self.fc3_bias) 99 | return [out] 100 | else: 101 | classifiers = [] 102 | for index in range(self.num_classifiers_test): 103 | fc3_w = fc3_distribution.rsample() 104 | classifiers.append(fc3_w) 105 | 106 | outputs = [] 107 | for index in range(self.num_classifiers_test): 108 | out = F.linear(x, classifiers[index], self.fc3_bias) 109 | outputs.append(out) 110 | 111 | return outputs 112 | -------------------------------------------------------------------------------- /digit_signal_classification/model/usps.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/usps.pyc -------------------------------------------------------------------------------- /digit_signal_classification/record/mnist_usps/test_2020-07-22-17-50-42.out: -------------------------------------------------------------------------------- 1 | Configures: Namespace(all_use='yes', batch_size=128, checkpoint_dir='checkpoint', constraint='softplus', cuda=True, eval_only=False, gpu_devices='0', init='kaiming_u', log_dir='record', loss_process='mean', lr=0.0002, max_epoch=200, no_cuda=False, num_classifiers_test=20, num_classifiers_train=2, num_k=4, one_step=False, optimizer='adam', resume_epoch=100, save_epoch=10, save_model=False, seed=1, source='mnist', target='usps', use_abs_diff=False, use_init=False) 2 | Accuracy: 58.66 3 | Accuracy: 69.14 4 | Accuracy: 73.76 5 | Accuracy: 77.37 6 | Accuracy: 80.91 7 | Accuracy: 81.88 8 | Accuracy: 84.19 9 | Accuracy: 85.16 10 | Accuracy: 86.94 11 | Accuracy: 88.33 12 | Accuracy: 88.60 13 | Accuracy: 88.98 14 | Accuracy: 89.73 15 | Accuracy: 89.89 16 | Accuracy: 89.95 17 | Accuracy: 90.22 18 | Accuracy: 91.08 19 | Accuracy: 91.83 20 | Accuracy: 92.42 21 | Accuracy: 91.56 22 | Accuracy: 92.80 23 | Accuracy: 93.12 24 | Accuracy: 92.85 25 | Accuracy: 93.23 26 | Accuracy: 93.98 27 | Accuracy: 93.87 28 | Accuracy: 94.03 29 | Accuracy: 93.55 30 | Accuracy: 94.30 31 | Accuracy: 93.98 32 | Accuracy: 94.25 33 | Accuracy: 94.14 34 | Accuracy: 94.62 35 | Accuracy: 94.68 36 | Accuracy: 94.68 37 | Accuracy: 95.22 38 | Accuracy: 95.11 39 | Accuracy: 95.48 40 | Accuracy: 95.11 41 | Accuracy: 95.22 42 | Accuracy: 94.73 43 | Accuracy: 95.43 44 | Accuracy: 95.16 45 | Accuracy: 95.27 46 | Accuracy: 95.16 47 | Accuracy: 94.95 48 | Accuracy: 95.11 49 | Accuracy: 95.97 50 | Accuracy: 95.86 51 | Accuracy: 95.59 52 | Accuracy: 95.81 53 | Accuracy: 95.75 54 | Accuracy: 95.65 55 | Accuracy: 94.95 56 | Accuracy: 95.81 57 | Accuracy: 95.70 58 | Accuracy: 95.91 59 | Accuracy: 96.34 60 | Accuracy: 95.48 61 | Accuracy: 96.18 62 | Accuracy: 96.24 63 | Accuracy: 96.34 64 | Accuracy: 96.61 65 | Accuracy: 96.45 66 | Accuracy: 96.29 67 | Accuracy: 96.45 68 | Accuracy: 96.40 69 | Accuracy: 96.02 70 | Accuracy: 95.75 71 | Accuracy: 96.18 72 | Accuracy: 96.13 73 | Accuracy: 96.99 74 | Accuracy: 96.51 75 | Accuracy: 96.51 76 | Accuracy: 96.24 77 | Accuracy: 96.45 78 | Accuracy: 97.04 79 | Accuracy: 96.67 80 | Accuracy: 97.10 81 | Accuracy: 97.37 82 | Accuracy: 97.04 83 | Accuracy: 96.45 84 | Accuracy: 96.99 85 | Accuracy: 96.40 86 | Accuracy: 97.31 87 | Accuracy: 96.45 88 | Accuracy: 96.88 89 | Accuracy: 97.42 90 | Accuracy: 96.83 91 | Accuracy: 97.20 92 | Accuracy: 96.99 93 | Accuracy: 96.83 94 | Accuracy: 96.72 95 | Accuracy: 97.26 96 | Accuracy: 97.15 97 | Accuracy: 96.83 98 | Accuracy: 96.99 99 | Accuracy: 97.20 100 | Accuracy: 97.04 101 | Accuracy: 97.42 102 | Accuracy: 97.26 103 | Accuracy: 97.04 104 | Accuracy: 96.94 105 | Accuracy: 96.94 106 | Accuracy: 97.42 107 | Accuracy: 97.42 108 | Accuracy: 97.42 109 | Accuracy: 97.15 110 | Accuracy: 97.31 111 | Accuracy: 97.26 112 | Accuracy: 97.04 113 | Accuracy: 97.15 114 | Accuracy: 97.10 115 | Accuracy: 97.37 116 | Accuracy: 97.26 117 | Accuracy: 96.83 118 | Accuracy: 97.53 119 | Accuracy: 97.15 120 | Accuracy: 97.47 121 | Accuracy: 97.31 122 | Accuracy: 97.10 123 | Accuracy: 97.10 124 | Accuracy: 96.94 125 | Accuracy: 97.37 126 | Accuracy: 97.15 127 | Accuracy: 96.94 128 | Accuracy: 97.20 129 | Accuracy: 97.69 130 | Accuracy: 96.99 131 | Accuracy: 97.42 132 | Accuracy: 97.26 133 | Accuracy: 97.42 134 | Accuracy: 97.53 135 | Accuracy: 97.20 136 | Accuracy: 97.20 137 | Accuracy: 97.26 138 | Accuracy: 97.31 139 | Accuracy: 97.69 140 | Accuracy: 97.42 141 | Accuracy: 97.63 142 | Accuracy: 96.99 143 | Accuracy: 97.37 144 | Accuracy: 97.20 145 | Accuracy: 97.47 146 | Accuracy: 97.69 147 | Accuracy: 97.42 148 | Accuracy: 97.63 149 | Accuracy: 97.15 150 | Accuracy: 97.63 151 | Accuracy: 96.99 152 | Accuracy: 96.94 153 | Accuracy: 97.42 154 | Accuracy: 97.10 155 | Accuracy: 97.58 156 | Accuracy: 97.26 157 | Accuracy: 97.26 158 | Accuracy: 97.15 159 | Accuracy: 97.53 160 | Accuracy: 97.15 161 | Accuracy: 97.74 162 | Accuracy: 97.15 163 | Accuracy: 97.15 164 | Accuracy: 97.63 165 | Accuracy: 97.37 166 | Accuracy: 97.63 167 | Accuracy: 97.20 168 | Accuracy: 97.85 169 | Accuracy: 96.94 170 | Accuracy: 97.53 171 | Accuracy: 97.31 172 | Accuracy: 97.63 173 | Accuracy: 97.15 174 | Accuracy: 97.37 175 | Accuracy: 97.53 176 | Accuracy: 97.69 177 | Accuracy: 97.26 178 | Accuracy: 97.20 179 | Accuracy: 97.37 180 | Accuracy: 97.20 181 | Accuracy: 97.53 182 | Accuracy: 97.80 183 | Accuracy: 97.69 184 | Accuracy: 97.85 185 | Accuracy: 97.63 186 | Accuracy: 97.85 187 | Accuracy: 97.63 188 | Accuracy: 97.80 189 | Accuracy: 97.69 190 | Accuracy: 97.69 191 | Accuracy: 97.47 192 | Accuracy: 97.15 193 | Accuracy: 97.74 194 | Accuracy: 97.74 195 | Accuracy: 97.37 196 | Accuracy: 97.26 197 | Accuracy: 97.15 198 | Accuracy: 97.15 199 | Accuracy: 97.74 200 | Accuracy: 97.26 201 | Accuracy: 97.80 202 | -------------------------------------------------------------------------------- /digit_signal_classification/record/mnist_usps/train_2020-07-22-17-50-42.out: -------------------------------------------------------------------------------- 1 | Dis Loss: 0.0253108087927, Cls Loss: 7.70617485046, Lr C: 0.0002, Lr G: 0.0002 2 | Dis Loss: 0.0257472600788, Cls Loss: 2.13100028038, Lr C: 0.0002, Lr G: 0.0002 3 | Dis Loss: 0.0211079176515, Cls Loss: 0.869371533394, Lr C: 0.0002, Lr G: 0.0002 4 | Dis Loss: 0.0197001192719, Cls Loss: 0.90205258131, Lr C: 0.0002, Lr G: 0.0002 5 | Dis Loss: 0.0156597942114, Cls Loss: 0.794815421104, Lr C: 0.0002, Lr G: 0.0002 6 | Dis Loss: 0.0165562909096, Cls Loss: 0.698212444782, Lr C: 0.0002, Lr G: 0.0002 7 | Dis Loss: 0.0161991883069, Cls Loss: 0.591216087341, Lr C: 0.0002, Lr G: 0.0002 8 | Dis Loss: 0.0116136167198, Cls Loss: 0.447472929955, Lr C: 0.0002, Lr G: 0.0002 9 | Dis Loss: 0.0123001784086, Cls Loss: 0.585250616074, Lr C: 0.0002, Lr G: 0.0002 10 | Dis Loss: 0.0130175296217, Cls Loss: 0.244773492217, Lr C: 0.0002, Lr G: 0.0002 11 | Dis Loss: 0.0131531087682, Cls Loss: 0.251306712627, Lr C: 0.0002, Lr G: 0.0002 12 | Dis Loss: 0.0155279133469, Cls Loss: 0.4324208498, Lr C: 0.0002, Lr G: 0.0002 13 | Dis Loss: 0.00917468499392, Cls Loss: 0.191932618618, Lr C: 0.0002, Lr G: 0.0002 14 | Dis Loss: 0.0103086763993, Cls Loss: 0.259280085564, Lr C: 0.0002, Lr G: 0.0002 15 | Dis Loss: 0.0101417507976, Cls Loss: 0.349472016096, Lr C: 0.0002, Lr G: 0.0002 16 | Dis Loss: 0.0107553582639, Cls Loss: 0.242658004165, Lr C: 0.0002, Lr G: 0.0002 17 | Dis Loss: 0.00794821791351, Cls Loss: 0.169440254569, Lr C: 0.0002, Lr G: 0.0002 18 | Dis Loss: 0.00926570873708, Cls Loss: 0.170461535454, Lr C: 0.0002, Lr G: 0.0002 19 | Dis Loss: 0.00805878080428, Cls Loss: 0.371224373579, Lr C: 0.0002, Lr G: 0.0002 20 | Dis Loss: 0.00677244085819, Cls Loss: 0.183727711439, Lr C: 0.0002, Lr G: 0.0002 21 | Dis Loss: 0.00799888651818, Cls Loss: 0.167165458202, Lr C: 0.0002, Lr G: 0.0002 22 | Dis Loss: 0.00736241042614, Cls Loss: 0.178331822157, Lr C: 0.0002, Lr G: 0.0002 23 | Dis Loss: 0.00778476614505, Cls Loss: 0.0710266232491, Lr C: 0.0002, Lr G: 0.0002 24 | Dis Loss: 0.0095592122525, Cls Loss: 0.206102699041, Lr C: 0.0002, Lr G: 0.0002 25 | Dis Loss: 0.00676157185808, Cls Loss: 0.160276055336, Lr C: 0.0002, Lr G: 0.0002 26 | Dis Loss: 0.00798108335584, Cls Loss: 0.234084278345, Lr C: 0.0002, Lr G: 0.0002 27 | Dis Loss: 0.00671329069883, Cls Loss: 0.249807089567, Lr C: 0.0002, Lr G: 0.0002 28 | Dis Loss: 0.00961916334927, Cls Loss: 0.10863583535, Lr C: 0.0002, Lr G: 0.0002 29 | Dis Loss: 0.00781802367419, Cls Loss: 0.116392239928, Lr C: 0.0002, Lr G: 0.0002 30 | Dis Loss: 0.00667194742709, Cls Loss: 0.123922757804, Lr C: 0.0002, Lr G: 0.0002 31 | Dis Loss: 0.00642341328785, Cls Loss: 0.179728776217, Lr C: 0.0002, Lr G: 0.0002 32 | Dis Loss: 0.00476458016783, Cls Loss: 0.134883105755, Lr C: 0.0002, Lr G: 0.0002 33 | Dis Loss: 0.00620850687847, Cls Loss: 0.146140903234, Lr C: 0.0002, Lr G: 0.0002 34 | Dis Loss: 0.00488848984241, Cls Loss: 0.105417221785, Lr C: 0.0002, Lr G: 0.0002 35 | Dis Loss: 0.00591717753559, Cls Loss: 0.0877575054765, Lr C: 0.0002, Lr G: 0.0002 36 | Dis Loss: 0.0071969195269, Cls Loss: 0.0746075063944, Lr C: 0.0002, Lr G: 0.0002 37 | Dis Loss: 0.00616848794743, Cls Loss: 0.193884968758, Lr C: 0.0002, Lr G: 0.0002 38 | Dis Loss: 0.00574999628589, Cls Loss: 0.165285438299, Lr C: 0.0002, Lr G: 0.0002 39 | Dis Loss: 0.00534277455881, Cls Loss: 0.122619472444, Lr C: 0.0002, Lr G: 0.0002 40 | Dis Loss: 0.00775020429865, Cls Loss: 0.211391568184, Lr C: 0.0002, Lr G: 0.0002 41 | Dis Loss: 0.00501205166802, Cls Loss: 0.0465583950281, Lr C: 0.0002, Lr G: 0.0002 42 | Dis Loss: 0.00527259055525, Cls Loss: 0.140332207084, Lr C: 0.0002, Lr G: 0.0002 43 | Dis Loss: 0.00588731607422, Cls Loss: 0.146271795034, Lr C: 0.0002, Lr G: 0.0002 44 | Dis Loss: 0.00705251982436, Cls Loss: 0.071382895112, Lr C: 0.0002, Lr G: 0.0002 45 | Dis Loss: 0.00553363654763, Cls Loss: 0.14066195488, Lr C: 0.0002, Lr G: 0.0002 46 | Dis Loss: 0.00580461835489, Cls Loss: 0.102398604155, Lr C: 0.0002, Lr G: 0.0002 47 | Dis Loss: 0.00497902650386, Cls Loss: 0.0818369686604, Lr C: 0.0002, Lr G: 0.0002 48 | Dis Loss: 0.0048161377199, Cls Loss: 0.129988700151, Lr C: 0.0002, Lr G: 0.0002 49 | Dis Loss: 0.00550159625709, Cls Loss: 0.0817996561527, Lr C: 0.0002, Lr G: 0.0002 50 | Dis Loss: 0.00509450910613, Cls Loss: 0.0380837135017, Lr C: 0.0002, Lr G: 0.0002 51 | Dis Loss: 0.00466236053035, Cls Loss: 0.0444295965135, Lr C: 0.0002, Lr G: 0.0002 52 | Dis Loss: 0.00520153809339, Cls Loss: 0.0979211330414, Lr C: 0.0002, Lr G: 0.0002 53 | Dis Loss: 0.00325969606638, Cls Loss: 0.038551196456, Lr C: 0.0002, Lr G: 0.0002 54 | Dis Loss: 0.00417832052335, Cls Loss: 0.284563630819, Lr C: 0.0002, Lr G: 0.0002 55 | Dis Loss: 0.00407133577392, Cls Loss: 0.147980719805, Lr C: 0.0002, Lr G: 0.0002 56 | Dis Loss: 0.005461496301, Cls Loss: 0.0501725636423, Lr C: 0.0002, Lr G: 0.0002 57 | Dis Loss: 0.00335832685232, Cls Loss: 0.143918901682, Lr C: 0.0002, Lr G: 0.0002 58 | Dis Loss: 0.00428040558472, Cls Loss: 0.0378617122769, Lr C: 0.0002, Lr G: 0.0002 59 | Dis Loss: 0.00598321203142, Cls Loss: 0.0431411862373, Lr C: 0.0002, Lr G: 0.0002 60 | Dis Loss: 0.00505055859685, Cls Loss: 0.0629364699125, Lr C: 0.0002, Lr G: 0.0002 61 | Dis Loss: 0.00350512797013, Cls Loss: 0.0760557428002, Lr C: 0.0002, Lr G: 0.0002 62 | Dis Loss: 0.00371325272135, Cls Loss: 0.0363791808486, Lr C: 0.0002, Lr G: 0.0002 63 | Dis Loss: 0.00418157037348, Cls Loss: 0.12404447794, Lr C: 0.0002, Lr G: 0.0002 64 | Dis Loss: 0.00380450929515, Cls Loss: 0.0583959147334, Lr C: 0.0002, Lr G: 0.0002 65 | Dis Loss: 0.00499662337825, Cls Loss: 0.0246373452246, Lr C: 0.0002, Lr G: 0.0002 66 | Dis Loss: 0.00492464518175, Cls Loss: 0.152146041393, Lr C: 0.0002, Lr G: 0.0002 67 | Dis Loss: 0.00411057798192, Cls Loss: 0.0738733112812, Lr C: 0.0002, Lr G: 0.0002 68 | Dis Loss: 0.00432006036863, Cls Loss: 0.0773489102721, Lr C: 0.0002, Lr G: 0.0002 69 | Dis Loss: 0.005784294568, Cls Loss: 0.0629475861788, Lr C: 0.0002, Lr G: 0.0002 70 | Dis Loss: 0.00267514935695, Cls Loss: 0.10301348567, Lr C: 0.0002, Lr G: 0.0002 71 | Dis Loss: 0.00302348122932, Cls Loss: 0.0238332692534, Lr C: 0.0002, Lr G: 0.0002 72 | Dis Loss: 0.00361108197831, Cls Loss: 0.0807729959488, Lr C: 0.0002, Lr G: 0.0002 73 | Dis Loss: 0.00284019764513, Cls Loss: 0.12972868979, Lr C: 0.0002, Lr G: 0.0002 74 | Dis Loss: 0.00362093793228, Cls Loss: 0.0568101257086, Lr C: 0.0002, Lr G: 0.0002 75 | Dis Loss: 0.00546729750931, Cls Loss: 0.0592128634453, Lr C: 0.0002, Lr G: 0.0002 76 | Dis Loss: 0.00342573830858, Cls Loss: 0.134608700871, Lr C: 0.0002, Lr G: 0.0002 77 | Dis Loss: 0.00497301761061, Cls Loss: 0.0699812322855, Lr C: 0.0002, Lr G: 0.0002 78 | Dis Loss: 0.00336416368373, Cls Loss: 0.0736046805978, Lr C: 0.0002, Lr G: 0.0002 79 | Dis Loss: 0.00429906789213, Cls Loss: 0.0790677592158, Lr C: 0.0002, Lr G: 0.0002 80 | Dis Loss: 0.00298569747247, Cls Loss: 0.107853844762, Lr C: 0.0002, Lr G: 0.0002 81 | Dis Loss: 0.0031755536329, Cls Loss: 0.156428828835, Lr C: 0.0002, Lr G: 0.0002 82 | Dis Loss: 0.00451165623963, Cls Loss: 0.0192727744579, Lr C: 0.0002, Lr G: 0.0002 83 | Dis Loss: 0.00421656994149, Cls Loss: 0.0594312474132, Lr C: 0.0002, Lr G: 0.0002 84 | Dis Loss: 0.0050105410628, Cls Loss: 0.156177505851, Lr C: 0.0002, Lr G: 0.0002 85 | Dis Loss: 0.00470406422392, Cls Loss: 0.0755702778697, Lr C: 0.0002, Lr G: 0.0002 86 | Dis Loss: 0.00336070358753, Cls Loss: 0.0805172324181, Lr C: 0.0002, Lr G: 0.0002 87 | Dis Loss: 0.004280324094, Cls Loss: 0.0691900476813, Lr C: 0.0002, Lr G: 0.0002 88 | Dis Loss: 0.00417717173696, Cls Loss: 0.130094155669, Lr C: 0.0002, Lr G: 0.0002 89 | Dis Loss: 0.00336504052393, Cls Loss: 0.0347976088524, Lr C: 0.0002, Lr G: 0.0002 90 | Dis Loss: 0.00330058042891, Cls Loss: 0.0670301020145, Lr C: 0.0002, Lr G: 0.0002 91 | Dis Loss: 0.00478171044961, Cls Loss: 0.0160610731691, Lr C: 0.0002, Lr G: 0.0002 92 | Dis Loss: 0.00342379207723, Cls Loss: 0.0681992769241, Lr C: 0.0002, Lr G: 0.0002 93 | Dis Loss: 0.00253219390288, Cls Loss: 0.0226874481887, Lr C: 0.0002, Lr G: 0.0002 94 | Dis Loss: 0.00355479470454, Cls Loss: 0.0486062765121, Lr C: 0.0002, Lr G: 0.0002 95 | Dis Loss: 0.00282041728497, Cls Loss: 0.0907536819577, Lr C: 0.0002, Lr G: 0.0002 96 | Dis Loss: 0.00399726582691, Cls Loss: 0.0737735033035, Lr C: 0.0002, Lr G: 0.0002 97 | Dis Loss: 0.00471308920532, Cls Loss: 0.0530179291964, Lr C: 0.0002, Lr G: 0.0002 98 | Dis Loss: 0.00450990861282, Cls Loss: 0.0423623323441, Lr C: 0.0002, Lr G: 0.0002 99 | Dis Loss: 0.00480007519946, Cls Loss: 0.0181928016245, Lr C: 0.0002, Lr G: 0.0002 100 | Dis Loss: 0.0035850845743, Cls Loss: 0.166106507182, Lr C: 0.0002, Lr G: 0.0002 101 | Dis Loss: 0.00450746295974, Cls Loss: 0.0182992517948, Lr C: 0.0002, Lr G: 0.0002 102 | Dis Loss: 0.00498038670048, Cls Loss: 0.0431930832565, Lr C: 0.0002, Lr G: 0.0002 103 | Dis Loss: 0.00271419575438, Cls Loss: 0.104042746127, Lr C: 0.0002, Lr G: 0.0002 104 | Dis Loss: 0.00427831569687, Cls Loss: 0.0295623596758, Lr C: 0.0002, Lr G: 0.0002 105 | Dis Loss: 0.00321475090459, Cls Loss: 0.0229330658913, Lr C: 0.0002, Lr G: 0.0002 106 | Dis Loss: 0.00369507516734, Cls Loss: 0.0861396640539, Lr C: 0.0002, Lr G: 0.0002 107 | Dis Loss: 0.00369632034563, Cls Loss: 0.0460176132619, Lr C: 0.0002, Lr G: 0.0002 108 | Dis Loss: 0.00286093493924, Cls Loss: 0.0172299966216, Lr C: 0.0002, Lr G: 0.0002 109 | Dis Loss: 0.00490171322599, Cls Loss: 0.0776693746448, Lr C: 0.0002, Lr G: 0.0002 110 | Dis Loss: 0.00411299988627, Cls Loss: 0.0282558891922, Lr C: 0.0002, Lr G: 0.0002 111 | Dis Loss: 0.00458105280995, Cls Loss: 0.116378813982, Lr C: 0.0002, Lr G: 0.0002 112 | Dis Loss: 0.00599916884676, Cls Loss: 0.0564890913665, Lr C: 0.0002, Lr G: 0.0002 113 | Dis Loss: 0.00354321091436, Cls Loss: 0.0543925464153, Lr C: 0.0002, Lr G: 0.0002 114 | Dis Loss: 0.00155441684183, Cls Loss: 0.101817071438, Lr C: 0.0002, Lr G: 0.0002 115 | Dis Loss: 0.00411869958043, Cls Loss: 0.0668775439262, Lr C: 0.0002, Lr G: 0.0002 116 | Dis Loss: 0.00649306504056, Cls Loss: 0.0597132444382, Lr C: 0.0002, Lr G: 0.0002 117 | Dis Loss: 0.0048693343997, Cls Loss: 0.0541714504361, Lr C: 0.0002, Lr G: 0.0002 118 | Dis Loss: 0.00391428312287, Cls Loss: 0.0604442432523, Lr C: 0.0002, Lr G: 0.0002 119 | Dis Loss: 0.00543193845078, Cls Loss: 0.0975139811635, Lr C: 0.0002, Lr G: 0.0002 120 | Dis Loss: 0.00529876584187, Cls Loss: 0.074589818716, Lr C: 0.0002, Lr G: 0.0002 121 | Dis Loss: 0.00587701890618, Cls Loss: 0.046294644475, Lr C: 0.0002, Lr G: 0.0002 122 | Dis Loss: 0.00575475534424, Cls Loss: 0.0465457551181, Lr C: 0.0002, Lr G: 0.0002 123 | Dis Loss: 0.00320053333417, Cls Loss: 0.087657853961, Lr C: 0.0002, Lr G: 0.0002 124 | Dis Loss: 0.00315615464933, Cls Loss: 0.0458883568645, Lr C: 0.0002, Lr G: 0.0002 125 | Dis Loss: 0.00321982288733, Cls Loss: 0.0363517850637, Lr C: 0.0002, Lr G: 0.0002 126 | Dis Loss: 0.00380116188899, Cls Loss: 0.0939439088106, Lr C: 0.0002, Lr G: 0.0002 127 | Dis Loss: 0.00293942703865, Cls Loss: 0.0320774838328, Lr C: 0.0002, Lr G: 0.0002 128 | Dis Loss: 0.0045956983231, Cls Loss: 0.0672373920679, Lr C: 0.0002, Lr G: 0.0002 129 | Dis Loss: 0.00609593093395, Cls Loss: 0.0583413168788, Lr C: 0.0002, Lr G: 0.0002 130 | Dis Loss: 0.00370764220133, Cls Loss: 0.121347114444, Lr C: 0.0002, Lr G: 0.0002 131 | Dis Loss: 0.00603236770257, Cls Loss: 0.105289764702, Lr C: 0.0002, Lr G: 0.0002 132 | Dis Loss: 0.00454461202025, Cls Loss: 0.0347447767854, Lr C: 0.0002, Lr G: 0.0002 133 | Dis Loss: 0.00476987520233, Cls Loss: 0.0154802519828, Lr C: 0.0002, Lr G: 0.0002 134 | Dis Loss: 0.00467550195754, Cls Loss: 0.0853082090616, Lr C: 0.0002, Lr G: 0.0002 135 | Dis Loss: 0.00362639478408, Cls Loss: 0.0213281717151, Lr C: 0.0002, Lr G: 0.0002 136 | Dis Loss: 0.00576998572797, Cls Loss: 0.088567301631, Lr C: 0.0002, Lr G: 0.0002 137 | Dis Loss: 0.00389611953869, Cls Loss: 0.0775500386953, Lr C: 0.0002, Lr G: 0.0002 138 | Dis Loss: 0.00364692509174, Cls Loss: 0.0360069274902, Lr C: 0.0002, Lr G: 0.0002 139 | Dis Loss: 0.00672007678077, Cls Loss: 0.0251312982291, Lr C: 0.0002, Lr G: 0.0002 140 | Dis Loss: 0.00412650499493, Cls Loss: 0.143935501575, Lr C: 0.0002, Lr G: 0.0002 141 | Dis Loss: 0.00546471495181, Cls Loss: 0.0206368491054, Lr C: 0.0002, Lr G: 0.0002 142 | Dis Loss: 0.00304111326113, Cls Loss: 0.126500666142, Lr C: 0.0002, Lr G: 0.0002 143 | Dis Loss: 0.0051850057207, Cls Loss: 0.0578130409122, Lr C: 0.0002, Lr G: 0.0002 144 | Dis Loss: 0.00496583385393, Cls Loss: 0.0406974591315, Lr C: 0.0002, Lr G: 0.0002 145 | Dis Loss: 0.00398027477786, Cls Loss: 0.129645079374, Lr C: 0.0002, Lr G: 0.0002 146 | Dis Loss: 0.00436131609604, Cls Loss: 0.0291619319469, Lr C: 0.0002, Lr G: 0.0002 147 | Dis Loss: 0.00447409320623, Cls Loss: 0.0951287448406, Lr C: 0.0002, Lr G: 0.0002 148 | Dis Loss: 0.00567617500201, Cls Loss: 0.0566911026835, Lr C: 0.0002, Lr G: 0.0002 149 | Dis Loss: 0.00418065488338, Cls Loss: 0.0225957930088, Lr C: 0.0002, Lr G: 0.0002 150 | Dis Loss: 0.00386782549322, Cls Loss: 0.0340443402529, Lr C: 0.0002, Lr G: 0.0002 151 | Dis Loss: 0.00331043382175, Cls Loss: 0.0555837228894, Lr C: 0.0002, Lr G: 0.0002 152 | Dis Loss: 0.00312547618523, Cls Loss: 0.0342776626348, Lr C: 0.0002, Lr G: 0.0002 153 | Dis Loss: 0.00458795810118, Cls Loss: 0.0738851577044, Lr C: 0.0002, Lr G: 0.0002 154 | Dis Loss: 0.00356465089135, Cls Loss: 0.070107460022, Lr C: 0.0002, Lr G: 0.0002 155 | Dis Loss: 0.00446801865473, Cls Loss: 0.0710966438055, Lr C: 0.0002, Lr G: 0.0002 156 | Dis Loss: 0.00504029775038, Cls Loss: 0.0303533878177, Lr C: 0.0002, Lr G: 0.0002 157 | Dis Loss: 0.00510559463874, Cls Loss: 0.0256489515305, Lr C: 0.0002, Lr G: 0.0002 158 | Dis Loss: 0.00315452227369, Cls Loss: 0.0374249257147, Lr C: 0.0002, Lr G: 0.0002 159 | Dis Loss: 0.00650208070874, Cls Loss: 0.0266140140593, Lr C: 0.0002, Lr G: 0.0002 160 | Dis Loss: 0.00455721607432, Cls Loss: 0.0496471635997, Lr C: 0.0002, Lr G: 0.0002 161 | Dis Loss: 0.00424830336124, Cls Loss: 0.0775512009859, Lr C: 0.0002, Lr G: 0.0002 162 | Dis Loss: 0.00599554041401, Cls Loss: 0.0285082086921, Lr C: 0.0002, Lr G: 0.0002 163 | Dis Loss: 0.0052908747457, Cls Loss: 0.0163153167814, Lr C: 0.0002, Lr G: 0.0002 164 | Dis Loss: 0.00422675767913, Cls Loss: 0.0244442783296, Lr C: 0.0002, Lr G: 0.0002 165 | Dis Loss: 0.00468214182183, Cls Loss: 0.058757096529, Lr C: 0.0002, Lr G: 0.0002 166 | Dis Loss: 0.00464393757284, Cls Loss: 0.0295814089477, Lr C: 0.0002, Lr G: 0.0002 167 | Dis Loss: 0.0040175607428, Cls Loss: 0.0164988674223, Lr C: 0.0002, Lr G: 0.0002 168 | Dis Loss: 0.00507428310812, Cls Loss: 0.0277650821954, Lr C: 0.0002, Lr G: 0.0002 169 | Dis Loss: 0.00443507311866, Cls Loss: 0.0347390472889, Lr C: 0.0002, Lr G: 0.0002 170 | Dis Loss: 0.00319409067743, Cls Loss: 0.0735673382878, Lr C: 0.0002, Lr G: 0.0002 171 | Dis Loss: 0.00557447317988, Cls Loss: 0.129221603274, Lr C: 0.0002, Lr G: 0.0002 172 | Dis Loss: 0.00489986827597, Cls Loss: 0.0250022243708, Lr C: 0.0002, Lr G: 0.0002 173 | Dis Loss: 0.00367637211457, Cls Loss: 0.037488501519, Lr C: 0.0002, Lr G: 0.0002 174 | Dis Loss: 0.00279768346809, Cls Loss: 0.0477290600538, Lr C: 0.0002, Lr G: 0.0002 175 | Dis Loss: 0.00308806146495, Cls Loss: 0.0397023484111, Lr C: 0.0002, Lr G: 0.0002 176 | Dis Loss: 0.00338975223713, Cls Loss: 0.0628024786711, Lr C: 0.0002, Lr G: 0.0002 177 | Dis Loss: 0.00590053154156, Cls Loss: 0.0783554241061, Lr C: 0.0002, Lr G: 0.0002 178 | Dis Loss: 0.00350930891, Cls Loss: 0.102638430893, Lr C: 0.0002, Lr G: 0.0002 179 | Dis Loss: 0.00412848358974, Cls Loss: 0.0205194931477, Lr C: 0.0002, Lr G: 0.0002 180 | Dis Loss: 0.00548736704513, Cls Loss: 0.0289202295244, Lr C: 0.0002, Lr G: 0.0002 181 | Dis Loss: 0.00462893256918, Cls Loss: 0.0628123059869, Lr C: 0.0002, Lr G: 0.0002 182 | Dis Loss: 0.00720173725858, Cls Loss: 0.0199466813356, Lr C: 0.0002, Lr G: 0.0002 183 | Dis Loss: 0.00340282986872, Cls Loss: 0.0598597824574, Lr C: 0.0002, Lr G: 0.0002 184 | Dis Loss: 0.00293013406917, Cls Loss: 0.0531075187027, Lr C: 0.0002, Lr G: 0.0002 185 | Dis Loss: 0.00399965513498, Cls Loss: 0.029729001224, Lr C: 0.0002, Lr G: 0.0002 186 | Dis Loss: 0.00509276194498, Cls Loss: 0.112514361739, Lr C: 0.0002, Lr G: 0.0002 187 | Dis Loss: 0.00507830735296, Cls Loss: 0.0705089122057, Lr C: 0.0002, Lr G: 0.0002 188 | Dis Loss: 0.00379161047749, Cls Loss: 0.0800604820251, Lr C: 0.0002, Lr G: 0.0002 189 | Dis Loss: 0.00571160251275, Cls Loss: 0.0446820110083, Lr C: 0.0002, Lr G: 0.0002 190 | Dis Loss: 0.00470952829346, Cls Loss: 0.0398795008659, Lr C: 0.0002, Lr G: 0.0002 191 | Dis Loss: 0.00690906727687, Cls Loss: 0.0992165058851, Lr C: 0.0002, Lr G: 0.0002 192 | Dis Loss: 0.00446418579668, Cls Loss: 0.0821895599365, Lr C: 0.0002, Lr G: 0.0002 193 | Dis Loss: 0.0044578476809, Cls Loss: 0.0322776995599, Lr C: 0.0002, Lr G: 0.0002 194 | Dis Loss: 0.00570461526513, Cls Loss: 0.0650118067861, Lr C: 0.0002, Lr G: 0.0002 195 | Dis Loss: 0.00549853919074, Cls Loss: 0.0504232347012, Lr C: 0.0002, Lr G: 0.0002 196 | Dis Loss: 0.00388650037348, Cls Loss: 0.129654437304, Lr C: 0.0002, Lr G: 0.0002 197 | Dis Loss: 0.00780218094587, Cls Loss: 0.0459640324116, Lr C: 0.0002, Lr G: 0.0002 198 | Dis Loss: 0.00588699011132, Cls Loss: 0.0635501295328, Lr C: 0.0002, Lr G: 0.0002 199 | Dis Loss: 0.00327239348553, Cls Loss: 0.0425560176373, Lr C: 0.0002, Lr G: 0.0002 200 | Dis Loss: 0.00459404895082, Cls Loss: 0.0682985782623, Lr C: 0.0002, Lr G: 0.0002 201 | -------------------------------------------------------------------------------- /digit_signal_classification/solver.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from model.build_gen import * 7 | from datasets.dataset_read import dataset_read 8 | from utils.avgmeter import AverageMeter 9 | import time 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | # Training settings 14 | class Solver(object): 15 | def __init__( 16 | self, 17 | args, 18 | batch_size=64, 19 | source='svhn', 20 | target='mnist', 21 | learning_rate=0.0002, 22 | interval=100, 23 | optimizer='adam', 24 | num_k=4, 25 | all_use=False, 26 | checkpoint_dir=None, 27 | save_epoch=10, 28 | num_classifiers_train=2, 29 | num_classifiers_test=20, 30 | init='kaiming_u', 31 | use_init=False, 32 | dis_metric='L1' 33 | ): 34 | 35 | self.batch_size = batch_size 36 | self.source = source 37 | self.target = target 38 | self.num_k = num_k 39 | self.checkpoint_dir = checkpoint_dir 40 | self.save_epoch = save_epoch 41 | self.use_abs_diff = args.use_abs_diff 42 | self.all_use = all_use 43 | self.num_classifiers_train = num_classifiers_train 44 | self.num_classifiers_test = num_classifiers_test 45 | self.init = init 46 | self.dis_metric = dis_metric 47 | self.use_init = use_init 48 | 49 | if self.source == 'svhn': 50 | self.scale = True 51 | else: 52 | self.scale = False 53 | 54 | print('dataset loading') 55 | self.datasets, self.dataset_test = dataset_read( 56 | source, target, 57 | self.batch_size, 58 | scale=self.scale, 59 | all_use=self.all_use 60 | ) 61 | print('load finished!') 62 | 63 | self.G = Generator(source=source, target=target) 64 | self.C = Classifier( 65 | source=source, target=target, 66 | num_classifiers_train=self.num_classifiers_train, 67 | num_classifiers_test=self.num_classifiers_test, 68 | init=self.init, 69 | use_init=self.use_init 70 | ) 71 | 72 | if args.eval_only: 73 | self.G.torch.load('{}/{}_to_{}_model_epoch{}_G.pt'.format( 74 | self.checkpoint_dir, 75 | self.source, 76 | self.target, 77 | args.resume_epoch) 78 | ) 79 | 80 | self.C.torch.load('{}/{}_to_{}_model_epoch{}_C.pt'.format( 81 | self.checkpoint_dir, 82 | self.source, 83 | self.target, 84 | args.resume_epoch) 85 | ) 86 | 87 | self.G.cuda() 88 | self.C.cuda() 89 | self.interval = interval 90 | self.writer = SummaryWriter() 91 | 92 | self.opt_c, self.opt_g = self.set_optimizer( 93 | which_opt=optimizer, 94 | lr=learning_rate 95 | ) 96 | self.lr = learning_rate 97 | 98 | # Learning rate scheduler 99 | self.scheduler_g = optim.lr_scheduler.CosineAnnealingLR( 100 | self.opt_g, 101 | float(args.max_epoch) 102 | ) 103 | self.scheduler_c = optim.lr_scheduler.CosineAnnealingLR( 104 | self.opt_c, 105 | float(args.max_epoch) 106 | ) 107 | 108 | def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9): 109 | if which_opt == 'momentum': 110 | self.opt_g = optim.SGD( 111 | self.G.parameters(), 112 | lr=lr, 113 | weight_decay=0.0005, 114 | momentum=momentum 115 | ) 116 | 117 | self.opt_c = optim.SGD( 118 | self.C.parameters(), 119 | lr=lr, 120 | weight_decay=0.0005, 121 | momentum=momentum 122 | ) 123 | 124 | if which_opt == 'adam': 125 | self.opt_g = optim.Adam( 126 | self.G.parameters(), 127 | lr=lr, 128 | weight_decay=0.0005, 129 | amsgrad=False 130 | ) 131 | 132 | self.opt_c = optim.Adam( 133 | self.C.parameters(), 134 | lr=lr, 135 | weight_decay=0.0005, 136 | amsgrad=False 137 | ) 138 | 139 | return self.opt_c, self.opt_g 140 | 141 | def reset_grad(self): 142 | self.opt_g.zero_grad() 143 | self.opt_c.zero_grad() 144 | 145 | @staticmethod 146 | def entropy(x): 147 | b = F.softmax(x) * F.log_softmax(x) 148 | b = -1.0 * b.sum() 149 | return b 150 | 151 | @staticmethod 152 | def discrepancy(out1, out2): 153 | l1loss = torch.nn.L1Loss() 154 | return l1loss(F.softmax(out1, dim=1), F.softmax(out2, dim=1)) 155 | 156 | @staticmethod 157 | def discrepancy_mse(out1, out2): 158 | mseloss = torch.nn.MSELoss() 159 | return mseloss(F.softmax(out1, dim=1), F.softmax(out2, dim=1)) 160 | 161 | @staticmethod 162 | def discrepancy_cos(out1, out2): 163 | cosloss = torch.nn.CosineSimilarity() 164 | return 1 - cosloss(F.softmax(out1, dim=1), F.softmax(out2, dim=1)) 165 | 166 | @staticmethod 167 | def discrepancy_slice_wasserstein(p1, p2): 168 | p1 = torch.sigmoid(p1) 169 | p2 = torch.sigmoid(p2) 170 | s = p1.shape 171 | if s[1] > 1: 172 | proj = torch.randn(s[1], 128).cuda() 173 | proj *= torch.rsqrt(torch.sum(torch.mul(proj, proj), 0, keepdim=True)) 174 | p1 = torch.matmul(p1, proj) 175 | p2 = torch.matmul(p2, proj) 176 | p1 = torch.topk(p1, s[0], dim=0)[0] 177 | p2 = torch.topk(p2, s[0], dim=0)[0] 178 | dist = p1 - p2 179 | wdist = torch.mean(torch.mul(dist, dist)) 180 | 181 | return wdist 182 | 183 | def train(self, epoch, record_file=None, loss_process='mean', func='L1'): 184 | criterion = nn.CrossEntropyLoss().cuda() 185 | 186 | # Various measurements for the discrepancy 187 | dis_dict = { 188 | 'L1': self.discrepancy, 189 | 'MSE': self.discrepancy_mse, 190 | 'Cosine': self.discrepancy_cos, 191 | 'SWD': self.discrepancy_slice_wasserstein 192 | } 193 | 194 | self.G.train() 195 | self.C.train() 196 | torch.cuda.manual_seed(1) 197 | batch_time = AverageMeter() 198 | data_time = AverageMeter() 199 | 200 | batch_num = min( 201 | len(self.datasets.data_loader_A), 202 | len(self.datasets.data_loader_B) 203 | ) 204 | 205 | end = time.time() 206 | 207 | for batch_idx, data in enumerate(self.datasets): 208 | data_time.update(time.time() - end) 209 | img_t = data['T'] 210 | img_s = data['S'] 211 | label_s = data['S_label'] 212 | 213 | if img_s.size()[0] < self.batch_size or img_t.size()[0] < self.batch_size: 214 | break 215 | 216 | img_s = img_s.cuda() 217 | img_t = img_t.cuda() 218 | 219 | imgs_st = torch.cat((img_s, img_t), dim=0) 220 | 221 | label_s = label_s.long().cuda() 222 | 223 | # Step1: update the whole network using source data 224 | self.reset_grad() 225 | feat_s = self.G(img_s) 226 | outputs_s = self.C(feat_s) 227 | 228 | loss_s = [] 229 | for index_tr in range(self.num_classifiers_train): 230 | loss_s.append(criterion(outputs_s[index_tr], label_s)) 231 | 232 | if loss_process == 'mean': 233 | loss_s = torch.stack(loss_s).mean() 234 | else: 235 | loss_s = torch.stack(loss_s).sum() 236 | 237 | loss_s.backward() 238 | self.opt_g.step() 239 | self.opt_c.step() 240 | 241 | # Step2: update the classifiers using target data 242 | self.reset_grad() 243 | feat_st = self.G(imgs_st) 244 | outputs_st = self.C(feat_st) 245 | outputs_s = [ 246 | outputs_st[0][:self.batch_size], 247 | outputs_st[1][:self.batch_size] 248 | ] 249 | outputs_t = [ 250 | outputs_st[0][self.batch_size:], 251 | outputs_st[1][self.batch_size:] 252 | ] 253 | 254 | loss_s = [] 255 | loss_dis = [] 256 | for index_tr in range(self.num_classifiers_train): 257 | loss_s.append(criterion(outputs_s[index_tr], label_s)) 258 | 259 | if loss_process == 'mean': 260 | loss_s = torch.stack(loss_s).mean() 261 | else: 262 | loss_s = torch.stack(loss_s).sum() 263 | 264 | for index_tr in range(self.num_classifiers_train): 265 | for index_tre in range(index_tr + 1, self.num_classifiers_train): 266 | loss_dis.append(dis_dict[func](outputs_t[index_tr], outputs_t[index_tre])) 267 | 268 | if loss_process == 'mean': 269 | loss_dis = torch.stack(loss_dis).mean() 270 | else: 271 | loss_dis = torch.stack(loss_dis).sum() 272 | 273 | loss = loss_s - loss_dis 274 | 275 | loss.backward() 276 | self.opt_c.step() 277 | 278 | # Step3: update the generator using target data 279 | self.reset_grad() 280 | 281 | for index in range(self.num_k+1): 282 | loss_dis = [] 283 | feat_t = self.G(img_t) 284 | outputs_t = self.C(feat_t) 285 | 286 | for index_tr in range(self.num_classifiers_train): 287 | for index_tre in range(index_tr + 1, self.num_classifiers_train): 288 | loss_dis.append(dis_dict[func](outputs_t[index_tr], outputs_t[index_tre])) 289 | 290 | if loss_process == 'mean': 291 | loss_dis = torch.stack(loss_dis).mean() 292 | else: 293 | loss_dis = torch.stack(loss_dis).sum() 294 | 295 | loss_dis.backward() 296 | self.opt_g.step() 297 | self.reset_grad() 298 | 299 | batch_time.update(time.time() - end) 300 | 301 | if batch_idx % self.interval == 0: 302 | print('Train Epoch: {} [{}/{}]\t ' 303 | 'Loss: {:.6f}\t ' 304 | 'Discrepancy: {:.6f} \t ' 305 | 'Lr C: {:.6f}\t' 306 | 'Lr G: {:.6f}\t' 307 | 'Time: {:.3f}({:.3f})\t' 308 | .format(epoch + 1, batch_idx, 309 | batch_num, loss_s.data, 310 | loss_dis.data, 311 | self.opt_c.param_groups[0]['lr'], 312 | self.opt_g.param_groups[0]['lr'], 313 | batch_time.val, batch_time.avg)) 314 | 315 | if record_file: 316 | record = open(record_file, 'a') 317 | record.write('Dis Loss: {}, Cls Loss: {}, Lr C: {}, Lr G: {} \n' 318 | .format(loss_dis.data.cpu().numpy(), 319 | loss_s.data.cpu().numpy(), 320 | self.opt_c.param_groups[0]['lr'], 321 | self.opt_g.param_groups[0]['lr'])) 322 | record.close() 323 | 324 | def test(self, epoch, record_file=None, save_model=False): 325 | criterion = nn.CrossEntropyLoss().cuda() 326 | self.G.eval() 327 | self.C.eval() 328 | test_loss = 0 329 | correct = 0 330 | size = 0 331 | with torch.no_grad(): 332 | for batch_idx, data in enumerate(self.dataset_test): 333 | img = data['T'] 334 | label = data['T_label'] 335 | img, label = img.cuda(), label.long().cuda() 336 | feat = self.G(img) 337 | outputs = self.C(feat) 338 | test_loss += criterion(outputs[0], label).data 339 | k = label.data.size()[0] 340 | output_ensemble = torch.zeros(outputs[0].shape).cuda() 341 | 342 | for index in range(len(outputs)): 343 | output_ensemble += outputs[index] 344 | 345 | pred_ensemble = output_ensemble.data.max(1)[1] 346 | correct += pred_ensemble.eq(label.data).cpu().sum() 347 | size += k 348 | test_loss = test_loss / size 349 | 350 | print('\nTest set: Average loss: {:.4f}\t Ensemble Accuracy: {}/{} ({:.2f}%)' 351 | .format(test_loss, correct, size, 100. * float(correct) / size)) 352 | 353 | if save_model and epoch % self.save_epoch == 0: 354 | torch.save(self.G, '{}/{}_to_{}_model_epoch{}_G.pt' 355 | .format(self.checkpoint_dir, self.source, self.target, epoch)) 356 | torch.save(self.C, '{}/{}_to_{}_model_epoch{}_C.pt' 357 | .format(self.checkpoint_dir, self.source, self.target, epoch)) 358 | 359 | if record_file: 360 | record = open(record_file, 'a') 361 | print('Recording {}'.format(record_file)) 362 | record.write('Accuracy: {:.2f}'.format(100. * float(correct) / size)) 363 | record.write('\n') 364 | record.close() 365 | 366 | self.writer.add_scalar('Test/loss', test_loss, epoch) 367 | self.writer.add_scalar('Test/ACC_en', 100. * float(correct) / size, epoch) 368 | -------------------------------------------------------------------------------- /digit_signal_classification/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['AverageMeter'] 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value. 9 | Examples:: 10 | >>> # Initialize a meter to record loss 11 | >>> losses = AverageMeter() 12 | >>> # Update meter after every minibatch update 13 | >>> losses.update(loss_value, batch_size) 14 | """ 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /digit_signal_classification/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def weights_init(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | m.weight.data.normal_(0.0, 0.01) 8 | m.bias.data.normal_(0.0, 0.01) 9 | elif classname.find('BatchNorm') != -1: 10 | m.weight.data.normal_(1.0, 0.01) 11 | m.bias.data.fill_(0) 12 | 13 | def dense_to_one_hot(labels_dense): 14 | """Convert class labels from scalars to one-hot vectors.""" 15 | labels_one_hot = np.zeros((len(labels_dense),)) 16 | labels_dense = list(labels_dense) 17 | for i, t in enumerate(labels_dense): 18 | if t == 10: 19 | t = 0 20 | labels_one_hot[i] = t 21 | else: 22 | labels_one_hot[i] = t 23 | return labels_one_hot 24 | --------------------------------------------------------------------------------