├── metrics ├── .gitignore ├── configs │ └── salad+cookgan.yaml ├── utils_metrics.py ├── medR.py ├── calc_inception.py ├── fid.py ├── datasets_inception.py └── inception.py ├── cookgan ├── .gitignore ├── run.sh ├── utils_cookgan.py ├── args_cookgan.py ├── generate_batch.py ├── datasets_cookgan.py ├── eval_cookgan.py ├── datasets_cookgan1.py ├── train_cookgan.py └── models_cookgan.py ├── retrieval_model ├── modules │ ├── .gitignore │ ├── __init__.py │ ├── l2net.py │ ├── hardnet_loss.py │ └── dynamic_soft_margin.py ├── .gitignore ├── pretrain_upmc │ ├── .gitignore │ ├── README.md │ ├── utils_upmc.py │ ├── dataset_upmc.py │ └── train_upmc.py ├── run_retrieval.sh ├── train_word2vec.py ├── args_retrieval.py ├── utils_retrieval.py ├── triplet_loss.py ├── datasets_retrieval.py ├── val_retrieval.py ├── explore_attention.py ├── train_retrieval.py ├── eval_ingr_retrieval.py ├── models_retrieval.py └── models_retrieval.py.bak ├── .gitignore ├── requirements.txt ├── LICENSE ├── clean_recipes_with_canonical_ingrs.py ├── common.py └── README.md /metrics/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pkl -------------------------------------------------------------------------------- /cookgan/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | metrics/ 3 | wandb/ 4 | *.pkl -------------------------------------------------------------------------------- /retrieval_model/modules/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | metrics/ 3 | wandb/ -------------------------------------------------------------------------------- /retrieval_model/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | metrics/ 3 | wandb/ 4 | models/ -------------------------------------------------------------------------------- /retrieval_model/pretrain_upmc/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | metrics/ 3 | wandb/ 4 | UPMC-Food-101 5 | models/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .ipynb_checkpoints/ 3 | .DS_Store 4 | __pycache__/ 5 | 6 | explore_data.ipynb 7 | data/ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision==0.7.0 3 | numpy==1.18.5 4 | tqdm==4.47.0 5 | Pillow==7.2.0 6 | wandb==0.10.11 7 | gensim==3.8.3 -------------------------------------------------------------------------------- /retrieval_model/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .dynamic_soft_margin import DynamicSoftMarginLoss 2 | from .hardnet_loss import HardNetLoss 3 | from .l2net import L2Net 4 | -------------------------------------------------------------------------------- /retrieval_model/run_retrieval.sh: -------------------------------------------------------------------------------- 1 | python train_retrieval.py \ 2 | --recipe_file='/common/home/fh199/CookGAN/data/Recipe1M/original_withImage.json' \ 3 | --img_dir='/common/home/fh199/CookGAN/data/Recipe1M/images/' \ 4 | --word2vec_file='/common/home/fh199/CookGAN/retrieval_model/models/word2vec_recipes.bin' \ 5 | --text_info='010' \ 6 | --with_attention=2 \ 7 | --loss_type='hardmining+hinge' \ 8 | --batch_size=64 -------------------------------------------------------------------------------- /cookgan/run.sh: -------------------------------------------------------------------------------- 1 | python train_cookgan.py \ 2 | --recipe_file='/common/home/fh199/CookGAN/data/Recipe1M/recipes_withImage.json' \ 3 | --img_dir='/common/home/fh199/CookGAN/data/Recipe1M/images/' \ 4 | --retrieval_model='/common/home/fh199/CookGAN/retrieval_model/wandb/run-20201204_174135-6w1fft7l/files/00000000.ckpt' \ 5 | --levels=3 \ 6 | --food_type='salad' \ 7 | --base_size=64 \ 8 | --batch_size=16 \ 9 | --workers=16 -------------------------------------------------------------------------------- /metrics/configs/salad+cookgan.yaml: -------------------------------------------------------------------------------- 1 | retrieval_model: /common/home/fh199/CookGAN/retrieval_model/wandb/run-20201204_174135-6w1fft7l/files/00000000.ckpt 2 | ckpt_path: /common/home/fh199/CookGAN/cookgan/wandb/run-20201208_132237-179ysfpv/files/000000.ckpt 3 | ckpt_dir: /common/home/fh199/CookGAN/cookgan/wandb/run-20201208_132237-179ysfpv/files/ 4 | batch_size: 64 5 | n_sample: 5000 6 | inception: inception_Recipe1M_salad.pkl 7 | device: cuda -------------------------------------------------------------------------------- /retrieval_model/pretrain_upmc/README.md: -------------------------------------------------------------------------------- 1 | ## Prepare dataset 2 | Download UPMC-Food-101 from http://visiir.lip6.fr/ 3 | 4 | > 2021-06-16: UPMC-Food-101 has changed to a newer version, the one we used is [HERE](https://drive.google.com/drive/folders/1URwnLMVKx3avmUI0ITjxzgjFkpvmiS3Q?usp=sharing). 5 | 6 | ## Train 7 | ``` 8 | CUDA_VISIBLE_DEVICES=0 python train_upmc.py --data_dir=your/data_dir --batch_size=128 9 | ``` 10 | 11 | # Note 12 | The code is tested on one Tesla K80 with 12Gb memory (batch_size=128 takes about 12 Gb memory). You could also run it with multiple GPUs by: 13 | ``` 14 | CUDA_VISIBLE_DEVICES=0,1 python train_upmc.py --data_dir=your/data_dir --batch_size=256 15 | ``` 16 | 17 | -------------------------------------------------------------------------------- /retrieval_model/pretrain_upmc/utils_upmc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_classes(data_dir): 4 | with open(os.path.join(data_dir, 'meta/classes.txt')) as f: 5 | lines = f.readlines() 6 | return [line.rstrip() for line in lines] 7 | 8 | def gen_filelist(data_dir, part): 9 | assert part in ('train, test'), 'part should be train|test' 10 | classes = get_classes(data_dir) 11 | table = {c:i for i,c in enumerate(classes)} 12 | with open(os.path.join(data_dir, 'meta/'+part+'.txt'), 'r') as f: 13 | lines = f.readlines() 14 | output = '' 15 | for line in lines: 16 | c = line.split('/')[0] 17 | temp = line.rstrip() + '.jpg ' + str(table[c]) + '\n' 18 | output += temp 19 | with open(os.path.join(data_dir, part+'.txt'), 'w') as f1: 20 | f1.write(output) 21 | 22 | if __name__ == '__main__': 23 | for part in('train', 'test'): 24 | gen_filelist('UPMC-Food-101', part) -------------------------------------------------------------------------------- /metrics/utils_metrics.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from types import SimpleNamespace 3 | import argparse 4 | from torch.nn import functional as F 5 | import sys 6 | sys.path.append('../') 7 | import common 8 | 9 | def load_args(): 10 | parser = argparse.ArgumentParser(description='Calculate Inception v3 features for datasets') 11 | parser.add_argument('--config', type=str, default=f'{common.root}/metrics/configs/salad+cookgan.yaml') 12 | args = parser.parse_args() 13 | with open(args.config) as f: 14 | data = yaml.load(f, Loader=yaml.FullLoader) 15 | args = SimpleNamespace(**data) 16 | return args 17 | 18 | 19 | def normalize(img): 20 | img = (img-img.min())/(img.max()-img.min()) 21 | means = [0.485, 0.456, 0.406] 22 | stds = [0.229, 0.224, 0.225] 23 | for i in range(3): 24 | img[:,i] = (img[:,i]-means[i])/stds[i] 25 | return img 26 | 27 | def resize(img, size=224): 28 | return F.interpolate(img, size=(size, size), mode='bilinear', align_corners=False) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Fangda Han 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /retrieval_model/pretrain_upmc/dataset_upmc.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | 6 | def default_loader(path): 7 | return Image.open(path).convert('RGB') 8 | 9 | def default_flist_reader(flist): 10 | """ 11 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 12 | """ 13 | imlist = [] 14 | classes = set() 15 | with open(flist, 'r') as rf: 16 | for line in rf.readlines(): 17 | impath, imlabel = line.strip().split() 18 | imlist.append( (impath, int(imlabel)) ) 19 | classes.add(int(imlabel)) 20 | return imlist, classes 21 | 22 | class Dataset(data.Dataset): 23 | def __init__(self, root, flist, transform=None, target_transform=None, 24 | flist_reader=default_flist_reader, loader=default_loader): 25 | self.root = root 26 | self.imlist, self.classes = flist_reader(flist) 27 | self.transform = transform 28 | self.target_transform = target_transform 29 | self.loader = loader 30 | 31 | def __getitem__(self, index): 32 | impath, target = self.imlist[index] 33 | img = self.loader(os.path.join(self.root,impath)) 34 | if self.transform is not None: 35 | img = self.transform(img) 36 | if self.target_transform is not None: 37 | target = self.target_transform(target) 38 | 39 | return img, target 40 | 41 | def __len__(self): 42 | return len(self.imlist) -------------------------------------------------------------------------------- /cookgan/utils_cookgan.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | import json 3 | import sys 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import torchvision.utils as vutils 8 | 9 | def prepare_data(data, device): 10 | txt, imgs, w_imgs, _ = data 11 | real_vimgs, wrong_vimgs = [], [] 12 | for i in range(len(imgs)): 13 | real_vimgs.append(imgs[i].to(device)) 14 | wrong_vimgs.append(w_imgs[i].to(device)) 15 | vtxt = [x.to(device) for x in txt] 16 | return vtxt, real_vimgs, wrong_vimgs 17 | 18 | def compute_txt_feat(txt, txt_encoder): 19 | txt_feat, _ = txt_encoder(*txt) 20 | return txt_feat 21 | 22 | def compute_img_feat(img, img_encoder): 23 | mean = [0.485, 0.456, 0.406] 24 | std = [0.229, 0.224, 0.225] 25 | img = img/2 + 0.5 26 | img = F.interpolate(img, [224, 224], mode='bilinear', align_corners=True) 27 | for i in range(img.shape[1]): 28 | img[:,i] = (img[:,i]-mean[i])/std[i] 29 | feat = img_encoder(img) 30 | return feat 31 | 32 | def save_img_results(real_imgs, fake_imgs, save_dir, epoch, level=-1): 33 | num = 64 34 | real_img = real_imgs[level][0:num] 35 | fake_img = fake_imgs[level][0:num] 36 | real_fake = torch.stack([real_img, fake_img]).permute(1,0,2,3,4).contiguous() 37 | real_fake = real_fake.view(-1, real_fake.shape[-3], real_fake.shape[-2], real_fake.shape[-1]) 38 | vutils.save_image( 39 | real_fake, 40 | '{}/e{}_real_fake.png'.format(save_dir, epoch), 41 | normalize=True, scale_each=True) 42 | real_fake = vutils.make_grid(real_fake, normalize=True, scale_each=True) 43 | vutils.save_image( 44 | fake_img, 45 | '{}/e{}_fake_samples.png'.format(save_dir, epoch), 46 | normalize=True, scale_each=True) 47 | return real_fake -------------------------------------------------------------------------------- /retrieval_model/train_word2vec.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from common import load_recipes 4 | from tqdm import tqdm 5 | import os 6 | from gensim.models import Word2Vec 7 | from gensim.models.callbacks import CallbackAny2Vec 8 | import argparse 9 | import pdb 10 | 11 | parser = argparse.ArgumentParser(description='train word2vec model') 12 | parser.add_argument('--recipe_file', default='../data/Recipe1M/recipes.json') 13 | args = parser.parse_args() 14 | 15 | print('Load documents...') 16 | file_path = args.recipe_file 17 | recipes = load_recipes(file_path, 'train') 18 | 19 | print('Tokenize...') 20 | all_sentences = [] 21 | for entry in tqdm(recipes): 22 | all_sentences.append(entry['title'].split()) 23 | insts = entry['instructions'] 24 | sentences = [inst.split() for inst in insts] 25 | all_sentences.extend(sentences) 26 | all_sentences.append(entry['ingredients']) 27 | print('number of sentences =', len(all_sentences)) 28 | 29 | print('Train Word2Vec model...') 30 | class EpochLogger(CallbackAny2Vec): 31 | '''Callback to log information about training''' 32 | def __init__(self): 33 | self.epoch = 0 34 | def on_epoch_begin(self, model): 35 | print('-' * 40) 36 | print("Epoch #{} start".format(self.epoch)) 37 | print('vocab_size = {}'.format(len(model.wv.index2word))) 38 | def on_epoch_end(self, model): 39 | print('total_train_time = {:.2f} s'.format(model.total_train_time)) 40 | print('loss = {:.2f}'.format(model.get_latest_training_loss())) 41 | print("Epoch #{} end".format(self.epoch)) 42 | self.epoch += 1 43 | 44 | epoch_logger = EpochLogger() 45 | model = Word2Vec( 46 | all_sentences, size=300, window=10, min_count=10, 47 | workers=20, iter=10, callbacks=[epoch_logger], 48 | compute_loss=True) 49 | 50 | suffix = os.path.basename(file_path).split('.')[0] 51 | if not os.path.exists('models'): 52 | os.makedirs('models') 53 | 54 | model.wv.save(os.path.join('models/word2vec_{}.bin'.format(suffix))) 55 | -------------------------------------------------------------------------------- /cookgan/args_cookgan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | sys.path.append('../') 4 | from common import root 5 | 6 | def get_parser(): 7 | parser = argparse.ArgumentParser(description='Train a GAN network') 8 | 9 | parser.add_argument('--seed', type=int, default=8) 10 | parser.add_argument("--device", type=str, default='cuda', choices=['cuda', 'cpu']) 11 | parser.add_argument('--workers', type=int, default=16) 12 | parser.add_argument('--num_batches', type=int, default=200_000) 13 | 14 | parser.add_argument('--batch_size', type=int, default=24) 15 | parser.add_argument('--base_size', type=int, default=64) 16 | 17 | parser.add_argument('--input_dim', type=int, default=1024) 18 | parser.add_argument('--embedding_dim', type=int, default=128) 19 | parser.add_argument('--z_dim', type=int, default=100) 20 | 21 | parser.add_argument('--labels', type=str, default='original', choices=['original', 'R-smooth', 'R-flip', 'R-flip-smooth']) 22 | parser.add_argument("--input_noise", type=int, default=0) 23 | parser.add_argument('--uncond', type=float, default=1.0) 24 | parser.add_argument('--cycle_txt', type=float, default=0.0) 25 | parser.add_argument('--cycle_img', type=float, default=0.0) 26 | # parser.add_argument('--tri_loss', type=float, default=0.0) 27 | parser.add_argument('--kl', type=float, default=2.0) 28 | 29 | parser.add_argument('--lr_g', type=float, default=2e-4) 30 | parser.add_argument('--lr_d', type=float, default=2e-4) 31 | 32 | 33 | parser.add_argument('--ckpt_path', type=str, default='') 34 | parser.add_argument('--food_type', type=str, default='salad') 35 | 36 | parser.add_argument('--recipe_file', type=str, default=f'{root}/data/Recipe1M/recipes_withImage.json') 37 | parser.add_argument('--img_dir', type=str, default=f'{root}/data/Recipe1M/images') 38 | parser.add_argument('--levels', type=int, default=3) 39 | parser.add_argument('--retrieval_model', type=str, default=f'{root}/retrieval_model/wandb/run-20201204_174135-6w1fft7l/files/00000000.ckpt') 40 | parser.add_argument('--word2vec_file', type=str, default=f'{root}/retrieval_model/models/word2vec_recipes.bin') 41 | 42 | parser.add_argument("--debug", type=int, default=0) 43 | 44 | # These are only for test_StackGAN.py 45 | parser.add_argument('--level', type=int, default=2) 46 | 47 | # These are only for interpolate.py 48 | parser.add_argument('--key_ingr', type=str, default='tomato') 49 | return parser -------------------------------------------------------------------------------- /retrieval_model/modules/l2net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L2Net(nn.Module): 7 | """ 8 | L2-Net: Deep Learning of Discriminative Patch Descriptor in Euclidean Space 9 | """ 10 | 11 | def __init__(self, out_dim=128, binary=False, dropout_rate=0.1): 12 | super().__init__() 13 | self._binary = binary 14 | 15 | self.features = nn.Sequential( 16 | nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False), 17 | nn.BatchNorm2d(32, affine=False), 18 | nn.ReLU(), 19 | nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), 20 | nn.BatchNorm2d(32, affine=False), 21 | nn.ReLU(), 22 | nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False), 23 | nn.BatchNorm2d(64, affine=False), 24 | nn.ReLU(), 25 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 26 | nn.BatchNorm2d(64, affine=False), 27 | nn.ReLU(), 28 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False), 29 | nn.BatchNorm2d(128, affine=False), 30 | nn.ReLU(), 31 | nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), 32 | nn.BatchNorm2d(128, affine=False), 33 | nn.ReLU(), 34 | nn.Dropout(dropout_rate), 35 | nn.Conv2d(128, out_dim, kernel_size=8, bias=False), 36 | nn.BatchNorm2d(out_dim, affine=False), 37 | ) 38 | 39 | if self._binary: 40 | self.binarizer = nn.Tanh() 41 | 42 | self.features.apply(weights_init) 43 | 44 | def input_norm(self, x): 45 | flat = x.view(x.size(0), -1) 46 | mp = torch.mean(flat, dim=1) 47 | sp = torch.std(flat, dim=1) + 1e-7 48 | return ( 49 | x - mp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x) 50 | ) / sp.detach().unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x) 51 | 52 | def forward(self, input): 53 | input = self.input_norm(input) 54 | x = self.features(input) 55 | x = x.view(x.size(0), -1) 56 | if self._binary: 57 | return self.binarizer(x) 58 | else: 59 | return F.normalize(x, p=2, dim=1) 60 | 61 | 62 | def weights_init(m): 63 | if isinstance(m, nn.Conv2d): 64 | nn.init.orthogonal_(m.weight.data, gain=0.6) 65 | try: 66 | nn.init.constant_(m.bias.data, 0.01) 67 | except: 68 | pass 69 | return 70 | -------------------------------------------------------------------------------- /retrieval_model/args_retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | sys.path.append('../') 4 | from common import root 5 | 6 | def get_parser(): 7 | parser = argparse.ArgumentParser(description='retrieval model parameters') 8 | parser.add_argument('--seed', default=8, type=int) 9 | parser.add_argument('--workers', default=16, type=int) 10 | parser.add_argument("--device", type=str, default='cuda', choices=['cuda', 'cpu']) 11 | parser.add_argument('--word2vec_dim', default=300, type=int) 12 | parser.add_argument('--rnn_hid_dim', default=300, type=int) 13 | parser.add_argument('--feature_dim', default=1024, type=int) 14 | parser.add_argument('--batch_size', default=64, type=int) 15 | parser.add_argument('--batches', default=400_000, type=int) 16 | parser.add_argument('--lr', default=1e-4, type=float) 17 | parser.add_argument('--margin', default=0.3, type=float) 18 | parser.add_argument('--classes_file', default=f'{root}/data/Recipe1M/classes1M.pkl') 19 | parser.add_argument('--img_dir', default=f'{root}/data/Recipe1M/images') 20 | 21 | parser.add_argument('--retrieved_type', default='recipe', choices=['recipe', 'image']) 22 | parser.add_argument('--retrieved_range', default=1000, type=int) 23 | parser.add_argument('--val_freq', default=1, type=int) 24 | parser.add_argument('--save_freq', default=1, type=int) 25 | parser.add_argument('--ckpt_path', default='') 26 | 27 | parser.add_argument('--loss_type', default='hardmining+hinge', choices=['hinge', 'hardmining+hinge', 'dynamic_soft_margin']) 28 | # TODO: train on other modalities 29 | parser.add_argument('--text_info', default='010', choices=['111', '010', '100', '001'], 30 | help='3-bit to represent [title, ingredients, instructions]') 31 | parser.add_argument('--word2vec_file', default=f'{root}/retrieval_model/models/word2vec_recipes.bin') 32 | parser.add_argument('--recipe_file', default=f'{root}/data/Recipe1M/recipes_withImage.json') 33 | parser.add_argument('--ingrs_enc_type', default='rnn', choices=['rnn', 'fc']) 34 | # upmc 35 | parser.add_argument('--upmc_model', default=f'') 36 | # permute ingredients 37 | parser.add_argument("--permute_ingrs", type=int, default=0, choices=[0,1], help="permute ingredients") 38 | # self attention on text 39 | parser.add_argument("--with_attention", type=int, default=2, choices=[0,1,2]) 40 | 41 | # in debug mode 42 | parser.add_argument("--debug", type=int, default=0, choices=[0,1], help="in debug mode or not") 43 | 44 | # val_retrieval.py 45 | parser.add_argument('--ckpt_dir', default='') 46 | 47 | # These are only for predict_key_ingredients.py 48 | parser.add_argument('--food_type', type=str, default='salad') 49 | parser.add_argument('--key_ingr', type=str, default='tomato') 50 | return parser -------------------------------------------------------------------------------- /retrieval_model/utils_retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import json 5 | from matplotlib import pyplot as plt 6 | 7 | def compute_statistics(rcps, imgs, retrieved_type='recipe', retrieved_range=1000, draw_hist=False, verbose=True): 8 | if verbose: 9 | print('retrieved_range =', retrieved_range) 10 | N = retrieved_range 11 | data_size = imgs.shape[0] 12 | glob_medR = [] 13 | glob_recall = {1:0.0, 5:0.0, 10:0.0} 14 | if draw_hist: 15 | plt.figure(figsize=(16, 6)) 16 | # average over 10 sets 17 | for i in range(10): 18 | ids_sub = np.random.choice(data_size, N, replace=False) 19 | imgs_sub = imgs[ids_sub, :] 20 | rcps_sub = rcps[ids_sub, :] 21 | imgs_sub = imgs_sub / np.linalg.norm(imgs_sub, axis=1)[:, None] 22 | rcps_sub = rcps_sub / np.linalg.norm(rcps_sub, axis=1)[:, None] 23 | if retrieved_type == 'recipe': 24 | queries = imgs_sub 25 | values = rcps_sub 26 | else: 27 | queries = rcps_sub 28 | values = imgs_sub 29 | ranks = compute_ranks(queries, values) 30 | recall = {1:0.0, 5:0.0, 10:0.0} 31 | recall[1] = (ranks<=1).sum() / N 32 | recall[5] = (ranks<=5).sum() / N 33 | recall[10] = (ranks<=10).sum() / N 34 | medR = int(np.median(ranks)) 35 | for ii in recall.keys(): 36 | glob_recall[ii] += recall[ii] 37 | glob_medR.append(medR) 38 | if draw_hist: 39 | ranks = np.array(ranks) 40 | plt.subplot(2,5,i+1) 41 | n, bins, patches = plt.hist(x=ranks, bins='auto', color='#0504aa', alpha=0.7, rwidth=0.85) 42 | plt.grid(axis='y', alpha=0.75) 43 | plt.ylim(top=300) 44 | # plt.xlabel('Rank') 45 | # plt.ylabel('Frequency') 46 | # plt.title('Rank Distribution') 47 | plt.text(23, 45, 'avgR(std) = {:.2f}({:.2f})\nmedR={:.2f}\n#<{:d}:{:d}|#={:d}:{:d}|#>{:d}:{:d}'.format( 48 | np.mean(ranks), np.std(ranks), np.median(ranks), 49 | medR,(ranksmedR).sum())) 50 | if draw_hist: 51 | plt.savefig('hist.png') 52 | 53 | for i in glob_recall.keys(): 54 | glob_recall[i] = glob_recall[i]/10 55 | 56 | glob_medR = np.array(glob_medR) 57 | if verbose: 58 | print('MedR = {:.4f}({:.4f})'.format(glob_medR.mean(), glob_medR.std())) 59 | for k,v in glob_recall.items(): 60 | print('Recall@{} = {:.4f}'.format(k, v)) 61 | return glob_medR, glob_recall 62 | 63 | def compute_ranks(queries, values): 64 | """compute the ranks for queries and values 65 | 66 | Arguments: 67 | queries {np.array} -- text feats (or image feats) 68 | values {np.array} -- image feats (or text feats) 69 | """ 70 | sims = np.dot(queries, values.T) 71 | ranks = [] 72 | # loop through the N similarities for images 73 | for ii in range(sims.shape[0]): 74 | # get a column of similarities for image ii 75 | sim = sims[ii,:] 76 | # sort indices in descending order 77 | sorting = np.argsort(sim)[::-1].tolist() 78 | # find where the index of the pair sample ended up in the sorting 79 | pos = sorting.index(ii) 80 | # store the position 81 | ranks.append(pos+1) 82 | return np.array(ranks) -------------------------------------------------------------------------------- /cookgan/generate_batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._C import device 3 | from torchvision import transforms 4 | from PIL import Image 5 | import os 6 | import lmdb 7 | from torch.utils import data 8 | import json 9 | from io import BytesIO 10 | import numpy as np 11 | 12 | import sys 13 | sys.path.append('../') 14 | import common 15 | sys.path.append('../retrieval_model') 16 | import train_retrieval 17 | sys.path.append('../cookgan') 18 | import train_cookgan 19 | from utils_cookgan import compute_txt_feat 20 | from datasets_cookgan import FoodDataset 21 | 22 | class BatchGenerator(): 23 | def __init__(self, args): 24 | device = args.device 25 | _, _, txt_encoder, _, _ = train_retrieval.load_model(args.retrieval_model, device) 26 | ckpt_args, _, netG, _, _, _ = train_cookgan.load_model(args.ckpt_path, device) 27 | netG = netG.eval().to(device) 28 | 29 | txt_encoder = txt_encoder.eval().to(device) 30 | 31 | imsize = ckpt_args.base_size * (2 ** (ckpt_args.levels-1)) 32 | train_transform = transforms.Compose([ 33 | transforms.Resize(int(imsize * 76 / 64)), 34 | transforms.CenterCrop(imsize)]) 35 | dataset = FoodDataset( 36 | recipe_file=ckpt_args.recipe_file, 37 | img_dir=ckpt_args.img_dir, 38 | levels=ckpt_args.levels, 39 | part='val', 40 | food_type=ckpt_args.food_type, 41 | base_size=ckpt_args.base_size, 42 | transform=train_transform) 43 | dataloader = torch.utils.data.DataLoader( 44 | dataset, batch_size=args.batch_size, num_workers=4) 45 | 46 | self.ckpt_args = ckpt_args 47 | self.netG = netG 48 | self.txt_encoder = txt_encoder 49 | self.dataloader = dataloader 50 | self.batch_size = args.batch_size 51 | self.device = device 52 | self.fixed_noise = torch.randn(self.batch_size, self.ckpt_args.z_dim).to(self.device) 53 | 54 | def generate_fid(self): 55 | _, _, batch_fake_img = self.generate_all() 56 | return batch_fake_img 57 | 58 | def generate_MedR(self): 59 | batch_txt, _, batch_fake_img = self.generate_all() 60 | return batch_txt, batch_fake_img 61 | 62 | def generate_all(self): 63 | batch_txt, batch_imgs, _, _ = next(common.sample_data(self.dataloader)) 64 | with torch.no_grad(): 65 | txt_feat = compute_txt_feat(batch_txt, self.txt_encoder) 66 | fakes, _, _ = self.netG(self.fixed_noise, txt_feat) 67 | return batch_txt, batch_imgs[-1], fakes[-1] 68 | # batch_txt: CookGAN txt input 69 | # batch_img: [BS, 3, size, size] 70 | # batch_fake_img: [BS, 3, size, size] 71 | 72 | 73 | if __name__ == '__main__': 74 | import pdb 75 | from types import SimpleNamespace 76 | args = SimpleNamespace( 77 | ckpt_path=f'{common.root}/cookgan/wandb/run-20201208_132237-179ysfpv/files/000000.ckpt', 78 | retrieval_model=f'{common.root}/retrieval_model/wandb/run-20201204_174135-6w1fft7l/files/00000000.ckpt', 79 | batch_size=16, 80 | size=256, 81 | device='cuda', 82 | ) 83 | 84 | # Recipe1M 85 | batch_generator = BatchGenerator(args) 86 | 87 | txt, img, fake_img = batch_generator.generate_all() 88 | print(len(txt)) 89 | for t in txt: 90 | print(t.shape) 91 | print(img.shape) 92 | print(fake_img.shape) 93 | -------------------------------------------------------------------------------- /retrieval_model/modules/hardnet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def compute_distance_matrix_unit_l2(a, b, eps=1e-6): 6 | """ 7 | computes pairwise Euclidean distance and return a N x N matrix 8 | """ 9 | 10 | dmat = torch.matmul(a, torch.transpose(b, 0, 1)) 11 | dmat = ((1.0 - dmat + eps) * 2.0).pow(0.5) 12 | return dmat 13 | 14 | 15 | def compute_distance_matrix_hamming(a, b): 16 | """ 17 | computes pairwise Hamming distance and return a N x N matrix 18 | """ 19 | 20 | dims = a.size(1) 21 | dmat = torch.matmul(a, torch.transpose(b, 0, 1)) 22 | dmat = (dims - dmat) * 0.5 23 | return dmat 24 | 25 | 26 | def find_hard_negatives(dmat, output_index=True, empirical_thresh=0.0): 27 | """ 28 | a = A * P' 29 | A: N * ndim 30 | P: N * ndim 31 | 32 | a1p1 a1p2 a1p3 a1p4 ... 33 | a2p1 a2p2 a2p3 a2p4 ... 34 | a3p1 a3p2 a3p3 a3p4 ... 35 | a4p1 a4p2 a4p3 a4p4 ... 36 | ... ... ... ... 37 | """ 38 | 39 | cnt = dmat.size(0) 40 | 41 | if not output_index: 42 | pos = dmat.diag() 43 | 44 | dmat = dmat + torch.eye(cnt).to(dmat.device) * 99999 # filter diagonal 45 | # dmat[dmat < empirical_thresh] = 99999 # filter outliers in brown dataset 46 | min_a, min_a_idx = torch.min(dmat, dim=0) 47 | min_p, min_p_idx = torch.min(dmat, dim=1) 48 | 49 | if not output_index: 50 | neg = torch.min(min_a, min_p) 51 | # import pdb; pdb.set_trace() 52 | return pos, neg 53 | 54 | mask = min_a < min_p 55 | a_idx = torch.cat( 56 | (mask.nonzero().view(-1) + cnt, (~mask).nonzero().view(-1)) 57 | ) # use p as anchor 58 | p_idx = torch.cat( 59 | (mask.nonzero().view(-1), (~mask).nonzero().view(-1) + cnt) 60 | ) # use a as anchor 61 | n_idx = torch.cat((min_a_idx[mask], min_p_idx[~mask] + cnt)) 62 | return a_idx, p_idx, n_idx 63 | 64 | 65 | def approx_hamming_distance(a, p): 66 | return (1.0 - a * p).sum(dim=1) * 0.5 67 | 68 | 69 | class HardNetLoss(nn.Module): 70 | def __init__(self, margin, is_binary): 71 | super().__init__() 72 | self._margin = margin 73 | self._is_binary = is_binary 74 | 75 | def _forward_float(self, x): 76 | cnt = x.size(0) // 2 77 | a = x[:cnt, :] 78 | p = x[cnt:, :] 79 | 80 | dmat = compute_distance_matrix_unit_l2(a, p) 81 | pos, neg = find_hard_negatives(dmat, output_index=False, empirical_thresh=0.008) 82 | return (self._margin - neg + pos).clamp(0).mean() 83 | 84 | def _forward_binary(self, x): 85 | cnt = x.size(0) // 2 86 | ndim = x.size(1) 87 | a = x[:cnt, :] 88 | p = x[cnt:, :] 89 | 90 | dmat = compute_distance_matrix_hamming( 91 | (a > 0).float() * 2.0 - 1.0, (p > 0).float() * 2.0 - 1.0 92 | ) 93 | a_idx, p_idx, n_idx = find_hard_negatives( 94 | dmat, output_index=True, empirical_thresh=2 95 | ) 96 | 97 | a = x[a_idx, :] 98 | p = x[p_idx, :] 99 | n = x[n_idx, :] 100 | 101 | pos_dist = approx_hamming_distance(a, p) 102 | neg_dist = approx_hamming_distance(a, n) 103 | 104 | pos_dist = pos_dist / ndim 105 | neg_dist = neg_dist / ndim 106 | 107 | return (self._margin - neg_dist + pos_dist).clamp(0).mean() 108 | 109 | def forward(self, x): 110 | if self._is_binary: 111 | return self._forward_binary(x) 112 | return self._forward_float(x) 113 | -------------------------------------------------------------------------------- /retrieval_model/triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | import pdb 6 | 7 | class TripletLoss(object): 8 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 9 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 10 | Loss for Person Re-Identification'.""" 11 | def __init__(self, margin=None): 12 | self.margin = margin 13 | if margin is not None: 14 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 15 | else: 16 | self.ranking_loss = nn.SoftMarginLoss() 17 | 18 | def __call__(self, dist_ap, dist_an): 19 | y = torch.ones_like(dist_ap) 20 | if self.margin is not None: 21 | loss = self.ranking_loss(dist_an, dist_ap, y) 22 | else: 23 | loss = self.ranking_loss(dist_an - dist_ap, y) 24 | return loss 25 | 26 | def normalize(x, axis=-1): 27 | 28 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 29 | return x 30 | 31 | 32 | def euclidean_dist(x, y): 33 | """ 34 | Args: 35 | x: pytorch Variable, with shape [m, d] 36 | y: pytorch Variable, with shape [n, d] 37 | Returns: 38 | dist: pytorch Variable, with shape [m, n] 39 | """ 40 | m, n = x.size(0), y.size(0) 41 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 42 | # import pdb; pdb.set_trace() 43 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 44 | dist = xx + yy 45 | # dist.addmm_(1, -2, x, y.t()) 46 | dist.addmm_(x, y.t(), beta=1, alpha=-2) 47 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 48 | return dist 49 | 50 | 51 | def hard_example_mining(dist_mat, labels, return_inds=False): 52 | 53 | assert len(dist_mat.size()) == 2 54 | assert dist_mat.size(0) == dist_mat.size(1) 55 | N = dist_mat.size(0) 56 | 57 | # shape [N, N] 58 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 59 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 60 | 61 | # `dist_ap` means distance(anchor, positive) 62 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 63 | dist_ap, relative_p_inds = torch.max( 64 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 65 | # `dist_an` means distance(anchor, negative) 66 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 67 | dist_an, relative_n_inds = torch.min( 68 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 69 | # shape [N] 70 | dist_ap = dist_ap.squeeze(1) 71 | dist_an = dist_an.squeeze(1) 72 | # import pdb; pdb.set_trace() 73 | if return_inds: 74 | # shape [N, N] 75 | ind = (labels.new().resize_as_(labels) 76 | .copy_(torch.arange(0, N).long()) 77 | .unsqueeze( 0).expand(N, N)) 78 | # shape [N, 1] 79 | p_inds = torch.gather( 80 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 81 | n_inds = torch.gather( 82 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 83 | # shape [N] 84 | p_inds = p_inds.squeeze(1) 85 | n_inds = n_inds.squeeze(1) 86 | return dist_ap, dist_an, p_inds, n_inds 87 | 88 | return dist_ap, dist_an 89 | 90 | 91 | def global_loss(tri_loss, global_feat, labels, normalize_feature=False): 92 | if normalize_feature: 93 | global_feat = normalize(global_feat, axis=-1) 94 | # shape [N, N] 95 | dist_mat = euclidean_dist(global_feat, global_feat) 96 | # import pdb; pdb.set_trace() 97 | dist_ap, dist_an = hard_example_mining( 98 | dist_mat, labels, return_inds=False) 99 | loss = tri_loss(dist_ap, dist_an) 100 | return loss, dist_ap, dist_an, dist_mat -------------------------------------------------------------------------------- /clean_recipes_with_canonical_ingrs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import json 4 | import argparse 5 | import pickle 6 | import re 7 | from common import root, tok, remove_numbers 8 | import common 9 | import numpy as np 10 | 11 | parser = argparse.ArgumentParser( 12 | description='clean recipes') 13 | parser.add_argument( 14 | '--data_dir', default=f'{root}/data/Recipe1M', 15 | help='the folder which contains Recipe1M text files') 16 | parser.add_argument("--lower", type=int, default=0, choices=[0,1]) 17 | parser.add_argument("--remove_numbers", type=int, default=0, choices=[0,1]) 18 | args = parser.parse_args() 19 | data_dir = args.data_dir 20 | 21 | print('load recipes (20 seconds)') 22 | recipes_original = common.Layer.merge( 23 | [common.Layer.L1, common.Layer.L2, common.Layer.INGRS], 24 | os.path.join(data_dir, 'texts')) 25 | 26 | for rcp in recipes_original: 27 | rcp['instructions'] = [x['text'] for x in rcp['instructions']] 28 | rcp['ingredients'] = [x['text'] for x in rcp['ingredients']] 29 | 30 | with open(f'{root}/manual_files/replacement_dict.pkl', 'rb') as f: 31 | replace_dict = pickle.load(f) 32 | 33 | print('start processing') 34 | cvgs = [] 35 | recipes = [] 36 | recipes_withImage = [] 37 | for rcp in tqdm(recipes_original): 38 | insts = [] 39 | for inst in rcp['instructions']: 40 | # words = tok(inst['text']).split() 41 | words = tok(inst).split() 42 | inst_ = ' '.join(words) 43 | insts.append(inst_) 44 | insts = '\n'.join(insts) 45 | if len(insts) == 0: 46 | continue 47 | 48 | title = rcp['title'] 49 | words = tok(title).split() 50 | title = ' '.join(words) 51 | 52 | if args.lower: 53 | insts = insts.lower() 54 | title = title.lower() 55 | if args.remove_numbers: 56 | insts = remove_numbers(insts) 57 | title = remove_numbers(title) 58 | 59 | ingrs = [] 60 | N = len(rcp['ingredients']) 61 | n = 0 62 | for ingr in rcp['ingredients']: 63 | # 1. add 'space' before and after 12 punctuation 64 | # 2. change 'space' to 'underscore' 65 | # ingr_name = ingr['text'] 66 | ingr_name = ingr 67 | name = re.sub(' +', ' ', tok(ingr_name)).replace(' ', '_') 68 | if name in replace_dict: 69 | final_name = replace_dict[name] 70 | ingrs.append(final_name) 71 | name1 = final_name.replace('_',' ') 72 | if args.lower: 73 | ingr_name = ingr_name.lower() 74 | name1 = name1.lower() 75 | insts = insts.replace(ingr_name, final_name) 76 | insts = insts.replace(name1, final_name) 77 | title = title.replace(ingr_name, final_name) 78 | title = title.replace(name1, final_name) 79 | n += 1 80 | 81 | if n==0: 82 | print('no ingredients, discard') 83 | continue 84 | cvg = n/N 85 | cvgs.append(cvg) 86 | 87 | rcp['title'] = title 88 | rcp['ingredients'] = ingrs 89 | rcp['instructions'] = insts.split('\n') 90 | recipes.append(rcp) 91 | if 'images' in rcp and len(rcp['images'])>0: 92 | recipes_withImage.append(rcp) 93 | 94 | cvgs = np.array(cvgs) 95 | print('cvg = {:.2f} -- {:.2f}'.format(cvgs.mean(), cvgs.std())) 96 | print(len(recipes), len(recipes_withImage)) 97 | 98 | print('saving...') 99 | if args.lower and not args.remove_numbers: 100 | filename = 'recipes_lower' 101 | elif not args.lower and args.remove_numbers: 102 | filename = 'recipes_noNumbers' 103 | elif args.remove_numbers and args.lower: 104 | filename = 'recipes_lower_noNumbers' 105 | else: 106 | filename = 'recipes' 107 | 108 | with open(os.path.join(data_dir, '{}.json'.format(filename)), 'w') as f: 109 | json.dump(recipes, f, indent=2) 110 | 111 | with open(os.path.join(data_dir, '{}_withImage.json'.format(filename)), 'w') as f: 112 | json.dump(recipes_withImage, f, indent=2) -------------------------------------------------------------------------------- /metrics/medR.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | import pdb 8 | import os 9 | import csv 10 | from glob import glob 11 | import math 12 | from torchvision import transforms 13 | from torch.nn import functional as F 14 | from matplotlib import pyplot as plt 15 | import sys 16 | sys.path.append('../retrieval_model') 17 | from utils_retrieval import compute_statistics 18 | import train_retrieval 19 | sys.path.append('../') 20 | from common import requires_grad 21 | 22 | if __name__ == '__main__': 23 | from utils_metrics import load_args, normalize, resize 24 | args = load_args() 25 | 26 | # assertations 27 | assert 'ckpt_dir' in args.__dict__ 28 | assert 'retrieval_model' in args.__dict__ 29 | assert 'device' in args.__dict__ 30 | assert 'batch_size' in args.__dict__ 31 | 32 | sys.path.append('../cookgan/') 33 | from generate_batch import BatchGenerator 34 | 35 | device = args.device 36 | _, _, txt_encoder, img_encoder, _ = train_retrieval.load_model(args.retrieval_model, device) 37 | requires_grad(txt_encoder, False) 38 | requires_grad(img_encoder, False) 39 | txt_encoder = txt_encoder.eval() 40 | img_encoder = img_encoder.eval() 41 | 42 | filename = os.path.join(args.ckpt_dir, 'medR.csv') 43 | 44 | # load values that are already computed 45 | computed = [] 46 | if os.path.exists(filename): 47 | with open(filename, 'r') as f: 48 | reader = csv.reader(f, delimiter=',') 49 | for row in reader: 50 | computed += [row[0]] 51 | 52 | # prepare to write 53 | f = open(filename, mode='a') 54 | writer = csv.writer(f, delimiter=',') 55 | 56 | # find checkpoints 57 | ckpt_paths = glob(os.path.join(args.ckpt_dir, '*.ckpt')) + glob(os.path.join(args.ckpt_dir, '*.pt'))+glob(os.path.join(args.ckpt_dir, '*.pth')) 58 | ckpt_paths = sorted(ckpt_paths) 59 | print('records:', ckpt_paths) 60 | print('computed:', computed) 61 | for ckpt_path in ckpt_paths: 62 | print() 63 | print(f'working on {ckpt_path}') 64 | iteration = os.path.basename(ckpt_path).split('.')[0] 65 | if iteration in computed: 66 | print('already computed') 67 | continue 68 | 69 | print('==> computing MedR') 70 | args.ckpt_path = ckpt_path 71 | batch_generator = BatchGenerator(args) 72 | 73 | txt_outputs = [] 74 | img_outputs = [] 75 | with torch.no_grad(): 76 | for _ in tqdm(range(1000//args.batch_size+1)): 77 | # generate 78 | txt, fake_img = batch_generator.generate_MedR() 79 | # fake_img: normalize 80 | fake_img = normalize(fake_img) 81 | # fake_img: resize 82 | fake_img = resize(fake_img, size=224) 83 | # retrieve 84 | txt_output, _ = txt_encoder(*txt) 85 | img_output = img_encoder(fake_img) 86 | txt_outputs.append(txt_output.detach().cpu()) 87 | img_outputs.append(img_output.detach().cpu()) 88 | txt_outputs = torch.cat(txt_outputs, dim=0).numpy() 89 | img_outputs = torch.cat(img_outputs, dim=0).numpy() 90 | retrieved_range = min(txt_outputs.shape[0], 1000) 91 | medR, recalls = compute_statistics( 92 | txt_outputs, img_outputs, retrieved_type='image', 93 | retrieved_range=retrieved_range, verbose=True) 94 | print(f'{iteration}, MedR={medR.mean()}') 95 | writer.writerow ([iteration, medR.mean()]) 96 | 97 | f.close() 98 | medRs = [] 99 | with open(filename, 'r') as f: 100 | reader = csv.reader(f, delimiter=',') 101 | for row in reader: 102 | medR = float(row[1]) 103 | medRs += [medR] 104 | fig = plt.figure(figsize=(6,6)) 105 | plt.plot(medRs) 106 | plt.savefig(os.path.join(args.ckpt_dir, 'medR.png')) -------------------------------------------------------------------------------- /metrics/calc_inception.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.models import inception_v3, Inception3 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from inception import InceptionV3 15 | 16 | import sys 17 | sys.path.append('../') 18 | import common 19 | sys.path.append('../retrieval_model') 20 | import train_retrieval 21 | sys.path.append('../cookgan') 22 | import train_cookgan 23 | from utils_cookgan import compute_txt_feat 24 | from datasets_cookgan import FoodDataset 25 | 26 | 27 | class Inception3Feature(Inception3): 28 | def forward(self, x): 29 | if x.shape[2] != 299 or x.shape[3] != 299: 30 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 31 | 32 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 33 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 34 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 35 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 36 | 37 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 38 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 39 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 40 | 41 | x = self.Mixed_5b(x) # 35 x 35 x 192 42 | x = self.Mixed_5c(x) # 35 x 35 x 256 43 | x = self.Mixed_5d(x) # 35 x 35 x 288 44 | 45 | x = self.Mixed_6a(x) # 35 x 35 x 288 46 | x = self.Mixed_6b(x) # 17 x 17 x 768 47 | x = self.Mixed_6c(x) # 17 x 17 x 768 48 | x = self.Mixed_6d(x) # 17 x 17 x 768 49 | x = self.Mixed_6e(x) # 17 x 17 x 768 50 | 51 | x = self.Mixed_7a(x) # 17 x 17 x 768 52 | x = self.Mixed_7b(x) # 8 x 8 x 1280 53 | x = self.Mixed_7c(x) # 8 x 8 x 2048 54 | 55 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 56 | 57 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 58 | 59 | 60 | def load_patched_inception_v3(): 61 | # inception = inception_v3(pretrained=True) 62 | # inception_feat = Inception3Feature() 63 | # inception_feat.load_state_dict(inception.state_dict()) 64 | inception_feat = InceptionV3([3], normalize_input=False) 65 | 66 | return inception_feat 67 | 68 | 69 | @torch.no_grad() 70 | def extract_features(loader, inception, device): 71 | pbar = tqdm(loader) 72 | 73 | feature_list = [] 74 | 75 | for _, imgs, _, _ in pbar: 76 | img = imgs[-1].to(device) 77 | feature = inception(img)[0].view(img.shape[0], -1) 78 | feature_list.append(feature.to('cpu')) 79 | 80 | features = torch.cat(feature_list, 0) 81 | 82 | return features 83 | 84 | 85 | if __name__ == '__main__': 86 | from utils_metrics import load_args 87 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 88 | args = load_args() 89 | 90 | _, _, txt_encoder, _, _ = train_retrieval.load_model(args.retrieval_model, device) 91 | txt_encoder = txt_encoder.eval().to(device) 92 | ckpt_args, _, netG, _, _, _ = train_cookgan.load_model(args.ckpt_path, device) 93 | netG = netG.eval().to(device) 94 | 95 | inception = load_patched_inception_v3() 96 | inception = nn.DataParallel(inception).eval().to(device) 97 | 98 | imsize = ckpt_args.base_size * (2 ** (ckpt_args.levels-1)) 99 | train_transform = transforms.Compose([ 100 | transforms.Resize(int(imsize * 76 / 64)), 101 | transforms.CenterCrop(imsize)]) 102 | dataset = FoodDataset( 103 | recipe_file=ckpt_args.recipe_file, 104 | img_dir=ckpt_args.img_dir, 105 | levels=ckpt_args.levels, 106 | part='val', 107 | food_type=ckpt_args.food_type, 108 | base_size=ckpt_args.base_size, 109 | transform=train_transform) 110 | 111 | dataset_name = 'Recipe1M' 112 | if ckpt_args.food_type: 113 | dataset_name += f'_{ckpt_args.food_type}' 114 | 115 | loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=4) 116 | 117 | features = extract_features(loader, inception, device).numpy() 118 | 119 | features = features[: args.n_sample] 120 | 121 | print(f'extracted {features.shape[0]} features') 122 | 123 | mean = np.mean(features, 0) 124 | cov = np.cov(features, rowvar=False) 125 | 126 | with open(f'inception_{dataset_name}.pkl', 'wb') as f: 127 | pickle.dump({'mean': mean, 'cov': cov, 'dataset_name': dataset_name}, f) -------------------------------------------------------------------------------- /metrics/fid.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | from scipy import linalg 7 | from tqdm import tqdm 8 | 9 | from calc_inception import load_patched_inception_v3 10 | import pdb 11 | import os 12 | import csv 13 | from glob import glob 14 | import math 15 | from torch.nn import functional as F 16 | from matplotlib import pyplot as plt 17 | 18 | @torch.no_grad() 19 | def extract_features(batch_generator, inception, args): 20 | n_batches = args.n_sample // args.batch_size 21 | features = [] 22 | for _ in tqdm(range(n_batches)): 23 | img = batch_generator.generate_fid() 24 | feat = inception(img)[0].view(img.shape[0], -1) 25 | features.append(feat.to("cpu")) 26 | features = torch.cat(features, 0) 27 | return features.numpy() 28 | 29 | def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): 30 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 31 | 32 | if not np.isfinite(cov_sqrt).all(): 33 | print('product of cov matrices is singular') 34 | offset = np.eye(sample_cov.shape[0]) * eps 35 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 36 | 37 | if np.iscomplexobj(cov_sqrt): 38 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 39 | m = np.max(np.abs(cov_sqrt.imag)) 40 | 41 | raise ValueError(f'Imaginary component {m}') 42 | 43 | cov_sqrt = cov_sqrt.real 44 | 45 | mean_diff = sample_mean - real_mean 46 | mean_norm = mean_diff @ mean_diff 47 | 48 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 49 | 50 | fid = mean_norm + trace 51 | 52 | return fid 53 | 54 | 55 | if __name__ == '__main__': 56 | from utils_metrics import load_args 57 | args = load_args() 58 | 59 | # assertations 60 | assert 'ckpt_dir' in args.__dict__ 61 | assert 'inception' in args.__dict__ 62 | assert 'device' in args.__dict__ 63 | assert 'n_sample' in args.__dict__ 64 | assert 'batch_size' in args.__dict__ 65 | 66 | import sys 67 | if 'cookgan' in args.ckpt_dir: 68 | sys.path.append('../cookgan/') 69 | from generate_batch import BatchGenerator 70 | 71 | device = args.device 72 | 73 | print(f'load real image statistics from {args.inception}') 74 | with open(args.inception, 'rb') as f: 75 | embeds = pickle.load(f) 76 | real_mean = embeds['mean'] 77 | real_cov = embeds['cov'] 78 | 79 | filename = os.path.join(args.ckpt_dir, f'fid_{args.n_sample}.csv') 80 | # load values that are already computed 81 | computed = [] 82 | if os.path.exists(filename): 83 | with open(filename, 'r') as f: 84 | reader = csv.reader(f, delimiter=',') 85 | for row in reader: 86 | computed += [row[0]] 87 | 88 | # prepare to write 89 | f = open(filename, mode='a') 90 | writer = csv.writer(f, delimiter=',') 91 | 92 | # load inception model 93 | inception = load_patched_inception_v3() 94 | inception = inception.eval().to(device) 95 | 96 | ckpt_paths = glob(os.path.join(args.ckpt_dir, '*.ckpt')) + glob(os.path.join(args.ckpt_dir, '*.pt'))+glob(os.path.join(args.ckpt_dir, '*.pth')) 97 | ckpt_paths = sorted(ckpt_paths) 98 | print('records:', ckpt_paths) 99 | print('computed:', computed) 100 | for ckpt_path in ckpt_paths: 101 | print() 102 | print(f'working on {ckpt_path}') 103 | iteration = os.path.basename(ckpt_path).split('.')[0] 104 | if iteration in computed: 105 | print('already computed') 106 | continue 107 | 108 | args.ckpt_path = ckpt_path 109 | batch_generator = BatchGenerator(args) 110 | 111 | features = extract_features(batch_generator, inception, args) 112 | 113 | print(f'extracted {features.shape[0]} features') 114 | sample_mean = np.mean(features, 0) 115 | sample_cov = np.cov(features, rowvar=False) 116 | fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) 117 | print(f'{iteration}, fid={fid}') 118 | writer.writerow([iteration, fid]) 119 | 120 | 121 | f.close() 122 | fids = [] 123 | with open(filename, 'r') as f: 124 | reader = csv.reader(f, delimiter=',') 125 | for row in reader: 126 | fid = float(row[1]) 127 | fids += [fid] 128 | fig = plt.figure(figsize=(6,6)) 129 | plt.plot(fids) 130 | plt.savefig(os.path.join(args.ckpt_dir, f'fid_{args.n_sample}.png')) 131 | -------------------------------------------------------------------------------- /retrieval_model/pretrain_upmc/train_upmc.py: -------------------------------------------------------------------------------- 1 | # Python 3.6, PyTorch 0.4 2 | import torch 3 | from torch.utils.data import Subset 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | from torchvision import datasets, models, transforms 8 | import os 9 | import argparse 10 | import copy 11 | from tqdm import tqdm 12 | from dataset_upmc import Dataset 13 | import pdb 14 | import wandb 15 | 16 | import sys 17 | sys.path.append('../../') 18 | from common import param_counter, root 19 | from utils_upmc import gen_filelist 20 | 21 | 22 | # arguments 23 | parser = argparse.ArgumentParser(description='Resnet50 UMPC-Food-101 Classifier') 24 | parser.add_argument('--batch_size', type=int, default=64) 25 | parser.add_argument('--epochs', type=int, default=25) 26 | parser.add_argument('--data_dir', default=f'{root}/retrieval_model/pretrain_upmc/UPMC-Food-101/') 27 | args = parser.parse_args() 28 | print(args) 29 | 30 | batch_size = args.batch_size 31 | epochs = args.epochs 32 | data_dir = args.data_dir 33 | 34 | # load data 35 | parts = ('train', 'test') 36 | for part in parts: 37 | gen_filelist(data_dir, part) 38 | 39 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 40 | data_transforms = { 41 | parts[0]: transforms.Compose([ 42 | transforms.RandomResizedCrop(224), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | normalize 46 | ]), 47 | parts[1]: transforms.Compose([ 48 | transforms.Resize((224,224)), 49 | transforms.ToTensor(), 50 | normalize 51 | ]), 52 | } 53 | 54 | datasets = { 55 | x: Dataset( 56 | root=os.path.join(data_dir, 'images'), 57 | flist=os.path.join(data_dir, x+".txt"), 58 | transform=data_transforms[x]) 59 | for x in parts} 60 | 61 | # datasets = {x: Subset(datasets[x], range(200)) for x in parts} 62 | 63 | dataloaders = { 64 | x: torch.utils.data.DataLoader( 65 | datasets[x], 66 | batch_size=batch_size, 67 | shuffle=True, num_workers=4, pin_memory=False) 68 | for x in parts} 69 | 70 | dataset_sizes = {x: len(datasets[x]) for x in parts} 71 | dataloader_sizes = {x: len(dataloaders[x]) for x in parts} 72 | print('datasets', dataset_sizes) 73 | print('dataloaders', dataloader_sizes) 74 | 75 | # load model 76 | model = models.resnet50(pretrained=True) 77 | num_feat = model.fc.in_features 78 | model.fc = nn.Linear(num_feat, 101) 79 | criterion = nn.CrossEntropyLoss() 80 | optimizer = optim.Adam(model.parameters()) 81 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience=5) 82 | 83 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 84 | model = model.to(device) 85 | if device == 'cuda': 86 | model = nn.DataParallel(model) 87 | model_to_save = model.module 88 | else: 89 | model_to_save = model 90 | 91 | print('# parameters:', param_counter(model.parameters())) 92 | 93 | wandb.init(project="cookgan_pretrain_upmc") 94 | wandb.config.update(args) 95 | 96 | # train 97 | for epoch in range(epochs): 98 | print('-' * 10) 99 | print('Epoch {}/{}'.format(epoch, epochs - 1)) 100 | running_loss = 0.0 101 | running_correct = 0.0 102 | for part in parts: 103 | if part == 'train': 104 | model.train() 105 | else: 106 | model.eval() 107 | 108 | pbar = tqdm(dataloaders[part]) 109 | for inputs, labels in pbar: 110 | inputs = inputs.to(device) 111 | labels = labels.to(device) 112 | with torch.set_grad_enabled(part == 'train'): 113 | outputs = model(inputs) 114 | _, preds = torch.max(outputs, 1) 115 | loss = criterion(outputs, labels) 116 | pbar.set_description(f'loss={loss:.4f}') 117 | if part == 'train': 118 | optimizer.zero_grad() 119 | loss.backward() 120 | optimizer.step() 121 | 122 | running_loss += loss.item() * inputs.shape[0] 123 | running_correct += torch.sum(preds == labels) 124 | 125 | loss_epoch = running_loss / dataset_sizes[part] 126 | acc_epoch = running_correct.double() / dataset_sizes[part] 127 | log = { 128 | 'epoch': epoch, 129 | f'loss_{part}': loss_epoch, 130 | f'accuracy_{part}': acc_epoch, 131 | 'lr': optimizer.param_groups[0]['lr'] 132 | } 133 | 134 | scheduler.step(loss_epoch) 135 | if epoch % 5 == 0: 136 | print('save checkpoint...') 137 | torch.save(model_to_save.state_dict(), f'{wandb.run.dir}/{epoch:>06d}.ckpt') -------------------------------------------------------------------------------- /retrieval_model/modules/dynamic_soft_margin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .hardnet_loss import ( 4 | compute_distance_matrix_hamming, 5 | compute_distance_matrix_unit_l2, 6 | find_hard_negatives, 7 | ) 8 | 9 | 10 | class DynamicSoftMarginLoss(nn.Module): 11 | def __init__(self, is_binary=False, momentum=0.01, max_dist=None, nbins=512): 12 | """ 13 | is_binary: true if learning binary descriptor 14 | momentum: weight assigned to the histogram computed from the current batch 15 | max_dist: maximum possible distance in the feature space 16 | nbins: number of bins to discretize the PDF 17 | """ 18 | super(DynamicSoftMarginLoss, self).__init__() 19 | self._is_binary = is_binary 20 | 21 | if max_dist is None: 22 | # max_dist = 256 if self._is_binary else 2.0 23 | max_dist = 2.0 24 | 25 | self._momentum = momentum 26 | self._max_val = max_dist 27 | self._min_val = -max_dist 28 | self.register_buffer("histogram", torch.ones(nbins)) 29 | 30 | self._stats_initialized = False 31 | self.current_step = None 32 | 33 | def _compute_distances(self, x): 34 | if self._is_binary: 35 | return self._compute_hamming_distances(x) 36 | else: 37 | return self._compute_l2_distances(x) 38 | 39 | def _compute_l2_distances(self, x): 40 | cnt = x.size(0) // 2 41 | a = x[:cnt, :] 42 | p = x[cnt:, :] 43 | # import pdb; pdb.set_trace() 44 | dmat = compute_distance_matrix_unit_l2(a, p) 45 | return find_hard_negatives(dmat, output_index=False, empirical_thresh=0.008) 46 | 47 | def _compute_hamming_distances(self, x): 48 | cnt = x.size(0) // 2 49 | ndims = x.size(1) 50 | a = x[:cnt, :] 51 | p = x[cnt:, :] 52 | 53 | dmat = compute_distance_matrix_hamming( 54 | (a > 0).float() * 2.0 - 1.0, (p > 0).float() * 2.0 - 1.0 55 | ) 56 | a_idx, p_idx, n_idx = find_hard_negatives( 57 | dmat, output_index=True, empirical_thresh=2 58 | ) 59 | 60 | # differentiable Hamming distance 61 | a = x[a_idx, :] 62 | p = x[p_idx, :] 63 | n = x[n_idx, :] 64 | 65 | pos_dist = (1.0 - a * p).sum(dim=1) / ndims 66 | neg_dist = (1.0 - a * n).sum(dim=1) / ndims 67 | 68 | # non-differentiable Hamming distance 69 | a_b = (a > 0).float() * 2.0 - 1.0 70 | p_b = (p > 0).float() * 2.0 - 1.0 71 | n_b = (n > 0).float() * 2.0 - 1.0 72 | 73 | pos_dist_b = (1.0 - a_b * p_b).sum(dim=1) / ndims 74 | neg_dist_b = (1.0 - a_b * n_b).sum(dim=1) / ndims 75 | 76 | return pos_dist, neg_dist, pos_dist_b, neg_dist_b 77 | 78 | def _compute_histogram(self, x, momentum): 79 | """ 80 | update the histogram using the current batch 81 | """ 82 | num_bins = self.histogram.size(0) 83 | x_detached = x.detach() 84 | self.bin_width = (self._max_val - self._min_val) / (num_bins - 1) 85 | lo = torch.floor((x_detached - self._min_val) / self.bin_width).long() 86 | hi = (lo + 1).clamp(min=0, max=num_bins - 1) 87 | hist = x.new_zeros(num_bins) 88 | alpha = ( 89 | 1.0 90 | - (x_detached - self._min_val - lo.float() * self.bin_width) 91 | / self.bin_width 92 | ) 93 | hist.index_add_(0, lo, alpha) 94 | hist.index_add_(0, hi, 1.0 - alpha) 95 | hist = hist / (hist.sum() + 1e-6) 96 | self.histogram = (1.0 - momentum) * self.histogram + momentum * hist 97 | 98 | def _compute_stats(self, pos_dist, neg_dist): 99 | hist_val = pos_dist - neg_dist 100 | if self._stats_initialized: 101 | self._compute_histogram(hist_val, self._momentum) 102 | else: 103 | self._compute_histogram(hist_val, 1.0) 104 | self._stats_initialized = True 105 | 106 | def forward(self, x): 107 | distances = self._compute_distances(x) 108 | if not self._is_binary: 109 | pos_dist, neg_dist = distances 110 | self._compute_stats(pos_dist, neg_dist) 111 | hist_var = pos_dist - neg_dist 112 | else: 113 | pos_dist, neg_dist, pos_dist_b, neg_dist_b = distances 114 | self._compute_stats(pos_dist_b, neg_dist_b) 115 | hist_var = pos_dist_b - neg_dist_b 116 | 117 | PDF = self.histogram / self.histogram.sum() 118 | CDF = PDF.cumsum(0) 119 | 120 | # lookup weight from the CDF 121 | bin_idx = torch.floor((hist_var - self._min_val) / self.bin_width).long() 122 | weight = CDF[bin_idx] 123 | 124 | loss = -(neg_dist * weight).mean() + (pos_dist * weight).mean() 125 | return loss 126 | -------------------------------------------------------------------------------- /retrieval_model/datasets_retrieval.py: -------------------------------------------------------------------------------- 1 | import json, pickle 2 | import os 3 | from glob import glob 4 | import numpy as np 5 | from torchvision import transforms 6 | from torch.utils import data 7 | from gensim.models.keyedvectors import KeyedVectors 8 | from PIL import Image 9 | 10 | import sys 11 | sys.path.append('../') 12 | from common import load_recipes, get_title_wordvec, get_ingredients_wordvec, get_instructions_wordvec 13 | 14 | mean = [0.485, 0.456, 0.406] 15 | std = [0.229, 0.224, 0.225] 16 | normalize = transforms.Normalize(mean=mean, std=std) 17 | train_transform = transforms.Compose([ 18 | transforms.Resize(224), 19 | transforms.RandomCrop(224), 20 | transforms.ToTensor(), 21 | normalize 22 | ]) 23 | 24 | val_transform = transforms.Compose([ 25 | transforms.Resize(224), 26 | transforms.CenterCrop(224), 27 | transforms.ToTensor(), 28 | normalize 29 | ]) 30 | 31 | def default_loader(path): 32 | try: 33 | im = Image.open(path).convert('RGB') 34 | return im 35 | except: 36 | print('error to open image:', path) 37 | return Image.new('RGB', (224, 224), 'white') 38 | 39 | def choose_one_image(rcp, img_dir): 40 | """ 41 | Arguments: 42 | rcp: recipe 43 | img_dir: image directory 44 | Returns: 45 | PIL image 46 | """ 47 | part = rcp['partition'] 48 | image_infos = rcp['images'] 49 | if part == 'train': 50 | # We do only use the first five images per recipe during training 51 | imgIdx = np.random.choice(range(min(5, len(image_infos)))) 52 | else: 53 | imgIdx = 0 54 | loader_path = [image_infos[imgIdx]['id'][i] for i in range(4)] 55 | loader_path = os.path.join(*loader_path) 56 | if 'plus' in img_dir: 57 | path = os.path.join(img_dir, loader_path, image_infos[imgIdx]['id']) 58 | else: 59 | path = os.path.join(img_dir, part, loader_path, image_infos[imgIdx]['id']) 60 | return default_loader(path) 61 | 62 | 63 | class Dataset(data.Dataset): 64 | def __init__( 65 | self, 66 | part, 67 | recipe_file, 68 | img_dir, 69 | word2vec_file, 70 | transform=None, 71 | permute_ingrs=False): 72 | assert part in ('train', 'val', 'test'), \ 73 | 'part must be one of [train, val, test]' 74 | self.recipes = load_recipes(recipe_file, part) 75 | 76 | wv = KeyedVectors.load(word2vec_file, mmap='r') 77 | w2i = {w: i+2 for i, w in enumerate(wv.index2word)} 78 | w2i[''] = 1 79 | self.w2i = w2i 80 | print('vocab size =', len(self.w2i)) 81 | 82 | self.img_dir = img_dir 83 | self.transform = transform 84 | self.permute_ingrs = permute_ingrs 85 | 86 | def _prepare_one_recipe(self, index): 87 | rcp = self.recipes[index] 88 | 89 | title, n_words_in_title = get_title_wordvec(rcp, self.w2i) # np.int [max_len] 90 | ingredients, n_ingrs = get_ingredients_wordvec(rcp, self.w2i, self.permute_ingrs) # np.int [max_len] 91 | instructions, n_insts, n_words_each_inst = get_instructions_wordvec(rcp, self.w2i) # np.int [max_len, max_len] 92 | 93 | pil_img = choose_one_image(rcp, self.img_dir) # PIL [3, 224, 224] 94 | if self.transform: 95 | img = self.transform(pil_img) 96 | return [title, n_words_in_title, ingredients, n_ingrs, instructions, n_insts, n_words_each_inst], img 97 | 98 | def __getitem__(self, index): 99 | txt, img = self._prepare_one_recipe(index) 100 | return txt, img 101 | 102 | def __len__(self): 103 | return len(self.recipes) 104 | 105 | 106 | if __name__ == '__main__': 107 | dataset = Dataset( 108 | part='train', 109 | recipe_file='../data/Recipe1M/recipes_withImage.json', 110 | img_dir='../data/Recipe1M/images', 111 | word2vec_file='models/word2vec_recipes.bin', 112 | transform=train_transform, 113 | permute_ingrs=False) 114 | 115 | for data in dataset: 116 | txt, img = data 117 | i2w = {i:w for w,i in dataset.w2i.items()} 118 | def get_words(vec, length, i2w): 119 | words = [] 120 | for i in vec[:length]: 121 | words.append(i2w[i]) 122 | return words 123 | 124 | print('[title] = {}'.format(' '.join(get_words(txt[0], txt[1], i2w)))) 125 | print('[ingredients] = {}'.format(get_words(txt[2], txt[3], i2w))) 126 | print('[instructions]') 127 | instructions, n_insts, n_words_each_inst = txt[4], txt[5], txt[6] 128 | for i in range(n_insts): 129 | inst = instructions[i] 130 | inst_length = n_words_each_inst[i] 131 | print('[{:>2d}] {}'.format(i+1, ' '.join(get_words(inst, inst_length, i2w)))) 132 | 133 | from matplotlib import pyplot as plt 134 | def show(img): 135 | npimg = img.numpy() 136 | npimg = (npimg-npimg.min()) / (npimg.max()-npimg.min()) 137 | plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest') 138 | plt.show() 139 | 140 | show(img) 141 | print() -------------------------------------------------------------------------------- /retrieval_model/val_retrieval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.data import DataLoader 4 | from tqdm import tqdm 5 | import numpy as np 6 | import os 7 | import csv 8 | from glob import glob 9 | 10 | from args_retrieval import get_parser 11 | from utils_retrieval import compute_statistics, compute_ranks 12 | from datasets_retrieval import Dataset, val_transform 13 | from train_retrieval import load_model 14 | 15 | def extract_features(text_encoder, image_encoder, ckpt_args, data_loader): 16 | text_encoder.eval() 17 | image_encoder.eval() 18 | txt_feats = [] 19 | img_feats = [] 20 | if ckpt_args.text_info[0] == '1': 21 | title_attn = [] 22 | if ckpt_args.text_info[1] == '1': 23 | ingr_attn = [] 24 | if ckpt_args.text_info[2] == '1': 25 | inst_attn = [] 26 | inst_word_attn = [] 27 | for data in tqdm(data_loader): 28 | txt, img = data 29 | for i in range(len(txt)): 30 | txt[i] = txt[i].to(device) 31 | img = img.to(device) 32 | 33 | with torch.no_grad(): 34 | txt_feat, attentions = text_encoder(*txt) 35 | if ckpt_args.with_attention: 36 | if ckpt_args.text_info[0] == '1': 37 | title_attn.append(attentions[0]) 38 | if ckpt_args.text_info[1] == '1': 39 | ingr_attn.append(attentions[1]) 40 | if ckpt_args.text_info[2] == '1': 41 | inst_attn.append(attentions[2]) 42 | inst_word_attn.append(attentions[3]) 43 | 44 | img_feat = image_encoder(img) 45 | txt_feats.append(txt_feat.detach().cpu()) 46 | img_feats.append(img_feat.detach().cpu()) 47 | 48 | txt_feats = torch.cat(txt_feats, dim=0) 49 | img_feats = torch.cat(img_feats, dim=0) 50 | attentions = [None, None, None, None] 51 | if ckpt_args.with_attention: 52 | if ckpt_args.text_info[0] == '1': 53 | title_attn = torch.cat(title_attn, dim=0).cpu().numpy() 54 | attentions[0] = title_attn 55 | if ckpt_args.text_info[1] == '1': 56 | ingr_attn = torch.cat(ingr_attn, dim=0).cpu().numpy() 57 | attentions[1] = ingr_attn 58 | if ckpt_args.text_info[2] == '1': 59 | inst_attn = torch.cat(inst_attn, dim=0).cpu().numpy() 60 | attentions[2] = inst_attn 61 | 62 | return txt_feats.numpy(), img_feats.numpy(), attentions 63 | 64 | 65 | if __name__ == '__main__': 66 | args = get_parser().parse_args() 67 | torch.manual_seed(args.seed) 68 | np.random.seed(args.seed) 69 | device = args.device 70 | if not args.ckpt_dir: 71 | args.ckpt_dir = '/common/home/fh199/CookGAN/retrieval_model/wandb/run-20201202_174456-3kh60es7/files' 72 | 73 | filename = os.path.join(args.ckpt_dir, f'metrics.csv') 74 | # load values that are already computed 75 | computed = [] 76 | if os.path.exists(filename): 77 | with open(filename, 'r') as f: 78 | reader = csv.reader(f, delimiter=',') 79 | for row in reader: 80 | computed += [row[0]] 81 | 82 | # prepare to write 83 | f = open(filename, mode='a') 84 | writer = csv.writer(f, delimiter=',') 85 | 86 | ckpt_paths = glob(os.path.join(args.ckpt_dir, '*.ckpt')) 87 | ckpt_paths = sorted(ckpt_paths) 88 | print('records:', ckpt_paths) 89 | print('computed:', computed) 90 | data_loader = None 91 | w2i = None 92 | for ckpt_path in ckpt_paths: 93 | print() 94 | print(f'working on {ckpt_path}') 95 | 96 | ckpt_args, _, text_encoder, image_encoder, _ = load_model(ckpt_path, device) 97 | 98 | if not data_loader: 99 | print('loading dataset') 100 | dataset = Dataset( 101 | part='val', 102 | recipe_file=ckpt_args.recipe_file, 103 | img_dir=ckpt_args.img_dir, 104 | word2vec_file=ckpt_args.word2vec_file, 105 | permute_ingrs=ckpt_args.permute_ingrs, 106 | transform=val_transform, 107 | ) 108 | w2i = dataset.w2i 109 | dataset = torch.utils.data.Subset(dataset, indices=np.random.choice(len(dataset), 5000)) 110 | data_loader = DataLoader( 111 | dataset, batch_size=args.batch_size, shuffle=False, 112 | num_workers=args.workers, pin_memory=True, drop_last=False) 113 | print('data info:', len(dataset), len(data_loader)) 114 | 115 | txt_feats, img_feats, attentions = extract_features(text_encoder, image_encoder, ckpt_args, data_loader) 116 | title_attn, ingr_attn, inst_attn, _ = attentions 117 | 118 | retrieved_range = min(txt_feats.shape[0], args.retrieved_range) 119 | medRs, recalls = compute_statistics( 120 | txt_feats, 121 | img_feats, 122 | retrieved_type=args.retrieved_type, 123 | retrieved_range=retrieved_range, 124 | draw_hist=False) 125 | 126 | writer.writerow([ckpt_path, medRs.mean(), medRs.std(), recalls[1], recalls[5], recalls[10]]) -------------------------------------------------------------------------------- /retrieval_model/explore_attention.py: -------------------------------------------------------------------------------- 1 | # TODO: finish this script 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | import numpy as np 7 | import os 8 | import csv 9 | from glob import glob 10 | 11 | from args_retrieval import get_parser 12 | from utils_retrieval import compute_ranks 13 | from datasets_retrieval import Dataset, val_transform 14 | from train_retrieval import load_model 15 | from val_retrieval import extract_features 16 | 17 | if __name__ == '__main__': 18 | args = get_parser().parse_args() 19 | torch.manual_seed(args.seed) 20 | np.random.seed(args.seed) 21 | device = args.device 22 | assert args.ckpt_path!='' 23 | ckpt_dir = os.path.dirname(args.ckpt_path) 24 | 25 | ckpt_args, _, text_encoder, image_encoder, _ = load_model(args.ckpt_path, device) 26 | 27 | print('loading dataset') 28 | dataset = Dataset( 29 | part='val', 30 | recipe_file=ckpt_args.recipe_file, 31 | img_dir=ckpt_args.img_dir, 32 | word2vec_file=ckpt_args.word2vec_file, 33 | permute_ingrs=ckpt_args.permute_ingrs, 34 | transform=val_transform, 35 | ) 36 | w2i = dataset.w2i 37 | dataset = torch.utils.data.Subset(dataset, indices=np.random.choice(len(dataset), 5000)) 38 | data_loader = DataLoader( 39 | dataset, batch_size=args.batch_size, shuffle=False, 40 | num_workers=args.workers, pin_memory=True, drop_last=False) 41 | print('data info:', len(dataset), len(data_loader)) 42 | 43 | txt_feats, img_feats, attentions = extract_features(text_encoder, image_encoder, ckpt_args, data_loader) 44 | title_attn, ingr_attn, inst_attn, _ = attentions 45 | 46 | # draw attention if possible: 47 | def save_attention_result(index, dataset, i2w, ranks, ingr_attn, ckpt_dir): 48 | fig = plt.figure(figsize=(12,6)) 49 | [title, n_words_in_title, ingredients, n_ingrs, instructions, n_insts, n_words_each_inst], img = dataset[index] 50 | ingr_alpha = ingr_attn[index] 51 | 52 | title_disp = ' '.join([i2w[idx] for idx in title[:n_words_in_title]]) 53 | fig.suptitle(title_disp) 54 | 55 | # # title 56 | # one_vector = title[i] 57 | # one_alpha = alpha_title[i] 58 | # length = len(one_vector.nonzero()[0]) 59 | # one_word_list = [i2w[idx] for idx in one_vector[:length]] 60 | # ind = np.arange(length) 61 | # plt.subplot(411) 62 | # # pdb.set_trace() 63 | # plt.barh(ind, one_alpha[:length]) 64 | # plt.yticks(ind, one_word_list) 65 | 66 | # ingredients 67 | one_vector = ingredients 68 | one_alpha = ingr_alpha 69 | one_word_list = [i2w[idx] for idx in one_vector[:n_ingrs]] 70 | ind = np.arange(n_ingrs) 71 | plt.subplot(121) 72 | plt.barh(ind, one_alpha[:n_ingrs]) 73 | plt.yticks(ind, one_word_list, fontsize=12) 74 | 75 | # # instructions 76 | # one_matrix = instructions[i] 77 | # one_alpha = alpha_instructions[i] 78 | # # pdb.set_trace() 79 | # length = one_matrix.nonzero()[0].max() + 1 80 | # one_sentence_list = [] 81 | # for k in range(length): 82 | # one_vector = one_matrix[k] 83 | # one_vector_length = len(one_vector.nonzero()[0]) 84 | # one_sentence = ' '.join([i2w[idx] for idx in one_vector[:one_vector_length]]) 85 | # one_sentence_list.append(one_sentence) 86 | # ind = np.arange(length) 87 | # plt.subplot(413) 88 | # plt.barh(ind, one_alpha[:length]) 89 | # plt.yticks(ind, one_sentence_list) 90 | 91 | # images 92 | plt.subplot(122) 93 | one_img = img.permute(1,2,0).detach().cpu().numpy() 94 | scale = one_img.max() - one_img.min() 95 | one_img = (one_img - one_img.min()) / scale 96 | plt.imshow(one_img) 97 | plt.axis('off') 98 | 99 | plt.savefig(os.path.join(ckpt_dir, 'rank={}_{}.jpg'.format(ranks[index], title_disp))) 100 | 101 | if ckpt_args.with_attention: 102 | from matplotlib import pyplot as plt 103 | ranks = compute_ranks(txt_feats[:1000], img_feats[:1000]) 104 | 105 | print('plot ranks') 106 | medR = np.median(ranks).astype(int) 107 | plt.figure(figsize=(6,6)) 108 | n, bins, patches = plt.hist(x=ranks, bins='auto', color='#0504aa', alpha=0.7, rwidth=0.85) 109 | plt.grid(axis='y', alpha=0.75) 110 | plt.xlabel('Rank') 111 | plt.ylabel('Frequency') 112 | plt.title('Rank Distribution') 113 | plt.text(23, 45, 'avgR(std) = {:.2f}({:.2f})\nmedR={:d}\n#<{:d}:{:d}|#={:d}:{:d}|#>{:d}:{:d}'.format( 114 | np.mean(ranks), np.std(ranks), medR, 115 | medR,(ranksmedR).sum())) 116 | plt.savefig(os.path.join(args.ckpt_dir, 'batch_ranks.jpg')) 117 | 118 | print('plot attentions') 119 | ingr_attn = ingr_attn[:1000] 120 | sorted_idxs = np.argsort(ranks).tolist() 121 | i2w = {i:w for w,i in w2i.items()} 122 | for i in range(10): 123 | idx = sorted_idxs[i] 124 | save_attention_result(idx, dataset, i2w, ranks, ingr_attn, args.ckpt_dir) 125 | for i in range(-10, 0): 126 | idx = sorted_idxs[i] 127 | save_attention_result(idx, dataset, i2w, ranks, ingr_attn, args.ckpt_dir) -------------------------------------------------------------------------------- /metrics/datasets_inception.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import lmdb 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | import math 7 | import torch 8 | import json 9 | import numpy as np 10 | import os 11 | 12 | transform = transforms.Compose([ 13 | transforms.ToTensor(), 14 | transforms.Normalize(0.5, 0.5)]) 15 | 16 | 17 | class Recipe1MDataset(Dataset): 18 | def __init__( 19 | self, 20 | lmdb_file='/dresden/users/fh199/food_project/data/Recipe1M/Recipe1M.lmdb', 21 | food_type='', transform=transform, resolution=256): 22 | 23 | assert food_type in ['', 'salad', 'cookie', 'muffin'], "part has to be in ['', 'salad', 'cookie', 'muffin']" 24 | 25 | dirname = os.path.dirname(lmdb_file) 26 | path = os.path.join(dirname, 'keys.json') 27 | with open(path, 'r') as f: 28 | self.keys = json.load(f) 29 | if food_type: 30 | self.keys = [x for x in self.keys if food_type.lower() in x['title'].lower()] 31 | 32 | self.env = lmdb.open( 33 | lmdb_file, 34 | max_readers=32, 35 | readonly=True, 36 | lock=False, 37 | readahead=False, 38 | meminit=False, 39 | ) 40 | 41 | if not self.env: 42 | raise IOError('Cannot open lmdb dataset', lmdb_file) 43 | 44 | self.resolution = resolution 45 | 46 | assert transform!=None, 'transform can not be None!' 47 | self.transform = transform 48 | 49 | def __len__(self): 50 | return len(self.keys) 51 | 52 | def _load_recipe(self, rcp): 53 | rcp_id = rcp['id'] 54 | 55 | with self.env.begin(write=False) as txn: 56 | key = f'title-{rcp_id}'.encode('utf-8') 57 | title = txn.get(key).decode('utf-8') 58 | 59 | key = f'ingredients-{rcp_id}'.encode('utf-8') 60 | ingredients = txn.get(key).decode('utf-8') 61 | 62 | key = f'instructions-{rcp_id}'.encode('utf-8') 63 | instructions = txn.get(key).decode('utf-8') 64 | 65 | key = f'{self.resolution}-{rcp_id}'.encode('utf-8') 66 | img_bytes = txn.get(key) 67 | 68 | txt = title 69 | txt += '\n' 70 | txt += ingredients 71 | txt += '\n' 72 | txt += instructions 73 | 74 | buffer = BytesIO(img_bytes) 75 | img = Image.open(buffer) 76 | img = self.transform(img) 77 | return txt, img 78 | 79 | def __getitem__(self, index): 80 | rcp = self.keys[index] 81 | txt, img = self._load_recipe(rcp) 82 | 83 | return txt, img 84 | 85 | 86 | class PizzaGANDataset(Dataset): 87 | def __init__( 88 | self, 89 | lmdb_file='/dresden/users/fh199/food_project/data/pizzaGANdata_new_concise/pizzaGANdata.lmdb', 90 | transform=transform, resolution=64): 91 | 92 | dirname = os.path.dirname(lmdb_file) 93 | label_file = os.path.join(dirname, 'imageLabels.txt') 94 | with open(label_file, 'r') as f: 95 | self.labels = f.read().strip().split('\n') 96 | 97 | self.env = lmdb.open( 98 | lmdb_file, 99 | max_readers=32, 100 | readonly=True, 101 | lock=False, 102 | readahead=False, 103 | meminit=False, 104 | ) 105 | 106 | if not self.env: 107 | raise IOError('Cannot open lmdb dataset', lmdb_file) 108 | 109 | with self.env.begin(write=False) as txn: 110 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 111 | 112 | self.resolution = resolution 113 | 114 | assert transform!=None, 'transform can not be None!' 115 | self.transform = transform 116 | 117 | def __len__(self): 118 | return self.length 119 | 120 | def _load_pizza(self, idx): 121 | with self.env.begin(write=False) as txn: 122 | key = f'{idx}'.encode('utf-8') 123 | ingrs = txn.get(key).decode('utf-8') 124 | if not ingrs: 125 | ingrs = 'empty' 126 | key = f'{self.resolution}-{idx}'.encode('utf-8') 127 | img_bytes = txn.get(key) 128 | 129 | txt = ingrs 130 | buffer = BytesIO(img_bytes) 131 | img = Image.open(buffer) 132 | img = self.transform(img) 133 | return txt, img 134 | 135 | def __getitem__(self, idx): 136 | txt, img = self._load_pizza(idx) 137 | return txt, img 138 | 139 | 140 | if __name__ == '__main__': 141 | import pdb 142 | from matplotlib import pyplot as plt 143 | 144 | def show(img): 145 | npimg = img.numpy() 146 | npimg = (npimg-npimg.min()) / (npimg.max()-npimg.min()) 147 | plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest') 148 | plt.show() 149 | 150 | res = 256 151 | 152 | dataset = PizzaGANDataset( 153 | lmdb_file='/dresden/users/fh199/food_project/data/pizzaGANdata_new_concise/pizzaGANdata.lmdb', 154 | transform=transform, resolution=res) 155 | 156 | # dataset = Recipe1MDataset( 157 | # lmdb_file='/dresden/users/fh199/food_project/data/Recipe1M/Recipe1M.lmdb', 158 | # food_type='', transform=transform, resolution=res) 159 | 160 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=8, shuffle=False) 161 | print(len(dataset), len(dataloader)) 162 | for txt, img in dataloader: 163 | print(len(txt)) 164 | print(txt[0]) 165 | print(img.shape) 166 | # show(img) 167 | pdb.set_trace() 168 | break -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import string 3 | import json 4 | import numpy as np 5 | import re 6 | import copy 7 | from datetime import datetime 8 | import json 9 | import argparse 10 | 11 | root = '/common/home/fh199/CookGAN' 12 | 13 | def clean_state_dict(state_dict): 14 | # create new OrderedDict that does not contain `module.` 15 | from collections import OrderedDict 16 | new_state_dict = OrderedDict() 17 | for k, v in state_dict.items(): 18 | name = k[7:] if k[:min(6,len(k))] == 'module' else k # remove `module.` 19 | new_state_dict[name] = v 20 | return new_state_dict 21 | 22 | def sample_data(loader): 23 | """ 24 | arguments: 25 | loader: torch.utils.data.DataLoader 26 | return: 27 | one batch of data 28 | usage: 29 | data = next(sample_data(loader)) 30 | """ 31 | while True: 32 | for batch in loader: 33 | yield batch 34 | 35 | def str2bool(v): 36 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 37 | return True 38 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 39 | return False 40 | else: 41 | raise argparse.ArgumentTypeError('Boolean value expected.') 42 | 43 | 44 | def dspath(ext, ROOT, **kwargs): 45 | return os.path.join(ROOT,ext) 46 | 47 | class Layer(object): 48 | L1 = 'layer1' 49 | L2 = 'layer2' 50 | L3 = 'layer3' 51 | INGRS = 'det_ingrs' 52 | 53 | @staticmethod 54 | def load(name, ROOT, **kwargs): 55 | with open(dspath(name + '.json',ROOT, **kwargs)) as f_layer: 56 | return json.load(f_layer) 57 | 58 | @staticmethod 59 | def merge(layers, ROOT,copy_base=False, **kwargs): 60 | layers = [l if isinstance(l, list) else Layer.load(l, ROOT, **kwargs) for l in layers] 61 | base = copy.deepcopy(layers[0]) if copy_base else layers[0] 62 | entries_by_id = {entry['id']: entry for entry in base} 63 | for layer in layers[1:]: 64 | for entry in layer: 65 | base_entry = entries_by_id.get(entry['id']) 66 | if not base_entry: 67 | continue 68 | base_entry.update(entry) 69 | return base 70 | 71 | def remove_numbers(s): 72 | ''' 73 | remove numbers in a sentence. 74 | - 1.1: \d+\.\d+ 75 | - 1 1/2 or 1-1/2 or 1 -1/2 or 1- 1/2 or 1 - 1/2: (\d+ *-* *)?\d+/\d+ 76 | - 1: \d+' 77 | 78 | Arguments: 79 | s {str} -- the string to operate on 80 | 81 | Returns: 82 | str -- the modified string without numbers 83 | ''' 84 | return re.sub(r'\d+\.\d+|(\d+ *-* *)?\d+/\d+|\d+', 'some', s) 85 | 86 | def tok(text, ts=False): 87 | if not ts: 88 | ts = [',','.',';','(',')','?','!','&','%',':','*','"'] 89 | for t in ts: 90 | text = text.replace(t,' ' + t + ' ') 91 | return text 92 | 93 | 94 | param_counter = lambda params: sum(p.numel() for p in params if p.requires_grad) 95 | 96 | 97 | def load_recipes(file_path, part=None): 98 | with open(file_path, 'r') as f: 99 | info = json.load(f) 100 | if part: 101 | info = [x for x in info if x['partition']==part] 102 | return info 103 | 104 | 105 | def get_title_wordvec(recipe, w2i, max_len=20): 106 | ''' 107 | get the title wordvec for the recipe, the 108 | number of items might be different for different 109 | recipe 110 | ''' 111 | title = recipe['title'] 112 | words = title.split() 113 | vec = np.zeros([max_len], dtype=np.int) 114 | num_words = min(max_len, len(words)) 115 | for i in range(num_words): 116 | word = words[i] 117 | if word not in w2i: 118 | word = '' 119 | vec[i] = w2i[word] 120 | return vec, num_words 121 | 122 | 123 | def get_instructions_wordvec(recipe, w2i, max_len=20): 124 | ''' 125 | get the instructions wordvec for the recipe, the 126 | number of items might be different for different 127 | recipe 128 | ''' 129 | instructions = recipe['instructions'] 130 | # each recipe has at most max_len sentences 131 | # each sentence has at most max_len words 132 | vec = np.zeros([max_len, max_len], dtype=np.int) 133 | num_insts = min(max_len, len(instructions)) 134 | num_words_each_inst = np.zeros(max_len, dtype=np.int) 135 | for row in range(num_insts): 136 | inst = instructions[row] 137 | words = inst.split() 138 | num_words = min(max_len, len(words)) 139 | num_words_each_inst[row] = num_words 140 | for col in range(num_words): 141 | word = words[col] 142 | if word not in w2i: 143 | word = '' 144 | vec[row, col] = w2i[word] 145 | return vec, num_insts, num_words_each_inst 146 | 147 | 148 | def get_ingredients_wordvec(recipe, w2i, permute_ingrs=False, max_len=20): 149 | ''' 150 | get the ingredients wordvec for the recipe, the 151 | number of items might be different for different 152 | recipe 153 | ''' 154 | ingredients = recipe['ingredients'] 155 | if permute_ingrs: 156 | ingredients = np.random.permutation(ingredients).tolist() 157 | vec = np.zeros([max_len], dtype=np.int) 158 | num_words = min(max_len, len(ingredients)) 159 | 160 | for i in range(num_words): 161 | word = ingredients[i] 162 | if word not in w2i: 163 | word = '' 164 | vec[i] = w2i[word] 165 | 166 | return vec, num_words 167 | 168 | 169 | def get_ingredients_wordvec_withClasses(recipe, w2i, ingr2i, permute_ingrs=False, max_len=20): 170 | ''' 171 | get the ingredients wordvec for the recipe, the 172 | number of items might be different for different 173 | recipe 174 | ''' 175 | ingredients = recipe['ingredients'] 176 | if permute_ingrs: 177 | ingredients = np.random.permutation(ingredients).tolist() 178 | 179 | label = np.zeros([len(ingr2i)], dtype=np.float32) 180 | 181 | vec = np.zeros([max_len], dtype=np.int) 182 | num_words = min(max_len, len(ingredients)) 183 | 184 | for i in range(num_words): 185 | word = ingredients[i] 186 | if word not in w2i: 187 | word = '' 188 | vec[i] = w2i[word] 189 | 190 | if word in ingr2i: 191 | label[ingr2i[word]] = 1 192 | 193 | return vec, num_words, label 194 | 195 | def requires_grad(model, flag=True): 196 | for p in model.parameters(): 197 | p.requires_grad = flag -------------------------------------------------------------------------------- /cookgan/datasets_cookgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils import data 4 | from torch.nn import functional as F 5 | from torchvision import transforms 6 | import pickle 7 | import numpy as np 8 | import os 9 | import json 10 | from gensim.models.keyedvectors import KeyedVectors 11 | from PIL import Image 12 | 13 | import sys 14 | sys.path.append('../') 15 | from common import load_recipes, get_title_wordvec, get_ingredients_wordvec_withClasses, get_instructions_wordvec 16 | 17 | def get_imgs(img_path, imsize, bbox=None, 18 | transform=None, normalize=None, levels=3): 19 | img = Image.open(img_path).convert('RGB') 20 | width, height = img.size 21 | if bbox is not None: 22 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75) 23 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 24 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 25 | y1 = np.maximum(0, center_y - r) 26 | y2 = np.minimum(height, center_y + r) 27 | x1 = np.maximum(0, center_x - r) 28 | x2 = np.minimum(width, center_x + r) 29 | img = img.crop([x1, y1, x2, y2]) 30 | 31 | if transform is not None: 32 | img = transform(img) 33 | 34 | ret = [] 35 | for i in range(levels): 36 | if i < (levels - 1): 37 | re_img = transforms.Resize(imsize[i])(img) 38 | else: 39 | re_img = img 40 | ret.append(normalize(re_img)) 41 | 42 | return ret 43 | 44 | def choose_one_image_path(rcp, img_dir): 45 | part = rcp['partition'] 46 | image_infos = rcp['images'] 47 | if part == 'train': 48 | # We do only use the first five images per recipe during training 49 | imgIdx = np.random.choice(range(min(5, len(image_infos)))) 50 | else: 51 | imgIdx = 0 52 | 53 | loader_path = [image_infos[imgIdx]['id'][i] for i in range(4)] 54 | loader_path = os.path.join(*loader_path) 55 | if 'plus' in img_dir: 56 | path = os.path.join(img_dir, loader_path, image_infos[imgIdx]['id']) 57 | else: 58 | path = os.path.join(img_dir, part, loader_path, image_infos[imgIdx]['id']) 59 | return path 60 | 61 | 62 | class FoodDataset(data.Dataset): 63 | def __init__( 64 | self, 65 | recipe_file, 66 | img_dir, 67 | levels=3, 68 | word2vec_file='../retrieval_model/models/word2vec_recipes.bin', 69 | vocab_ingrs_file='../manual_files/list_of_merged_ingredients.txt', 70 | part='train', 71 | food_type='salad', 72 | base_size=64, 73 | transform=None, 74 | num_samples=None): 75 | self.transform = transform 76 | self.norm = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 79 | 80 | self.imsize = [] 81 | self.levels = levels 82 | self.recipe_file = recipe_file 83 | self.img_dir = img_dir 84 | for _ in range(levels): 85 | self.imsize.append(base_size) 86 | base_size = base_size * 2 87 | 88 | self.recipes = load_recipes(recipe_file, part) 89 | if food_type: 90 | self.recipes = [x for x in self.recipes if food_type.lower() in x['title'].lower()] 91 | if num_samples: 92 | N = min(len(self.recipes), num_samples) 93 | self.recipes = np.random.choice(self.recipes, N, replace=False) 94 | 95 | wv = KeyedVectors.load(word2vec_file, mmap='r') 96 | w2i = {w: i+2 for i, w in enumerate(wv.index2word)} 97 | w2i[''] = 1 98 | self.w2i = w2i 99 | 100 | with open(vocab_ingrs_file, 'r') as f: 101 | vocab_ingrs = f.read().strip().split('\n') 102 | self.ingr2i = {ingr:i for i,ingr in enumerate(vocab_ingrs)} 103 | 104 | def __getitem__(self, index): 105 | rcp = self.recipes[index] 106 | 107 | title, n_words_in_title = get_title_wordvec(rcp, self.w2i) # np.int [max_len] 108 | ingredients, n_ingrs, _ = get_ingredients_wordvec_withClasses(rcp, self.w2i, self.ingr2i) # np.int [max_len] 109 | instructions, n_insts, n_words_each_inst = get_instructions_wordvec(rcp, self.w2i) # np.int [max_len, max_len] 110 | txt = (title, n_words_in_title, ingredients, n_ingrs, instructions, n_insts, n_words_each_inst) 111 | 112 | img_name = choose_one_image_path(rcp, self.img_dir) 113 | imgs = get_imgs(img_name, self.imsize, transform=self.transform, normalize=self.norm, levels=self.levels) 114 | 115 | all_idx = range(len(self.recipes)) 116 | wrong_idx = np.random.choice(all_idx) 117 | while wrong_idx == index: 118 | wrong_idx = np.random.choice(all_idx) 119 | wrong_img_name = choose_one_image_path(self.recipes[wrong_idx], self.img_dir) 120 | wrong_imgs = get_imgs(wrong_img_name, self.imsize, transform=self.transform, normalize=self.norm, levels=self.levels) 121 | 122 | return txt, imgs, wrong_imgs, rcp['title'] 123 | 124 | def __len__(self): 125 | return len(self.recipes) 126 | 127 | if __name__ == '__main__': 128 | class Args: pass 129 | args = Args() 130 | args.base_size = 64 131 | args.levels = 3 132 | args.recipe_file = '../data/Recipe1M/recipes_withImage.json' 133 | args.img_dir = '../data/Recipe1M/images' 134 | args.food_type = 'salad' 135 | args.batch_size = 32 136 | args.workers = 4 137 | 138 | imsize = args.base_size * (2 ** (args.levels-1)) 139 | train_transform = transforms.Compose([ 140 | transforms.Resize(int(imsize * 76 / 64)), 141 | transforms.RandomCrop(imsize), 142 | transforms.RandomHorizontalFlip()]) 143 | train_set = FoodDataset( 144 | recipe_file=args.recipe_file, 145 | img_dir=args.img_dir, 146 | levels=args.levels, 147 | part='train', 148 | food_type=args.food_type, 149 | base_size=args.base_size, 150 | transform=train_transform) 151 | train_loader = torch.utils.data.DataLoader( 152 | train_set, batch_size=args.batch_size, 153 | drop_last=False, shuffle=False, num_workers=int(args.workers)) 154 | 155 | for txt, imgs, w_imgs, title in train_loader: 156 | print(len(txt)) 157 | for one_txt in txt: 158 | print(one_txt.shape) 159 | 160 | print(len(imgs)) 161 | for img in imgs: 162 | print(img.shape) 163 | 164 | print(len(w_imgs)) 165 | for img in w_imgs: 166 | print(img.shape) 167 | 168 | print(len(title)) 169 | print(title[0]) 170 | input() -------------------------------------------------------------------------------- /cookgan/eval_cookgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | from torchvision import transforms 7 | from torchvision.utils import save_image 8 | from tqdm import tqdm 9 | from args_cookgan import args 10 | from models_StackGANv2 import G_NET 11 | from datasets_cookgan import Dataset 12 | import sys 13 | sys.path.append('../') 14 | from common import make_saveDir, load_retrieval_model, load_generation_model, mean, std, rank 15 | from scipy.spatial.distance import cdist, pdist 16 | import pdb 17 | 18 | assert args.resume != '' 19 | args.batch_size = 64 20 | print(args) 21 | torch.manual_seed(args.seed) 22 | np.random.seed(args.seed) 23 | 24 | device = torch.device('cuda' \ 25 | if torch.cuda.is_available() and args.cuda 26 | else 'cpu') 27 | print('device:', device) 28 | 29 | netG = load_generation_model(args.resume, device) 30 | abspath = os.path.abspath(__file__) 31 | dname = os.path.dirname(abspath) 32 | os.chdir('../') 33 | TxtEnc, ImgEnc = load_retrieval_model(args.retrieval_model, device) 34 | os.chdir(dname) 35 | 36 | imsize = 256 37 | image_transform = transforms.Compose([ 38 | transforms.Resize(256), 39 | transforms.CenterCrop(imsize)]) 40 | dataset = Dataset( 41 | args.data_dir, args.img_dir, food_type=args.food_type, 42 | levels=args.levels, part='test', 43 | base_size=args.base_size, transform=image_transform) 44 | # dataset = torch.utils.data.Subset(dataset, range(500)) 45 | dataloader = torch.utils.data.DataLoader( 46 | dataset, batch_size=args.batch_size, 47 | drop_last=True, shuffle=False, num_workers=int(args.workers)) 48 | print('=> dataset dataloader =', len(dataset), len(dataloader)) 49 | 50 | generation_model_name = args.resume.rsplit('/', 1)[-1].rsplit('.', 1)[0] 51 | save_dir = 'experiments/{}'.format(generation_model_name) 52 | if not os.path.exists(save_dir): 53 | os.makedirs(save_dir) 54 | 55 | def prepare_data(data): 56 | imgs, w_imgs, txt, _ = data 57 | 58 | real_vimgs, wrong_vimgs = [], [] 59 | for i in range(args.levels): 60 | real_vimgs.append(imgs[i].to(device)) 61 | wrong_vimgs.append(w_imgs[i].to(device)) 62 | 63 | vtxt = [x.to(device) for x in txt] 64 | return real_vimgs, wrong_vimgs, vtxt 65 | 66 | # fixed_noise = torch.zeros(args.batch_size, args.z_dim).to(device) 67 | fixed_noise = torch.FloatTensor(1, args.z_dim).normal_(0, 1).to(device) 68 | fixed_noise = fixed_noise.repeat(args.batch_size, 1) 69 | batch = 0 70 | 71 | txt_feats_real = [] 72 | img_feats_real = [] 73 | img_feats_fake = [] 74 | 75 | def _get_img_embeddings(img, ImgEnc): 76 | img = img/2 + 0.5 77 | img = F.interpolate(img, [224, 224], mode='bilinear', align_corners=True) 78 | for i in range(img.shape[1]): 79 | img[:,i] = (img[:,i]-mean[i])/std[i] 80 | with torch.no_grad(): 81 | img_feats = ImgEnc(img).detach().cpu() 82 | return img_feats 83 | 84 | for data in tqdm(dataloader): 85 | real_imgs, _, txt = prepare_data(data) 86 | txt_embedding = TxtEnc(txt) 87 | with torch.no_grad(): 88 | fake_imgs, _, _ = netG(fixed_noise, txt_embedding) 89 | 90 | txt_feats_real.append(txt_embedding.detach().cpu()) 91 | img_fake = fake_imgs[-1] 92 | img_embedding_fake = _get_img_embeddings(img_fake, ImgEnc) 93 | img_feats_fake.append(img_embedding_fake.detach().cpu()) 94 | img_real = real_imgs[-1] 95 | img_embedding_real = _get_img_embeddings(img_real, ImgEnc) 96 | img_feats_real.append(img_embedding_real.detach().cpu()) 97 | 98 | if batch == 0: 99 | noise = torch.FloatTensor(args.batch_size, args.z_dim).normal_(0, 1).to(device) 100 | one_txt_feat = txt_embedding[0:1] 101 | one_txt_feat = one_txt_feat.repeat(args.batch_size, 1) 102 | fakes, _, _ = netG(noise, one_txt_feat) 103 | save_image( 104 | fakes[-1], 105 | os.path.join(save_dir, 'random_noise_image0.jpg'), 106 | normalize=True, scale_each=True) 107 | 108 | save_image( 109 | fake_imgs[0], 110 | os.path.join(save_dir, 'batch{}_fake0.jpg'.format(batch)), 111 | normalize=True, scale_each=True) 112 | save_image( 113 | fake_imgs[1], 114 | os.path.join(save_dir, 'batch{}_fake1.jpg'.format(batch)), 115 | normalize=True, scale_each=True) 116 | save_image( 117 | fake_imgs[2], 118 | os.path.join(save_dir, 'batch{}_fake2.jpg'.format(batch)), 119 | normalize=True, scale_each=True) 120 | save_image( 121 | real_imgs[-1], 122 | os.path.join(save_dir, 'batch{}_real.jpg'.format(batch)), 123 | normalize=True) 124 | 125 | real_fake = torch.stack([real_imgs[-1], fake_imgs[-1]]).permute(1,0,2,3,4).contiguous() 126 | real_fake = real_fake.view(-1, real_fake.shape[-3], real_fake.shape[-2], real_fake.shape[-1]) 127 | save_image( 128 | real_fake, 129 | os.path.join(save_dir, 'batch{}_real_fake.jpg'.format(batch)), 130 | normalize=True, scale_each=True) 131 | batch += 1 132 | 133 | txt_feats_real = torch.cat(txt_feats_real, dim=0) 134 | img_feats_real = torch.cat(img_feats_real, dim=0) 135 | img_feats_fake = torch.cat(img_feats_fake, dim=0) 136 | cos = torch.nn.CosineSimilarity(dim=1) 137 | dists = cos(txt_feats_real, img_feats_real) 138 | print('=> Real txt and real img cosine (N={}): {:.4f}({:.4f})'.format(dists.shape[0], dists.mean().item(), dists.std().item())) 139 | dists = cos(txt_feats_real, img_feats_fake) 140 | print('=> Real txt and fake img cosine (N={}): {:.4f}({:.4f})'.format(dists.shape[0], dists.mean().item(), dists.std().item())) 141 | 142 | N = min(1000, txt_feats_real.shape[0]) 143 | 144 | idxs = np.random.choice(img_feats_real.shape[0], N, replace=False) 145 | sub = img_feats_real.numpy()[idxs] 146 | Y = 1-pdist(sub, 'cosine') 147 | print('=> Two random real images cosine (N={}): {:.4f}({:.4f})'.format(Y.shape[0], Y.mean().item(), Y.std().item())) 148 | 149 | idxs = np.random.choice(txt_feats_real.shape[0], N, replace=False) 150 | sub = txt_feats_real.numpy()[idxs] 151 | Y = 1-pdist(sub, 'cosine') 152 | print('=> Two random real texts cosine (N={}): {:.4f}({:.4f})'.format(Y.shape[0], Y.mean().item(), Y.std().item())) 153 | 154 | idxs = np.random.choice(img_feats_fake.shape[0], N, replace=False) 155 | sub = img_feats_fake.numpy()[idxs] 156 | Y = 1-pdist(sub, 'cosine') 157 | print('=> Two random fake images cosine (N={}): {:.4f}({:.4f})'.format(Y.shape[0], Y.mean().item(), Y.std().item())) 158 | 159 | 160 | print('=> computing ranks...') 161 | retrieved_range = min(900, len(dataloader)*args.batch_size) 162 | medR, medR_std, recalls = rank(txt_feats_real.numpy(), img_feats_real.numpy(), retrieved_type='recipe', retrieved_range=retrieved_range) 163 | print('=> Real MedR: {:.4f}({:.4f})'.format(medR, medR_std)) 164 | for k, v in recalls.items(): 165 | print('Real Recall@{} = {:.4f}'.format(k, v)) 166 | 167 | medR, medR_std, recalls = rank(txt_feats_real.numpy(), img_feats_fake.numpy(), retrieved_type='recipe', retrieved_range=retrieved_range) 168 | print('=> Fake MedR: {:.4f}({:.4f})'.format(medR, medR_std)) 169 | for k, v in recalls.items(): 170 | print('Fake Recall@{} = {:.4f}'.format(k, v)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the official repository for the WACV paper [CookGAN: Meal Image Synthesis from Ingredients](https://openaccess.thecvf.com/content_WACV_2020/papers/Han_CookGAN_Meal_Image_Synthesis_from_Ingredients_WACV_2020_paper.pdf). The code is tested with Python 3.8, PyTorch 1.6, CUDA 10.2 on Ubuntu 18.04 2 | 3 | # Prepare Dataset 4 | ## Dowload original dataset 5 | Download Recipe1M dataset from http://pic2recipe.csail.mit.edu/, make sure you have download and unzip all images and the files `det_ingrs.json`, `layer1.json`, `layer2.json`. Your data folder should look like the one shown below 6 | ``` 7 | CookGAN/data/Recipe1M/ 8 | images/ 9 | train/ 10 | val/ 11 | test/ 12 | recipe1M/ 13 | det_ingrs.json 14 | layer1.json 15 | layer2.json 16 | ``` 17 | Create an environment using Python 3.8, install the required packages. 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Simplify dataset 23 | run `python clean_recipes_with_canonical_ingrs.py` to generate `./data/Recipe1M/recipes_withImage.json` which just contains the *simplified* recipes with images (N=402760), e.g. 24 | ``` 25 | { 26 | "id": "00003a70b1", 27 | "url": "http://www.food.com/recipe/crunchy-onion-potato-bake-479149", 28 | "partition": "test", 29 | "title": "Crunchy Onion Potato Bake", 30 | "instructions": [ 31 | "Preheat oven to 350 degrees Fahrenheit.", 32 | "Spray pan with non stick cooking spray.", 33 | "Heat milk, water and butter to boiling; stir in contents of both pouches of potatoes; let stand one minute.", 34 | "Stir in corn.", 35 | "Spoon half the potato mixture in pan.", 36 | "Sprinkle half each of cheese and onions; top with remaining potatoes.", 37 | "Sprinkle with remaining cheese and onions.", 38 | "Bake 10 to 15 minutes until cheese is melted.", 39 | "Enjoy !" 40 | ], 41 | "ingredients": [ 42 | "milk", 43 | "water", 44 | "butter", 45 | "mashed potatoes", 46 | "whole kernel corn", 47 | "cheddar cheese", 48 | "French - fried onions" 49 | ], 50 | "valid": [ 51 | true, 52 | true, 53 | true, 54 | true, 55 | true, 56 | true, 57 | true 58 | ], 59 | "images": [ 60 | "3/e/2/3/3e233001e2.jpg", 61 | "7/f/7/4/7f749987f9.jpg", 62 | "a/a/f/6/aaf6b2dcd3.jpg" 63 | ] 64 | } 65 | ``` 66 | 67 | # Train Models 68 | All models (except word2vec) could be monitored using [wandb](https://www.wandb.com/). 69 | 70 | ## Train word2vec 71 | Go to `retrieval_model` and run `python train_word2vec.py` to generate `models/word2vec_recipes.bin`. 72 | 73 | ## Pre-train UPMC-Food-101 classifier 74 | Go to `./pretrain_upmc`, follow `./pretrain_upmc/README` to pretrain the image encoder on UPMC-Food-101 dataset. 75 | 76 | ## Train the attention-based retrieval model 77 | Run 78 | ``` 79 | CUDA_VISIBLE_DEVICES=0 bash run_retrieval.sh 80 | ``` 81 | to train the [attention-based recipe retrieval model](https://dl.acm.org/citation.cfm?id=3240627). Here, `010` means only using ingredients to train the model. The code also supports training using all three domains by `--text_info=111` (title+ingredients+instructions). 82 | 83 | ## Train CookGAN 84 | Go to `CookGAN/cookgan` and run 85 | ``` 86 | CUDA_VISIBLE_DEVICES=0 bash run.sh 87 | ``` 88 | to train CookGAN on salad. 89 | 90 | ## Test Models 91 | go to `CookGAN/metrics/`, 92 | 93 | * Update the configurations following `configs/salad+cookgan.yaml`. 94 | * Run `python calc_inception.py` to generate statistics for real images. 95 | * Run `python fid.py` to compute the FIDs under a certain checkpoint directory. 96 | * Run `python medR.py` to compute the median ranks under a certain checkpoint directory. 97 | 98 | ### Genearte an image from the pre-trained model 99 | 100 | 1. Download the trained model from the [Google drive folder](https://drive.google.com/drive/folders/1URwnLMVKx3avmUI0ITjxzgjFkpvmiS3Q?usp=sharing). 101 | 2. Run the notebook test_model.ipynb to generate an image. 102 | 103 | 186 | 187 | # License 188 | MIT -------------------------------------------------------------------------------- /retrieval_model/train_retrieval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.data import DataLoader 4 | from torch import optim 5 | from tqdm import tqdm 6 | import numpy as np 7 | import os 8 | import pdb 9 | import wandb 10 | 11 | from args_retrieval import get_parser 12 | from datasets_retrieval import Dataset, train_transform 13 | from models_retrieval import TextEncoder, ImageEncoder 14 | from triplet_loss import global_loss, TripletLoss 15 | from modules import DynamicSoftMarginLoss 16 | import sys 17 | sys.path.append('../') 18 | from common import param_counter, sample_data 19 | 20 | def create_model(ckpt_args, device='cuda'): 21 | text_encoder = TextEncoder( 22 | emb_dim=ckpt_args.word2vec_dim, 23 | hid_dim=ckpt_args.rnn_hid_dim, 24 | z_dim=ckpt_args.feature_dim, 25 | word2vec_file=ckpt_args.word2vec_file, 26 | text_info=ckpt_args.text_info, 27 | with_attention=ckpt_args.with_attention, 28 | ingrs_enc_type=ckpt_args.ingrs_enc_type) 29 | image_encoder = ImageEncoder( 30 | z_dim=ckpt_args.feature_dim) 31 | text_encoder, image_encoder = [x.to(device) for x in [text_encoder, image_encoder]] 32 | print('# text_encoder', param_counter(text_encoder.parameters())) 33 | print('# image_encoder', param_counter(image_encoder.parameters())) 34 | if device == 'cuda': 35 | text_encoder, image_encoder = [nn.DataParallel(x) for x in [text_encoder, image_encoder]] 36 | optimizer = torch.optim.Adam([ 37 | {'params': text_encoder.parameters()}, 38 | {'params': image_encoder.parameters()}, 39 | ], lr=ckpt_args.lr, betas=(0.5, 0.999)) 40 | return text_encoder, image_encoder, optimizer 41 | 42 | def load_model(ckpt_path, device='cuda'): 43 | print('load retrieval model from:', ckpt_path) 44 | ckpt = torch.load(ckpt_path) 45 | ckpt_args = ckpt['args'] 46 | batch_idx = ckpt['batch_idx'] 47 | text_encoder, image_encoder, optimizer = create_model(ckpt_args, device) 48 | if device=='cpu': 49 | text_encoder.load_state_dict(ckpt['text_encoder']) 50 | image_encoder.load_state_dict(ckpt['image_encoder']) 51 | else: 52 | text_encoder.module.load_state_dict(ckpt['text_encoder']) 53 | image_encoder.module.load_state_dict(ckpt['image_encoder']) 54 | optimizer.load_state_dict(ckpt['optimizer']) 55 | 56 | return ckpt_args, batch_idx, text_encoder, image_encoder, optimizer 57 | 58 | 59 | def save_model(args, batch_idx, text_encoder, image_encoder, optimizer, ckpt_path): 60 | print('save retrieval model to:', ckpt_path) 61 | ckpt = { 62 | 'args': args, 63 | 'batch_idx': batch_idx, 64 | 'text_encoder': text_encoder.state_dict(), 65 | 'image_encoder': image_encoder.state_dict(), 66 | 'optimizer': optimizer.state_dict(), 67 | } 68 | torch.save(ckpt, ckpt_path) 69 | 70 | # hinge loss 71 | def compute_loss(txt_feat, img_feat, device='cuda'): 72 | BS = txt_feat.shape[0] 73 | denom = img_feat.norm(p=2, dim=1, keepdim=True) @ txt_feat.norm(p=2, dim=1, keepdim=True).t() 74 | numer = img_feat @ txt_feat.t() 75 | sim = numer / (denom + 1e-12) 76 | margin = 0.3 * torch.ones_like(sim) 77 | mask = torch.eye(margin.shape[0], margin.shape[1]).bool().to(device) 78 | margin.masked_fill_(mask, 0) 79 | pos_sim = (torch.diag(sim) * torch.ones(BS, BS).to(device)).t() # [BS, BS] 80 | loss_retrieve_txt = torch.max( 81 | torch.tensor(0.0).to(device), 82 | margin + sim - pos_sim) 83 | loss_retrieve_img = torch.max( 84 | torch.tensor(0.0).to(device), 85 | margin + sim.t() - pos_sim) 86 | loss = loss_retrieve_img + loss_retrieve_txt 87 | # effective number of pairs is BS*BS-BS, those on the diagnal are never counted and always zero 88 | loss = loss.sum() / (BS*BS-BS) / 2.0 89 | return loss 90 | 91 | def train(args, start_batch_idx, text_encoder, image_encoder, optimizer, train_loader, device='cuda'): 92 | if args.loss_type == 'hinge': 93 | criterion = compute_loss 94 | elif args.loss_type == 'hardmining+hinge': 95 | triplet_loss = TripletLoss(margin=args.margin) 96 | elif args.loss_type == 'dynamic_soft_margin': 97 | criterion = DynamicSoftMarginLoss(is_binary=False, nbins=args.batch_size // 2) 98 | criterion = criterion.to(device) 99 | 100 | ##################### 101 | # train 102 | ##################### 103 | wandb.init(project="cookgan_retrieval_model") 104 | wandb.config.update(args) 105 | 106 | pbar = range(args.batches) 107 | pbar = tqdm(pbar, initial=start_batch_idx, dynamic_ncols=True, smoothing=0.3) 108 | 109 | text_encoder.train() 110 | image_encoder.train() 111 | if device=='cuda': 112 | text_module = text_encoder.module 113 | image_module = image_encoder.module 114 | else: 115 | text_module = text_encoder 116 | image_module = image_encoder 117 | train_loader = sample_data(train_loader) 118 | 119 | for batch_idx in pbar: 120 | txt, img = next(train_loader) 121 | for i in range(len(txt)): 122 | txt[i] = txt[i].to(device) 123 | img = img.to(device) 124 | 125 | txt_feat, _ = text_encoder(*txt) 126 | img_feat = image_encoder(img) 127 | bs = img.shape[0] 128 | if args.loss_type == 'hinge': 129 | loss = criterion(img_feat, txt_feat, device) 130 | elif args.loss_type == 'hardmining+hinge': 131 | label = list(range(0, bs)) 132 | label.extend(label) 133 | label = np.array(label) 134 | label = torch.tensor(label).long().to(device) 135 | loss = global_loss(triplet_loss, torch.cat((img_feat, txt_feat)), label, normalize_feature=True)[0] 136 | elif args.loss_type == 'dynamic_soft_margin': 137 | out = torch.cat((img_feat, txt_feat)) 138 | loss = criterion(out) 139 | 140 | optimizer.zero_grad() 141 | loss.backward() 142 | optimizer.step() 143 | 144 | wandb.log({ 145 | 'training loss': loss, 146 | 'batch_idx': batch_idx 147 | }) 148 | 149 | if batch_idx % 10_000 == 0: 150 | ckpt_path = f'{wandb.run.dir}/{batch_idx:>08d}.ckpt' 151 | save_model(args, batch_idx, text_module, image_module, optimizer, ckpt_path) 152 | 153 | if __name__ == '__main__': 154 | ############################## 155 | # setup 156 | ############################## 157 | args = get_parser().parse_args() 158 | torch.manual_seed(args.seed) 159 | np.random.seed(args.seed) 160 | torch.backends.cudnn.benchmark = True 161 | device = args.device 162 | 163 | ############################## 164 | # dataset 165 | ############################## 166 | print('loading datasets') 167 | train_set = Dataset( 168 | part='train', 169 | recipe_file=args.recipe_file, 170 | img_dir=args.img_dir, 171 | word2vec_file=args.word2vec_file, 172 | transform=train_transform, 173 | permute_ingrs=args.permute_ingrs) 174 | 175 | if args.debug: 176 | print('in debug mode') 177 | train_set = torch.utils.data.Subset(train_set, range(2000)) 178 | 179 | train_loader = DataLoader( 180 | train_set, batch_size=args.batch_size, shuffle=True, 181 | num_workers=args.workers, pin_memory=True, drop_last=False) 182 | print('train data:', len(train_set), len(train_loader)) 183 | 184 | ########################## 185 | # model 186 | ########################## 187 | if args.ckpt_path: 188 | ckpt_args, batch_idx, text_encoder, image_encoder, optimizer = load_model(args.ckpt_path, device) 189 | else: 190 | text_encoder, image_encoder, optimizer = create_model(args, device) 191 | batch_idx = 0 192 | 193 | train(args, batch_idx, text_encoder, image_encoder, optimizer, train_loader, device='cuda') 194 | -------------------------------------------------------------------------------- /cookgan/datasets_cookgan1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils import data 4 | from torch.nn import functional as F 5 | from torchvision import transforms 6 | import pickle 7 | import numpy as np 8 | import os 9 | import json 10 | from gensim.models.keyedvectors import KeyedVectors 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import matplotlib.pyplot as plt 14 | 15 | import sys 16 | sys.path.append('../') 17 | from common import load_recipes, get_title_wordvec, get_ingredients_wordvec_withClasses, get_instructions_wordvec 18 | 19 | def get_imgs(img_path, imsize, bbox=None, 20 | transform=None, normalize=None, levels=3): 21 | img = Image.open(img_path).convert('RGB') 22 | width, height = img.size 23 | if bbox is not None: 24 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75) 25 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 26 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 27 | y1 = np.maximum(0, center_y - r) 28 | y2 = np.minimum(height, center_y + r) 29 | x1 = np.maximum(0, center_x - r) 30 | x2 = np.minimum(width, center_x + r) 31 | img = img.crop([x1, y1, x2, y2]) 32 | 33 | if transform is not None: 34 | img = transform(img) 35 | 36 | ret = [] 37 | for i in range(levels): 38 | if i < (levels - 1): 39 | re_img = transforms.Resize(imsize[i])(img) 40 | else: 41 | re_img = img 42 | ret.append(normalize(re_img)) 43 | 44 | return ret 45 | 46 | def choose_one_image_path(rcp, img_dir): 47 | part = rcp['partition'] 48 | image_infos = rcp['images'] 49 | if part == 'train': 50 | # We do only use the first five images per recipe during training 51 | imgIdx = np.random.choice(range(min(5, len(image_infos)))) 52 | else: 53 | imgIdx = 0 54 | 55 | loader_path = [image_infos[imgIdx]['id'][i] for i in range(4)] 56 | loader_path = os.path.join(*loader_path) 57 | if 'plus' in img_dir: 58 | path = os.path.join(img_dir, loader_path, image_infos[imgIdx]['id']) 59 | else: 60 | path = os.path.join(img_dir, part, loader_path, image_infos[imgIdx]['id']) 61 | return path 62 | 63 | 64 | class FoodDataset(data.Dataset): 65 | def __init__( 66 | self, 67 | recipe_file, 68 | img_dir, 69 | levels=3, 70 | word2vec_file='../retrieval_model/models/word2vec_recipes.bin', 71 | vocab_ingrs_file='../manual_files/list_of_merged_ingredients.txt', 72 | part='train', 73 | food_type='salad', 74 | base_size=64, 75 | transform=None, 76 | num_samples=None): 77 | self.transform = transform 78 | self.norm = transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 81 | 82 | self.imsize = [] 83 | self.levels = levels 84 | self.recipe_file = recipe_file 85 | self.img_dir = img_dir 86 | for _ in range(levels): 87 | self.imsize.append(base_size) 88 | base_size = base_size * 2 89 | 90 | self.recipes = load_recipes(recipe_file, part) 91 | if food_type: 92 | self.recipes = [x for x in self.recipes if food_type.lower() in x['title'].lower()] 93 | if num_samples: 94 | N = min(len(self.recipes), num_samples) 95 | self.recipes = np.random.choice(self.recipes, N, replace=False) 96 | 97 | wv = KeyedVectors.load(word2vec_file, mmap='r') 98 | w2i = {w: i+2 for i, w in enumerate(wv.index2word)} 99 | w2i[''] = 1 100 | self.w2i = w2i 101 | 102 | with open(vocab_ingrs_file, 'r') as f: 103 | vocab_ingrs = f.read().strip().split('\n') 104 | self.ingr2i = {ingr:i for i,ingr in enumerate(vocab_ingrs)} 105 | 106 | def __getitem__(self, index): 107 | rcp = self.recipes[index] 108 | id_=rcp['id'] 109 | ''' 110 | title, n_words_in_title = get_title_wordvec(rcp, self.w2i) # np.int [max_len] 111 | ingredients, n_ingrs, _ = get_ingredients_wordvec_withClasses(rcp, self.w2i, self.ingr2i) # np.int [max_len] 112 | instructions, n_insts, n_words_each_inst = get_instructions_wordvec(rcp, self.w2i) # np.int [max_len, max_len] 113 | txt = (title, n_words_in_title, ingredients, n_ingrs, instructions, n_insts, n_words_each_inst) 114 | 115 | img_name = choose_one_image_path(rcp, self.img_dir) 116 | imgs = get_imgs(img_name, self.imsize, transform=self.transform, normalize=self.norm, levels=self.levels) 117 | 118 | all_idx = range(len(self.recipes)) 119 | wrong_idx = np.random.choice(all_idx) 120 | while wrong_idx == index: 121 | wrong_idx = np.random.choice(all_idx) 122 | wrong_img_name = choose_one_image_path(self.recipes[wrong_idx], self.img_dir) 123 | wrong_imgs = get_imgs(wrong_img_name, self.imsize, transform=self.transform, normalize=self.norm, levels=self.levels) 124 | 125 | return txt, imgs, wrong_imgs, rcp['title'],id_ 126 | ''' 127 | return id_ 128 | 129 | 130 | def __len__(self): 131 | return len(self.recipes) 132 | 133 | if __name__ == '__main__': 134 | class Args: pass 135 | args = Args() 136 | args.base_size = 64 137 | args.levels = 3 138 | args.recipe_file = '../data/Recipe1M/recipes_withImage.json' 139 | args.img_dir = '../data/Recipe1M/images' 140 | args.food_type = 'salad' 141 | args.batch_size = 32 142 | args.workers = 4 143 | 144 | imsize = args.base_size * (2 ** (args.levels-1)) 145 | train_transform = transforms.Compose([ 146 | transforms.Resize(int(imsize * 76 / 64)), 147 | transforms.RandomCrop(imsize), 148 | transforms.RandomHorizontalFlip()]) 149 | train_set = FoodDataset( 150 | recipe_file=args.recipe_file, 151 | img_dir=args.img_dir, 152 | levels=args.levels, 153 | part='train', 154 | food_type=args.food_type, 155 | base_size=args.base_size, 156 | transform=train_transform) 157 | train_loader = torch.utils.data.DataLoader( 158 | train_set, batch_size=args.batch_size, 159 | drop_last=False, shuffle=False, num_workers=int(args.workers)) 160 | 161 | l=list() 162 | #for txt, imgs, w_imgs, title,id_ in tqdm(train_loader): 163 | for id_ in tqdm(train_loader): 164 | l=l+list(id_) 165 | print(len(l)) 166 | 167 | train_set = FoodDataset( 168 | recipe_file=args.recipe_file, 169 | img_dir=args.img_dir, 170 | levels=args.levels, 171 | part='train', 172 | food_type='cookie', 173 | base_size=args.base_size, 174 | transform=train_transform) 175 | train_loader = torch.utils.data.DataLoader( 176 | train_set, batch_size=args.batch_size, 177 | drop_last=False, shuffle=False, num_workers=int(args.workers)) 178 | #for txt, imgs, w_imgs, title,id_ in tqdm(train_loader): 179 | for id_ in tqdm(train_loader): 180 | l=l+list(id_) 181 | print(len(l)) 182 | train_set = FoodDataset( 183 | recipe_file=args.recipe_file, 184 | img_dir=args.img_dir, 185 | levels=args.levels, 186 | part='train', 187 | food_type='muffin', 188 | base_size=args.base_size, 189 | transform=train_transform) 190 | train_loader = torch.utils.data.DataLoader( 191 | train_set, batch_size=args.batch_size, 192 | drop_last=False, shuffle=False, num_workers=int(args.workers)) 193 | #for txt, imgs, w_imgs, title,id_ in tqdm(train_loader): 194 | for id_ in tqdm(train_loader): 195 | l=l+list(id_) 196 | print(len(l)) 197 | D=pickle.load(open('recipe_vector1.pkl','rb')) 198 | D1=pickle.load(open('id_vector.pkl','rb')) 199 | counter=0 200 | cvg=list() 201 | idt=np.zeros(1362) 202 | N=0 203 | for id_ in l: 204 | if id_ in D: 205 | counter=counter+1 206 | idn=np.sum(D1[id_]>0) 207 | an=np.sum(D[id_]['number']>0) 208 | cvg.append(an/idn) 209 | idt=idt+(D1[id_]>0) 210 | N=N+idn 211 | ''' 212 | cvg=np.array(cvg) 213 | print(np.sum(cvg==1)) 214 | print(np.mean(cvg)) 215 | plt.hist(cvg) 216 | plt.show() 217 | ''' 218 | print(np.sum(idt>0)) 219 | print(N/counter) 220 | #print(counter) 221 | #print(counter/len(l)) -------------------------------------------------------------------------------- /retrieval_model/eval_ingr_retrieval.py: -------------------------------------------------------------------------------- 1 | # TODO: finish this script 2 | ''' 3 | CUDA_VISIBLE_DEVICES=0 python eval_ingr_retrieval.py \ 4 | --batch_size=32 --resume=models/010.ckpt \ 5 | --food_type=salad --hot_ingr=tomato --save_dir=experiments \ 6 | --generation_model=generative_model/models/salad.ckpt 7 | ''' 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from networks import TextEncoder, ImageEncoder 13 | from dataset import Dataset 14 | from torch.utils.data import DataLoader 15 | from torchvision import transforms 16 | from torchvision.utils import save_image 17 | from args import args 18 | from tqdm import tqdm 19 | import numpy as np 20 | import os 21 | import sys 22 | import math 23 | import pdb 24 | from copy import deepcopy 25 | from glob import glob 26 | from PIL import Image 27 | import json 28 | from utils import load_recipes, load_dict, load_retrieval_model, load_generation_model 29 | from utils import compute_img_feature, compute_txt_feature, transform 30 | from inflection import singularize, pluralize 31 | 32 | # type_ = 'salad' 33 | # hot_ingr = ['tomato', 'cucumber', 'black_olife', 'avocado', 'carrot', 'red_pepper'] 34 | 35 | # type_ = 'cookie' 36 | # hot_ingr = ['walnut', 'chocolate', 'coconut', 'molass', 'orange'] 37 | 38 | # type_ = 'muffin' 39 | # hot_ingr = ['blueberry', 'chocolate', 'oat', 'banana', 'cranberry'] 40 | 41 | assert args.resume != '' 42 | 43 | food_type = args.food_type 44 | hot_ingr = args.hot_ingr 45 | tops = 5 46 | 47 | generation_model_name = args.generation_model.rsplit('/', 1)[-1].rsplit('.', 1)[0] 48 | save_dir = os.path.join(args.save_dir, generation_model_name) 49 | if not os.path.exists(save_dir): 50 | print('create directory:', save_dir) 51 | os.makedirs(save_dir) 52 | 53 | recipes = load_recipes(os.path.join(args.data_dir,'recipesV1.json'), 'val') 54 | recipes = [x for x in recipes if food_type.lower() in x['title'].lower()] 55 | print('# {} recipes:'.format(food_type), len(recipes)) 56 | vocab_inst = load_dict(os.path.join(args.data_dir, 'vocab_inst.txt')) 57 | print('#vocab_inst:', len(vocab_inst)) 58 | vocab_ingr = load_dict(os.path.join(args.data_dir, 'vocab_ingr.txt')) 59 | print('#vocab_ingr:', len(vocab_ingr)) 60 | 61 | for recipe in recipes: 62 | ingrediens_list = recipe['ingredients'] 63 | recipe['new_ingrs'] = [] 64 | for name in ingrediens_list: 65 | if hot_ingr in name: 66 | recipe['new_ingrs'].append(hot_ingr) 67 | else: 68 | recipe['new_ingrs'].append(name) 69 | 70 | device = torch.device('cuda' \ 71 | if torch.cuda.is_available() and args.cuda 72 | else 'cpu') 73 | print('device:', device) 74 | if device.__str__() == 'cpu': 75 | args.batch_size = 16 76 | 77 | TxtEnc, ImgEnc = load_retrieval_model(args.resume, device) 78 | 79 | mean = [0.485, 0.456, 0.406] 80 | std = [0.229, 0.224, 0.225] 81 | 82 | print('\nhot_ingr:', hot_ingr) 83 | hot_recipes = [x for x in recipes if hot_ingr in x['new_ingrs']] 84 | cold_recipes = [x for x in recipes if hot_ingr not in x['new_ingrs']] 85 | print('#with={}/{} = {:.2f}, #without={}/{} = {:.2f}'.format( 86 | len(hot_recipes), len(recipes), 1.0*len(hot_recipes)/len(recipes), 87 | len(cold_recipes), len(recipes), 1.0*len(cold_recipes)/len(recipes))) 88 | 89 | recipes_a = [] 90 | recipes_b = [] 91 | threshold = 0.7 92 | for rcp_a in hot_recipes: 93 | tmp = deepcopy(rcp_a['new_ingrs']) 94 | tmp.remove(hot_ingr) 95 | ingrs_a = set(tmp) 96 | for rcp_b in cold_recipes: 97 | ingrs_b = set(rcp_b['new_ingrs']) 98 | union = ingrs_a.union(ingrs_b) 99 | common = ingrs_a.intersection(ingrs_b) 100 | if 1.0*len(common)/len(union) >= threshold: 101 | recipes_a.append(rcp_a) 102 | recipes_b.append(rcp_b) 103 | print('#{} pairs (IoU={:.2f}) = {}'.format(hot_ingr, threshold, len(recipes_a))) 104 | 105 | ids_a = set() 106 | uniques_a = [] 107 | for rcp in recipes_a: 108 | if rcp['id'] not in ids_a: 109 | uniques_a.append(rcp) 110 | ids_a.add(rcp['id']) 111 | ids_b = set() 112 | uniques_b = [] 113 | for rcp in recipes_b: 114 | if rcp['id'] not in ids_b: 115 | uniques_b.append(rcp) 116 | ids_b.add(rcp['id']) 117 | print('#unique = {}, #unique_ = {}'.format(len(uniques_a), len(uniques_b))) 118 | 119 | ##################################### 120 | print('-' * 40) 121 | print('compute REAL image features for interesting recipes') 122 | ##################################### 123 | if len(uniques_a)>1 and len(uniques_b)>1: 124 | print('compute REAL image features for recipes with {}'.format(hot_ingr)) 125 | img_a, img_feat_a = compute_img_feature(uniques_a, args.img_dir, ImgEnc, transform, device) 126 | save_image( 127 | img_a[:64], 128 | os.path.join(save_dir, '{}_with.jpg'.format(hot_ingr)), 129 | normalize=True) 130 | img_feat_a = img_feat_a.detach().cpu().numpy() 131 | 132 | print('compute REAL image features for recipes without {}'.format(hot_ingr)) 133 | img_b, img_feat_b = compute_img_feature(uniques_b, args.img_dir, ImgEnc, transform, device) 134 | save_image( 135 | img_b[:64], 136 | os.path.join(save_dir, '{}_without.jpg'.format(hot_ingr)), 137 | normalize=True) 138 | img_feat_b = img_feat_b.detach().cpu().numpy() 139 | else: 140 | print('unable to compute') 141 | sys.exit(-1) 142 | 143 | 144 | # ****************************************************** 145 | print('-' * 40) 146 | print('compute text features for all recipes') 147 | # ***************************************************** 148 | _, txt_feats = compute_txt_feature(recipes, TxtEnc, vocab_inst, vocab_ingr, device) 149 | txt_feats = txt_feats.cpu().numpy() 150 | 151 | # ****************************************************** 152 | print('-' * 40) 153 | print('compute coverage among the top {} retrieved recipes'.format(tops)) 154 | # ***************************************************** 155 | def compute_ingredient_retrival_score(imgs, txts, tops): 156 | imgs = imgs / np.linalg.norm(imgs, axis=1)[:, None] 157 | txts = txts / np.linalg.norm(txts, axis=1)[:, None] 158 | # retrieve recipe 159 | sims = np.dot(imgs, txts.T) # [N, N] 160 | # loop through the N similarities for images 161 | cvgs = [] 162 | for ii in range(imgs.shape[0]): 163 | # get a row of similarities for image ii 164 | sim = sims[ii,:] 165 | # sort indices in descending order 166 | sorting = np.argsort(sim)[::-1].tolist() 167 | topk_idxs = sorting[:tops] 168 | success = 0.0 169 | for rcp_idx in topk_idxs: 170 | rcp = recipes[rcp_idx] 171 | ingrs = rcp['new_ingrs'] 172 | if hot_ingr in ingrs: 173 | success += 1 174 | cvgs.append(success / tops) 175 | return np.array(cvgs) 176 | 177 | cvgs = compute_ingredient_retrival_score(img_feat_a, txt_feats, tops) 178 | print('Top {} avg coverage with {} (#={}) = {:.2f} ({:.2f})'.format( 179 | tops, hot_ingr, len(uniques_a), cvgs.mean(), cvgs.std())) 180 | cvgs = compute_ingredient_retrival_score(img_feat_b, txt_feats, tops) 181 | print('Top {} avg coverage without {} (#={}) = {:.2f} ({:.2f})'.format( 182 | tops, hot_ingr, len(uniques_b), cvgs.mean(), cvgs.std())) 183 | 184 | 185 | # ****************************************************** 186 | print('-' * 40) 187 | print('compute coverage for interpolation between with and without {}'.format(hot_ingr)) 188 | # ***************************************************** 189 | print('load pretrained generative model') 190 | netG = load_generation_model(args.generation_model, device) 191 | 192 | with open(os.path.join(save_dir, '{}_with.json'.format(hot_ingr)), 'w') as f: 193 | json.dump(uniques_a, f, indent=2) 194 | with open(os.path.join(save_dir, '{}_without.json'.format(hot_ingr)), 'w') as f: 195 | json.dump(uniques_b, f, indent=2) 196 | N = min(len(uniques_a), len(uniques_b)) 197 | N = min(N, 128) 198 | print('compute text features for recipes with {}'.format(hot_ingr)) 199 | _, txt_feat_y = compute_txt_feature(uniques_a[:N], TxtEnc, vocab_inst, vocab_ingr, device) 200 | print('compute text features for recipes without {}'.format(hot_ingr)) 201 | _, txt_feat_n = compute_txt_feature(uniques_b[:N], TxtEnc, vocab_inst, vocab_ingr, device) 202 | interpolate_points = [1.0, 0.75, 0.5, 0.25, 0.0] 203 | print('interpolate points:', interpolate_points) 204 | if args.cuda: 205 | fixed_noise = torch.zeros(1, 100).to(device) 206 | fixed_noise = fixed_noise.repeat(N, 1) 207 | 208 | imgs_all = [] 209 | for w_y in interpolate_points: 210 | txt_embedding = w_y*txt_feat_y + (1-w_y)*txt_feat_n 211 | with torch.no_grad(): 212 | fake_imgs, _, _ = netG(fixed_noise, txt_embedding) 213 | imgs = fake_imgs[-1] # those 256x256 images 214 | imgs_all.append(imgs[:12]) 215 | imgs = imgs/2 + 0.5 216 | imgs = F.interpolate(imgs, [224, 224], mode='bilinear', align_corners=True) 217 | for i in range(imgs.shape[1]): 218 | imgs[:,i] = (imgs[:,i]-mean[i])/std[i] 219 | with torch.no_grad(): 220 | img_feats = ImgEnc(imgs).detach().cpu().numpy() 221 | cvgs = compute_ingredient_retrival_score(img_feats, txt_feats, tops) 222 | print('with/without={:.2f}/{:.2f}, avg cvg (over {} recipes)={:.2f} ({:.2f})'.format( 223 | w_y, 1-w_y, N, cvgs.mean(), cvgs.std())) 224 | 225 | imgs_all = torch.stack(imgs_all) # [5, 8, 3, 256, 256] 226 | imgs_all = imgs_all.permute(1,0,2,3,4).contiguous() # [8, 5, 3, 256, 256] 227 | imgs_all = imgs_all.view(-1, 3, 256, 256) # [40, 3, 256, 256] 228 | save_image( 229 | imgs_all, 230 | os.path.join(save_dir, '{}_interpolations.jpg'.format(hot_ingr)), 231 | nrow=5, 232 | normalize=True, 233 | scale_each=True) -------------------------------------------------------------------------------- /metrics/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = models.inception_v3(pretrained=True, init_weights=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def fid_inception_v3(): 167 | """Build pretrained Inception model for FID computation 168 | 169 | The Inception model for FID computation uses a different set of weights 170 | and has a slightly different structure than torchvision's Inception. 171 | 172 | This method first constructs torchvision's Inception and then patches the 173 | necessary parts that are different in the FID Inception model. 174 | """ 175 | inception = models.inception_v3(num_classes=1008, 176 | aux_logits=False, 177 | pretrained=False, 178 | init_weights=True) 179 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 180 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 181 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 182 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 183 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 184 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 185 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 186 | inception.Mixed_7b = FIDInceptionE_1(1280) 187 | inception.Mixed_7c = FIDInceptionE_2(2048) 188 | 189 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 190 | inception.load_state_dict(state_dict) 191 | return inception 192 | 193 | 194 | class FIDInceptionA(models.inception.InceptionA): 195 | """InceptionA block patched for FID computation""" 196 | def __init__(self, in_channels, pool_features): 197 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 198 | 199 | def forward(self, x): 200 | branch1x1 = self.branch1x1(x) 201 | 202 | branch5x5 = self.branch5x5_1(x) 203 | branch5x5 = self.branch5x5_2(branch5x5) 204 | 205 | branch3x3dbl = self.branch3x3dbl_1(x) 206 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 207 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 208 | 209 | # Patch: Tensorflow's average pool does not use the padded zero's in 210 | # its average calculation 211 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 212 | count_include_pad=False) 213 | branch_pool = self.branch_pool(branch_pool) 214 | 215 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 216 | return torch.cat(outputs, 1) 217 | 218 | 219 | class FIDInceptionC(models.inception.InceptionC): 220 | """InceptionC block patched for FID computation""" 221 | def __init__(self, in_channels, channels_7x7): 222 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 223 | 224 | def forward(self, x): 225 | branch1x1 = self.branch1x1(x) 226 | 227 | branch7x7 = self.branch7x7_1(x) 228 | branch7x7 = self.branch7x7_2(branch7x7) 229 | branch7x7 = self.branch7x7_3(branch7x7) 230 | 231 | branch7x7dbl = self.branch7x7dbl_1(x) 232 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 233 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 234 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 235 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 236 | 237 | # Patch: Tensorflow's average pool does not use the padded zero's in 238 | # its average calculation 239 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 240 | count_include_pad=False) 241 | branch_pool = self.branch_pool(branch_pool) 242 | 243 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 244 | return torch.cat(outputs, 1) 245 | 246 | 247 | class FIDInceptionE_1(models.inception.InceptionE): 248 | """First InceptionE block patched for FID computation""" 249 | def __init__(self, in_channels): 250 | super(FIDInceptionE_1, self).__init__(in_channels) 251 | 252 | def forward(self, x): 253 | branch1x1 = self.branch1x1(x) 254 | 255 | branch3x3 = self.branch3x3_1(x) 256 | branch3x3 = [ 257 | self.branch3x3_2a(branch3x3), 258 | self.branch3x3_2b(branch3x3), 259 | ] 260 | branch3x3 = torch.cat(branch3x3, 1) 261 | 262 | branch3x3dbl = self.branch3x3dbl_1(x) 263 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 264 | branch3x3dbl = [ 265 | self.branch3x3dbl_3a(branch3x3dbl), 266 | self.branch3x3dbl_3b(branch3x3dbl), 267 | ] 268 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 269 | 270 | # Patch: Tensorflow's average pool does not use the padded zero's in 271 | # its average calculation 272 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 273 | count_include_pad=False) 274 | branch_pool = self.branch_pool(branch_pool) 275 | 276 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 277 | return torch.cat(outputs, 1) 278 | 279 | 280 | class FIDInceptionE_2(models.inception.InceptionE): 281 | """Second InceptionE block patched for FID computation""" 282 | def __init__(self, in_channels): 283 | super(FIDInceptionE_2, self).__init__(in_channels) 284 | 285 | def forward(self, x): 286 | branch1x1 = self.branch1x1(x) 287 | 288 | branch3x3 = self.branch3x3_1(x) 289 | branch3x3 = [ 290 | self.branch3x3_2a(branch3x3), 291 | self.branch3x3_2b(branch3x3), 292 | ] 293 | branch3x3 = torch.cat(branch3x3, 1) 294 | 295 | branch3x3dbl = self.branch3x3dbl_1(x) 296 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 297 | branch3x3dbl = [ 298 | self.branch3x3dbl_3a(branch3x3dbl), 299 | self.branch3x3dbl_3b(branch3x3dbl), 300 | ] 301 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 302 | 303 | # Patch: The FID Inception model uses max pooling instead of average 304 | # pooling. This is likely an error in this specific Inception 305 | # implementation, as other Inception models use average pooling here 306 | # (which matches the description in the paper). 307 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 308 | branch_pool = self.branch_pool(branch_pool) 309 | 310 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 311 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /retrieval_model/models_retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn.utils import rnn 7 | from torchvision import models 8 | from gensim.models.keyedvectors import KeyedVectors 9 | import pdb 10 | 11 | 12 | def clean_state_dict(state_dict): 13 | # create new OrderedDict that does not contain `module.` 14 | from collections import OrderedDict 15 | new_state_dict = OrderedDict() 16 | for k, v in state_dict.items(): 17 | name = k[7:] if k[:min(6,len(k))] == 'module' else k # remove `module.` 18 | new_state_dict[name] = v 19 | return new_state_dict 20 | 21 | class AttentionLayer(nn.Module): 22 | def __init__(self, input_dim): 23 | super(AttentionLayer, self).__init__() 24 | self.u = torch.nn.Parameter(torch.randn(input_dim)) # u = [2*hid_dim] 25 | self.u.requires_grad = True 26 | self.fc = nn.Linear(input_dim, input_dim) 27 | def forward(self, x): 28 | # x = [BS, num_vec, 2*hid_dim] 29 | mask = (x!=0) 30 | # a trick used to find the mask for the softmax 31 | mask = mask[:,:,0].bool() 32 | h = torch.tanh(self.fc(x)) # h = [BS, num_vec, 2*hid_dim] 33 | tmp = h @ self.u # tmp = [BS, num_vec], unnormalized importance 34 | masked_tmp = tmp.masked_fill(~mask, -1e32) 35 | alpha = F.softmax(masked_tmp, dim=1) # alpha = [BS, num_vec], normalized importance 36 | alpha = alpha.unsqueeze(-1) # alpha = [BS, num_vec, 1] 37 | out = x * alpha # out = [BS, num_vec, 2*hid_dim] 38 | out = out.sum(dim=1) # out = [BS, 2*hid_dim] 39 | # pdb.set_trace() 40 | return out 41 | 42 | 43 | class InstEmbedLayer(nn.Module): 44 | def __init__(self, data_dir, emb_dim): 45 | super(InstEmbedLayer, self).__init__() 46 | self.data_dir = data_dir 47 | path = os.path.join(self.data_dir, 'word2vec.bin') 48 | # model = KeyedVectors.load_word2vec_format(path, binary=True) 49 | wv = KeyedVectors.load(path, mmap='r') 50 | vec = torch.from_numpy(wv.vectors).float() 51 | # first three index has special meaning, see utils.py 52 | emb = nn.Embedding(vec.shape[0]+3, vec.shape[1], padding_idx=0) 53 | emb.weight.data[3:].copy_(vec) 54 | for p in emb.parameters(): 55 | p.requires_grad = False 56 | self.embed_layer = emb 57 | # print('==> Inst embed layer', emb) 58 | 59 | def forward(self, sent_list): 60 | # sent_list [BS, max_len] 61 | return self.embed_layer(sent_list) # x=[BS, max_len, emb_dim] 62 | 63 | class IngrEmbedLayer(nn.Module): 64 | def __init__(self, data_dir, emb_dim): 65 | super(IngrEmbedLayer, self).__init__() 66 | path = os.path.join(data_dir, 'vocab_ingr.txt') 67 | with open(path, 'r') as f: 68 | num_ingr = len(f.read().split('\n')) 69 | # first three index has special meaning, see utils.py 70 | emb = nn.Embedding(num_ingr+3, emb_dim, padding_idx=0) 71 | self.embed_layer = emb 72 | # print('==> Ingr embed layer', emb) 73 | 74 | def forward(self, sent_list): 75 | # sent_list [BS, max_len] 76 | return self.embed_layer(sent_list) # x=[BS, max_len, emb_dim] 77 | 78 | class SentEncoder(nn.Module): 79 | def __init__( 80 | self, 81 | data_dir, 82 | emb_dim, 83 | hid_dim, 84 | with_attention=True, 85 | source='inst'): 86 | assert source in ('inst', 'ingr') 87 | super(SentEncoder, self).__init__() 88 | if source=='inst': 89 | self.embed_layer = InstEmbedLayer(data_dir=data_dir, emb_dim=emb_dim) 90 | elif source=='ingr': 91 | self.embed_layer = IngrEmbedLayer(data_dir=data_dir, emb_dim=emb_dim) 92 | self.rnn = nn.LSTM( 93 | input_size=emb_dim, 94 | hidden_size=hid_dim, 95 | bidirectional=True, 96 | batch_first=True) 97 | if with_attention: 98 | self.atten_layer = AttentionLayer(2*hid_dim) 99 | self.with_attention = with_attention 100 | 101 | def forward(self, sent_list): 102 | # sent_list [BS, max_len] 103 | x = self.embed_layer(sent_list) # x=[BS, max_len, emb_dim] 104 | # print(sent_list) 105 | # lens = (sent_list==1).nonzero()[:,1] + 1 106 | lens = sent_list.count_nonzero(dim=1) + 1 107 | # print(lens.shape) 108 | sorted_len, sorted_idx = lens.sort(0, descending=True) # sorted_idx=[BS], for sorting 109 | _, original_idx = sorted_idx.sort(0, descending=False) # original_idx=[BS], for unsorting 110 | # print(sorted_idx.shape, x.shape) 111 | index_sorted_idx = sorted_idx.view(-1,1,1).expand_as(x) # sorted_idx=[BS, max_len, emb_dim] 112 | sorted_inputs = x.gather(0, index_sorted_idx.long()) # sort by num_words 113 | packed_seq = rnn.pack_padded_sequence( 114 | sorted_inputs, sorted_len.cpu().numpy(), batch_first=True) 115 | 116 | if self.with_attention: 117 | out, _ = self.rnn(packed_seq) 118 | y, _ = rnn.pad_packed_sequence( 119 | out, batch_first=True) # y=[BS, max_len, 2*hid_dim], currently in WRONG order! 120 | unsorted_idx = original_idx.view(-1,1,1).expand_as(y) 121 | output = y.gather(0, unsorted_idx).contiguous() # [BS, max_len, 2*hid_dim], now in correct order 122 | feat = self.atten_layer(output) 123 | else: 124 | _, (h,_) = self.rnn(packed_seq) # [2, BS, hid_dim], currently in WRONG order! 125 | h = h.transpose(0,1) # [BS, 2, hid_dim], still in WRONG order! 126 | # unsort the output 127 | unsorted_idx = original_idx.view(-1,1,1).expand_as(h) 128 | output = h.gather(0, unsorted_idx).contiguous() # [BS, 2, hid_dim], now in correct order 129 | feat = output.view(output.size(0), output.size(1)*output.size(2)) # [BS, 2*hid_dim] 130 | 131 | # print('sent', feat.shape) # [BS, 2*hid_dim] 132 | return feat 133 | 134 | 135 | class SentEncoderFC(nn.Module): 136 | def __init__( 137 | self, 138 | data_dir, 139 | emb_dim, 140 | hid_dim, 141 | with_attention=True, 142 | source='inst'): 143 | assert source in ('inst', 'ingr') 144 | super(SentEncoderFC, self).__init__() 145 | if source=='inst': 146 | self.embed_layer = InstEmbedLayer(data_dir=data_dir, emb_dim=emb_dim) 147 | elif source=='ingr': 148 | self.embed_layer = IngrEmbedLayer(data_dir=data_dir, emb_dim=emb_dim) 149 | self.fc = nn.Linear(emb_dim, 2*hid_dim) 150 | if with_attention: 151 | self.atten_layer = AttentionLayer(2*hid_dim) 152 | self.with_attention = with_attention 153 | 154 | def forward(self, sent_list): 155 | # sent_list [BS, max_len] 156 | x = self.embed_layer(sent_list) # x=[BS, max_len, emb_dim] 157 | x = self.fc(x) # [BS, max_len, 2*hid_dim] 158 | if not self.with_attention: 159 | feat = x.sum(dim=1) # [BS, 2*hid_dim] 160 | else: 161 | feat = self.atten_layer(x) # [BS, 2*hid_dim] 162 | # print('ingredients', feat.shape) 163 | return feat 164 | 165 | 166 | class DocEncoder(nn.Module): 167 | def __init__(self, sent_encoder, hid_dim, with_attention): 168 | super(DocEncoder, self).__init__() 169 | self.sent_encoder = sent_encoder 170 | self.rnn = nn.LSTM( 171 | input_size=2*hid_dim, 172 | hidden_size=hid_dim, 173 | bidirectional=True, 174 | batch_first=True) 175 | self.atten_layer_sent = AttentionLayer(2*hid_dim) 176 | self.with_attention = with_attention 177 | 178 | def forward(self, doc_list): 179 | # doc_list=[BS, max_len, max_len] 180 | embs = [] 181 | lens = [] 182 | for doc in doc_list: 183 | len_doc = doc.nonzero()[:,0].max().item() + 1 184 | lens.append(len_doc) 185 | emb_doc = self.sent_encoder(doc[:len_doc]) # [?, 2*hid_dim] 186 | embs.append(emb_doc) 187 | 188 | embs = sorted(embs, key=lambda x: -x.shape[0]) # [BS, [?, 2*hid_dim]] 189 | packed_seq = rnn.pack_sequence(embs) 190 | lens = torch.tensor(lens).long().to(embs[0].device) 191 | _, sorted_idx = lens.sort(0, descending=True) # sorted_idx=[BS], for sorting 192 | _, original_idx = sorted_idx.sort(0, descending=False) # original_idx=[BS], for unsorting 193 | 194 | if not self.with_attention: 195 | _, (h,_) = self.rnn(packed_seq) # [2, BS, hid_dim], currently in WRONG order! 196 | h = h.transpose(0,1) # [BS, 2, hid_dim], still in WRONG order! 197 | # unsort the output 198 | unsorted_idx = original_idx.view(-1,1,1).expand_as(h) 199 | output = h.gather(0, unsorted_idx).contiguous() # [BS, 2, hid_dim], now in correct order 200 | feat = output.view(output.size(0), output.size(1)*output.size(2)) # [BS, 2*hid_dim] 201 | else: 202 | out, _ = self.rnn(packed_seq) 203 | y, _ = rnn.pad_packed_sequence( 204 | out, batch_first=True) # y=[BS, max_valid_len, 2*hid_dim], currently in WRONG order! 205 | unsorted_idx = original_idx.view(-1,1,1).expand_as(y) 206 | output = y.gather(0, unsorted_idx).contiguous() # [BS, 2, hid_dim], now in correct order 207 | feat = self.atten_layer_sent(output) 208 | 209 | # print('instructions', feat.shape) 210 | return feat 211 | 212 | 213 | class TextEncoder(nn.Module): 214 | def __init__( 215 | self, data_dir, text_info, hid_dim, emb_dim, z_dim, with_attention, ingr_enc_type): 216 | super(TextEncoder, self).__init__() 217 | self.text_info = text_info 218 | if self.text_info == '111': 219 | self.sent_encoder = SentEncoder( 220 | data_dir, 221 | emb_dim, 222 | hid_dim, 223 | with_attention, 224 | source='inst') 225 | self.doc_encoder = DocEncoder( 226 | self.sent_encoder, 227 | hid_dim, 228 | with_attention 229 | ) 230 | if ingr_enc_type=='rnn': 231 | self.ingr_encoder = SentEncoder( 232 | data_dir, 233 | emb_dim, 234 | hid_dim, 235 | with_attention, 236 | source='ingr') 237 | elif ingr_enc_type == 'fc': 238 | self.ingr_encoder = SentEncoderFC( 239 | data_dir, 240 | emb_dim, 241 | hid_dim, 242 | with_attention, 243 | source='ingr') 244 | self.bn = nn.BatchNorm1d((2+2+2)*hid_dim) 245 | self.fc = nn.Linear((2+2+2)*hid_dim, z_dim) 246 | 247 | elif self.text_info == '010': 248 | if ingr_enc_type=='rnn': 249 | self.ingr_encoder = SentEncoder( 250 | data_dir, 251 | emb_dim, 252 | hid_dim, 253 | with_attention, 254 | source='ingr') 255 | elif ingr_enc_type == 'fc': 256 | self.ingr_encoder = SentEncoderFC( 257 | data_dir, 258 | emb_dim, 259 | hid_dim, 260 | with_attention, 261 | source='ingr') 262 | self.bn = nn.BatchNorm1d(2*hid_dim) 263 | self.fc = nn.Linear(2*hid_dim, z_dim) 264 | 265 | def forward(self, recipe_list): 266 | title_list, ingredients_list, instructions_list = recipe_list 267 | if self.text_info == '111': 268 | feat_title = self.sent_encoder(title_list) 269 | feat_ingredients = self.ingr_encoder(ingredients_list) 270 | feat_instructions = self.doc_encoder(instructions_list) 271 | feat = torch.cat([feat_title, feat_ingredients, feat_instructions], dim=1) 272 | feat = torch.tanh(self.fc(self.bn(feat))) 273 | elif self.text_info == '010': 274 | feat_ingredients = self.ingr_encoder(ingredients_list) 275 | feat = torch.tanh(self.fc(self.bn(feat_ingredients))) 276 | # print('recipe', feat.shape) 277 | return feat 278 | 279 | 280 | class Resnet(nn.Module): 281 | def __init__(self, ckpt_path=None): 282 | super(Resnet, self).__init__() 283 | resnet = models.resnet50(pretrained=False) 284 | num_feat = resnet.fc.in_features 285 | resnet.fc = nn.Linear(num_feat, 101) 286 | if ckpt_path: 287 | resnet.load_state_dict(clean_state_dict(torch.load(ckpt_path))) 288 | modules = list(resnet.children())[:-1] # we do not use the last fc layer. 289 | self.encoder = nn.Sequential(*modules) 290 | 291 | def forward(self, image_list): 292 | BS = image_list.shape[0] 293 | return self.encoder(image_list).view(BS, -1) 294 | 295 | class ImageEncoder(nn.Module): 296 | def __init__(self, z_dim, ckpt_path=None): 297 | super(ImageEncoder, self).__init__() 298 | self.resnet = Resnet(ckpt_path) 299 | self.bottleneck = nn.Sequential( 300 | nn.BatchNorm1d(2048), 301 | nn.Linear(2048, z_dim), 302 | nn.Tanh() 303 | ) 304 | 305 | def forward(self, image_list): 306 | feat = self.resnet(image_list) 307 | feat = self.bottleneck(feat) 308 | # print('image', feat.shape) 309 | return feat -------------------------------------------------------------------------------- /cookgan/train_cookgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from torch.utils.tensorboard import SummaryWriter 3 | import wandb 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torchvision import transforms 7 | import torchvision.utils as vutils 8 | from torch import optim 9 | import numpy as np 10 | import random 11 | import pprint 12 | import os 13 | from tqdm import tqdm 14 | import json 15 | import pdb 16 | 17 | from args_cookgan import get_parser 18 | from datasets_cookgan import FoodDataset 19 | from models_cookgan import G_NET, D_NET64, D_NET128, D_NET256 20 | from utils_cookgan import prepare_data, compute_img_feat, compute_txt_feat, save_img_results 21 | 22 | import sys 23 | sys.path.append('../') 24 | from common import param_counter, sample_data, clean_state_dict, requires_grad 25 | sys.path.append('../retrieval_model') 26 | import train_retrieval 27 | import utils_retrieval 28 | 29 | def weights_init(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('Conv') != -1: 32 | nn.init.orthogonal_(m.weight.data, 1.0) 33 | elif classname.find('BatchNorm') != -1: 34 | m.weight.data.normal_(1.0, 0.02) 35 | m.bias.data.fill_(0) 36 | elif classname.find('Linear') != -1: 37 | nn.init.orthogonal_(m.weight.data, 1.0) 38 | if m.bias is not None: 39 | m.bias.data.fill_(0.0) 40 | 41 | 42 | def save_model(args, batch_idx, netG, optimizerG, netsD, optimizersD, ckpt_path): 43 | print('save to:', ckpt_path) 44 | ckpt = {} 45 | ckpt['args'] = args 46 | ckpt['batch_idx'] = batch_idx 47 | ckpt['netG'] = netG.state_dict() 48 | ckpt['optimizerG'] = optimizerG.state_dict() 49 | for i in range(len(netsD)): 50 | netD = netsD[i] 51 | optimizerD = optimizersD[i] 52 | ckpt['netD_{}'.format(i)] = netD.state_dict() 53 | ckpt['optimizerD_{}'.format(i)] = optimizerD.state_dict() 54 | torch.save(ckpt, ckpt_path) 55 | 56 | 57 | def create_model(args, device='cuda'): 58 | netG = G_NET( 59 | gf_dim=64, z_dim=args.z_dim, 60 | text_dim=args.input_dim, embedding_dim=args.embedding_dim, 61 | r_num=2, levels=args.levels, b_condition=True, ca=True).to(device) 62 | netG.apply(weights_init) 63 | print('# params in netG =', param_counter(netG.parameters())) 64 | 65 | netsD = [] 66 | netsD.append(D_NET64()) 67 | if args.levels >= 2: 68 | netsD.append(D_NET128()) 69 | if args.levels >= 3: 70 | netsD.append(D_NET256()) 71 | for i in range(len(netsD)): 72 | netsD[i] = netsD[i].to(device) 73 | netsD[i].apply(weights_init) 74 | print('# params in netD_{} ='.format(i), param_counter(netsD[i].parameters())) 75 | 76 | if device=='cuda': 77 | netG = nn.DataParallel(netG) 78 | for i in range(len(netsD)): 79 | netsD[i] = nn.DataParallel(netsD[i]) 80 | 81 | optimizerG = optim.Adam(netG.parameters(), 82 | lr=args.lr_g, 83 | betas=(0.5, 0.999)) 84 | optimizersD = [] 85 | num_Ds = len(netsD) 86 | for i in range(num_Ds): 87 | opt = optim.Adam(netsD[i].parameters(), 88 | lr=args.lr_d, 89 | betas=(0.5, 0.999)) 90 | optimizersD.append(opt) 91 | return netG, netsD, optimizerG, optimizersD 92 | 93 | def load_model(ckpt_path, device='cuda'): 94 | print('load CookGAN model from:', ckpt_path) 95 | ckpt = torch.load(ckpt_path) 96 | ckpt_args = ckpt['args'] 97 | batch = ckpt['batch_idx'] 98 | 99 | netG, netsD, optimizerG, optimizersD = create_model(ckpt_args, device) 100 | if device=='cuda': 101 | netG.module.load_state_dict(ckpt['netG']) 102 | else: 103 | netG.load_state_dict(ckpt['netG']) 104 | optimizerG.load_state_dict(ckpt['optimizerG']) 105 | for i in range(len(netsD)): 106 | if device=='cuda': 107 | netsD[i].module.load_state_dict(ckpt['netD_{}'.format(i)]) 108 | else: 109 | netsD[i].load_state_dict(ckpt['netD_{}'.format(i)]) 110 | optimizersD[i].load_state_dict(ckpt['optimizerD_{}'.format(i)]) 111 | return ckpt_args, batch, netG, optimizerG, netsD, optimizersD 112 | 113 | def compute_cycle_loss(feat1, feat2, paired=True, device='cuda'): 114 | if paired: 115 | loss = nn.CosineEmbeddingLoss(0.3)(feat1, feat2, torch.ones(feat1.shape[0]).to(device)) 116 | else: 117 | loss = nn.CosineEmbeddingLoss(0.3)(feat1, feat2, -torch.ones(feat1.shape[0]).to(device)) 118 | return loss 119 | 120 | def compute_kl(mu, logvar, embedding_dim=128): 121 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 122 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) 123 | # return KLD.mean() # correct 124 | return KLD.mean() / embedding_dim # not correct, this is just to follow the official code 125 | 126 | 127 | def train( 128 | args, batch_start, train_loader, device, 129 | txt_encoder, img_encoder, 130 | netG, optimizerG, netsD, optimizersD, criterion, 131 | fixed_noise, fixed_txt, fixed_img, fixed_title, save_dir): 132 | 133 | if device=='cuda': 134 | netG_to_save = netG.module 135 | netsD_to_save = [] 136 | for i in range(len(netsD)): 137 | netsD_to_save.append(netsD[i].module) 138 | else: 139 | netG_to_save = netG 140 | netsD_to_save = [] 141 | for i in range(len(netsD)): 142 | netsD_to_save.append(netsD[i]) 143 | 144 | noise = torch.FloatTensor(args.batch_size, args.z_dim).to(device) 145 | loader = sample_data(train_loader) 146 | for batch_idx in tqdm(range(batch_start, batch_start+args.num_batches)): 147 | if args.labels == 'original': 148 | real_labels = torch.FloatTensor(args.batch_size).fill_( 149 | 1) # (torch.FloatTensor(args.batch_size).uniform_() < 0.9).float() # 150 | fake_labels = torch.FloatTensor(args.batch_size).fill_( 151 | 0) # (torch.FloatTensor(args.batch_size).uniform_() > 0.9).float() # 152 | elif args.labels == 'R-smooth': 153 | real_labels = torch.FloatTensor(args.batch_size).fill_(1) - ( 154 | torch.FloatTensor(args.batch_size).uniform_() * 0.1) 155 | fake_labels = (torch.FloatTensor(args.batch_size).uniform_() * 0.1) 156 | elif args.labels == 'R-flip': 157 | real_labels = (torch.FloatTensor(args.batch_size).uniform_() < 0.9).float() # 158 | fake_labels = (torch.FloatTensor(args.batch_size).uniform_() > 0.9).float() # 159 | elif args.labels == 'R-flip-smooth': 160 | real_labels = torch.abs((torch.FloatTensor(args.batch_size).uniform_() > 0.9).float() - ( 161 | torch.FloatTensor(args.batch_size).fill_(1) - ( 162 | torch.FloatTensor(args.batch_size).uniform_() * 0.1))) 163 | fake_labels = torch.abs((torch.FloatTensor(args.batch_size).uniform_() > 0.9).float() - ( 164 | torch.FloatTensor(args.batch_size).uniform_() * 0.1)) 165 | 166 | real_labels = real_labels.to(device) 167 | fake_labels = fake_labels.to(device) 168 | 169 | data = next(loader) 170 | txt, real_imgs, wrong_imgs = prepare_data(data, device) 171 | with torch.no_grad(): 172 | txt_feat = compute_txt_feat(txt, txt_encoder) 173 | noise.normal_(0, 1).to(device) 174 | fake_imgs, mu, logvar = netG(noise, txt_feat) 175 | 176 | ###################### 177 | # train Discriminators 178 | ###################### 179 | errD_total = 0 180 | for level in range(args.levels): 181 | if args.input_noise: 182 | sigma = np.clip(1.0 - batch_idx/80_000, 0, 1) * 0.1 183 | real_img_noise = torch.empty_like(real_imgs[level]).normal_(0, sigma) 184 | wrong_img_noise = torch.empty_like(wrong_imgs[level]).normal_(0, sigma) 185 | fake_img_noise = torch.empty_like(fake_imgs[level]).normal_(0, sigma) 186 | else: 187 | real_img_noise = torch.zeros_like(real_imgs[level]) 188 | wrong_img_noise = torch.zeros_like(wrong_imgs[level]) 189 | fake_img_noise = torch.zeros_like(fake_imgs[level]) 190 | 191 | netD = netsD[level] 192 | optD = optimizersD[level] 193 | real_logits = netD(real_imgs[level]+real_img_noise, mu.detach()) 194 | wrong_logits = netD(wrong_imgs[level]+wrong_img_noise, mu.detach()) 195 | fake_logits = netD(fake_imgs[level].detach()+fake_img_noise, mu.detach()) 196 | 197 | errD_real = criterion(real_logits[0], real_labels) # cond_real --> 1 198 | errD_wrong = criterion(wrong_logits[0], fake_labels) # cond_wrong --> 0 199 | errD_fake = criterion(fake_logits[0], fake_labels) # cond_fake --> 0 200 | errD_cond = errD_real + errD_wrong + errD_fake 201 | 202 | if len(real_logits)>1: 203 | errD_real_uncond = criterion(real_logits[1], real_labels) # uncond_real --> 1 204 | errD_wrong_uncond = criterion(wrong_logits[1], real_labels) # uncond_wrong --> 1 205 | errD_fake_uncond = criterion(fake_logits[1], fake_labels) # uncond_fake --> 0 206 | errD_uncond = errD_real_uncond + errD_wrong_uncond + errD_fake_uncond 207 | else: # back to GAN-INT-CLS 208 | errD_cond = errD_real + 0.5 * (errD_wrong + errD_fake) 209 | errD_uncond = 0.0 210 | 211 | errD = errD_cond + args.uncond * errD_uncond 212 | 213 | optD.zero_grad() 214 | errD.backward() 215 | optD.step() 216 | 217 | # record 218 | errD_total += errD 219 | 220 | wandb.log({ 221 | f'errD_cond{level}': errD_cond, 222 | f'errD_uncond{level}': errD_uncond, 223 | f'errD{level}': errD, 224 | f'batch_idx': batch_idx, 225 | }) 226 | 227 | ###################### 228 | # train Generator 229 | ###################### 230 | errG_total = 0.0 231 | for level in range(args.levels): 232 | if args.input_noise: 233 | sigma = np.clip(1.0 - batch_idx/80_000, 0, 1) * 0.1 234 | fake_img_noise = torch.empty_like(fake_imgs[level]).normal_(0, sigma) 235 | else: 236 | fake_img_noise = torch.zeros_like(fake_imgs[level]) 237 | 238 | outputs = netsD[level](fake_imgs[level] + fake_img_noise, mu) 239 | errG_cond = criterion(outputs[0], real_labels) # cond_fake --> 1 240 | errG_uncond = criterion(outputs[1], real_labels) # uncond_fake --> 1 241 | 242 | fake_img_feat = compute_img_feat(fake_imgs[level], img_encoder) 243 | errG_cycle_txt = compute_cycle_loss(fake_img_feat, txt_feat) 244 | 245 | real_img_feat = compute_img_feat(real_imgs[level], img_encoder) 246 | errG_cycle_img = compute_cycle_loss(fake_img_feat, real_img_feat) 247 | 248 | # rightRcp_vs_rightImg = compute_cycle_loss(txt_feat, real_img_feat) 249 | # wrong_img_feat = compute_img_feat(wrong_imgs[level], img_encoder) 250 | # rightRcp_vs_wrongImg = compute_cycle_loss(txt_feat, wrong_img_feat, paired=False) 251 | # tri_loss = rightRcp_vs_rightImg + rightRcp_vs_wrongImg 252 | 253 | errG = errG_cond \ 254 | + args.uncond * errG_uncond \ 255 | + args.cycle_txt * errG_cycle_txt \ 256 | + args.cycle_img * errG_cycle_img \ 257 | # + args.tri_loss * tri_loss 258 | 259 | # record 260 | errG_total += errG 261 | 262 | wandb.log({ 263 | f'errG_cond{level}': errG_cond, 264 | f'errG_uncond{level}': errG_uncond, 265 | f'errG_cycle_txt{level}': errG_cycle_txt, 266 | f'errG_cycle_img{level}': errG_cycle_img, 267 | f'errG{level}': errG, 268 | f'batch_idx': batch_idx, 269 | }) 270 | 271 | errG_kl = compute_kl(mu, logvar) 272 | errG_total += args.kl * errG_kl 273 | 274 | optimizerG.zero_grad() 275 | errG_total.backward() 276 | optimizerG.step() 277 | 278 | 279 | wandb.log({ 280 | f'errG_kl': errG_kl, 281 | f'errD_total': errD_total, 282 | f'errG_total': errG_total, 283 | f'batch_idx': batch_idx, 284 | }) 285 | 286 | if batch_idx % 1000 == 0: 287 | netG.eval() 288 | # save img and ckpt 289 | with torch.no_grad(): 290 | fixed_txt_feat = compute_txt_feat(fixed_txt, txt_encoder) 291 | fake_imgs, mu, logvar = netG(fixed_noise, fixed_txt_feat) 292 | fake_img = fake_imgs[-1] 293 | real_fake = torch.stack([fixed_img.detach().cpu(), fake_img.detach().cpu()]).permute(1,0,2,3,4).contiguous() 294 | output_img = [] 295 | for item, title in zip(real_fake, fixed_title): 296 | item = vutils.make_grid(item, normalize=True, scale_each=True) 297 | output_img.append(wandb.Image(item, caption=title[:100])) 298 | 299 | wandb.log({ 300 | f'img': output_img, 301 | f'batch_idx': batch_idx, 302 | }) 303 | netG.train() 304 | if batch_idx % 10000 == 0: 305 | ckpt_path = os.path.join(save_dir, f'{batch_idx:>06d}.ckpt') 306 | save_model(args, batch_idx, netG_to_save, optimizerG, netsD_to_save, optimizersD, ckpt_path) 307 | 308 | batch_idx += 1 309 | 310 | 311 | if __name__ == '__main__': 312 | args = get_parser().parse_args() 313 | ############################## 314 | # setup 315 | ############################## 316 | if not args.seed: 317 | args.seed = random.randint(1, 10000) 318 | random.seed(args.seed) 319 | torch.manual_seed(args.seed) 320 | np.random.seed(args.seed) 321 | torch.backends.cudnn.benchmark = True 322 | 323 | device = args.device 324 | if device.__str__() == 'cpu': 325 | args.batch_size = 16 326 | 327 | ############################## 328 | # dataset 329 | ############################## 330 | imsize = args.base_size * (2 ** (args.levels-1)) 331 | train_transform = transforms.Compose([ 332 | transforms.Resize(int(imsize * 76 / 64)), 333 | transforms.RandomCrop(imsize), 334 | transforms.RandomHorizontalFlip()]) 335 | train_set = FoodDataset( 336 | recipe_file=args.recipe_file, 337 | img_dir=args.img_dir, 338 | levels=args.levels, 339 | part='train', 340 | food_type=args.food_type, 341 | base_size=args.base_size, 342 | transform=train_transform) 343 | 344 | train_loader = torch.utils.data.DataLoader( 345 | train_set, batch_size=args.batch_size, 346 | drop_last=True, shuffle=True, num_workers=int(args.workers)) 347 | print('train data info:', len(train_set), len(train_loader)) 348 | 349 | ############################## 350 | # model 351 | ############################## 352 | ckpt_args, _, txt_encoder, img_encoder, _ = train_retrieval.load_model(args.retrieval_model, device) 353 | requires_grad(txt_encoder, False) 354 | requires_grad(img_encoder, False) 355 | txt_encoder = txt_encoder.eval() 356 | img_encoder = img_encoder.eval() 357 | 358 | if args.ckpt_path: 359 | ckpt_args, batch, netG, optimizerG, netsD, optimizersD = load_model(args.ckpt_path, device) 360 | wandb_run_id = args.ckpt_path.split('/')[-2] 361 | batch_start = batch + 1 362 | else: 363 | netG, netsD, optimizerG, optimizersD = create_model(args, device) 364 | wandb_run_id = '' 365 | batch_start = 0 366 | 367 | ############################## 368 | # train 369 | ############################## 370 | # define loss 371 | criterion = nn.BCELoss() 372 | 373 | # define fixed noise 374 | fixed_noise_part1 = torch.FloatTensor(1, args.z_dim).normal_(0, 1) 375 | fixed_noise_part1 = fixed_noise_part1.repeat(args.batch_size//2, 1) 376 | fixed_noise_part2 = torch.FloatTensor(args.batch_size//2, args.z_dim).normal_(0, 1) 377 | fixed_noise = torch.cat([fixed_noise_part1, fixed_noise_part2], dim=0).to(device) 378 | 379 | fixed_txt, fixed_imgs, _, fixed_title = next(iter(train_loader)) 380 | fixed_img = fixed_imgs[-1] 381 | 382 | # setup saving directory 383 | project_name = "cookgan" 384 | wandb.init(project=project_name, config=args, resume=wandb_run_id) 385 | wandb.config.update(args) 386 | 387 | train( 388 | args, batch_start, train_loader, device, 389 | txt_encoder, img_encoder, 390 | netG, optimizerG, netsD, optimizersD, criterion, 391 | fixed_noise, fixed_txt, fixed_img, fixed_title, wandb.run.dir) 392 | -------------------------------------------------------------------------------- /retrieval_model/models_retrieval.py.bak: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn.utils import rnn 7 | from torchvision import models 8 | import torch.utils.model_zoo as model_zoo 9 | from gensim.models.keyedvectors import KeyedVectors 10 | import pdb 11 | import torchvision 12 | import math 13 | import numpy as np 14 | 15 | class AttentionLayer(nn.Module): 16 | def __init__(self, input_dim, with_attention): 17 | super(AttentionLayer, self).__init__() 18 | self.u = torch.nn.Parameter(torch.randn(input_dim)) # u = [2*hid_dim] 19 | self.u.requires_grad = True 20 | self.fc = nn.Linear(input_dim, input_dim) 21 | self.with_attention = with_attention 22 | def forward(self, x): 23 | # x = [BS, max_len, 2*hid_dim] 24 | # a trick used to find the mask for the softmax 25 | mask = (x!=0) 26 | mask = mask[:,:,0].bool() 27 | h = torch.tanh(self.fc(x)) # h = [BS, max_len, 2*hid_dim] 28 | if self.with_attention == 1: # softmax 29 | scores = h @ self.u # scores = [BS, max_len], unnormalized importance 30 | elif self.with_attention == 2: # Transformer 31 | scores = h @ self.u / math.sqrt(h.shape[-1]) # scores = [BS, max_len], unnormalized importance 32 | masked_scores = scores.masked_fill(~mask, -1e32) 33 | alpha = F.softmax(masked_scores, dim=1) # alpha = [BS, max_len], normalized importance 34 | 35 | alpha = alpha.unsqueeze(-1) # alpha = [BS, max_len, 1] 36 | out = x * alpha # out = [BS, max_len, 2*hid_dim] 37 | out = out.sum(dim=1) # out = [BS, 2*hid_dim] 38 | # pdb.set_trace() 39 | return out, alpha.squeeze(-1) 40 | 41 | 42 | class IngredientsEncoderRNN(nn.Module): 43 | def __init__( 44 | self, 45 | emb_dim, 46 | hid_dim, 47 | z_dim, 48 | word2vec_file='data/word2vec_recipes.bin', 49 | with_attention=True): 50 | super(IngredientsEncoderRNN, self).__init__() 51 | 52 | wv = KeyedVectors.load(word2vec_file, mmap='r') 53 | vec = torch.from_numpy(np.copy(wv.vectors)).float() 54 | # first two index has special meaning, see load_dict() in utils.py 55 | emb = nn.Embedding(vec.shape[0]+2, vec.shape[1], padding_idx=0) 56 | emb.weight.data[2:].copy_(vec) 57 | # for p in emb.parameters(): 58 | # p.requires_grad = False 59 | self.embed_layer = emb 60 | print('IngredientsEncoderRNN:', emb) 61 | 62 | self.rnn = nn.GRU( 63 | input_size=emb_dim, 64 | hidden_size=hid_dim, 65 | bidirectional=True, 66 | batch_first=True) 67 | 68 | self.with_attention = with_attention 69 | if with_attention: 70 | self.atten_layer = AttentionLayer(2*hid_dim, with_attention) 71 | 72 | def forward(self, sent_list, lengths): 73 | # sent_list [BS, max_len] 74 | # lengths [BS] 75 | x = self.embed_layer(sent_list) # x=[BS, max_len, emb_dim] 76 | sorted_len, sorted_idx = lengths.sort(0, descending=True) # sorted_idx=[BS], for sorting 77 | _, original_idx = sorted_idx.sort(0, descending=False) # original_idx=[BS], for unsorting 78 | index_sorted_idx = sorted_idx.view(-1,1,1).expand_as(x) # sorted_idx=[BS, max_len, emb_dim] 79 | sorted_inputs = x.gather(0, index_sorted_idx.long()) # sort by num_words 80 | packed_seq = rnn.pack_padded_sequence( 81 | sorted_inputs, sorted_len.cpu().numpy(), batch_first=True) 82 | self.rnn.flatten_parameters() 83 | # if not self.with_attention: 84 | if self.with_attention: 85 | out, _ = self.rnn(packed_seq) 86 | # pdb.set_trace() 87 | y, _ = rnn.pad_packed_sequence( 88 | out, batch_first=True, total_length=20) # y=[BS, max_len, 2*hid_dim], currently in WRONG order! 89 | unsorted_idx = original_idx.view(-1,1,1).expand_as(y) 90 | output = y.gather(0, unsorted_idx).contiguous() # [BS, max_len, 2*hid_dim], now in correct order 91 | feat, alpha = self.atten_layer(output) # [BS, 2*hid_dim] 92 | # print('sent', feat.shape) # [BS, 2*hid_dim] 93 | return feat, alpha 94 | else: 95 | _, h = self.rnn(packed_seq) # [2, BS, hid_dim], currently in WRONG order! 96 | # pdb.set_trace() 97 | h = h.transpose(0,1) # [BS, 2, hid_dim], still in WRONG order! 98 | # unsort the output 99 | unsorted_idx = original_idx.view(-1,1,1).expand_as(h) 100 | output = h.gather(0, unsorted_idx).contiguous() # [BS, 2, hid_dim], now in correct order 101 | feat = output.view(output.size(0), output.size(1)*output.size(2)) # [BS, 2*hid_dim] 102 | return feat 103 | 104 | 105 | class IngredientsEncoderFC(nn.Module): 106 | def __init__( 107 | self, 108 | emb_dim, 109 | hid_dim, 110 | z_dim, 111 | word2vec_file='data/word2vec_recipes.bin', 112 | with_attention=True): 113 | super(IngredientsEncoderFC, self).__init__() 114 | 115 | wv = KeyedVectors.load(word2vec_file, mmap='r') 116 | vec = torch.from_numpy(wv.vectors).float() 117 | # first two index has special meaning, see load_dict() in utils.py 118 | emb = nn.Embedding(vec.shape[0]+2, vec.shape[1], padding_idx=0) 119 | emb.weight.data[2:].copy_(vec) 120 | # for p in emb.parameters(): 121 | # p.requires_grad = False 122 | self.embed_layer = emb 123 | print('IngredientsEncoderRNN:', emb) 124 | 125 | self.fc1 = nn.Linear(emb_dim, 2*hid_dim) 126 | self.fc2 = nn.Linear(2*hid_dim, 2*hid_dim) 127 | 128 | self.with_attention = with_attention 129 | if with_attention: 130 | self.atten_layer = AttentionLayer(2*hid_dim, with_attention) 131 | 132 | def forward(self, sent_list, lengths): 133 | # sent_list [BS, max_len] 134 | # lengths [BS] 135 | # sent_list [BS, max_len] 136 | x = self.embed_layer(sent_list) # x=[BS, max_len, emb_dim] 137 | x = self.fc2(F.relu(self.fc1(x))) # [BS, max_len, 2*hid_dim] 138 | if not self.with_attention: 139 | feat = x.sum(dim=1) # [BS, 2*hid_dim] 140 | return feat 141 | else: 142 | feat, alpha = self.atten_layer(x) # [BS, 2*hid_dim] 143 | # print('ingredients', feat.shape) 144 | return feat, alpha 145 | 146 | 147 | class SentenceEncoder(nn.Module): 148 | def __init__( 149 | self, 150 | emb_dim, 151 | hid_dim, 152 | z_dim, 153 | word2vec_file='data/word2vec_recipes.bin', 154 | with_attention=True): 155 | super(SentenceEncoder, self).__init__() 156 | 157 | wv = KeyedVectors.load(word2vec_file, mmap='r') 158 | vec = torch.from_numpy(np.copy(wv.vectors)).float() 159 | # first two index has special meaning, see load_dict() in utils.py 160 | emb = nn.Embedding(vec.shape[0]+2, vec.shape[1], padding_idx=0) 161 | emb.weight.data[2:].copy_(vec) 162 | # for p in emb.parameters(): 163 | # p.requires_grad = False 164 | self.embed_layer = emb 165 | 166 | self.rnn = nn.GRU( 167 | input_size=emb_dim, 168 | hidden_size=hid_dim, 169 | bidirectional=True, 170 | batch_first=True) 171 | 172 | self.with_attention = with_attention 173 | if with_attention: 174 | self.atten_layer = AttentionLayer(2*hid_dim, with_attention) 175 | 176 | def forward(self, sent_list, lengths): 177 | # sent_list [BS, max_len] 178 | # lengths [BS] 179 | x = self.embed_layer(sent_list) # x=[BS, max_len, emb_dim] 180 | sorted_len, sorted_idx = lengths.sort(0, descending=True) # sorted_idx=[BS], for sorting 181 | _, original_idx = sorted_idx.sort(0, descending=False) # original_idx=[BS], for unsorting 182 | index_sorted_idx = sorted_idx.view(-1,1,1).expand_as(x) # sorted_idx=[BS, max_len, emb_dim] 183 | sorted_inputs = x.gather(0, index_sorted_idx.long()) # sort by num_words 184 | packed_seq = rnn.pack_padded_sequence( 185 | sorted_inputs, sorted_len.cpu().numpy(), batch_first=True) 186 | self.rnn.flatten_parameters() 187 | # if not self.with_attention: 188 | if self.with_attention: 189 | out, _ = self.rnn(packed_seq) 190 | # pdb.set_trace() 191 | y, _ = rnn.pad_packed_sequence( 192 | out, batch_first=True, total_length=20) # y=[BS, max_len, 2*hid_dim], currently in WRONG order! 193 | unsorted_idx = original_idx.view(-1,1,1).expand_as(y) 194 | output = y.gather(0, unsorted_idx).contiguous() # [BS, max_len, 2*hid_dim], now in correct order 195 | feat, alpha = self.atten_layer(output) # [BS, 2*hid_dim] 196 | # print('sent', feat.shape) # [BS, 2*hid_dim] 197 | return feat, alpha 198 | else: 199 | _, h = self.rnn(packed_seq) # [2, BS, hid_dim], currently in WRONG order! 200 | # pdb.set_trace() 201 | h = h.transpose(0,1) # [BS, 2, hid_dim], still in WRONG order! 202 | # unsort the output 203 | unsorted_idx = original_idx.view(-1,1,1).expand_as(h) 204 | output = h.gather(0, unsorted_idx).contiguous() # [BS, 2, hid_dim], now in correct order 205 | feat = output.view(output.size(0), output.size(1)*output.size(2)) # [BS, 2*hid_dim] 206 | return feat 207 | 208 | 209 | class DocEncoder(nn.Module): 210 | def __init__(self, sent_encoder, hid_dim, with_attention): 211 | super(DocEncoder, self).__init__() 212 | self.sent_encoder = sent_encoder 213 | self.rnn = nn.GRU( 214 | input_size=2*hid_dim, 215 | hidden_size=hid_dim, 216 | bidirectional=True, 217 | batch_first=True) 218 | if with_attention: 219 | self.atten_layer_sent = AttentionLayer(2*hid_dim, with_attention) 220 | self.with_attention = with_attention 221 | 222 | def forward(self, doc_list, n_insts, n_words_each_inst): 223 | # doc_list=[BS, max_len, max_len] 224 | # n_insts = [BS] 225 | # n_words_each_inst = [BS, 20] 226 | embs = [] 227 | attentions_words_each_inst =[] 228 | for i in range(len(n_insts)): 229 | doc = doc_list[i] 230 | ln = n_insts[i] # how many steps 231 | sent_lns = n_words_each_inst[i, :n_words_each_inst[i].nonzero(as_tuple=False).shape[0]] # len of each step 232 | 233 | if self.with_attention: 234 | emb_doc, alpha = self.sent_encoder(doc[:ln], sent_lns) # [?, 2*hid_dim] 235 | attentions_words_each_inst.append(alpha) # e.g. if lns=[5,1,14,7, ...], then attentions_words_each_inst=[[5,20], [1,20], [14,20], [7,20], ...] with length=BS 236 | else: 237 | emb_doc = self.sent_encoder(doc[:ln], sent_lns) # [?, 2*hid_dim] 238 | embs.append(emb_doc) 239 | 240 | embs = sorted(embs, key=lambda x: -x.shape[0]) # [BS, [?, 2*hid_dim]] 241 | packed_seq = rnn.pack_sequence(embs) 242 | _, sorted_idx = n_insts.sort(0, descending=True) # sorted_idx=[BS], for sorting 243 | _, original_idx = sorted_idx.sort(0, descending=False) # original_idx=[BS], for unsorting 244 | 245 | self.rnn.flatten_parameters() 246 | if self.with_attention: 247 | out, _ = self.rnn(packed_seq) 248 | y, _ = rnn.pad_packed_sequence( 249 | out, batch_first=True, total_length=20) # y=[BS, max_len, 2*hid_dim], currently in WRONG order! 250 | # pdb.set_trace() 251 | unsorted_idx = original_idx.view(-1,1,1).expand_as(y) 252 | output = y.gather(0, unsorted_idx).contiguous() # [BS, max_len, 2*hid_dim], now in correct order 253 | out, attentions_each_inst = self.atten_layer_sent(output) 254 | # print('instructions', feat.shape) 255 | return out, attentions_each_inst, attentions_words_each_inst 256 | else: 257 | _, h = self.rnn(packed_seq) # [2, BS, hid_dim], currently in WRONG order! 258 | h = h.transpose(0,1) # [BS, 2, hid_dim], still in WRONG order! 259 | # unsort the output 260 | unsorted_idx = original_idx.view(-1,1,1).expand_as(h) 261 | output = h.gather(0, unsorted_idx).contiguous() # [BS, 2, hid_dim], now in correct order 262 | feat = output.view(output.size(0), output.size(1)*output.size(2)) # [BS, 2*hid_dim] 263 | # print('instructions', feat.shape) 264 | return feat 265 | 266 | 267 | class TextEncoder(nn.Module): 268 | def __init__( 269 | self, 270 | emb_dim, hid_dim, z_dim, 271 | word2vec_file, 272 | with_attention=0, 273 | text_info='010', 274 | ingrs_enc_type='rnn' 275 | ): 276 | super(TextEncoder, self).__init__() 277 | if ingrs_enc_type == 'rnn': 278 | self.ingrs_encoder = IngredientsEncoderRNN( 279 | emb_dim, hid_dim, z_dim, 280 | word2vec_file=word2vec_file, 281 | with_attention=with_attention) 282 | elif ingrs_enc_type == 'fc': 283 | self.ingrs_encoder = IngredientsEncoderFC( 284 | emb_dim, hid_dim, z_dim, 285 | word2vec_file=word2vec_file, 286 | with_attention=with_attention) 287 | 288 | self.sent_encoder = SentenceEncoder( 289 | emb_dim=emb_dim, hid_dim=hid_dim, z_dim=z_dim, word2vec_file=word2vec_file, 290 | with_attention=with_attention) 291 | self.doc_encoder = DocEncoder( 292 | self.sent_encoder, 293 | hid_dim, 294 | with_attention 295 | ) 296 | self.with_attention = with_attention 297 | self.text_info = text_info 298 | num_ones = text_info.count('1') 299 | self.bn = nn.BatchNorm1d(2*num_ones*hid_dim) 300 | self.fc = nn.Linear(2*num_ones*hid_dim, z_dim) 301 | 302 | def forward(self, title, title_len, ingredients, n_ingrs, instructions, n_insts, insts_lens): 303 | if self.with_attention: 304 | feat_title, alpha_title = self.sent_encoder(title, title_len) 305 | feat_ingredients, alpha_ingredients = self.ingrs_encoder(ingredients, n_ingrs) 306 | feat_instructions, alpha_instructions, alpha_words = self.doc_encoder(instructions, n_insts, insts_lens) 307 | else: 308 | feat_title = self.sent_encoder(title, title_len) 309 | feat_ingredients = self.ingrs_encoder(ingredients, n_ingrs) 310 | feat_instructions = self.doc_encoder(instructions, n_insts, insts_lens) 311 | 312 | if self.text_info == '100': 313 | feat = feat_title 314 | attentions = [alpha_title, None, None, None] if self.with_attention else None 315 | elif self.text_info == '010': 316 | feat = feat_ingredients 317 | attentions = [None, alpha_ingredients, None, None] if self.with_attention else None 318 | elif self.text_info == '001': 319 | feat = feat_instructions 320 | attentions = [None, None, alpha_instructions, alpha_words] if self.with_attention else None 321 | elif self.text_info == '111': 322 | feat = torch.cat([feat_title, feat_ingredients, feat_instructions], dim=1) 323 | attentions = [alpha_title, alpha_ingredients, alpha_instructions, alpha_words] if self.with_attention else None 324 | 325 | feat = self.fc(self.bn(feat)) 326 | feat = F.normalize(feat, p=2, dim=1) 327 | return feat, attentions 328 | 329 | 330 | # Image Encoder 331 | def clean_state_dict(state_dict): 332 | # create new OrderedDict that does not contain `module.` 333 | from collections import OrderedDict 334 | new_state_dict = OrderedDict() 335 | for k, v in state_dict.items(): 336 | name = k[7:] if k[:min(6,len(k))] == 'module' else k # remove `module.` 337 | new_state_dict[name] = v 338 | return new_state_dict 339 | 340 | class Resnet(nn.Module): 341 | def __init__(self, ckpt_path=None): 342 | super(Resnet, self).__init__() 343 | resnet = torchvision.models.resnet50(pretrained=True) 344 | num_feat = resnet.fc.in_features 345 | resnet.fc = nn.Linear(num_feat, 101) 346 | if ckpt_path: 347 | resnet.load_state_dict(clean_state_dict(torch.load(ckpt_path))) 348 | modules = list(resnet.children())[:-1] # we do not use the last fc layer. 349 | self.encoder = nn.Sequential(*modules) 350 | 351 | def forward(self, image_list): 352 | BS = image_list.shape[0] 353 | return self.encoder(image_list).view(BS, -1) 354 | 355 | class ImageEncoder(nn.Module): 356 | def __init__(self, z_dim, ckpt_path=None): 357 | super(ImageEncoder, self).__init__() 358 | self.resnet = Resnet(ckpt_path) 359 | self.bottleneck = nn.Sequential( 360 | nn.BatchNorm1d(2048), 361 | nn.Linear(2048, z_dim), 362 | nn.Tanh() 363 | ) 364 | 365 | def forward(self, image_list): 366 | feat = self.resnet(image_list) 367 | feat = self.bottleneck(feat) 368 | # print('image', feat.shape) 369 | return F.normalize(feat, p=2, dim=1) -------------------------------------------------------------------------------- /cookgan/models_cookgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | b_condition=True 8 | text_dim = 1024 9 | z_dim = 100 10 | embedding_dim = 128 11 | r_num = 2 12 | gf_dim = 64 13 | df_dim = 64 14 | 15 | class INCEPTION_V3(nn.Module): 16 | def __init__(self): 17 | super(INCEPTION_V3, self).__init__() 18 | self.model = models.inception_v3() 19 | url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth' 20 | # print(next(model.parameters()).data) 21 | state_dict = \ 22 | model_zoo.load_url(url, map_location=lambda storage, loc: storage) 23 | self.model.load_state_dict(state_dict) 24 | for param in self.model.parameters(): 25 | param.requires_grad = False 26 | print('Load pretrained inception_v3 model from', url) 27 | # print(next(self.model.parameters()).data) 28 | # print(self.model) 29 | 30 | def forward(self, input): 31 | # [-1.0, 1.0] --> [0, 1.0] 32 | x = input * 0.5 + 0.5 33 | # mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] 34 | # --> mean = 0, std = 1 35 | x[:, 0] = (x[:, 0] - 0.485) / 0.229 36 | x[:, 1] = (x[:, 1] - 0.456) / 0.224 37 | x[:, 2] = (x[:, 2] - 0.406) / 0.225 38 | # 39 | # --> fixed-size input: batch x 3 x 299 x 299 40 | # x = nn.Upsample(size=(299, 299), mode='bilinear')(x) 41 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 42 | # 299 x 299 x 3 43 | x = self.model(x) 44 | x = nn.Softmax(dim=-1)(x) 45 | return x 46 | 47 | # Generator 48 | class G_NET(nn.Module): 49 | def __init__( 50 | self, 51 | gf_dim=64, 52 | z_dim=100, 53 | r_num=2, 54 | levels=3, 55 | text_dim=1024, 56 | embedding_dim=128, 57 | b_condition=True, 58 | ca=True): 59 | 60 | super(G_NET, self).__init__() 61 | self.gf_dim = gf_dim 62 | self.z_dim = z_dim 63 | self.r_num = r_num 64 | self.levels = levels 65 | self.gf_dim = gf_dim 66 | self.text_dim = text_dim 67 | self.embedding_dim = embedding_dim 68 | self.b_condition = b_condition 69 | self.ca = ca 70 | self.define_module() 71 | 72 | def define_module(self): 73 | if self.ca and self.b_condition: 74 | self.ca_net = CA_NET(text_dim=self.text_dim, embedding_dim=self.embedding_dim) 75 | 76 | if self.levels > 0: 77 | self.h_net1 = INIT_STAGE_G(ngf=self.gf_dim * 16, b_condition=self.b_condition) 78 | self.img_net1 = GET_IMAGE_G(self.gf_dim) 79 | if self.levels > 1: 80 | self.h_net2 = NEXT_STAGE_G(self.gf_dim) 81 | self.img_net2 = GET_IMAGE_G(self.gf_dim // 2) 82 | if self.levels > 2: 83 | self.h_net3 = NEXT_STAGE_G(self.gf_dim // 2) 84 | self.img_net3 = GET_IMAGE_G(self.gf_dim // 4) 85 | if self.levels > 3: # Recommended structure (mainly limited by GPU memory), and not test yet 86 | self.h_net4 = NEXT_STAGE_G(self.gf_dim // 4, num_residual=1) 87 | self.img_net4 = GET_IMAGE_G(self.gf_dim // 8) 88 | if self.levels > 4: 89 | self.h_net4 = NEXT_STAGE_G(self.gf_dim // 8, num_residual=1) 90 | self.img_net4 = GET_IMAGE_G(self.gf_dim // 16) 91 | 92 | def forward(self, z_code, text_embedding=None): 93 | if self.b_condition and text_embedding is not None: 94 | c_code, mu, logvar = self.ca_net(text_embedding) 95 | else: 96 | c_code, mu, logvar = z_code, None, None 97 | fake_imgs = [] 98 | if self.levels > 0: 99 | h_code1 = self.h_net1(z_code, c_code) 100 | fake_img1 = self.img_net1(h_code1) 101 | fake_imgs.append(fake_img1) 102 | if self.levels > 1: 103 | h_code2 = self.h_net2(h_code1, c_code) 104 | fake_img2 = self.img_net2(h_code2) 105 | fake_imgs.append(fake_img2) 106 | if self.levels > 2: 107 | h_code3 = self.h_net3(h_code2, c_code) 108 | fake_img3 = self.img_net3(h_code3) 109 | fake_imgs.append(fake_img3) 110 | if self.levels > 3: 111 | h_code4 = self.h_net4(h_code3, c_code) 112 | fake_img4 = self.img_net4(h_code4) 113 | fake_imgs.append(fake_img4) 114 | 115 | return fake_imgs, mu, logvar 116 | 117 | 118 | class CA_NET(nn.Module): 119 | # some code is modified from vae examples 120 | # (https://github.com/pytorch/examples/blob/master/vae/main.py) 121 | def __init__(self, text_dim=1024, embedding_dim=128): 122 | super(CA_NET, self).__init__() 123 | self.t_dim = text_dim 124 | self.ef_dim = embedding_dim 125 | self.fc = nn.Linear(self.t_dim, self.ef_dim * 4, bias=True) 126 | self.relu = GLU() 127 | 128 | def encode(self, text_embedding): 129 | x = self.relu(self.fc(text_embedding)) 130 | mu = x[:, :self.ef_dim] 131 | logvar = x[:, self.ef_dim:] 132 | return mu, logvar 133 | 134 | def reparametrize(self, mu, logvar): 135 | std = logvar.mul(0.5).exp_() 136 | eps = torch.FloatTensor(std.size()).normal_().to(mu.device) 137 | return eps.mul(std).add_(mu) 138 | 139 | def forward(self, text_embedding): 140 | mu, logvar = self.encode(text_embedding) 141 | if self.training: 142 | c_code = self.reparametrize(mu, logvar) 143 | return c_code, mu, logvar 144 | else: 145 | return mu, mu, logvar 146 | 147 | 148 | class GLU(nn.Module): 149 | def __init__(self): 150 | super(GLU, self).__init__() 151 | 152 | def forward(self, x): 153 | nc = x.size(1) 154 | assert nc % 2 == 0, "channels can't divide by 2!" 155 | nc = int(nc/2) 156 | return x[:, :nc] * torch.sigmoid(x[:, nc:]) 157 | 158 | 159 | class INIT_STAGE_G(nn.Module): 160 | def __init__(self, ngf, b_condition): 161 | super(INIT_STAGE_G, self).__init__() 162 | self.gf_dim = ngf 163 | self.b_condition = b_condition 164 | if self.b_condition: 165 | self.in_dim = z_dim + embedding_dim 166 | else: 167 | self.in_dim = z_dim 168 | self.define_module() 169 | 170 | def define_module(self): 171 | in_dim = self.in_dim 172 | ngf = self.gf_dim 173 | self.fc = nn.Sequential( 174 | nn.Linear(in_dim, ngf * 4 * 4 * 2, bias=False), 175 | nn.BatchNorm1d(ngf * 4 * 4 * 2), 176 | GLU()) 177 | 178 | 179 | self.upsample1 = upBlock(ngf, ngf // 2) 180 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 181 | self.upsample3 = upBlock(ngf // 4, ngf // 8) 182 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 183 | 184 | def forward(self, z_code, c_code=None): 185 | if self.b_condition and c_code is not None: 186 | in_code = torch.cat((c_code, z_code), 1) 187 | else: 188 | in_code = z_code 189 | # state size 16ngf x 4 x 4 190 | out_code = self.fc(in_code) 191 | out_code = out_code.view(-1, self.gf_dim, 4, 4) 192 | # state size 8ngf x 8 x 8 193 | out_code = self.upsample1(out_code) 194 | # state size 4ngf x 16 x 16 195 | out_code = self.upsample2(out_code) 196 | # state size 2ngf x 32 x 32 197 | out_code = self.upsample3(out_code) 198 | # state size ngf x 64 x 64 199 | out_code = self.upsample4(out_code) 200 | 201 | return out_code 202 | 203 | 204 | class NEXT_STAGE_G(nn.Module): 205 | def __init__(self, ngf, num_residual=r_num): 206 | super(NEXT_STAGE_G, self).__init__() 207 | self.gf_dim = ngf 208 | if b_condition: 209 | self.ef_dim = embedding_dim 210 | else: 211 | self.ef_dim = z_dim 212 | self.num_residual = num_residual 213 | self.define_module() 214 | 215 | def _make_layer(self, block, channel_num): 216 | layers = [] 217 | for i in range(self.num_residual): 218 | layers.append(block(channel_num)) 219 | return nn.Sequential(*layers) 220 | 221 | def define_module(self): 222 | ngf = self.gf_dim 223 | efg = self.ef_dim 224 | 225 | self.jointConv = Block3x3_relu(ngf + efg, ngf) 226 | self.residual = self._make_layer(ResBlock, ngf) 227 | self.upsample = upBlock(ngf, ngf // 2) 228 | 229 | def forward(self, h_code, c_code): 230 | s_size = h_code.size(2) 231 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 232 | c_code = c_code.repeat(1, 1, s_size, s_size) 233 | # state size (ngf+egf) x in_size x in_size 234 | h_c_code = torch.cat((c_code, h_code), 1) 235 | # state size ngf x in_size x in_size 236 | out_code = self.jointConv(h_c_code) 237 | out_code = self.residual(out_code) 238 | # state size ngf/2 x 2in_size x 2in_size 239 | out_code = self.upsample(out_code) 240 | 241 | return out_code 242 | 243 | 244 | class GET_IMAGE_G(nn.Module): 245 | def __init__(self, ngf): 246 | super(GET_IMAGE_G, self).__init__() 247 | self.gf_dim = ngf 248 | self.img = nn.Sequential( 249 | conv3x3(ngf, 3), 250 | nn.Tanh() 251 | ) 252 | 253 | def forward(self, h_code): 254 | out_img = self.img(h_code) 255 | return out_img 256 | 257 | 258 | def conv3x3(in_planes, out_planes): 259 | "3x3 convolution with padding" 260 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 261 | padding=1, bias=False) 262 | 263 | 264 | class ScaleUp(nn.Module): 265 | def __init__(self, scale_factor=2, mode='nearest'): 266 | super(ScaleUp, self).__init__() 267 | self.interp = nn.functional.interpolate 268 | self.scale_factor = scale_factor 269 | self.mode = mode 270 | 271 | def forward(self, x): 272 | x = self.interp( 273 | x, scale_factor=self.scale_factor, mode=self.mode) 274 | return x 275 | 276 | 277 | def upBlock(in_planes, out_planes): 278 | block = nn.Sequential( 279 | ScaleUp(scale_factor=2, mode='nearest'), 280 | conv3x3(in_planes, out_planes * 2), 281 | nn.BatchNorm2d(out_planes * 2), 282 | GLU() 283 | ) 284 | return block 285 | 286 | # Keep the spatial size 287 | def Block3x3_relu(in_planes, out_planes): 288 | block = nn.Sequential( 289 | conv3x3(in_planes, out_planes * 2), 290 | nn.BatchNorm2d(out_planes * 2), 291 | GLU() 292 | ) 293 | return block 294 | 295 | class ResBlock(nn.Module): 296 | def __init__(self, channel_num): 297 | super(ResBlock, self).__init__() 298 | self.block = nn.Sequential( 299 | conv3x3(channel_num, channel_num * 2), 300 | nn.BatchNorm2d(channel_num * 2), 301 | GLU(), 302 | conv3x3(channel_num, channel_num), 303 | nn.BatchNorm2d(channel_num) 304 | ) 305 | 306 | def forward(self, x): 307 | residual = x 308 | out = self.block(x) 309 | out += residual 310 | return out 311 | 312 | 313 | # ************************************************ 314 | # Discriminator 315 | # ************************************************ 316 | # For 64 x 64 images 317 | class D_NET64(nn.Module): 318 | def __init__(self): 319 | super(D_NET64, self).__init__() 320 | self.df_dim = df_dim 321 | self.ef_dim = embedding_dim 322 | self.define_module() 323 | 324 | def define_module(self): 325 | ndf = self.df_dim 326 | efg = self.ef_dim 327 | self.img_code_s16 = encode_image_by_16times(ndf) 328 | 329 | self.logits = nn.Sequential( 330 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 331 | nn.Sigmoid()) 332 | 333 | if b_condition: 334 | self.jointConv = Block3x3_leakRelu(ndf * 8 + efg, ndf * 8) 335 | self.uncond_logits = nn.Sequential( 336 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 337 | nn.Sigmoid()) 338 | 339 | def forward(self, x_var, c_code=None): 340 | x_code = self.img_code_s16(x_var) 341 | 342 | if b_condition and c_code is not None: 343 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 344 | c_code = c_code.repeat(1, 1, 4, 4) 345 | # state size (ngf+egf) x 4 x 4 346 | h_c_code = torch.cat((c_code, x_code), 1) 347 | # state size ngf x in_size x in_size 348 | h_c_code = self.jointConv(h_c_code) 349 | else: 350 | h_c_code = x_code 351 | 352 | output = self.logits(h_c_code) 353 | if b_condition: 354 | out_uncond = self.uncond_logits(x_code) 355 | return [output.view(-1), out_uncond.view(-1)] 356 | else: 357 | return [output.view(-1)] 358 | 359 | 360 | # For 128 x 128 images 361 | class D_NET128(nn.Module): 362 | def __init__(self): 363 | super(D_NET128, self).__init__() 364 | self.df_dim = df_dim 365 | self.ef_dim = embedding_dim 366 | self.define_module() 367 | 368 | def define_module(self): 369 | ndf = self.df_dim 370 | efg = self.ef_dim 371 | self.img_code_s16 = encode_image_by_16times(ndf) 372 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 373 | self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8) 374 | 375 | self.logits = nn.Sequential( 376 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 377 | nn.Sigmoid()) 378 | 379 | if b_condition: 380 | self.jointConv = Block3x3_leakRelu(ndf * 8 + efg, ndf * 8) 381 | self.uncond_logits = nn.Sequential( 382 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 383 | nn.Sigmoid()) 384 | 385 | def forward(self, x_var, c_code=None): 386 | x_code = self.img_code_s16(x_var) 387 | x_code = self.img_code_s32(x_code) 388 | x_code = self.img_code_s32_1(x_code) 389 | 390 | if b_condition and c_code is not None: 391 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 392 | c_code = c_code.repeat(1, 1, 4, 4) 393 | # state size (ngf+egf) x 4 x 4 394 | h_c_code = torch.cat((c_code, x_code), 1) 395 | # state size ngf x in_size x in_size 396 | h_c_code = self.jointConv(h_c_code) 397 | else: 398 | h_c_code = x_code 399 | 400 | output = self.logits(h_c_code) 401 | if b_condition: 402 | out_uncond = self.uncond_logits(x_code) 403 | return [output.view(-1), out_uncond.view(-1)] 404 | else: 405 | return [output.view(-1)] 406 | 407 | 408 | # For 256 x 256 images 409 | class D_NET256(nn.Module): 410 | def __init__(self): 411 | super(D_NET256, self).__init__() 412 | self.df_dim = df_dim 413 | self.ef_dim = embedding_dim 414 | self.define_module() 415 | 416 | def define_module(self): 417 | ndf = self.df_dim 418 | efg = self.ef_dim 419 | self.img_code_s16 = encode_image_by_16times(ndf) 420 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 421 | self.img_code_s64 = downBlock(ndf * 16, ndf * 32) 422 | self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16) 423 | self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8) 424 | 425 | self.logits = nn.Sequential( 426 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 427 | nn.Sigmoid()) 428 | 429 | if b_condition: 430 | self.jointConv = Block3x3_leakRelu(ndf * 8 + efg, ndf * 8) 431 | self.uncond_logits = nn.Sequential( 432 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 433 | nn.Sigmoid()) 434 | 435 | def forward(self, x_var, c_code=None): 436 | x_code = self.img_code_s16(x_var) 437 | x_code = self.img_code_s32(x_code) 438 | x_code = self.img_code_s64(x_code) 439 | x_code = self.img_code_s64_1(x_code) 440 | x_code = self.img_code_s64_2(x_code) 441 | 442 | if b_condition and c_code is not None: 443 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 444 | c_code = c_code.repeat(1, 1, 4, 4) 445 | # state size (ngf+egf) x 4 x 4 446 | h_c_code = torch.cat((c_code, x_code), 1) 447 | # state size ngf x in_size x in_size 448 | h_c_code = self.jointConv(h_c_code) 449 | else: 450 | h_c_code = x_code 451 | 452 | output = self.logits(h_c_code) 453 | if b_condition: 454 | out_uncond = self.uncond_logits(x_code) 455 | return [output.view(-1), out_uncond.view(-1)] 456 | else: 457 | return [output.view(-1)] 458 | 459 | 460 | def Block3x3_leakRelu(in_planes, out_planes): 461 | block = nn.Sequential( 462 | conv3x3(in_planes, out_planes), 463 | nn.BatchNorm2d(out_planes), 464 | nn.LeakyReLU(0.2, inplace=True) 465 | ) 466 | return block 467 | 468 | 469 | # Downsale the spatial size by a factor of 2 470 | def downBlock(in_planes, out_planes): 471 | block = nn.Sequential( 472 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False), 473 | nn.BatchNorm2d(out_planes), 474 | nn.LeakyReLU(0.2, inplace=True) 475 | ) 476 | return block 477 | 478 | 479 | # Downsale the spatial size by a factor of 16 480 | def encode_image_by_16times(ndf): 481 | encode_img = nn.Sequential( 482 | # --> state size. ndf x in_size/2 x in_size/2 483 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 484 | nn.LeakyReLU(0.2, inplace=True), 485 | # --> state size 2ndf x x in_size/4 x in_size/4 486 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 487 | nn.BatchNorm2d(ndf * 2), 488 | nn.LeakyReLU(0.2, inplace=True), 489 | # --> state size 4ndf x in_size/8 x in_size/8 490 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 491 | nn.BatchNorm2d(ndf * 4), 492 | nn.LeakyReLU(0.2, inplace=True), 493 | # --> state size 8ndf x in_size/16 x in_size/16 494 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 495 | nn.BatchNorm2d(ndf * 8), 496 | nn.LeakyReLU(0.2, inplace=True) 497 | ) 498 | return encode_img 499 | 500 | if __name__ == '__main__': 501 | class Args: pass 502 | args = Args() 503 | args.device='cuda' 504 | args.z_dim=100 505 | args.input_dim=1024 506 | args.levels=3 507 | 508 | bs = 2 509 | device=args.device 510 | G = G_NET(gf_dim=64, z_dim=args.z_dim, r_num=2, levels=args.levels, b_condition=True, ca=True).to(device) 511 | G = nn.DataParallel(G) 512 | Ds = [] 513 | imgs = [] 514 | Ds.append(D_NET64()) 515 | imgs.append(torch.randn(bs, 3, 64, 64).to(device)) 516 | if args.levels >= 2: 517 | Ds.append(D_NET128()) 518 | imgs.append(torch.randn(bs, 3, 128, 128).to(device)) 519 | if args.levels >= 3: 520 | Ds.append(D_NET256()) 521 | imgs.append(torch.randn(bs, 3, 256, 256).to(device)) 522 | for i in range(len(Ds)): 523 | Ds[i] = nn.DataParallel(Ds[i].to(device)) 524 | 525 | noise = torch.randn(bs, args.z_dim).to(device) 526 | txt_feat = torch.randn(bs, args.input_dim).to(device) 527 | fake_imgs, mu, logvar = G(noise, txt_feat) 528 | print(len(fake_imgs)) 529 | print(fake_imgs[2].shape) 530 | print(mu.shape, logvar.shape) 531 | 532 | for level in range(3): 533 | D = Ds[level] 534 | img = imgs[level] 535 | real_logits = D(img, mu.detach()) 536 | print(real_logits[0].shape, real_logits[1].shape) --------------------------------------------------------------------------------