├── checkpoint ├── en │ └── .gitignore └── jp │ └── .gitignore ├── model ├── English │ └── .gitignore └── Japanese │ └── .gitignore ├── config.ini ├── src ├── ModelLoader.py ├── utils │ └── utils.py ├── Translator.py ├── models │ ├── distributed.py │ ├── model_builder.py │ ├── reporter.py │ ├── stats.py │ ├── rnn.py │ ├── trainloader.py │ ├── trainer.py │ ├── encoder.py │ ├── preprocess.py │ ├── optimizers.py │ └── neural.py ├── Summarizer.py ├── TestLoader.py └── LangFactory.py ├── testjp.txt ├── youyakuman.py ├── README.md ├── youyakumanJPN_train.py └── test.txt /checkpoint/en/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /checkpoint/jp/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /model/English/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /model/Japanese/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | vocab_path = model/Japanese/juman_vocab.txt 3 | 4 | [Juman] 5 | command = /Tools/jumanpp-2.0.0/bld/src/jumandic/RelWithDebInfo/jumanpp_v2 6 | option = --model=/Tools/jumanpp-2.0.0/model/jumandic.jppmdl --config=/Tools/jumanpp-2.0.0/model/jumandic.conf 7 | -------------------------------------------------------------------------------- /src/ModelLoader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.utils.utils import DictX 4 | from src.models.model_builder import Summarizer 5 | 6 | 7 | class ModelLoader(Summarizer): 8 | def __init__(self, cp, opt, bert_model): 9 | cp_statedict = torch.load(cp, map_location=lambda storage, loc: storage) 10 | opt = DictX(torch.load(opt)) 11 | super(ModelLoader, self).__init__(opt, bert_model) 12 | self.load_cp(cp_statedict) 13 | self.eval() 14 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | import sys 4 | 5 | 6 | @contextlib.contextmanager 7 | def nostdout(): 8 | save_stdout = sys.stdout 9 | sys.stdout = io.BytesIO() 10 | yield 11 | sys.stdout = save_stdout 12 | 13 | 14 | class DictX(dict): 15 | def __getattr__(self, key): 16 | try: 17 | return self[key] 18 | except KeyError as k: 19 | raise AttributeError(k) 20 | 21 | def __setattr__(self, key, value): 22 | self[key] = value 23 | 24 | def __delattr__(self, key): 25 | try: 26 | del self[key] 27 | except KeyError as k: 28 | raise AttributeError(k) -------------------------------------------------------------------------------- /testjp.txt: -------------------------------------------------------------------------------- 1 | 2 | きょう12日は令和元年最後の満月。実は満月には月ごとに名前がついていて、12月はコールドムーンと呼ばれています。夜はぐっと冷え込みますので、暖かくして空を眺めてみてはいかがでしょうか。 3 | 4 | 令和元年最後の満月 5 | きょう12日は令和元年最後の満月が楽しめます。満月には英語圏で様々な呼び名があり、4月はピンクムーン、6月はストロベリームーンなどと月ごとに呼び方が変わります。12月は「寒月」(コールドムーン)と呼ばれます。この時期は冬の寒さが強まり、夜が長くなる頃。寒空に輝く満月を見ると、心がほっとしますね。ちなみに次回の満月、つまり2020年(令和二年)の最初の満月は1月11日です。また2020年最大の満月は4月8日、最小の満月は10月31日です。 6 | 7 | 8 | 今夜の天気 太平洋側ほどきれいに見える 9 | 今夜、満月がきれいに見える所は、北海道と東北の太平洋側、関東から九州の太平洋側が中心でしょう。月の出の時間は、日の入り後まもなくで、東京では16時36分です。夜が長いので、満月をゆっくり楽しめそうです。 10 | 11 | 寒さ対策をしっかり 12 | 満月を鑑賞する際は、寒さ対策が必要です。きょうの日中は関東から西ではコートがなくても過ごせるほど暖かかった所もありましたが、今夜から放射冷却が強まり、東京都心などお帰りの際には気温が10度を下回る所もあり、ぐっと冷え込むでしょう。あす13日朝の最低気温は北海道や東北は、ほとんどの所で氷点下、関東から九州でも5度を下回る所が多くなりそうです。冬のコートやダウンに、ストールや手袋など、防寒対策をしっかりして令和元年最後の満月を楽しんでください。 -------------------------------------------------------------------------------- /src/Translator.py: -------------------------------------------------------------------------------- 1 | from googletrans import Translator 2 | import sys 3 | 4 | 5 | class TranslatorY(Translator): 6 | def __init__(self): 7 | super(TranslatorY, self).__init__() 8 | self.input_lang = 'en' 9 | self.check_lang = True 10 | 11 | def input(self, text): 12 | if self.check_lang: 13 | self.input_lang = self.detect(text).lang 14 | sys.stdout.write('\n'.format(self.input_lang.upper())) 15 | self.check_lang = False 16 | trans = self._translation(text, self.input_lang, 'en') 17 | return trans.text 18 | 19 | def output(self, texts_list): 20 | transed_text = [] 21 | for text in texts_list: 22 | trans = self._translation(text, 'en', self.input_lang) 23 | transed_text.append(trans.text) 24 | return transed_text 25 | 26 | def _translation(self, text, input_lang, output_lang): 27 | return self.translate(text, src=input_lang, dest=output_lang) 28 | -------------------------------------------------------------------------------- /src/models/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | 12 | import torch.distributed 13 | 14 | 15 | def all_gather_list(data, max_size=4096): 16 | """Gathers arbitrary data from all nodes into a list.""" 17 | world_size = torch.distributed.get_world_size() 18 | if not hasattr(all_gather_list, '_in_buffer') or \ 19 | max_size != all_gather_list._in_buffer.size(): 20 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 21 | all_gather_list._out_buffers = [ 22 | torch.cuda.ByteTensor(max_size) 23 | for i in range(world_size) 24 | ] 25 | in_buffer = all_gather_list._in_buffer 26 | out_buffers = all_gather_list._out_buffers 27 | 28 | enc = pickle.dumps(data) 29 | enc_size = len(enc) 30 | if enc_size + 2 > max_size: 31 | raise ValueError( 32 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 33 | assert max_size < 255*256 34 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 35 | in_buffer[1] = enc_size % 255 36 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 37 | 38 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 39 | 40 | results = [] 41 | for i in range(world_size): 42 | out_buffer = out_buffers[i] 43 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 44 | 45 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 46 | result = pickle.loads(bytes_list) 47 | results.append(result) 48 | return results 49 | -------------------------------------------------------------------------------- /youyakuman.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from argparse import RawTextHelpFormatter 4 | 5 | from src.TestLoader import TestLoader 6 | from src.ModelLoader import ModelLoader 7 | from src.Summarizer import SummarizerIO 8 | from src.Translator import TranslatorY 9 | from src.LangFactory import LangFactory 10 | 11 | os.chdir('./') 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, 15 | description=""" 16 | Intro: This is an one-touch extractive summarization machine. 17 | using BertSum as summatization model, extract top N important sentences. 18 | 19 | Note: Since Bert only takes 512 length as inputs, this summarizer crop articles >512 length. 20 | If --super_long option is used, summarizer automatically parse to numbers of 512 length 21 | inputs and summarize per inputs. Number of extraction might slightly altered with --super_long used. 22 | 23 | Example: youyakuman.py -txt_file YOUR_FILE -n 3 24 | """) 25 | 26 | parser.add_argument("-txt_file", default='test.txt', 27 | help='Text file for summarization (encoding:"utf-8_sig")') 28 | parser.add_argument("-n", default=3, type=int, 29 | help='Numbers of extraction summaries') 30 | parser.add_argument("-lang", default='en', type=str, 31 | help='If language of article isn\'t Englisth, will automatically translate by google') 32 | parser.add_argument("--super_long", action='store_true', 33 | help='If length of article >512, this option is needed') 34 | 35 | args = parser.parse_args() 36 | 37 | # if args.super_long: 38 | # sys.stdout.write('\n\n') 39 | 40 | # Language initiator 41 | lf = LangFactory(args.lang) 42 | translator = None if args.lang in lf.support_lang else TranslatorY() 43 | 44 | data = TestLoader(args.txt_file, args.super_long, args.lang, translator).data 45 | model = ModelLoader(lf.toolkit.cp, lf.toolkit.opt, lf.toolkit.bert_model) 46 | summarizer = SummarizerIO(data, model, args.n, translator) 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YouyakuMan 2 | 3 | [![Unstable](https://poser.pugx.org/ali-irawan/xtra/v/unstable.svg)](*https://poser.pugx.org/ali-irawan/xtra/v/unstable.svg*) [![License](https://poser.pugx.org/ali-irawan/xtra/license.svg)](*https://poser.pugx.org/ali-irawan/xtra/license.svg*) 4 | 5 | ### Introduction 6 | 7 | This is an one-touch extractive summarization machine. 8 | 9 | using BertSum as summatization model, extract top N important sentences. 10 | 11 | ![img](https://cdn-images-1.medium.com/max/800/1*NRamBWCtYuS8U6pqpnDiJQ.png) 12 | 13 | --- 14 | 15 | ### Prerequisites 16 | 17 | #### General requirement 18 | 19 | ``` 20 | pip install torch 21 | pip install transformers 22 | pip install googletrans 23 | ``` 24 | 25 | #### Japanese specific requirement 26 | 27 | - [BERT日本語Pretrainedモデル — KUROHASHI-KAWAHARA LAB](http://nlp.ist.i.kyoto-u.ac.jp/index.php?BERT日本語Pretrainedモデル) 28 | - [Juman++ V2の開発版](https://github.com/ku-nlp/jumanpp)[ — KUROHASHI-KAWAHARA LAB](http://nlp.ist.i.kyoto-u.ac.jp/index.php?BERT日本語Pretrainedモデル) 29 | 30 | --- 31 | 32 | ### Pretrained Model 33 | 34 | English: [Here](https://drive.google.com/open?id=1wxf6zTTrhYGmUTVHVMxGpl_GLaZAC1ye) 35 | 36 | Japanese: [Here](https://drive.google.com/file/d/1BBjg0LI8VAgpKT6QN1ah1S49mlUhbM1h/view?usp=sharing) 37 | 38 | * Japanese model updated: trained with 35k data and 120k iteration 39 | 40 | Download and put under directory `checkpoint/en` or `checkpoint/jp` 41 | 42 | --- 43 | 44 | ### Example 45 | 46 | ``` 47 | $python youyakuman.py -txt_file YOUR_FILE -lang LANG -n 3 --super_long 48 | ``` 49 | 50 | #### Note 51 | 52 | Since Bert only takes 512 length as inputs, this summarizer crop articles >512 length. 53 | 54 | If --super_long option is used, summarizer automatically parse to numbers of 512 length inputs and summarize per inputs. Number of extraction might slightly altered with --super_long used. 55 | 56 | --- 57 | 58 | ### Train Example 59 | 60 | ``` 61 | $python youyakumanJPN_train.py -data_folder [training_txt_path] -save_path [model_saving_path] -train_from [pretrained_model_file] 62 | """ 63 | -data_folder : path to train data folder, structure showed as below: 64 | training_txt_path 65 | ├─ article1.pickle 66 | ├─ article2.pickle 67 | .. 68 | """ 69 | ``` 70 | 71 | ### Train Data Preparation 72 | 73 | Training data should be a dictionary saved by `pickle`, to be specifically, a dictionary containing below contents of **one article**. 74 | 75 | ``` 76 | {'body': 'TEXT_BODY', 'summary': 'SUMMARY_1SUMMARY_2SUMMARY3'} 77 | ``` 78 | 79 | --- 80 | ### Version Log: 81 | 82 | 2020-08-03 Updated to `transformer` package, remove redudndancy, model saving format while training 83 | 84 | 2020-02-10 Training part added 85 | 86 | 2019-11-14 Add multiple language support 87 | 88 | 2019-10-29 Add auto parse function, available for long article as input 89 | -------------------------------------------------------------------------------- /youyakumanJPN_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import os 4 | 5 | import argparse 6 | 7 | from src.models.trainloader import TrainLoader 8 | from src.models.model_builder import Summarizer, build_optim 9 | from src.models.trainer import build_trainer 10 | 11 | os.chdir('./') 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("-data_folder", default='', type=str) 16 | parser.add_argument("-batch_size", default=5, type=int) 17 | parser.add_argument("-train_from", default='', type=str) 18 | parser.add_argument("-save_path", default='', type=str) 19 | 20 | parser.add_argument("-train_steps", default=1200000, type=int) 21 | parser.add_argument("-report_every", default=10, type=int) 22 | parser.add_argument("-save_checkpoint_steps", default=1000, type=int) 23 | parser.add_argument("-accum_count", default=2, type=int) 24 | 25 | parser.add_argument("-optim", default='adam', type=str) 26 | parser.add_argument("-learning_rate", default=5e-5, type=float) 27 | parser.add_argument("-beta1", default=0.9, type=float) 28 | parser.add_argument("-beta2", default=0.999, type=float) 29 | parser.add_argument("-decay_method", default='no', type=str) 30 | parser.add_argument("-warmup_steps", default=10000, type=int) 31 | 32 | parser.add_argument("-ff_size", default=2048, type=int) 33 | parser.add_argument("-heads", default=8, type=int) 34 | parser.add_argument("-dropout", default=0.1, type=float) 35 | parser.add_argument("-param_init", default=0.0, type=float) 36 | parser.add_argument("-param_init_glorot", default=True, type=bool) 37 | parser.add_argument("-max_grad_norm", default=0, type=float) 38 | parser.add_argument("-inter_layers", default=2, type=int) 39 | 40 | parser.add_argument("-seed", default='') 41 | 42 | args = parser.parse_args() 43 | 44 | model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval', 'rnn_size'] 45 | 46 | device = "cuda" 47 | device_id = -1 48 | 49 | if args.seed: 50 | torch.manual_seed(args.seed) 51 | random.seed(args.seed) 52 | 53 | def train_loader_fct(): 54 | return TrainLoader(args.data_folder, 512, args.batch_size, device=device, shuffle=True) 55 | 56 | model = Summarizer(args, './model/Japanese/', device, train=True) 57 | if args.train_from != '': 58 | print('Loading checkpoint from %s' % args.train_from) 59 | checkpoint = torch.load(args.train_from, 60 | map_location=lambda storage, loc: storage) 61 | opt = dict(checkpoint['opt']) 62 | for k in opt.keys(): 63 | if k in model_flags: 64 | setattr(args, k, opt[k]) 65 | model.load_cp(checkpoint['model']) 66 | optim = build_optim(args, model, checkpoint) 67 | else: 68 | optim = build_optim(args, model, None) 69 | 70 | trainer = build_trainer(args, model, optim) 71 | trainer.train(train_loader_fct, args.train_steps) 72 | 73 | -------------------------------------------------------------------------------- /src/models/model_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel 4 | from torch.nn.init import xavier_uniform_ 5 | 6 | from src.models.encoder import TransformerInterEncoder 7 | from src.models.optimizers import Optimizer 8 | 9 | 10 | def build_optim(args, model, checkpoint): 11 | saved_optimizer_state_dict = None 12 | 13 | if args.train_from != '': 14 | optim = checkpoint['optim'] 15 | saved_optimizer_state_dict = optim.optimizer.state_dict() 16 | else: 17 | optim = Optimizer( 18 | args.optim, args.learning_rate, args.max_grad_norm, 19 | beta1=args.beta1, beta2=args.beta2, 20 | decay_method=args.decay_method, 21 | warmup_steps=args.warmup_steps) 22 | 23 | optim.set_parameters(list(model.named_parameters())) 24 | 25 | if args.train_from != '': 26 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 27 | optim.learning_rate = args.learning_rate 28 | for param_group in optim.optimizer.param_groups: 29 | param_group['lr'] = args.learning_rate 30 | 31 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 32 | raise RuntimeError( 33 | "Error: loaded Adam optimizer from existing model" + 34 | " but optimizer state is empty") 35 | 36 | return optim 37 | 38 | 39 | class Bert(nn.Module): 40 | def __init__(self, bert_model): 41 | super(Bert, self).__init__() 42 | self.model = BertModel.from_pretrained(bert_model) 43 | 44 | def forward(self, x, segs, mask): 45 | top_vec, _ = self.model(x, token_type_ids=segs, attention_mask=mask) 46 | return top_vec 47 | 48 | 49 | class Summarizer(nn.Module): 50 | def __init__(self, args, bert_model, device='cpu', train=False): 51 | super(Summarizer, self).__init__() 52 | self.device = device 53 | self.bert = Bert(bert_model) 54 | self.encoder = TransformerInterEncoder(self.bert.model.config.hidden_size, 55 | args.ff_size, args.heads, 56 | args.dropout, args.inter_layers) 57 | 58 | if train: 59 | if args.param_init: 60 | for p in self.encoder.parameters(): 61 | p.data.uniform_(-args.param_init, args.param_init) 62 | 63 | if args.param_init_glorot: 64 | for p in self.encoder.parameters(): 65 | if p.dim() > 1: 66 | xavier_uniform_(p) 67 | self.to(device) 68 | 69 | def load_cp(self, pt): 70 | self.load_state_dict(pt, strict=True) 71 | 72 | def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None): 73 | 74 | top_vec = self.bert(x, segs, mask) 75 | sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss] 76 | sents_vec = sents_vec * mask_cls[:, :, None].float() 77 | sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1) 78 | return sent_scores, mask_cls 79 | -------------------------------------------------------------------------------- /src/Summarizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | sys.stdout = open(sys.stdout.fileno(), mode='w', encoding='utf8', buffering=1) 5 | 6 | 7 | class SummarizerIO: 8 | def __init__(self, test_data, model, n, translator=None): 9 | self.data = test_data 10 | self.model = model 11 | self.translator = translator 12 | n = self.n_distribution(n) 13 | start_n = 0 14 | for i, data in enumerate(self.data): 15 | self._evaluate(data) 16 | self._extract_n(n[i], start_n) 17 | start_n += n[i] 18 | 19 | def n_distribution(self, n): 20 | if len(self.data) == 1: 21 | return [n] 22 | else: 23 | last_ratio = sum([x > 0 for x in self.data[-1]['src']])/512 24 | article_len = len(self.data) - 1 + last_ratio 25 | n_sub = max([n/article_len, 0.51]) # At least 1 summary per data input 26 | n_extract = [round(n_sub)]*len(self.data) 27 | n_extract[-1] = round(n_sub*last_ratio) 28 | return n_extract 29 | 30 | def _extract_n(self, n, start_n): 31 | def _get_ngrams(n, text): 32 | ngram_set = set() 33 | text_length = len(text) 34 | max_index_ngram_start = text_length - n 35 | for i in range(max_index_ngram_start + 1): 36 | ngram_set.add(tuple(text[i:i + n])) 37 | return ngram_set 38 | 39 | def _block_tri(c, p): 40 | tri_c = _get_ngrams(3, c.split()) 41 | for s in p: 42 | tri_s = _get_ngrams(3, s.split()) 43 | if len(tri_c.intersection(tri_s)) > 0: 44 | return True 45 | return False 46 | 47 | _pred = [] 48 | for j in self.selected_ids[:self.str_len]: 49 | candidate = self.src_str[j].strip() 50 | if not _block_tri(candidate, _pred): 51 | _pred.append(candidate) 52 | if (len(_pred) == n) or (n==0): 53 | break 54 | 55 | # Translate Summaries to other language 56 | if self.translator: 57 | _pred = self.translator.output(_pred) 58 | 59 | # Print result 60 | for i, pred in enumerate(_pred): 61 | sys.stdout.write("[要約文%s] %s \n" % (start_n+i+1, pred)) 62 | # print("[Summary %s] %s" % (start_n+i+1, pred)) 63 | 64 | def _evaluate(self, test_data): 65 | self.model.eval() 66 | with torch.no_grad(): 67 | src = torch.tensor([test_data['src']]) 68 | segs = torch.tensor([test_data['segs']]) 69 | clss = torch.tensor([test_data['clss']]) 70 | mask = torch.tensor([test_data['mask']]) 71 | mask_cls = torch.tensor([test_data['mask_cls']]) 72 | sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) 73 | 74 | sent_scores = sent_scores + mask.float() 75 | selected_ids = torch.argsort(-sent_scores, 1) 76 | selected_ids = selected_ids.cpu().data.numpy() 77 | self.selected_ids = selected_ids[0] 78 | self.src_str = test_data['src_str'] 79 | self.str_len = len(test_data['src_str']) 80 | self.fname = test_data['fname'] 81 | 82 | # Archieve so far 83 | def diet(self, percent): 84 | _pred = [] 85 | diet_ids = self.selected_ids[:int(self.str_len * percent)] 86 | diet_text = [x for i, x in enumerate(self.src_str) if i in diet_ids] 87 | diet_text = '. \n'.join(diet_text) + '. ' 88 | return diet_text 89 | -------------------------------------------------------------------------------- /src/models/reporter.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | 4 | import time 5 | from datetime import datetime 6 | 7 | from src.models.stats import Statistics 8 | 9 | 10 | def build_report_manager(opt): 11 | if opt.tensorboard: 12 | from tensorboardX import SummaryWriter 13 | tensorboard_log_dir = opt.tensorboard_log_dir 14 | 15 | if not opt.train_from: 16 | tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S") 17 | 18 | writer = SummaryWriter(tensorboard_log_dir, 19 | comment="Unmt") 20 | else: 21 | writer = None 22 | 23 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 24 | tensorboard_writer=writer) 25 | return report_mgr 26 | 27 | 28 | class ReportMgrBase(object): 29 | """ 30 | Report Manager Base class 31 | Inherited classes should override: 32 | * `_report_training` 33 | """ 34 | 35 | def __init__(self, report_every, start_time=-1.): 36 | """ 37 | Args: 38 | report_every(int): Report status every this many sentences 39 | start_time(float): manually set report start time. Negative values 40 | means that you will need to set it later or use `start()` 41 | """ 42 | self.report_every = report_every 43 | self.progress_step = 0 44 | self.start_time = start_time 45 | 46 | def start(self): 47 | self.start_time = time.time() 48 | 49 | def report_training(self, step, num_steps, learning_rate, 50 | report_stats): 51 | """ 52 | This is the user-defined batch-level traing progress 53 | report function. 54 | 55 | Args: 56 | step(int): current step count. 57 | num_steps(int): total number of batches. 58 | learning_rate(float): current learning rate. 59 | report_stats(Statistics): old Statistics instance. 60 | Returns: 61 | report_stats(Statistics): updated Statistics instance. 62 | """ 63 | if self.start_time < 0: 64 | raise ValueError("""ReportMgr needs to be started 65 | (set 'start_time' or use 'start()'""") 66 | 67 | if step % self.report_every == 0: 68 | self._report_training( 69 | step, num_steps, learning_rate, report_stats) 70 | self.progress_step += 1 71 | return Statistics() 72 | else: 73 | return report_stats 74 | 75 | def _report_training(self, *args, **kwargs): 76 | """ To be overridden """ 77 | raise NotImplementedError() 78 | 79 | 80 | class ReportMgr(ReportMgrBase): 81 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 82 | """ 83 | A report manager that writes statistics on standard output as well as 84 | (optionally) TensorBoard 85 | 86 | Args: 87 | report_every(int): Report status every this many sentences 88 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 89 | The TensorBoard Summary writer to use or None 90 | """ 91 | super(ReportMgr, self).__init__(report_every, start_time) 92 | self.tensorboard_writer = tensorboard_writer 93 | 94 | def _report_training(self, step, num_steps, learning_rate, 95 | report_stats): 96 | """ 97 | See base class method `ReportMgrBase.report_training`. 98 | """ 99 | report_stats.output(step, num_steps, 100 | learning_rate, self.start_time) 101 | 102 | report_stats = Statistics() 103 | 104 | return report_stats 105 | -------------------------------------------------------------------------------- /src/TestLoader.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | 3 | from src.LangFactory import LangFactory 4 | 5 | 6 | class TestLoader: 7 | def __init__(self, path, super_long, lang, translator=None): 8 | self.path = path 9 | self.data = [] 10 | self.super_long = super_long 11 | self.langfac = LangFactory(lang) 12 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 13 | self._load_data() 14 | # If rawdata isnt modelized, use google translation to translate to English 15 | if self.langfac.stat is 'Invalid': 16 | self.translator = translator 17 | self._translate() 18 | # Outsource suitable line splitter 19 | self.texts = self.langfac.toolkit.linesplit(self.rawtexts) 20 | # Outsource suitable tokenizer 21 | self.token, self.token_id = self.langfac.toolkit.tokenizer(self.texts) 22 | self._generate_results() 23 | 24 | def _generate_results(self): 25 | 26 | if not self.super_long: 27 | _, _ = self._add_result(self.fname, self.token_id) 28 | else: 29 | # Initialize indexes for while loop 30 | src_start, token_start, src_end = 0, 0, 1 31 | while src_end != 0: 32 | token_end, src_end = self._add_result(self.fname, self.token_id, 33 | src_start, token_start) 34 | token_start = token_end 35 | src_start = src_end 36 | 37 | def _add_result(self, fname, token_all, src_start=0, token_start=0): 38 | results, (token_end, src_end) = self._all_tofixlen(token_all, src_start, token_start) 39 | token, clss, segs, labels, mask, mask_cls, src = results 40 | self.data.append({'fname': fname, 41 | 'src': token, 42 | 'labels': labels, 43 | 'segs': segs, 44 | 'mask': mask, 45 | 'mask_cls': mask_cls, 46 | 'clss': clss, 47 | 'src_str': src}) 48 | return token_end, src_end 49 | 50 | def _load_data(self): 51 | self.fname = self.path.split('/')[-1].split('.')[0] 52 | with open(self.path, 'r', encoding='utf-8_sig', errors='ignore') as f: 53 | self.rawtexts = f.readlines() 54 | self.rawtexts = ' '.join(self.rawtexts) 55 | 56 | def _translate(self): 57 | texts = self.rawtexts 58 | self.texts = self.translator.input(texts) 59 | 60 | def _all_tofixlen(self, token, src_start, token_start): 61 | # Tune All shit into 512 length 62 | token_end = 0 63 | src_end = 0 64 | token = token[token_start:] 65 | src = self.texts[src_start:] 66 | clss = [i for i, x in enumerate(token) if x == self.langfac.toolkit.cls_id] 67 | if len(token) > 512: 68 | clss, token, token_stop, src, src_stop = self._length512(src, token, clss) 69 | token_end = token_start + token_stop 70 | src_end = src_start + src_stop 71 | labels = [0] * len(clss) 72 | mask = ([True] * len(token)) + ([False] * (512 - len(token))) 73 | mask_cls = [True] * len(clss) 74 | token = token + ([self.langfac.toolkit.mask_id] * (512 - len(token))) 75 | segs = [] 76 | flag = 1 77 | for idx in token: 78 | if idx == self.langfac.toolkit.cls_id: 79 | flag = not flag 80 | segs.append(int(flag)) 81 | return (token, clss, segs, labels, mask, mask_cls, src), (token_end, src_end) 82 | 83 | @staticmethod 84 | def _length512(src, token, clss): 85 | if max(clss) > 512: 86 | src_stop = [x > 512 for x in clss].index(True) - 1 87 | else: 88 | src_stop = len(clss) - 1 89 | token_stop = clss[src_stop] 90 | clss = clss[:src_stop] 91 | src = src[:src_stop] 92 | token = token[:token_stop] 93 | return clss, token, token_stop, src, src_stop 94 | -------------------------------------------------------------------------------- /src/models/stats.py: -------------------------------------------------------------------------------- 1 | """ Statistics calculation utility """ 2 | from __future__ import division 3 | 4 | import sys 5 | import time 6 | 7 | 8 | class Statistics(object): 9 | """ 10 | Accumulator for loss statistics. 11 | Currently calculates: 12 | 13 | * accuracy 14 | * perplexity 15 | * elapsed time 16 | """ 17 | 18 | def __init__(self, loss=0, n_docs=0, n_correct=0): 19 | self.loss = loss 20 | self.n_docs = n_docs 21 | self.start_time = time.time() 22 | 23 | @staticmethod 24 | def all_gather_stats(stat, max_size=4096): 25 | """ 26 | Gather a `Statistics` object accross multiple process/nodes 27 | 28 | Args: 29 | stat(:obj:Statistics): the statistics object to gather 30 | accross all processes/nodes 31 | max_size(int): max buffer size to use 32 | 33 | Returns: 34 | `Statistics`, the update stats object 35 | """ 36 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 37 | return stats[0] 38 | 39 | @staticmethod 40 | def all_gather_stats_list(stat_list, max_size=4096): 41 | """ 42 | Gather a `Statistics` list accross all processes/nodes 43 | 44 | Args: 45 | stat_list(list([`Statistics`])): list of statistics objects to 46 | gather accross all processes/nodes 47 | max_size(int): max buffer size to use 48 | 49 | Returns: 50 | our_stats(list([`Statistics`])): list of updated stats 51 | """ 52 | from torch.distributed import get_rank 53 | from src.models.distributed import all_gather_list 54 | 55 | # Get a list of world_size lists with len(stat_list) Statistics objects 56 | all_stats = all_gather_list(stat_list, max_size=max_size) 57 | 58 | our_rank = get_rank() 59 | our_stats = all_stats[our_rank] 60 | for other_rank, stats in enumerate(all_stats): 61 | if other_rank == our_rank: 62 | continue 63 | for i, stat in enumerate(stats): 64 | our_stats[i].update(stat, update_n_src_words=True) 65 | return our_stats 66 | 67 | def update(self, stat, update_n_src_words=False): 68 | """ 69 | Update statistics by suming values with another `Statistics` object 70 | 71 | Args: 72 | stat: another statistic object 73 | update_n_src_words(bool): whether to update (sum) `n_src_words` 74 | or not 75 | 76 | """ 77 | self.loss += stat.loss 78 | 79 | self.n_docs += stat.n_docs 80 | 81 | def xent(self): 82 | """ compute cross entropy """ 83 | if self.n_docs == 0: 84 | return 0 85 | return self.loss / self.n_docs 86 | 87 | def elapsed_time(self): 88 | """ compute elapsed time """ 89 | return time.time() - self.start_time 90 | 91 | def output(self, step, num_steps, learning_rate, start): 92 | """Write out statistics to stdout. 93 | 94 | Args: 95 | step (int): current step 96 | n_batch (int): total batches 97 | start (int): start time of step. 98 | """ 99 | t = self.elapsed_time() 100 | step_fmt = "%2d" % step 101 | if num_steps > 0: 102 | step_fmt = "%s/%5d" % (step_fmt, num_steps) 103 | print( 104 | ("Step %s; xent: %4.2f; " + 105 | "lr: %7.7f; %3.0f docs/s; %6.0f sec") 106 | % (step_fmt, 107 | self.xent(), 108 | learning_rate, 109 | self.n_docs / (t + 1e-5), 110 | time.time() - start)) 111 | sys.stdout.flush() 112 | 113 | def log_tensorboard(self, prefix, writer, learning_rate, step): 114 | """ display statistics to tensorboard """ 115 | t = self.elapsed_time() 116 | writer.add_scalar(prefix + "/xent", self.xent(), step) 117 | writer.add_scalar(prefix + "/lr", learning_rate, step) 118 | -------------------------------------------------------------------------------- /src/models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class LayerNormLSTMCell(nn.LSTMCell): 7 | 8 | def __init__(self, input_size, hidden_size, bias=True): 9 | super().__init__(input_size, hidden_size, bias) 10 | 11 | self.ln_ih = nn.LayerNorm(4 * hidden_size) 12 | self.ln_hh = nn.LayerNorm(4 * hidden_size) 13 | self.ln_ho = nn.LayerNorm(hidden_size) 14 | 15 | def forward(self, input, hidden=None): 16 | self.check_forward_input(input) 17 | if hidden is None: 18 | hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) 19 | cx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) 20 | else: 21 | hx, cx = hidden 22 | self.check_forward_hidden(input, hx, '[0]') 23 | self.check_forward_hidden(input, cx, '[1]') 24 | 25 | gates = self.ln_ih(F.linear(input, self.weight_ih, self.bias_ih)) \ 26 | + self.ln_hh(F.linear(hx, self.weight_hh, self.bias_hh)) 27 | i, f, o = gates[:, :(3 * self.hidden_size)].sigmoid().chunk(3, 1) 28 | g = gates[:, (3 * self.hidden_size):].tanh() 29 | 30 | cy = (f * cx) + (i * g) 31 | hy = o * self.ln_ho(cy).tanh() 32 | return hy, cy 33 | 34 | 35 | class LayerNormLSTM(nn.Module): 36 | 37 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, bidirectional=False): 38 | super().__init__() 39 | self.input_size = input_size 40 | self.hidden_size = hidden_size 41 | self.num_layers = num_layers 42 | self.bidirectional = bidirectional 43 | 44 | num_directions = 2 if bidirectional else 1 45 | self.hidden0 = nn.ModuleList([ 46 | LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions), 47 | hidden_size=hidden_size, bias=bias) 48 | for layer in range(num_layers) 49 | ]) 50 | 51 | if self.bidirectional: 52 | self.hidden1 = nn.ModuleList([ 53 | LayerNormLSTMCell(input_size=(input_size if layer == 0 else hidden_size * num_directions), 54 | hidden_size=hidden_size, bias=bias) 55 | for layer in range(num_layers) 56 | ]) 57 | 58 | def forward(self, input, hidden=None): 59 | seq_len, batch_size, hidden_size = input.size() # supports TxNxH only 60 | num_directions = 2 if self.bidirectional else 1 61 | if hidden is None: 62 | hx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False) 63 | cx = input.new_zeros(self.num_layers * num_directions, batch_size, self.hidden_size, requires_grad=False) 64 | else: 65 | hx, cx = hidden 66 | 67 | ht = [[None, ] * (self.num_layers * num_directions)] * seq_len 68 | ct = [[None, ] * (self.num_layers * num_directions)] * seq_len 69 | 70 | if self.bidirectional: 71 | xs = input 72 | for l, (layer0, layer1) in enumerate(zip(self.hidden0, self.hidden1)): 73 | l0, l1 = 2 * l, 2 * l + 1 74 | h0, c0, h1, c1 = hx[l0], cx[l0], hx[l1], cx[l1] 75 | for t, (x0, x1) in enumerate(zip(xs, reversed(xs))): 76 | ht[t][l0], ct[t][l0] = layer0(x0, (h0, c0)) 77 | h0, c0 = ht[t][l0], ct[t][l0] 78 | t = seq_len - 1 - t 79 | ht[t][l1], ct[t][l1] = layer1(x1, (h1, c1)) 80 | h1, c1 = ht[t][l1], ct[t][l1] 81 | xs = [torch.cat((h[l0], h[l1]), dim=1) for h in ht] 82 | y = torch.stack(xs) 83 | hy = torch.stack(ht[-1]) 84 | cy = torch.stack(ct[-1]) 85 | else: 86 | h, c = hx, cx 87 | for t, x in enumerate(input): 88 | for l, layer in enumerate(self.hidden0): 89 | ht[t][l], ct[t][l] = layer(x, (h[l], c[l])) 90 | x = ht[t][l] 91 | h, c = ht[t], ct[t] 92 | y = torch.stack([h[-1] for h in ht]) 93 | hy = torch.stack(ht[-1]) 94 | cy = torch.stack(ct[-1]) 95 | 96 | return y, (hy, cy) 97 | -------------------------------------------------------------------------------- /src/models/trainloader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import random 4 | import torch 5 | 6 | from src.models.preprocess import Preprocess 7 | 8 | 9 | def lazy_dataset(data_path, shuffle): 10 | data_path = data_path if data_path[-1] is '/' else data_path+'/' 11 | 12 | def _lazy_load(path): 13 | with open(path, 'rb') as f: 14 | return pickle.load(f) 15 | 16 | pts = [data_path + x for x in os.listdir(data_path) if '.pickle' in x] 17 | if pts: 18 | if shuffle: 19 | random.shuffle(pts) 20 | 21 | for pt in pts: 22 | yield _lazy_load(pt) 23 | else: 24 | raise IOError('Data not exist in {}'.format(data_path)) 25 | 26 | 27 | class Batch(object): 28 | def __init__(self, dataset=None, device=None,): 29 | """Create a Batch from a list of examples.""" 30 | if dataset is not None: 31 | self.batch_size = len(dataset) 32 | pre_src = [x['src'] for x in dataset] 33 | pre_labels = [x['labels'] for x in dataset] 34 | pre_segs = [x['segs'] for x in dataset] 35 | pre_clss = [x['clss'] for x in dataset] 36 | pre_srctxt = [x['src_str'] for x in dataset] 37 | 38 | src = torch.tensor(pre_src) 39 | segs = torch.tensor(pre_segs) 40 | mask = ~(src == 0) 41 | 42 | labels = torch.tensor(self._pad(pre_labels, -1)) 43 | clss = torch.tensor(self._pad(pre_clss, -1)) 44 | mask_cls = ~(clss == -1) 45 | clss[clss == -1] = 0 46 | 47 | setattr(self, 'clss', clss.to(device)) 48 | setattr(self, 'mask_cls', mask_cls.to(device)) 49 | setattr(self, 'src', src.to(device)) 50 | setattr(self, 'labels', labels.to(device)) 51 | setattr(self, 'segs', segs.to(device)) 52 | setattr(self, 'mask', mask.to(device)) 53 | setattr(self, 'src_str', pre_srctxt) 54 | 55 | def __len__(self): 56 | return self.batch_size 57 | 58 | @staticmethod 59 | def _pad(data, pad_id, width=-1): 60 | if width == -1: 61 | width = max(len(d) for d in data) 62 | rtn_data = [d + [pad_id] * (width - len(d)) for d in data] 63 | return rtn_data 64 | 65 | 66 | class TrainLoader: 67 | def __init__(self, data_path, input_length, batch_size, 68 | device='cuda', shuffle=True): 69 | self.data_path = data_path if data_path[-1] is '/' else data_path+'/' 70 | self.shuffle = shuffle 71 | self.preprocessor = Preprocess() 72 | self.input_len = input_length 73 | self.batch_size = batch_size 74 | self.device = device 75 | self.iterer = self._lazy_loader() 76 | 77 | def __next__(self): 78 | return self._next_batch() 79 | 80 | def _lazy_loader(self): 81 | def _lazy_load(path): 82 | with open(path, 'rb') as f: 83 | return pickle.load(f) 84 | # shuffle data in folders and load one per yield 85 | pts = [self.data_path + x for x in os.listdir(self.data_path) if '.pickle' in x] 86 | if pts: 87 | if self.shuffle: 88 | random.shuffle(pts) 89 | for pt in pts: 90 | data_dic = _lazy_load(pt) 91 | if (data_dic['summary'] != '')&(data_dic['body'] != ''): 92 | yield _lazy_load(pt) 93 | else: 94 | raise IOError('Data not exist in {}'.format(self.data_path)) 95 | 96 | def _next_batch(self): 97 | while True: 98 | try: 99 | dataset = [] 100 | try: 101 | # Drop the current dataset for decreasing memory 102 | for i in range(self.batch_size): 103 | rawdata = next(self.iterer) 104 | data = self.preprocessor(rawdata, self.input_len) 105 | dataset.append(data) 106 | batched_dataset = Batch(dataset=dataset, device=self.device) 107 | return batched_dataset 108 | except Exception as e: 109 | print('DataWarning with data has error {}'.format(e)) 110 | except StopIteration: 111 | self.iterer = self._lazy_loader() 112 | return None 113 | -------------------------------------------------------------------------------- /src/LangFactory.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pyknp import Juman 3 | from configparser import ConfigParser 4 | from transformers import BertTokenizer 5 | 6 | config = ConfigParser() 7 | config.read('./config.ini') 8 | 9 | 10 | class JumanTokenizer: 11 | def __init__(self): 12 | self.juman = Juman(command=config['Juman']['command'], 13 | option=config['Juman']['option']) 14 | 15 | def __call__(self, text): 16 | result = self.juman.analysis(text) 17 | return [mrph.midasi for mrph in result.mrph_list()] 18 | 19 | 20 | class LangFactory: 21 | def __init__(self, lang): 22 | self.support_lang = ['en', 'jp'] 23 | self.lang = lang 24 | self.stat = 'valid' 25 | if self.lang not in self.support_lang: 26 | print('Language not supported, will activate Translation.') 27 | self.stat = 'Invalid' 28 | self._toolchooser() 29 | 30 | def _toolchooser(self): 31 | if self.lang == 'jp': 32 | self.toolkit = JapaneseWorker() 33 | elif self.lang == 'en': 34 | self.toolkit = EnglishWorker() 35 | else: 36 | self.toolkit = EnglishWorker() 37 | 38 | 39 | class JapaneseWorker: 40 | def __init__(self): 41 | self.juman_tokenizer = JumanTokenizer() 42 | self.bert_tokenizer = BertTokenizer(config['DEFAULT']['vocab_path'], 43 | do_basic_tokenize=False) 44 | self.cls_id = self.bert_tokenizer.vocab['[CLS]'] 45 | self.mask_id = self.bert_tokenizer.vocab['[MASK]'] 46 | self.bert_model = 'model/Japanese/' 47 | 48 | self.cp = 'checkpoint/jp/cp_step_1200000.pt' 49 | self.opt = 'checkpoint/jp/opt_step_1200000.pt' 50 | 51 | @staticmethod 52 | def linesplit(src): 53 | """ 54 | :param src: type str, String type article 55 | :return: type list, punctuation seperated sentences 56 | """ 57 | def remove_newline(x): 58 | x = x.replace('\n', '') 59 | return x 60 | 61 | def remove_blank(x): 62 | x = x.replace(' ', '') 63 | return x 64 | 65 | def remove_unknown(x): 66 | unknown = ['\u3000'] 67 | for h in unknown: 68 | x = x.replace(h, '') 69 | return x 70 | src = remove_blank(src) 71 | src = remove_newline(src) 72 | src = remove_unknown(src) 73 | src_line = re.split('。(? 0 33 | if model: 34 | self.model.train() 35 | 36 | def train(self, train_iter_fct, train_steps): 37 | step = self.optim._step + 1 38 | true_batchs = [] 39 | accum = 0 40 | normalization = 0 41 | train_iter = train_iter_fct() 42 | 43 | total_stats = Statistics() 44 | report_stats = Statistics() 45 | self._start_report_manager(start_time=total_stats.start_time) 46 | 47 | while step <= train_steps: 48 | reduce_counter = 0 49 | batch = next(train_iter) 50 | 51 | true_batchs.append(batch) 52 | normalization += batch.batch_size 53 | accum += 1 54 | if accum == self.grad_accum_count: 55 | reduce_counter += 1 56 | 57 | self._gradient_accumulation( 58 | true_batchs, normalization, total_stats, 59 | report_stats) 60 | 61 | report_stats = self._report_training( 62 | step, train_steps, 63 | self.optim.learning_rate, 64 | report_stats) 65 | 66 | true_batchs = [] 67 | accum = 0 68 | normalization = 0 69 | if step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0: 70 | self._save(step) 71 | 72 | step += 1 73 | if step > train_steps: 74 | break 75 | train_iter = train_iter_fct() 76 | 77 | return total_stats 78 | 79 | def _gradient_accumulation(self, true_batchs, normalization, total_stats, 80 | report_stats): 81 | if self.grad_accum_count > 1: 82 | self.model.zero_grad() 83 | 84 | for batch in true_batchs: 85 | if self.grad_accum_count == 1: 86 | self.model.zero_grad() 87 | 88 | src = batch.src 89 | labels = batch.labels 90 | segs = batch.segs 91 | clss = batch.clss 92 | mask = batch.mask 93 | mask_cls = batch.mask_cls 94 | 95 | sent_scores, mask = self.model(src, segs, clss, mask, mask_cls) 96 | 97 | loss = self.loss(sent_scores, labels.float()) 98 | loss = (loss * mask.float()).sum() 99 | (loss / loss.numel()).backward() 100 | 101 | batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization) 102 | 103 | total_stats.update(batch_stats) 104 | report_stats.update(batch_stats) 105 | 106 | if self.grad_accum_count == 1: 107 | self.optim.step() 108 | 109 | if self.grad_accum_count > 1: 110 | self.optim.step() 111 | 112 | def _save(self, step): 113 | real_model = self.model 114 | 115 | model_state_dict = real_model.state_dict() 116 | checkpoint = { 117 | 'model': model_state_dict, 118 | 'opt': self.args, 119 | 'optim': self.optim, 120 | } 121 | print(f'Saving checkpoint {self.args.save_path}/model_step_{step}.pt') 122 | if not os.path.exists(self.args.save_path): 123 | os.mkdir(self.args.save_path) 124 | torch.save(checkpoint, f'{self.args.save_path}/model_step_{step}.pt') 125 | torch.save(model_state_dict, f'{self.args.save_path}/cp_step_{step}.pt') 126 | torch.save(self.args, f'{self.args.save_path}/opt_{step}.pt') 127 | 128 | def _start_report_manager(self, start_time=None): 129 | if self.report_manager is not None: 130 | if start_time is None: 131 | self.report_manager.start() 132 | else: 133 | self.report_manager.start_time = start_time 134 | 135 | def _report_training(self, step, num_steps, learning_rate, 136 | report_stats): 137 | if self.report_manager is not None: 138 | return self.report_manager.report_training( 139 | step, num_steps, learning_rate, report_stats) 140 | -------------------------------------------------------------------------------- /src/models/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from src.models.neural import MultiHeadedAttention, PositionwiseFeedForward 7 | from src.models.rnn import LayerNormLSTM 8 | 9 | 10 | class Classifier(nn.Module): 11 | def __init__(self, hidden_size): 12 | super(Classifier, self).__init__() 13 | self.linear1 = nn.Linear(hidden_size, 1) 14 | self.sigmoid = nn.Sigmoid() 15 | 16 | def forward(self, x, mask_cls): 17 | h = self.linear1(x).squeeze(-1) 18 | sent_scores = self.sigmoid(h) * mask_cls.float() 19 | return sent_scores 20 | 21 | 22 | class PositionalEncoding(nn.Module): 23 | 24 | def __init__(self, dropout, dim, max_len=5000): 25 | pe = torch.zeros(max_len, dim) 26 | position = torch.arange(0, max_len).unsqueeze(1) 27 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 28 | -(math.log(10000.0) / dim))) 29 | pe[:, 0::2] = torch.sin(position.float() * div_term) 30 | pe[:, 1::2] = torch.cos(position.float() * div_term) 31 | pe = pe.unsqueeze(0) 32 | super(PositionalEncoding, self).__init__() 33 | self.register_buffer('pe', pe) 34 | self.dropout = nn.Dropout(p=dropout) 35 | self.dim = dim 36 | 37 | def forward(self, emb, step=None): 38 | emb = emb * math.sqrt(self.dim) 39 | if (step): 40 | emb = emb + self.pe[:, step][:, None, :] 41 | 42 | else: 43 | emb = emb + self.pe[:, :emb.size(1)] 44 | emb = self.dropout(emb) 45 | return emb 46 | 47 | def get_emb(self, emb): 48 | return self.pe[:, :emb.size(1)] 49 | 50 | 51 | class TransformerEncoderLayer(nn.Module): 52 | def __init__(self, d_model, heads, d_ff, dropout): 53 | super(TransformerEncoderLayer, self).__init__() 54 | 55 | self.self_attn = MultiHeadedAttention( 56 | heads, d_model, dropout=dropout) 57 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 58 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 59 | self.dropout = nn.Dropout(dropout) 60 | 61 | def forward(self, iter, query, inputs, mask): 62 | if (iter != 0): 63 | input_norm = self.layer_norm(inputs) 64 | else: 65 | input_norm = inputs 66 | 67 | mask = mask.unsqueeze(1) 68 | context = self.self_attn(input_norm, input_norm, input_norm, 69 | mask=mask) 70 | out = self.dropout(context) + inputs 71 | return self.feed_forward(out) 72 | 73 | 74 | class TransformerInterEncoder(nn.Module): 75 | def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0): 76 | super(TransformerInterEncoder, self).__init__() 77 | self.d_model = d_model 78 | self.num_inter_layers = num_inter_layers 79 | self.pos_emb = PositionalEncoding(dropout, d_model) 80 | self.transformer_inter = nn.ModuleList( 81 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 82 | for _ in range(num_inter_layers)]) 83 | self.dropout = nn.Dropout(dropout) 84 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 85 | self.wo = nn.Linear(d_model, 1, bias=True) 86 | self.sigmoid = nn.Sigmoid() 87 | 88 | def forward(self, top_vecs, mask): 89 | """ See :obj:`EncoderBase.forward()`""" 90 | 91 | batch_size, n_sents = top_vecs.size(0), top_vecs.size(1) 92 | pos_emb = self.pos_emb.pe[:, :n_sents] 93 | x = top_vecs * mask[:, :, None].float() 94 | x = x + pos_emb 95 | 96 | for i in range(self.num_inter_layers): 97 | x = self.transformer_inter[i](i, x, x, ~mask) # all_sents * max_tokens * dim 98 | 99 | x = self.layer_norm(x) 100 | sent_scores = self.sigmoid(self.wo(x)) 101 | sent_scores = sent_scores.squeeze(-1) * mask.float() 102 | 103 | return sent_scores 104 | 105 | 106 | class RNNEncoder(nn.Module): 107 | 108 | def __init__(self, bidirectional, num_layers, input_size, 109 | hidden_size, dropout=0.0): 110 | super(RNNEncoder, self).__init__() 111 | num_directions = 2 if bidirectional else 1 112 | assert hidden_size % num_directions == 0 113 | hidden_size = hidden_size // num_directions 114 | 115 | self.rnn = LayerNormLSTM( 116 | input_size=input_size, 117 | hidden_size=hidden_size, 118 | num_layers=num_layers, 119 | bidirectional=bidirectional) 120 | 121 | self.wo = nn.Linear(num_directions * hidden_size, 1, bias=True) 122 | self.dropout = nn.Dropout(dropout) 123 | self.sigmoid = nn.Sigmoid() 124 | 125 | def forward(self, x, mask): 126 | """See :func:`EncoderBase.forward()`""" 127 | x = torch.transpose(x, 1, 0) 128 | memory_bank, _ = self.rnn(x) 129 | memory_bank = self.dropout(memory_bank) + x 130 | memory_bank = torch.transpose(memory_bank, 1, 0) 131 | 132 | sent_scores = self.sigmoid(self.wo(memory_bank)) 133 | sent_scores = sent_scores.squeeze(-1) * mask.float() 134 | return sent_scores 135 | -------------------------------------------------------------------------------- /src/models/preprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from pyknp import Juman 4 | from sumeval.metrics.rouge import RougeCalculator 5 | from configparser import ConfigParser 6 | from transformers import BertTokenizer 7 | 8 | config = ConfigParser() 9 | config.read('config.ini') 10 | 11 | 12 | class JumanTokenizer: 13 | def __init__(self): 14 | self.juman = Juman(command=config['Juman']['command'], 15 | option=config['Juman']['option']) 16 | 17 | def __call__(self, text): 18 | result = self.juman.analysis(text) 19 | return [mrph.midasi for mrph in result.mrph_list()] 20 | 21 | 22 | class RougeNCalc: 23 | def __init__(self): 24 | self.rouge = RougeCalculator(stopwords=True, lang="ja") 25 | 26 | def __call__(self, summary, reference): 27 | score = self.rouge.rouge_n(summary, reference, n=1) 28 | return score 29 | 30 | 31 | class Preprocess: 32 | def __init__(self): 33 | self.juman_tokenizer = JumanTokenizer() 34 | self.rouge_calculator = RougeNCalc() 35 | self.bert_tokenizer = BertTokenizer(config['DEFAULT']['vocab_path'], 36 | do_lower_case=False, do_basic_tokenize=False) 37 | self.trim_input = 0 38 | self.trim_clss = 0 39 | 40 | def __call__(self, data_dic, length): 41 | self.src_body = data_dic['body'] 42 | self.src_summary = data_dic['summary'].split('') 43 | self._init_data() 44 | 45 | if self.src_body is '': 46 | raise ValueError('Empty data') 47 | 48 | # step 1. article to lines 49 | self._split_line() 50 | # step 2. pick extractive summary by rouge 51 | self._rougematch() 52 | # step 3. tokenize 53 | self._tokenize() 54 | # step 4. clss process 55 | self._prep_clss() 56 | # step 5. segs process 57 | self._prep_segs() 58 | # step 6. trim length for input 59 | self._set_length(length) 60 | 61 | return {'src': self.tokenid, 62 | 'labels': self.label, 63 | 'segs': self.segs, 64 | 'mask': self.mask, 65 | 'mask_cls': self.mask_cls, 66 | 'clss': self.clss, 67 | 'src_str': self.src_line} 68 | 69 | def _init_data(self): 70 | self.src_line = [] 71 | self.label = [] 72 | self.tokenid = [] 73 | self.token = [] 74 | self.clss = [] 75 | self.segs = [] 76 | self.mask = [] 77 | self.mask_cls = [] 78 | 79 | # step 1. 80 | def _split_line(self): 81 | # regex note: (?!...) Negative Lookahead 82 | # e.g. /foo(?!bar)/ for "foobar foobaz" get "foobaz" only 83 | self.src_line = re.split('。(? n: 125 | # If last sentence starts after 512 126 | if self.clss[-1] > 512: 127 | for i, idx in enumerate(self.clss): 128 | if idx > n: 129 | # Index of last [SEP] in length=n 130 | self.trim_input = self.clss[i-1] - 1 131 | # Index of last [CLS] index in clss 132 | self.trim_clss = i - 2 133 | break 134 | # If src longer than 512 but last sentence start < 512 135 | else: 136 | self.trim_input = self.clss[len(self.clss) - 1] - 1 137 | self.trim_clss = len(self.clss) - 2 138 | # Do nothing if length < n 139 | if self.trim_clss*self.trim_input == 0: 140 | return 141 | self.tokenid = self.tokenid[:(self.trim_input+1)] 142 | self.segs = self.segs[:(self.trim_input+1)] 143 | self.clss = self.clss[:(self.trim_clss+1)] 144 | self.label = self.label[:(self.trim_clss+1)] 145 | self.src_line = self.src_line[:(self.trim_clss+1)] 146 | 147 | def __add_mask(self, n): 148 | # from index to len: +1 149 | pad_len = (n - len(self.tokenid)) 150 | self.tokenid = self.tokenid + ([self.bert_tokenizer.vocab['[MASK]']] * pad_len) 151 | self.segs = self.segs + ([int(not self.segs[-1])] * pad_len) 152 | -------------------------------------------------------------------------------- /src/models/optimizers.py: -------------------------------------------------------------------------------- 1 | """ Optimizers class """ 2 | import torch.optim as optim 3 | from torch.nn.utils import clip_grad_norm_ 4 | 5 | 6 | # from onmt.utils import use_gpu 7 | 8 | 9 | def use_gpu(opt): 10 | """ 11 | Creates a boolean if gpu used 12 | """ 13 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 14 | (hasattr(opt, 'gpu') and opt.gpu > -1) 15 | 16 | 17 | class MultipleOptimizer(object): 18 | """ Implement multiple optimizers needed for sparse adam """ 19 | 20 | def __init__(self, op): 21 | """ ? """ 22 | self.optimizers = op 23 | 24 | def zero_grad(self): 25 | """ ? """ 26 | for op in self.optimizers: 27 | op.zero_grad() 28 | 29 | def step(self): 30 | """ ? """ 31 | for op in self.optimizers: 32 | op.step() 33 | 34 | @property 35 | def state(self): 36 | """ ? """ 37 | return {k: v for op in self.optimizers for k, v in op.state.items()} 38 | 39 | def state_dict(self): 40 | """ ? """ 41 | return [op.state_dict() for op in self.optimizers] 42 | 43 | def load_state_dict(self, state_dicts): 44 | """ ? """ 45 | assert len(state_dicts) == len(self.optimizers) 46 | for i in range(len(state_dicts)): 47 | self.optimizers[i].load_state_dict(state_dicts[i]) 48 | 49 | 50 | class Optimizer(object): 51 | """ 52 | Controller class for optimization. Mostly a thin 53 | wrapper for `optim`, but also useful for implementing 54 | rate scheduling beyond what is currently available. 55 | Also implements necessary methods for training RNNs such 56 | as grad manipulations. 57 | 58 | Args: 59 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam] 60 | lr (float): learning rate 61 | lr_decay (float, optional): learning rate decay multiplier 62 | start_decay_steps (int, optional): step to start learning rate decay 63 | beta1, beta2 (float, optional): parameters for adam 64 | adagrad_accum (float, optional): initialization parameter for adagrad 65 | decay_method (str, option): custom decay options 66 | warmup_steps (int, option): parameter for `noam` decay 67 | 68 | We use the default parameters for Adam that are suggested by 69 | the original paper https://arxiv.org/pdf/1412.6980.pdf 70 | These values are also used by other established implementations, 71 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 72 | https://keras.io/optimizers/ 73 | Recently there are slightly different values used in the paper 74 | "Attention is all you need" 75 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 76 | was used there however, beta2=0.999 is still arguably the more 77 | established value, so we use that here as well 78 | """ 79 | 80 | def __init__(self, method, learning_rate, max_grad_norm, 81 | lr_decay=1, start_decay_steps=None, decay_steps=None, 82 | beta1=0.9, beta2=0.999, 83 | adagrad_accum=0.0, 84 | decay_method=None, 85 | warmup_steps=4000 86 | ): 87 | self.last_ppl = None 88 | self.learning_rate = learning_rate 89 | print(f'LR is: {self.learning_rate}') 90 | self.original_lr = learning_rate 91 | self.max_grad_norm = max_grad_norm 92 | self.method = method 93 | self.lr_decay = lr_decay 94 | self.start_decay_steps = start_decay_steps 95 | self.decay_steps = decay_steps 96 | self.start_decay = False 97 | self._step = 0 98 | self.betas = [beta1, beta2] 99 | self.adagrad_accum = adagrad_accum 100 | self.decay_method = decay_method 101 | self.warmup_steps = warmup_steps 102 | 103 | def set_parameters(self, params): 104 | self.params = [] 105 | self.sparse_params = [] 106 | for k, p in params: 107 | if p.requires_grad: 108 | if self.method != 'sparseadam' or "embed" not in k: 109 | self.params.append(p) 110 | else: 111 | self.sparse_params.append(p) 112 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate, 113 | betas=self.betas, eps=1e-9) 114 | 115 | def _set_rate(self, learning_rate): 116 | self.learning_rate = learning_rate 117 | if self.method != 'sparseadam': 118 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 119 | else: 120 | for op in self.optimizer.optimizers: 121 | op.param_groups[0]['lr'] = self.learning_rate 122 | 123 | def step(self): 124 | """Update the model parameters based on current gradients. 125 | 126 | Optionally, will employ gradient modification or update learning 127 | rate. 128 | """ 129 | self._step += 1 130 | 131 | # Decay method used in tensor2tensor. 132 | if self.decay_method == "noam": 133 | self._set_rate( 134 | self.original_lr * 135 | min(self._step ** (-0.5), 136 | self._step * self.warmup_steps**(-1.5))) 137 | elif self.decay_method == "no": 138 | self.learning_rate = self.learning_rate 139 | # Decay based on start_decay_steps every decay_steps 140 | else: 141 | if ((self.start_decay_steps is not None) and ( 142 | self._step >= self.start_decay_steps)): 143 | self.start_decay = True 144 | if self.start_decay: 145 | if ((self._step - self.start_decay_steps) 146 | % self.decay_steps == 0): 147 | self.learning_rate = self.learning_rate * self.lr_decay 148 | 149 | if self.method != 'sparseadam': 150 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 151 | 152 | if self.max_grad_norm: 153 | clip_grad_norm_(self.params, self.max_grad_norm) 154 | self.optimizer.step() 155 | -------------------------------------------------------------------------------- /src/models/neural.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def gelu(x): 8 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 9 | 10 | 11 | class PositionwiseFeedForward(nn.Module): 12 | """ A two-layer Feed-Forward-Network with residual layer norm. 13 | 14 | Args: 15 | d_model (int): the size of input for the first-layer of the FFN. 16 | d_ff (int): the hidden layer size of the second-layer 17 | of the FNN. 18 | dropout (float): dropout probability in :math:`[0, 1)`. 19 | """ 20 | 21 | def __init__(self, d_model, d_ff, dropout=0.1): 22 | super(PositionwiseFeedForward, self).__init__() 23 | self.w_1 = nn.Linear(d_model, d_ff) 24 | self.w_2 = nn.Linear(d_ff, d_model) 25 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 26 | self.actv = gelu 27 | self.dropout_1 = nn.Dropout(dropout) 28 | self.dropout_2 = nn.Dropout(dropout) 29 | 30 | def forward(self, x): 31 | inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) 32 | output = self.dropout_2(self.w_2(inter)) 33 | return output + x 34 | 35 | 36 | class MultiHeadedAttention(nn.Module): 37 | """ 38 | Multi-Head Attention module from 39 | "Attention is All You Need" 40 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. 41 | 42 | Similar to standard `dot` attention but uses 43 | multiple attention distributions simulataneously 44 | to select relevant items. 45 | 46 | .. mermaid:: 47 | 48 | graph BT 49 | A[key] 50 | B[value] 51 | C[query] 52 | O[output] 53 | subgraph Attn 54 | D[Attn 1] 55 | E[Attn 2] 56 | F[Attn N] 57 | end 58 | A --> D 59 | C --> D 60 | A --> E 61 | C --> E 62 | A --> F 63 | C --> F 64 | D --> O 65 | E --> O 66 | F --> O 67 | B --> O 68 | 69 | Also includes several additional tricks. 70 | 71 | Args: 72 | head_count (int): number of parallel heads 73 | model_dim (int): the dimension of keys/values/queries, 74 | must be divisible by head_count 75 | dropout (float): dropout parameter 76 | """ 77 | 78 | def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): 79 | assert model_dim % head_count == 0 80 | self.dim_per_head = model_dim // head_count 81 | self.model_dim = model_dim 82 | 83 | super(MultiHeadedAttention, self).__init__() 84 | self.head_count = head_count 85 | 86 | self.linear_keys = nn.Linear(model_dim, 87 | head_count * self.dim_per_head) 88 | self.linear_values = nn.Linear(model_dim, 89 | head_count * self.dim_per_head) 90 | self.linear_query = nn.Linear(model_dim, 91 | head_count * self.dim_per_head) 92 | self.softmax = nn.Softmax(dim=-1) 93 | self.dropout = nn.Dropout(dropout) 94 | self.use_final_linear = use_final_linear 95 | if (self.use_final_linear): 96 | self.final_linear = nn.Linear(model_dim, model_dim) 97 | 98 | def forward(self, key, value, query, mask=None, 99 | layer_cache=None, type=None, predefined_graph_1=None): 100 | """ 101 | Compute the context vector and the attention vectors. 102 | 103 | Args: 104 | key (`FloatTensor`): set of `key_len` 105 | key vectors `[batch, key_len, dim]` 106 | value (`FloatTensor`): set of `key_len` 107 | value vectors `[batch, key_len, dim]` 108 | query (`FloatTensor`): set of `query_len` 109 | query vectors `[batch, query_len, dim]` 110 | mask: binary mask indicating which keys have 111 | non-zero attention `[batch, query_len, key_len]` 112 | Returns: 113 | (`FloatTensor`, `FloatTensor`) : 114 | 115 | * output context vectors `[batch, query_len, dim]` 116 | * one of the attention vectors `[batch, query_len, key_len]` 117 | """ 118 | 119 | # CHECKS 120 | # batch, k_len, d = key.size() 121 | # batch_, k_len_, d_ = value.size() 122 | # aeq(batch, batch_) 123 | # aeq(k_len, k_len_) 124 | # aeq(d, d_) 125 | # batch_, q_len, d_ = query.size() 126 | # aeq(batch, batch_) 127 | # aeq(d, d_) 128 | # aeq(self.model_dim % 8, 0) 129 | # if mask is not None: 130 | # batch_, q_len_, k_len_ = mask.size() 131 | # aeq(batch_, batch) 132 | # aeq(k_len_, k_len) 133 | # aeq(q_len_ == q_len) 134 | # END CHECKS 135 | 136 | batch_size = key.size(0) 137 | dim_per_head = self.dim_per_head 138 | head_count = self.head_count 139 | key_len = key.size(1) 140 | query_len = query.size(1) 141 | 142 | def shape(x): 143 | """ projection """ 144 | return x.view(batch_size, -1, head_count, dim_per_head) \ 145 | .transpose(1, 2) 146 | 147 | def unshape(x): 148 | """ compute context """ 149 | return x.transpose(1, 2).contiguous() \ 150 | .view(batch_size, -1, head_count * dim_per_head) 151 | 152 | # 1) Project key, value, and query. 153 | if layer_cache is not None: 154 | if type == "self": 155 | query, key, value = self.linear_query(query), \ 156 | self.linear_keys(query), \ 157 | self.linear_values(query) 158 | 159 | key = shape(key) 160 | value = shape(value) 161 | 162 | if layer_cache is not None: 163 | device = key.device 164 | if layer_cache["self_keys"] is not None: 165 | key = torch.cat( 166 | (layer_cache["self_keys"].to(device), key), 167 | dim=2) 168 | if layer_cache["self_values"] is not None: 169 | value = torch.cat( 170 | (layer_cache["self_values"].to(device), value), 171 | dim=2) 172 | layer_cache["self_keys"] = key 173 | layer_cache["self_values"] = value 174 | elif type == "context": 175 | query = self.linear_query(query) 176 | if layer_cache is not None: 177 | if layer_cache["memory_keys"] is None: 178 | key, value = self.linear_keys(key), \ 179 | self.linear_values(value) 180 | key = shape(key) 181 | value = shape(value) 182 | else: 183 | key, value = layer_cache["memory_keys"], \ 184 | layer_cache["memory_values"] 185 | layer_cache["memory_keys"] = key 186 | layer_cache["memory_values"] = value 187 | else: 188 | key, value = self.linear_keys(key), \ 189 | self.linear_values(value) 190 | key = shape(key) 191 | value = shape(value) 192 | else: 193 | key = self.linear_keys(key) 194 | value = self.linear_values(value) 195 | query = self.linear_query(query) 196 | key = shape(key) 197 | value = shape(value) 198 | 199 | query = shape(query) 200 | 201 | key_len = key.size(2) 202 | query_len = query.size(2) 203 | 204 | # 2) Calculate and scale scores. 205 | query = query / math.sqrt(dim_per_head) 206 | scores = torch.matmul(query, key.transpose(2, 3)) 207 | 208 | if mask is not None: 209 | mask = mask.unsqueeze(1).expand_as(scores) 210 | scores = scores.masked_fill(mask, -1e18) 211 | 212 | # 3) Apply attention dropout and compute context vectors. 213 | 214 | attn = self.softmax(scores) 215 | 216 | if (not predefined_graph_1 is None): 217 | attn_masked = attn[:, -1] * predefined_graph_1 218 | attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) 219 | 220 | attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) 221 | 222 | drop_attn = self.dropout(attn) 223 | if (self.use_final_linear): 224 | context = unshape(torch.matmul(drop_attn, value)) 225 | output = self.final_linear(context) 226 | return output 227 | else: 228 | context = torch.matmul(drop_attn, value) 229 | return context 230 | 231 | # CHECK 232 | # batch_, q_len_, d_ = output.size() 233 | # aeq(q_len, q_len_) 234 | # aeq(batch, batch_) 235 | # aeq(d, d_) 236 | 237 | # Return one attn 238 | 239 | -------------------------------------------------------------------------------- /test.txt: -------------------------------------------------------------------------------- 1 | I am honored to be with you today at your commencement from one of the finest universities in the world. I never graduated from college. Truth be told, this is the closest I’ve ever gotten to a college graduation. Today I want to tell you three stories from my life. That’s it. No big deal. Just three stories. 2 | 3 | Related to this story 4 | 2005 Stanford Commencement coverage 5 | The first story is about connecting the dots. 6 | 7 | I dropped out of Reed College after the first 6 months, but then stayed around as a drop-in for another 18 months or so before I really quit. So why did I drop out? 8 | 9 | It started before I was born. My biological mother was a young, unwed college graduate student, and she decided to put me up for adoption. She felt very strongly that I should be adopted by college graduates, so everything was all set for me to be adopted at birth by a lawyer and his wife. Except that when I popped out they decided at the last minute that they really wanted a girl. So my parents, who were on a waiting list, got a call in the middle of the night asking: “We have an unexpected baby boy; do you want him?” They said: “Of course.” My biological mother later found out that my mother had never graduated from college and that my father had never graduated from high school. She refused to sign the final adoption papers. She only relented a few months later when my parents promised that I would someday go to college. 10 | 11 | And 17 years later I did go to college. But I naively chose a college that was almost as expensive as Stanford, and all of my working-class parents’ savings were being spent on my college tuition. After six months, I couldn’t see the value in it. I had no idea what I wanted to do with my life and no idea how college was going to help me figure it out. And here I was spending all of the money my parents had saved their entire life. So I decided to drop out and trust that it would all work out OK. It was pretty scary at the time, but looking back it was one of the best decisions I ever made. The minute I dropped out I could stop taking the required classes that didn’t interest me, and begin dropping in on the ones that looked interesting. 12 | 13 | It wasn’t all romantic. I didn’t have a dorm room, so I slept on the floor in friends’ rooms, I returned Coke bottles for the 5¢ deposits to buy food with, and I would walk the 7 miles across town every Sunday night to get one good meal a week at the Hare Krishna temple. I loved it. And much of what I stumbled into by following my curiosity and intuition turned out to be priceless later on. Let me give you one example: 14 | 15 | Reed College at that time offered perhaps the best calligraphy instruction in the country. Throughout the campus every poster, every label on every drawer, was beautifully hand calligraphed. Because I had dropped out and didn’t have to take the normal classes, I decided to take a calligraphy class to learn how to do this. I learned about serif and sans serif typefaces, about varying the amount of space between different letter combinations, about what makes great typography great. It was beautiful, historical, artistically subtle in a way that science can’t capture, and I found it fascinating. 16 | 17 | None of this had even a hope of any practical application in my life. But 10 years later, when we were designing the first Macintosh computer, it all came back to me. And we designed it all into the Mac. It was the first computer with beautiful typography. If I had never dropped in on that single course in college, the Mac would have never had multiple typefaces or proportionally spaced fonts. And since Windows just copied the Mac, it’s likely that no personal computer would have them. If I had never dropped out, I would have never dropped in on this calligraphy class, and personal computers might not have the wonderful typography that they do. Of course it was impossible to connect the dots looking forward when I was in college. But it was very, very clear looking backward 10 years later. 18 | 19 | Again, you can’t connect the dots looking forward; you can only connect them looking backward. So you have to trust that the dots will somehow connect in your future. You have to trust in something — your gut, destiny, life, karma, whatever. This approach has never let me down, and it has made all the difference in my life. 20 | 21 | My second story is about love and loss. 22 | 23 | I was lucky — I found what I loved to do early in life. Woz and I started Apple in my parents’ garage when I was 20. We worked hard, and in 10 years Apple had grown from just the two of us in a garage into a $2 billion company with over 4,000 employees. We had just released our finest creation — the Macintosh — a year earlier, and I had just turned 30. And then I got fired. How can you get fired from a company you started? Well, as Apple grew we hired someone who I thought was very talented to run the company with me, and for the first year or so things went well. But then our visions of the future began to diverge and eventually we had a falling out. When we did, our Board of Directors sided with him. So at 30 I was out. And very publicly out. What had been the focus of my entire adult life was gone, and it was devastating. 24 | 25 | I really didn’t know what to do for a few months. I felt that I had let the previous generation of entrepreneurs down — that I had dropped the baton as it was being passed to me. I met with David Packard and Bob Noyce and tried to apologize for screwing up so badly. I was a very public failure, and I even thought about running away from the valley. But something slowly began to dawn on me — I still loved what I did. The turn of events at Apple had not changed that one bit. I had been rejected, but I was still in love. And so I decided to start over. 26 | 27 | I didn’t see it then, but it turned out that getting fired from Apple was the best thing that could have ever happened to me. The heaviness of being successful was replaced by the lightness of being a beginner again, less sure about everything. It freed me to enter one of the most creative periods of my life. 28 | 29 | During the next five years, I started a company named NeXT, another company named Pixar, and fell in love with an amazing woman who would become my wife. Pixar went on to create the world’s first computer animated feature film, Toy Story, and is now the most successful animation studio in the world. In a remarkable turn of events, Apple bought NeXT, I returned to Apple, and the technology we developed at NeXT is at the heart of Apple’s current renaissance. And Laurene and I have a wonderful family together. 30 | 31 | I’m pretty sure none of this would have happened if I hadn’t been fired from Apple. It was awful tasting medicine, but I guess the patient needed it. Sometimes life hits you in the head with a brick. Don’t lose faith. I’m convinced that the only thing that kept me going was that I loved what I did. You’ve got to find what you love. And that is as true for your work as it is for your lovers. Your work is going to fill a large part of your life, and the only way to be truly satisfied is to do what you believe is great work. And the only way to do great work is to love what you do. If you haven’t found it yet, keep looking. Don’t settle. As with all matters of the heart, you’ll know when you find it. And, like any great relationship, it just gets better and better as the years roll on. So keep looking until you find it. Don’t settle. 32 | 33 | My third story is about death. 34 | 35 | When I was 17, I read a quote that went something like: “If you live each day as if it was your last, someday you’ll most certainly be right.” It made an impression on me, and since then, for the past 33 years, I have looked in the mirror every morning and asked myself: “If today were the last day of my life, would I want to do what I am about to do today?” And whenever the answer has been “No” for too many days in a row, I know I need to change something. 36 | 37 | Remembering that I’ll be dead soon is the most important tool I’ve ever encountered to help me make the big choices in life. Because almost everything — all external expectations, all pride, all fear of embarrassment or failure — these things just fall away in the face of death, leaving only what is truly important. Remembering that you are going to die is the best way I know to avoid the trap of thinking you have something to lose. You are already naked. There is no reason not to follow your heart. 38 | 39 | About a year ago I was diagnosed with cancer. I had a scan at 7:30 in the morning, and it clearly showed a tumor on my pancreas. I didn’t even know what a pancreas was. The doctors told me this was almost certainly a type of cancer that is incurable, and that I should expect to live no longer than three to six months. My doctor advised me to go home and get my affairs in order, which is doctor’s code for prepare to die. It means to try to tell your kids everything you thought you’d have the next 10 years to tell them in just a few months. It means to make sure everything is buttoned up so that it will be as easy as possible for your family. It means to say your goodbyes. 40 | 41 | I lived with that diagnosis all day. Later that evening I had a biopsy, where they stuck an endoscope down my throat, through my stomach and into my intestines, put a needle into my pancreas and got a few cells from the tumor. I was sedated, but my wife, who was there, told me that when they viewed the cells under a microscope the doctors started crying because it turned out to be a very rare form of pancreatic cancer that is curable with surgery. I had the surgery and I’m fine now. 42 | 43 | This was the closest I’ve been to facing death, and I hope it’s the closest I get for a few more decades. Having lived through it, I can now say this to you with a bit more certainty than when death was a useful but purely intellectual concept: 44 | 45 | No one wants to die. Even people who want to go to heaven don’t want to die to get there. And yet death is the destination we all share. No one has ever escaped it. And that is as it should be, because Death is very likely the single best invention of Life. It is Life’s change agent. It clears out the old to make way for the new. Right now the new is you, but someday not too long from now, you will gradually become the old and be cleared away. Sorry to be so dramatic, but it is quite true. 46 | 47 | Your time is limited, so don’t waste it living someone else’s life. Don’t be trapped by dogma — which is living with the results of other people’s thinking. Don’t let the noise of others’ opinions drown out your own inner voice. And most important, have the courage to follow your heart and intuition. They somehow already know what you truly want to become. Everything else is secondary. 48 | 49 | When I was young, there was an amazing publication called The Whole Earth Catalog, which was one of the bibles of my generation. It was created by a fellow named Stewart Brand not far from here in Menlo Park, and he brought it to life with his poetic touch. This was in the late 1960s, before personal computers and desktop publishing, so it was all made with typewriters, scissors and Polaroid cameras. It was sort of like Google in paperback form, 35 years before Google came along: It was idealistic, and overflowing with neat tools and great notions. 50 | 51 | Stewart and his team put out several issues of The Whole Earth Catalog, and then when it had run its course, they put out a final issue. It was the mid-1970s, and I was your age. On the back cover of their final issue was a photograph of an early morning country road, the kind you might find yourself hitchhiking on if you were so adventurous. Beneath it were the words: “Stay Hungry. Stay Foolish.” It was their farewell message as they signed off. Stay Hungry. Stay Foolish. And I have always wished that for myself. And now, as you graduate to begin anew, I wish that for you. 52 | 53 | Stay Hungry. Stay Foolish. --------------------------------------------------------------------------------