├── deephumor ├── __init__.py ├── imaging │ ├── __init__.py │ └── caption.py ├── crawlers │ ├── __init__.py │ ├── utils.py │ └── crawlers.py ├── experiments │ ├── __init__.py │ ├── metrics.py │ ├── inference.py │ └── trainer.py ├── data │ ├── __init__.py │ ├── dataloaders.py │ ├── tokenizers.py │ ├── utils.py │ ├── vocab.py │ └── datasets.py └── models │ ├── __init__.py │ ├── beam.py │ ├── encoders.py │ ├── rnn_models.py │ ├── caption_models.py │ └── transformers.py ├── assets ├── lstm.png ├── lstm-labels.png ├── transformer.png ├── base-transformer.png └── deep-learning-meme.jpg ├── fonts └── impact.ttf ├── requirements.txt ├── load_data.sh ├── split_data.py ├── crawl_data.py ├── .gitignore └── README.md /deephumor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/deephumor/HEAD/assets/lstm.png -------------------------------------------------------------------------------- /fonts/impact.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/deephumor/HEAD/fonts/impact.ttf -------------------------------------------------------------------------------- /assets/lstm-labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/deephumor/HEAD/assets/lstm-labels.png -------------------------------------------------------------------------------- /assets/transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/deephumor/HEAD/assets/transformer.png -------------------------------------------------------------------------------- /assets/base-transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/deephumor/HEAD/assets/base-transformer.png -------------------------------------------------------------------------------- /assets/deep-learning-meme.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ilya16/deephumor/HEAD/assets/deep-learning-meme.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchvision 4 | requests 5 | lxml 6 | langdetect 7 | python-Levenshtein 8 | pillow -------------------------------------------------------------------------------- /deephumor/imaging/__init__.py: -------------------------------------------------------------------------------- 1 | from .caption import memeify_image 2 | 3 | __all__ = [ 4 | 'memeify_image' 5 | ] 6 | -------------------------------------------------------------------------------- /deephumor/crawlers/__init__.py: -------------------------------------------------------------------------------- 1 | from .crawlers import MemeGeneratorCrawler 2 | 3 | __all__ = [ 4 | 'MemeGeneratorCrawler' 5 | ] -------------------------------------------------------------------------------- /deephumor/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import * 2 | from .metrics import * 3 | from .trainer import Trainer 4 | 5 | __all__ = [ 6 | 'text_to_seq', 7 | 'seq_to_text', 8 | 'split_caption', 9 | 'perplexity', 10 | 'Trainer' 11 | ] 12 | -------------------------------------------------------------------------------- /deephumor/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloaders import pad_collate 2 | from .datasets import MemeDataset 3 | from .tokenizers import * 4 | from .vocab import * 5 | 6 | __all__ = [ 7 | 'SPECIAL_TOKENS', 'Vocab', 'build_vocab', 'build_vocab_from_file', 8 | 'Tokenizer', 'WordPunctTokenizer', 'CharTokenizer', 9 | 'MemeDataset', 'pad_collate' 10 | ] 11 | -------------------------------------------------------------------------------- /load_data.sh: -------------------------------------------------------------------------------- 1 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1j6YG3skamxA1-mdogC1kRjugFuOkHt_A' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1j6YG3skamxA1-mdogC1kRjugFuOkHt_A" -O memes.zip && rm -rf /tmp/cookies.txt 2 | unzip memes.zip -------------------------------------------------------------------------------- /deephumor/experiments/metrics.py: -------------------------------------------------------------------------------- 1 | """Evaluation metrics.""" 2 | 3 | 4 | def perplexity(logits, targets, lengths, pad_index=0): 5 | log_values = logits.log_softmax(-1).gather(-1, targets.unsqueeze(-1)).squeeze() 6 | log_values /= lengths.unsqueeze(1) # divide by lengths 7 | log_values[targets == pad_index] = 0. # remove padded indices 8 | pp_seq = (-log_values.sum(dim=-1)).exp() # compute per-sequence perplexity 9 | return pp_seq.mean() 10 | -------------------------------------------------------------------------------- /deephumor/data/dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | 5 | def pad_collate(batch): 6 | """Batch collate with padding for Dataloader.""" 7 | # unpack batch 8 | labels, captions, images = zip(*batch) 9 | 10 | # pad sequences 11 | labels = pad_sequence(labels, batch_first=True, padding_value=0) 12 | captions = pad_sequence(captions, batch_first=True, padding_value=0) 13 | images = torch.stack(images, dim=0) 14 | 15 | return labels, captions, images 16 | -------------------------------------------------------------------------------- /deephumor/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders import ( 2 | ImageEncoder, 3 | ImageLabelEncoder 4 | ) 5 | from .rnn_models import LSTMDecoder 6 | from .transformers import ( 7 | TransformerEncoder, 8 | TransformerDecoder, 9 | ) 10 | from .caption_models import ( 11 | CaptioningLSTM, 12 | CaptioningLSTMWithLabels, 13 | CaptioningTransformerBase, 14 | CaptioningTransformer 15 | ) 16 | 17 | __all__ = [ 18 | 'ImageEncoder', 19 | 'ImageLabelEncoder', 20 | 'LSTMDecoder', 21 | 'TransformerEncoder', 22 | 'TransformerDecoder', 23 | 'CaptioningTransformerBase', 24 | 'CaptioningTransformer', 25 | ] 26 | -------------------------------------------------------------------------------- /deephumor/data/tokenizers.py: -------------------------------------------------------------------------------- 1 | """Text Tokenizers.""" 2 | import abc 3 | import re 4 | 5 | 6 | class Tokenizer: 7 | """Abstract tokenizer.""" 8 | 9 | @abc.abstractmethod 10 | def tokenize(self, text): 11 | pass 12 | 13 | 14 | class WordPunctTokenizer: 15 | """WordPunctuation tokenizer.""" 16 | 17 | token_pattern = re.compile(r"[<\w'>]+|[^\w\s]+") 18 | 19 | def tokenize(self, text): 20 | return self.token_pattern.findall(text) 21 | 22 | 23 | class CharTokenizer: 24 | """Character-level tokenizer that preserves special tokens in `<>`.""" 25 | 26 | token_pattern = re.compile(r"<\w+>|.") 27 | 28 | def tokenize(self, text): 29 | return self.token_pattern.findall(text) 30 | -------------------------------------------------------------------------------- /deephumor/crawlers/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import requests 5 | 6 | 7 | def time_to_str(time): 8 | """Converts time in seconds into pretty-looking string.""" 9 | return f'{int(time / 60.):3d}:{(time % 60.):05.2f}' 10 | 11 | 12 | def load_image(image_url, save_dir='.'): 13 | """Loads image by url. 14 | 15 | Args: 16 | image_url (str): image URL 17 | save_dir (str): directory for saving the image 18 | 19 | Returns: 20 | str: name of the file 21 | """ 22 | r = requests.get(image_url, stream=True) 23 | file_name = image_url.split('/')[-1] 24 | image_path = os.path.join(save_dir, file_name) 25 | 26 | with open(image_path, 'wb') as out: 27 | shutil.copyfileobj(r.raw, out) 28 | 29 | return file_name 30 | -------------------------------------------------------------------------------- /deephumor/data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from langdetect import detect_langs 4 | 5 | TOKEN_PATTERN = re.compile(r"[<\w'>]+|[!#$%&\()*+,\-./:;=?@\\^{|}~]+") 6 | PUNCT_PATTERN_0 = re.compile(r"([<>|\\])+") 7 | PUNCT_PATTERN_1 = re.compile(r"([%&\()*+,\-/:;=@^{}~\"])+") 8 | PUNCT_PATTERN_23 = re.compile(r"([\.?!$#_]){4,}") 9 | 10 | 11 | def clean_text(text): 12 | """Cleans text from unnecessary punctuation repetitions""" 13 | text = text if text else '' 14 | 15 | if text: 16 | text = PUNCT_PATTERN_0.sub('', text) 17 | text = PUNCT_PATTERN_1.sub(r'\g<1>', text) 18 | text = PUNCT_PATTERN_23.sub(r'\g<1>\g<1>\g<1>', text) 19 | 20 | return " ".join(text.split()) 21 | 22 | 23 | def check_text(text, min_len=10, max_len=100, max_tokens=32): 24 | """Checks characters and length of the text.""" 25 | # check non-english characters 26 | try: 27 | text.encode('ascii') 28 | except UnicodeEncodeError: 29 | return False 30 | 31 | # filter long texts 32 | if len(text) < min_len or len(text) > max_len: 33 | return False 34 | 35 | # filter texts with many tokens 36 | if len(TOKEN_PATTERN.findall(text)) > max_tokens: 37 | return False 38 | 39 | return True 40 | 41 | 42 | def english_prob(text): 43 | """Returns the probability of the text to be english text.""" 44 | langs = detect_langs(text) 45 | for lang in langs: 46 | if lang.lang == 'en': 47 | return lang.prob 48 | return 0. 49 | -------------------------------------------------------------------------------- /split_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser('Meme dataset split') 9 | 10 | parser.add_argument('--data-dir', '-d', required=True, type=str, 11 | help='directory with the dataset') 12 | parser.add_argument('--splits', type=int, default=(2500, 250, 250), nargs=3, 13 | help='sizes of train/val/test splits for each template') 14 | parser.add_argument('--random-state', type=int, default=0, 15 | help='random seed for the data shuffling') 16 | 17 | args = parser.parse_args() 18 | 19 | np.random.seed(0) 20 | start_ids = np.cumsum([0] + args.splits) 21 | end_ids = start_ids[1:] 22 | 23 | labels, captions = defaultdict(bool), defaultdict(list) 24 | with open(os.path.join(args.data_dir, 'captions.txt'), 'r') as f: 25 | for line in f: 26 | label, _, _ = line.strip().split('\t') 27 | captions[label].append(line) 28 | labels[label] = True 29 | 30 | splits = ['train', 'val', 'test'] 31 | f_splits = [ 32 | open(os.path.join(args.data_dir, f'captions_{split}.txt'), 'w') 33 | for split in splits 34 | ] 35 | 36 | for label in labels.keys(): 37 | indices = np.arange(len(captions[label])) 38 | np.random.shuffle(indices) 39 | 40 | for i, f in enumerate(f_splits): 41 | for idx in sorted(indices[start_ids[i]:end_ids[i]]): 42 | f.write(captions[label][idx]) 43 | 44 | for f in f_splits: 45 | f.close() 46 | -------------------------------------------------------------------------------- /crawl_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from deephumor.crawlers import MemeGeneratorCrawler 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser('Meme dataset crawler') 7 | 8 | parser.add_argument('--source', '-s', type=str, default='memegenerator.net', 9 | help='data source') 10 | parser.add_argument('--save-dir', '-d', required=True, type=str, 11 | help='directory where the dataset should be stored') 12 | 13 | # crawling arguments 14 | parser.add_argument('--poolsize', '-p', type=int, default=25, 15 | help='size of the multiprocessing Pool') 16 | parser.add_argument('--num-templates', '-t', type=int, default=300, 17 | help='number of templates to crawl') 18 | parser.add_argument('--num-captions', '-c', type=int, default=1000, 19 | help='number of captions per template') 20 | 21 | parser.add_argument('--detect-english', action='store_true', 22 | help='filter out templates with majority of english texts') 23 | parser.add_argument('--detect-duplicates', action='store_true', 24 | help='(slow) filter out duplicate captions') 25 | 26 | parser.add_argument('--min-len', type=int, default=10, 27 | help='minimum length of the caption text') 28 | parser.add_argument('--max-len', type=int, default=96, 29 | help='maximum length of the caption text') 30 | parser.add_argument('--max-tokens', type=int, default=31, 31 | help='maximum number of tokens in the caption text') 32 | 33 | args = parser.parse_args() 34 | assert args.source == 'memegenerator.net', 'Only memegenerator.net is supported' 35 | 36 | crawler = MemeGeneratorCrawler( 37 | poolsize=args.poolsize, 38 | min_len=args.min_len, max_len=args.max_len, max_tokens=args.max_tokens, 39 | detect_english=args.detect_english, detect_duplicates=args.detect_duplicates 40 | ) 41 | 42 | crawler.crawl_dataset( 43 | num_templates=args.num_templates, 44 | num_captions=args.num_captions, 45 | save_dir=args.save_dir 46 | ) 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /deephumor/experiments/inference.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | 5 | from deephumor.data import SPECIAL_TOKENS 6 | 7 | 8 | PUNCT_PATTERN = re.compile(r"( )([!#$%&\()*+,\-.\/:;<=>?@\\^{|}~]+)") 9 | 10 | 11 | def text_to_seq(text, vocab, tokenizer): 12 | """Transforms string text into a tensor of tokens. 13 | 14 | Args: 15 | text (str): text input 16 | vocab (Vocab): token vocabulary 17 | tokenizer (Tokenizer): text tokenizer 18 | 19 | Returns: 20 | Torch.tensor: sequence of tokens of size (1, seq_len) 21 | """ 22 | 23 | # tokenize 24 | tokens = tokenizer.tokenize(text.lower()) 25 | 26 | # replace with `UNK` 27 | tokens = [tok if tok in vocab.stoi else SPECIAL_TOKENS['UNK'] for tok in tokens] 28 | 29 | # convert to ids 30 | tokens = [vocab.stoi[tok] for tok in tokens] 31 | 32 | return torch.tensor(tokens).unsqueeze(0) 33 | 34 | 35 | def seq_to_text(seq, vocab, delimiter=' '): 36 | """Transforms torch tensor of tokens into a text. 37 | 38 | Args: 39 | seq (Torch.tensor): sequence of tokens of size (1, seq_len) 40 | vocab (Vocab): token vocabulary 41 | delimiter (str): delimiter between text tokens 42 | 43 | Returns: 44 | str: transformed text 45 | """ 46 | 47 | # find the end the sequence 48 | eos_ids = torch.where(seq == vocab.stoi[SPECIAL_TOKENS['EOS']])[0] 49 | if len(eos_ids) > 0: 50 | seq = seq[:eos_ids[0]] 51 | 52 | # convert tokens indices into text tokens 53 | tokens = list(map(lambda x: vocab.itos[x], seq.cpu().numpy())) 54 | 55 | # join text tokens 56 | text = delimiter.join(tokens) 57 | 58 | return text 59 | 60 | 61 | def split_caption(text, num_blocks=None): 62 | """Splits text caption into blocks according to the special tokens. 63 | 64 | Args: 65 | text (str): input caption text 66 | num_blocks (int): number of blocks to return (`None` for keeping all) 67 | 68 | Returns: 69 | List[str]: a list of text blocks 70 | """ 71 | 72 | def _clean_text_block(text_block): 73 | text_block = re.sub(r'<\w+>', '', text_block) 74 | text_block = re.sub(r'^\s+', '', text_block) 75 | text_block = re.sub(r'\s+$', '', text_block) 76 | text_block = PUNCT_PATTERN.sub('\\2', text_block) 77 | return text_block 78 | 79 | text_blocks = text.split(SPECIAL_TOKENS['SEP']) 80 | 81 | # clean blocks from any special tokens and padding spaces 82 | text_blocks = [_clean_text_block(t) for t in text_blocks] 83 | 84 | if num_blocks is None: 85 | num_blocks = len(text_blocks) 86 | elif len(text_blocks) < num_blocks: 87 | text_blocks += [''] * (num_blocks - len(text_blocks)) 88 | 89 | return text_blocks[:num_blocks] 90 | -------------------------------------------------------------------------------- /deephumor/data/vocab.py: -------------------------------------------------------------------------------- 1 | """Vocabulary tools.""" 2 | 3 | from collections import Counter 4 | 5 | SPECIAL_TOKENS = { 6 | 'PAD': '', 7 | 'UNK': '', 8 | 'BOS': '', 9 | 'EOS': '', 10 | 'SEP': '', 11 | 'EMPTY': '', 12 | } 13 | 14 | 15 | class Vocab: 16 | """Token vocabulary.""" 17 | 18 | def __init__(self, tokens, special_tokens=tuple(SPECIAL_TOKENS.values())): 19 | tokens = list(sorted(filter(lambda x: x not in special_tokens, tokens))) 20 | self.tokens = list(special_tokens) + tokens 21 | self.stoi = {self.tokens[idx]: idx for idx in range(len(self.tokens))} 22 | self.itos = {idx: self.tokens[idx] for idx in range(len(self.tokens))} 23 | 24 | def __iter__(self): 25 | return iter(self.tokens) 26 | 27 | def __len__(self): 28 | return len(self.tokens) 29 | 30 | def save(self, filepath): 31 | with open(filepath, 'w') as f: 32 | for token in self.tokens: 33 | f.write(f'{token}\n') 34 | 35 | @staticmethod 36 | def load(filepath): 37 | tokens = [] 38 | with open(filepath, 'r') as f: 39 | for line in f: 40 | token = line.strip('\n') 41 | tokens.append(token) 42 | return Vocab(tokens) 43 | 44 | 45 | def build_vocab(documents, tokenizer, min_df=7): 46 | """Builds vocabulary of tokens from a collection of documents. 47 | 48 | Args: 49 | documents (list[str]): collection of documents 50 | tokenizer (Tokenizer): Tokenizer object 51 | min_df (int): minimum document frequency for tokens 52 | 53 | Returns: 54 | Vocab: vocabulary of tokens 55 | """ 56 | token_counts = Counter() 57 | 58 | # tokenize and count unique tokens 59 | for text in documents: 60 | tokens = set(tokenizer.tokenize(text.lower())) 61 | token_counts.update(tokens) 62 | 63 | # filter by minimum document frequency 64 | tokens = [token for token, count in token_counts.items() if count >= min_df] 65 | 66 | # build vocabulary 67 | vocab = Vocab(tokens) 68 | 69 | return vocab 70 | 71 | 72 | def build_vocab_from_file(captions_file, tokenizer, min_df=7): 73 | """Builds vocabulary from captions file. 74 | 75 | Args: 76 | captions_file (str): path to the file with captions 77 | tokenizer (Tokenizer): Tokenizer object 78 | min_df (int): minimum document frequency for tokens 79 | 80 | Returns: 81 | Vocab: vocabulary of tokens 82 | """ 83 | 84 | captions = [] 85 | with open(captions_file) as f: 86 | for line in f: 87 | _, _, caption = line.strip().split('\t') 88 | captions.append(caption) 89 | 90 | return build_vocab(captions, tokenizer, min_df=min_df) 91 | -------------------------------------------------------------------------------- /deephumor/data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | from .vocab import SPECIAL_TOKENS 8 | from .tokenizers import WordPunctTokenizer 9 | 10 | 11 | class MemeDataset(Dataset): 12 | """MemeGenerator dataset class.""" 13 | 14 | def __init__(self, root, vocab, tokenizer=WordPunctTokenizer(), 15 | split='train', num_classes=300, image_transform=None, 16 | preload_images=True): 17 | assert split in ('train', 'val', 'test'), 'Incorrect data split' 18 | 19 | self.root = root 20 | self.split = split 21 | self.tokenizer = tokenizer 22 | self.vocab = vocab 23 | self.image_transform = image_transform 24 | self.preload_images = preload_images 25 | 26 | self.num_classes = num_classes 27 | self._load_dataset() 28 | 29 | def _load_dataset(self): 30 | # load templates information 31 | fn_temp = os.path.join(self.root, 'templates.txt') 32 | assert os.path.exists(fn_temp), \ 33 | f'Templates file {fn_temp} is not found' 34 | 35 | dir_imgs = os.path.join(self.root, 'images') 36 | assert os.path.isdir(dir_imgs), \ 37 | f'Images directory {dir_imgs} is not found' 38 | 39 | self.templates = {} 40 | self.images = {} 41 | with open(fn_temp, 'r') as f: 42 | for line in f: 43 | label, _, url = line.strip().split('\t') 44 | filename = url.split('/')[-1] 45 | self.templates[label] = os.path.join(dir_imgs, filename) 46 | 47 | # preaload images and apply transforms 48 | if self.preload_images: 49 | img = Image.open(self.templates[label]) 50 | if self.image_transform is not None: 51 | img = self.image_transform(img) 52 | self.images[label] = img 53 | else: 54 | self.images[label] = self.templates[label] 55 | 56 | if len(self.templates) == self.num_classes: 57 | break 58 | 59 | # load captions 60 | fn_capt = os.path.join(self.root, f'captions_{self.split}.txt') 61 | assert os.path.exists(fn_capt), \ 62 | f'Captions file {fn_capt} is not found' 63 | 64 | self.captions = [] 65 | with open(fn_capt, 'r') as f: 66 | for i, line in enumerate(f): 67 | label, _, caption = line.strip().split('\t') 68 | if label in self.templates: 69 | self.captions.append((label, caption)) 70 | 71 | def _preprocess_text(self, text): 72 | # tokenize 73 | tokens = self.tokenizer.tokenize(text.lower()) 74 | 75 | # replace with `UNK` 76 | tokens = [tok if tok in self.vocab.stoi else SPECIAL_TOKENS['UNK'] for tok in tokens] 77 | 78 | # add `EOS` 79 | tokens += [SPECIAL_TOKENS['EOS']] 80 | 81 | # convert to ids 82 | tokens = [self.vocab.stoi[tok] for tok in tokens] 83 | 84 | return tokens 85 | 86 | def __getitem__(self, idx): 87 | label, caption = self.captions[idx] 88 | img = self.images[label] 89 | 90 | # label and caption tokens 91 | label = torch.tensor(self._preprocess_text(label)).long() 92 | caption = torch.tensor(self._preprocess_text(caption)).long() 93 | 94 | # image transform 95 | if not self.preload_images: 96 | img = Image.open(img) 97 | if self.image_transform is not None: 98 | img = self.image_transform(img) 99 | 100 | return label, caption, img 101 | 102 | def __len__(self): 103 | return len(self.captions) 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepHumor: Image-based Meme Generation using Deep Learning 2 | 3 | > Final Project in "Deep Learning" course in Skotech, 2020. 4 | > Authors: [Ilya Borovik](https://github.com/ilya16), [Bulat Khabibullin](https://github.com/Bulichek), [Vladislav Kniazev](https://github.com/Vladoskn), [Oluwafemi Olaleke](https://github.com/6861) and [Zakhar Pichugin](https://github.com/zakharpichugin) 5 | > 6 | >[![Open in YouTube](https://img.shields.io/badge/_-Presentation-red.svg?logo=youtube&labelColor=5c5c5c)](https://youtu.be/gf-HcRwsSfI) 7 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilya16/deephumor/blob/master/deephumor_demo.ipynb) 8 | 9 | Deep Learning meme 10 | 11 | ## Description 12 | 13 | The repository presents multiple meme generation models (see illustrations [below](#models)): 14 | 15 | - Captioning LSTM with Image-only Encoder 16 | - Captioning LSTM with Image-label Encoder 17 | - Base Captioning Transformer with Global image embedding 18 | - Captioning Transformer LSTM with Spatial image features 19 | 20 | **Observe the models in action in the demo notebook:** 21 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilya16/deephumor/blob/master/deephumor_demo.ipynb) 22 | [![Open in GitHub](https://img.shields.io/badge/_-Open_in_GitHub-blue.svg?logo=Jupyter&labelColor=5c5c5c)](deephumor_demo.ipynb) 23 | 24 | All pretrained models will be automatically downloaded and built in Colab runtime. 25 | 26 | Except for the models, we collect and release a large-scale dataset of 900,000 meme templates crawled from [MemeGenerator](https://memegenerator.net) website. 27 | The dataset is uploaded to [Google Drive](https://drive.google.com/file/d/1j6YG3skamxA1-mdogC1kRjugFuOkHt_A). Description of the dataset is given in the corresponding [section](#dataset). 28 | 29 | *Note: Repository state at the end of "Deep Learning" course project is recorded in the branch* [`skoltech-dl-project`](https://github.com/ilya16/deephumor/tree/skoltech-dl-project). 30 | 31 | ## Training code 32 | 33 | The example code for training the models is provided in [Colab notebook](https://colab.research.google.com/drive/1ayyWPuOw8ET2SRZ5KD-r4dwMH4jBn-B8). It contains the training progress and TensorBoard logs for all experiments described in the project report. 34 | 35 | ## Dataset 36 | 37 | We crawl and preprocess a large-scale meme dataset consisting of 900,000 meme captions for 300 meme template images collected from [MemeGenerator](https://memegenerator.net) website. 38 | During the data collection we clean the data from evident duplicates, long caption outliers, non-ASCII symbols and non-English templates. 39 | 40 | ### Download dataset 41 | Crawled dataset of 300 meme templates with 3000 captions per templates can be downloaded 42 | using [`load_data.sh`](load_data.sh) script or directly from [Google Drive](https://drive.google.com/file/d/1j6YG3skamxA1-mdogC1kRjugFuOkHt_A). The data is split into `train/val/test` with 2500/250/250 captions per split for each template. We provide the data splits to make the comparison of new models with our works possible. 43 | 44 | The dataset archive follows the following format: 45 | 46 | ``` 47 | ├── memes900k 48 | | ├── images -- template images 49 | | ├── cool-dog.jpg 50 | | ├── dogeee.jpg 51 | | ├── ... 52 | | ├── tempaltes.txt -- template labels and image urls 53 | | ├── captions.txt -- all captions 54 | | ├── captions_train.txt -- training split 55 | | ├── captions_val.txt -- validation split 56 | | ├── captions_test.txt -- test split 57 | ``` 58 | 59 | ### Crawl dataset 60 | To crawl own dataset, run the following script: 61 | ```shell script 62 | python crawl_data.py --source memegenerator.net --save-dir ../memes \ 63 | --poolsize 25 --num-templates 300 --num-captions 3000 \ 64 | --detect-english --detect-duplicates \ 65 | --min-len 10 --max-len 96 --max-tokens 31 66 | ``` 67 | 68 | Then, split the data into `train/val/test` using: 69 | ```shell script 70 | python split_data.py --data-dir ../memes --splits 2500 250 250 71 | ``` 72 | 73 | ## Models 74 | 75 | ### Captioning LSTM 76 | Captioning LSTM 77 | 78 | ### Captioning LSTM with labels 79 | Captioning LSTM with labels 80 | 81 | ### Captioning Base Transformer 82 | Captioning Base Transformer 83 | 84 | ### Captioning Transformer 85 | Captioning Transformer 86 | -------------------------------------------------------------------------------- /deephumor/models/beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BeamSearchHelper: 5 | """Helper class with common functions for beam search sampling.""" 6 | 7 | def __init__(self, temperature=1.0, beam_size=10, top_k=50, 8 | unk_index=1, eos_index=3, device='cuda'): 9 | assert beam_size <= top_k, '`beam_size` should be less than `top_k`' 10 | 11 | self.temperature = temperature 12 | self.beam_size = beam_size 13 | self.top_k = top_k 14 | self.unk_index = unk_index 15 | self.eos_index = eos_index 16 | self.device = device 17 | self._build_has_ended_variables() 18 | 19 | def _build_has_ended_variables(self): 20 | """Returns flags and masks for monitoring if generation has ended.""" 21 | # flags showing if sequence has ended 22 | self.has_ended = torch.tensor([False] * self.beam_size).to(self.device) 23 | 24 | # masks for filtering out predictions for ended/not_ended sequences 25 | self._n_copies_has_ended = torch.tensor([[self.beam_size], [1]]).to(self.device) 26 | self._mask_has_ended = torch.stack( 27 | [torch.tensor([True] * self.beam_size), 28 | torch.tensor([True] + [False] * (self.beam_size - 1))], 29 | dim=0 30 | ).to(self.device) 31 | 32 | def filter_top_k(self, logits): 33 | """Filters `top_k` logit values by zeroing out others.""" 34 | filter_ind = logits < torch.topk(logits, self.top_k, dim=-1).values[:, -1].unsqueeze(-1) 35 | filter_ind[:, self.unk_index] = True # zero out unk token 36 | logits[filter_ind] = float('-inf') 37 | return logits 38 | 39 | def sample_k_indices(self, logits, k=None): 40 | """Samples `beam_size` indices for each sequence in the batch.""" 41 | # compute probabilities 42 | p_next = torch.softmax(logits / self.temperature, dim=-1) 43 | 44 | # sample values 45 | k = self.beam_size if k is None else k 46 | sample_ind = torch.multinomial(p_next, k) 47 | 48 | return sample_ind 49 | 50 | @staticmethod 51 | def filter_by_indices(values, indices): 52 | sample_val = torch.gather(values, 1, indices) 53 | return sample_val 54 | 55 | def process_logits(self, logits, sample_seq, sample_val): 56 | """Main logic of beam search sampling step. 57 | 58 | Steps: 59 | - filter `top_k` logit scores 60 | - filter out predictions for already ended sequences 61 | - check if new predictions end sequence 62 | - update `has_ended` indices 63 | 64 | Args: 65 | logits (torch.Tensor): logit predictions, outputs of the classifier layer 66 | sample_seq (torch.Tensor): `beam_size` sequences from the previous sampling step 67 | sample_val (torch.Tensor): scores for the sequences from the previous sampling step 68 | 69 | Returns: 70 | (prev_seqs, prev_vals), (new_ind, new_val): 71 | expanded sequences and their scores from the previous sampling step 72 | + new candidate predictions and their scores 73 | """ 74 | # filter `top_k` values 75 | logits = self.filter_top_k(logits) 76 | 77 | # sample `beam` sequences for each branch 78 | new_ind = self.sample_k_indices(logits, k=self.beam_size) 79 | new_val = self.filter_by_indices(logits, new_ind).log_softmax(-1) 80 | new_ind, new_val = new_ind.flatten(), new_val.flatten() 81 | 82 | # numbers of repeat_interleave copies (if ended, only a single copy) 83 | n_copies = self._n_copies_has_ended[self.has_ended.long(), :].flatten() 84 | 85 | # mask for unique rows 86 | unique_rows = self._mask_has_ended[self.has_ended.long(), :].flatten() 87 | 88 | # filter values 89 | new_ind = new_ind[unique_rows] 90 | new_val = new_val[unique_rows] 91 | 92 | # check if the sequences already ended 93 | # (no need to predict and evaluate new scores) 94 | self.has_ended = torch.repeat_interleave(self.has_ended, n_copies, dim=0) 95 | new_ind[self.has_ended], new_val[self.has_ended] = 0, 0. 96 | 97 | # update `had_ended` based on new predictions 98 | self.has_ended = self.has_ended | (new_ind == self.eos_index) 99 | 100 | # repeat current sampled sequences 101 | prev_seqs = torch.repeat_interleave(sample_seq, n_copies, dim=0) 102 | prev_vals = torch.repeat_interleave(sample_val, n_copies, dim=0) 103 | 104 | if len(prev_seqs.size()) == 1: 105 | prev_seqs = prev_seqs.unsqueeze(0) 106 | prev_vals = prev_vals.unsqueeze(0) 107 | 108 | return (prev_seqs, prev_vals), (new_ind, new_val) 109 | 110 | def all_ended(self): 111 | """Returns bool indicating if all sequences have ended.""" 112 | return torch.all(self.has_ended) 113 | -------------------------------------------------------------------------------- /deephumor/models/encoders.py: -------------------------------------------------------------------------------- 1 | """Image and Text Encoder models.""" 2 | import torch 3 | from torch import nn 4 | from torchvision import models 5 | 6 | 7 | class ImageEncoder(nn.Module): 8 | """ResNet-based [1] image encoder. 9 | 10 | Encodes an image into a `emb_size` vector. 11 | 12 | If `spatial_features=True`, encoder also builds spatial features 13 | of the image based on the output of the last block of ResNet. 14 | The shape of spatial features is `[k x k, emb_size]` 15 | 16 | Note: `nn.Linear` layer is shared for global and spatial encodings. 17 | 18 | References: 19 | [1]: "Deep Residual Learning for Image Recognition", https://arxiv.org/abs/1512.03385 20 | """ 21 | 22 | def __init__(self, emb_dim=256, dropout=0.2, spatial_features=False): 23 | """Initializes ImageEncoder. 24 | 25 | Args: 26 | emb_dim (int): dimensions of the output embedding 27 | dropout (float): dropout for the encoded features 28 | spatial_features (bool): whether compute spatial features or not 29 | """ 30 | super().__init__() 31 | 32 | self.spatial_features = spatial_features 33 | 34 | resnet = models.resnet50(pretrained=True) 35 | for p in resnet.parameters(): 36 | p.requires_grad = False 37 | modules = list(resnet.children())[:-2] 38 | self.resnet = nn.Sequential(*modules) 39 | self.avgpool = resnet.avgpool 40 | 41 | # embedding layer 42 | self.linear = nn.Linear(resnet.fc.in_features, emb_dim) 43 | self.bn = nn.BatchNorm1d(emb_dim) 44 | self.dropout = nn.Dropout(dropout) 45 | 46 | def forward(self, images): 47 | """ 48 | Args: 49 | images (torch.Tensor): input images of shape `[bs, width, height]` 50 | 51 | Returns: 52 | torch.Tensor: global image embedding of shape `[bs, emb_dim]` if `self.spatial_features=False`, 53 | (`self.spatial_features=True`) spatial image embeddings of shape `[bs, k_w x k_h, emb_dim]` 54 | """ 55 | # ResNet features 56 | features = self.resnet(images) 57 | bs, dim = features.shape[:2] 58 | 59 | # global image embedding 60 | x = self.avgpool(features).reshape(bs, -1) 61 | emb = self.dropout(self.bn(self.linear(x))) 62 | 63 | # spatial features 64 | if self.spatial_features: 65 | x = features.reshape(bs, dim, -1) 66 | x = x.transpose(2, 1) # (B, D, N) -> (B, N, D) 67 | spatial_emb = self.dropout(self.linear(x)) 68 | return emb, spatial_emb 69 | 70 | return emb 71 | 72 | 73 | class LabelEncoder(nn.Module): 74 | """Label encoder. 75 | 76 | Encodes text labels into a single embedding of size `emb_dim`. 77 | 78 | Label Encoder 2 from [1]. 79 | 80 | References: 81 | [1]: "Dank Learning: Generating Memes Using Deep Neural Networks", https://arxiv.org/abs/1806.04510 82 | """ 83 | 84 | def __init__(self, num_tokens, emb_dim=256, dropout=0.2): 85 | """Initializes LabelEncoder. 86 | 87 | Args: 88 | num_tokens: number of tokens in the vocabulary 89 | emb_dim (int): dimensions of the output embedding 90 | dropout (float): dropout for the encoded features 91 | """ 92 | super().__init__() 93 | self.embedding = nn.Embedding(num_tokens, emb_dim) 94 | self.dropout = nn.Dropout(dropout) 95 | 96 | def forward(self, labels): 97 | """ 98 | Args: 99 | labels (torch.Tensor): input text labels of shape `[bs, seq_len]` 100 | 101 | Returns: 102 | torch.Tensor: average label embedding of shape `[bs, emb_dim]` 103 | """ 104 | emb = self.embedding(labels).mean(dim=1) 105 | emb = self.dropout(emb) 106 | return emb 107 | 108 | 109 | class ImageLabelEncoder(nn.Module): 110 | """ImageLabel encoder. 111 | 112 | Encodes images and text labels into a single embedding of size `emb_dim`. 113 | """ 114 | 115 | def __init__(self, num_tokens, emb_dim=256, dropout=0.2): 116 | """Initializes LabelEncoder. 117 | 118 | Args: 119 | num_tokens: number of tokens in the vocabulary 120 | emb_dim (int): dimensions of the output embedding 121 | dropout (float): dropout for the encoded features 122 | """ 123 | super().__init__() 124 | self.image_encoder = ImageEncoder(emb_dim, dropout) 125 | self.label_encoder = LabelEncoder(num_tokens, emb_dim, dropout) 126 | self.linear = nn.Linear(2 * emb_dim, emb_dim) 127 | self.dropout = nn.Dropout(dropout) 128 | 129 | def forward(self, images, labels): 130 | """ 131 | Args: 132 | images (torch.Tensor): input images of shape `[bs, width, height]` 133 | labels (torch.Tensor): input text labels of shape `[bs, seq_len]` 134 | 135 | Returns: 136 | torch.Tensor: combined image-label embedding of shape `[bs, emb_dim]` 137 | """ 138 | image_emb = self.image_encoder(images) 139 | label_emb = self.label_encoder(labels) 140 | 141 | emb = torch.cat([image_emb, label_emb], dim=1) 142 | emb = self.dropout(self.linear(emb)) 143 | 144 | return emb 145 | -------------------------------------------------------------------------------- /deephumor/models/rnn_models.py: -------------------------------------------------------------------------------- 1 | """RNN-based models.""" 2 | import torch 3 | from torch import nn 4 | 5 | from deephumor.models.beam import BeamSearchHelper 6 | 7 | 8 | class LSTMDecoder(nn.Module): 9 | """LSTM-based decoder.""" 10 | 11 | def __init__(self, num_tokens, emb_dim=256, hidden_size=512, 12 | num_layers=3, dropout=0.1, embedding=None): 13 | 14 | super(LSTMDecoder, self).__init__() 15 | 16 | self.num_tokens = num_tokens 17 | 18 | if embedding is not None: 19 | self.embedding = embedding 20 | else: 21 | self.embedding = nn.Embedding(num_tokens, emb_dim) 22 | 23 | self.lstm = nn.LSTM(emb_dim, hidden_size, num_layers, batch_first=True, 24 | dropout=(0 if num_layers == 1 else dropout)) 25 | 26 | self.classifier = nn.Linear(hidden_size, num_tokens) 27 | 28 | def forward(self, image_emb, captions, lengths=None): 29 | # caption tokens embeddings 30 | token_emb = self.embedding(captions) 31 | 32 | # image embedding + token embeddings 33 | x = torch.cat((image_emb.unsqueeze(1), token_emb), dim=1) 34 | 35 | if lengths is None: 36 | lengths = torch.tensor(x.size(1)).repeat(x.size(0)) 37 | 38 | # LSTM ouputs 39 | packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) 40 | outputs, _ = self.lstm(packed) 41 | outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) 42 | 43 | # mapping into `num_tokens` 44 | outputs = self.classifier(outputs) 45 | 46 | return outputs 47 | 48 | def generate(self, image_emb, caption=None, max_len=25, 49 | temperature=1.0, beam_size=10, top_k=50, eos_index=3): 50 | """Generates text tokens based on the image embedding. 51 | 52 | Args: 53 | image_emb (torch.Tensor): image embedding of shape `[1, emb_dim]` 54 | caption (torch.Tensor, optional): beginning tokens of the caption of shape `[1, seq_len]` 55 | max_len (int): maximum length of the caption 56 | temperature (float): temperature for softmax over logits 57 | beam_size (int): number of maintained branches at each step 58 | top_k (int): number of the most probable tokens to consider during sampling 59 | eos_index (int): index of the EOS (end-of-sequence) token 60 | 61 | Returns: 62 | torch.Tensor: generated caption tokens of shape `[1, min(output_len, max_len)]` 63 | """ 64 | 65 | # beam search sampling helper 66 | helper = BeamSearchHelper( 67 | temperature=temperature, beam_size=beam_size, 68 | top_k=top_k, eos_index=eos_index, 69 | device=image_emb.device 70 | ) 71 | 72 | # process caption tokens if present 73 | if caption is None: 74 | inputs = image_emb 75 | else: 76 | token_emb = self.embedding(caption) 77 | inputs = torch.cat([image_emb, token_emb], dim=1) 78 | 79 | # run LSTM over the inputs and predict the next token 80 | outputs, (h, c) = self.lstm(inputs) 81 | logits = self.classifier(outputs[:, -1, :]) 82 | 83 | # repeat hidden state `beam` times 84 | h, c = h.repeat((1, beam_size, 1)), c.repeat((1, beam_size, 1)) 85 | 86 | # filter `top_k` values 87 | logits = helper.filter_top_k(logits) 88 | 89 | # compute probabilities and sample k values 90 | sample_ind = helper.sample_k_indices(logits, k=beam_size) 91 | sample_val = helper.filter_by_indices(logits, sample_ind).log_softmax(-1) 92 | sample_ind, sample_val = sample_ind.T, sample_val.T 93 | 94 | # define total prediction sequences 95 | sample_seq = sample_ind.clone().detach() 96 | if caption is not None: 97 | sample_seq = torch.cat([caption.repeat(beam_size, 1), sample_seq], dim=1) 98 | 99 | # reusable parameters 100 | beam_copies = torch.tensor([beam_size] * beam_size).to(outputs.device) 101 | 102 | # update `has_ended` index 103 | helper.has_ended = (sample_ind == eos_index).view(-1) 104 | 105 | for i in range(sample_seq.size(1), max_len): 106 | # predict the next time step 107 | inputs = self.embedding(sample_ind) 108 | outputs, (h, c) = self.lstm(inputs, (h, c)) 109 | logits = self.classifier(outputs[:, -1, :]) 110 | 111 | (prev_seqs, prev_vals), (new_ind, new_val) = helper.process_logits( 112 | logits, sample_seq, sample_val 113 | ) 114 | 115 | # create candidate sequences and compute their probabilities 116 | cand_seq = torch.cat((prev_seqs, new_ind.unsqueeze(0).T), -1) 117 | cand_val = prev_vals.flatten() + new_val 118 | 119 | # sample `beam` sequences 120 | filter_ind = helper.sample_k_indices(cand_val, k=beam_size) 121 | 122 | # update total sequences and their scores 123 | sample_val = cand_val[filter_ind] 124 | sample_seq = cand_seq[filter_ind] 125 | sample_ind = sample_seq[:, -1].unsqueeze(-1) 126 | 127 | # filter `has_ended` flags 128 | helper.has_ended = helper.has_ended[filter_ind] 129 | 130 | # check if every branch has ended 131 | if helper.all_ended(): 132 | break 133 | 134 | # repeat hidden state `beam` times and filter by sampled indices 135 | h = torch.repeat_interleave(h, beam_copies, dim=1) 136 | c = torch.repeat_interleave(c, beam_copies, dim=1) 137 | h, c = h[:, filter_ind, :], c[:, filter_ind, :] 138 | 139 | # sample output sequence 140 | ind = helper.sample_k_indices(sample_val, k=1) 141 | output_seq = sample_seq[ind, :].squeeze() 142 | 143 | return output_seq 144 | -------------------------------------------------------------------------------- /deephumor/experiments/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from time import time 4 | 5 | import torch 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | from deephumor.experiments.metrics import perplexity 9 | 10 | 11 | class Trainer: 12 | """An ultimate class for running the models.""" 13 | def __init__(self, experiment_title, log_dir='./logs', text_labels=False, 14 | phases=('train', 'val'), clip_norm=3., log_grad_norm=False, 15 | unk_index=0, pad_index=0, device='cuda'): 16 | self.experiment_data = self._setup_experiment(experiment_title, log_dir) 17 | 18 | self.text_labels = text_labels 19 | self.phases = phases 20 | self.clip_norm = clip_norm 21 | self.log_grad_norm = log_grad_norm 22 | 23 | self.unk_index = unk_index 24 | self.pad_index = pad_index 25 | self.device = device 26 | 27 | self.writers = self._setup_writers() 28 | 29 | @staticmethod 30 | def _setup_experiment(title, log_dir='./logs'): 31 | experiment_name = "{}@{}".format(title, datetime.now().strftime("%d.%m.%Y-%H:%M:%S")) 32 | experiment_dir = os.path.join(log_dir, experiment_name) 33 | best_model_path = f"{title}.best.pth" 34 | 35 | experiment_data = { 36 | 'model_name': title, 37 | 'name': experiment_name, 38 | 'dir': experiment_dir, 39 | 'best_model_path': best_model_path, 40 | 'epochs': 0, 41 | 'iterations': 0, 42 | } 43 | 44 | return experiment_data 45 | 46 | def _setup_writers(self): 47 | return { 48 | phase: SummaryWriter(log_dir=os.path.join(self.experiment_data['dir'], phase)) 49 | for phase in self.phases 50 | } 51 | 52 | def run_epoch(self, model, dataloader, optimizer, criterion, phase='train'): 53 | is_train = (phase == 'train') 54 | model.train() if is_train else model.eval() 55 | 56 | epoch = self.experiment_data['epochs'] 57 | iterations = self.experiment_data['iterations'] 58 | epoch_loss, epoch_pp = 0., 0. 59 | 60 | with torch.set_grad_enabled(is_train): 61 | for batch in dataloader: 62 | # unpack batch 63 | labels, captions, images = batch 64 | bs, max_len = captions.size() 65 | 66 | captions, images = captions.to(self.device), images.to(self.device) 67 | lengths = captions.size(1) - (captions == self.pad_index).sum(dim=1) 68 | 69 | if self.text_labels: 70 | labels = labels.to(self.device) 71 | pred = model(images, captions[:, :-1], lengths, labels) 72 | else: 73 | pred = model(images, captions[:, :-1], lengths) 74 | 75 | pred = pred[:, :max_len, :] 76 | 77 | mask = captions != self.pad_index 78 | loss = criterion(pred[mask], captions[mask]) 79 | 80 | with torch.no_grad(): 81 | pp = perplexity(pred, captions, lengths, self.pad_index) 82 | 83 | if self.writers is not None and phase in self.writers and is_train: 84 | # make optimization step 85 | optimizer.zero_grad() 86 | loss.backward() 87 | 88 | if self.log_grad_norm: 89 | self.writers[phase].add_scalar(f"train/grad_norm", gradient_norm(model).item(), iterations) 90 | torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm) 91 | 92 | optimizer.step() 93 | 94 | if is_train: 95 | iterations += 1 96 | 97 | epoch_loss += loss.item() * len(captions) 98 | epoch_pp += pp.item() * len(captions) 99 | 100 | # dump batch metrics to tensorboard 101 | if self.writers is not None and phase in self.writers and is_train: 102 | self.writers[phase].add_scalar(f"train/batch_loss", loss.item(), iterations) 103 | self.writers[phase].add_scalar(f"train/batch_perplexity", pp.item(), iterations) 104 | 105 | epoch_loss = epoch_loss / len(dataloader.dataset) 106 | epoch_pp = epoch_pp / len(dataloader.dataset) 107 | 108 | # dump epoch metrics to tensorboard 109 | if self.writers is not None and phase in self.writers: 110 | self.writers[phase].add_scalar(f"eval/loss", epoch_loss, epoch) 111 | self.writers[phase].add_scalar(f"eval/perplexity", epoch_pp, epoch) 112 | 113 | if is_train: 114 | self.experiment_data['iterations'] = iterations 115 | 116 | return epoch_loss, epoch_pp 117 | 118 | def train_model(self, model, dataloaders, optimizer, criterion, scheduler=None, n_epochs=50): 119 | 120 | best_epoch, best_val_loss = 0, float('+inf') 121 | past_epochs = self.experiment_data['epochs'] 122 | iterations = self.experiment_data['iterations'] 123 | 124 | if self.writers is None: 125 | self._setup_writers() 126 | 127 | for epoch in range(past_epochs + 1, past_epochs + n_epochs + 1): 128 | self.experiment_data['epochs'] = epoch 129 | print(f'Epoch {epoch:02d}/{past_epochs + n_epochs:02d}') 130 | 131 | st = time() 132 | for phase in self.phases: 133 | epoch_loss, epoch_pp = self.run_epoch( 134 | model, dataloaders[phase], optimizer, criterion, phase=phase 135 | ) 136 | 137 | print(f' {phase:5s} loss: {epoch_loss:.5f}, perplexity: {epoch_pp:.3f}') 138 | 139 | if phase == 'val' and epoch_loss < best_val_loss: 140 | best_epoch, best_val_loss = epoch, epoch_loss 141 | model.save(self.experiment_data['best_model_path']) 142 | 143 | model.save(f"{self.experiment_data['model_name']}.e{epoch}.pth") 144 | 145 | if phase == 'train' and scheduler is not None: 146 | scheduler.step() 147 | 148 | et = time() - st 149 | print(f' epoch time: {et:.2f}s') 150 | 151 | print(f'Best val_loss: {best_val_loss} (epoch: {best_epoch})') 152 | 153 | self.experiment_data['epochs'] = epoch 154 | self.experiment_data['iterations'] = iterations 155 | 156 | return self.experiment_data 157 | 158 | def close(self): 159 | for writer in self.writers.values(): 160 | writer.close() 161 | self.writers = None 162 | 163 | 164 | def gradient_norm(model, norm_type=2.): 165 | total_norm = torch.norm( 166 | torch.stack([torch.norm(p.grad.detach(), norm_type) 167 | for p in model.parameters() if p.grad is not None]), 168 | norm_type 169 | ) 170 | return total_norm 171 | -------------------------------------------------------------------------------- /deephumor/imaging/caption.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import ImageFont, ImageDraw 3 | from copy import deepcopy 4 | 5 | 6 | MEME_FONT_PATH = '../../fonts/impact.ttf' 7 | 8 | 9 | def memeify_image(img, top='', bottom='', font_path=MEME_FONT_PATH): 10 | """Adds top and bottom captions to an image. 11 | 12 | Args: 13 | img (PIL.Image): input image 14 | top (str): top caption text 15 | bottom (str): top caption text 16 | font_path (str): path to font 17 | 18 | Returns: 19 | PIL.Image: captioned image 20 | """ 21 | # do not change existing image 22 | img = deepcopy(img) 23 | 24 | # initial font 25 | font = _get_initial_font(img, texts=[top, bottom], font_path=font_path) 26 | 27 | # split texts into lines 28 | top_lines = split_to_lines(img, top, font) 29 | bottom_lines = split_to_lines(img, bottom, font) 30 | 31 | # adjust the font 32 | font = _get_final_font(img, [top_lines, bottom_lines], font_path=font_path) 33 | 34 | # caption image with both texts 35 | img = caption_image(img, top_lines, font, 'top') 36 | img = caption_image(img, bottom_lines, font, 'bottom') 37 | 38 | return img 39 | 40 | 41 | def get_maximal_font(img, text, font_size=64, text_width=0.94, font_path=MEME_FONT_PATH): 42 | """Computes the font of maximal size that fits the text. 43 | 44 | Args: 45 | img (PIL.Image): input image 46 | text (str): text to fit into image 47 | font_size (int): initial font size 48 | text_width (float): text width ratio with respect to image width 49 | font_path (str): path to font 50 | 51 | Returns: 52 | PIL.ImageFont: optimal font 53 | """ 54 | font = ImageFont.truetype(font_path, font_size) 55 | w, h = font.getsize(text) 56 | 57 | # find the biggest font size that works 58 | while w > img.width * text_width: 59 | font_size = font_size - 1 60 | font = ImageFont.truetype(font_path, font_size) 61 | w, h = font.getsize(text) 62 | 63 | return font 64 | 65 | 66 | def _get_initial_font(img, texts, max_chars=20, font_path=MEME_FONT_PATH): 67 | """Compute initial font of maximal size based of texts. 68 | 69 | Args: 70 | img (PIL.Image): input image 71 | texts (List[str]): list of texts 72 | max_chars (int): maximum number of characters in a line 73 | font_path (str): path to font 74 | 75 | Returns: 76 | PIL.ImageFont: optimal font 77 | """ 78 | # compute the maximum number of characters in a line 79 | max_len = max(map(len, texts)) 80 | max_len = max_len if max_len < max_chars else max_chars 81 | longest_text = 'G' * max_len 82 | 83 | # get initial font size from image dimensions 84 | font_size = int(img.height / 5.4) 85 | 86 | # get maximal font for the initial text 87 | font = get_maximal_font(img, longest_text, font_size, font_path=font_path) 88 | 89 | return font 90 | 91 | 92 | def _get_final_font(img, text_lines, font_path=MEME_FONT_PATH): 93 | """Compute final font of maximal size based of texts split into lines. 94 | 95 | Args: 96 | img (PIL.Image): input image 97 | text_lines (List[List[str]]): list of list of text lines 98 | font_path (str): path to font 99 | 100 | Returns: 101 | PIL.ImageFont: optimal font 102 | """ 103 | # initial font size 104 | font_size = int(img.height / 5.4) // max(map(len, text_lines)) 105 | font = ImageFont.truetype(font_path, font_size) 106 | 107 | # find the text with the highest occupied width 108 | text_lines = [text for lines in text_lines for text in lines] 109 | lengths = list(map(lambda x: font.getsize(x)[0], text_lines)) 110 | longest_text = text_lines[np.argmax(lengths)] 111 | 112 | # get maximal font for the text with highest width 113 | font = get_maximal_font(img, longest_text, font_size, font_path=font_path) 114 | 115 | return font 116 | 117 | 118 | def split_to_lines(img, text, font): 119 | """Splits text into lines to fit the image with a given font. 120 | 121 | Args: 122 | img (PIL.Image): input image 123 | text (str): input text 124 | font (PIL.ImageFont): text font 125 | 126 | Returns: 127 | List[str]: list of text lines 128 | """ 129 | draw = ImageDraw.Draw(img) 130 | text = text.replace('', '').upper() 131 | w, h = draw.textsize(text, font) # measure the size the text will take 132 | 133 | # compute the number of lines 134 | line_count = 1 135 | if w > img.width: 136 | line_count = w // img.width + 1 137 | 138 | lines = [] 139 | if line_count > 1: 140 | # cut text into lines preserving words 141 | 142 | last_cut = 0 143 | is_last = False 144 | 145 | for i in range(0, line_count): 146 | cut = (len(text) // line_count) * i if last_cut == 0 else last_cut 147 | 148 | if i < line_count - 1: 149 | next_cut = (len(text) // line_count) * (i + 1) 150 | else: 151 | next_cut = len(text) 152 | is_last = True 153 | 154 | # make sure we don't cut words in half 155 | if not (next_cut == len(text) or text[next_cut] == " "): 156 | while text[next_cut] != " ": 157 | next_cut += 1 158 | 159 | line = text[cut:next_cut].strip() 160 | 161 | # does line still fit? 162 | w, h = draw.textsize(line, font) 163 | if not is_last and w > img.width * 0.95: 164 | next_cut -= 1 165 | while text[next_cut] != " ": 166 | next_cut -= 1 167 | 168 | last_cut = next_cut 169 | lines.append(text[cut:next_cut].strip()) 170 | else: 171 | lines.append(text) 172 | 173 | return lines 174 | 175 | 176 | def caption_image(img, text_lines, font, pos='top'): 177 | """Captions the image with text. 178 | 179 | Args: 180 | img (PIL.Image): input image 181 | text_lines (List[str]): list of text lines 182 | font (PIL.ImageFont): text font 183 | pos (str): position of text (`top` or `bottom`) 184 | 185 | Returns: 186 | PIL.Image: captioned image 187 | """ 188 | draw = ImageDraw.Draw(img) 189 | w, h = draw.textsize(text_lines[0], font) # measure the size the text will take 190 | 191 | # text border size 192 | border_size = font.size // 18 193 | 194 | # compute the position of text on y-axis 195 | last_y = -h 196 | if pos == 'bottom': 197 | last_y = img.height * 0.987 - h * (len(text_lines) + 1) - border_size 198 | 199 | # draw text lines 200 | for line in text_lines: 201 | w, h = draw.textsize(line, font) 202 | x = img.width / 2 - w / 2 203 | y = last_y + h 204 | 205 | # add borders of black color 206 | for xx in range(-border_size, border_size + 1): 207 | for yy in range(-border_size, border_size + 1): 208 | draw.text((x + xx, y + yy), line, (0, 0, 0), font=font) 209 | 210 | # add text in white 211 | draw.text((x, y), line, (255, 255, 255), font=font) 212 | 213 | last_y = y 214 | 215 | return img 216 | -------------------------------------------------------------------------------- /deephumor/crawlers/crawlers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from multiprocessing import Pool 5 | 6 | import numpy as np 7 | import requests 8 | from Levenshtein import ratio as sim_ratio 9 | from lxml import html 10 | 11 | from .utils import time_to_str, load_image 12 | from deephumor.data import SPECIAL_TOKENS 13 | from deephumor.data.utils import clean_text, check_text, english_prob 14 | 15 | 16 | def crawl_templates(page=1): 17 | """Crawls templates from All-time page. 18 | 19 | Args: 20 | page (int): page number 21 | """ 22 | 23 | meme_templates = [] 24 | url = f'https://memegenerator.net/memes/popular/alltime/page/{page}' 25 | 26 | try: 27 | r = requests.get(url) 28 | tree = html.fromstring(r.content) 29 | 30 | divs = tree.xpath('//div[@class="char-img"]/a') 31 | 32 | for div in divs: 33 | link = div.get('href') 34 | img = div.find('img') 35 | label = img.get('alt') 36 | src = img.get('src') 37 | 38 | meme_templates.append({'label': label, 'link': link, 'src': src}) 39 | except ConnectionError as e: 40 | print(e) 41 | 42 | return meme_templates 43 | 44 | 45 | def crawl_template_page(template_link, page=1, num_retries=10): 46 | """Crawls data from the template page. 47 | 48 | Args: 49 | template_link (str): link identifier of the template 50 | page (int): page number 51 | num_retries (int): number of retries 52 | """ 53 | 54 | url = f'https://memegenerator.net{template_link}/images/popular/alltime/page/{page}' 55 | score_pattern = re.compile(r'(-?\d+(,\d*)?)') 56 | 57 | num_errors = 0 58 | try: 59 | while True: 60 | r = requests.get(url) 61 | if r.status_code == 200: 62 | break 63 | else: 64 | num_errors += 1 65 | if num_errors > num_retries: 66 | print('Failed to load ' + url) 67 | return None, None, None 68 | except ConnectionError as e: 69 | print(e) 70 | return None, None, None 71 | 72 | tree = html.fromstring(r.content) 73 | 74 | label = tree.xpath('//h1/a/text()')[0] 75 | divs = tree.xpath('//div[@class="char-img"]') 76 | 77 | memes = [] 78 | 79 | for div in divs: 80 | score = div.xpath('.//div[contains(@class, "score")]/text()')[0] 81 | score = int(score_pattern.findall(score)[0][0].replace(',', '')) 82 | text0 = div.xpath('a//div[@class="optimized-instance-text0"]/text()') 83 | text1 = div.xpath('a//div[@class="optimized-instance-text1"]/text()') 84 | text0 = text0[0] if text0 else '' 85 | text1 = text1[0] if text1 else '' 86 | 87 | memes.append((score, text0, text1)) 88 | 89 | return label, memes, template_link 90 | 91 | 92 | class MemeGeneratorCrawler: 93 | """MemeGenerator.net website crawler.""" 94 | 95 | # characteristics of the website 96 | temp_pp = 15 # templates per page 97 | capt_pp = 15 # captions per page 98 | 99 | def __init__(self, poolsize=2, 100 | min_len=10, max_len=96, max_tokens=31, 101 | detect_english=False, detect_duplicates=False): 102 | """Initializes crawler and multiprocessing Pool. 103 | 104 | Args: 105 | poolsize (int): size of the multiprocessing pool 106 | min_len (int): minimum length of the caption text 107 | max_len (int): maximum length of the caption text 108 | max_tokens (int): maximum number of tokens in the caption text 109 | detect_english (bool): (non-stable) globally filter non-english templates 110 | detect_duplicates (bool): (slow) check for the similarity of captions and filter duplicates 111 | """ 112 | 113 | self.poolsize = poolsize 114 | self.pool = Pool(poolsize) 115 | 116 | # text preprocessing parameters 117 | self.min_len = min_len 118 | self.max_len = max_len 119 | self.max_tokens = max_tokens 120 | self.detect_english = detect_english 121 | self.detect_duplicates = detect_duplicates 122 | 123 | # containers shared across threads 124 | self.captions = {} 125 | self.num_visited = {} 126 | self.total_texts = {} 127 | 128 | def template_page_callback(self, result): 129 | """Processes the results from the template page.""" 130 | _, memes, link = result 131 | 132 | # check and clear memes 133 | memes_filtered = [] 134 | 135 | for meme in memes: 136 | (score, top, bottom) = meme 137 | top, bottom = clean_text(top), clean_text(bottom) 138 | text = (top + ' ' + bottom).lower() 139 | 140 | if check_text(text, min_len=self.min_len, max_len=self.max_len, max_tokens=self.max_tokens): 141 | memes_filtered.append((score, top, bottom)) 142 | self.total_texts[link] += text + ' ' 143 | 144 | self.captions[link] += memes_filtered 145 | self.num_visited[link] += 1 146 | 147 | def crawl_dataset(self, num_templates=300, num_captions=3000, save_dir='memes'): 148 | """Crawls dataset from memegenerator.net website. 149 | 150 | Args: 151 | num_templates (int): number of meme templates to crawl 152 | num_captions (int): number of captions per template 153 | save_dir (str): directory for saving the data 154 | """ 155 | # approximate number of caption pages needed 156 | num_capt_pages = int(num_captions / self.capt_pp) 157 | num_capt_pages += (10 - num_capt_pages % 10) 158 | 159 | # directories and files 160 | images_dir = os.path.join(save_dir, "images/") 161 | if not os.path.exists(images_dir): 162 | os.makedirs(images_dir) 163 | templates_file = open(os.path.join(save_dir, "templates.txt"), 'a') 164 | captions_file = open(os.path.join(save_dir, "captions.txt"), 'a') 165 | 166 | # counters 167 | temp_page = 1 168 | total_captions, total_templates = 0, 0 169 | 170 | # start crawling until enough templates are loaded 171 | start_time = time.time() 172 | while total_templates < num_templates: 173 | # parse page with templates 174 | templates = crawl_templates(page=temp_page) 175 | print(f'{time_to_str(time.time() - start_time)}, ' 176 | f'{100 * float(total_captions) / num_templates / num_captions:5.2f}%: ' 177 | f'Crawling page {temp_page} with {len(templates)} templates') 178 | 179 | # load captions in async mode 180 | for temp in templates: 181 | link = temp['link'] 182 | self.captions[link] = [] 183 | self.num_visited[link] = 0 184 | self.total_texts[link] = '' 185 | 186 | for i in range(1, num_capt_pages + 1): 187 | self.pool.apply_async(crawl_template_page, [link, i], 188 | callback=self.template_page_callback) 189 | time.sleep(0.3) 190 | 191 | total_page_templates, total_page_captions = 0, 0 192 | for temp in templates: 193 | label, link, src = temp['label'], temp['link'], temp['src'] 194 | 195 | # wait until all initial pages for the template are loaded 196 | for n_retry in range(100): 197 | if self.num_visited[link] >= num_capt_pages: 198 | break 199 | time.sleep(0.5) 200 | 201 | if self.detect_english: 202 | # check captions language 203 | prob_en = np.mean([english_prob(self.total_texts[link]) for _ in range(5)]) 204 | if prob_en < 0.9: 205 | # non-english, stop processing 206 | print(f'{time_to_str(time.time() - start_time)}, ' 207 | f'{100 * float(total_captions) / num_templates / num_captions:5.2f}%: ' 208 | f' NON_ENGLISH {label} - {len(self.captions[link])} captions (eng:{prob_en:.3f})') 209 | continue 210 | else: 211 | prob_en = None 212 | 213 | page = num_capt_pages 214 | if self.detect_duplicates: 215 | # check duplicates and keep collecting to get `n_captions_per_template` 216 | 217 | unique_captions = [] 218 | while True: 219 | for n_retry in range(100): 220 | if self.num_visited[link] >= page: 221 | break 222 | time.sleep(0.5) 223 | 224 | if not self.captions[link]: 225 | # no new captions for the template 226 | break 227 | 228 | # process crawled captions for duplicates (slow..) 229 | for (score, top, bottom) in self.captions[link]: 230 | is_unique = True 231 | text = (top + ' ' + bottom).lower() 232 | 233 | for (_, other_top, other_bottom) in unique_captions: 234 | other_text = (other_top + ' ' + other_bottom).lower() 235 | if sim_ratio(text, other_text) > 0.9: 236 | is_unique = False 237 | break 238 | 239 | if is_unique: 240 | unique_captions.append((score, top, bottom)) 241 | 242 | self.captions[link] = [] 243 | if len(unique_captions) >= num_captions: 244 | break 245 | 246 | # load five more pages 247 | for i in range(page + 1, page + 10): 248 | self.pool.apply_async(crawl_template_page, [link, i], 249 | callback=self.template_page_callback) 250 | page = i 251 | else: 252 | unique_captions = self.captions[link] 253 | 254 | # total captions 255 | if len(unique_captions) < num_captions: 256 | # skip template 257 | print(f'{time_to_str(time.time() - start_time)}, ' 258 | f'{100 * float(total_captions) / num_templates / num_captions:5.2f}%: ' 259 | f' NOT_ENOUGH {label} - {len(unique_captions)} captions (eng:{prob_en:.3f})') 260 | continue 261 | 262 | # take top captions by their score 263 | captions = list(sorted(unique_captions, key=lambda x: -x[0])) 264 | captions = captions[:num_captions] 265 | 266 | # save template information and load image 267 | templates_file.write(f'{label}\t{link}\t{src}\n') 268 | self.pool.apply_async(load_image, [src, images_dir]) 269 | total_templates += 1 270 | total_page_templates += 1 271 | 272 | # save captions 273 | for (score, top, bot) in captions: 274 | top = top if top else SPECIAL_TOKENS['EMPTY'] 275 | bot = bot if bot else SPECIAL_TOKENS['EMPTY'] 276 | text = top + ' ' + SPECIAL_TOKENS['SEP'] + ' ' + bot 277 | captions_file.write(f'{label}\t{score}\t{text}\n') 278 | 279 | total_captions += len(captions) 280 | total_page_captions += len(captions) 281 | 282 | # delete data from memory 283 | del self.captions[link] 284 | del self.num_visited[link] 285 | del self.total_texts[link] 286 | 287 | print(f'{time_to_str(time.time() - start_time)}, ' 288 | f'{100 * float(total_captions) / num_templates / num_captions:5.2f}%: ' 289 | f' {label} - {len(captions)} captions ({total_captions}) (pid:{page}, en:{prob_en:.3f})') 290 | 291 | if total_templates == num_templates: 292 | # crawled enough templates, skip others if any 293 | break 294 | 295 | print(f'{time_to_str(time.time() - start_time)}, ' 296 | f'{100 * float(total_captions) / num_templates / num_captions:5.2f}%: ' 297 | f'Crawled page {temp_page} with {total_page_templates} templates ' 298 | f'and {total_page_captions} captions ({total_templates}/{total_captions})') 299 | 300 | time.sleep(0.5) 301 | temp_page += 1 302 | 303 | print(f'{time_to_str(time.time() - start_time)}: ' 304 | f'Finished: crawled {total_templates} templates and ' 305 | f'{total_captions} captions') 306 | 307 | templates_file.close() 308 | captions_file.close() 309 | -------------------------------------------------------------------------------- /deephumor/models/caption_models.py: -------------------------------------------------------------------------------- 1 | """Image captioning models.""" 2 | import torch 3 | from torch import nn 4 | 5 | from . import ImageEncoder, TransformerDecoder, LSTMDecoder, ImageLabelEncoder 6 | from .transformers import SelfAttentionTransformerDecoder 7 | 8 | 9 | class CaptioningLSTM(nn.Module): 10 | """LSTM-based image captioning model. 11 | 12 | Encodes input images into a embeddings of size `emb_dim` 13 | and passes them as the first token to the caption generation decoder. 14 | """ 15 | def __init__(self, num_tokens, emb_dim=256, hidden_size=512, num_layers=2, 16 | enc_dropout=0.3, dec_dropout=0.1): 17 | super(CaptioningLSTM, self).__init__() 18 | 19 | self.encoder = ImageEncoder( 20 | emb_dim=emb_dim, 21 | dropout=enc_dropout 22 | ) 23 | 24 | self.decoder = LSTMDecoder( 25 | num_tokens=num_tokens, 26 | emb_dim=emb_dim, 27 | hidden_size=hidden_size, 28 | num_layers=num_layers, 29 | dropout=dec_dropout, 30 | ) 31 | 32 | # hyperparameters dictionary 33 | self._hp = { 34 | 'num_tokens': num_tokens, 35 | 'emb_dim': emb_dim, 36 | 'hidden_size': hidden_size, 37 | 'num_layers': num_layers, 38 | 'enc_dropout': enc_dropout, 39 | 'dec_dropout': dec_dropout, 40 | } 41 | 42 | def forward(self, images, captions, lengths=None): 43 | emb = self.encoder(images) 44 | out = self.decoder(emb, captions, lengths) 45 | 46 | return out 47 | 48 | def generate(self, image, caption=None, max_len=25, 49 | temperature=1.0, beam_size=10, top_k=50, eos_index=3): 50 | """Generates caption for an image. 51 | 52 | Args: 53 | image (torch.Tensor): input image of shape `[1, width, height]` 54 | caption (torch.Tensor, optional): beginning tokens of the caption of shape `[1, seq_len]` 55 | max_len (int): maximum length of the caption 56 | temperature (float): temperature for softmax over logits 57 | beam_size (int): number of maintained branches at each step 58 | top_k (int): number of the most probable tokens to consider during sampling 59 | eos_index (int): index of the EOS (end-of-sequence) token 60 | 61 | Returns: 62 | torch.Tensor: generated caption tokens of shape `[1, min(output_len, max_len)]` 63 | """ 64 | 65 | # get image embedding 66 | image_emb = self.encoder(image).unsqueeze(1) 67 | 68 | sampled_ids = self.decoder.generate( 69 | image_emb, caption=caption, 70 | max_len=max_len, temperature=temperature, 71 | beam_size=beam_size, top_k=top_k, eos_index=eos_index 72 | ) 73 | 74 | return sampled_ids 75 | 76 | def save(self, ckpt_path): 77 | """Saves the model's state and hyperparameters.""" 78 | torch.save( 79 | {'model': self.state_dict(), 'hp': self._hp}, 80 | ckpt_path 81 | ) 82 | 83 | @staticmethod 84 | def from_pretrained(ckpt_path): 85 | """Loads and builds the model from the checkpoint file.""" 86 | ckpt = torch.load(ckpt_path, map_location='cpu') 87 | hp = ckpt['hp'] 88 | 89 | model = CaptioningLSTM( 90 | num_tokens=hp['num_tokens'], 91 | emb_dim=hp['emb_dim'], 92 | hidden_size=hp['hidden_size'], 93 | num_layers=hp['num_layers'], 94 | enc_dropout=hp['enc_dropout'], 95 | dec_dropout=hp['dec_dropout'], 96 | ) 97 | model.load_state_dict(ckpt['model']) 98 | return model 99 | 100 | 101 | class CaptioningLSTMWithLabels(nn.Module): 102 | """LSTM-based image captioning model with label inputs. 103 | 104 | Uses image and text label to condition the decoder. 105 | 106 | Encoder build combined embeddings of size `emb_dim` for input images and text labels 107 | and passes them as the first token to the caption generation decoder. 108 | """ 109 | def __init__(self, num_tokens, emb_dim=256, hidden_size=512, num_layers=2, 110 | enc_dropout=0.3, dec_dropout=0.1): 111 | super(CaptioningLSTMWithLabels, self).__init__() 112 | 113 | self.encoder = ImageLabelEncoder( 114 | num_tokens=num_tokens, 115 | emb_dim=emb_dim, 116 | dropout=enc_dropout 117 | ) 118 | 119 | self.decoder = LSTMDecoder( 120 | num_tokens=num_tokens, 121 | emb_dim=emb_dim, 122 | hidden_size=hidden_size, 123 | num_layers=num_layers, 124 | dropout=dec_dropout, 125 | embedding=self.encoder.label_encoder.embedding 126 | ) 127 | 128 | # hyperparameters dictionary 129 | self._hp = { 130 | 'num_tokens': num_tokens, 131 | 'emb_dim': emb_dim, 132 | 'hidden_size': hidden_size, 133 | 'num_layers': num_layers, 134 | 'enc_dropout': enc_dropout, 135 | 'dec_dropout': dec_dropout, 136 | } 137 | 138 | def forward(self, images, captions, lengths, labels): 139 | emb = self.encoder(images=images, labels=labels) 140 | out = self.decoder(emb, captions, lengths) 141 | 142 | return out 143 | 144 | def generate(self, image, label, caption=None, max_len=25, 145 | temperature=1.0, beam_size=10, top_k=50, eos_index=3): 146 | """Generates caption for an image based on the text label. 147 | 148 | Args: 149 | image (torch.Tensor): input image of shape `[1, width, height]` 150 | label: (torch.Tensor): text label for the image `[1, label_len]` 151 | caption (torch.Tensor, optional): beginning tokens of the caption of shape `[1, seq_len]` 152 | max_len (int): maximum length of the caption 153 | temperature (float): temperature for softmax over logits 154 | beam_size (int): number of maintained branches at each step 155 | top_k (int): number of the most probable tokens to consider during sampling 156 | eos_index (int): index of the EOS (end-of-sequence) token 157 | 158 | Returns: 159 | torch.Tensor: generated caption tokens of shape `[1, min(output_len, max_len)]` 160 | """ 161 | 162 | # get image embedding 163 | image_emb = self.encoder(image, label).unsqueeze(1) 164 | 165 | sampled_ids = self.decoder.generate( 166 | image_emb, caption=caption, 167 | max_len=max_len, temperature=temperature, 168 | beam_size=beam_size, top_k=top_k, eos_index=eos_index 169 | ) 170 | 171 | return sampled_ids 172 | 173 | def save(self, ckpt_path): 174 | """Saves the model's state and hyperparameters.""" 175 | torch.save( 176 | {'model': self.state_dict(), 'hp': self._hp}, 177 | ckpt_path 178 | ) 179 | 180 | @staticmethod 181 | def from_pretrained(ckpt_path): 182 | """Loads and builds the model from the checkpoint file.""" 183 | ckpt = torch.load(ckpt_path, map_location='cpu') 184 | hp = ckpt['hp'] 185 | 186 | model = CaptioningLSTMWithLabels( 187 | num_tokens=hp['num_tokens'], 188 | emb_dim=hp['emb_dim'], 189 | hidden_size=hp['hidden_size'], 190 | num_layers=hp['num_layers'], 191 | enc_dropout=hp['enc_dropout'], 192 | dec_dropout=hp['dec_dropout'], 193 | ) 194 | model.load_state_dict(ckpt['model']) 195 | return model 196 | 197 | 198 | class CaptioningTransformerBase(nn.Module): 199 | """Simple Transformer-based image captioning model without Encoder-Attention Decoder blocks. 200 | 201 | - ResNet-based [1] ImageEncoder for getting global and spatial image embeddings. 202 | - Vanilla Transformer Decoder without Encoder-Attention [2]. 203 | 204 | Global image embedding is prepended to the token embedding of decoder input sequences. 205 | 206 | References: 207 | [1]: "Deep Residual Learning for Image Recognition", https://arxiv.org/abs/1512.03385 208 | [2]: "Attention Is All You Need", https://arxiv.org/abs/1706.03762 209 | """ 210 | 211 | def __init__(self, num_tokens, hid_dim=512, n_layers=6, n_heads=8, pf_dim=2048, 212 | enc_dropout=0.3, dec_dropout=0.1, pad_index=0, max_len=128): 213 | """Initializes CaptioningTransformer. 214 | 215 | Args: 216 | num_tokens (int): number of tokens in caption sequences 217 | hid_dim (int): hidden dimension and embedding sizes 218 | n_layers (int): number of Decoder layers 219 | n_heads (int): number of attention heads 220 | pf_dim (int): dimensions of the position-wise layer 221 | enc_dropout (float): image embeddings dropout 222 | dec_dropout (float): attention and position-wise layer dropouts of the Decoder 223 | pad_index (int): index used for padding values in input sequences 224 | max_len (int): maximum lengths of input sequences. 225 | """ 226 | 227 | super().__init__() 228 | 229 | self.encoder = ImageEncoder( 230 | emb_dim=hid_dim, 231 | dropout=enc_dropout, 232 | spatial_features=False 233 | ) 234 | 235 | self.decoder = SelfAttentionTransformerDecoder( 236 | num_tokens=num_tokens, 237 | hid_dim=hid_dim, 238 | n_layers=n_layers, 239 | n_heads=n_heads, 240 | pf_dim=pf_dim, 241 | dropout=dec_dropout, 242 | pad_index=pad_index, 243 | max_len=max_len 244 | ) 245 | 246 | # hyperparameters dictionary 247 | self._hp = { 248 | 'num_tokens': num_tokens, 249 | 'hid_dim': hid_dim, 250 | 'n_layers': n_layers, 251 | 'n_heads': n_heads, 252 | 'pf_dim': pf_dim, 253 | 'enc_dropout': enc_dropout, 254 | 'dec_dropout': dec_dropout, 255 | 'pad_index': pad_index, 256 | 'max_len': max_len 257 | } 258 | 259 | def forward(self, images, captions, lengths=None): 260 | """ 261 | Args: 262 | images (torch.Tensor): input images of shape `[bs, width, height]` 263 | captions (torch.Tensor): text captions of shape `[bs, seq_len]` 264 | lengths (torch.Tensor): lengths of the input sequences of shape `[bs,]` 265 | 266 | Returns: 267 | torch.Tensor: decoded scores for caption sequence tokens of shape `[bs, seq_len, num_tokens]` 268 | """ 269 | image_emb = self.encoder(images) 270 | out = self.decoder(captions, start_emb=image_emb) 271 | 272 | return out 273 | 274 | def generate(self, image, caption=None, max_len=25, 275 | temperature=1.0, beam_size=10, top_k=50, eos_index=3): 276 | """Generates caption for an image. 277 | 278 | Args: 279 | image (torch.Tensor): input image of shape `[1, width, height]` 280 | caption (torch.Tensor, optional): beginning tokens of the caption of shape `[1, seq_len]` 281 | max_len (int): maximum length of the caption 282 | temperature (float): temperature for softmax over logits 283 | beam_size (int): number of maintained branches at each step 284 | top_k (int): number of the most probable tokens to consider during sampling 285 | eos_index (int): index of the EOS (end-of-sequence) token 286 | 287 | Returns: 288 | torch.Tensor: generated caption tokens of shape `[1, min(output_len, max_len)]` 289 | """ 290 | 291 | # get image embeddings 292 | image_emb = self.encoder(image) 293 | 294 | sampled_ids = self.decoder.generate( 295 | image_emb, caption=caption, 296 | max_len=max_len, temperature=temperature, 297 | beam_size=beam_size, top_k=top_k, eos_index=eos_index 298 | ) 299 | 300 | return sampled_ids 301 | 302 | def save(self, ckpt_path): 303 | """Saves the model's state and hyperparameters.""" 304 | torch.save( 305 | {'model': self.state_dict(), 'hp': self._hp}, 306 | ckpt_path 307 | ) 308 | 309 | @staticmethod 310 | def from_pretrained(ckpt_path): 311 | """Loads and builds the model from the checkpoint file.""" 312 | ckpt = torch.load(ckpt_path, map_location='cpu') 313 | hp = ckpt['hp'] 314 | 315 | model = CaptioningTransformerBase( 316 | num_tokens=hp['num_tokens'], 317 | hid_dim=hp['hid_dim'], 318 | n_layers=hp['n_layers'], 319 | n_heads=hp['n_heads'], 320 | pf_dim=hp['pf_dim'], 321 | enc_dropout=hp['enc_dropout'], 322 | dec_dropout=hp['dec_dropout'], 323 | pad_index=hp['pad_index'], 324 | max_len=hp['max_len'] 325 | ) 326 | model.load_state_dict(ckpt['model']) 327 | return model 328 | 329 | 330 | class CaptioningTransformer(nn.Module): 331 | """Transformer-based image captioning model. 332 | 333 | - ResNet-based [1] ImageEncoder for getting global and spatial image embeddings. 334 | - Vanilla Transformer Decoder [2]. 335 | 336 | Global image embedding is prepended to the token embedding of decoder input sequences. 337 | Spatial image embeddings are used as encoder outputs in the encoder-attention block 338 | of the Decoder layers. 339 | 340 | References: 341 | [1]: "Deep Residual Learning for Image Recognition", https://arxiv.org/abs/1512.03385 342 | [2]: "Attention Is All You Need", https://arxiv.org/abs/1706.03762 343 | """ 344 | 345 | def __init__(self, num_tokens, hid_dim=512, n_layers=6, n_heads=8, pf_dim=2048, 346 | enc_dropout=0.3, dec_dropout=0.1, pad_index=0, max_len=128): 347 | """Initializes CaptioningTransformer. 348 | 349 | Args: 350 | num_tokens (int): number of tokens in caption sequences 351 | hid_dim (int): hidden dimension and embedding sizes 352 | n_layers (int): number of Decoder layers 353 | n_heads (int): number of attention heads 354 | pf_dim (int): dimensions of the position-wise layer 355 | enc_dropout (float): image embeddings dropout 356 | dec_dropout (float): attention and position-wise layer dropouts of the Decoder 357 | pad_index (int): index used for padding values in input sequences 358 | max_len (int): maximum lengths of input sequences. 359 | """ 360 | 361 | super().__init__() 362 | 363 | self.encoder = ImageEncoder( 364 | emb_dim=hid_dim, 365 | dropout=enc_dropout, 366 | spatial_features=True 367 | ) 368 | 369 | self.decoder = TransformerDecoder( 370 | num_tokens=num_tokens, 371 | hid_dim=hid_dim, 372 | n_layers=n_layers, 373 | n_heads=n_heads, 374 | pf_dim=pf_dim, 375 | dropout=dec_dropout, 376 | pad_index=pad_index, 377 | max_len=max_len 378 | ) 379 | 380 | # hyperparameters dictionary 381 | self._hp = { 382 | 'num_tokens': num_tokens, 383 | 'hid_dim': hid_dim, 384 | 'n_layers': n_layers, 385 | 'n_heads': n_heads, 386 | 'pf_dim': pf_dim, 387 | 'enc_dropout': enc_dropout, 388 | 'dec_dropout': dec_dropout, 389 | 'pad_index': pad_index, 390 | 'max_len': max_len 391 | } 392 | 393 | def forward(self, images, captions, lengths=None): 394 | """ 395 | Args: 396 | images (torch.Tensor): input images of shape `[bs, width, height]` 397 | captions (torch.Tensor): text captions of shape `[bs, seq_len]` 398 | lengths (torch.Tensor): lengths of the input sequences of shape `[bs,]` 399 | 400 | Returns: 401 | torch.Tensor: decoded scores for caption sequence tokens of shape `[bs, seq_len, num_tokens]` 402 | """ 403 | image_emb, image_spatial_emb = self.encoder(images) 404 | out = self.decoder(captions, enc_out=image_spatial_emb, start_emb=image_emb) 405 | 406 | return out 407 | 408 | def generate(self, image, caption=None, max_len=25, 409 | temperature=1.0, beam_size=10, top_k=50, eos_index=3): 410 | """Generates caption for an image. 411 | 412 | Args: 413 | image (torch.Tensor): input image of shape `[1, width, height]` 414 | caption (torch.Tensor, optional): beginning tokens of the caption of shape `[1, seq_len]` 415 | max_len (int): maximum length of the caption 416 | temperature (float): temperature for softmax over logits 417 | beam_size (int): number of maintained branches at each step 418 | top_k (int): number of the most probable tokens to consider during sampling 419 | eos_index (int): index of the EOS (end-of-sequence) token 420 | 421 | Returns: 422 | torch.Tensor: generated caption tokens of shape `[1, min(output_len, max_len)]` 423 | """ 424 | 425 | # get image embeddings 426 | image_emb, image_spatial_emb = self.encoder(image) 427 | 428 | sampled_ids = self.decoder.generate( 429 | image_emb, image_spatial_emb, caption=caption, 430 | max_len=max_len, temperature=temperature, 431 | beam_size=beam_size, top_k=top_k, eos_index=eos_index 432 | ) 433 | 434 | return sampled_ids 435 | 436 | def save(self, ckpt_path): 437 | """Saves the model's state and hyperparameters.""" 438 | torch.save( 439 | {'model': self.state_dict(), 'hp': self._hp}, 440 | ckpt_path 441 | ) 442 | 443 | @staticmethod 444 | def from_pretrained(ckpt_path): 445 | """Loads and builds the model from the checkpoint file.""" 446 | ckpt = torch.load(ckpt_path, map_location='cpu') 447 | hp = ckpt['hp'] 448 | 449 | model = CaptioningTransformer( 450 | num_tokens=hp['num_tokens'], 451 | hid_dim=hp['hid_dim'], 452 | n_layers=hp['n_layers'], 453 | n_heads=hp['n_heads'], 454 | pf_dim=hp['pf_dim'], 455 | enc_dropout=hp['enc_dropout'], 456 | dec_dropout=hp['dec_dropout'], 457 | pad_index=hp['pad_index'], 458 | max_len=hp['max_len'] 459 | ) 460 | model.load_state_dict(ckpt['model']) 461 | return model 462 | -------------------------------------------------------------------------------- /deephumor/models/transformers.py: -------------------------------------------------------------------------------- 1 | """Transformer modules. 2 | 3 | References: 4 | [1]: "Attention Is All You Need", https://arxiv.org/abs/1706.03762 5 | """ 6 | import torch 7 | from torch import nn 8 | 9 | from deephumor.models.beam import BeamSearchHelper 10 | 11 | 12 | def get_pad_mask(query, key, pad_index=0): 13 | """Computes padding mask from the Query and Key sequences. 14 | 15 | Args: 16 | query (torch.Tensor): query sequences of shape `[bs, query_len]` 17 | key (torch.Tensor): key sequences of shape `[bs, key_len]` 18 | pad_index (int): index used for padding the values 19 | 20 | Returns: 21 | torch.Tensor: boolean padding mask of shape `[bs, query_len, key_len]` 22 | """ 23 | bs, seq_len_q = query.shape[:2] 24 | bs, seq_len_k = key.shape[:2] 25 | pad_mask = (key == pad_index).unsqueeze(1) 26 | return pad_mask.expand(bs, seq_len_q, seq_len_k).to(query.device) 27 | 28 | 29 | def get_autoregressive_mask(seq): 30 | """Returns autoregressive mask for the decoder inputs. 31 | 32 | Args: 33 | seq (torch.Tensor): input sequences of shape `[bs, seq_len]` 34 | 35 | Returns: 36 | torch.bool: boolean mask of shape `[bs, seq_len, seq_len]` 37 | """ 38 | bs, seq_len = seq.shape[:2] 39 | autoregressive_mask = torch.triu(torch.ones([bs, seq_len, seq_len]), 1) 40 | return autoregressive_mask.bool().to(seq.device) 41 | 42 | 43 | class MultiHeadAttentionLayer(nn.Module): 44 | """MultiHeadAttentionLayer from "Attention Is All You Need".""" 45 | 46 | def __init__(self, hid_dim=512, n_heads=8, dropout=0.): 47 | """Initializes MultiHeadAttentionLayer. 48 | 49 | Dimension of one head is `hid_dim` // `n_heads` 50 | 51 | Args: 52 | hid_dim (int): hidden dimension size 53 | n_heads (int): number of attention heads 54 | dropout (float): attention dropout 55 | """ 56 | 57 | super().__init__() 58 | 59 | assert hid_dim % n_heads == 0, "hid_dim must be divisible by n_heads" 60 | 61 | self.hid_dim = hid_dim 62 | self.n_heads = n_heads 63 | self.head_dim = hid_dim // n_heads 64 | 65 | # query, key and value linear networks 66 | self.fc_q = nn.Linear(hid_dim, hid_dim) 67 | self.fc_k = nn.Linear(hid_dim, hid_dim) 68 | self.fc_v = nn.Linear(hid_dim, hid_dim) 69 | 70 | # output linear networks 71 | self.fc_o = nn.Linear(hid_dim, hid_dim) 72 | 73 | # attention dropout 74 | self.dropout = nn.Dropout(dropout) 75 | 76 | # scale parameter 77 | self.scale = torch.nn.Parameter( 78 | torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)), 79 | requires_grad=False 80 | ) 81 | 82 | def forward(self, query, key, value, mask=None): 83 | """ 84 | Args: 85 | query (torch.Tensor): queries of shape `[bs, seq_len, hid_dim]` 86 | key (torch.Tensor): keys of shape `[bs, seq_len, hid_dim]` 87 | value (torch.Tensor): values of shape `[bs, seq_len, hid_dim]` 88 | mask (torch.Tensor): boolean mask for padded elements of shape `[bs, seq_len, seq_len]` 89 | 90 | Returns: 91 | torch.Tensor: multi-head attention tensor of shape `[bs, seq_len, hid_dim]` 92 | """ 93 | 94 | bs, seq_len = query.shape[:2] 95 | 96 | # calculate Q, K, V using corresponding linear networks 97 | q, k, v = self.fc_q(query), self.fc_k(key), self.fc_v(value) # shape is [bs, seq_len, hid_dim] 98 | 99 | # prepare Q, K, V for .matmul() or `@` operator 100 | # shape is [bs, n_heads, seq_len, head_dim] 101 | q = q.view(bs, seq_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 102 | k = k.view(bs, seq_len, self.n_heads, self.head_dim).permute(0, 2, 3, 1) 103 | v = v.view(bs, seq_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 104 | 105 | # compute energy 106 | energy = (q @ k) / self.scale # shape is [bs, n_heads, seq_q_len, seq_k_len] 107 | 108 | if mask is not None: 109 | # apply mask 110 | mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) 111 | energy = energy.masked_fill(mask, -1e8) 112 | 113 | # apply softmax along the last dim of energy and get the attention weights 114 | # shape is [bs, n_heads, seq_len, seq_len] 115 | attention = torch.softmax(energy, dim=-1) 116 | attention = self.dropout(attention) 117 | 118 | # weight values with calculated attention 119 | # shape is [bs, n_heads, seq_len, head_dim] 120 | x = attention @ v 121 | 122 | # squash 1 and 4 dims back 123 | x = x.permute(0, 2, 1, 3).contiguous() 124 | x = x.view(bs, -1, self.hid_dim) # shape is [bs, seq_len, hid_dim] 125 | 126 | # apply output linear layer 127 | x = self.fc_o(x) 128 | 129 | return x 130 | 131 | 132 | class PositionwiseFeedforwardLayer(nn.Module): 133 | """Position-wise Feedforward Layer from "Attention Is All You Need".""" 134 | 135 | def __init__(self, hid_dim=512, pf_dim=2048, dropout=0.): 136 | """Initializes PositionwiseFeedforwardLayer. 137 | 138 | Args: 139 | hid_dim (int): hidden dimension size 140 | pf_dim (int): dimensions of the position-wise layer 141 | dropout (float): position-wise layer dropout 142 | """ 143 | 144 | super().__init__() 145 | 146 | # linear layers 147 | self.fc_1 = nn.Linear(hid_dim, pf_dim) 148 | self.fc_2 = nn.Linear(pf_dim, hid_dim) 149 | 150 | # dropout is applied after the first layer 151 | self.dropout = nn.Dropout(dropout) 152 | 153 | def forward(self, x): 154 | """ 155 | Args: 156 | x (torch.Tensor): sequences of shape `[bs, seq_len, hid_dim]` 157 | 158 | Returns: 159 | torch.Tensor: processed sequences of shape `[bs, seq_len, hid_dim]` 160 | """ 161 | # apply linear layers + dropout 162 | x = self.dropout(torch.relu(self.fc_1(x))) 163 | x = self.fc_2(x) 164 | 165 | return x 166 | 167 | 168 | class EncoderLayer(nn.Module): 169 | """Encoder Layer of the Vanilla Transformer.""" 170 | 171 | def __init__(self, hid_dim=512, n_heads=8, pf_dim=2048, dropout=0.): 172 | """Initializes EncoderLayer. 173 | 174 | Args: 175 | hid_dim (int): hidden dimension size 176 | n_heads (int): number of attention heads 177 | pf_dim (int): dimensions of the position-wise layer 178 | dropout (float): attention and position-wise layer dropouts 179 | """ 180 | 181 | super().__init__() 182 | 183 | # self-attention + layer normalization 184 | self.self_attn = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) 185 | self.self_attn_ln = nn.LayerNorm(hid_dim) 186 | 187 | # positionwise feedforward layer + layer normalization 188 | self.pf = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) 189 | self.pf_ln = nn.LayerNorm(hid_dim) 190 | 191 | # dropout to the outputs of the attention and position-wise feedforward layers 192 | self.dropout = nn.Dropout(dropout) 193 | 194 | def forward(self, x, input_mask=None): 195 | """ 196 | Args: 197 | x (torch.Tensor): input sequences of shape `[bs, seq_len, hid_dim]` 198 | input_mask (torch.Tensor): boolean mask for padded elements of shape `[bs, seq_len, seq_len]` 199 | 200 | Returns: 201 | torch.Tensor: processed sequences of shape `[bs, seq_len, hid_dim]` 202 | """ 203 | ### block 1 204 | # calculate self-attention + dropout 205 | attn_out = self.self_attn(x, x, x, mask=input_mask) 206 | attn_out = self.dropout(attn_out) 207 | 208 | # residual (attention) + attention layer norm 209 | x = self.self_attn_ln(x + attn_out) 210 | 211 | ### block 2 212 | # calculate position-wise feedforward + dropout 213 | ff_out = self.dropout(self.pf(x)) 214 | 215 | # residual (position-wise feedforward) + position-wise feedforward layer norm 216 | x = self.pf_ln(x + ff_out) 217 | 218 | return x 219 | 220 | 221 | class TransformerEncoder(nn.Module): 222 | """Multi-layer Transformer Encoder. 223 | 224 | Follows the architecture of Vanilla Transformer Encoder 225 | from "Attention Is All You Need". 226 | 227 | Modifications: 228 | - Learned positional embeddings instead of the sinusoidal positional encoding. 229 | """ 230 | 231 | def __init__(self, num_tokens, hid_dim=512, n_layers=6, n_heads=8, 232 | pf_dim=2048, dropout=0., pad_index=None, max_len=128): 233 | """Initializes TransformerEncoder. 234 | 235 | Args: 236 | num_tokens (int): number of tokens in input sequences 237 | hid_dim (int): hidden dimension size 238 | n_layers (int): number of Encoder layers 239 | n_heads (int): number of attention heads 240 | pf_dim (int): dimensions of the position-wise layer 241 | dropout (float): attention and position-wise layer dropouts 242 | pad_index (int): index used for padding values 243 | max_len (int): maximum lengths of input sequences. 244 | """ 245 | 246 | super().__init__() 247 | 248 | self.pad_index = pad_index # if None, don't use masking 249 | 250 | # embeddings 251 | self.tok_embedding = nn.Embedding(num_tokens, hid_dim) 252 | self.pos_embedding = nn.Embedding(max_len, hid_dim) 253 | self.dropout = nn.Dropout(dropout) 254 | 255 | # encoder layers (implemented below) 256 | self.layers = nn.ModuleList([ 257 | EncoderLayer(hid_dim, n_heads, pf_dim, dropout) 258 | for _ in range(n_layers) 259 | ]) 260 | 261 | # scale parameter 262 | self.scale = torch.nn.Parameter( 263 | torch.sqrt(torch.tensor(hid_dim, dtype=torch.float32)), 264 | requires_grad=False 265 | ) 266 | 267 | # custom weight initialization 268 | self.init_weights() 269 | 270 | def init_weights(self): 271 | for m in self.modules(): 272 | if hasattr(m, 'weight') and m.weight.dim() > 1: 273 | nn.init.xavier_uniform_(m.weight.data) 274 | 275 | def forward(self, x): 276 | """ 277 | Args: 278 | x (torch.Tensor): token sequences of shape `[bs, seq_len]` 279 | 280 | Returns: 281 | torch.Tensor: encoded sequences of shape `[bs, seq_len, hid_dim]` 282 | """ 283 | bs, seq_len = x.shape[:2] 284 | 285 | # get token embeddings and scale with self.scale parameter 286 | tok_emb = self.tok_embedding(x) / self.scale 287 | 288 | # get pos embeddings 289 | indices = torch.arange(seq_len).repeat(bs, 1).to(x.device) 290 | pos_emb = self.pos_embedding(indices) 291 | 292 | # sum up token and positional embeddings and apply dropout 293 | emb = tok_emb + pos_emb 294 | emb = self.dropout(emb) 295 | 296 | # compute padding mask 297 | mask = None 298 | if self.padding_index is not None: 299 | mask = get_pad_mask(x, x, pad_index=self.pad_index) 300 | 301 | # apply encoder layers one by one; input shape is [bs, seq_len, hid dim] 302 | x = emb 303 | for layer in self.layers: 304 | x = layer(x, input_mask=mask) 305 | 306 | return x 307 | 308 | 309 | class DecoderLayer(nn.Module): 310 | """Decoder Layer of the Vanilla Transformer.""" 311 | 312 | def __init__(self, 313 | hid_dim=512, 314 | n_heads=8, 315 | pf_dim=2048, 316 | dropout=0.): 317 | """Initializes DecoderLayer. 318 | 319 | Args: 320 | hid_dim (int): hidden dimension size 321 | n_heads (int): number of attention heads 322 | pf_dim (int): dimensions of the position-wise layer 323 | dropout (float): attention and position-wise layer dropouts 324 | """ 325 | 326 | super().__init__() 327 | 328 | # masked self-attention + layer normalization 329 | self.self_attn = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) 330 | self.self_attn_ln = nn.LayerNorm(hid_dim) 331 | 332 | # encoder-attention + layer normalization 333 | self.enc_attn = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) 334 | self.enc_attn_ln = nn.LayerNorm(hid_dim) 335 | 336 | # position-wise feedforward layer + layer normalization 337 | self.pf = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) 338 | self.pf_ln = nn.LayerNorm(hid_dim) 339 | 340 | # attention and position-wise feedforward layer dropouts 341 | self.dropout = nn.Dropout(dropout) 342 | 343 | def forward(self, x, enc_out, input_mask=None, enc_mask=None): 344 | """ 345 | Args: 346 | x (torch.Tensor): input sequences of shape `[bs, seq_len, hid_dim]` 347 | enc_out (torch.Tensor): encoder outputs of shape `[bs, seq_len, hid_dim]` 348 | input_mask (torch.Tensor): masked self-attention + padding mask of shape `[bs, seq_len, seq_len]` 349 | enc_mask (torch.Tensor): encoder outputs padding mask of shape `[bs, seq_len, seq_len]` 350 | 351 | Returns: 352 | torch.Tensor: processed sequences of shape `[bs, seq_len, hid_dim]` 353 | """ 354 | ### block 1 355 | # self-attention + dropout 356 | attn_out = self.self_attn(x, x, x, mask=input_mask) 357 | attn_out = self.dropout(attn_out) 358 | 359 | # residual (attention) + attention layer norm 360 | x = self.self_attn_ln(x + attn_out) 361 | 362 | ### block 2 363 | # encoder-attention + dropout 364 | attn_out = self.enc_attn(x, enc_out, enc_out, mask=enc_mask) 365 | attn_out = self.dropout(attn_out) 366 | 367 | # residual (attention) + attention layer norm 368 | x = self.enc_attn_ln(x + attn_out) 369 | 370 | ### block 2 371 | # positionwise feedforward + dropout 372 | ff_out = self.dropout(self.pf(x)) 373 | 374 | # residual (positionwise feedforward) + positionwise feedforward layer norm 375 | x = self.pf_ln(x + ff_out) 376 | 377 | return x 378 | 379 | 380 | class TransformerDecoder(nn.Module): 381 | """Multi-layer Transformer Decoder. 382 | 383 | Follows the architecture of Vanilla Transformer Decoder from "Attention Is All You Need". 384 | 385 | Outputs scores for tokens in the target sequence. 386 | 387 | Modifications: 388 | - Learned positional embeddings instead of the sinusoidal positional encoding. 389 | - Allows passing as input image embedding vector which is prepended to 390 | the token embeddings. 391 | """ 392 | 393 | def __init__(self, num_tokens, hid_dim=512, n_layers=6, n_heads=8, 394 | pf_dim=2048, dropout=0., pad_index=None, max_len=128): 395 | """Initializes TransformerDecoder. 396 | 397 | Args: 398 | num_tokens (int): number of tokens in input sequences 399 | hid_dim (int): hidden dimension size 400 | n_layers (int): number of Decoder layers 401 | n_heads (int): number of attention heads 402 | pf_dim (int): dimensions of the position-wise layer 403 | dropout (float): attention and position-wise layer dropouts 404 | pad_index (int): index used for padding values in input sequences 405 | max_len (int): maximum lengths of input sequences. 406 | """ 407 | 408 | super().__init__() 409 | 410 | self.pad_index = pad_index # if None, don't use masking 411 | 412 | # embeddings 413 | self.tok_embedding = nn.Embedding(num_tokens, hid_dim) 414 | self.pos_embedding = nn.Embedding(max_len, hid_dim) 415 | self.dropout = nn.Dropout(dropout) 416 | 417 | # decoder layers (implemented below) 418 | self.layers = nn.ModuleList([ 419 | DecoderLayer(hid_dim, n_heads, pf_dim, dropout) 420 | for _ in range(n_layers) 421 | ]) 422 | 423 | # scale parameter 424 | self.scale = torch.nn.Parameter( 425 | torch.sqrt(torch.tensor(hid_dim, dtype=torch.float32)), 426 | requires_grad=False 427 | ) 428 | 429 | # output layer 430 | self.classifier = nn.Linear(hid_dim, num_tokens) 431 | 432 | def forward(self, x, enc_out, start_emb=None): 433 | """ 434 | Args: 435 | x (torch.Tensor): token sequences of shape `[bs, seq_len]` 436 | enc_out (torch.Tensor): encoder outputs of shape `[bs, seq_len, hid_dim]` 437 | start_emb (torch.Tensor, optional): starting position embedding of shape `[bs, hid_dim]` 438 | 439 | Returns: 440 | torch.Tensor: decoded sequences of shape `[bs, seq_len, num_tokens]` 441 | """ 442 | device = x.device 443 | bs, dec_seq_len = x.shape[:2] 444 | enc_seq_len, hid_dim = enc_out.shape[1:3] 445 | 446 | if start_emb is not None: 447 | dec_seq_len += 1 448 | 449 | # pad input and encoder outputs to the same seq_len 450 | seq_len = max(dec_seq_len, enc_seq_len) 451 | x = torch.cat([x, self.pad_index * torch.ones(bs, seq_len - dec_seq_len).long().to(device)], dim=1) 452 | enc_out = torch.cat([enc_out, torch.zeros(bs, seq_len - enc_seq_len, hid_dim).to(device)], dim=1) 453 | 454 | # get token embeddings 455 | tok_emb = self.tok_embedding(x) 456 | 457 | # add image embedding: 458 | if start_emb is not None: 459 | tok_emb = torch.cat((start_emb.unsqueeze(1), tok_emb), 1) 460 | 461 | # scale token embeddings with self.scale parameter 462 | tok_emb = tok_emb / self.scale 463 | 464 | # get pos embeddings 465 | indices = torch.arange(seq_len).repeat(bs, 1).to(device) 466 | pos_emb = self.pos_embedding(indices) 467 | 468 | # sum up token and positional embeddings and apply dropout 469 | emb = tok_emb + pos_emb 470 | emb = self.dropout(emb) 471 | 472 | # compute decoder input mask 473 | if start_emb is not None: 474 | x = torch.cat([torch.ones(bs, 1).long().to(device), x], dim=1) 475 | pad_mask = get_pad_mask(x, x, pad_index=self.pad_index) 476 | autoregr_mask = get_autoregressive_mask(x) 477 | input_mask = pad_mask | autoregr_mask 478 | 479 | # compute encoder output mask 480 | enc_inp_mask = (enc_out != 0.).all(dim=-1).long() 481 | enc_mask = get_pad_mask(x, enc_inp_mask, pad_index=self.pad_index) 482 | 483 | # apply encoder layers one by one; input shape is [bs, seq_len, hid dim] 484 | x = emb 485 | for layer in self.layers: 486 | x = layer(x, enc_out, input_mask=input_mask, enc_mask=enc_mask) 487 | 488 | out = self.classifier(x) 489 | 490 | return out 491 | 492 | def generate(self, start_emb, enc_out, caption=None, max_len=25, 493 | temperature=1.0, beam_size=10, top_k=50, eos_index=3): 494 | """Generates text tokens based on the image embedding. 495 | 496 | Args: 497 | start_emb (torch.Tensor): starting position embedding of shape `[1, hid_dim]` 498 | enc_out (torch.Tensor): encoder outputs of shape `[bs, seq_len, hid_dim]` 499 | caption (torch.Tensor, optional): beginning tokens of the caption of shape `[1, seq_len]` 500 | max_len (int): maximum length of the caption 501 | temperature (float): temperature for softmax over logits 502 | beam_size (int): number of maintained branches at each step 503 | top_k (int): number of the most probable tokens to consider during sampling 504 | eos_index (int): index of the EOS (end-of-sequence) token 505 | 506 | Returns: 507 | torch.Tensor: generated caption tokens of shape `[1, min(output_len, max_len)]` 508 | """ 509 | 510 | # beam search sampling helper 511 | helper = BeamSearchHelper( 512 | temperature=temperature, beam_size=beam_size, 513 | top_k=top_k, eos_index=eos_index, 514 | device=start_emb.device 515 | ) 516 | 517 | sample_seq = self.pad_index * torch.ones((1, max_len)) 518 | sample_seq = sample_seq.long().to(start_emb.device) 519 | 520 | # process caption tokens if present 521 | if caption is None: 522 | pos = 0 523 | else: 524 | pos = caption.size(1) 525 | sample_seq[:, :pos] = caption 526 | 527 | # run TransformerDecoder over the inputs and predict the next token 528 | outputs = self(sample_seq, enc_out, start_emb) 529 | logits = outputs[:, pos, :] 530 | 531 | # filter `top_k` values 532 | logits = helper.filter_top_k(logits) 533 | 534 | # compute probabilities and sample k values 535 | sample_ind = helper.sample_k_indices(logits, k=beam_size) 536 | sample_val = helper.filter_by_indices(logits, sample_ind).log_softmax(-1) 537 | sample_ind, sample_val = sample_ind.T, sample_val.T 538 | 539 | # update total prediction sequences 540 | sample_seq = sample_seq.repeat(beam_size, 1) 541 | sample_seq[:, pos:pos + 1] = sample_ind 542 | 543 | # repeat `image_emb` and `enc_out` 544 | enc_out = enc_out.repeat(beam_size, 1, 1) 545 | start_emb = start_emb.repeat(beam_size, 1) 546 | 547 | for i in range(pos + 1, max_len + 1): 548 | # predict the next time step 549 | outputs = self(sample_seq, enc_out, start_emb) 550 | logits = outputs[:, i, :] 551 | 552 | (prev_seqs, prev_vals), (new_ind, new_val) = helper.process_logits( 553 | logits, sample_seq, sample_val 554 | ) 555 | 556 | # create candidate sequences and compute their probabilities 557 | prev_seqs[:, i:i + 1] = new_ind.unsqueeze(0).T 558 | cand_seq = prev_seqs 559 | cand_val = prev_vals.flatten() + new_val 560 | 561 | # sample `beam` sequences 562 | filter_ind = helper.sample_k_indices(cand_val, k=beam_size) 563 | 564 | # update total sequences and their scores 565 | sample_val = cand_val[filter_ind] 566 | sample_seq = cand_seq[filter_ind] 567 | 568 | # filter `has_ended` flags 569 | helper.has_ended = helper.has_ended[filter_ind] 570 | 571 | # check if every branch has ended 572 | if helper.all_ended(): 573 | break 574 | 575 | # sample output sequence 576 | ind = helper.sample_k_indices(sample_val, k=1) 577 | output_seq = sample_seq[ind, :i].squeeze() 578 | 579 | return output_seq 580 | 581 | 582 | class SelfAttentionDecoderLayer(nn.Module): 583 | """Self-Attention Decoder Layer without Encoder-Attention.""" 584 | 585 | def __init__(self, 586 | hid_dim=512, 587 | n_heads=8, 588 | pf_dim=2048, 589 | dropout=0.): 590 | """Initializes SelfAttentionDecoderLayer. 591 | 592 | Args: 593 | hid_dim (int): hidden dimension size 594 | n_heads (int): number of attention heads 595 | pf_dim (int): dimensions of the position-wise layer 596 | dropout (float): attention and position-wise layer dropouts 597 | """ 598 | 599 | super().__init__() 600 | 601 | # masked self-attention + layer normalization 602 | self.self_attn = MultiHeadAttentionLayer(hid_dim, n_heads, dropout) 603 | self.self_attn_ln = nn.LayerNorm(hid_dim) 604 | 605 | # position-wise feedforward layer + layer normalization 606 | self.pf = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout) 607 | self.pf_ln = nn.LayerNorm(hid_dim) 608 | 609 | # attention and position-wise feedforward layer dropouts 610 | self.dropout = nn.Dropout(dropout) 611 | 612 | def forward(self, x, input_mask=None): 613 | """ 614 | Args: 615 | x (torch.Tensor): input sequences of shape `[bs, seq_len, hid_dim]` 616 | input_mask (torch.Tensor): masked self-attention + padding mask of shape `[bs, seq_len, seq_len]` 617 | 618 | Returns: 619 | torch.Tensor: processed sequences of shape `[bs, seq_len, hid_dim]` 620 | """ 621 | ### block 1 622 | # self-attention + dropout 623 | attn_out = self.self_attn(x, x, x, mask=input_mask) 624 | attn_out = self.dropout(attn_out) 625 | 626 | # residual (attention) + attention layer norm 627 | x = self.self_attn_ln(x + attn_out) 628 | 629 | ### block 2 630 | # positionwise feedforward + dropout 631 | ff_out = self.dropout(self.pf(x)) 632 | 633 | # residual (positionwise feedforward) + positionwise feedforward layer norm 634 | x = self.pf_ln(x + ff_out) 635 | 636 | return x 637 | 638 | 639 | class SelfAttentionTransformerDecoder(nn.Module): 640 | """Multi-layer Transformer Decoder without Encoder-Attention blocks. 641 | 642 | Modifies the architecture of Vanilla Transformer Decoder from "Attention Is All You Need" 643 | by taking as an input only a single encoder embedding vector without a sequence of encoded features. 644 | 645 | Requires an embedding for the starting token position. 646 | 647 | Outputs scores for tokens in the target sequence. 648 | 649 | Modifications: 650 | - No encoder outputs as inputs as in a classical Transformer Decoder. 651 | - Learned positional embeddings instead of the sinusoidal positional encoding. 652 | - Prepends image embedding vector to the token embeddings. 653 | """ 654 | 655 | def __init__(self, num_tokens, hid_dim=512, n_layers=6, n_heads=8, 656 | pf_dim=2048, dropout=0., pad_index=None, max_len=128): 657 | """Initializes TransformerImageDecoder. 658 | 659 | Args: 660 | num_tokens (int): number of tokens in input sequences 661 | hid_dim (int): hidden dimension size 662 | n_layers (int): number of Decoder layers 663 | n_heads (int): number of attention heads 664 | pf_dim (int): dimensions of the position-wise layer 665 | dropout (float): attention and position-wise layer dropouts 666 | pad_index (int): index used for padding values in input sequences 667 | max_len (int): maximum lengths of input sequences. 668 | """ 669 | 670 | super().__init__() 671 | 672 | self.pad_index = pad_index # if None, don't use masking 673 | 674 | # embeddings 675 | self.tok_embedding = nn.Embedding(num_tokens, hid_dim) 676 | self.pos_embedding = nn.Embedding(max_len, hid_dim) 677 | self.dropout = nn.Dropout(dropout) 678 | 679 | # decoder layers (implemented below) 680 | self.layers = nn.ModuleList([ 681 | SelfAttentionDecoderLayer(hid_dim, n_heads, pf_dim, dropout) 682 | for _ in range(n_layers) 683 | ]) 684 | 685 | # scale parameter 686 | self.scale = torch.nn.Parameter( 687 | torch.sqrt(torch.tensor(hid_dim, dtype=torch.float32)), 688 | requires_grad=False 689 | ) 690 | 691 | # output layer 692 | self.classifier = nn.Linear(hid_dim, num_tokens) 693 | 694 | def forward(self, x, start_emb): 695 | """ 696 | Args: 697 | x (torch.Tensor): token sequences of shape `[bs, seq_len]` 698 | start_emb (torch.Tensor, optional): starting position embedding of shape `[bs, hid_dim]` 699 | 700 | Returns: 701 | torch.Tensor: decoded sequences of shape `[bs, seq_len, num_tokens]` 702 | """ 703 | device = x.device 704 | 705 | # get token embeddings 706 | tok_emb = self.tok_embedding(x) 707 | 708 | # add start position embedding: 709 | if start_emb is not None: 710 | tok_emb = torch.cat((start_emb.unsqueeze(1), tok_emb), 1) 711 | 712 | # scale token embeddings with self.scale parameter 713 | tok_emb = tok_emb / self.scale 714 | bs, seq_len = tok_emb.shape[:2] 715 | 716 | # get pos embeddings 717 | indices = torch.arange(seq_len).repeat(bs, 1).to(device) 718 | pos_emb = self.pos_embedding(indices) 719 | 720 | # sum up token and positional embeddings and apply dropout 721 | emb = tok_emb + pos_emb 722 | emb = self.dropout(emb) 723 | 724 | # compute decoder input mask 725 | if start_emb is not None: 726 | x = torch.cat([torch.ones(bs, 1).long().to(device), x], dim=1) 727 | pad_mask = get_pad_mask(x, x, pad_index=self.pad_index) 728 | autoregr_mask = get_autoregressive_mask(x) 729 | input_mask = pad_mask | autoregr_mask 730 | 731 | # apply encoder layers one by one; input shape is [bs, seq_len, hid dim] 732 | x = emb 733 | for layer in self.layers: 734 | x = layer(x, input_mask=input_mask) 735 | 736 | out = self.classifier(x) 737 | 738 | return out 739 | 740 | def generate(self, start_emb, caption=None, max_len=25, 741 | temperature=1.0, beam_size=10, top_k=50, eos_index=3): 742 | """Generates text tokens based on the image embedding. 743 | 744 | Args: 745 | start_emb (torch.Tensor): starting position embedding of shape `[1, hid_dim]` 746 | caption (torch.Tensor, optional): beginning tokens of the caption of shape `[1, seq_len]` 747 | max_len (int): maximum length of the caption 748 | temperature (float): temperature for softmax over logits 749 | beam_size (int): number of maintained branches at each step 750 | top_k (int): number of the most probable tokens to consider during sampling 751 | eos_index (int): index of the EOS (end-of-sequence) token 752 | 753 | Returns: 754 | torch.Tensor: generated caption tokens of shape `[1, min(output_len, max_len)]` 755 | """ 756 | 757 | # beam search sampling helper 758 | helper = BeamSearchHelper( 759 | temperature=temperature, beam_size=beam_size, 760 | top_k=top_k, eos_index=eos_index, 761 | device=start_emb.device 762 | ) 763 | 764 | sample_seq = self.pad_index * torch.ones((1, max_len)) 765 | sample_seq = sample_seq.long().to(start_emb.device) 766 | 767 | # process caption tokens if present 768 | if caption is None: 769 | pos = 0 770 | else: 771 | pos = caption.size(1) 772 | sample_seq[:, :pos] = caption 773 | 774 | # run TransformerDecoder over the inputs and predict the next token 775 | outputs = self(sample_seq, start_emb) 776 | logits = outputs[:, pos, :] 777 | 778 | # filter `top_k` values 779 | logits = helper.filter_top_k(logits) 780 | 781 | # compute probabilities and sample k values 782 | sample_ind = helper.sample_k_indices(logits, k=beam_size) 783 | sample_val = helper.filter_by_indices(logits, sample_ind).log_softmax(-1) 784 | sample_ind, sample_val = sample_ind.T, sample_val.T 785 | 786 | # update total prediction sequences 787 | sample_seq = sample_seq.repeat(beam_size, 1) 788 | sample_seq[:, pos:pos + 1] = sample_ind 789 | 790 | # repeat `image_emb` and `enc_out` 791 | start_emb = start_emb.repeat(beam_size, 1) 792 | 793 | for i in range(pos + 1, max_len + 1): 794 | # predict the next time step 795 | outputs = self(sample_seq, start_emb) 796 | logits = outputs[:, i, :] 797 | 798 | (prev_seqs, prev_vals), (new_ind, new_val) = helper.process_logits( 799 | logits, sample_seq, sample_val 800 | ) 801 | 802 | # create candidate sequences and compute their probabilities 803 | prev_seqs[:, i:i + 1] = new_ind.unsqueeze(0).T 804 | cand_seq = prev_seqs 805 | cand_val = prev_vals.flatten() + new_val 806 | 807 | # sample `beam` sequences 808 | filter_ind = helper.sample_k_indices(cand_val, k=beam_size) 809 | 810 | # update total sequences and their scores 811 | sample_val = cand_val[filter_ind] 812 | sample_seq = cand_seq[filter_ind] 813 | 814 | # filter `has_ended` flags 815 | helper.has_ended = helper.has_ended[filter_ind] 816 | 817 | # check if every branch has ended 818 | if helper.all_ended(): 819 | break 820 | 821 | # sample output sequence 822 | ind = helper.sample_k_indices(sample_val, k=1) 823 | output_seq = sample_seq[ind, :i].squeeze() 824 | 825 | return output_seq 826 | --------------------------------------------------------------------------------