├── _version.py
├── torchtools
├── _version.py
├── tt
│ ├── __pycache__
│ │ ├── arg.cpython-36.pyc
│ │ ├── stat.cpython-36.pyc
│ │ ├── layer.cpython-36.pyc
│ │ ├── logger.cpython-36.pyc
│ │ ├── utils.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── trainer.cpython-36.pyc
│ ├── __init__.py
│ ├── layer.py
│ ├── stat.py
│ ├── trainer.py
│ ├── utils.py
│ ├── arg.py
│ └── logger.py
├── __pycache__
│ └── __init__.cpython-36.pyc
└── __init__.py
├── __pycache__
└── __init__.cpython-36.pyc
├── .idea
├── vcs.xml
├── modules.xml
├── egnn_distribute.iml
└── workspace.xml
├── __init__.py
├── LICENSE
├── eval.py
├── README.md
├── model.py
├── data.py
└── train.py
/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.4.0' # align version with pytorch
2 |
3 |
--------------------------------------------------------------------------------
/torchtools/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.4.0' # align version with pytorch
2 |
3 |
--------------------------------------------------------------------------------
/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/torchtools/tt/__pycache__/arg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/arg.cpython-36.pyc
--------------------------------------------------------------------------------
/torchtools/tt/__pycache__/stat.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/stat.cpython-36.pyc
--------------------------------------------------------------------------------
/torchtools/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/torchtools/tt/__pycache__/layer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/layer.cpython-36.pyc
--------------------------------------------------------------------------------
/torchtools/tt/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/torchtools/tt/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/torchtools/tt/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/torchtools/tt/__pycache__/trainer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/trainer.cpython-36.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/torchtools/tt/__init__.py:
--------------------------------------------------------------------------------
1 | from torchtools.tt.arg import _parse_opts
2 | from torchtools.tt.utils import *
3 | from torchtools.tt.layer import *
4 | from torchtools.tt.logger import *
5 | from torchtools.tt.stat import *
6 | from torchtools.tt.trainer import *
7 |
8 |
9 | __author__ = 'namju.kim@kakaobrain.com'
10 |
11 |
12 | # global command line arguments
13 | arg = _parse_opts()
14 |
--------------------------------------------------------------------------------
/torchtools/tt/layer.py:
--------------------------------------------------------------------------------
1 | from torchtools import nn
2 |
3 |
4 | #
5 | # Reshape layer for Sequential or ModuleList
6 | #
7 | class Reshape(nn.Module):
8 |
9 | def __init__(self, *shape):
10 | super(Reshape, self).__init__()
11 | self.shape = shape
12 |
13 | def forward(self, x):
14 | return x.reshape(self.shape)
15 |
16 | def extra_repr(self):
17 | return 'shape={}'.format(self.shape)
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch import optim
5 | from torch import cuda
6 | from torch import utils
7 | from torch.nn import functional as F
8 | from torch.utils.data import *
9 | from torch.distributions import *
10 | from torchtools import tt
11 |
12 |
13 | __author__ = 'namju.kim@kakaobrain.com'
14 |
15 |
16 | # initialize seed
17 | if tt.arg.seed:
18 | np.random.seed(tt.arg.seed)
19 | torch.manual_seed(tt.arg.seed)
20 |
--------------------------------------------------------------------------------
/torchtools/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch import optim
5 | from torch import cuda
6 | from torch import utils
7 | from torch.nn import functional as F
8 | from torch.utils.data import *
9 | from torch.distributions import *
10 | from torchtools import tt
11 |
12 |
13 | __author__ = 'namju.kim@kakaobrain.com'
14 |
15 |
16 | # initialize seed
17 | if tt.arg.seed:
18 | np.random.seed(tt.arg.seed)
19 | torch.manual_seed(tt.arg.seed)
20 |
--------------------------------------------------------------------------------
/torchtools/tt/stat.py:
--------------------------------------------------------------------------------
1 | from torchtools import tt
2 |
3 |
4 | __author__ = 'namju.kim@kakaobrain.com'
5 |
6 |
7 | def accuracy(prob, label, ignore_index=-100):
8 |
9 | # argmax
10 | pred = prob.max(1)[1].type_as(label)
11 |
12 | # masking
13 | mask = label.ne(ignore_index)
14 | pred = pred.masked_select(mask)
15 | label = label.masked_select(mask)
16 |
17 | # calc accuracy
18 | hit = tt.nvar(pred.eq(label).long().sum())
19 | acc = hit / label.size(0)
20 | return acc
21 |
--------------------------------------------------------------------------------
/.idea/egnn_distribute.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Jongmin Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/torchtools/tt/trainer.py:
--------------------------------------------------------------------------------
1 | from torchtools import nn, optim, tt
2 |
3 |
4 | __author__ = 'namju.kim@kakaobrain.com'
5 |
6 |
7 | class SupervisedTrainer(object):
8 |
9 | def __init__(self, model, data_loader, optimizer=None, criterion=None):
10 | self.global_step = 0
11 | self.model = model.to(tt.arg.device)
12 | self.data_loader = data_loader
13 | self.optimizer = optimizer or optim.Adam(model.parameters())
14 | self.criterion = criterion or nn.CrossEntropyLoss()
15 |
16 | def train(self, inputs):
17 |
18 | # split inputs
19 | x, y = inputs
20 |
21 | # forward
22 | if tt.arg.cuda:
23 | z = nn.DataParallel(self.model)(x)
24 | else:
25 | z = self.model(x)
26 |
27 | # loss
28 | loss = self.criterion(z, y)
29 |
30 | # accuracy
31 | acc = tt.accuracy(z, y)
32 |
33 | # update model
34 | self.optimizer.zero_grad()
35 | loss.backward()
36 | self.optimizer.step()
37 |
38 | # logging
39 | tt.log_scalar('loss', loss, self.global_step)
40 | tt.log_scalar('acc', acc, self.global_step)
41 |
42 | def epoch(self, ep_no=None):
43 | pass
44 |
45 | def run(self):
46 |
47 | # experiment name
48 | tt.arg.experiment = tt.arg.experiment or self.model.__class__.__name__.lower()
49 |
50 | # load model
51 | self.global_step = self.model.load_model()
52 | epoch, min_step = divmod(self.global_step, len(self.data_loader))
53 |
54 | # epochs
55 | while epoch < (tt.arg.epoch or 1):
56 | epoch += 1
57 |
58 | # iterations
59 | for step, inputs in enumerate(self.data_loader, min_step + 1):
60 |
61 | # check step counter
62 | if step > len(self.data_loader):
63 | break
64 |
65 | # increase global step count
66 | self.global_step += 1
67 |
68 | # update learning rate
69 | for param_group in self.optimizer.param_groups:
70 | param_group['lr'] = tt.arg.lr
71 |
72 | # call train func
73 | if type(inputs) in [list, tuple]:
74 | self.train([tt.var(d) for d in inputs])
75 | else:
76 | self.train(tt.var(inputs))
77 |
78 | # logging
79 | tt.log_weight(self.model, global_step=self.global_step)
80 | tt.log_gradient(self.model, global_step=self.global_step)
81 | tt.log_step(epoch=epoch, global_step=self.global_step,
82 | max_epoch=(tt.arg.epoch or 1), max_step=len(self.data_loader))
83 |
84 | # save model
85 | self.model.save_model(self.global_step)
86 |
87 | # epoch handler
88 | self.epoch(epoch)
89 |
90 | # save final model
91 | self.model.save_model(self.global_step, force=True)
92 |
--------------------------------------------------------------------------------
/torchtools/tt/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import time
4 | import pathlib
5 | from torchtools import torch, nn, tt
6 |
7 |
8 | __author__ = 'namju.kim@kakaobrain.com'
9 |
10 |
11 | # time stamp
12 | _tic_start = _last_saved = _last_archived = time.time()
13 | # best statics
14 | _best = -100000000.
15 |
16 |
17 | def tic():
18 | global _tic_start
19 | _tic_start = time.time()
20 | return _tic_start
21 |
22 |
23 | def toc(tic=None):
24 | global _tic_start
25 | if tic is None:
26 | return time.time() - _tic_start
27 | else:
28 | return time.time() - tic
29 |
30 |
31 | def sleep(seconds):
32 | time.sleep(seconds)
33 |
34 |
35 | #
36 | # automatic device-aware torch.tensor
37 | #
38 | def var(data, dtype=None, device=None, requires_grad=False):
39 | # return torch.tensor(data, dtype=dtype, device=(device or tt.arg.device), requires_grad=requires_grad)
40 | # the upper code doesn't work, so work around as following. ( maybe bug )
41 | return torch.tensor(data, dtype=dtype, requires_grad=requires_grad).to((device or tt.arg.device))
42 |
43 |
44 | def vars(x_list, dtype=None, device=None, requires_grad=False):
45 | return [var(x, dtype, device, requires_grad) for x in x_list]
46 |
47 |
48 | # for old torchtools compatibility
49 | def cvar(x):
50 | return x.detach()
51 |
52 |
53 | #
54 | # to python or numpy variable(s)
55 | #
56 | def nvar(x):
57 | if isinstance(x, torch.Tensor):
58 | x = x.detach().cpu()
59 | x = x.item() if x.dim() == 0 else x.numpy()
60 | return x
61 |
62 |
63 | def nvars(x_list):
64 | return [nvar(x) for x in x_list]
65 |
66 |
67 | def load_model(model, best=False, postfix=None, experiment=None):
68 | global _best
69 |
70 | # model file name
71 | filename = tt.arg.save_dir + '%s.pt' % (experiment or tt.arg.experiment or model.__class__.__name__.lower())
72 | if postfix is not None:
73 | filename = filename + '.%s' % postfix
74 |
75 | # load model
76 | global_step = 0
77 | if os.path.exists(filename):
78 | if best:
79 | global_step, model_state, _best = torch.load(filename + '.best', map_location=lambda storage, loc: storage)
80 | else:
81 | global_step, model_state = torch.load(filename, map_location=lambda storage, loc: storage)
82 | model.load_state_dict(model_state)
83 |
84 | # update best stat
85 | filename += '.best'
86 | if os.path.exists(filename):
87 | _, _, _best = torch.load(filename, map_location=lambda storage, loc: storage)
88 |
89 | return global_step
90 |
91 |
92 | def save_model(model, global_step, force=False, best=None, postfix=None):
93 | global _last_saved, _last_archived, _best
94 |
95 | # make directory
96 | pathlib.Path(tt.arg.save_dir).mkdir(parents=True, exist_ok=True)
97 |
98 | # filename to save
99 | filename = '%s.pt' % (tt.arg.experiment or model.__class__.__name__.lower())
100 | if postfix is not None:
101 | filename = filename + '.%s' % postfix
102 |
103 | # save model
104 | if force or (tt.arg.save_interval and time.time() - _last_saved >= tt.arg.save_interval) or \
105 | (tt.arg.save_step and global_step % tt.arg.save_step == 0):
106 | torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename)
107 | _last_saved = time.time()
108 |
109 | # archive model
110 | if (tt.arg.archive_interval and time.time() - _last_archived >= tt.arg.archive_interval) or \
111 | (tt.arg.archive_step and global_step % tt.arg.archive_step == 0):
112 | # filename to archive
113 | if tt.arg.archive_interval:
114 | filename = filename + datetime.datetime.now().strftime('.%Y%m%d.%H%M%S')
115 | else:
116 | filename = filename + '.%d' % global_step
117 | torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename)
118 | _last_archived = time.time()
119 |
120 | # save best model
121 | if best is not None and best > _best:
122 | _best = best
123 | filename = filename + '.best'
124 | torch.save((global_step, model.state_dict(), best), tt.arg.save_dir + filename)
125 |
126 |
127 | # patch Module
128 | nn.Module.load_model = load_model
129 | nn.Module.save_model = save_model
130 |
--------------------------------------------------------------------------------
/torchtools/tt/arg.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import configparser
3 | import torch
4 | import threading
5 | import time
6 | import os
7 |
8 |
9 | __author__ = 'namju.kim@kakaobrain.com'
10 |
11 |
12 | _config_time_stamp = 0
13 |
14 |
15 | class _Opt(object):
16 |
17 | def __len__(self):
18 | return len(self.__dict__)
19 |
20 | def __setitem__(self, key, value):
21 | self.__dict__[key] = value
22 |
23 | def __getitem__(self, item):
24 | if item in self.__dict__:
25 | return self.__dict__[item]
26 | else:
27 | return None
28 |
29 | def __getattr__(self, item):
30 | return self.__getitem__(item)
31 |
32 |
33 | def _to_py_obj(x):
34 | # check boolean first
35 | if x.lower() in ['true', 'yes', 'on']:
36 | return True
37 | if x.lower() in ['false', 'no', 'off']:
38 | return False
39 | # from string to python object if possible
40 | try:
41 | obj = eval(x)
42 | if type(obj).__name__ in ['int', 'float', 'tuple', 'list', 'dict', 'NoneType']:
43 | x = obj
44 | except:
45 | pass
46 | return x
47 |
48 |
49 | def _parse_config(arg, file):
50 |
51 | # read config file
52 | config = configparser.ConfigParser()
53 | config.read(file)
54 | # traverse sections
55 | for section in config.sections():
56 | # traverse items
57 | opt = _Opt()
58 | for key in config[section]:
59 | opt[key] = _to_py_obj(config[section][key])
60 | # if default section, save items to global scope
61 | if section.lower() == 'default':
62 | for k, v in opt.__dict__.items():
63 | arg[k] = v
64 | else:
65 | arg['_'.join(section.split())] = opt
66 |
67 |
68 | def _parse_config_thread(arg, file):
69 |
70 | global _config_time_stamp
71 |
72 | while True:
73 | # check timestamp
74 | stamp = os.stat(file).st_mtime
75 | if not stamp == _config_time_stamp:
76 | # update timestamp
77 | _config_time_stamp = stamp
78 | # parse config file
79 | _parse_config(arg, file)
80 | # print result
81 | # _print_opts(arg, 'CONFIGURATION CHANGE DETECTED')
82 | # sleep
83 | time.sleep(1)
84 |
85 |
86 | def _print_opts(arg, header):
87 | print(header, flush=True)
88 | print('-' * 30, flush=True)
89 | for k, v in arg.__dict__.items():
90 | print('%s=%s' % (k, v), flush=True)
91 | print('-' * 30, flush=True)
92 |
93 |
94 | def _parse_opts():
95 |
96 | global _config_time_stamp
97 |
98 | # get command line arguments
99 | arg = _Opt()
100 | argv = sys.argv[1:]
101 |
102 | # check length
103 | assert len(argv) % 2 == 0, 'arguments should be paired with the format of --key value'
104 |
105 | # parse args
106 | for i in range(0, len(argv), 2):
107 |
108 | # check format
109 | assert argv[i].startswith('--'), 'arguments should be paired with the format of --key value'
110 |
111 | # save argument
112 | arg[argv[i][2:]] = _to_py_obj(argv[i + 1])
113 |
114 | # check config file
115 | if argv[i][2:].lower() == 'config':
116 | _parse_config(arg, argv[i + 1])
117 | _config_time_stamp = os.stat(argv[i + 1]).st_mtime
118 |
119 | #
120 | # inject default options
121 | #
122 |
123 | # device setting
124 | if arg.device is None:
125 | arg.device = 'cuda' if torch.cuda.is_available() else 'cpu'
126 | arg.device = torch.device(arg.device)
127 | arg.cuda = arg.device.type == 'cuda'
128 |
129 | # default learning rate
130 | #arg.lr = 1e-3
131 |
132 | # directories
133 | arg.log_dir = arg.log_dir or 'asset/log/'
134 | arg.data_dir = arg.data_dir or 'asset/data/'
135 | arg.save_dir = arg.save_dir or 'asset/train/'
136 | arg.log_dir += '' if arg.log_dir.endswith('/') else '/'
137 | arg.data_dir += '' if arg.data_dir.endswith('/') else '/'
138 | arg.save_dir += '' if arg.save_dir.endswith('/') else '/'
139 |
140 | # print arg option
141 | # _print_opts(arg, 'CONFIGURATION')
142 |
143 | # start config file watcher if config is defined
144 | if arg.config:
145 | t = threading.Thread(target=_parse_config_thread, args=(arg, arg.config))
146 | t.daemon = True
147 | t.start()
148 |
149 | return arg
150 |
--------------------------------------------------------------------------------
/torchtools/tt/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import time
3 | from tensorboardX import SummaryWriter
4 | from torchtools import tt
5 |
6 |
7 | __author__ = 'namju.kim@kakaobrain.com'
8 |
9 |
10 | # tensorboard writer
11 | _writer = None
12 | _stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist = {}, {}, {}, {}, {}
13 |
14 | # time stamp
15 | _last_logged = time.time()
16 |
17 |
18 | # general print wrapper
19 | def log(*args):
20 | print(*args, flush=True)
21 | # save to log_file
22 | if tt.arg.log_file:
23 | with open(tt.arg.log_dir + tt.arg.log_file, 'a') as f:
24 | print(*args, flush=True, file=f)
25 |
26 |
27 | # tensor board writer
28 | def _get_writer():
29 | global _writer
30 | if _writer is None:
31 | # logging directory
32 | tf_log_dir = tt.arg.log_dir
33 | tf_log_dir += '' if tf_log_dir.endswith('/') else '/'
34 | if tt.arg.experiment:
35 | tf_log_dir += tt.arg.experiment
36 | tf_log_dir += datetime.datetime.now().strftime('-%Y%m%d-%H%M%S')
37 | # create writer
38 | _writer = SummaryWriter(tf_log_dir)
39 | return _writer
40 |
41 |
42 | def log_scalar(tag, value, global_step=None):
43 | _stats_scalar[tag] = (tt.nvar(value), global_step)
44 |
45 |
46 | def log_audio(tag, audio, global_step=None):
47 | _stats_audio[tag] = (tt.nvar(audio), global_step)
48 |
49 |
50 | def log_image(tag, image, global_step=None):
51 | _stats_image[tag] = (tt.nvar(image), global_step)
52 |
53 |
54 | def log_text(tag, text, global_step=None):
55 | _stats_text[tag] = (text, global_step)
56 |
57 |
58 | def log_hist(tag, values, global_step=None):
59 | _stats_hist[tag] = (tt.nvar(values), global_step)
60 |
61 |
62 | def log_step(epoch=None, global_step=None, max_epoch=None, max_step=None):
63 |
64 | global _last_logged, _last_logged_step, _stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist
65 |
66 | # logging
67 | if (tt.arg.log_interval is None and tt.arg.log_step is None) or \
68 | (tt.arg.log_interval and time.time() - _last_logged >= tt.arg.log_interval) or \
69 | (tt.arg.log_step and global_step % tt.arg.log_step == 0):
70 |
71 | # update logging time stamp
72 | _last_logged = time.time()
73 | _last_logged_step = global_step
74 |
75 | # console output string
76 | console_out = ''
77 | if epoch:
78 | console_out += 'ep: %d' % epoch
79 | if max_epoch:
80 | console_out += '/%d' % max_epoch
81 | if global_step:
82 | if max_step:
83 | step = global_step % max_step
84 | step = max_step if step == 0 else step
85 | console_out += ' step: %d/%d' % (step, max_step)
86 | else:
87 | console_out += ' step: %d' % global_step
88 |
89 | # add stats to tensor board
90 | for k, v in _stats_scalar.items():
91 | _get_writer().add_scalar(k, *v)
92 | # add to console output
93 | if not k.startswith('weight/') and not k.startswith('gradient/'):
94 | console_out += ' %s: %f' % (k, v[0])
95 | for k, v in _stats_image.items():
96 | _get_writer().add_image(k, *v)
97 | for k, v in _stats_audio.items():
98 | _get_writer().add_audio(k, *v)
99 | for k, v in _stats_text.items():
100 | _get_writer().add_text(k, *v)
101 | for k, v in _stats_hist.items():
102 | _get_writer().add_histogram(k, *v, 'auto')
103 |
104 | # flush
105 | _get_writer().file_writer.flush()
106 |
107 | # console out
108 | if len(console_out) > 0:
109 | log(console_out)
110 |
111 | # clear stats
112 | _stats_scalar, _stats_image, _stats_audio, _stats_text = {}, {}, {}, {}
113 |
114 |
115 | def log_weight(model, global_step=None):
116 | # weight statics
117 | if tt.arg.log_weight:
118 | for k, v in model.named_parameters():
119 | if 'weight' in k: # only for weight not bias
120 | log_scalar('weight/' + k, v.norm(), global_step)
121 |
122 |
123 | def log_gradient(model, global_step=None):
124 | # gradient statics
125 | if tt.arg.log_grad:
126 | for k, v in model.named_parameters():
127 | if 'weight' in k: # only for weight not bias
128 | if v.grad is not None:
129 | log_scalar('gradient/' + k, v.grad.norm(), global_step)
130 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from torchtools import *
2 | from data import MiniImagenetLoader, TieredImagenetLoader
3 | from model import EmbeddingImagenet, GraphNetwork, ConvNet
4 | import shutil
5 | import os
6 | import random
7 | from train import ModelTrainer
8 |
9 | if __name__ == '__main__':
10 |
11 | tt.arg.test_model = 'D-mini_N-5_K-1_U-0_L-3_B-40_T-True' if tt.arg.test_model is None else tt.arg.test_model
12 |
13 | list1 = tt.arg.test_model.split("_")
14 | param = {}
15 | for i in range(len(list1)):
16 | param[list1[i].split("-", 1)[0]] = list1[i].split("-", 1)[1]
17 | tt.arg.dataset = param['D']
18 | tt.arg.num_ways = int(param['N'])
19 | tt.arg.num_shots = int(param['K'])
20 | tt.arg.num_unlabeled = int(param['U'])
21 | tt.arg.num_layers = int(param['L'])
22 | tt.arg.meta_batch_size = int(param['B'])
23 | tt.arg.transductive = False if param['T'] == 'False' else True
24 |
25 |
26 | ####################
27 | tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device
28 | # replace dataset_root with your own
29 | tt.arg.dataset_root = '/data/private/dataset'
30 | tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset
31 | tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways
32 | tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots
33 | tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled
34 | tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers
35 | tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size
36 | tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive
37 | tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed
38 | tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus
39 |
40 | tt.arg.num_ways_train = tt.arg.num_ways
41 | tt.arg.num_ways_test = tt.arg.num_ways
42 |
43 | tt.arg.num_shots_train = tt.arg.num_shots
44 | tt.arg.num_shots_test = tt.arg.num_shots
45 |
46 | tt.arg.train_transductive = tt.arg.transductive
47 | tt.arg.test_transductive = tt.arg.transductive
48 |
49 | # model parameter related
50 | tt.arg.num_edge_features = 96
51 | tt.arg.num_node_features = 96
52 | tt.arg.emb_size = 128
53 |
54 | # train, test parameters
55 | tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000
56 | tt.arg.test_iteration = 10000
57 | tt.arg.test_interval = 5000
58 | tt.arg.test_batch_size = 10
59 | tt.arg.log_step = 1000
60 |
61 | tt.arg.lr = 1e-3
62 | tt.arg.grad_clip = 5
63 | tt.arg.weight_decay = 1e-6
64 | tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000
65 | tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0
66 |
67 | #set random seed
68 | np.random.seed(tt.arg.seed)
69 | torch.manual_seed(tt.arg.seed)
70 | torch.cuda.manual_seed_all(tt.arg.seed)
71 | random.seed(tt.arg.seed)
72 | torch.backends.cudnn.deterministic = True
73 | torch.backends.cudnn.benchmark = False
74 |
75 |
76 | enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)
77 |
78 | # set random seed
79 | np.random.seed(tt.arg.seed)
80 | torch.manual_seed(tt.arg.seed)
81 | torch.cuda.manual_seed_all(tt.arg.seed)
82 | random.seed(tt.arg.seed)
83 | torch.backends.cudnn.deterministic = True
84 | torch.backends.cudnn.benchmark = False
85 |
86 | # to check
87 | exp_name = 'D-{}'.format(tt.arg.dataset)
88 | exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled)
89 | exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size)
90 | exp_name += '_T-{}'.format(tt.arg.transductive)
91 |
92 |
93 | if not exp_name == tt.arg.test_model:
94 | print(exp_name)
95 | print(tt.arg.test_model)
96 | print('Test model and input arguments are mismatched!')
97 | AssertionError()
98 |
99 | gnn_module = GraphNetwork(in_features=tt.arg.emb_size,
100 | node_features=tt.arg.num_edge_features,
101 | edge_features=tt.arg.num_node_features,
102 | num_layers=tt.arg.num_layers,
103 | dropout=tt.arg.dropout)
104 |
105 | if tt.arg.dataset == 'mini':
106 | test_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='test')
107 | elif tt.arg.dataset == 'tiered':
108 | test_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='test')
109 | else:
110 | print('Unknown dataset!')
111 |
112 |
113 | data_loader = {'test': test_loader}
114 |
115 | # create trainer
116 | tester = ModelTrainer(enc_module=enc_module,
117 | gnn_module=gnn_module,
118 | data_loader=data_loader)
119 |
120 |
121 | #checkpoint = torch.load('asset/checkpoints/{}/'.format(exp_name) + 'model_best.pth.tar')
122 | checkpoint = torch.load('./trained_models/{}/'.format(exp_name) + 'model_best.pth.tar')
123 |
124 |
125 | tester.enc_module.load_state_dict(checkpoint['enc_module_state_dict'])
126 | print("load pre-trained enc_nn done!")
127 |
128 | # initialize gnn pre-trained
129 | tester.gnn_module.load_state_dict(checkpoint['gnn_module_state_dict'])
130 | print("load pre-trained egnn done!")
131 |
132 | tester.val_acc = checkpoint['val_acc']
133 | tester.global_step = checkpoint['iteration']
134 |
135 | print(tester.global_step)
136 |
137 |
138 | tester.eval(partition='test')
139 |
140 |
141 |
142 |
143 |
144 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # fewshot-egnn
2 |
3 | ### Introduction
4 |
5 | The current project page provides pytorch code that implements the following CVPR2019 paper:
6 | **Title:** "Edge-labeling Graph Neural Network for Few-shot Learning"
7 | **Authors:** Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D.Yoo
8 |
9 | **Institution:** KAIST, KaKaoBrain
10 | **Code:** https://github.com/khy0809/fewshot-egnn
11 | **Arxiv:** https://arxiv.org/abs/1905.01436
12 |
13 | **Abstract:**
14 | In this paper, we propose a novel edge-labeling graph
15 | neural network (EGNN), which adapts a deep neural network
16 | on the edge-labeling graph, for few-shot learning.
17 | The previous graph neural network (GNN) approaches in
18 | few-shot learning have been based on the node-labeling
19 | framework, which implicitly models the intra-cluster similarity
20 | and the inter-cluster dissimilarity. In contrast, the
21 | proposed EGNN learns to predict the edge-labels rather
22 | than the node-labels on the graph that enables the evolution
23 | of an explicit clustering by iteratively updating the edgelabels
24 | with direct exploitation of both intra-cluster similarity
25 | and the inter-cluster dissimilarity. It is also well suited
26 | for performing on various numbers of classes without retraining,
27 | and can be easily extended to perform a transductive
28 | inference. The parameters of the EGNN are learned
29 | by episodic training with an edge-labeling loss to obtain a
30 | well-generalizable model for unseen low-data problem. On
31 | both of the supervised and semi-supervised few-shot image
32 | classification tasks with two benchmark datasets, the proposed
33 | EGNN significantly improves the performances over
34 | the existing GNNs.
35 |
36 | ### Citation
37 | If you find this code useful you can cite us using the following bibTex:
38 | ```
39 | @article{kim2019egnn,
40 | title={Edge-labeling Graph Neural Network for Few-shot Learning},
41 | author={Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D. Yoo},
42 | journal={arXiv preprint arXiv:1905.01436},
43 | year={2019}
44 | }
45 | ```
46 |
47 |
48 | ### Platform
49 | This code was developed and tested with pytorch version 1.0.1
50 |
51 | ### Setting
52 |
53 | You can download miniImagenet dataset from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w).
54 |
55 | Download 'mini_imagenet_train/val/test.pickle', and put them in the path
56 | 'tt.arg.dataset_root/mini-imagenet/compacted_dataset/'
57 |
58 | In ```train.py```, replace the dataset root directory with your own:
59 | tt.arg.dataset_root = '/data/private/dataset'
60 |
61 |
62 |
63 | ### Training
64 |
65 | ```
66 | # ************************** miniImagenet, 5way 1shot *****************************
67 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive False
68 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive True
69 |
70 | # ************************** miniImagenet, 5way 5shot *****************************
71 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive False
72 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive True
73 |
74 | # ************************** miniImagenet, 10way 5shot *****************************
75 | $ python3 train.py --dataset mini --num_ways 10 --num_shots 5 --meta_batch_size 20 --transductive True
76 |
77 | # ************************** tieredImagenet, 5way 5shot *****************************
78 | $ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive False
79 | $ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive True
80 |
81 | # **************** miniImagenet, 5way 5shot, 20% labeled (semi) *********************
82 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive False
83 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive True
84 |
85 | ```
86 |
87 | ### Evaluation
88 | The trained models are saved in the path './asset/checkpoints/', with the name of 'D-{dataset}-N-{ways}-K-{shots}-U-{num_unlabeld}-L-{num_layers}-B-{batch size}-T-{transductive}'.
89 | So, for example, if you want to test the trained model of 'miniImagenet, 5way 1shot, transductive' setting, you can give --test_model argument as follow:
90 | ```
91 | $ python3 eval.py --test_model D-mini_N-5_K-1_U-0_L-3_B-40_T-True
92 | ```
93 |
94 |
95 | ## Result
96 | Here are some experimental results presented in the paper. You should be able to reproduce all the results by using the trained models which can be downloaded from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w).
97 | #### miniImageNet, non-transductive
98 |
99 | | Model | 5-way 5-shot acc (%)|
100 | |--------------------------| ------------------: |
101 | | Matching Networks [1] | 55.30 |
102 | | Reptile [2] | 62.74 |
103 | | Prototypical Net [3] | 65.77 |
104 | | GNN [4] | 66.41 |
105 | | **(ours)** EGNN | **66.85** |
106 |
107 | #### miniImageNet, transductive
108 |
109 | | Model | 5-way 5-shot acc (%)|
110 | |--------------------------| ------------------: |
111 | | MAML [5] | 63.11 |
112 | | Reptile + BN [2] | 65.99 |
113 | | Relation Net [6] | 67.07 |
114 | | MAML + Transduction [5] | 66.19 |
115 | | TPN [7] | 69.43 |
116 | | TPN (Higher K) [7] | 69.86 |
117 | | **(ours)** EGNN | **76.37** |
118 |
119 | #### tieredImageNet, non-transductive
120 |
121 | | Model | 5-way 5-shot acc (%)|
122 | |--------------------------| ------------------: |
123 | | Reptile [2] | 66.47 |
124 | | Prototypical Net [3] | 69.57 |
125 | | **(ours)** EGNN | **70.98** |
126 |
127 | #### tieredImageNet, transductive
128 |
129 | | Model | 5-way 5-shot acc (%)|
130 | |--------------------------| ------------------: |
131 | | MAML [5] | 70.30 |
132 | | Reptile + BN [2] | 71.03 |
133 | | Relation Net [6] | 71.31 |
134 | | MAML + Transduction [5] | 70.83 |
135 | | TPN [7] | 72.58 |
136 | | **(ours)** EGNN | **80.15** |
137 |
138 |
139 | #### miniImageNet, semi-supervised, 5-way 5-shot
140 |
141 | | Model | 20% | 40% | 60% | 100% |
142 | |--------------------------| ------------------: | ------------------: | ------------------: | ------------------: |
143 | | GNN-LabeledOnly [4] | 50.33 | 56.91 | - | 66.41 |
144 | | GNN-Semi [4] | 52.45 | 58.76 | - | 66.41 |
145 | | EGNN-LabeledOnly | 52.86 | - | - | 66.85 |
146 | | EGNN-Semi | 61.88 | 62.52 | 63.53 | 66.85 |
147 | | EGNN-LabeledOnly (Transductive) | 59.18 | - | - | 76.37 |
148 | | EGNN-Semi (Transductive) | 63.62 | 64.32 | 66.37 | 76.37 |
149 |
150 |
151 | #### miniImageNet, cross-way experiment
152 | | Model | train way | test way | Accuracy |
153 | |--------------------------| ------------------: | ------------------: | ------------------: |
154 | | GNN | 5 | 5 | 66.41 |
155 | | GNN | 5 | 10 | N/A |
156 | | GNN | 10 | 10 | 51.75 |
157 | | GNN | 10 | 5 | N/A |
158 | | EGNN | 5 | 5 | 76.37 |
159 | | EGNN | 5 | 10 | 56.35 |
160 | | EGNN | 10 | 10 | 57.61 |
161 | | EGNN | 10 | 5 | 76.27 |
162 |
163 |
164 |
165 | ### References
166 | ```
167 | [1] O. Vinyals et al. Matching networks for one shot learning.
168 | [2] A Nichol, J Achiam, J Schulman, On first-order meta-learning algorithms.
169 | [3] J. Snell, K. Swersky, and R. S. Zemel. Prototypical networks for few-shot learning.
170 | [4] V Garcia, J Bruna, Few-shot learning with graph neural network.
171 | [5] C. Finn, P. Abbeel, and S. Levine. Model-agnostic meta-learning for fast adaptation of deep networks.
172 | [6] F. Sung et al, Learning to Compare: Relation Network for Few-Shot Learning.
173 | [7] Y Liu, J Lee, M Park, S Kim, Y Yang, Transductive propagation network for few-shot learning.
174 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 | tt.arg.inter_dea
100 | inter_deactivate
101 |
102 |
103 |
104 |
105 |
106 |
107 |
114 |
115 |
116 |
117 |
118 | true
119 | DEFINITION_ORDER
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 | 1556855662817
172 |
173 |
174 | 1556855662817
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from torchtools import *
2 | from collections import OrderedDict
3 | import math
4 | #import seaborn as sns
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | class ConvBlock(nn.Module):
10 | def __init__(self, in_planes, out_planes, userelu=True, momentum=0.1, affine=True, track_running_stats=True):
11 | super(ConvBlock, self).__init__()
12 | self.layers = nn.Sequential()
13 | self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes,
14 | kernel_size=3, stride=1, padding=1, bias=False))
15 |
16 | if tt.arg.normtype == 'batch':
17 | self.layers.add_module('Norm', nn.BatchNorm2d(out_planes, momentum=momentum, affine=affine, track_running_stats=track_running_stats))
18 | elif tt.arg.normtype == 'instance':
19 | self.layers.add_module('Norm', nn.InstanceNorm2d(out_planes))
20 |
21 | if userelu:
22 | self.layers.add_module('ReLU', nn.ReLU(inplace=True))
23 |
24 | self.layers.add_module(
25 | 'MaxPool', nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
26 |
27 | def forward(self, x):
28 | out = self.layers(x)
29 | return out
30 |
31 | class ConvNet(nn.Module):
32 | def __init__(self, opt, momentum=0.1, affine=True, track_running_stats=True):
33 | super(ConvNet, self).__init__()
34 | self.in_planes = opt['in_planes']
35 | self.out_planes = opt['out_planes']
36 | self.num_stages = opt['num_stages']
37 | if type(self.out_planes) == int:
38 | self.out_planes = [self.out_planes for i in range(self.num_stages)]
39 | assert(type(self.out_planes)==list and len(self.out_planes)==self.num_stages)
40 |
41 | num_planes = [self.in_planes,] + self.out_planes
42 | userelu = opt['userelu'] if ('userelu' in opt) else True
43 |
44 | conv_blocks = []
45 | for i in range(self.num_stages):
46 | if i == (self.num_stages-1):
47 | conv_blocks.append(
48 | ConvBlock(num_planes[i], num_planes[i+1], userelu=userelu))
49 | else:
50 | conv_blocks.append(
51 | ConvBlock(num_planes[i], num_planes[i+1]))
52 | self.conv_blocks = nn.Sequential(*conv_blocks)
53 |
54 | for m in self.modules():
55 | if isinstance(m, nn.Conv2d):
56 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
57 | m.weight.data.normal_(0, math.sqrt(2. / n))
58 | elif isinstance(m, nn.BatchNorm2d):
59 | m.weight.data.fill_(1)
60 | m.bias.data.zero_()
61 |
62 | def forward(self, x):
63 | out = self.conv_blocks(x)
64 | out = out.view(out.size(0),-1)
65 | return out
66 |
67 |
68 |
69 | # encoder for imagenet dataset
70 | class EmbeddingImagenet(nn.Module):
71 | def __init__(self,
72 | emb_size):
73 | super(EmbeddingImagenet, self).__init__()
74 | # set size
75 | self.hidden = 64
76 | self.last_hidden = self.hidden * 25
77 | self.emb_size = emb_size
78 |
79 | # set layers
80 | self.conv_1 = nn.Sequential(nn.Conv2d(in_channels=3,
81 | out_channels=self.hidden,
82 | kernel_size=3,
83 | padding=1,
84 | bias=False),
85 | nn.BatchNorm2d(num_features=self.hidden),
86 | nn.MaxPool2d(kernel_size=2),
87 | nn.LeakyReLU(negative_slope=0.2, inplace=True))
88 | self.conv_2 = nn.Sequential(nn.Conv2d(in_channels=self.hidden,
89 | out_channels=int(self.hidden*1.5),
90 | kernel_size=3,
91 | bias=False),
92 | nn.BatchNorm2d(num_features=int(self.hidden*1.5)),
93 | nn.MaxPool2d(kernel_size=2),
94 | nn.LeakyReLU(negative_slope=0.2, inplace=True))
95 | self.conv_3 = nn.Sequential(nn.Conv2d(in_channels=int(self.hidden*1.5),
96 | out_channels=self.hidden*2,
97 | kernel_size=3,
98 | padding=1,
99 | bias=False),
100 | nn.BatchNorm2d(num_features=self.hidden * 2),
101 | nn.MaxPool2d(kernel_size=2),
102 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
103 | nn.Dropout2d(0.4))
104 | self.conv_4 = nn.Sequential(nn.Conv2d(in_channels=self.hidden*2,
105 | out_channels=self.hidden*4,
106 | kernel_size=3,
107 | padding=1,
108 | bias=False),
109 | nn.BatchNorm2d(num_features=self.hidden * 4),
110 | nn.MaxPool2d(kernel_size=2),
111 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
112 | nn.Dropout2d(0.5))
113 | self.layer_last = nn.Sequential(nn.Linear(in_features=self.last_hidden * 4,
114 | out_features=self.emb_size, bias=True),
115 | nn.BatchNorm1d(self.emb_size))
116 |
117 | def forward(self, input_data):
118 | output_data = self.conv_4(self.conv_3(self.conv_2(self.conv_1(input_data))))
119 | return self.layer_last(output_data.view(output_data.size(0), -1))
120 |
121 |
122 |
123 |
124 | class NodeUpdateNetwork(nn.Module):
125 | def __init__(self,
126 | in_features,
127 | num_features,
128 | ratio=[2, 1],
129 | dropout=0.0):
130 | super(NodeUpdateNetwork, self).__init__()
131 | # set size
132 | self.in_features = in_features
133 | self.num_features_list = [num_features * r for r in ratio]
134 | self.dropout = dropout
135 |
136 | # layers
137 | layer_list = OrderedDict()
138 | for l in range(len(self.num_features_list)):
139 |
140 | layer_list['conv{}'.format(l)] = nn.Conv2d(
141 | in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
142 | out_channels=self.num_features_list[l],
143 | kernel_size=1,
144 | bias=False)
145 | layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
146 | )
147 | layer_list['relu{}'.format(l)] = nn.LeakyReLU()
148 |
149 | if self.dropout > 0 and l == (len(self.num_features_list) - 1):
150 | layer_list['drop{}'.format(l)] = nn.Dropout2d(p=self.dropout)
151 |
152 | self.network = nn.Sequential(layer_list)
153 |
154 | def forward(self, node_feat, edge_feat):
155 | # get size
156 | num_tasks = node_feat.size(0)
157 | num_data = node_feat.size(1)
158 |
159 | # get eye matrix (batch_size x 2 x node_size x node_size)
160 | diag_mask = 1.0 - torch.eye(num_data).unsqueeze(0).unsqueeze(0).repeat(num_tasks, 2, 1, 1).to(tt.arg.device)
161 |
162 | # set diagonal as zero and normalize
163 | edge_feat = F.normalize(edge_feat * diag_mask, p=1, dim=-1)
164 |
165 | # compute attention and aggregate
166 | aggr_feat = torch.bmm(torch.cat(torch.split(edge_feat, 1, 1), 2).squeeze(1), node_feat)
167 |
168 | node_feat = torch.cat([node_feat, torch.cat(aggr_feat.split(num_data, 1), -1)], -1).transpose(1, 2)
169 |
170 | # non-linear transform
171 | node_feat = self.network(node_feat.unsqueeze(-1)).transpose(1, 2).squeeze(-1)
172 | return node_feat
173 |
174 |
175 | class EdgeUpdateNetwork(nn.Module):
176 | def __init__(self,
177 | in_features,
178 | num_features,
179 | ratio=[2, 2, 1, 1],
180 | separate_dissimilarity=False,
181 | dropout=0.0):
182 | super(EdgeUpdateNetwork, self).__init__()
183 | # set size
184 | self.in_features = in_features
185 | self.num_features_list = [num_features * r for r in ratio]
186 | self.separate_dissimilarity = separate_dissimilarity
187 | self.dropout = dropout
188 |
189 | # layers
190 | layer_list = OrderedDict()
191 | for l in range(len(self.num_features_list)):
192 | # set layer
193 | layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
194 | out_channels=self.num_features_list[l],
195 | kernel_size=1,
196 | bias=False)
197 | layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
198 | )
199 | layer_list['relu{}'.format(l)] = nn.LeakyReLU()
200 |
201 | if self.dropout > 0:
202 | layer_list['drop{}'.format(l)] = nn.Dropout2d(p=self.dropout)
203 |
204 | layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
205 | out_channels=1,
206 | kernel_size=1)
207 | self.sim_network = nn.Sequential(layer_list)
208 |
209 | if self.separate_dissimilarity:
210 | # layers
211 | layer_list = OrderedDict()
212 | for l in range(len(self.num_features_list)):
213 | # set layer
214 | layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
215 | out_channels=self.num_features_list[l],
216 | kernel_size=1,
217 | bias=False)
218 | layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
219 | )
220 | layer_list['relu{}'.format(l)] = nn.LeakyReLU()
221 |
222 | if self.dropout > 0:
223 | layer_list['drop{}'.format(l)] = nn.Dropout(p=self.dropout)
224 |
225 | layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
226 | out_channels=1,
227 | kernel_size=1)
228 | self.dsim_network = nn.Sequential(layer_list)
229 |
230 | def forward(self, node_feat, edge_feat):
231 | # compute abs(x_i, x_j)
232 | x_i = node_feat.unsqueeze(2)
233 | x_j = torch.transpose(x_i, 1, 2)
234 | x_ij = torch.abs(x_i - x_j)
235 | x_ij = torch.transpose(x_ij, 1, 3)
236 |
237 | # compute similarity/dissimilarity (batch_size x feat_size x num_samples x num_samples)
238 | sim_val = F.sigmoid(self.sim_network(x_ij))
239 |
240 | if self.separate_dissimilarity:
241 | dsim_val = F.sigmoid(self.dsim_network(x_ij))
242 | else:
243 | dsim_val = 1.0 - sim_val
244 |
245 |
246 | diag_mask = 1.0 - torch.eye(node_feat.size(1)).unsqueeze(0).unsqueeze(0).repeat(node_feat.size(0), 2, 1, 1).to(tt.arg.device)
247 | edge_feat = edge_feat * diag_mask
248 | merge_sum = torch.sum(edge_feat, -1, True)
249 | # set diagonal as zero and normalize
250 | edge_feat = F.normalize(torch.cat([sim_val, dsim_val], 1) * edge_feat, p=1, dim=-1) * merge_sum
251 | force_edge_feat = torch.cat((torch.eye(node_feat.size(1)).unsqueeze(0), torch.zeros(node_feat.size(1), node_feat.size(1)).unsqueeze(0)), 0).unsqueeze(0).repeat(node_feat.size(0), 1, 1, 1).to(tt.arg.device)
252 | edge_feat = edge_feat + force_edge_feat
253 | edge_feat = edge_feat + 1e-6
254 | edge_feat = edge_feat / torch.sum(edge_feat, dim=1).unsqueeze(1).repeat(1, 2, 1, 1)
255 |
256 | return edge_feat
257 |
258 |
259 | class GraphNetwork(nn.Module):
260 | def __init__(self,
261 | in_features,
262 | node_features,
263 | edge_features,
264 | num_layers,
265 | dropout=0.0):
266 | super(GraphNetwork, self).__init__()
267 | # set size
268 | self.in_features = in_features
269 | self.node_features = node_features
270 | self.edge_features = edge_features
271 | self.num_layers = num_layers
272 | self.dropout = dropout
273 |
274 | # for each layer
275 | for l in range(self.num_layers):
276 | # set edge to node
277 | edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features,
278 | num_features=self.node_features,
279 | dropout=self.dropout if l < self.num_layers-1 else 0.0)
280 |
281 | # set node to edge
282 | node2edge_net = EdgeUpdateNetwork(in_features=self.node_features,
283 | num_features=self.edge_features,
284 | separate_dissimilarity=False,
285 | dropout=self.dropout if l < self.num_layers-1 else 0.0)
286 |
287 | self.add_module('edge2node_net{}'.format(l), edge2node_net)
288 | self.add_module('node2edge_net{}'.format(l), node2edge_net)
289 |
290 | # forward
291 | def forward(self, node_feat, edge_feat):
292 | # for each layer
293 | edge_feat_list = []
294 | for l in range(self.num_layers):
295 | # (1) edge to node
296 | node_feat = self._modules['edge2node_net{}'.format(l)](node_feat, edge_feat)
297 |
298 | # (2) node to edge
299 | edge_feat = self._modules['node2edge_net{}'.format(l)](node_feat, edge_feat)
300 |
301 | # save edge feature
302 | edge_feat_list.append(edge_feat)
303 |
304 | # if tt.arg.visualization:
305 | # for l in range(self.num_layers):
306 | # ax = sns.heatmap(tt.nvar(edge_feat_list[l][0, 0, :, :]), xticklabels=False, yticklabels=False, linewidth=0.1, cmap="coolwarm", cbar=False, square=True)
307 | # ax.get_figure().savefig('./visualization/edge_feat_layer{}.png'.format(l))
308 |
309 |
310 | return edge_feat_list
311 |
312 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from torchtools import *
3 | import torch.utils.data as data
4 | import random
5 | import os
6 | import numpy as np
7 | from PIL import Image as pil_image
8 | import pickle
9 | from itertools import islice
10 | from torchvision import transforms
11 |
12 |
13 | class MiniImagenetLoader(data.Dataset):
14 | def __init__(self, root, partition='train'):
15 | super(MiniImagenetLoader, self).__init__()
16 | # set dataset information
17 | self.root = root
18 | self.partition = partition
19 | self.data_size = [3, 84, 84]
20 |
21 | # set normalizer
22 | mean_pix = [x / 255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
23 | std_pix = [x / 255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
24 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
25 |
26 | # set transformer
27 | if self.partition == 'train':
28 | self.transform = transforms.Compose([transforms.RandomCrop(84, padding=4),
29 | lambda x: np.asarray(x),
30 | transforms.ToTensor(),
31 | normalize])
32 | else: # 'val' or 'test' ,
33 | self.transform = transforms.Compose([lambda x: np.asarray(x),
34 | transforms.ToTensor(),
35 | normalize])
36 |
37 | # load data
38 | self.data = self.load_dataset()
39 |
40 | def load_dataset(self):
41 | # load data
42 | dataset_path = os.path.join(self.root, 'mini-imagenet/compacted_datasets', 'mini_imagenet_%s.pickle' % self.partition)
43 | with open(dataset_path, 'rb') as handle:
44 | data = pickle.load(handle)
45 |
46 | # for each class
47 | for c_idx in data:
48 | # for each image
49 | for i_idx in range(len(data[c_idx])):
50 | # resize
51 | image_data = pil_image.fromarray(np.uint8(data[c_idx][i_idx]))
52 | image_data = image_data.resize((self.data_size[2], self.data_size[1]))
53 | #image_data = np.array(image_data, dtype='float32')
54 |
55 | #image_data = np.transpose(image_data, (2, 0, 1))
56 |
57 | # save
58 | data[c_idx][i_idx] = image_data
59 | return data
60 |
61 | def get_task_batch(self,
62 | num_tasks=5,
63 | num_ways=20,
64 | num_shots=1,
65 | num_queries=1,
66 | seed=None):
67 |
68 | if seed is not None:
69 | random.seed(seed)
70 |
71 | # init task batch data
72 | support_data, support_label, query_data, query_label = [], [], [], []
73 | for _ in range(num_ways * num_shots):
74 | data = np.zeros(shape=[num_tasks] + self.data_size,
75 | dtype='float32')
76 | label = np.zeros(shape=[num_tasks],
77 | dtype='float32')
78 | support_data.append(data)
79 | support_label.append(label)
80 | for _ in range(num_ways * num_queries):
81 | data = np.zeros(shape=[num_tasks] + self.data_size,
82 | dtype='float32')
83 | label = np.zeros(shape=[num_tasks],
84 | dtype='float32')
85 | query_data.append(data)
86 | query_label.append(label)
87 |
88 | # get full class list in dataset
89 | full_class_list = list(self.data.keys())
90 |
91 | # for each task
92 | for t_idx in range(num_tasks):
93 | # define task by sampling classes (num_ways)
94 | task_class_list = random.sample(full_class_list, num_ways)
95 |
96 | # for each sampled class in task
97 | for c_idx in range(num_ways):
98 | # sample data for support and query (num_shots + num_queries)
99 | class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)
100 |
101 |
102 | # load sample for support set
103 | for i_idx in range(num_shots):
104 | # set data
105 | support_data[i_idx + c_idx * num_shots][t_idx] = self.transform(class_data_list[i_idx])
106 | support_label[i_idx + c_idx * num_shots][t_idx] = c_idx
107 |
108 | # load sample for query set
109 | for i_idx in range(num_queries):
110 | query_data[i_idx + c_idx * num_queries][t_idx] = self.transform(class_data_list[num_shots + i_idx])
111 | query_label[i_idx + c_idx * num_queries][t_idx] = c_idx
112 |
113 | # convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)
114 | support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)
115 | support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)
116 | query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)
117 | query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)
118 |
119 | return [support_data, support_label, query_data, query_label]
120 |
121 |
122 |
123 | class TieredImagenetLoader(data.Dataset):
124 | def __init__(self, root, partition='train'):
125 | self.root = root
126 | self.partition = partition # train/val/test
127 | #self.preprocess()
128 | self.data_size = [3, 84, 84]
129 |
130 | # load data
131 | self.data = self.load_dataset()
132 |
133 | # if not self._check_exists_():
134 | # self._init_folders_()
135 | # if self.check_decompress():
136 | # self._decompress_()
137 | # self._preprocess_()
138 |
139 |
140 | def get_image_paths(self, file):
141 | images_path, class_names = [], []
142 | with open(file, 'r') as f:
143 | f.readline()
144 | for line in f:
145 | name, class_ = line.split(',')
146 | class_ = class_[0:(len(class_)-1)]
147 | path = self.root + '/tiered-imagenet/images/'+name
148 | images_path.append(path)
149 | class_names.append(class_)
150 | return class_names, images_path
151 |
152 | def preprocess(self):
153 | print('\nPreprocessing Tiered-Imagenet images...')
154 | (class_names_train, images_path_train) = self.get_image_paths('%s/tiered-imagenet/train.csv' % self.root)
155 | (class_names_test, images_path_test) = self.get_image_paths('%s/tiered-imagenet/test.csv' % self.root)
156 | (class_names_val, images_path_val) = self.get_image_paths('%s/tiered-imagenet/val.csv' % self.root)
157 |
158 | keys_train = list(set(class_names_train))
159 | keys_test = list(set(class_names_test))
160 | keys_val = list(set(class_names_val))
161 | label_encoder = {}
162 | label_decoder = {}
163 | for i in range(len(keys_train)):
164 | label_encoder[keys_train[i]] = i
165 | label_decoder[i] = keys_train[i]
166 | for i in range(len(keys_train), len(keys_train)+len(keys_test)):
167 | label_encoder[keys_test[i-len(keys_train)]] = i
168 | label_decoder[i] = keys_test[i-len(keys_train)]
169 | for i in range(len(keys_train)+len(keys_test), len(keys_train)+len(keys_test)+len(keys_val)):
170 | label_encoder[keys_val[i-len(keys_train) - len(keys_test)]] = i
171 | label_decoder[i] = keys_val[i-len(keys_train)-len(keys_test)]
172 |
173 | counter = 0
174 | train_set = {}
175 |
176 | for class_, path in zip(class_names_train, images_path_train):
177 | img = pil_image.open(path)
178 | img = img.convert('RGB')
179 | img = img.resize((84, 84), pil_image.ANTIALIAS)
180 | img = np.array(img, dtype='float32')
181 | if label_encoder[class_] not in train_set:
182 | train_set[label_encoder[class_]] = []
183 | train_set[label_encoder[class_]].append(img)
184 | counter += 1
185 | if counter % 1000 == 0:
186 | print("Counter "+str(counter) + " from " + str(len(images_path_train)))
187 |
188 | test_set = {}
189 | for class_, path in zip(class_names_test, images_path_test):
190 | img = pil_image.open(path)
191 | img = img.convert('RGB')
192 | img = img.resize((84, 84), pil_image.ANTIALIAS)
193 | img = np.array(img, dtype='float32')
194 |
195 | if label_encoder[class_] not in test_set:
196 | test_set[label_encoder[class_]] = []
197 | test_set[label_encoder[class_]].append(img)
198 | counter += 1
199 | if counter % 1000 == 0:
200 | print("Counter " + str(counter) + " from "+str(len(class_names_test)))
201 |
202 | val_set = {}
203 | for class_, path in zip(class_names_val, images_path_val):
204 | img = pil_image.open(path)
205 | img = img.convert('RGB')
206 | img = img.resize((84, 84), pil_image.ANTIALIAS)
207 | img = np.array(img, dtype='float32')
208 |
209 | if label_encoder[class_] not in val_set:
210 | val_set[label_encoder[class_]] = []
211 | val_set[label_encoder[class_]].append(img)
212 | counter += 1
213 | if counter % 1000 == 0:
214 | print("Counter "+str(counter) + " from " + str(len(class_names_val)))
215 |
216 | partition_count = 0
217 | for item in self.chunks(train_set, 20):
218 | partition_count = partition_count + 1
219 | with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_train_{}.pickle'.format(partition_count)), 'wb') as handle:
220 | pickle.dump(item, handle, protocol=2)
221 |
222 | partition_count = 0
223 | for item in self.chunks(test_set, 20):
224 | partition_count = partition_count + 1
225 | with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_test_{}.pickle'.format(partition_count)), 'wb') as handle:
226 | pickle.dump(item, handle, protocol=2)
227 |
228 | partition_count = 0
229 | for item in self.chunks(val_set, 20):
230 | partition_count = partition_count + 1
231 | with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_val_{}.pickle'.format(partition_count)), 'wb') as handle:
232 | pickle.dump(item, handle, protocol=2)
233 |
234 |
235 |
236 | label_encoder = {}
237 | keys = list(train_set.keys()) + list(test_set.keys())
238 | for id_key, key in enumerate(keys):
239 | label_encoder[key] = id_key
240 | with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_label_encoder.pickle'), 'wb') as handle:
241 | pickle.dump(label_encoder, handle, protocol=2)
242 |
243 | print('Images preprocessed')
244 |
245 | def load_dataset(self):
246 | print("Loading dataset")
247 | data = {}
248 | if self.partition == 'train':
249 | num_partition = 18
250 | elif self.partition == 'val':
251 | num_partition = 5
252 | elif self.partition == 'test':
253 | num_partition = 8
254 |
255 | partition_count = 0
256 | for i in range(num_partition):
257 | partition_count = partition_count +1
258 | with open(os.path.join(self.root, 'tiered-imagenet/compacted_datasets', 'tiered_imagenet_{}_{}.pickle'.format(self.partition, partition_count)), 'rb') as handle:
259 | data.update(pickle.load(handle))
260 |
261 | # Resize images and normalize
262 | for class_ in data:
263 | for i in range(len(data[class_])):
264 | image2resize = pil_image.fromarray(np.uint8(data[class_][i]))
265 | image_resized = image2resize.resize((self.data_size[2], self.data_size[1]))
266 | image_resized = np.array(image_resized, dtype='float32')
267 |
268 | # Normalize
269 | image_resized = np.transpose(image_resized, (2, 0, 1))
270 | image_resized[0, :, :] -= 120.45 # R
271 | image_resized[1, :, :] -= 115.74 # G
272 | image_resized[2, :, :] -= 104.65 # B
273 | image_resized /= 127.5
274 |
275 | data[class_][i] = image_resized
276 |
277 | print("Num classes " + str(len(data)))
278 | num_images = 0
279 | for class_ in data:
280 | num_images += len(data[class_])
281 | print("Num images " + str(num_images))
282 | return data
283 |
284 | def chunks(self, data, size=10000):
285 | it = iter(data)
286 | for i in range(0, len(data), size):
287 | yield {k: data[k] for k in islice(it, size)}
288 |
289 | def get_task_batch(self,
290 | num_tasks=5,
291 | num_ways=20,
292 | num_shots=1,
293 | num_queries=1,
294 | seed=None):
295 | if seed is not None:
296 | random.seed(seed)
297 |
298 | # init task batch data
299 | support_data, support_label, query_data, query_label = [], [], [], []
300 | for _ in range(num_ways * num_shots):
301 | data = np.zeros(shape=[num_tasks] + self.data_size,
302 | dtype='float32')
303 | label = np.zeros(shape=[num_tasks],
304 | dtype='float32')
305 | support_data.append(data)
306 | support_label.append(label)
307 | for _ in range(num_ways * num_queries):
308 | data = np.zeros(shape=[num_tasks] + self.data_size,
309 | dtype='float32')
310 | label = np.zeros(shape=[num_tasks],
311 | dtype='float32')
312 | query_data.append(data)
313 | query_label.append(label)
314 |
315 | # get full class list in dataset
316 | full_class_list = list(self.data.keys())
317 |
318 | # for each task
319 | for t_idx in range(num_tasks):
320 | # define task by sampling classes (num_ways)
321 | task_class_list = random.sample(full_class_list, num_ways)
322 |
323 | # for each sampled class in task
324 | for c_idx in range(num_ways):
325 | # sample data for support and query (num_shots + num_queries)
326 | class_data_list = random.sample(self.data[task_class_list[c_idx]], num_shots + num_queries)
327 |
328 | # load sample for support set
329 | for i_idx in range(num_shots):
330 | # set data
331 | support_data[i_idx + c_idx * num_shots][t_idx] = class_data_list[i_idx]
332 | support_label[i_idx + c_idx * num_shots][t_idx] = c_idx
333 |
334 | # load sample for query set
335 | for i_idx in range(num_queries):
336 | query_data[i_idx + c_idx * num_queries][t_idx] = class_data_list[num_shots + i_idx]
337 | query_label[i_idx + c_idx * num_queries][t_idx] = c_idx
338 |
339 |
340 |
341 | # convert to tensor (num_tasks x (num_ways * (num_supports + num_queries)) x ...)
342 | support_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in support_data], 1)
343 | support_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in support_label], 1)
344 | query_data = torch.stack([torch.from_numpy(data).float().to(tt.arg.device) for data in query_data], 1)
345 | query_label = torch.stack([torch.from_numpy(label).float().to(tt.arg.device) for label in query_label], 1)
346 |
347 | return [support_data, support_label, query_data, query_label]
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from torchtools import *
2 | from data import MiniImagenetLoader, TieredImagenetLoader
3 | from model import EmbeddingImagenet, GraphNetwork, ConvNet
4 | import shutil
5 | import os
6 | import random
7 | #import seaborn as sns
8 |
9 |
10 | class ModelTrainer(object):
11 | def __init__(self,
12 | enc_module,
13 | gnn_module,
14 | data_loader):
15 | # set encoder and gnn
16 | self.enc_module = enc_module.to(tt.arg.device)
17 | self.gnn_module = gnn_module.to(tt.arg.device)
18 |
19 | if tt.arg.num_gpus > 1:
20 | print('Construct multi-gpu model ...')
21 | self.enc_module = nn.DataParallel(self.enc_module, device_ids=[0, 1, 2, 3], dim=0)
22 | self.gnn_module = nn.DataParallel(self.gnn_module, device_ids=[0, 1, 2, 3], dim=0)
23 |
24 | print('done!\n')
25 |
26 | # get data loader
27 | self.data_loader = data_loader
28 |
29 | # set optimizer
30 | self.module_params = list(self.enc_module.parameters()) + list(self.gnn_module.parameters())
31 |
32 | # set optimizer
33 | self.optimizer = optim.Adam(params=self.module_params,
34 | lr=tt.arg.lr,
35 | weight_decay=tt.arg.weight_decay)
36 |
37 | # set loss
38 | self.edge_loss = nn.BCELoss(reduction='none')
39 |
40 | self.node_loss = nn.CrossEntropyLoss(reduction='none')
41 |
42 | self.global_step = 0
43 | self.val_acc = 0
44 | self.test_acc = 0
45 |
46 | def train(self):
47 | val_acc = self.val_acc
48 |
49 | # set edge mask (to distinguish support and query edges)
50 | num_supports = tt.arg.num_ways_train * tt.arg.num_shots_train
51 | num_queries = tt.arg.num_ways_train * 1
52 | num_samples = num_supports + num_queries
53 | support_edge_mask = torch.zeros(tt.arg.meta_batch_size, num_samples, num_samples).to(tt.arg.device)
54 | support_edge_mask[:, :num_supports, :num_supports] = 1
55 | query_edge_mask = 1 - support_edge_mask
56 |
57 | evaluation_mask = torch.ones(tt.arg.meta_batch_size, num_samples, num_samples).to(tt.arg.device)
58 | # for semi-supervised setting, ignore unlabeled support sets for evaluation
59 | for c in range(tt.arg.num_ways_train):
60 | evaluation_mask[:,
61 | ((c + 1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_train,
62 | :num_supports] = 0
63 | evaluation_mask[:, :num_supports,
64 | ((c + 1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_train] = 0
65 |
66 | # for each iteration
67 | for iter in range(self.global_step + 1, tt.arg.train_iteration + 1):
68 | # init grad
69 | self.optimizer.zero_grad()
70 |
71 | # set current step
72 | self.global_step = iter
73 |
74 | # load task data list
75 | [support_data,
76 | support_label,
77 | query_data,
78 | query_label] = self.data_loader['train'].get_task_batch(num_tasks=tt.arg.meta_batch_size,
79 | num_ways=tt.arg.num_ways_train,
80 | num_shots=tt.arg.num_shots_train,
81 | seed=iter + tt.arg.seed)
82 |
83 | # set as single data
84 | full_data = torch.cat([support_data, query_data], 1)
85 | full_label = torch.cat([support_label, query_label], 1)
86 | full_edge = self.label2edge(full_label)
87 |
88 | # set init edge
89 | init_edge = full_edge.clone() # batch_size x 2 x num_samples x num_samples
90 | init_edge[:, :, num_supports:, :] = 0.5
91 | init_edge[:, :, :, num_supports:] = 0.5
92 | for i in range(num_queries):
93 | init_edge[:, 0, num_supports + i, num_supports + i] = 1.0
94 | init_edge[:, 1, num_supports + i, num_supports + i] = 0.0
95 |
96 | # for semi-supervised setting,
97 | for c in range(tt.arg.num_ways_train):
98 | init_edge[:, :, ((c+1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_train, :num_supports] = 0.5
99 | init_edge[:, :, :num_supports, ((c+1) * tt.arg.num_shots_train - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_train] = 0.5
100 |
101 | # set as train mode
102 | self.enc_module.train()
103 | self.gnn_module.train()
104 |
105 | # (1) encode data
106 | full_data = [self.enc_module(data.squeeze(1)) for data in full_data.chunk(full_data.size(1), dim=1)]
107 | full_data = torch.stack(full_data, dim=1) # batch_size x num_samples x featdim
108 |
109 | # (2) predict edge logit (consider only the last layer logit, num_tasks x 2 x num_samples x num_samples)
110 | if tt.arg.train_transductive:
111 | full_logit_layers = self.gnn_module(node_feat=full_data, edge_feat=init_edge)
112 | else:
113 | evaluation_mask[:, num_supports:, num_supports:] = 0 # ignore query-query edges, since it is non-transductive setting
114 | # input_node_feat: (batch_size x num_queries) x (num_support + 1) x featdim
115 | # input_edge_feat: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
116 | support_data = full_data[:, :num_supports] # batch_size x num_support x featdim
117 | query_data = full_data[:, num_supports:] # batch_size x num_query x featdim
118 | support_data_tiled = support_data.unsqueeze(1).repeat(1, num_queries, 1, 1) # batch_size x num_queries x num_support x featdim
119 | support_data_tiled = support_data_tiled.view(tt.arg.meta_batch_size * num_queries, num_supports, -1) # (batch_size x num_queries) x num_support x featdim
120 | query_data_reshaped = query_data.contiguous().view(tt.arg.meta_batch_size * num_queries, -1).unsqueeze(1) # (batch_size x num_queries) x 1 x featdim
121 | input_node_feat = torch.cat([support_data_tiled, query_data_reshaped], 1) # (batch_size x num_queries) x (num_support + 1) x featdim
122 |
123 | input_edge_feat = 0.5 * torch.ones(tt.arg.meta_batch_size, 2, num_supports + 1, num_supports + 1).to(tt.arg.device) # batch_size x 2 x (num_support + 1) x (num_support + 1)
124 |
125 | input_edge_feat[:, :, :num_supports, :num_supports] = init_edge[:, :, :num_supports, :num_supports] # batch_size x 2 x (num_support + 1) x (num_support + 1)
126 | input_edge_feat = input_edge_feat.repeat(num_queries, 1, 1, 1) #(batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
127 |
128 | # logit: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
129 | logit_layers = self.gnn_module(node_feat=input_node_feat, edge_feat=input_edge_feat)
130 |
131 | logit_layers = [logit_layer.view(tt.arg.meta_batch_size, num_queries, 2, num_supports + 1, num_supports + 1) for logit_layer in logit_layers]
132 |
133 | # logit --> full_logit (batch_size x 2 x num_samples x num_samples)
134 | full_logit_layers = []
135 | for l in range(tt.arg.num_layers):
136 | full_logit_layers.append(torch.zeros(tt.arg.meta_batch_size, 2, num_samples, num_samples).to(tt.arg.device))
137 |
138 | for l in range(tt.arg.num_layers):
139 | full_logit_layers[l][:, :, :num_supports, :num_supports] = logit_layers[l][:, :, :, :num_supports, :num_supports].mean(1)
140 | full_logit_layers[l][:, :, :num_supports, num_supports:] = logit_layers[l][:, :, :, :num_supports, -1].transpose(1, 2).transpose(2, 3)
141 | full_logit_layers[l][:, :, num_supports:, :num_supports] = logit_layers[l][:, :, :, -1, :num_supports].transpose(1, 2)
142 |
143 | # (4) compute loss
144 | full_edge_loss_layers = [self.edge_loss((1-full_logit_layer[:, 0]), (1-full_edge[:, 0])) for full_logit_layer in full_logit_layers]
145 |
146 | # weighted edge loss for balancing pos/neg
147 | pos_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(query_edge_mask * full_edge[:, 0] * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers]
148 | neg_query_edge_loss_layers = [torch.sum(full_edge_loss_layer * query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) / torch.sum(query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) for full_edge_loss_layer in full_edge_loss_layers]
149 | query_edge_loss_layers = [pos_query_edge_loss_layer + neg_query_edge_loss_layer for (pos_query_edge_loss_layer, neg_query_edge_loss_layer) in zip(pos_query_edge_loss_layers, neg_query_edge_loss_layers)]
150 |
151 | # compute accuracy
152 | full_edge_accr_layers = [self.hit(full_logit_layer, 1-full_edge[:, 0].long()) for full_logit_layer in full_logit_layers]
153 | query_edge_accr_layers = [torch.sum(full_edge_accr_layer * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask) for full_edge_accr_layer in full_edge_accr_layers]
154 |
155 | # compute node loss & accuracy (num_tasks x num_quries x num_ways)
156 | query_node_pred_layers = [torch.bmm(full_logit_layer[:, 0, num_supports:, :num_supports], self.one_hot_encode(tt.arg.num_ways_train, support_label.long())) for full_logit_layer in full_logit_layers] # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)
157 | query_node_accr_layers = [torch.eq(torch.max(query_node_pred_layer, -1)[1], query_label.long()).float().mean() for query_node_pred_layer in query_node_pred_layers]
158 |
159 | total_loss_layers = query_edge_loss_layers
160 |
161 | # update model
162 | total_loss = []
163 | for l in range(tt.arg.num_layers - 1):
164 | total_loss += [total_loss_layers[l].view(-1) * 0.5]
165 | total_loss += [total_loss_layers[-1].view(-1) * 1.0]
166 | total_loss = torch.mean(torch.cat(total_loss, 0))
167 |
168 | total_loss.backward()
169 |
170 | self.optimizer.step()
171 |
172 | # adjust learning rate
173 | self.adjust_learning_rate(optimizers=[self.optimizer],
174 | lr=tt.arg.lr,
175 | iter=self.global_step)
176 |
177 | # logging
178 | tt.log_scalar('train/edge_loss', query_edge_loss_layers[-1], self.global_step)
179 | tt.log_scalar('train/edge_accr', query_edge_accr_layers[-1], self.global_step)
180 | tt.log_scalar('train/node_accr', query_node_accr_layers[-1], self.global_step)
181 |
182 | # evaluation
183 | if self.global_step % tt.arg.test_interval == 0:
184 | val_acc = self.eval(partition='val')
185 |
186 | is_best = 0
187 |
188 | if val_acc >= self.val_acc:
189 | self.val_acc = val_acc
190 | is_best = 1
191 |
192 | tt.log_scalar('val/best_accr', self.val_acc, self.global_step)
193 |
194 | self.save_checkpoint({
195 | 'iteration': self.global_step,
196 | 'enc_module_state_dict': self.enc_module.state_dict(),
197 | 'gnn_module_state_dict': self.gnn_module.state_dict(),
198 | 'val_acc': val_acc,
199 | 'optimizer': self.optimizer.state_dict(),
200 | }, is_best)
201 |
202 | tt.log_step(global_step=self.global_step)
203 |
204 | def eval(self, partition='test', log_flag=True):
205 | best_acc = 0
206 | # set edge mask (to distinguish support and query edges)
207 | num_supports = tt.arg.num_ways_test * tt.arg.num_shots_test
208 | num_queries = tt.arg.num_ways_test * 1
209 | num_samples = num_supports + num_queries
210 | support_edge_mask = torch.zeros(tt.arg.test_batch_size, num_samples, num_samples).to(tt.arg.device)
211 | support_edge_mask[:, :num_supports, :num_supports] = 1
212 | query_edge_mask = 1 - support_edge_mask
213 | evaluation_mask = torch.ones(tt.arg.test_batch_size, num_samples, num_samples).to(tt.arg.device)
214 | # for semi-supervised setting, ignore unlabeled support sets for evaluation
215 | for c in range(tt.arg.num_ways_test):
216 | evaluation_mask[:,
217 | ((c + 1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_test,
218 | :num_supports] = 0
219 | evaluation_mask[:, :num_supports,
220 | ((c + 1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c + 1) * tt.arg.num_shots_test] = 0
221 |
222 | query_edge_losses = []
223 | query_edge_accrs = []
224 | query_node_accrs = []
225 |
226 | # for each iteration
227 | for iter in range(tt.arg.test_iteration//tt.arg.test_batch_size):
228 | # load task data list
229 | [support_data,
230 | support_label,
231 | query_data,
232 | query_label] = self.data_loader[partition].get_task_batch(num_tasks=tt.arg.test_batch_size,
233 | num_ways=tt.arg.num_ways_test,
234 | num_shots=tt.arg.num_shots_test,
235 | seed=iter)
236 |
237 | # set as single data
238 | full_data = torch.cat([support_data, query_data], 1)
239 | full_label = torch.cat([support_label, query_label], 1)
240 | full_edge = self.label2edge(full_label)
241 |
242 | # set init edge
243 | init_edge = full_edge.clone()
244 | init_edge[:, :, num_supports:, :] = 0.5
245 | init_edge[:, :, :, num_supports:] = 0.5
246 | for i in range(num_queries):
247 | init_edge[:, 0, num_supports + i, num_supports + i] = 1.0
248 | init_edge[:, 1, num_supports + i, num_supports + i] = 0.0
249 |
250 | # for semi-supervised setting,
251 | for c in range(tt.arg.num_ways_test):
252 | init_edge[:, :, ((c+1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_test, :num_supports] = 0.5
253 | init_edge[:, :, :num_supports, ((c+1) * tt.arg.num_shots_test - tt.arg.num_unlabeled):(c+1) * tt.arg.num_shots_test] = 0.5
254 |
255 | # set as train mode
256 | self.enc_module.eval()
257 | self.gnn_module.eval()
258 |
259 | # (1) encode data
260 | full_data = [self.enc_module(data.squeeze(1)) for data in full_data.chunk(full_data.size(1), dim=1)]
261 | full_data = torch.stack(full_data, dim=1)
262 |
263 | # (2) predict edge logit (consider only the last layer logit, num_tasks x 2 x num_samples x num_samples)
264 | if tt.arg.test_transductive:
265 | full_logit_all = self.gnn_module(node_feat=full_data, edge_feat=init_edge)
266 | full_logit = full_logit_all[-1]
267 | else:
268 | evaluation_mask[:, num_supports:, num_supports:] = 0 # ignore query-query edges, since it is non-transductive setting
269 |
270 | full_logit = torch.zeros(tt.arg.test_batch_size, 2, num_samples, num_samples).to(tt.arg.device)
271 |
272 | # input_node_feat: (batch_size x num_queries) x (num_support + 1) x featdim
273 | # input_edge_feat: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
274 | support_data = full_data[:, :num_supports] # batch_size x num_support x featdim
275 | query_data = full_data[:, num_supports:] # batch_size x num_query x featdim
276 | support_data_tiled = support_data.unsqueeze(1).repeat(1, num_queries, 1, 1) # batch_size x num_queries x num_support x featdim
277 | support_data_tiled = support_data_tiled.view(tt.arg.test_batch_size * num_queries, num_supports, -1) # (batch_size x num_queries) x num_support x featdim
278 | query_data_reshaped = query_data.contiguous().view(tt.arg.test_batch_size * num_queries, -1).unsqueeze(1) # (batch_size x num_queries) x 1 x featdim
279 | input_node_feat = torch.cat([support_data_tiled, query_data_reshaped], 1) # (batch_size x num_queries) x (num_support + 1) x featdim
280 |
281 | input_edge_feat = 0.5 * torch.ones(tt.arg.test_batch_size, 2, num_supports + 1, num_supports + 1).to(tt.arg.device) # batch_size x 2 x (num_support + 1) x (num_support + 1)
282 |
283 | input_edge_feat[:, :, :num_supports, :num_supports] = init_edge[:, :, :num_supports, :num_supports] # batch_size x 2 x (num_support + 1) x (num_support + 1)
284 | input_edge_feat = input_edge_feat.repeat(num_queries, 1, 1, 1) # (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
285 |
286 | # logit: (batch_size x num_queries) x 2 x (num_support + 1) x (num_support + 1)
287 | logit = self.gnn_module(node_feat=input_node_feat, edge_feat=input_edge_feat)[-1]
288 |
289 | logit = logit.view(tt.arg.test_batch_size, num_queries, 2, num_supports + 1, num_supports + 1)
290 |
291 | # batch_size x num_queries x 2 x (num_support + 1) x (num_support + 1)
292 | # logit --> full_logit (batch_size x 2 x num_samples x num_samples)
293 | full_logit[:, :, :num_supports, :num_supports] = logit[:, :, :, :num_supports, :num_supports].mean(1)
294 | full_logit[:, :, :num_supports, num_supports:] = logit[:, :, :, :num_supports, -1].transpose(1, 2).transpose(2, 3)
295 | full_logit[:, :, num_supports:, :num_supports] = logit[:, :, :, -1, :num_supports].transpose(1, 2)
296 |
297 | # (4) compute loss
298 | full_edge_loss = self.edge_loss(1-full_logit[:, 0], 1-full_edge[:, 0])
299 |
300 | query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask)
301 |
302 | # weighted loss for balancing pos/neg
303 | pos_query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * full_edge[:, 0] * evaluation_mask) / torch.sum(query_edge_mask * full_edge[:, 0] * evaluation_mask)
304 | neg_query_edge_loss = torch.sum(full_edge_loss * query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask) / torch.sum(query_edge_mask * (1-full_edge[:, 0]) * evaluation_mask)
305 | query_edge_loss = pos_query_edge_loss + neg_query_edge_loss
306 |
307 | # compute accuracy
308 | full_edge_accr = self.hit(full_logit, 1-full_edge[:, 0].long())
309 | query_edge_accr = torch.sum(full_edge_accr * query_edge_mask * evaluation_mask) / torch.sum(query_edge_mask * evaluation_mask)
310 |
311 | # compute node accuracy (num_tasks x num_quries x num_ways)
312 | query_node_pred = torch.bmm(full_logit[:, 0, num_supports:, :num_supports], self.one_hot_encode(tt.arg.num_ways_test, support_label.long())) # (num_tasks x num_quries x num_supports) * (num_tasks x num_supports x num_ways)
313 | query_node_accr = torch.eq(torch.max(query_node_pred, -1)[1], query_label.long()).float().mean()
314 |
315 | query_edge_losses += [query_edge_loss.item()]
316 | query_edge_accrs += [query_edge_accr.item()]
317 | query_node_accrs += [query_node_accr.item()]
318 |
319 | # logging
320 | if log_flag:
321 | tt.log('---------------------------')
322 | tt.log_scalar('{}/edge_loss'.format(partition), np.array(query_edge_losses).mean(), self.global_step)
323 | tt.log_scalar('{}/edge_accr'.format(partition), np.array(query_edge_accrs).mean(), self.global_step)
324 | tt.log_scalar('{}/node_accr'.format(partition), np.array(query_node_accrs).mean(), self.global_step)
325 |
326 | tt.log('evaluation: total_count=%d, accuracy: mean=%.2f%%, std=%.2f%%, ci95=%.2f%%' %
327 | (iter,
328 | np.array(query_node_accrs).mean() * 100,
329 | np.array(query_node_accrs).std() * 100,
330 | 1.96 * np.array(query_node_accrs).std() / np.sqrt(float(len(np.array(query_node_accrs)))) * 100))
331 | tt.log('---------------------------')
332 |
333 | return np.array(query_node_accrs).mean()
334 |
335 | def adjust_learning_rate(self, optimizers, lr, iter):
336 | new_lr = lr * (0.5 ** (int(iter / tt.arg.dec_lr)))
337 |
338 | for optimizer in optimizers:
339 | for param_group in optimizer.param_groups:
340 | param_group['lr'] = new_lr
341 |
342 | def label2edge(self, label):
343 | # get size
344 | num_samples = label.size(1)
345 |
346 | # reshape
347 | label_i = label.unsqueeze(-1).repeat(1, 1, num_samples)
348 | label_j = label_i.transpose(1, 2)
349 |
350 | # compute edge
351 | edge = torch.eq(label_i, label_j).float().to(tt.arg.device)
352 |
353 | # expand
354 | edge = edge.unsqueeze(1)
355 | edge = torch.cat([edge, 1 - edge], 1)
356 | return edge
357 |
358 | def hit(self, logit, label):
359 | pred = logit.max(1)[1]
360 | hit = torch.eq(pred, label).float()
361 | return hit
362 |
363 | def one_hot_encode(self, num_classes, class_idx):
364 | return torch.eye(num_classes)[class_idx].to(tt.arg.device)
365 |
366 | def save_checkpoint(self, state, is_best):
367 | torch.save(state, 'asset/checkpoints/{}/'.format(tt.arg.experiment) + 'checkpoint.pth.tar')
368 | if is_best:
369 | shutil.copyfile('asset/checkpoints/{}/'.format(tt.arg.experiment) + 'checkpoint.pth.tar',
370 | 'asset/checkpoints/{}/'.format(tt.arg.experiment) + 'model_best.pth.tar')
371 |
372 | def set_exp_name():
373 | exp_name = 'D-{}'.format(tt.arg.dataset)
374 | exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled)
375 | exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size)
376 | exp_name += '_T-{}'.format(tt.arg.transductive)
377 | exp_name += '_SEED-{}'.format(tt.arg.seed)
378 |
379 | return exp_name
380 |
381 | if __name__ == '__main__':
382 |
383 | tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device
384 | # replace dataset_root with your own
385 | tt.arg.dataset_root = '/data/private/dataset'
386 | tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset
387 | tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways
388 | tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots
389 | tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled
390 | tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers
391 | tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size
392 | tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive
393 | tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed
394 | tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus
395 |
396 | tt.arg.num_ways_train = tt.arg.num_ways
397 | tt.arg.num_ways_test = tt.arg.num_ways
398 |
399 | tt.arg.num_shots_train = tt.arg.num_shots
400 | tt.arg.num_shots_test = tt.arg.num_shots
401 |
402 | tt.arg.train_transductive = tt.arg.transductive
403 | tt.arg.test_transductive = tt.arg.transductive
404 |
405 | # model parameter related
406 | tt.arg.num_edge_features = 96
407 | tt.arg.num_node_features = 96
408 | tt.arg.emb_size = 128
409 |
410 | # train, test parameters
411 | tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000
412 | tt.arg.test_iteration = 10000
413 | tt.arg.test_interval = 5000 if tt.arg.test_interval is None else tt.arg.test_interval
414 | tt.arg.test_batch_size = 10
415 | tt.arg.log_step = 1000 if tt.arg.log_step is None else tt.arg.log_step
416 |
417 | tt.arg.lr = 1e-3
418 | tt.arg.grad_clip = 5
419 | tt.arg.weight_decay = 1e-6
420 | tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000
421 | tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0
422 |
423 | tt.arg.experiment = set_exp_name() if tt.arg.experiment is None else tt.arg.experiment
424 |
425 | print(set_exp_name())
426 |
427 | #set random seed
428 | np.random.seed(tt.arg.seed)
429 | torch.manual_seed(tt.arg.seed)
430 | torch.cuda.manual_seed_all(tt.arg.seed)
431 | random.seed(tt.arg.seed)
432 | torch.backends.cudnn.deterministic = True
433 | torch.backends.cudnn.benchmark = False
434 |
435 | tt.arg.log_dir_user = tt.arg.log_dir if tt.arg.log_dir_user is None else tt.arg.log_dir_user
436 | tt.arg.log_dir = tt.arg.log_dir_user
437 |
438 | if not os.path.exists('asset/checkpoints'):
439 | os.makedirs('asset/checkpoints')
440 | if not os.path.exists('asset/checkpoints/' + tt.arg.experiment):
441 | os.makedirs('asset/checkpoints/' + tt.arg.experiment)
442 |
443 |
444 | enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)
445 |
446 | gnn_module = GraphNetwork(in_features=tt.arg.emb_size,
447 | node_features=tt.arg.num_edge_features,
448 | edge_features=tt.arg.num_node_features,
449 | num_layers=tt.arg.num_layers,
450 | dropout=tt.arg.dropout)
451 |
452 | if tt.arg.dataset == 'mini':
453 | train_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='train')
454 | valid_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='val')
455 | elif tt.arg.dataset == 'tiered':
456 | train_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='train')
457 | valid_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='val')
458 | else:
459 | print('Unknown dataset!')
460 |
461 | data_loader = {'train': train_loader,
462 | 'val': valid_loader
463 | }
464 |
465 | # create trainer
466 | trainer = ModelTrainer(enc_module=enc_module,
467 | gnn_module=gnn_module,
468 | data_loader=data_loader)
469 |
470 | trainer.train()
471 |
--------------------------------------------------------------------------------