├── .idea
├── .gitignore
├── STAR_Stochastic_Classifiers_for_UDA.iml
├── deployment.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
└── digit_signal_classification
├── README.md
├── datasets
├── __init__.pyc
├── base_data_loader.py
├── dataset_read.py
├── datasets.py
├── gtsrb.py
├── mnist.py
├── svhn.py
├── synth_traffic.py
├── unaligned_data_loader.py
└── usps.py
├── doc
├── architecture.jpg
├── architecture_small.jpg
└── architecture_small.png
├── main.py
├── model
├── __init__.pyc
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── build_gen.cpython-36.pyc
│ └── build_gen.cpython-37.pyc
├── build_gen.py
├── build_gen.pyc
├── svhn2mnist.py
├── svhn2mnist.pyc
├── syn2gtrsb.py
├── syn2gtrsb.pyc
├── usps.py
└── usps.pyc
├── record
└── mnist_usps
│ ├── test_2020-07-22-17-50-42.out
│ └── train_2020-07-22-17-50-42.out
├── solver.py
└── utils
├── avgmeter.py
└── utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
3 |
--------------------------------------------------------------------------------
/.idea/STAR_Stochastic_Classifiers_for_UDA.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Stochastic Classifiers for Unsupervised Domain Adaptation (CVPR2020)
2 |
3 | ## [[Paper Link]](https://openaccess.thecvf.com/content_CVPR_2020/papers/Lu_Stochastic_Classifiers_for_Unsupervised_Domain_Adaptation_CVPR_2020_paper.pdf)
4 |
5 | ## Short introduction
6 |
7 | This is the implementation for STAR (STochastic clAssifieRs). The main idea for that is to build a distribution over the weights of the classifiers. With that, infinite number of classifiers can be sampled without extra parameters.
8 |
9 | ## Architecture
10 |
11 |
12 | ## Citation
13 |
14 | If you find this helpful, please cite it.
15 |
16 | ```
17 | @InProceedings{Lu_2020_CVPR,
18 | author = {Lu, Zhihe and Yang, Yongxin and Zhu, Xiatian and Liu, Cong and Song, Yi-Zhe and Xiang, Tao},
19 | title = {Stochastic Classifiers for Unsupervised Domain Adaptation},
20 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
21 | month = {June},
22 | year = {2020}
23 | }
24 | ```
25 |
26 | ## Further discussion
27 |
28 | If you have any problems, please contact zhihe.lu@surrey.ac.uk or simply write it in the issue session.
--------------------------------------------------------------------------------
/digit_signal_classification/README.md:
--------------------------------------------------------------------------------
1 | ## Digit and Sign classification
2 |
3 | ### Getting Started
4 | #### Installation
5 | - Install PyTorch (Works on Version 0.2.0_3) and dependencies from http://pytorch.org.
6 | - Install Python 2.7.
7 | - Install tensorboardX.
8 |
9 | ### Download Dataset
10 | Download MNIST Dataset [here](https://drive.google.com/file/d/1cZ4vSIS-IKoyKWPfcgxFMugw0LtMiqPf/view?usp=sharing). Resized image dataset is contained in the file.
11 | Place it in the directory ./data.
12 | All other datasets should be placed in the directory too.
13 |
14 | ### Train
15 | Here is an example for task: MNIST to USPS,
16 |
17 | ```
18 | python main.py \
19 | --source mnist \
20 | --target usps \
21 | --num_classifiers_train 2 \
22 | --lr 0.0002 \
23 | --max_epoch 300 \
24 | --all_use yes \
25 | --optimizer adam
26 | ```
27 |
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/datasets/__init__.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/base_data_loader.py:
--------------------------------------------------------------------------------
1 | class BaseDataLoader():
2 | def __init__(self):
3 | pass
4 |
5 | def initialize(self,batch_size):
6 | self.batch_size = batch_size
7 | self.serial_batches = 0
8 | self.nThreads = 2
9 | self.max_dataset_size=float("inf")
10 | pass
11 |
12 | def load_data():
13 | return None
14 |
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/dataset_read.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from unaligned_data_loader import UnalignedDataLoader
3 | from svhn import load_svhn
4 | from mnist import load_mnist
5 | from usps import load_usps
6 | from gtsrb import load_gtsrb
7 | from synth_traffic import load_syntraffic
8 | sys.path.append('../loader')
9 |
10 |
11 | def return_dataset(data, scale=False, usps=False, all_use='no'):
12 | if data == 'svhn':
13 | train_image, train_label, test_image, test_label = load_svhn()
14 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape,
15 | test_image.shape))
16 |
17 | if data == 'mnist':
18 | train_image, train_label, test_image, test_label = load_mnist(scale=scale, usps=usps, all_use=all_use)
19 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape,
20 | test_image.shape))
21 |
22 | if data == 'usps':
23 | train_image, train_label, test_image, test_label = load_usps(all_use=all_use)
24 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape,
25 | test_image.shape))
26 |
27 | if data == 'synth':
28 | train_image, train_label, test_image, test_label = load_syntraffic()
29 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape,
30 | test_image.shape))
31 |
32 | if data == 'gtsrb':
33 | train_image, train_label, test_image, test_label = load_gtsrb()
34 | print('The size of {} training dataset: {} and testing dataset: {}'.format(data, train_image.shape,
35 | test_image.shape))
36 |
37 | return train_image, train_label, test_image, test_label
38 |
39 |
40 | def dataset_read(source, target, batch_size, scale=False, all_use='no'):
41 | S = {}
42 | S_test = {}
43 | T = {}
44 | T_test = {}
45 | usps = False
46 | if source == 'usps' or target == 'usps':
47 | usps = True
48 |
49 | train_source, s_label_train, test_source, s_label_test = return_dataset(
50 | source,
51 | scale=scale,
52 | usps=usps,
53 | all_use=all_use
54 | )
55 |
56 | train_target, t_label_train, test_target, t_label_test = return_dataset(
57 | target,
58 | scale=scale,
59 | usps=usps,
60 | all_use=all_use
61 | )
62 |
63 | S['imgs'] = train_source
64 | S['labels'] = s_label_train
65 | T['imgs'] = train_target
66 | T['labels'] = t_label_train
67 |
68 | # Input target samples for both as test target is not used
69 | S_test['imgs'] = test_target
70 | S_test['labels'] = t_label_test
71 | T_test['imgs'] = test_target
72 | T_test['labels'] = t_label_test
73 |
74 | scale = 40 if source == 'synth' else 28 if source == 'usps' or target == 'usps' else 32
75 |
76 | train_loader = UnalignedDataLoader()
77 | train_loader.initialize(S, T, batch_size, batch_size, scale=scale, shuffle_=True)
78 | dataset = train_loader.load_data()
79 |
80 | test_loader = UnalignedDataLoader()
81 | test_loader.initialize(S_test, T_test, batch_size, batch_size, scale=scale, shuffle_=True)
82 | dataset_test = test_loader.load_data()
83 |
84 | return dataset, dataset_test
85 |
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch.utils.data as data
3 | from PIL import Image
4 | import numpy as np
5 |
6 |
7 | class Dataset(data.Dataset):
8 | """Args:
9 | transform (callable, optional): A function/transform that takes in an PIL image
10 | and returns a transformed version. E.g, ``transforms.RandomCrop``
11 | target_transform (callable, optional): A function/transform that takes in the
12 | target and transforms it.
13 | download (bool, optional): If true, downloads the dataset from the internet and
14 | puts it in root directory. If dataset is already downloaded, it is not
15 | downloaded again.
16 | """
17 | def __init__(self, data, label, transform=None,target_transform=None):
18 |
19 | self.transform = transform
20 | self.target_transform = target_transform
21 | self.data = data
22 | self.labels = label
23 |
24 | def __getitem__(self, index):
25 | """
26 | Args:
27 | index (int): Index
28 | Returns:
29 | tuple: (image, target) where target is index of the target class.
30 | """
31 |
32 | img, target = self.data[index], self.labels[index]
33 | if img.shape[0] != 1:
34 | img = Image.fromarray(np.uint8(np.asarray(img.transpose((1, 2, 0)))))
35 | #
36 | elif img.shape[0] == 1:
37 | im = np.uint8(np.asarray(img))
38 | im = np.vstack([im, im, im]).transpose((1, 2, 0))
39 | img = Image.fromarray(im)
40 |
41 | if self.target_transform is not None:
42 | target = self.target_transform(target)
43 | if self.transform is not None:
44 | img = self.transform(img)
45 | # return img, target
46 | return img, target
47 |
48 | def __len__(self):
49 | return len(self.data)
50 |
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/gtsrb.py:
--------------------------------------------------------------------------------
1 | mport numpy as np
2 | import cPickle as pkl
3 |
4 |
5 | def load_gtsrb():
6 | data_target = pkl.load(open('/data/digit_sign/data_gtsrb'))
7 | target_train = np.random.permutation(len(data_target['image']))
8 | data_t_im = data_target['image'][target_train[:31367], :, :, :]
9 | data_t_im_test = data_target['image'][target_train[31367:], :, :, :]
10 | data_t_label = data_target['label'][target_train[:31367]] + 1
11 | data_t_label_test = data_target['label'][target_train[31367:]] + 1
12 | data_t_im = data_t_im.transpose(0, 3, 1, 2).astype(np.float32)
13 | data_t_im_test = data_t_im_test.transpose(0, 3, 1, 2).astype(np.float32)
14 | return data_t_im, data_t_label, data_t_im_test, data_t_label_test
15 |
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.io import loadmat
3 |
4 |
5 | def load_mnist(scale=True, usps=False, all_use=False):
6 | mnist_data = loadmat('/data/digit_sign/mnist_data.mat')
7 | if scale:
8 | mnist_train = np.reshape(mnist_data['train_32'], (55000, 32, 32, 1))
9 | mnist_test = np.reshape(mnist_data['test_32'], (10000, 32, 32, 1))
10 | mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
11 | mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)
12 | mnist_train = mnist_train.transpose(0, 3, 1, 2).astype(np.float32)
13 | mnist_test = mnist_test.transpose(0, 3, 1, 2).astype(np.float32)
14 | mnist_labels_train = mnist_data['label_train']
15 | mnist_labels_test = mnist_data['label_test']
16 | else:
17 | mnist_train = mnist_data['train_28']
18 | mnist_test = mnist_data['test_28']
19 | mnist_labels_train = mnist_data['label_train']
20 | mnist_labels_test = mnist_data['label_test']
21 | mnist_train = mnist_train.astype(np.float32)
22 | mnist_test = mnist_test.astype(np.float32)
23 | mnist_train = mnist_train.transpose((0, 3, 1, 2))
24 | mnist_test = mnist_test.transpose((0, 3, 1, 2))
25 | train_label = np.argmax(mnist_labels_train, axis=1)
26 | inds = np.random.permutation(mnist_train.shape[0])
27 | mnist_train = mnist_train[inds]
28 | train_label = train_label[inds]
29 | test_label = np.argmax(mnist_labels_test, axis=1)
30 | if usps and all_use != 'yes':
31 | mnist_train = mnist_train[:2000]
32 | train_label = train_label[:2000]
33 |
34 | return mnist_train, train_label, mnist_test, test_label
35 |
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/svhn.py:
--------------------------------------------------------------------------------
1 | rom scipy.io import loadmat
2 | import numpy as np
3 | import sys
4 | sys.path.append('../utils/')
5 | from utils.utils import dense_to_one_hot
6 |
7 |
8 | def load_svhn():
9 | svhn_train = loadmat('/data/digit_sign/train_32x32.mat')
10 | svhn_test = loadmat('/data/digit_sign/test_32x32.mat')
11 | svhn_train_im = svhn_train['X']
12 | svhn_train_im = svhn_train_im.transpose(3, 2, 0, 1).astype(np.float32)
13 | svhn_label = dense_to_one_hot(svhn_train['y'])
14 | svhn_test_im = svhn_test['X']
15 | svhn_test_im = svhn_test_im.transpose(3, 2, 0, 1).astype(np.float32)
16 | svhn_label_test = dense_to_one_hot(svhn_test['y'])
17 |
18 | return svhn_train_im, svhn_label, svhn_test_im, svhn_label_test
19 |
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/synth_traffic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cPickle as pkl
3 |
4 |
5 | def load_syntraffic():
6 | data_source = pkl.load(open('/data/digit_sign/data_synthetic'))
7 | source_train = np.random.permutation(len(data_source['image']))
8 | data_s_im = data_source['image'][source_train[:len(data_source['image'])], :, :, :]
9 | data_s_im_test = data_source['image'][source_train[len(data_source['image']) - 2000:], :, :, :]
10 | data_s_label = data_source['label'][source_train[:len(data_source['image'])]]
11 | data_s_label_test = data_source['label'][source_train[len(data_source['image']) - 2000:]]
12 | data_s_im = data_s_im.transpose(0, 3, 1, 2).astype(np.float32)
13 | data_s_im_test = data_s_im_test.transpose(0, 3, 1, 2).astype(np.float32)
14 | return data_s_im, data_s_label, data_s_im_test, data_s_label_test
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/unaligned_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from builtins import object
3 | import torchvision.transforms as transforms
4 | from datasets import Dataset
5 |
6 |
7 | class PairedData(object):
8 | def __init__(self, data_loader_A, data_loader_B, max_dataset_size):
9 | self.data_loader_A = data_loader_A
10 | self.data_loader_B = data_loader_B
11 | self.stop_A = False
12 | self.stop_B = False
13 | self.max_dataset_size = max_dataset_size
14 |
15 | def __iter__(self):
16 | self.stop_A = False
17 | self.stop_B = False
18 | self.data_loader_A_iter = iter(self.data_loader_A)
19 | self.data_loader_B_iter = iter(self.data_loader_B)
20 | self.iter = 0
21 | return self
22 |
23 | def __next__(self):
24 | A, A_paths = None, None
25 | B, B_paths = None, None
26 | try:
27 | A, A_paths = next(self.data_loader_A_iter)
28 | except StopIteration:
29 | if A is None or A_paths is None:
30 | self.stop_A = True
31 | self.data_loader_A_iter = iter(self.data_loader_A)
32 | A, A_paths = next(self.data_loader_A_iter)
33 |
34 | try:
35 | B, B_paths = next(self.data_loader_B_iter)
36 | except StopIteration:
37 | if B is None or B_paths is None:
38 | self.stop_B = True
39 | self.data_loader_B_iter = iter(self.data_loader_B)
40 | B, B_paths = next(self.data_loader_B_iter)
41 |
42 | if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size:
43 | self.stop_A = False
44 | self.stop_B = False
45 | raise StopIteration()
46 | else:
47 | self.iter += 1
48 | return {'S': A, 'S_label': A_paths,
49 | 'T': B, 'T_label': B_paths}
50 |
51 |
52 | class UnalignedDataLoader():
53 | def initialize(self, source, target, batch_size1, batch_size2, scale=32, shuffle_=False):
54 | transform = transforms.Compose([
55 | transforms.Resize(scale),
56 | transforms.ToTensor(),
57 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
58 | ])
59 |
60 | dataset_source = Dataset(source['imgs'], source['labels'], transform=transform)
61 | dataset_target = Dataset(target['imgs'], target['labels'], transform=transform)
62 |
63 | data_loader_s = torch.utils.data.DataLoader(
64 | dataset_source,
65 | batch_size=batch_size1,
66 | shuffle=shuffle_,
67 | num_workers=4
68 | )
69 |
70 | data_loader_t = torch.utils.data.DataLoader(
71 | dataset_target,
72 | batch_size=batch_size2,
73 | shuffle=shuffle_,
74 | num_workers=4
75 | )
76 |
77 | self.dataset_s = dataset_source
78 | self.dataset_t = dataset_target
79 | self.paired_data = PairedData(data_loader_s, data_loader_t, float("inf"))
80 |
81 | @staticmethod
82 | def name():
83 | return 'UnalignedDataLoader'
84 |
85 | def load_data(self):
86 | return self.paired_data
87 |
88 | def __len__(self):
89 | return min(max(len(self.dataset_s), len(self.dataset_t)), float("inf"))
90 |
--------------------------------------------------------------------------------
/digit_signal_classification/datasets/usps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gzip
3 | import cPickle
4 |
5 |
6 | def load_usps(all_use=False):
7 | f = gzip.open('/data/digit_sign/usps_28x28.pkl', 'rb')
8 | data_set = cPickle.load(f)
9 | f.close()
10 | img_train = data_set[0][0]
11 | label_train = data_set[0][1]
12 | img_test = data_set[1][0]
13 | label_test = data_set[1][1]
14 | inds = np.random.permutation(img_train.shape[0])
15 | if all_use == 'yes':
16 | img_train = img_train[inds][:6562]
17 | label_train = label_train[inds][:6562]
18 | else:
19 | img_train = img_train[inds][:1800]
20 | label_train = label_train[inds][:1800]
21 | img_train = img_train * 255
22 | img_test = img_test * 255
23 | img_train = img_train.reshape((img_train.shape[0], 1, 28, 28))
24 | img_test = img_test.reshape((img_test.shape[0], 1, 28, 28))
25 | return img_train, label_train, img_test, label_test
26 |
--------------------------------------------------------------------------------
/digit_signal_classification/doc/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/doc/architecture.jpg
--------------------------------------------------------------------------------
/digit_signal_classification/doc/architecture_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/doc/architecture_small.jpg
--------------------------------------------------------------------------------
/digit_signal_classification/doc/architecture_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/doc/architecture_small.png
--------------------------------------------------------------------------------
/digit_signal_classification/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | from solver import Solver
5 | import os
6 | import time
7 |
8 | # Training settings
9 | parser = argparse.ArgumentParser(description='PyTorch Stochastic Classifiers Implementation')
10 | parser.add_argument('--all_use', type=str, default='no', metavar='N',
11 | help='use all training data? in usps adaptation')
12 | parser.add_argument('--batch-size', type=int, default=128, metavar='N',
13 | help='input batch size for training (default: 64)')
14 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', metavar='N',
15 | help='source only or not')
16 | parser.add_argument('--eval_only', action='store_true', default=False,
17 | help='evaluation only option')
18 | parser.add_argument('--lr', type=float, default=0.0002, metavar='LR',
19 | help='learning rate (default: 0.02)')
20 | parser.add_argument('--max_epoch', type=int, default=100, metavar='N',
21 | help='how many epochs')
22 | parser.add_argument('--no-cuda', action='store_true', default=False,
23 | help='disables CUDA training')
24 | parser.add_argument('--num_k', type=int, default=4, metavar='N',
25 | help='hyper parameter for generator update')
26 | parser.add_argument('--one_step', action='store_true', default=False,
27 | help='one step training with gradient reversal layer')
28 | parser.add_argument('--optimizer', type=str, default='momentum', metavar='N',
29 | help='which optimizer')
30 | parser.add_argument('--resume_epoch', type=int, default=100, metavar='N',
31 | help='epoch to resume')
32 | parser.add_argument('--save_epoch', type=int, default=10, metavar='N',
33 | help='when to restore the model')
34 | parser.add_argument('--save_model', action='store_true', default=False,
35 | help='save_model or not')
36 | parser.add_argument('--seed', type=int, default=1, metavar='S',
37 | help='random seed (default: 1)')
38 | parser.add_argument('--source', type=str, default='svhn', metavar='N',
39 | help='source dataset')
40 | parser.add_argument('--target', type=str, default='mnist', metavar='N',
41 | help='target dataset')
42 | parser.add_argument('--use_abs_diff', action='store_true', default=False,
43 | help='use absolute difference value as a measurement')
44 | parser.add_argument('--num_classifiers_train', type=int, default=2, metavar='N',
45 | help='the number of classifiers used in training')
46 | parser.add_argument('--num_classifiers_test', type=int, default=20, metavar='N',
47 | help='the number of classifiers used in testing')
48 | parser.add_argument('--gpu_devices', type=str, default='0', help='the device you use')
49 | parser.add_argument('--loss_process', type=str, default='sum',
50 | help='mean or sum of the loss')
51 | parser.add_argument('--log_dir', type=str, default='record', metavar='N',
52 | help='the place to store the logs')
53 | parser.add_argument('--init', type=str, default='kaiming_u', metavar='N',
54 | help='the initialization method')
55 | parser.add_argument('--use_init', action='store_true', default=False,
56 | help='whether use initialization')
57 |
58 | args = parser.parse_args()
59 | args.cuda = not args.no_cuda and torch.cuda.is_available()
60 | torch.manual_seed(args.seed)
61 |
62 | if args.cuda and args.gpu_devices:
63 | torch.cuda.manual_seed(args.seed)
64 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
65 |
66 | print(args)
67 |
68 |
69 | def main():
70 | solver = Solver(
71 | args, source=args.source,
72 | target=args.target,
73 | learning_rate=args.lr,
74 | batch_size=args.batch_size,
75 | optimizer=args.optimizer,
76 | num_k=args.num_k,
77 | all_use=args.all_use,
78 | checkpoint_dir=args.checkpoint_dir,
79 | save_epoch=args.save_epoch,
80 | num_classifiers_train=args.num_classifiers_train,
81 | num_classifiers_test=args.num_classifiers_test,
82 | init=args.init,
83 | use_init=args.use_init
84 | )
85 |
86 | record_time = time.strftime('%Y-%m-%d-%H-%M-%S')
87 |
88 | if args.use_init:
89 | record_dir = '{}/{}_{}'.format(
90 | args.log_dir,
91 | args.source,
92 | args.target
93 | )
94 | else:
95 | record_dir = '{}/{}_{}'.format(
96 | args.log_dir,
97 | args.source,
98 | args.target,
99 | )
100 |
101 | if not os.path.exists(args.checkpoint_dir):
102 | os.mkdir(args.checkpoint_dir)
103 | if not os.path.exists(args.log_dir):
104 | os.mkdir(args.log_dir)
105 | if not os.path.exists(record_dir):
106 | os.mkdir(record_dir)
107 |
108 | record_train = '{}/train_{}.out'.format(record_dir, record_time)
109 | record_test = '{}/test_{}.out'.format(record_dir, record_time)
110 |
111 | # Log the configures into log file
112 | record = open(record_test, 'a')
113 | record.write('Configures: {} \n'.format(args))
114 | record.close()
115 |
116 | if args.eval_only:
117 | solver.test(0)
118 | else:
119 | for epoch in range(args.max_epoch):
120 | solver.train(
121 | epoch,
122 | record_file=record_train,
123 | loss_process=args.loss_process
124 | )
125 |
126 | if epoch % 1 == 0:
127 | solver.test(
128 | epoch,
129 | record_file=record_test,
130 | save_model=args.save_model
131 | )
132 |
133 |
134 | if __name__ == '__main__':
135 | main()
136 |
--------------------------------------------------------------------------------
/digit_signal_classification/model/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__init__.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/model/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/model/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/model/__pycache__/build_gen.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__pycache__/build_gen.cpython-36.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/model/__pycache__/build_gen.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/__pycache__/build_gen.cpython-37.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/model/build_gen.py:
--------------------------------------------------------------------------------
1 | import svhn2mnist
2 | import usps
3 | import syn2gtrsb
4 |
5 |
6 | def Generator(source, target):
7 | if source == 'usps' or target == 'usps':
8 | return usps.Feature()
9 | elif source == 'svhn':
10 | return svhn2mnist.Feature()
11 | elif source == 'synth':
12 | return syn2gtrsb.Feature()
13 |
14 |
15 | def Classifier(
16 | source,
17 | target,
18 | num_classifiers_train=2,
19 | num_classifiers_test=1,
20 | init='kaiming_u',
21 | use_init=False
22 | ):
23 |
24 | if source == 'usps' or target == 'usps':
25 | return usps.Predictor(
26 | num_classifiers_train,
27 | num_classifiers_test,
28 | init,
29 | use_init
30 | )
31 |
32 | if source == 'svhn':
33 | return svhn2mnist.Predictor(
34 | num_classifiers_train,
35 | num_classifiers_test,
36 | init,
37 | use_init
38 | )
39 |
40 | if source == 'synth':
41 | return syn2gtrsb.Predictor(
42 | num_classifiers_train,
43 | num_classifiers_test,
44 | init,
45 | use_init
46 | )
47 |
48 |
--------------------------------------------------------------------------------
/digit_signal_classification/model/build_gen.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/build_gen.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/model/svhn2mnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.parameter import Parameter
5 | import torch.distributions.normal as normal
6 |
7 |
8 | class Feature(nn.Module):
9 | def __init__(self):
10 | super(Feature, self).__init__()
11 | self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)
12 | self.bn1 = nn.BatchNorm2d(64)
13 | self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)
14 | self.bn2 = nn.BatchNorm2d(64)
15 | self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)
16 | self.bn3 = nn.BatchNorm2d(128)
17 | self.fc1 = nn.Linear(8192, 3072)
18 | self.bn1_fc = nn.BatchNorm1d(3072)
19 |
20 | def forward(self, x):
21 | x = F.max_pool2d(
22 | F.relu(self.bn1(self.conv1(x))),
23 | stride=2,
24 | kernel_size=3,
25 | padding=1
26 | )
27 | x = F.max_pool2d(
28 | F.relu(self.bn2(self.conv2(x))),
29 | stride=2,
30 | kernel_size=3,
31 | padding=1
32 | )
33 | x = F.relu(self.bn3(self.conv3(x)))
34 | x = x.view(x.size(0), 8192)
35 | x = F.relu(self.bn1_fc(self.fc1(x)))
36 | x = F.dropout(x, training=self.training)
37 | return x
38 |
39 |
40 | class Predictor(nn.Module):
41 | def __init__(
42 | self, num_classifiers_train=2,
43 | num_classifiers_test=20,
44 | init='kaiming_u',
45 | use_init=False
46 | ):
47 | super(Predictor, self).__init__()
48 | self.num_classifiers_train = num_classifiers_train
49 | self.num_classifiers_test = num_classifiers_test
50 | self.init = init
51 |
52 | function_init = {
53 | 'kaiming_u': nn.init.kaiming_uniform_,
54 | 'kaiming_n': nn.init.kaiming_normal_,
55 | 'xavier': nn.init.xavier_normal_
56 | }
57 |
58 | self.fc1 = nn.Linear(3072, 2048)
59 |
60 | self.mu2 = Parameter(torch.randn(10, 2048))
61 | self.sigma2 = Parameter(torch.zeros(10, 2048))
62 |
63 | if use_init:
64 | all_parameters = [self.mu2, self.sigma2]
65 | for item in all_parameters:
66 | function_init[self.init](item)
67 |
68 | self.b2 = Parameter(torch.zeros(10))
69 | self.bn1_fc = nn.BatchNorm1d(2048)
70 |
71 | def forward(self, x, only_mu=True):
72 |
73 | x = self.fc1(x)
74 | x = F.relu(self.bn1_fc(x))
75 |
76 | sigma2_pos = torch.sigmoid(self.sigma2)
77 | fc2_distribution = normal.Normal(self.mu2, sigma2_pos)
78 |
79 | if self.training:
80 | classifiers = []
81 | for index in range(self.num_classifiers_train):
82 | fc2_w = fc2_distribution.rsample()
83 | one_classifier = [fc2_w, self.b2]
84 |
85 | classifiers.append(one_classifier)
86 |
87 | outputs = []
88 | for index in range(self.num_classifiers_train):
89 | out = F.linear(x, classifiers[index][0], classifiers[index][1])
90 | outputs.append(out)
91 |
92 | return outputs
93 | else:
94 | if only_mu:
95 | # Only use mu for classification
96 | out = F.linear(x, self.mu2, self.b2)
97 | return [out]
98 | else:
99 | classifiers = []
100 | for index in range(self.num_classifiers_test):
101 | fc2_w = fc2_distribution.rsample()
102 | one_classifier = [fc2_w, self.b2]
103 |
104 | classifiers.append(one_classifier)
105 |
106 | outputs = []
107 | for index in range(self.num_classifiers_test):
108 | out = F.linear(x, classifiers[index][0], classifiers[index][1])
109 | outputs.append(out)
110 |
111 | return outputs
112 |
--------------------------------------------------------------------------------
/digit_signal_classification/model/svhn2mnist.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/svhn2mnist.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/model/syn2gtrsb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.parameter import Parameter
5 | import torch.distributions.normal as normal
6 |
7 |
8 | class Feature(nn.Module):
9 | def __init__(self):
10 | super(Feature, self).__init__()
11 | self.conv1 = nn.Conv2d(3, 96, kernel_size=5, stride=1, padding=2)
12 | self.bn1 = nn.BatchNorm2d(96)
13 | self.conv2 = nn.Conv2d(96, 144, kernel_size=3, stride=1, padding=1)
14 | self.bn2 = nn.BatchNorm2d(144)
15 | self.conv3 = nn.Conv2d(144, 256, kernel_size=5, stride=1, padding=2)
16 | self.bn3 = nn.BatchNorm2d(256)
17 |
18 | def forward(self, x):
19 | x = F.max_pool2d(
20 | F.relu(self.bn1(self.conv1(x))),
21 | stride=2,
22 | kernel_size=2,
23 | padding=0
24 | )
25 | x = F.max_pool2d(
26 | F.relu(self.bn2(self.conv2(x))),
27 | stride=2,
28 | kernel_size=2,
29 | padding=0
30 | )
31 | x = F.max_pool2d(
32 | F.relu(self.bn3(self.conv3(x))),
33 | stride=2,
34 | kernel_size=2,
35 | padding=0
36 | )
37 | x = x.view(x.size(0), 6400)
38 | return x
39 |
40 |
41 | class Predictor(nn.Module):
42 | def __init__(
43 | self,
44 | num_classifiers_train=2,
45 | num_classifiers_test=1,
46 | use_init=False
47 | ):
48 | super(Predictor, self).__init__()
49 | self.num_classifiers_train = num_classifiers_train
50 | self.num_classifiers_test = num_classifiers_test
51 | self.init = init
52 | function_init = {
53 | 'kaiming_u': nn.init.kaiming_uniform_,
54 | 'kaiming_n': nn.init.kaiming_normal_,
55 | 'xavier': nn.init.xavier_normal_
56 | }
57 |
58 | self.fc1 = nn.Linear(6400, 512)
59 | self.bn1_fc = nn.BatchNorm1d(512)
60 |
61 | self.fc2_mu = Parameter(torch.randn(43, 512))
62 | self.fc2_sigma = Parameter(torch.zeros(43, 512))
63 | self.fc2_bias = Parameter(torch.zeros(43))
64 |
65 | if use_init:
66 | function_init[self.init](self.fc2_mu)
67 |
68 | def forward(self, x, only_mu=True):
69 |
70 | x = F.relu(self.bn1_fc(self.fc1(x)))
71 | x = F.dropout(x, training=self.training)
72 |
73 | fc2_sigma_pos = F.softplus(self.fc2_sigma - 2)
74 | fc2_distribution = normal.Normal(self.fc2_mu, fc2_sigma_pos)
75 |
76 | if self.training:
77 | classifiers = []
78 | for index in range(self.num_classifiers_train):
79 | fc2_w = fc2_distribution.rsample()
80 | classifiers.append(fc2_w)
81 |
82 | outputs = []
83 | for index in range(self.num_classifiers_train):
84 | out = F.linear(x, classifiers[index], self.fc2_bias)
85 | outputs.append(out)
86 | return outputs
87 | else:
88 | if only_mu:
89 | # Only use mu for classification
90 | out = F.linear(x, self.fc2_mu, self.fc2_bias)
91 | return [out]
92 | else:
93 | classifiers = []
94 | for index in range(self.num_classifiers_test):
95 | fc2_w = fc2_distribution.rsample()
96 | classifiers.append(fc2_w)
97 |
98 | outputs = []
99 | for index in range(self.num_classifiers_test):
100 | out = F.linear(x, classifiers[index], self.fc2_bias)
101 | outputs.append(out)
102 | return outputs
103 |
--------------------------------------------------------------------------------
/digit_signal_classification/model/syn2gtrsb.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/syn2gtrsb.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/model/usps.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.parameter import Parameter
5 | import torch.distributions.normal as normal
6 |
7 |
8 | class Feature(nn.Module):
9 | def __init__(self):
10 | super(Feature, self).__init__()
11 | self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1)
12 | self.bn1 = nn.BatchNorm2d(32)
13 | self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1)
14 | self.bn2 = nn.BatchNorm2d(48)
15 |
16 | def forward(self, x):
17 | x = torch.mean(x, 1).view(x.size()[0], 1, x.size()[2], x.size()[3])
18 |
19 | x = F.max_pool2d(
20 | F.relu(self.bn1(self.conv1(x))),
21 | stride=2,
22 | kernel_size=2,
23 | dilation=(1, 1)
24 | )
25 |
26 | x = F.max_pool2d(
27 | F.relu(self.bn2(self.conv2(x))),
28 | stride=2,
29 | kernel_size=2,
30 | dilation=(1, 1)
31 | )
32 |
33 | x = x.view(x.size(0), 48*4*4)
34 | return x
35 |
36 |
37 | class Predictor(nn.Module):
38 | def __init__(
39 | self,
40 | num_classifiers_train=2,
41 | num_classifiers_test=20,
42 | init='kaiming_u',
43 | use_init=False,
44 | prob=0.5
45 | ):
46 | super(Predictor, self).__init__()
47 | self.num_classifiers_train = num_classifiers_train
48 | self.num_classifiers_test = num_classifiers_test
49 | self.prob = prob
50 | self.init = init
51 |
52 | function_init = {
53 | 'kaiming_u': nn.init.kaiming_uniform_,
54 | 'kaiming_n': nn.init.kaiming_normal_,
55 | 'xavier': nn.init.xavier_normal_
56 | }
57 |
58 | self.fc1 = nn.Linear(48*4*4, 100)
59 | self.bn1_fc = nn.BatchNorm1d(100)
60 |
61 | self.fc2 = nn.Linear(100, 100)
62 | self.bn2_fc = nn.BatchNorm1d(100)
63 |
64 | # Use distribution in the last layer
65 | self.fc3_mu = Parameter(torch.randn(10, 100))
66 | self.fc3_sigma = Parameter(torch.zeros(10, 100))
67 | self.fc3_bias = Parameter(torch.zeros(10))
68 |
69 | if use_init:
70 | function_init[init](self.fc3_mu)
71 |
72 | def forward(self, x, only_mu=True):
73 | x = F.dropout(x, training=self.training, p=self.prob)
74 | x = F.relu(self.bn1_fc(self.fc1(x)))
75 | x = F.dropout(x, training=self.training, p=self.prob)
76 | x = F.relu(self.bn2_fc(self.fc2(x)))
77 |
78 | # Distribution sample for the fc layer
79 | fc3_sigma_pos = F.softplus(self.fc3_sigma - 2)
80 | fc3_distribution = normal.Normal(self.fc3_mu, fc3_sigma_pos)
81 |
82 | if self.training:
83 | classifiers = []
84 | for index in range(self.num_classifiers_train):
85 | fc3_w = fc3_distribution.rsample()
86 | classifiers.append(fc3_w)
87 |
88 | outputs = []
89 | for index in range(self.num_classifiers_train):
90 | out = F.linear(x, classifiers[index], self.fc3_bias)
91 | outputs.append(out)
92 |
93 | return outputs
94 |
95 | else:
96 | if only_mu:
97 | # Only use mu for classification
98 | out = F.linear(x, self.fc3_mu, self.fc3_bias)
99 | return [out]
100 | else:
101 | classifiers = []
102 | for index in range(self.num_classifiers_test):
103 | fc3_w = fc3_distribution.rsample()
104 | classifiers.append(fc3_w)
105 |
106 | outputs = []
107 | for index in range(self.num_classifiers_test):
108 | out = F.linear(x, classifiers[index], self.fc3_bias)
109 | outputs.append(out)
110 |
111 | return outputs
112 |
--------------------------------------------------------------------------------
/digit_signal_classification/model/usps.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhiheLu/STAR_Stochastic_Classifiers_for_UDA/9ea65735ae2cd58ff805f766366bb57619132dd7/digit_signal_classification/model/usps.pyc
--------------------------------------------------------------------------------
/digit_signal_classification/record/mnist_usps/test_2020-07-22-17-50-42.out:
--------------------------------------------------------------------------------
1 | Configures: Namespace(all_use='yes', batch_size=128, checkpoint_dir='checkpoint', constraint='softplus', cuda=True, eval_only=False, gpu_devices='0', init='kaiming_u', log_dir='record', loss_process='mean', lr=0.0002, max_epoch=200, no_cuda=False, num_classifiers_test=20, num_classifiers_train=2, num_k=4, one_step=False, optimizer='adam', resume_epoch=100, save_epoch=10, save_model=False, seed=1, source='mnist', target='usps', use_abs_diff=False, use_init=False)
2 | Accuracy: 58.66
3 | Accuracy: 69.14
4 | Accuracy: 73.76
5 | Accuracy: 77.37
6 | Accuracy: 80.91
7 | Accuracy: 81.88
8 | Accuracy: 84.19
9 | Accuracy: 85.16
10 | Accuracy: 86.94
11 | Accuracy: 88.33
12 | Accuracy: 88.60
13 | Accuracy: 88.98
14 | Accuracy: 89.73
15 | Accuracy: 89.89
16 | Accuracy: 89.95
17 | Accuracy: 90.22
18 | Accuracy: 91.08
19 | Accuracy: 91.83
20 | Accuracy: 92.42
21 | Accuracy: 91.56
22 | Accuracy: 92.80
23 | Accuracy: 93.12
24 | Accuracy: 92.85
25 | Accuracy: 93.23
26 | Accuracy: 93.98
27 | Accuracy: 93.87
28 | Accuracy: 94.03
29 | Accuracy: 93.55
30 | Accuracy: 94.30
31 | Accuracy: 93.98
32 | Accuracy: 94.25
33 | Accuracy: 94.14
34 | Accuracy: 94.62
35 | Accuracy: 94.68
36 | Accuracy: 94.68
37 | Accuracy: 95.22
38 | Accuracy: 95.11
39 | Accuracy: 95.48
40 | Accuracy: 95.11
41 | Accuracy: 95.22
42 | Accuracy: 94.73
43 | Accuracy: 95.43
44 | Accuracy: 95.16
45 | Accuracy: 95.27
46 | Accuracy: 95.16
47 | Accuracy: 94.95
48 | Accuracy: 95.11
49 | Accuracy: 95.97
50 | Accuracy: 95.86
51 | Accuracy: 95.59
52 | Accuracy: 95.81
53 | Accuracy: 95.75
54 | Accuracy: 95.65
55 | Accuracy: 94.95
56 | Accuracy: 95.81
57 | Accuracy: 95.70
58 | Accuracy: 95.91
59 | Accuracy: 96.34
60 | Accuracy: 95.48
61 | Accuracy: 96.18
62 | Accuracy: 96.24
63 | Accuracy: 96.34
64 | Accuracy: 96.61
65 | Accuracy: 96.45
66 | Accuracy: 96.29
67 | Accuracy: 96.45
68 | Accuracy: 96.40
69 | Accuracy: 96.02
70 | Accuracy: 95.75
71 | Accuracy: 96.18
72 | Accuracy: 96.13
73 | Accuracy: 96.99
74 | Accuracy: 96.51
75 | Accuracy: 96.51
76 | Accuracy: 96.24
77 | Accuracy: 96.45
78 | Accuracy: 97.04
79 | Accuracy: 96.67
80 | Accuracy: 97.10
81 | Accuracy: 97.37
82 | Accuracy: 97.04
83 | Accuracy: 96.45
84 | Accuracy: 96.99
85 | Accuracy: 96.40
86 | Accuracy: 97.31
87 | Accuracy: 96.45
88 | Accuracy: 96.88
89 | Accuracy: 97.42
90 | Accuracy: 96.83
91 | Accuracy: 97.20
92 | Accuracy: 96.99
93 | Accuracy: 96.83
94 | Accuracy: 96.72
95 | Accuracy: 97.26
96 | Accuracy: 97.15
97 | Accuracy: 96.83
98 | Accuracy: 96.99
99 | Accuracy: 97.20
100 | Accuracy: 97.04
101 | Accuracy: 97.42
102 | Accuracy: 97.26
103 | Accuracy: 97.04
104 | Accuracy: 96.94
105 | Accuracy: 96.94
106 | Accuracy: 97.42
107 | Accuracy: 97.42
108 | Accuracy: 97.42
109 | Accuracy: 97.15
110 | Accuracy: 97.31
111 | Accuracy: 97.26
112 | Accuracy: 97.04
113 | Accuracy: 97.15
114 | Accuracy: 97.10
115 | Accuracy: 97.37
116 | Accuracy: 97.26
117 | Accuracy: 96.83
118 | Accuracy: 97.53
119 | Accuracy: 97.15
120 | Accuracy: 97.47
121 | Accuracy: 97.31
122 | Accuracy: 97.10
123 | Accuracy: 97.10
124 | Accuracy: 96.94
125 | Accuracy: 97.37
126 | Accuracy: 97.15
127 | Accuracy: 96.94
128 | Accuracy: 97.20
129 | Accuracy: 97.69
130 | Accuracy: 96.99
131 | Accuracy: 97.42
132 | Accuracy: 97.26
133 | Accuracy: 97.42
134 | Accuracy: 97.53
135 | Accuracy: 97.20
136 | Accuracy: 97.20
137 | Accuracy: 97.26
138 | Accuracy: 97.31
139 | Accuracy: 97.69
140 | Accuracy: 97.42
141 | Accuracy: 97.63
142 | Accuracy: 96.99
143 | Accuracy: 97.37
144 | Accuracy: 97.20
145 | Accuracy: 97.47
146 | Accuracy: 97.69
147 | Accuracy: 97.42
148 | Accuracy: 97.63
149 | Accuracy: 97.15
150 | Accuracy: 97.63
151 | Accuracy: 96.99
152 | Accuracy: 96.94
153 | Accuracy: 97.42
154 | Accuracy: 97.10
155 | Accuracy: 97.58
156 | Accuracy: 97.26
157 | Accuracy: 97.26
158 | Accuracy: 97.15
159 | Accuracy: 97.53
160 | Accuracy: 97.15
161 | Accuracy: 97.74
162 | Accuracy: 97.15
163 | Accuracy: 97.15
164 | Accuracy: 97.63
165 | Accuracy: 97.37
166 | Accuracy: 97.63
167 | Accuracy: 97.20
168 | Accuracy: 97.85
169 | Accuracy: 96.94
170 | Accuracy: 97.53
171 | Accuracy: 97.31
172 | Accuracy: 97.63
173 | Accuracy: 97.15
174 | Accuracy: 97.37
175 | Accuracy: 97.53
176 | Accuracy: 97.69
177 | Accuracy: 97.26
178 | Accuracy: 97.20
179 | Accuracy: 97.37
180 | Accuracy: 97.20
181 | Accuracy: 97.53
182 | Accuracy: 97.80
183 | Accuracy: 97.69
184 | Accuracy: 97.85
185 | Accuracy: 97.63
186 | Accuracy: 97.85
187 | Accuracy: 97.63
188 | Accuracy: 97.80
189 | Accuracy: 97.69
190 | Accuracy: 97.69
191 | Accuracy: 97.47
192 | Accuracy: 97.15
193 | Accuracy: 97.74
194 | Accuracy: 97.74
195 | Accuracy: 97.37
196 | Accuracy: 97.26
197 | Accuracy: 97.15
198 | Accuracy: 97.15
199 | Accuracy: 97.74
200 | Accuracy: 97.26
201 | Accuracy: 97.80
202 |
--------------------------------------------------------------------------------
/digit_signal_classification/record/mnist_usps/train_2020-07-22-17-50-42.out:
--------------------------------------------------------------------------------
1 | Dis Loss: 0.0253108087927, Cls Loss: 7.70617485046, Lr C: 0.0002, Lr G: 0.0002
2 | Dis Loss: 0.0257472600788, Cls Loss: 2.13100028038, Lr C: 0.0002, Lr G: 0.0002
3 | Dis Loss: 0.0211079176515, Cls Loss: 0.869371533394, Lr C: 0.0002, Lr G: 0.0002
4 | Dis Loss: 0.0197001192719, Cls Loss: 0.90205258131, Lr C: 0.0002, Lr G: 0.0002
5 | Dis Loss: 0.0156597942114, Cls Loss: 0.794815421104, Lr C: 0.0002, Lr G: 0.0002
6 | Dis Loss: 0.0165562909096, Cls Loss: 0.698212444782, Lr C: 0.0002, Lr G: 0.0002
7 | Dis Loss: 0.0161991883069, Cls Loss: 0.591216087341, Lr C: 0.0002, Lr G: 0.0002
8 | Dis Loss: 0.0116136167198, Cls Loss: 0.447472929955, Lr C: 0.0002, Lr G: 0.0002
9 | Dis Loss: 0.0123001784086, Cls Loss: 0.585250616074, Lr C: 0.0002, Lr G: 0.0002
10 | Dis Loss: 0.0130175296217, Cls Loss: 0.244773492217, Lr C: 0.0002, Lr G: 0.0002
11 | Dis Loss: 0.0131531087682, Cls Loss: 0.251306712627, Lr C: 0.0002, Lr G: 0.0002
12 | Dis Loss: 0.0155279133469, Cls Loss: 0.4324208498, Lr C: 0.0002, Lr G: 0.0002
13 | Dis Loss: 0.00917468499392, Cls Loss: 0.191932618618, Lr C: 0.0002, Lr G: 0.0002
14 | Dis Loss: 0.0103086763993, Cls Loss: 0.259280085564, Lr C: 0.0002, Lr G: 0.0002
15 | Dis Loss: 0.0101417507976, Cls Loss: 0.349472016096, Lr C: 0.0002, Lr G: 0.0002
16 | Dis Loss: 0.0107553582639, Cls Loss: 0.242658004165, Lr C: 0.0002, Lr G: 0.0002
17 | Dis Loss: 0.00794821791351, Cls Loss: 0.169440254569, Lr C: 0.0002, Lr G: 0.0002
18 | Dis Loss: 0.00926570873708, Cls Loss: 0.170461535454, Lr C: 0.0002, Lr G: 0.0002
19 | Dis Loss: 0.00805878080428, Cls Loss: 0.371224373579, Lr C: 0.0002, Lr G: 0.0002
20 | Dis Loss: 0.00677244085819, Cls Loss: 0.183727711439, Lr C: 0.0002, Lr G: 0.0002
21 | Dis Loss: 0.00799888651818, Cls Loss: 0.167165458202, Lr C: 0.0002, Lr G: 0.0002
22 | Dis Loss: 0.00736241042614, Cls Loss: 0.178331822157, Lr C: 0.0002, Lr G: 0.0002
23 | Dis Loss: 0.00778476614505, Cls Loss: 0.0710266232491, Lr C: 0.0002, Lr G: 0.0002
24 | Dis Loss: 0.0095592122525, Cls Loss: 0.206102699041, Lr C: 0.0002, Lr G: 0.0002
25 | Dis Loss: 0.00676157185808, Cls Loss: 0.160276055336, Lr C: 0.0002, Lr G: 0.0002
26 | Dis Loss: 0.00798108335584, Cls Loss: 0.234084278345, Lr C: 0.0002, Lr G: 0.0002
27 | Dis Loss: 0.00671329069883, Cls Loss: 0.249807089567, Lr C: 0.0002, Lr G: 0.0002
28 | Dis Loss: 0.00961916334927, Cls Loss: 0.10863583535, Lr C: 0.0002, Lr G: 0.0002
29 | Dis Loss: 0.00781802367419, Cls Loss: 0.116392239928, Lr C: 0.0002, Lr G: 0.0002
30 | Dis Loss: 0.00667194742709, Cls Loss: 0.123922757804, Lr C: 0.0002, Lr G: 0.0002
31 | Dis Loss: 0.00642341328785, Cls Loss: 0.179728776217, Lr C: 0.0002, Lr G: 0.0002
32 | Dis Loss: 0.00476458016783, Cls Loss: 0.134883105755, Lr C: 0.0002, Lr G: 0.0002
33 | Dis Loss: 0.00620850687847, Cls Loss: 0.146140903234, Lr C: 0.0002, Lr G: 0.0002
34 | Dis Loss: 0.00488848984241, Cls Loss: 0.105417221785, Lr C: 0.0002, Lr G: 0.0002
35 | Dis Loss: 0.00591717753559, Cls Loss: 0.0877575054765, Lr C: 0.0002, Lr G: 0.0002
36 | Dis Loss: 0.0071969195269, Cls Loss: 0.0746075063944, Lr C: 0.0002, Lr G: 0.0002
37 | Dis Loss: 0.00616848794743, Cls Loss: 0.193884968758, Lr C: 0.0002, Lr G: 0.0002
38 | Dis Loss: 0.00574999628589, Cls Loss: 0.165285438299, Lr C: 0.0002, Lr G: 0.0002
39 | Dis Loss: 0.00534277455881, Cls Loss: 0.122619472444, Lr C: 0.0002, Lr G: 0.0002
40 | Dis Loss: 0.00775020429865, Cls Loss: 0.211391568184, Lr C: 0.0002, Lr G: 0.0002
41 | Dis Loss: 0.00501205166802, Cls Loss: 0.0465583950281, Lr C: 0.0002, Lr G: 0.0002
42 | Dis Loss: 0.00527259055525, Cls Loss: 0.140332207084, Lr C: 0.0002, Lr G: 0.0002
43 | Dis Loss: 0.00588731607422, Cls Loss: 0.146271795034, Lr C: 0.0002, Lr G: 0.0002
44 | Dis Loss: 0.00705251982436, Cls Loss: 0.071382895112, Lr C: 0.0002, Lr G: 0.0002
45 | Dis Loss: 0.00553363654763, Cls Loss: 0.14066195488, Lr C: 0.0002, Lr G: 0.0002
46 | Dis Loss: 0.00580461835489, Cls Loss: 0.102398604155, Lr C: 0.0002, Lr G: 0.0002
47 | Dis Loss: 0.00497902650386, Cls Loss: 0.0818369686604, Lr C: 0.0002, Lr G: 0.0002
48 | Dis Loss: 0.0048161377199, Cls Loss: 0.129988700151, Lr C: 0.0002, Lr G: 0.0002
49 | Dis Loss: 0.00550159625709, Cls Loss: 0.0817996561527, Lr C: 0.0002, Lr G: 0.0002
50 | Dis Loss: 0.00509450910613, Cls Loss: 0.0380837135017, Lr C: 0.0002, Lr G: 0.0002
51 | Dis Loss: 0.00466236053035, Cls Loss: 0.0444295965135, Lr C: 0.0002, Lr G: 0.0002
52 | Dis Loss: 0.00520153809339, Cls Loss: 0.0979211330414, Lr C: 0.0002, Lr G: 0.0002
53 | Dis Loss: 0.00325969606638, Cls Loss: 0.038551196456, Lr C: 0.0002, Lr G: 0.0002
54 | Dis Loss: 0.00417832052335, Cls Loss: 0.284563630819, Lr C: 0.0002, Lr G: 0.0002
55 | Dis Loss: 0.00407133577392, Cls Loss: 0.147980719805, Lr C: 0.0002, Lr G: 0.0002
56 | Dis Loss: 0.005461496301, Cls Loss: 0.0501725636423, Lr C: 0.0002, Lr G: 0.0002
57 | Dis Loss: 0.00335832685232, Cls Loss: 0.143918901682, Lr C: 0.0002, Lr G: 0.0002
58 | Dis Loss: 0.00428040558472, Cls Loss: 0.0378617122769, Lr C: 0.0002, Lr G: 0.0002
59 | Dis Loss: 0.00598321203142, Cls Loss: 0.0431411862373, Lr C: 0.0002, Lr G: 0.0002
60 | Dis Loss: 0.00505055859685, Cls Loss: 0.0629364699125, Lr C: 0.0002, Lr G: 0.0002
61 | Dis Loss: 0.00350512797013, Cls Loss: 0.0760557428002, Lr C: 0.0002, Lr G: 0.0002
62 | Dis Loss: 0.00371325272135, Cls Loss: 0.0363791808486, Lr C: 0.0002, Lr G: 0.0002
63 | Dis Loss: 0.00418157037348, Cls Loss: 0.12404447794, Lr C: 0.0002, Lr G: 0.0002
64 | Dis Loss: 0.00380450929515, Cls Loss: 0.0583959147334, Lr C: 0.0002, Lr G: 0.0002
65 | Dis Loss: 0.00499662337825, Cls Loss: 0.0246373452246, Lr C: 0.0002, Lr G: 0.0002
66 | Dis Loss: 0.00492464518175, Cls Loss: 0.152146041393, Lr C: 0.0002, Lr G: 0.0002
67 | Dis Loss: 0.00411057798192, Cls Loss: 0.0738733112812, Lr C: 0.0002, Lr G: 0.0002
68 | Dis Loss: 0.00432006036863, Cls Loss: 0.0773489102721, Lr C: 0.0002, Lr G: 0.0002
69 | Dis Loss: 0.005784294568, Cls Loss: 0.0629475861788, Lr C: 0.0002, Lr G: 0.0002
70 | Dis Loss: 0.00267514935695, Cls Loss: 0.10301348567, Lr C: 0.0002, Lr G: 0.0002
71 | Dis Loss: 0.00302348122932, Cls Loss: 0.0238332692534, Lr C: 0.0002, Lr G: 0.0002
72 | Dis Loss: 0.00361108197831, Cls Loss: 0.0807729959488, Lr C: 0.0002, Lr G: 0.0002
73 | Dis Loss: 0.00284019764513, Cls Loss: 0.12972868979, Lr C: 0.0002, Lr G: 0.0002
74 | Dis Loss: 0.00362093793228, Cls Loss: 0.0568101257086, Lr C: 0.0002, Lr G: 0.0002
75 | Dis Loss: 0.00546729750931, Cls Loss: 0.0592128634453, Lr C: 0.0002, Lr G: 0.0002
76 | Dis Loss: 0.00342573830858, Cls Loss: 0.134608700871, Lr C: 0.0002, Lr G: 0.0002
77 | Dis Loss: 0.00497301761061, Cls Loss: 0.0699812322855, Lr C: 0.0002, Lr G: 0.0002
78 | Dis Loss: 0.00336416368373, Cls Loss: 0.0736046805978, Lr C: 0.0002, Lr G: 0.0002
79 | Dis Loss: 0.00429906789213, Cls Loss: 0.0790677592158, Lr C: 0.0002, Lr G: 0.0002
80 | Dis Loss: 0.00298569747247, Cls Loss: 0.107853844762, Lr C: 0.0002, Lr G: 0.0002
81 | Dis Loss: 0.0031755536329, Cls Loss: 0.156428828835, Lr C: 0.0002, Lr G: 0.0002
82 | Dis Loss: 0.00451165623963, Cls Loss: 0.0192727744579, Lr C: 0.0002, Lr G: 0.0002
83 | Dis Loss: 0.00421656994149, Cls Loss: 0.0594312474132, Lr C: 0.0002, Lr G: 0.0002
84 | Dis Loss: 0.0050105410628, Cls Loss: 0.156177505851, Lr C: 0.0002, Lr G: 0.0002
85 | Dis Loss: 0.00470406422392, Cls Loss: 0.0755702778697, Lr C: 0.0002, Lr G: 0.0002
86 | Dis Loss: 0.00336070358753, Cls Loss: 0.0805172324181, Lr C: 0.0002, Lr G: 0.0002
87 | Dis Loss: 0.004280324094, Cls Loss: 0.0691900476813, Lr C: 0.0002, Lr G: 0.0002
88 | Dis Loss: 0.00417717173696, Cls Loss: 0.130094155669, Lr C: 0.0002, Lr G: 0.0002
89 | Dis Loss: 0.00336504052393, Cls Loss: 0.0347976088524, Lr C: 0.0002, Lr G: 0.0002
90 | Dis Loss: 0.00330058042891, Cls Loss: 0.0670301020145, Lr C: 0.0002, Lr G: 0.0002
91 | Dis Loss: 0.00478171044961, Cls Loss: 0.0160610731691, Lr C: 0.0002, Lr G: 0.0002
92 | Dis Loss: 0.00342379207723, Cls Loss: 0.0681992769241, Lr C: 0.0002, Lr G: 0.0002
93 | Dis Loss: 0.00253219390288, Cls Loss: 0.0226874481887, Lr C: 0.0002, Lr G: 0.0002
94 | Dis Loss: 0.00355479470454, Cls Loss: 0.0486062765121, Lr C: 0.0002, Lr G: 0.0002
95 | Dis Loss: 0.00282041728497, Cls Loss: 0.0907536819577, Lr C: 0.0002, Lr G: 0.0002
96 | Dis Loss: 0.00399726582691, Cls Loss: 0.0737735033035, Lr C: 0.0002, Lr G: 0.0002
97 | Dis Loss: 0.00471308920532, Cls Loss: 0.0530179291964, Lr C: 0.0002, Lr G: 0.0002
98 | Dis Loss: 0.00450990861282, Cls Loss: 0.0423623323441, Lr C: 0.0002, Lr G: 0.0002
99 | Dis Loss: 0.00480007519946, Cls Loss: 0.0181928016245, Lr C: 0.0002, Lr G: 0.0002
100 | Dis Loss: 0.0035850845743, Cls Loss: 0.166106507182, Lr C: 0.0002, Lr G: 0.0002
101 | Dis Loss: 0.00450746295974, Cls Loss: 0.0182992517948, Lr C: 0.0002, Lr G: 0.0002
102 | Dis Loss: 0.00498038670048, Cls Loss: 0.0431930832565, Lr C: 0.0002, Lr G: 0.0002
103 | Dis Loss: 0.00271419575438, Cls Loss: 0.104042746127, Lr C: 0.0002, Lr G: 0.0002
104 | Dis Loss: 0.00427831569687, Cls Loss: 0.0295623596758, Lr C: 0.0002, Lr G: 0.0002
105 | Dis Loss: 0.00321475090459, Cls Loss: 0.0229330658913, Lr C: 0.0002, Lr G: 0.0002
106 | Dis Loss: 0.00369507516734, Cls Loss: 0.0861396640539, Lr C: 0.0002, Lr G: 0.0002
107 | Dis Loss: 0.00369632034563, Cls Loss: 0.0460176132619, Lr C: 0.0002, Lr G: 0.0002
108 | Dis Loss: 0.00286093493924, Cls Loss: 0.0172299966216, Lr C: 0.0002, Lr G: 0.0002
109 | Dis Loss: 0.00490171322599, Cls Loss: 0.0776693746448, Lr C: 0.0002, Lr G: 0.0002
110 | Dis Loss: 0.00411299988627, Cls Loss: 0.0282558891922, Lr C: 0.0002, Lr G: 0.0002
111 | Dis Loss: 0.00458105280995, Cls Loss: 0.116378813982, Lr C: 0.0002, Lr G: 0.0002
112 | Dis Loss: 0.00599916884676, Cls Loss: 0.0564890913665, Lr C: 0.0002, Lr G: 0.0002
113 | Dis Loss: 0.00354321091436, Cls Loss: 0.0543925464153, Lr C: 0.0002, Lr G: 0.0002
114 | Dis Loss: 0.00155441684183, Cls Loss: 0.101817071438, Lr C: 0.0002, Lr G: 0.0002
115 | Dis Loss: 0.00411869958043, Cls Loss: 0.0668775439262, Lr C: 0.0002, Lr G: 0.0002
116 | Dis Loss: 0.00649306504056, Cls Loss: 0.0597132444382, Lr C: 0.0002, Lr G: 0.0002
117 | Dis Loss: 0.0048693343997, Cls Loss: 0.0541714504361, Lr C: 0.0002, Lr G: 0.0002
118 | Dis Loss: 0.00391428312287, Cls Loss: 0.0604442432523, Lr C: 0.0002, Lr G: 0.0002
119 | Dis Loss: 0.00543193845078, Cls Loss: 0.0975139811635, Lr C: 0.0002, Lr G: 0.0002
120 | Dis Loss: 0.00529876584187, Cls Loss: 0.074589818716, Lr C: 0.0002, Lr G: 0.0002
121 | Dis Loss: 0.00587701890618, Cls Loss: 0.046294644475, Lr C: 0.0002, Lr G: 0.0002
122 | Dis Loss: 0.00575475534424, Cls Loss: 0.0465457551181, Lr C: 0.0002, Lr G: 0.0002
123 | Dis Loss: 0.00320053333417, Cls Loss: 0.087657853961, Lr C: 0.0002, Lr G: 0.0002
124 | Dis Loss: 0.00315615464933, Cls Loss: 0.0458883568645, Lr C: 0.0002, Lr G: 0.0002
125 | Dis Loss: 0.00321982288733, Cls Loss: 0.0363517850637, Lr C: 0.0002, Lr G: 0.0002
126 | Dis Loss: 0.00380116188899, Cls Loss: 0.0939439088106, Lr C: 0.0002, Lr G: 0.0002
127 | Dis Loss: 0.00293942703865, Cls Loss: 0.0320774838328, Lr C: 0.0002, Lr G: 0.0002
128 | Dis Loss: 0.0045956983231, Cls Loss: 0.0672373920679, Lr C: 0.0002, Lr G: 0.0002
129 | Dis Loss: 0.00609593093395, Cls Loss: 0.0583413168788, Lr C: 0.0002, Lr G: 0.0002
130 | Dis Loss: 0.00370764220133, Cls Loss: 0.121347114444, Lr C: 0.0002, Lr G: 0.0002
131 | Dis Loss: 0.00603236770257, Cls Loss: 0.105289764702, Lr C: 0.0002, Lr G: 0.0002
132 | Dis Loss: 0.00454461202025, Cls Loss: 0.0347447767854, Lr C: 0.0002, Lr G: 0.0002
133 | Dis Loss: 0.00476987520233, Cls Loss: 0.0154802519828, Lr C: 0.0002, Lr G: 0.0002
134 | Dis Loss: 0.00467550195754, Cls Loss: 0.0853082090616, Lr C: 0.0002, Lr G: 0.0002
135 | Dis Loss: 0.00362639478408, Cls Loss: 0.0213281717151, Lr C: 0.0002, Lr G: 0.0002
136 | Dis Loss: 0.00576998572797, Cls Loss: 0.088567301631, Lr C: 0.0002, Lr G: 0.0002
137 | Dis Loss: 0.00389611953869, Cls Loss: 0.0775500386953, Lr C: 0.0002, Lr G: 0.0002
138 | Dis Loss: 0.00364692509174, Cls Loss: 0.0360069274902, Lr C: 0.0002, Lr G: 0.0002
139 | Dis Loss: 0.00672007678077, Cls Loss: 0.0251312982291, Lr C: 0.0002, Lr G: 0.0002
140 | Dis Loss: 0.00412650499493, Cls Loss: 0.143935501575, Lr C: 0.0002, Lr G: 0.0002
141 | Dis Loss: 0.00546471495181, Cls Loss: 0.0206368491054, Lr C: 0.0002, Lr G: 0.0002
142 | Dis Loss: 0.00304111326113, Cls Loss: 0.126500666142, Lr C: 0.0002, Lr G: 0.0002
143 | Dis Loss: 0.0051850057207, Cls Loss: 0.0578130409122, Lr C: 0.0002, Lr G: 0.0002
144 | Dis Loss: 0.00496583385393, Cls Loss: 0.0406974591315, Lr C: 0.0002, Lr G: 0.0002
145 | Dis Loss: 0.00398027477786, Cls Loss: 0.129645079374, Lr C: 0.0002, Lr G: 0.0002
146 | Dis Loss: 0.00436131609604, Cls Loss: 0.0291619319469, Lr C: 0.0002, Lr G: 0.0002
147 | Dis Loss: 0.00447409320623, Cls Loss: 0.0951287448406, Lr C: 0.0002, Lr G: 0.0002
148 | Dis Loss: 0.00567617500201, Cls Loss: 0.0566911026835, Lr C: 0.0002, Lr G: 0.0002
149 | Dis Loss: 0.00418065488338, Cls Loss: 0.0225957930088, Lr C: 0.0002, Lr G: 0.0002
150 | Dis Loss: 0.00386782549322, Cls Loss: 0.0340443402529, Lr C: 0.0002, Lr G: 0.0002
151 | Dis Loss: 0.00331043382175, Cls Loss: 0.0555837228894, Lr C: 0.0002, Lr G: 0.0002
152 | Dis Loss: 0.00312547618523, Cls Loss: 0.0342776626348, Lr C: 0.0002, Lr G: 0.0002
153 | Dis Loss: 0.00458795810118, Cls Loss: 0.0738851577044, Lr C: 0.0002, Lr G: 0.0002
154 | Dis Loss: 0.00356465089135, Cls Loss: 0.070107460022, Lr C: 0.0002, Lr G: 0.0002
155 | Dis Loss: 0.00446801865473, Cls Loss: 0.0710966438055, Lr C: 0.0002, Lr G: 0.0002
156 | Dis Loss: 0.00504029775038, Cls Loss: 0.0303533878177, Lr C: 0.0002, Lr G: 0.0002
157 | Dis Loss: 0.00510559463874, Cls Loss: 0.0256489515305, Lr C: 0.0002, Lr G: 0.0002
158 | Dis Loss: 0.00315452227369, Cls Loss: 0.0374249257147, Lr C: 0.0002, Lr G: 0.0002
159 | Dis Loss: 0.00650208070874, Cls Loss: 0.0266140140593, Lr C: 0.0002, Lr G: 0.0002
160 | Dis Loss: 0.00455721607432, Cls Loss: 0.0496471635997, Lr C: 0.0002, Lr G: 0.0002
161 | Dis Loss: 0.00424830336124, Cls Loss: 0.0775512009859, Lr C: 0.0002, Lr G: 0.0002
162 | Dis Loss: 0.00599554041401, Cls Loss: 0.0285082086921, Lr C: 0.0002, Lr G: 0.0002
163 | Dis Loss: 0.0052908747457, Cls Loss: 0.0163153167814, Lr C: 0.0002, Lr G: 0.0002
164 | Dis Loss: 0.00422675767913, Cls Loss: 0.0244442783296, Lr C: 0.0002, Lr G: 0.0002
165 | Dis Loss: 0.00468214182183, Cls Loss: 0.058757096529, Lr C: 0.0002, Lr G: 0.0002
166 | Dis Loss: 0.00464393757284, Cls Loss: 0.0295814089477, Lr C: 0.0002, Lr G: 0.0002
167 | Dis Loss: 0.0040175607428, Cls Loss: 0.0164988674223, Lr C: 0.0002, Lr G: 0.0002
168 | Dis Loss: 0.00507428310812, Cls Loss: 0.0277650821954, Lr C: 0.0002, Lr G: 0.0002
169 | Dis Loss: 0.00443507311866, Cls Loss: 0.0347390472889, Lr C: 0.0002, Lr G: 0.0002
170 | Dis Loss: 0.00319409067743, Cls Loss: 0.0735673382878, Lr C: 0.0002, Lr G: 0.0002
171 | Dis Loss: 0.00557447317988, Cls Loss: 0.129221603274, Lr C: 0.0002, Lr G: 0.0002
172 | Dis Loss: 0.00489986827597, Cls Loss: 0.0250022243708, Lr C: 0.0002, Lr G: 0.0002
173 | Dis Loss: 0.00367637211457, Cls Loss: 0.037488501519, Lr C: 0.0002, Lr G: 0.0002
174 | Dis Loss: 0.00279768346809, Cls Loss: 0.0477290600538, Lr C: 0.0002, Lr G: 0.0002
175 | Dis Loss: 0.00308806146495, Cls Loss: 0.0397023484111, Lr C: 0.0002, Lr G: 0.0002
176 | Dis Loss: 0.00338975223713, Cls Loss: 0.0628024786711, Lr C: 0.0002, Lr G: 0.0002
177 | Dis Loss: 0.00590053154156, Cls Loss: 0.0783554241061, Lr C: 0.0002, Lr G: 0.0002
178 | Dis Loss: 0.00350930891, Cls Loss: 0.102638430893, Lr C: 0.0002, Lr G: 0.0002
179 | Dis Loss: 0.00412848358974, Cls Loss: 0.0205194931477, Lr C: 0.0002, Lr G: 0.0002
180 | Dis Loss: 0.00548736704513, Cls Loss: 0.0289202295244, Lr C: 0.0002, Lr G: 0.0002
181 | Dis Loss: 0.00462893256918, Cls Loss: 0.0628123059869, Lr C: 0.0002, Lr G: 0.0002
182 | Dis Loss: 0.00720173725858, Cls Loss: 0.0199466813356, Lr C: 0.0002, Lr G: 0.0002
183 | Dis Loss: 0.00340282986872, Cls Loss: 0.0598597824574, Lr C: 0.0002, Lr G: 0.0002
184 | Dis Loss: 0.00293013406917, Cls Loss: 0.0531075187027, Lr C: 0.0002, Lr G: 0.0002
185 | Dis Loss: 0.00399965513498, Cls Loss: 0.029729001224, Lr C: 0.0002, Lr G: 0.0002
186 | Dis Loss: 0.00509276194498, Cls Loss: 0.112514361739, Lr C: 0.0002, Lr G: 0.0002
187 | Dis Loss: 0.00507830735296, Cls Loss: 0.0705089122057, Lr C: 0.0002, Lr G: 0.0002
188 | Dis Loss: 0.00379161047749, Cls Loss: 0.0800604820251, Lr C: 0.0002, Lr G: 0.0002
189 | Dis Loss: 0.00571160251275, Cls Loss: 0.0446820110083, Lr C: 0.0002, Lr G: 0.0002
190 | Dis Loss: 0.00470952829346, Cls Loss: 0.0398795008659, Lr C: 0.0002, Lr G: 0.0002
191 | Dis Loss: 0.00690906727687, Cls Loss: 0.0992165058851, Lr C: 0.0002, Lr G: 0.0002
192 | Dis Loss: 0.00446418579668, Cls Loss: 0.0821895599365, Lr C: 0.0002, Lr G: 0.0002
193 | Dis Loss: 0.0044578476809, Cls Loss: 0.0322776995599, Lr C: 0.0002, Lr G: 0.0002
194 | Dis Loss: 0.00570461526513, Cls Loss: 0.0650118067861, Lr C: 0.0002, Lr G: 0.0002
195 | Dis Loss: 0.00549853919074, Cls Loss: 0.0504232347012, Lr C: 0.0002, Lr G: 0.0002
196 | Dis Loss: 0.00388650037348, Cls Loss: 0.129654437304, Lr C: 0.0002, Lr G: 0.0002
197 | Dis Loss: 0.00780218094587, Cls Loss: 0.0459640324116, Lr C: 0.0002, Lr G: 0.0002
198 | Dis Loss: 0.00588699011132, Cls Loss: 0.0635501295328, Lr C: 0.0002, Lr G: 0.0002
199 | Dis Loss: 0.00327239348553, Cls Loss: 0.0425560176373, Lr C: 0.0002, Lr G: 0.0002
200 | Dis Loss: 0.00459404895082, Cls Loss: 0.0682985782623, Lr C: 0.0002, Lr G: 0.0002
201 |
--------------------------------------------------------------------------------
/digit_signal_classification/solver.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | from model.build_gen import *
7 | from datasets.dataset_read import dataset_read
8 | from utils.avgmeter import AverageMeter
9 | import time
10 | from tensorboardX import SummaryWriter
11 |
12 |
13 | # Training settings
14 | class Solver(object):
15 | def __init__(
16 | self,
17 | args,
18 | batch_size=64,
19 | source='svhn',
20 | target='mnist',
21 | learning_rate=0.0002,
22 | interval=100,
23 | optimizer='adam',
24 | num_k=4,
25 | all_use=False,
26 | checkpoint_dir=None,
27 | save_epoch=10,
28 | num_classifiers_train=2,
29 | num_classifiers_test=20,
30 | init='kaiming_u',
31 | use_init=False,
32 | dis_metric='L1'
33 | ):
34 |
35 | self.batch_size = batch_size
36 | self.source = source
37 | self.target = target
38 | self.num_k = num_k
39 | self.checkpoint_dir = checkpoint_dir
40 | self.save_epoch = save_epoch
41 | self.use_abs_diff = args.use_abs_diff
42 | self.all_use = all_use
43 | self.num_classifiers_train = num_classifiers_train
44 | self.num_classifiers_test = num_classifiers_test
45 | self.init = init
46 | self.dis_metric = dis_metric
47 | self.use_init = use_init
48 |
49 | if self.source == 'svhn':
50 | self.scale = True
51 | else:
52 | self.scale = False
53 |
54 | print('dataset loading')
55 | self.datasets, self.dataset_test = dataset_read(
56 | source, target,
57 | self.batch_size,
58 | scale=self.scale,
59 | all_use=self.all_use
60 | )
61 | print('load finished!')
62 |
63 | self.G = Generator(source=source, target=target)
64 | self.C = Classifier(
65 | source=source, target=target,
66 | num_classifiers_train=self.num_classifiers_train,
67 | num_classifiers_test=self.num_classifiers_test,
68 | init=self.init,
69 | use_init=self.use_init
70 | )
71 |
72 | if args.eval_only:
73 | self.G.torch.load('{}/{}_to_{}_model_epoch{}_G.pt'.format(
74 | self.checkpoint_dir,
75 | self.source,
76 | self.target,
77 | args.resume_epoch)
78 | )
79 |
80 | self.C.torch.load('{}/{}_to_{}_model_epoch{}_C.pt'.format(
81 | self.checkpoint_dir,
82 | self.source,
83 | self.target,
84 | args.resume_epoch)
85 | )
86 |
87 | self.G.cuda()
88 | self.C.cuda()
89 | self.interval = interval
90 | self.writer = SummaryWriter()
91 |
92 | self.opt_c, self.opt_g = self.set_optimizer(
93 | which_opt=optimizer,
94 | lr=learning_rate
95 | )
96 | self.lr = learning_rate
97 |
98 | # Learning rate scheduler
99 | self.scheduler_g = optim.lr_scheduler.CosineAnnealingLR(
100 | self.opt_g,
101 | float(args.max_epoch)
102 | )
103 | self.scheduler_c = optim.lr_scheduler.CosineAnnealingLR(
104 | self.opt_c,
105 | float(args.max_epoch)
106 | )
107 |
108 | def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9):
109 | if which_opt == 'momentum':
110 | self.opt_g = optim.SGD(
111 | self.G.parameters(),
112 | lr=lr,
113 | weight_decay=0.0005,
114 | momentum=momentum
115 | )
116 |
117 | self.opt_c = optim.SGD(
118 | self.C.parameters(),
119 | lr=lr,
120 | weight_decay=0.0005,
121 | momentum=momentum
122 | )
123 |
124 | if which_opt == 'adam':
125 | self.opt_g = optim.Adam(
126 | self.G.parameters(),
127 | lr=lr,
128 | weight_decay=0.0005,
129 | amsgrad=False
130 | )
131 |
132 | self.opt_c = optim.Adam(
133 | self.C.parameters(),
134 | lr=lr,
135 | weight_decay=0.0005,
136 | amsgrad=False
137 | )
138 |
139 | return self.opt_c, self.opt_g
140 |
141 | def reset_grad(self):
142 | self.opt_g.zero_grad()
143 | self.opt_c.zero_grad()
144 |
145 | @staticmethod
146 | def entropy(x):
147 | b = F.softmax(x) * F.log_softmax(x)
148 | b = -1.0 * b.sum()
149 | return b
150 |
151 | @staticmethod
152 | def discrepancy(out1, out2):
153 | l1loss = torch.nn.L1Loss()
154 | return l1loss(F.softmax(out1, dim=1), F.softmax(out2, dim=1))
155 |
156 | @staticmethod
157 | def discrepancy_mse(out1, out2):
158 | mseloss = torch.nn.MSELoss()
159 | return mseloss(F.softmax(out1, dim=1), F.softmax(out2, dim=1))
160 |
161 | @staticmethod
162 | def discrepancy_cos(out1, out2):
163 | cosloss = torch.nn.CosineSimilarity()
164 | return 1 - cosloss(F.softmax(out1, dim=1), F.softmax(out2, dim=1))
165 |
166 | @staticmethod
167 | def discrepancy_slice_wasserstein(p1, p2):
168 | p1 = torch.sigmoid(p1)
169 | p2 = torch.sigmoid(p2)
170 | s = p1.shape
171 | if s[1] > 1:
172 | proj = torch.randn(s[1], 128).cuda()
173 | proj *= torch.rsqrt(torch.sum(torch.mul(proj, proj), 0, keepdim=True))
174 | p1 = torch.matmul(p1, proj)
175 | p2 = torch.matmul(p2, proj)
176 | p1 = torch.topk(p1, s[0], dim=0)[0]
177 | p2 = torch.topk(p2, s[0], dim=0)[0]
178 | dist = p1 - p2
179 | wdist = torch.mean(torch.mul(dist, dist))
180 |
181 | return wdist
182 |
183 | def train(self, epoch, record_file=None, loss_process='mean', func='L1'):
184 | criterion = nn.CrossEntropyLoss().cuda()
185 |
186 | # Various measurements for the discrepancy
187 | dis_dict = {
188 | 'L1': self.discrepancy,
189 | 'MSE': self.discrepancy_mse,
190 | 'Cosine': self.discrepancy_cos,
191 | 'SWD': self.discrepancy_slice_wasserstein
192 | }
193 |
194 | self.G.train()
195 | self.C.train()
196 | torch.cuda.manual_seed(1)
197 | batch_time = AverageMeter()
198 | data_time = AverageMeter()
199 |
200 | batch_num = min(
201 | len(self.datasets.data_loader_A),
202 | len(self.datasets.data_loader_B)
203 | )
204 |
205 | end = time.time()
206 |
207 | for batch_idx, data in enumerate(self.datasets):
208 | data_time.update(time.time() - end)
209 | img_t = data['T']
210 | img_s = data['S']
211 | label_s = data['S_label']
212 |
213 | if img_s.size()[0] < self.batch_size or img_t.size()[0] < self.batch_size:
214 | break
215 |
216 | img_s = img_s.cuda()
217 | img_t = img_t.cuda()
218 |
219 | imgs_st = torch.cat((img_s, img_t), dim=0)
220 |
221 | label_s = label_s.long().cuda()
222 |
223 | # Step1: update the whole network using source data
224 | self.reset_grad()
225 | feat_s = self.G(img_s)
226 | outputs_s = self.C(feat_s)
227 |
228 | loss_s = []
229 | for index_tr in range(self.num_classifiers_train):
230 | loss_s.append(criterion(outputs_s[index_tr], label_s))
231 |
232 | if loss_process == 'mean':
233 | loss_s = torch.stack(loss_s).mean()
234 | else:
235 | loss_s = torch.stack(loss_s).sum()
236 |
237 | loss_s.backward()
238 | self.opt_g.step()
239 | self.opt_c.step()
240 |
241 | # Step2: update the classifiers using target data
242 | self.reset_grad()
243 | feat_st = self.G(imgs_st)
244 | outputs_st = self.C(feat_st)
245 | outputs_s = [
246 | outputs_st[0][:self.batch_size],
247 | outputs_st[1][:self.batch_size]
248 | ]
249 | outputs_t = [
250 | outputs_st[0][self.batch_size:],
251 | outputs_st[1][self.batch_size:]
252 | ]
253 |
254 | loss_s = []
255 | loss_dis = []
256 | for index_tr in range(self.num_classifiers_train):
257 | loss_s.append(criterion(outputs_s[index_tr], label_s))
258 |
259 | if loss_process == 'mean':
260 | loss_s = torch.stack(loss_s).mean()
261 | else:
262 | loss_s = torch.stack(loss_s).sum()
263 |
264 | for index_tr in range(self.num_classifiers_train):
265 | for index_tre in range(index_tr + 1, self.num_classifiers_train):
266 | loss_dis.append(dis_dict[func](outputs_t[index_tr], outputs_t[index_tre]))
267 |
268 | if loss_process == 'mean':
269 | loss_dis = torch.stack(loss_dis).mean()
270 | else:
271 | loss_dis = torch.stack(loss_dis).sum()
272 |
273 | loss = loss_s - loss_dis
274 |
275 | loss.backward()
276 | self.opt_c.step()
277 |
278 | # Step3: update the generator using target data
279 | self.reset_grad()
280 |
281 | for index in range(self.num_k+1):
282 | loss_dis = []
283 | feat_t = self.G(img_t)
284 | outputs_t = self.C(feat_t)
285 |
286 | for index_tr in range(self.num_classifiers_train):
287 | for index_tre in range(index_tr + 1, self.num_classifiers_train):
288 | loss_dis.append(dis_dict[func](outputs_t[index_tr], outputs_t[index_tre]))
289 |
290 | if loss_process == 'mean':
291 | loss_dis = torch.stack(loss_dis).mean()
292 | else:
293 | loss_dis = torch.stack(loss_dis).sum()
294 |
295 | loss_dis.backward()
296 | self.opt_g.step()
297 | self.reset_grad()
298 |
299 | batch_time.update(time.time() - end)
300 |
301 | if batch_idx % self.interval == 0:
302 | print('Train Epoch: {} [{}/{}]\t '
303 | 'Loss: {:.6f}\t '
304 | 'Discrepancy: {:.6f} \t '
305 | 'Lr C: {:.6f}\t'
306 | 'Lr G: {:.6f}\t'
307 | 'Time: {:.3f}({:.3f})\t'
308 | .format(epoch + 1, batch_idx,
309 | batch_num, loss_s.data,
310 | loss_dis.data,
311 | self.opt_c.param_groups[0]['lr'],
312 | self.opt_g.param_groups[0]['lr'],
313 | batch_time.val, batch_time.avg))
314 |
315 | if record_file:
316 | record = open(record_file, 'a')
317 | record.write('Dis Loss: {}, Cls Loss: {}, Lr C: {}, Lr G: {} \n'
318 | .format(loss_dis.data.cpu().numpy(),
319 | loss_s.data.cpu().numpy(),
320 | self.opt_c.param_groups[0]['lr'],
321 | self.opt_g.param_groups[0]['lr']))
322 | record.close()
323 |
324 | def test(self, epoch, record_file=None, save_model=False):
325 | criterion = nn.CrossEntropyLoss().cuda()
326 | self.G.eval()
327 | self.C.eval()
328 | test_loss = 0
329 | correct = 0
330 | size = 0
331 | with torch.no_grad():
332 | for batch_idx, data in enumerate(self.dataset_test):
333 | img = data['T']
334 | label = data['T_label']
335 | img, label = img.cuda(), label.long().cuda()
336 | feat = self.G(img)
337 | outputs = self.C(feat)
338 | test_loss += criterion(outputs[0], label).data
339 | k = label.data.size()[0]
340 | output_ensemble = torch.zeros(outputs[0].shape).cuda()
341 |
342 | for index in range(len(outputs)):
343 | output_ensemble += outputs[index]
344 |
345 | pred_ensemble = output_ensemble.data.max(1)[1]
346 | correct += pred_ensemble.eq(label.data).cpu().sum()
347 | size += k
348 | test_loss = test_loss / size
349 |
350 | print('\nTest set: Average loss: {:.4f}\t Ensemble Accuracy: {}/{} ({:.2f}%)'
351 | .format(test_loss, correct, size, 100. * float(correct) / size))
352 |
353 | if save_model and epoch % self.save_epoch == 0:
354 | torch.save(self.G, '{}/{}_to_{}_model_epoch{}_G.pt'
355 | .format(self.checkpoint_dir, self.source, self.target, epoch))
356 | torch.save(self.C, '{}/{}_to_{}_model_epoch{}_C.pt'
357 | .format(self.checkpoint_dir, self.source, self.target, epoch))
358 |
359 | if record_file:
360 | record = open(record_file, 'a')
361 | print('Recording {}'.format(record_file))
362 | record.write('Accuracy: {:.2f}'.format(100. * float(correct) / size))
363 | record.write('\n')
364 | record.close()
365 |
366 | self.writer.add_scalar('Test/loss', test_loss, epoch)
367 | self.writer.add_scalar('Test/ACC_en', 100. * float(correct) / size, epoch)
368 |
--------------------------------------------------------------------------------
/digit_signal_classification/utils/avgmeter.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | __all__ = ['AverageMeter']
5 |
6 |
7 | class AverageMeter(object):
8 | """Computes and stores the average and current value.
9 | Examples::
10 | >>> # Initialize a meter to record loss
11 | >>> losses = AverageMeter()
12 | >>> # Update meter after every minibatch update
13 | >>> losses.update(loss_value, batch_size)
14 | """
15 | def __init__(self):
16 | self.reset()
17 |
18 | def reset(self):
19 | self.val = 0
20 | self.avg = 0
21 | self.sum = 0
22 | self.count = 0
23 |
24 | def update(self, val, n=1):
25 | self.val = val
26 | self.sum += val * n
27 | self.count += n
28 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/digit_signal_classification/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def weights_init(m):
5 | classname = m.__class__.__name__
6 | if classname.find('Conv') != -1:
7 | m.weight.data.normal_(0.0, 0.01)
8 | m.bias.data.normal_(0.0, 0.01)
9 | elif classname.find('BatchNorm') != -1:
10 | m.weight.data.normal_(1.0, 0.01)
11 | m.bias.data.fill_(0)
12 |
13 | def dense_to_one_hot(labels_dense):
14 | """Convert class labels from scalars to one-hot vectors."""
15 | labels_one_hot = np.zeros((len(labels_dense),))
16 | labels_dense = list(labels_dense)
17 | for i, t in enumerate(labels_dense):
18 | if t == 10:
19 | t = 0
20 | labels_one_hot[i] = t
21 | else:
22 | labels_one_hot[i] = t
23 | return labels_one_hot
24 |
--------------------------------------------------------------------------------