├── .gitignore ├── LICENSE ├── README.md ├── assets ├── arch.png └── gzsl.png ├── config.py ├── main.py ├── requirements.txt ├── train.py ├── train_clswgan.py └── utils ├── AWADataset.py ├── nn.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /lab/ 131 | /images/ 132 | /data.csv 133 | /models/ 134 | /AWA2/ 135 | /model/ 136 | /.idea/ 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 mkara44 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature Generating Networks for Zero-Shot Learning 2 | The unofficial implementation of [Feature Generating Networks for Zero-Shot Learning](https://arxiv.org/abs/1712.00981) on Pytorch 3 | 4 | ![GZSL](./assets/gzsl.png "GZSL") 5 | *Figure from Official Paper* 6 | 7 | ## Generalized Zero Shot Learning (GZSL) 8 | - Zero-shot learning aims to recognize objects whose instances may not have been seen during training. [1] 9 | - GZSL leverages semantic information of both seen (source) and unseen (target) classes to bridge the gap between both seen and unseen classes [2] 10 | 11 | ## Model Architecture 12 | ![Model Architecture](./assets/arch.png "Model Architecure") 13 | *f-CLSWGAN Architecture Figure from Official Paper* 14 | 15 | ## Dependencies 16 | - Python 3.6+ 17 | - `pip install -r requirements.txt` 18 | 19 | ## Dataset 20 | - Animal with Attributes 2 [3] dataset is used. This dataset contains 50 classes with 37322 images. 40 classes are divided into seen classes. 21 | - Original paper authors has shared the features maps from ResNet101 [4]. These feature maps are used for training. 22 | - Seen classes are splitted to train and test sets. `trainval_loc` indexes ares used for training, `test_seen_loc` indexes are used for testing.[4] 23 | - Unseen classes are not splitted. `test_unseen_loc` indexes are used for both training and testing. [4] 24 | 25 | ## Training 26 | - After training, models will be saved to defined path in `config.py`. 27 | - Training process without any pretrained models. 28 | - `python main.py --train` 29 | - Pretrained models for any part of proposed approach can be used for fine-tuning. 30 | - `python main.py --train --g_cls_path ` 31 | - `python main.py --train --g_cls_path --wgan_G_path --wgan_D_path --projection_path ` 32 | 33 | ## Evaluation 34 | - Pretrained models can be used for evaluation 35 | - `python main.py --g_cls_path --wgan_G_path --wgan_D_path --projection_path ` 36 | 37 | ## References 38 | - [1] [Zero-Shot Learning - A Comprehensive Evaluation of the Good, the Bad and the Ugly](https://arxiv.org/pdf/1707.00600.pdf) 39 | - [2] [A Review of Generalized Zero-Shot Learning Methods](https://arxiv.org/pdf/2011.08641.pdf) 40 | - [3] [Animal with Attributes 2](https://cvml.ist.ac.at/AwA2/) 41 | - [4] [ResNet101 Feature Maps of AWA2 Dataset](https://datasets.d2.mpi-inf.mpg.de/xian/xlsa17.zip) -------------------------------------------------------------------------------- /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkara44/f-clswgan_pytorch/641cd00834ec5b801775fce01c94911d10cb33b4/assets/arch.png -------------------------------------------------------------------------------- /assets/gzsl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkara44/f-clswgan_pytorch/641cd00834ec5b801775fce01c94911d10cb33b4/assets/gzsl.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | 3 | cfg = EasyDict() 4 | cfg.beta = 0.01 5 | cfg.lambd = 10. 6 | cfg.latent_dim = 128 7 | cfg.batch_size = 64 8 | cfg.attr_number = 85 9 | cfg.output_size = 224 10 | cfg.seen_class_number = 40 11 | cfg.unseen_class_number = 10 12 | cfg.atts_path = './AWA2/att_splits.mat' 13 | cfg.res_path = './AWA2/res101.mat' 14 | 15 | # g_cls settings 16 | cfg.g_cls = EasyDict() 17 | cfg.g_cls.epoch = 30 18 | cfg.g_cls.learning_rate = 1e-4 19 | cfg.g_cls.model_name = 'g_cls_model_1e4.pt' 20 | 21 | # wgan settings 22 | cfg.wgan = EasyDict() 23 | cfg.wgan.epoch = 100 24 | cfg.wgan.n_step = 5 25 | cfg.wgan.learning_rate = 1e-4 26 | cfg.wgan.G_model_name = 'wgan_G_model_1e4.pt' 27 | cfg.wgan.D_model_name = 'wgan_D_model_1e4.pt' 28 | 29 | # projection settings 30 | cfg.projection = EasyDict() 31 | cfg.projection.epoch = 30 32 | cfg.projection.learning_rate = 1e-4 33 | cfg.projection.model_name = 'projection_model_1e4.pt' 34 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | # Additional Scripts 5 | from train import TrainTestPipe 6 | 7 | 8 | def main_pipeline(parser): 9 | device = 'cpu:0' 10 | if torch.cuda.is_available(): 11 | device = 'cuda:0' 12 | 13 | ttp = TrainTestPipe(device) 14 | ttp.load_model(parser.g_cls_path, 'g_cls') 15 | ttp.load_model([parser.wgan_G_path, parser.wgan_D_path], ['wgan_G', 'wgan_D']) 16 | ttp.load_model(parser.projection_path, 'projection') 17 | 18 | if parser.train: 19 | print('G_cls training process has been started!') 20 | ttp.train_g_cls() 21 | 22 | print('Wgan training process has been started!') 23 | ttp.train_wgan() 24 | 25 | print('Projection training process has been started!') 26 | ttp.train_projection() 27 | 28 | print('Test has been started!') 29 | ttp.test() 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--train', action='store_true') 35 | parser.add_argument('--g_cls_path', type=str, default=None) 36 | parser.add_argument('--wgan_G_path', type=str, default=None) 37 | parser.add_argument('--wgan_D_path', type=str, default=None) 38 | parser.add_argument('--projection_path', type=str, default=None) 39 | parser = parser.parse_args() 40 | 41 | main_pipeline(parser) 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | numpy==1.19.2 3 | scipy==1.5.2 4 | torch==1.7.0 5 | tqdm==4.49.0 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import tqdm 5 | from torch.utils.data import DataLoader 6 | 7 | # Additional Scripts 8 | from train_clswgan import TrainCLSWGAN 9 | from utils.utils import calculate_correct_cls, calculate_label_acc, EpochCallback 10 | from utils.AWADataset import AWADataset 11 | from config import cfg 12 | 13 | 14 | class TrainTestPipe: 15 | def __init__(self, device): 16 | self.device = device 17 | 18 | self.train_seen_loader = self.__load_dataset('trainval_loc') 19 | self.val_seen_loader = self.__load_dataset('test_seen_loc') 20 | self.unseen_loader = self.__load_dataset('test_unseen_loc') 21 | 22 | self.clswgan = TrainCLSWGAN(self.device) 23 | 24 | def __load_dataset(self, zsl_set): 25 | animal_set = AWADataset(zsl_set) 26 | return DataLoader(animal_set, batch_size=cfg.batch_size, shuffle=False) 27 | 28 | def __loop_train(self, loader, step_func, t, val=False, set=None): 29 | total_loss = None 30 | total_correct = None 31 | 32 | for step, data in enumerate(loader): 33 | feat, atts, cls_true = data['feature'], data['attribute'], data['label'] 34 | feat = torch.autograd.Variable(feat, requires_grad=True).to(self.device) 35 | atts = torch.autograd.Variable(atts, requires_grad=True).to(self.device) 36 | cls_true = cls_true.to(self.device).squeeze_() 37 | 38 | loss, cls_pred = step_func(feat=feat, atts=atts, cls_true=cls_true, val=val, set=set, step=step) 39 | 40 | if cls_pred is not None: 41 | n_correct = calculate_correct_cls(cls_pred, cls_true) 42 | if total_correct is None: 43 | total_correct = 0 44 | 45 | total_correct += n_correct 46 | 47 | if isinstance(loss, list): 48 | if total_loss is None: 49 | total_loss = [0] * len(loss) 50 | 51 | total_loss = np.add(total_loss, loss).tolist() 52 | else: 53 | if total_loss is None: 54 | total_loss = 0 55 | 56 | total_loss += loss 57 | 58 | t.update() 59 | 60 | return total_loss, total_correct 61 | 62 | def load_model(self, paths, model_types): 63 | if isinstance(paths, str) or paths is None: 64 | paths = [paths] 65 | model_types = [model_types] 66 | 67 | for model_type, path in zip(model_types, paths): 68 | if path is None: 69 | print(f'{model_type}_path cannot be loaded, it is not defined!') 70 | break 71 | 72 | elif not os.path.exists(path): 73 | print(f'Path ({path}) does not exist!') 74 | break 75 | 76 | if model_type == 'g_cls': 77 | model = self.clswgan.G_cls 78 | optimizer = self.clswgan.G_cls_optimizer 79 | elif model_type == 'wgan_G': 80 | model = self.clswgan.G 81 | optimizer = self.clswgan.G_optimizer 82 | elif model_type == 'wgan_D': 83 | model = self.clswgan.D 84 | optimizer = self.clswgan.D_optimizer 85 | elif model_type == 'projection': 86 | model = self.clswgan.projection 87 | optimizer = self.clswgan.projection_optimizer 88 | else: 89 | print(f'Unexpected model_type! ({model_type})') 90 | break 91 | 92 | ckpt = torch.load(path) 93 | model.load_state_dict(ckpt['model_state_dict']) 94 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 95 | print(f'{model_type} model has been loaded!') 96 | 97 | def train_g_cls(self): 98 | callback = EpochCallback(cfg.g_cls.model_name, cfg.g_cls.epoch, 99 | self.clswgan.G_cls, self.clswgan.G_cls_optimizer, 'val_loss') 100 | 101 | for epoch in range(cfg.g_cls.epoch): 102 | with tqdm(total=len(self.train_seen_loader) + len(self.val_seen_loader)) as t: 103 | train_loss, train_correct = self.__loop_train(self.train_seen_loader, self.clswgan.step_g_cls, t) 104 | 105 | val_loss, val_correct = self.__loop_train(self.val_seen_loader, self.clswgan.step_g_cls, t, val=True) 106 | 107 | callback.epoch_end(epoch + 1, 108 | {'loss': train_loss / len(self.train_seen_loader), 109 | 'acc': train_correct / (len(self.train_seen_loader) * cfg.batch_size), 110 | 'val_loss': val_loss / len(self.val_seen_loader), 111 | 'val_acc': val_correct / (len(self.val_seen_loader) * cfg.batch_size)}) 112 | 113 | def train_wgan(self): 114 | callback = EpochCallback([cfg.wgan.G_model_name, cfg.wgan.D_model_name], cfg.wgan.epoch, 115 | [self.clswgan.G, self.clswgan.D], 116 | [self.clswgan.G_optimizer, self.clswgan.D_optimizer]) 117 | 118 | for epoch in range(cfg.wgan.epoch): 119 | with tqdm(total=len(self.train_seen_loader)) as t: 120 | loss, _ = self.__loop_train(self.train_seen_loader, self.clswgan.step_wgan, t) 121 | 122 | callback.epoch_end(epoch + 1, 123 | {'d_loss': loss[0] / len(self.train_seen_loader), 124 | 'g_loss': loss[1] / (len(self.train_seen_loader) / cfg.wgan.n_step)}) 125 | 126 | def train_projection(self): 127 | callback = EpochCallback(cfg.projection.model_name, cfg.projection.epoch, 128 | self.clswgan.projection, self.clswgan.projection_optimizer, 129 | 'unseen_loss') 130 | 131 | for epoch in range(cfg.projection.epoch): 132 | with tqdm(total=len(self.train_seen_loader) + len(self.unseen_loader)) as t: 133 | seen_train_loss, seen_train_correct = self.__loop_train(self.train_seen_loader, 134 | self.clswgan.step_projection, t, set='seen') 135 | 136 | unseen_loss, unseen_correct = self.__loop_train(self.unseen_loader, 137 | self.clswgan.step_projection, t, set='unseen') 138 | 139 | callback.epoch_end(epoch + 1, 140 | {'seen_train_loss': seen_train_loss / len(self.train_seen_loader), 141 | 'seen_train_acc': seen_train_correct / (len(self.train_seen_loader) * cfg.batch_size), 142 | 'unseen_loss': unseen_loss / len(self.unseen_loader), 143 | 'unseen_acc': unseen_correct / (len(self.unseen_loader) * cfg.batch_size)}) 144 | 145 | def __loop_test(self, loader, t): 146 | label_acc = {} 147 | for data in loader: 148 | feat, atts, cls_true = data['feature'], data['attribute'], data['label'] 149 | feat = feat.to(self.device) 150 | cls_true = cls_true.to(self.device).squeeze_() 151 | 152 | cls_pred = self.clswgan.inference(feat=feat) 153 | label_acc = calculate_label_acc(cls_pred, cls_true, label_acc) 154 | t.update() 155 | 156 | ay = sum(n_correct / n for n_correct, n in label_acc.values()) / len(label_acc) 157 | return ay 158 | 159 | def test(self): 160 | with tqdm(total=len(self.val_seen_loader) + len(self.unseen_loader)) as t: 161 | ays = self.__loop_test(self.val_seen_loader, t) 162 | ayu = self.__loop_test(self.unseen_loader, t) 163 | 164 | H = (2 * ayu * ays) / (ayu + ays) 165 | 166 | print(f'Seen Set Accuracy: {ays}\nUnseen Set Accuracy: {ayu}\nH: {H}') 167 | -------------------------------------------------------------------------------- /train_clswgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import torch.autograd as autograd 6 | from torch.distributions import uniform, normal 7 | 8 | # Additional Scripts 9 | from utils.nn import Generator, Discriminator, MLP 10 | from config import cfg 11 | 12 | 13 | class TrainCLSWGAN: 14 | atts_dim = 0 15 | batch_size = 64 16 | output_size = 224 17 | eps = uniform.Uniform(0, 1) 18 | Z_sampler = normal.Normal(0, 1) 19 | beta = cfg.beta 20 | lambd = cfg.lambd 21 | 22 | def __init__(self, device): 23 | self.device = device 24 | 25 | # self.G_cls = MLP(cfg.seen_class_number).to(self.device) 26 | self.G_cls = MLP(cfg.seen_class_number + cfg.unseen_class_number).to(self.device) 27 | self.G = Generator(cfg.attr_number + cfg.latent_dim).to(self.device) 28 | self.D = Discriminator(cfg.attr_number).to(self.device) 29 | self.projection = MLP(cfg.seen_class_number + cfg.unseen_class_number).to(self.device) 30 | 31 | self.projection_criterion = nn.NLLLoss() 32 | self.G_cls_criterion = nn.NLLLoss() 33 | 34 | self.G_cls_optimizer = optim.Adam(self.G_cls.parameters(), lr=cfg.g_cls.learning_rate) 35 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=cfg.wgan.learning_rate) 36 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=cfg.wgan.learning_rate) 37 | self.projection_optimizer = optim.Adam(self.projection.parameters(), lr=cfg.projection.learning_rate) 38 | 39 | def get_noise(self, batch_size): 40 | return torch.autograd.Variable(self.Z_sampler.sample(torch.Size([batch_size, cfg.latent_dim])).to(self.device)) 41 | 42 | def get_gradient_penalty(self, d_real, d_fake, batch_size, atts): 43 | eps = self.eps.sample(torch.Size([batch_size, 1])).to(self.device) 44 | X_penalty = eps * d_real + (1 - eps) * d_fake 45 | 46 | X_penalty = autograd.Variable(X_penalty, requires_grad=True).to(self.device) 47 | d_pred = self.D(X_penalty, atts) 48 | grad_outputs = torch.ones(d_pred.size()).to(self.device) 49 | gradients = autograd.grad( 50 | outputs=d_pred, inputs=X_penalty, 51 | grad_outputs=grad_outputs, 52 | create_graph=True, retain_graph=True, only_inputs=True 53 | )[0] 54 | 55 | grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambd 56 | return grad_penalty 57 | 58 | def step_g_cls(self, **params): 59 | if params['val']: 60 | self.G_cls.eval() 61 | with torch.no_grad(): 62 | cls_pred = self.G_cls(params['feat']) 63 | loss = self.G_cls_criterion(F.log_softmax(cls_pred, dim=1), params['cls_true']) 64 | 65 | return loss.item(), cls_pred 66 | 67 | self.G_cls.train() 68 | self.G_cls_optimizer.zero_grad() 69 | cls_pred = self.G_cls(params['feat']) 70 | loss = self.G_cls_criterion(F.log_softmax(cls_pred, dim=1), params['cls_true']) 71 | loss.backward() 72 | self.G_cls_optimizer.step() 73 | 74 | return loss.item(), cls_pred 75 | 76 | def step_wgan(self, **params): 77 | loss_g = None 78 | self.G_cls.eval() 79 | for p in self.D.parameters(): 80 | p.requires_grad = True 81 | 82 | batch_size = params['atts'].shape[0] 83 | self.D_optimizer.zero_grad() 84 | d_real = self.D(params['feat'], params['atts']) 85 | d_real = torch.mean(d_real) 86 | d_real.backward(torch.tensor(-1.)) 87 | 88 | Z = self.get_noise(batch_size) 89 | fake_feat = self.G(Z, params['atts']) 90 | 91 | d_fake = self.D(fake_feat, params['atts']) 92 | d_fake = torch.mean(d_fake) 93 | d_fake.backward(torch.tensor(1.)) 94 | 95 | gradient_penalty = self.get_gradient_penalty(params['feat'], fake_feat, batch_size, params['atts']) 96 | gradient_penalty.backward() 97 | 98 | loss_d = d_fake - d_real + gradient_penalty 99 | self.D_optimizer.step() 100 | 101 | if params['step'] % cfg.wgan.n_step == 0: 102 | for p in self.D.parameters(): 103 | p.requires_grad = False 104 | self.G_optimizer.zero_grad() 105 | Z = self.get_noise(batch_size) 106 | fake_feat = self.G(Z, params['atts']) 107 | 108 | d_fake = self.D(fake_feat, params['atts']) 109 | d_fake = -1 * torch.mean(d_fake) 110 | 111 | g_cls_pred = self.G_cls(fake_feat) 112 | loss_cls = self.G_cls_criterion(F.log_softmax(g_cls_pred, dim=1), params['cls_true']) 113 | loss_g = d_fake + self.beta * loss_cls 114 | 115 | loss_g.backward() 116 | self.G_optimizer.step() 117 | 118 | return [loss_d.item(), loss_g.item() if loss_g is not None else 0], None 119 | 120 | def step_projection(self, **params): 121 | with torch.no_grad(): 122 | if params['set'] == 'seen': 123 | feat = params['feat'] 124 | 125 | if params['val']: 126 | self.projection.eval() 127 | cls_pred = self.projection(params['feat']) 128 | loss = self.projection_criterion(cls_pred, params['cls_true']) 129 | 130 | return loss.item(), cls_pred 131 | 132 | elif params['set'] == 'unseen': 133 | batch_size = params['atts'].shape[0] 134 | Z = self.Z_sampler.sample(torch.Size([batch_size, cfg.latent_dim])).to(self.device) 135 | feat = self.G(Z, params['atts']) 136 | 137 | self.projection.train() 138 | self.projection_optimizer.zero_grad() 139 | cls_pred = self.projection(feat) 140 | loss = self.projection_criterion(F.log_softmax(cls_pred, dim=1), params['cls_true']) 141 | loss.backward() 142 | self.projection_optimizer.step() 143 | 144 | return loss.item(), cls_pred 145 | 146 | def inference(self, **params): 147 | self.projection.eval() 148 | with torch.no_grad(): 149 | cls_pred = self.projection(params['feat']) 150 | 151 | return cls_pred 152 | -------------------------------------------------------------------------------- /utils/AWADataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import scipy.io as sio 4 | 5 | # Additional Scripts 6 | from config import cfg 7 | 8 | 9 | class AWADataset(Dataset): 10 | res_mat = sio.loadmat(cfg.res_path) 11 | atts_mat = sio.loadmat(cfg.atts_path) 12 | 13 | def __init__(self, set): 14 | super().__init__() 15 | 16 | loc = self.atts_mat[set].squeeze() - 1 17 | 18 | self.features = torch.from_numpy(self.res_mat['features'][..., loc]).float().T 19 | self.atts = torch.from_numpy(self.atts_mat['att']).float().T 20 | self.labels = torch.from_numpy((self.res_mat['labels'] - 1)[loc]).long() 21 | 22 | def __getitem__(self, idx): 23 | return {'feature': self.features[idx, :], 24 | 'label': self.labels[idx], 25 | 'attribute': self.atts[self.labels[idx][0]]} 26 | 27 | def __len__(self): 28 | return self.labels.shape[0] 29 | -------------------------------------------------------------------------------- /utils/nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, out_dim, x_dim=2048): 8 | super().__init__() 9 | 10 | self.fc1 = nn.Linear(x_dim, int(x_dim / 2)) 11 | self.fc2 = nn.Linear(int(x_dim / 2), int(x_dim / 2)) 12 | self.fc3 = nn.Linear(int(x_dim / 2), int(x_dim / 2)) 13 | self.fc4 = nn.Linear(int(x_dim / 2), out_dim) 14 | 15 | def forward(self, x): 16 | for fc in [self.fc1, self.fc2, self.fc3]: 17 | x = F.leaky_relu(fc(x)) 18 | x = F.dropout(x) 19 | 20 | x = self.fc4(x) 21 | return x 22 | 23 | 24 | def weights_init(m): 25 | classname = m.__class__.__name__ 26 | if classname.find('Linear') != -1: 27 | m.weight.data.normal_(0.0, 0.02) 28 | m.bias.data.fill_(0) 29 | 30 | 31 | class Generator(nn.Module): 32 | def __init__(self, attr_dim): 33 | super().__init__() 34 | 35 | self.fc1 = nn.Linear(attr_dim, 4096) 36 | self.fc2 = nn.Linear(4096, 2048) 37 | 38 | self.apply(weights_init) 39 | 40 | def forward(self, noise, atts): 41 | x = torch.cat((noise, atts), 1) 42 | x = F.leaky_relu(self.fc1(x)) 43 | x = F.relu(self.fc2(x)) 44 | 45 | return x 46 | 47 | 48 | class Discriminator(nn.Module): 49 | def __init__(self, attr_dim, x_dim=2048): 50 | super().__init__() 51 | 52 | self.fc1 = nn.Linear(x_dim + attr_dim, 4096) 53 | self.fc2 = nn.Linear(4096, 1) 54 | 55 | self.apply(weights_init) 56 | 57 | def forward(self, feat, atts): 58 | x = torch.cat((feat, atts), 1) 59 | x = F.leaky_relu(self.fc1(x)) 60 | x = self.fc2(x) 61 | 62 | return x -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def calculate_correct_cls(cls_pred, cls_true): 6 | _, cls_pred = torch.max(cls_pred.data, 1) 7 | return (cls_pred == cls_true).sum().item() 8 | 9 | 10 | def calculate_label_acc(cls_pred, cls_true, label_acc): 11 | _, cls_pred = torch.max(cls_pred.data, 1) 12 | for c_p, c_t in zip(cls_pred, cls_true): 13 | c_p = int(c_p) 14 | c_t = int(c_t) 15 | 16 | if label_acc.get(c_t) is None: 17 | label_acc[c_t] = [0, 0] 18 | 19 | label_acc[c_t][1] += 1 20 | if c_p == c_t: 21 | label_acc[c_t][0] += 1 22 | 23 | return label_acc 24 | 25 | 26 | class EpochCallback: 27 | monitor_value = np.inf 28 | 29 | def __init__(self, model_name, total_epoch_num, model, optimizer, monitor=None): 30 | if isinstance(model_name, str): 31 | model_name = [model_name] 32 | model = [model] 33 | optimizer = [optimizer] 34 | 35 | self.model_name = model_name 36 | self.total_epoch_num = total_epoch_num 37 | self.monitor = monitor 38 | self.model = model 39 | self.optimizer = optimizer 40 | 41 | def __save_model(self): 42 | for m_name, m, opt in zip(self.model_name, self.model, self.optimizer): 43 | torch.save({'model_state_dict': m.state_dict(), 44 | 'optimizer_state_dict': opt.state_dict()}, 45 | m_name) 46 | 47 | print(f'Model saved to {m_name}') 48 | 49 | def epoch_end(self, epoch_num, hash): 50 | epoch_end_str = f'Epoch {epoch_num}/{self.total_epoch_num} - ' 51 | for name, value in hash.items(): 52 | epoch_end_str += f'{name}: {round(value, 3)} ' 53 | 54 | print(epoch_end_str) 55 | 56 | if self.monitor is None: 57 | self.__save_model() 58 | 59 | elif hash[self.monitor] < self.monitor_value: 60 | print(f'{self.monitor} decreased from {round(self.monitor_value, 4)} to {round(hash[self.monitor], 4)}') 61 | 62 | self.monitor_value = hash[self.monitor] 63 | self.__save_model() 64 | else: 65 | print(f'{self.monitor} did not decrease from {round(self.monitor_value, 4)}, model did not save!') 66 | --------------------------------------------------------------------------------