├── _version.py ├── torchtools ├── _version.py ├── tt │ ├── __pycache__ │ │ ├── arg.cpython-36.pyc │ │ ├── stat.cpython-36.pyc │ │ ├── layer.cpython-36.pyc │ │ ├── logger.cpython-36.pyc │ │ ├── utils.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── trainer.cpython-36.pyc │ ├── __init__.py │ ├── layer.py │ ├── stat.py │ ├── trainer.py │ ├── utils.py │ ├── arg.py │ └── logger.py ├── __pycache__ │ └── __init__.cpython-36.pyc └── __init__.py ├── __pycache__ └── __init__.cpython-36.pyc ├── .idea ├── vcs.xml ├── modules.xml ├── egnn_distribute.iml └── workspace.xml ├── __init__.py ├── LICENSE ├── eval.py ├── README.md ├── model.py ├── data.py └── train.py /_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4.0' # align version with pytorch 2 | 3 | -------------------------------------------------------------------------------- /torchtools/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4.0' # align version with pytorch 2 | 3 | -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /torchtools/tt/__pycache__/arg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/arg.cpython-36.pyc -------------------------------------------------------------------------------- /torchtools/tt/__pycache__/stat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/stat.cpython-36.pyc -------------------------------------------------------------------------------- /torchtools/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /torchtools/tt/__pycache__/layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/layer.cpython-36.pyc -------------------------------------------------------------------------------- /torchtools/tt/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /torchtools/tt/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /torchtools/tt/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /torchtools/tt/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmkim0309/fewshot-egnn/HEAD/torchtools/tt/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /torchtools/tt/__init__.py: -------------------------------------------------------------------------------- 1 | from torchtools.tt.arg import _parse_opts 2 | from torchtools.tt.utils import * 3 | from torchtools.tt.layer import * 4 | from torchtools.tt.logger import * 5 | from torchtools.tt.stat import * 6 | from torchtools.tt.trainer import * 7 | 8 | 9 | __author__ = 'namju.kim@kakaobrain.com' 10 | 11 | 12 | # global command line arguments 13 | arg = _parse_opts() 14 | -------------------------------------------------------------------------------- /torchtools/tt/layer.py: -------------------------------------------------------------------------------- 1 | from torchtools import nn 2 | 3 | 4 | # 5 | # Reshape layer for Sequential or ModuleList 6 | # 7 | class Reshape(nn.Module): 8 | 9 | def __init__(self, *shape): 10 | super(Reshape, self).__init__() 11 | self.shape = shape 12 | 13 | def forward(self, x): 14 | return x.reshape(self.shape) 15 | 16 | def extra_repr(self): 17 | return 'shape={}'.format(self.shape) -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch import optim 5 | from torch import cuda 6 | from torch import utils 7 | from torch.nn import functional as F 8 | from torch.utils.data import * 9 | from torch.distributions import * 10 | from torchtools import tt 11 | 12 | 13 | __author__ = 'namju.kim@kakaobrain.com' 14 | 15 | 16 | # initialize seed 17 | if tt.arg.seed: 18 | np.random.seed(tt.arg.seed) 19 | torch.manual_seed(tt.arg.seed) 20 | -------------------------------------------------------------------------------- /torchtools/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch import optim 5 | from torch import cuda 6 | from torch import utils 7 | from torch.nn import functional as F 8 | from torch.utils.data import * 9 | from torch.distributions import * 10 | from torchtools import tt 11 | 12 | 13 | __author__ = 'namju.kim@kakaobrain.com' 14 | 15 | 16 | # initialize seed 17 | if tt.arg.seed: 18 | np.random.seed(tt.arg.seed) 19 | torch.manual_seed(tt.arg.seed) 20 | -------------------------------------------------------------------------------- /torchtools/tt/stat.py: -------------------------------------------------------------------------------- 1 | from torchtools import tt 2 | 3 | 4 | __author__ = 'namju.kim@kakaobrain.com' 5 | 6 | 7 | def accuracy(prob, label, ignore_index=-100): 8 | 9 | # argmax 10 | pred = prob.max(1)[1].type_as(label) 11 | 12 | # masking 13 | mask = label.ne(ignore_index) 14 | pred = pred.masked_select(mask) 15 | label = label.masked_select(mask) 16 | 17 | # calc accuracy 18 | hit = tt.nvar(pred.eq(label).long().sum()) 19 | acc = hit / label.size(0) 20 | return acc 21 | -------------------------------------------------------------------------------- /.idea/egnn_distribute.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jongmin Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchtools/tt/trainer.py: -------------------------------------------------------------------------------- 1 | from torchtools import nn, optim, tt 2 | 3 | 4 | __author__ = 'namju.kim@kakaobrain.com' 5 | 6 | 7 | class SupervisedTrainer(object): 8 | 9 | def __init__(self, model, data_loader, optimizer=None, criterion=None): 10 | self.global_step = 0 11 | self.model = model.to(tt.arg.device) 12 | self.data_loader = data_loader 13 | self.optimizer = optimizer or optim.Adam(model.parameters()) 14 | self.criterion = criterion or nn.CrossEntropyLoss() 15 | 16 | def train(self, inputs): 17 | 18 | # split inputs 19 | x, y = inputs 20 | 21 | # forward 22 | if tt.arg.cuda: 23 | z = nn.DataParallel(self.model)(x) 24 | else: 25 | z = self.model(x) 26 | 27 | # loss 28 | loss = self.criterion(z, y) 29 | 30 | # accuracy 31 | acc = tt.accuracy(z, y) 32 | 33 | # update model 34 | self.optimizer.zero_grad() 35 | loss.backward() 36 | self.optimizer.step() 37 | 38 | # logging 39 | tt.log_scalar('loss', loss, self.global_step) 40 | tt.log_scalar('acc', acc, self.global_step) 41 | 42 | def epoch(self, ep_no=None): 43 | pass 44 | 45 | def run(self): 46 | 47 | # experiment name 48 | tt.arg.experiment = tt.arg.experiment or self.model.__class__.__name__.lower() 49 | 50 | # load model 51 | self.global_step = self.model.load_model() 52 | epoch, min_step = divmod(self.global_step, len(self.data_loader)) 53 | 54 | # epochs 55 | while epoch < (tt.arg.epoch or 1): 56 | epoch += 1 57 | 58 | # iterations 59 | for step, inputs in enumerate(self.data_loader, min_step + 1): 60 | 61 | # check step counter 62 | if step > len(self.data_loader): 63 | break 64 | 65 | # increase global step count 66 | self.global_step += 1 67 | 68 | # update learning rate 69 | for param_group in self.optimizer.param_groups: 70 | param_group['lr'] = tt.arg.lr 71 | 72 | # call train func 73 | if type(inputs) in [list, tuple]: 74 | self.train([tt.var(d) for d in inputs]) 75 | else: 76 | self.train(tt.var(inputs)) 77 | 78 | # logging 79 | tt.log_weight(self.model, global_step=self.global_step) 80 | tt.log_gradient(self.model, global_step=self.global_step) 81 | tt.log_step(epoch=epoch, global_step=self.global_step, 82 | max_epoch=(tt.arg.epoch or 1), max_step=len(self.data_loader)) 83 | 84 | # save model 85 | self.model.save_model(self.global_step) 86 | 87 | # epoch handler 88 | self.epoch(epoch) 89 | 90 | # save final model 91 | self.model.save_model(self.global_step, force=True) 92 | -------------------------------------------------------------------------------- /torchtools/tt/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import time 4 | import pathlib 5 | from torchtools import torch, nn, tt 6 | 7 | 8 | __author__ = 'namju.kim@kakaobrain.com' 9 | 10 | 11 | # time stamp 12 | _tic_start = _last_saved = _last_archived = time.time() 13 | # best statics 14 | _best = -100000000. 15 | 16 | 17 | def tic(): 18 | global _tic_start 19 | _tic_start = time.time() 20 | return _tic_start 21 | 22 | 23 | def toc(tic=None): 24 | global _tic_start 25 | if tic is None: 26 | return time.time() - _tic_start 27 | else: 28 | return time.time() - tic 29 | 30 | 31 | def sleep(seconds): 32 | time.sleep(seconds) 33 | 34 | 35 | # 36 | # automatic device-aware torch.tensor 37 | # 38 | def var(data, dtype=None, device=None, requires_grad=False): 39 | # return torch.tensor(data, dtype=dtype, device=(device or tt.arg.device), requires_grad=requires_grad) 40 | # the upper code doesn't work, so work around as following. ( maybe bug ) 41 | return torch.tensor(data, dtype=dtype, requires_grad=requires_grad).to((device or tt.arg.device)) 42 | 43 | 44 | def vars(x_list, dtype=None, device=None, requires_grad=False): 45 | return [var(x, dtype, device, requires_grad) for x in x_list] 46 | 47 | 48 | # for old torchtools compatibility 49 | def cvar(x): 50 | return x.detach() 51 | 52 | 53 | # 54 | # to python or numpy variable(s) 55 | # 56 | def nvar(x): 57 | if isinstance(x, torch.Tensor): 58 | x = x.detach().cpu() 59 | x = x.item() if x.dim() == 0 else x.numpy() 60 | return x 61 | 62 | 63 | def nvars(x_list): 64 | return [nvar(x) for x in x_list] 65 | 66 | 67 | def load_model(model, best=False, postfix=None, experiment=None): 68 | global _best 69 | 70 | # model file name 71 | filename = tt.arg.save_dir + '%s.pt' % (experiment or tt.arg.experiment or model.__class__.__name__.lower()) 72 | if postfix is not None: 73 | filename = filename + '.%s' % postfix 74 | 75 | # load model 76 | global_step = 0 77 | if os.path.exists(filename): 78 | if best: 79 | global_step, model_state, _best = torch.load(filename + '.best', map_location=lambda storage, loc: storage) 80 | else: 81 | global_step, model_state = torch.load(filename, map_location=lambda storage, loc: storage) 82 | model.load_state_dict(model_state) 83 | 84 | # update best stat 85 | filename += '.best' 86 | if os.path.exists(filename): 87 | _, _, _best = torch.load(filename, map_location=lambda storage, loc: storage) 88 | 89 | return global_step 90 | 91 | 92 | def save_model(model, global_step, force=False, best=None, postfix=None): 93 | global _last_saved, _last_archived, _best 94 | 95 | # make directory 96 | pathlib.Path(tt.arg.save_dir).mkdir(parents=True, exist_ok=True) 97 | 98 | # filename to save 99 | filename = '%s.pt' % (tt.arg.experiment or model.__class__.__name__.lower()) 100 | if postfix is not None: 101 | filename = filename + '.%s' % postfix 102 | 103 | # save model 104 | if force or (tt.arg.save_interval and time.time() - _last_saved >= tt.arg.save_interval) or \ 105 | (tt.arg.save_step and global_step % tt.arg.save_step == 0): 106 | torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename) 107 | _last_saved = time.time() 108 | 109 | # archive model 110 | if (tt.arg.archive_interval and time.time() - _last_archived >= tt.arg.archive_interval) or \ 111 | (tt.arg.archive_step and global_step % tt.arg.archive_step == 0): 112 | # filename to archive 113 | if tt.arg.archive_interval: 114 | filename = filename + datetime.datetime.now().strftime('.%Y%m%d.%H%M%S') 115 | else: 116 | filename = filename + '.%d' % global_step 117 | torch.save((global_step, model.state_dict()), tt.arg.save_dir + filename) 118 | _last_archived = time.time() 119 | 120 | # save best model 121 | if best is not None and best > _best: 122 | _best = best 123 | filename = filename + '.best' 124 | torch.save((global_step, model.state_dict(), best), tt.arg.save_dir + filename) 125 | 126 | 127 | # patch Module 128 | nn.Module.load_model = load_model 129 | nn.Module.save_model = save_model 130 | -------------------------------------------------------------------------------- /torchtools/tt/arg.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import configparser 3 | import torch 4 | import threading 5 | import time 6 | import os 7 | 8 | 9 | __author__ = 'namju.kim@kakaobrain.com' 10 | 11 | 12 | _config_time_stamp = 0 13 | 14 | 15 | class _Opt(object): 16 | 17 | def __len__(self): 18 | return len(self.__dict__) 19 | 20 | def __setitem__(self, key, value): 21 | self.__dict__[key] = value 22 | 23 | def __getitem__(self, item): 24 | if item in self.__dict__: 25 | return self.__dict__[item] 26 | else: 27 | return None 28 | 29 | def __getattr__(self, item): 30 | return self.__getitem__(item) 31 | 32 | 33 | def _to_py_obj(x): 34 | # check boolean first 35 | if x.lower() in ['true', 'yes', 'on']: 36 | return True 37 | if x.lower() in ['false', 'no', 'off']: 38 | return False 39 | # from string to python object if possible 40 | try: 41 | obj = eval(x) 42 | if type(obj).__name__ in ['int', 'float', 'tuple', 'list', 'dict', 'NoneType']: 43 | x = obj 44 | except: 45 | pass 46 | return x 47 | 48 | 49 | def _parse_config(arg, file): 50 | 51 | # read config file 52 | config = configparser.ConfigParser() 53 | config.read(file) 54 | # traverse sections 55 | for section in config.sections(): 56 | # traverse items 57 | opt = _Opt() 58 | for key in config[section]: 59 | opt[key] = _to_py_obj(config[section][key]) 60 | # if default section, save items to global scope 61 | if section.lower() == 'default': 62 | for k, v in opt.__dict__.items(): 63 | arg[k] = v 64 | else: 65 | arg['_'.join(section.split())] = opt 66 | 67 | 68 | def _parse_config_thread(arg, file): 69 | 70 | global _config_time_stamp 71 | 72 | while True: 73 | # check timestamp 74 | stamp = os.stat(file).st_mtime 75 | if not stamp == _config_time_stamp: 76 | # update timestamp 77 | _config_time_stamp = stamp 78 | # parse config file 79 | _parse_config(arg, file) 80 | # print result 81 | # _print_opts(arg, 'CONFIGURATION CHANGE DETECTED') 82 | # sleep 83 | time.sleep(1) 84 | 85 | 86 | def _print_opts(arg, header): 87 | print(header, flush=True) 88 | print('-' * 30, flush=True) 89 | for k, v in arg.__dict__.items(): 90 | print('%s=%s' % (k, v), flush=True) 91 | print('-' * 30, flush=True) 92 | 93 | 94 | def _parse_opts(): 95 | 96 | global _config_time_stamp 97 | 98 | # get command line arguments 99 | arg = _Opt() 100 | argv = sys.argv[1:] 101 | 102 | # check length 103 | assert len(argv) % 2 == 0, 'arguments should be paired with the format of --key value' 104 | 105 | # parse args 106 | for i in range(0, len(argv), 2): 107 | 108 | # check format 109 | assert argv[i].startswith('--'), 'arguments should be paired with the format of --key value' 110 | 111 | # save argument 112 | arg[argv[i][2:]] = _to_py_obj(argv[i + 1]) 113 | 114 | # check config file 115 | if argv[i][2:].lower() == 'config': 116 | _parse_config(arg, argv[i + 1]) 117 | _config_time_stamp = os.stat(argv[i + 1]).st_mtime 118 | 119 | # 120 | # inject default options 121 | # 122 | 123 | # device setting 124 | if arg.device is None: 125 | arg.device = 'cuda' if torch.cuda.is_available() else 'cpu' 126 | arg.device = torch.device(arg.device) 127 | arg.cuda = arg.device.type == 'cuda' 128 | 129 | # default learning rate 130 | #arg.lr = 1e-3 131 | 132 | # directories 133 | arg.log_dir = arg.log_dir or 'asset/log/' 134 | arg.data_dir = arg.data_dir or 'asset/data/' 135 | arg.save_dir = arg.save_dir or 'asset/train/' 136 | arg.log_dir += '' if arg.log_dir.endswith('/') else '/' 137 | arg.data_dir += '' if arg.data_dir.endswith('/') else '/' 138 | arg.save_dir += '' if arg.save_dir.endswith('/') else '/' 139 | 140 | # print arg option 141 | # _print_opts(arg, 'CONFIGURATION') 142 | 143 | # start config file watcher if config is defined 144 | if arg.config: 145 | t = threading.Thread(target=_parse_config_thread, args=(arg, arg.config)) 146 | t.daemon = True 147 | t.start() 148 | 149 | return arg 150 | -------------------------------------------------------------------------------- /torchtools/tt/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | from tensorboardX import SummaryWriter 4 | from torchtools import tt 5 | 6 | 7 | __author__ = 'namju.kim@kakaobrain.com' 8 | 9 | 10 | # tensorboard writer 11 | _writer = None 12 | _stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist = {}, {}, {}, {}, {} 13 | 14 | # time stamp 15 | _last_logged = time.time() 16 | 17 | 18 | # general print wrapper 19 | def log(*args): 20 | print(*args, flush=True) 21 | # save to log_file 22 | if tt.arg.log_file: 23 | with open(tt.arg.log_dir + tt.arg.log_file, 'a') as f: 24 | print(*args, flush=True, file=f) 25 | 26 | 27 | # tensor board writer 28 | def _get_writer(): 29 | global _writer 30 | if _writer is None: 31 | # logging directory 32 | tf_log_dir = tt.arg.log_dir 33 | tf_log_dir += '' if tf_log_dir.endswith('/') else '/' 34 | if tt.arg.experiment: 35 | tf_log_dir += tt.arg.experiment 36 | tf_log_dir += datetime.datetime.now().strftime('-%Y%m%d-%H%M%S') 37 | # create writer 38 | _writer = SummaryWriter(tf_log_dir) 39 | return _writer 40 | 41 | 42 | def log_scalar(tag, value, global_step=None): 43 | _stats_scalar[tag] = (tt.nvar(value), global_step) 44 | 45 | 46 | def log_audio(tag, audio, global_step=None): 47 | _stats_audio[tag] = (tt.nvar(audio), global_step) 48 | 49 | 50 | def log_image(tag, image, global_step=None): 51 | _stats_image[tag] = (tt.nvar(image), global_step) 52 | 53 | 54 | def log_text(tag, text, global_step=None): 55 | _stats_text[tag] = (text, global_step) 56 | 57 | 58 | def log_hist(tag, values, global_step=None): 59 | _stats_hist[tag] = (tt.nvar(values), global_step) 60 | 61 | 62 | def log_step(epoch=None, global_step=None, max_epoch=None, max_step=None): 63 | 64 | global _last_logged, _last_logged_step, _stats_scalar, _stats_image, _stats_audio, _stats_text, _stats_hist 65 | 66 | # logging 67 | if (tt.arg.log_interval is None and tt.arg.log_step is None) or \ 68 | (tt.arg.log_interval and time.time() - _last_logged >= tt.arg.log_interval) or \ 69 | (tt.arg.log_step and global_step % tt.arg.log_step == 0): 70 | 71 | # update logging time stamp 72 | _last_logged = time.time() 73 | _last_logged_step = global_step 74 | 75 | # console output string 76 | console_out = '' 77 | if epoch: 78 | console_out += 'ep: %d' % epoch 79 | if max_epoch: 80 | console_out += '/%d' % max_epoch 81 | if global_step: 82 | if max_step: 83 | step = global_step % max_step 84 | step = max_step if step == 0 else step 85 | console_out += ' step: %d/%d' % (step, max_step) 86 | else: 87 | console_out += ' step: %d' % global_step 88 | 89 | # add stats to tensor board 90 | for k, v in _stats_scalar.items(): 91 | _get_writer().add_scalar(k, *v) 92 | # add to console output 93 | if not k.startswith('weight/') and not k.startswith('gradient/'): 94 | console_out += ' %s: %f' % (k, v[0]) 95 | for k, v in _stats_image.items(): 96 | _get_writer().add_image(k, *v) 97 | for k, v in _stats_audio.items(): 98 | _get_writer().add_audio(k, *v) 99 | for k, v in _stats_text.items(): 100 | _get_writer().add_text(k, *v) 101 | for k, v in _stats_hist.items(): 102 | _get_writer().add_histogram(k, *v, 'auto') 103 | 104 | # flush 105 | _get_writer().file_writer.flush() 106 | 107 | # console out 108 | if len(console_out) > 0: 109 | log(console_out) 110 | 111 | # clear stats 112 | _stats_scalar, _stats_image, _stats_audio, _stats_text = {}, {}, {}, {} 113 | 114 | 115 | def log_weight(model, global_step=None): 116 | # weight statics 117 | if tt.arg.log_weight: 118 | for k, v in model.named_parameters(): 119 | if 'weight' in k: # only for weight not bias 120 | log_scalar('weight/' + k, v.norm(), global_step) 121 | 122 | 123 | def log_gradient(model, global_step=None): 124 | # gradient statics 125 | if tt.arg.log_grad: 126 | for k, v in model.named_parameters(): 127 | if 'weight' in k: # only for weight not bias 128 | if v.grad is not None: 129 | log_scalar('gradient/' + k, v.grad.norm(), global_step) 130 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from torchtools import * 2 | from data import MiniImagenetLoader, TieredImagenetLoader 3 | from model import EmbeddingImagenet, GraphNetwork, ConvNet 4 | import shutil 5 | import os 6 | import random 7 | from train import ModelTrainer 8 | 9 | if __name__ == '__main__': 10 | 11 | tt.arg.test_model = 'D-mini_N-5_K-1_U-0_L-3_B-40_T-True' if tt.arg.test_model is None else tt.arg.test_model 12 | 13 | list1 = tt.arg.test_model.split("_") 14 | param = {} 15 | for i in range(len(list1)): 16 | param[list1[i].split("-", 1)[0]] = list1[i].split("-", 1)[1] 17 | tt.arg.dataset = param['D'] 18 | tt.arg.num_ways = int(param['N']) 19 | tt.arg.num_shots = int(param['K']) 20 | tt.arg.num_unlabeled = int(param['U']) 21 | tt.arg.num_layers = int(param['L']) 22 | tt.arg.meta_batch_size = int(param['B']) 23 | tt.arg.transductive = False if param['T'] == 'False' else True 24 | 25 | 26 | #################### 27 | tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device 28 | # replace dataset_root with your own 29 | tt.arg.dataset_root = '/data/private/dataset' 30 | tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset 31 | tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways 32 | tt.arg.num_shots = 1 if tt.arg.num_shots is None else tt.arg.num_shots 33 | tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled 34 | tt.arg.num_layers = 3 if tt.arg.num_layers is None else tt.arg.num_layers 35 | tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size 36 | tt.arg.transductive = False if tt.arg.transductive is None else tt.arg.transductive 37 | tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed 38 | tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus 39 | 40 | tt.arg.num_ways_train = tt.arg.num_ways 41 | tt.arg.num_ways_test = tt.arg.num_ways 42 | 43 | tt.arg.num_shots_train = tt.arg.num_shots 44 | tt.arg.num_shots_test = tt.arg.num_shots 45 | 46 | tt.arg.train_transductive = tt.arg.transductive 47 | tt.arg.test_transductive = tt.arg.transductive 48 | 49 | # model parameter related 50 | tt.arg.num_edge_features = 96 51 | tt.arg.num_node_features = 96 52 | tt.arg.emb_size = 128 53 | 54 | # train, test parameters 55 | tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000 56 | tt.arg.test_iteration = 10000 57 | tt.arg.test_interval = 5000 58 | tt.arg.test_batch_size = 10 59 | tt.arg.log_step = 1000 60 | 61 | tt.arg.lr = 1e-3 62 | tt.arg.grad_clip = 5 63 | tt.arg.weight_decay = 1e-6 64 | tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000 65 | tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0 66 | 67 | #set random seed 68 | np.random.seed(tt.arg.seed) 69 | torch.manual_seed(tt.arg.seed) 70 | torch.cuda.manual_seed_all(tt.arg.seed) 71 | random.seed(tt.arg.seed) 72 | torch.backends.cudnn.deterministic = True 73 | torch.backends.cudnn.benchmark = False 74 | 75 | 76 | enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size) 77 | 78 | # set random seed 79 | np.random.seed(tt.arg.seed) 80 | torch.manual_seed(tt.arg.seed) 81 | torch.cuda.manual_seed_all(tt.arg.seed) 82 | random.seed(tt.arg.seed) 83 | torch.backends.cudnn.deterministic = True 84 | torch.backends.cudnn.benchmark = False 85 | 86 | # to check 87 | exp_name = 'D-{}'.format(tt.arg.dataset) 88 | exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled) 89 | exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size) 90 | exp_name += '_T-{}'.format(tt.arg.transductive) 91 | 92 | 93 | if not exp_name == tt.arg.test_model: 94 | print(exp_name) 95 | print(tt.arg.test_model) 96 | print('Test model and input arguments are mismatched!') 97 | AssertionError() 98 | 99 | gnn_module = GraphNetwork(in_features=tt.arg.emb_size, 100 | node_features=tt.arg.num_edge_features, 101 | edge_features=tt.arg.num_node_features, 102 | num_layers=tt.arg.num_layers, 103 | dropout=tt.arg.dropout) 104 | 105 | if tt.arg.dataset == 'mini': 106 | test_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='test') 107 | elif tt.arg.dataset == 'tiered': 108 | test_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='test') 109 | else: 110 | print('Unknown dataset!') 111 | 112 | 113 | data_loader = {'test': test_loader} 114 | 115 | # create trainer 116 | tester = ModelTrainer(enc_module=enc_module, 117 | gnn_module=gnn_module, 118 | data_loader=data_loader) 119 | 120 | 121 | #checkpoint = torch.load('asset/checkpoints/{}/'.format(exp_name) + 'model_best.pth.tar') 122 | checkpoint = torch.load('./trained_models/{}/'.format(exp_name) + 'model_best.pth.tar') 123 | 124 | 125 | tester.enc_module.load_state_dict(checkpoint['enc_module_state_dict']) 126 | print("load pre-trained enc_nn done!") 127 | 128 | # initialize gnn pre-trained 129 | tester.gnn_module.load_state_dict(checkpoint['gnn_module_state_dict']) 130 | print("load pre-trained egnn done!") 131 | 132 | tester.val_acc = checkpoint['val_acc'] 133 | tester.global_step = checkpoint['iteration'] 134 | 135 | print(tester.global_step) 136 | 137 | 138 | tester.eval(partition='test') 139 | 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fewshot-egnn 2 | 3 | ### Introduction 4 | 5 | The current project page provides pytorch code that implements the following CVPR2019 paper: 6 | **Title:** "Edge-labeling Graph Neural Network for Few-shot Learning" 7 | **Authors:** Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D.Yoo 8 | 9 | **Institution:** KAIST, KaKaoBrain 10 | **Code:** https://github.com/khy0809/fewshot-egnn 11 | **Arxiv:** https://arxiv.org/abs/1905.01436 12 | 13 | **Abstract:** 14 | In this paper, we propose a novel edge-labeling graph 15 | neural network (EGNN), which adapts a deep neural network 16 | on the edge-labeling graph, for few-shot learning. 17 | The previous graph neural network (GNN) approaches in 18 | few-shot learning have been based on the node-labeling 19 | framework, which implicitly models the intra-cluster similarity 20 | and the inter-cluster dissimilarity. In contrast, the 21 | proposed EGNN learns to predict the edge-labels rather 22 | than the node-labels on the graph that enables the evolution 23 | of an explicit clustering by iteratively updating the edgelabels 24 | with direct exploitation of both intra-cluster similarity 25 | and the inter-cluster dissimilarity. It is also well suited 26 | for performing on various numbers of classes without retraining, 27 | and can be easily extended to perform a transductive 28 | inference. The parameters of the EGNN are learned 29 | by episodic training with an edge-labeling loss to obtain a 30 | well-generalizable model for unseen low-data problem. On 31 | both of the supervised and semi-supervised few-shot image 32 | classification tasks with two benchmark datasets, the proposed 33 | EGNN significantly improves the performances over 34 | the existing GNNs. 35 | 36 | ### Citation 37 | If you find this code useful you can cite us using the following bibTex: 38 | ``` 39 | @article{kim2019egnn, 40 | title={Edge-labeling Graph Neural Network for Few-shot Learning}, 41 | author={Jongmin Kim, Taesup Kim, Sungwoong Kim, Chang D. Yoo}, 42 | journal={arXiv preprint arXiv:1905.01436}, 43 | year={2019} 44 | } 45 | ``` 46 | 47 | 48 | ### Platform 49 | This code was developed and tested with pytorch version 1.0.1 50 | 51 | ### Setting 52 | 53 | You can download miniImagenet dataset from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w). 54 | 55 | Download 'mini_imagenet_train/val/test.pickle', and put them in the path 56 | 'tt.arg.dataset_root/mini-imagenet/compacted_dataset/' 57 | 58 | In ```train.py```, replace the dataset root directory with your own: 59 | tt.arg.dataset_root = '/data/private/dataset' 60 | 61 | 62 | 63 | ### Training 64 | 65 | ``` 66 | # ************************** miniImagenet, 5way 1shot ***************************** 67 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive False 68 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 1 --transductive True 69 | 70 | # ************************** miniImagenet, 5way 5shot ***************************** 71 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive False 72 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --transductive True 73 | 74 | # ************************** miniImagenet, 10way 5shot ***************************** 75 | $ python3 train.py --dataset mini --num_ways 10 --num_shots 5 --meta_batch_size 20 --transductive True 76 | 77 | # ************************** tieredImagenet, 5way 5shot ***************************** 78 | $ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive False 79 | $ python3 train.py --dataset tiered --num_ways 5 --num_shots 5 --transductive True 80 | 81 | # **************** miniImagenet, 5way 5shot, 20% labeled (semi) ********************* 82 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive False 83 | $ python3 train.py --dataset mini --num_ways 5 --num_shots 5 --num_unlabeled 4 --transductive True 84 | 85 | ``` 86 | 87 | ### Evaluation 88 | The trained models are saved in the path './asset/checkpoints/', with the name of 'D-{dataset}-N-{ways}-K-{shots}-U-{num_unlabeld}-L-{num_layers}-B-{batch size}-T-{transductive}'. 89 | So, for example, if you want to test the trained model of 'miniImagenet, 5way 1shot, transductive' setting, you can give --test_model argument as follow: 90 | ``` 91 | $ python3 eval.py --test_model D-mini_N-5_K-1_U-0_L-3_B-40_T-True 92 | ``` 93 | 94 | 95 | ## Result 96 | Here are some experimental results presented in the paper. You should be able to reproduce all the results by using the trained models which can be downloaded from [here](https://drive.google.com/open?id=15WuREBvhEbSWo4fTr1r-vMY0C_6QWv4w). 97 | #### miniImageNet, non-transductive 98 | 99 | | Model | 5-way 5-shot acc (%)| 100 | |--------------------------| ------------------: | 101 | | Matching Networks [1] | 55.30 | 102 | | Reptile [2] | 62.74 | 103 | | Prototypical Net [3] | 65.77 | 104 | | GNN [4] | 66.41 | 105 | | **(ours)** EGNN | **66.85** | 106 | 107 | #### miniImageNet, transductive 108 | 109 | | Model | 5-way 5-shot acc (%)| 110 | |--------------------------| ------------------: | 111 | | MAML [5] | 63.11 | 112 | | Reptile + BN [2] | 65.99 | 113 | | Relation Net [6] | 67.07 | 114 | | MAML + Transduction [5] | 66.19 | 115 | | TPN [7] | 69.43 | 116 | | TPN (Higher K) [7] | 69.86 | 117 | | **(ours)** EGNN | **76.37** | 118 | 119 | #### tieredImageNet, non-transductive 120 | 121 | | Model | 5-way 5-shot acc (%)| 122 | |--------------------------| ------------------: | 123 | | Reptile [2] | 66.47 | 124 | | Prototypical Net [3] | 69.57 | 125 | | **(ours)** EGNN | **70.98** | 126 | 127 | #### tieredImageNet, transductive 128 | 129 | | Model | 5-way 5-shot acc (%)| 130 | |--------------------------| ------------------: | 131 | | MAML [5] | 70.30 | 132 | | Reptile + BN [2] | 71.03 | 133 | | Relation Net [6] | 71.31 | 134 | | MAML + Transduction [5] | 70.83 | 135 | | TPN [7] | 72.58 | 136 | | **(ours)** EGNN | **80.15** | 137 | 138 | 139 | #### miniImageNet, semi-supervised, 5-way 5-shot 140 | 141 | | Model | 20% | 40% | 60% | 100% | 142 | |--------------------------| ------------------: | ------------------: | ------------------: | ------------------: | 143 | | GNN-LabeledOnly [4] | 50.33 | 56.91 | - | 66.41 | 144 | | GNN-Semi [4] | 52.45 | 58.76 | - | 66.41 | 145 | | EGNN-LabeledOnly | 52.86 | - | - | 66.85 | 146 | | EGNN-Semi | 61.88 | 62.52 | 63.53 | 66.85 | 147 | | EGNN-LabeledOnly (Transductive) | 59.18 | - | - | 76.37 | 148 | | EGNN-Semi (Transductive) | 63.62 | 64.32 | 66.37 | 76.37 | 149 | 150 | 151 | #### miniImageNet, cross-way experiment 152 | | Model | train way | test way | Accuracy | 153 | |--------------------------| ------------------: | ------------------: | ------------------: | 154 | | GNN | 5 | 5 | 66.41 | 155 | | GNN | 5 | 10 | N/A | 156 | | GNN | 10 | 10 | 51.75 | 157 | | GNN | 10 | 5 | N/A | 158 | | EGNN | 5 | 5 | 76.37 | 159 | | EGNN | 5 | 10 | 56.35 | 160 | | EGNN | 10 | 10 | 57.61 | 161 | | EGNN | 10 | 5 | 76.27 | 162 | 163 | 164 | 165 | ### References 166 | ``` 167 | [1] O. Vinyals et al. Matching networks for one shot learning. 168 | [2] A Nichol, J Achiam, J Schulman, On first-order meta-learning algorithms. 169 | [3] J. Snell, K. Swersky, and R. S. Zemel. Prototypical networks for few-shot learning. 170 | [4] V Garcia, J Bruna, Few-shot learning with graph neural network. 171 | [5] C. Finn, P. Abbeel, and S. Levine. Model-agnostic meta-learning for fast adaptation of deep networks. 172 | [6] F. Sung et al, Learning to Compare: Relation Network for Few-Shot Learning. 173 | [7] Y Liu, J Lee, M Park, S Kim, Y Yang, Transductive propagation network for few-shot learning. 174 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | tt.arg.inter_dea 100 | inter_deactivate 101 | 102 | 103 | 104 | 106 | 107 | 114 | 115 | 116 | 117 | 118 | true 119 | DEFINITION_ORDER 120 | 121 | 122 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 |