├── coop ├── __init__.py ├── models │ ├── __init__.py │ ├── util.py │ ├── base.py │ ├── bimeanvae.py │ └── optimus.py ├── reader.py ├── util.py ├── tokenizer.py ├── vae.py └── search.py ├── img └── overview.png ├── data └── sentencepiece │ ├── amzn.model │ └── yelp.model ├── requirements.txt ├── config ├── optimus │ ├── amzn.jsonnet │ └── yelp.jsonnet ├── bimeanvae │ ├── amzn.jsonnet │ └── yelp.jsonnet └── utils.libsonnet ├── setup.py ├── LICENSE ├── scripts ├── spm_train.py ├── get_summ.py └── preprocess.py ├── evaluate.py ├── .gitignore ├── train.py └── README.md /coop/__init__.py: -------------------------------------------------------------------------------- 1 | from . import util 2 | from .vae import VAE 3 | -------------------------------------------------------------------------------- /img/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megagonlabs/coop/HEAD/img/overview.png -------------------------------------------------------------------------------- /data/sentencepiece/amzn.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megagonlabs/coop/HEAD/data/sentencepiece/amzn.model -------------------------------------------------------------------------------- /data/sentencepiece/yelp.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/megagonlabs/coop/HEAD/data/sentencepiece/yelp.model -------------------------------------------------------------------------------- /coop/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Model 2 | from .bimeanvae import BiMeanVAE 3 | from .optimus import Optimus 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.7.0 2 | transformers~=3.4.0 3 | sentencepiece==0.1.94 4 | click 5 | py-rouge 6 | tqdm 7 | pandas 8 | jsonnet 9 | allennlp==1.2.0 10 | tensorboard 11 | requests 12 | nltk 13 | huggingface_hub -------------------------------------------------------------------------------- /coop/models/util.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import torch 4 | from torch.distributions import Normal 5 | 6 | 7 | class Losses(NamedTuple): 8 | nll: torch.Tensor 9 | zkl: torch.Tensor 10 | zkl_real: torch.Tensor 11 | 12 | 13 | class VAEOut(NamedTuple): 14 | q: Normal 15 | generated: torch.Tensor = None 16 | -------------------------------------------------------------------------------- /config/optimus/amzn.jsonnet: -------------------------------------------------------------------------------- 1 | local lib = import '../utils.libsonnet'; 2 | local data_type = "amzn"; 3 | local latent_dim = 512; 4 | local free_bit = 2.0; 5 | local num_steps = 500000; 6 | local checkout_step = 20000; 7 | local batch_size = 4; 8 | local lr = 1e-5; 9 | 10 | { 11 | "data_dir": "./data/%s" % data_type, 12 | "model": lib.Optimus(latent_dim, free_bit), 13 | "trainer": lib.VAETrainer(num_steps, checkout_step, batch_size, lr) 14 | } 15 | -------------------------------------------------------------------------------- /config/optimus/yelp.jsonnet: -------------------------------------------------------------------------------- 1 | local lib = import '../utils.libsonnet'; 2 | local data_type = "yelp"; 3 | local latent_dim = 512; 4 | local free_bit = 2.0; 5 | local num_steps = 500000; 6 | local checkout_step = 20000; 7 | local batch_size = 4; 8 | local lr = 1e-5; 9 | 10 | { 11 | "data_dir": "./data/%s" % data_type, 12 | "model": lib.Optimus(latent_dim, free_bit), 13 | "trainer": lib.VAETrainer(num_steps, checkout_step, batch_size, lr) 14 | } 15 | -------------------------------------------------------------------------------- /config/bimeanvae/amzn.jsonnet: -------------------------------------------------------------------------------- 1 | local lib = import '../utils.libsonnet'; 2 | local data_type = "amzn"; 3 | local latent_dim = 512; 4 | local free_bit = 0.25; 5 | local num_steps = 100000; 6 | local checkout_step = 1000; 7 | local batch_size = 256; 8 | local lr = 1e-3; 9 | 10 | { 11 | "data_dir": "./data/%s" % data_type, 12 | "spm_path": "./data/sentencepiece/%s.model" % data_type, 13 | "model": lib.BiMeanVAE(latent_dim, free_bit), 14 | "trainer": lib.VAETrainer(num_steps, checkout_step, batch_size, lr) 15 | } 16 | -------------------------------------------------------------------------------- /config/bimeanvae/yelp.jsonnet: -------------------------------------------------------------------------------- 1 | local lib = import '../utils.libsonnet'; 2 | local data_type = "yelp"; 3 | local latent_dim = 512; 4 | local free_bit = 0.25; 5 | local num_steps = 100000; 6 | local checkout_step = 1000; 7 | local batch_size = 256; 8 | local lr = 1e-3; 9 | 10 | { 11 | "data_dir": "./data/%s" % data_type, 12 | "spm_path": "./data/sentencepiece/%s.model" % data_type, 13 | "model": lib.BiMeanVAE(latent_dim, free_bit), 14 | "trainer": lib.VAETrainer(num_steps, checkout_step, batch_size, lr) 15 | } 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("requirements.txt") as f: 4 | required = f.read().splitlines() 5 | 6 | setup( 7 | name="coop", 8 | version="0.0.1", 9 | description="Convex Aggregation for Opinion Summarization (Iso et al; Findings of EMNLP 2021)", 10 | long_description=open("README.md", encoding="utf-8").read(), 11 | long_description_content_type="text/markdown", 12 | author="Hayate Iso", 13 | author_email="hayate@megagon.ai", 14 | url="https://github.com/megagonlabs/coop", 15 | packages=find_packages(), 16 | license="BSD", 17 | install_requires=required, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: BSD License", 21 | ], 22 | python_requires='>=3.7', 23 | ) 24 | -------------------------------------------------------------------------------- /config/utils.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | BiMeanVAE(latent_dim, free_bit):: 3 | local embedding_dim = 256; 4 | local hidden_size = 512; 5 | local num_layers = 1; 6 | { 7 | "type": "bimeanvae", 8 | "embedding_dim": embedding_dim, 9 | "hidden_size": hidden_size, 10 | "latent_dim": latent_dim, 11 | "num_layers": num_layers, 12 | "free_bit": free_bit 13 | }, 14 | 15 | Optimus(latent_dim, free_bit):: 16 | { 17 | "type": "optimus", 18 | "latent_dim": latent_dim, 19 | "free_bit": free_bit, 20 | }, 21 | 22 | VAETrainer(num_steps, checkout_step, batch_size, lr):: 23 | { 24 | "num_steps": num_steps, 25 | "checkout_step": checkout_step, 26 | "batch_size": batch_size, 27 | "lr": lr 28 | } 29 | } -------------------------------------------------------------------------------- /coop/models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Model(nn.Module): 6 | 7 | def __init__(self, 8 | hidden_size: int, 9 | latent_dim: int): 10 | super().__init__() 11 | self.hidden_size = hidden_size 12 | self.latent_dim = latent_dim 13 | 14 | def forward(self, 15 | src: torch.Tensor, 16 | tgt: torch.Tensor = None, 17 | do_generate: torch.Tensor = False, 18 | **kwargs): 19 | raise NotImplementedError() 20 | 21 | @torch.no_grad() 22 | def generate(self, 23 | z: torch.Tensor, 24 | num_beams: int = 4, 25 | max_tokens: int = 256): 26 | raise NotImplementedError() 27 | 28 | @staticmethod 29 | def klw(step: int, 30 | interval: int, 31 | r: float = 0.8, 32 | t: float = 0.0, 33 | s: int = 10000): 34 | raise NotImplementedError() 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Megagon Labs All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 7 | Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 8 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 9 | -------------------------------------------------------------------------------- /scripts/spm_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from multiprocessing import cpu_count 4 | from time import time 5 | 6 | import click 7 | import sentencepiece as spm 8 | from tqdm import tqdm 9 | 10 | PAD, UNK = "@pad@", "@@UNKNOWN@@" 11 | START, END = "@start@", "@end@" 12 | 13 | 14 | @click.command() 15 | @click.argument("train_file", type=click.Path(exists=True)) 16 | @click.argument("model_prefix", type=click.STRING) 17 | def spm_train(train_file, 18 | model_prefix): 19 | tmp_path = str(time()) 20 | with open(tmp_path, "w") as file: 21 | for x in tqdm(map(json.loads, open(train_file)), desc="Prep"): 22 | print(x["text"], file=file) 23 | 24 | spm.SentencePieceTrainer.Train(input=tmp_path, 25 | model_prefix=model_prefix, 26 | model_type="bpe", 27 | vocab_size=32000, 28 | max_sentence_length=8192, 29 | character_coverage=1., 30 | num_threads=cpu_count(), 31 | bos_piece=START, 32 | eos_piece=END, 33 | unk_piece=UNK, 34 | pad_piece=PAD, 35 | pad_id=0, 36 | bos_id=1, 37 | eos_id=2, 38 | unk_id=3) 39 | 40 | os.remove(tmp_path) 41 | ø 42 | 43 | if __name__ == '__main__': 44 | spm_train() 45 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import click 5 | import pandas as pd 6 | import rouge 7 | import torch 8 | 9 | from coop.util import load_tokenizer, load_data, build_model 10 | 11 | 12 | def evaluate(model, data, num_beams=4, debug=False): 13 | evaluator = rouge.Rouge(metrics=["rouge-n", "rouge-l"], max_n=2, limit_length=False, apply_avg=True, 14 | stemming=True, ensure_compatibility=True) 15 | hyp, ref = [], [] 16 | for x in data: 17 | out = model(x["src"], do_generate=True) 18 | summary_avg = model.generate(out.q.mean.mean(dim=0, keepdim=True), num_beams=num_beams) 19 | summary_avg = data.tokenizer.decode(summary_avg) 20 | hyp.extend(summary_avg) 21 | ref.append(x["summary"]) 22 | 23 | sums = evaluator.get_scores(hyp, ref).items() 24 | scores = {"_".join((metric, "sum", k)): v for metric, vs in sums for k, v in vs.items()} 25 | 26 | if debug: 27 | print("Generated examples:") 28 | print("\n".join(hyp[:10])) 29 | 30 | return scores 31 | 32 | 33 | @click.command() 34 | @click.argument("log_dir", type=click.Path(exists=True)) 35 | @click.option("--debug", is_flag=True) 36 | def main(log_dir, debug): 37 | log_dir = Path(log_dir) 38 | checkpoint = log_dir / "best.th" 39 | 40 | config = json.load(open(log_dir / "config.json")) 41 | src_tokenizer, tgt_tokenizer = load_tokenizer(config) 42 | _, dev, test = load_data(config, src_tokenizer, tgt_tokenizer) 43 | model = build_model(config).eval() 44 | model.load_state_dict(torch.load(checkpoint, map_location=lambda storage, loc: storage)) 45 | 46 | if torch.cuda.is_available(): 47 | model.cuda() 48 | scores = {} 49 | for data_type in ("dev", "test"): 50 | data = eval(data_type) 51 | scores[data_type] = evaluate(model, data, debug=debug) 52 | 53 | df = pd.DataFrame(scores) 54 | df.sort_index(inplace=True) 55 | print(df) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /scripts/get_summ.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import io 3 | import json 4 | import os 5 | import shutil 6 | from pathlib import Path 7 | 8 | import click 9 | import requests 10 | 11 | from preprocess import strip_text 12 | 13 | 14 | def amzn(output_dir): 15 | output_dir = Path(output_dir) 16 | output_dir.mkdir(exist_ok=True) 17 | os.system("git clone https://github.com/ixlan/Copycat-abstractive-opinion-summarizer.git") 18 | 19 | for sp in ("dev", "test"): 20 | ins = [] 21 | for x in csv.DictReader(open(f"Copycat-abstractive-opinion-summarizer/gold_summs/{sp}.csv"), 22 | dialect='excel-tab'): 23 | ins.append({ 24 | "reviews": [strip_text(x[f"rev{i + 1}"]) for i in range(8)], 25 | "summary": [x[f"summ{i + 1}"] for i in range(3)], 26 | "category": x["cat"], 27 | "prod_id": x["prod_id"] 28 | }) 29 | 30 | json.dump(ins, open(output_dir / f"{sp}.json", "w")) 31 | shutil.rmtree("Copycat-abstractive-opinion-summarizer") 32 | 33 | 34 | def yelp(output_dir): 35 | output_dir = Path(output_dir) 36 | output_dir.mkdir(exist_ok=True) 37 | file = requests.get("https://s3.us-east-2.amazonaws.com/unsup-sum/summaries_0-200_cleaned.csv").content.decode() 38 | ins = [] 39 | for x in csv.DictReader(io.StringIO(file)): 40 | ins.append({ 41 | "reviews": [strip_text(x[f"Input.original_review_{i}"]) for i in range(8)], 42 | "summary": [x["Answer.summary"]], 43 | "review_ids": [x[f"Input.original_review_{i}_id"] for i in range(8)], 44 | "business_id": x["Input.business_id"] 45 | }) 46 | json.dump(ins[:100], open(output_dir / "dev.json", "w")) 47 | json.dump(ins[100:], open(output_dir / "test.json", "w")) 48 | 49 | 50 | @click.command() 51 | @click.argument("data_type", type=click.Choice(("yelp", "amzn")), ) 52 | @click.argument("data_dir", type=click.Path() ) 53 | def main(data_type, data_dir): 54 | if data_type == "yelp": 55 | yelp(data_dir) 56 | else: 57 | amzn(data_dir) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import re 4 | import string 5 | import unicodedata 6 | from collections import defaultdict 7 | from pathlib import Path 8 | 9 | import click 10 | from sentencepiece import SentencePieceProcessor 11 | from tqdm import tqdm 12 | 13 | PRINTABLE = set(string.printable) 14 | 15 | MIN_REV_LEN = 4 16 | MAX_REV_LEN = 128 17 | 18 | 19 | def strip_text(s: str) -> str: 20 | # https://stackoverflow.com/a/518232/2809427 21 | # https://stackoverflow.com/a/8689826 22 | return re.sub(" +", " ", "".join(c for c in unicodedata.normalize("NFD", s) 23 | if unicodedata.category(c) != "Mn" and c in PRINTABLE).replace("\n", " ")) 24 | 25 | 26 | def yelp(file_path: str, spm: SentencePieceProcessor = None): 27 | d = defaultdict(list) 28 | for ins in tqdm(map(json.loads, open(file_path)), desc="Yelp"): 29 | rating = int(ins["stars"]) 30 | text = strip_text(ins["text"]) 31 | x = {"business_id": ins["business_id"], 32 | "review_id": ins["review_id"], 33 | "rating": rating, 34 | "text": text} 35 | if spm is not None: 36 | piece = spm.Encode(text) 37 | if MIN_REV_LEN <= len(piece) <= MAX_REV_LEN: 38 | x["piece"] = piece 39 | d[ins["business_id"]].append(x) 40 | else: 41 | d[ins["business_id"]].append(x) 42 | 43 | for reviews in d.values(): 44 | if len(reviews) > 10: 45 | yield from reviews 46 | 47 | 48 | def amzn(dir_path: str, spm: SentencePieceProcessor = None): 49 | p = tqdm() 50 | obs = set() 51 | for fp in Path(dir_path).glob("*.gz"): 52 | p.set_description(desc=fp.stem) 53 | d = defaultdict(list) 54 | for ins in filter(lambda x: x["asin"] not in obs, map(json.loads, gzip.open(fp, "rb"))): 55 | text = strip_text(ins["reviewText"]) 56 | rating = int(float(ins["overall"])) 57 | review_id = ins["reviewerID"] 58 | x = {"business_id": ins["asin"], 59 | "review_id": review_id, 60 | "rating": rating, 61 | "text": text} 62 | if spm is not None: 63 | piece = spm.Encode(text) 64 | if MIN_REV_LEN <= len(piece) <= MAX_REV_LEN: 65 | x["piece"] = piece 66 | d[ins["asin"]].append(x) 67 | else: 68 | d[ins["asin"]].append(x) 69 | p.update() 70 | 71 | for reviews in d.values(): 72 | if len(reviews) > 10: 73 | yield from reviews 74 | obs.update(set(d)) 75 | p.close() 76 | 77 | 78 | @click.command() 79 | @click.argument("data_type", type=click.Choice(("yelp", "amzn")), ) 80 | @click.argument("raw_file", type=click.Path(exists=True)) 81 | def main(data_type, raw_file): 82 | spm_file = Path(f"./data/sentencepiece/{data_type}.model") 83 | if spm_file.exists(): 84 | spm = SentencePieceProcessor() 85 | spm.Load(str(spm_file)) 86 | else: 87 | spm = None 88 | 89 | if data_type == "yelp": 90 | parser = yelp 91 | elif data_type == "amzn": 92 | parser = amzn 93 | else: 94 | raise KeyError() 95 | 96 | for x in parser(raw_file, spm): 97 | print(json.dumps(x)) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /coop/reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import linecache 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import torch.tensor 7 | from torch.nn.utils.rnn import pad_sequence 8 | from torch.utils.data import Dataset 9 | 10 | from .tokenizer import Tokenizer 11 | 12 | 13 | def get_length(fp): 14 | i = 0 15 | for i, _ in enumerate(open(fp), 1): 16 | pass 17 | return i 18 | 19 | 20 | class ReviewDataset(Dataset): 21 | def __init__(self, 22 | data_file, 23 | tokenizer: Tokenizer): 24 | super().__init__() 25 | assert Path(data_file).exists(), f"Data directory, {data_file}, does not exist." 26 | self.data_file = str(data_file) 27 | self.tokenizer = tokenizer 28 | 29 | self.len = get_length(data_file) 30 | 31 | def __getitem__(self, idx): 32 | ins = json.loads(linecache.getline(self.data_file, idx + 1)) 33 | x = [self.tokenizer.bos_id] + ins["piece"] + [self.tokenizer.eos_id] 34 | return torch.tensor(x), ins["text"] 35 | 36 | def __len__(self): 37 | return self.len 38 | 39 | def collate_fn(self, data: List[torch.Tensor]): 40 | tensor, reviews = zip(*data) 41 | tensor = pad_sequence(tensor, batch_first=True, padding_value=self.tokenizer.pad_id) 42 | if torch.cuda.is_available(): 43 | tensor = tensor.cuda() 44 | return {"src": tensor, "tgt": tensor, "reviews": reviews} 45 | 46 | 47 | class OptimusDataset(Dataset): 48 | def __init__(self, 49 | data_file, 50 | src_tokenizer: Tokenizer, 51 | tgt_tokenizer: Tokenizer): 52 | super().__init__() 53 | assert Path(data_file).exists(), f"Data directory, {data_file}, does not exist." 54 | self.data_file = str(data_file) 55 | self.pad, self.bos, self.eos = '', '', '' 56 | 57 | self.src_tokenizer = src_tokenizer 58 | self.tgt_tokenizer = self.tokenizer = tgt_tokenizer 59 | 60 | self.len = get_length(data_file) 61 | 62 | def __len__(self): 63 | return self.len 64 | 65 | def __getitem__(self, idx): 66 | x = json.loads(linecache.getline(self.data_file, idx + 1))["text"] 67 | return x 68 | 69 | def collate_fn(self, data: List[str]): 70 | src = self.src_tokenizer(data,) 71 | tgt = self.tgt_tokenizer(data) 72 | return {"src": src, "tgt": tgt, "reviews": data} 73 | 74 | 75 | class ReviewTest(Dataset): 76 | def __init__(self, 77 | data_file, 78 | tokenizer: Tokenizer): 79 | super().__init__() 80 | assert Path(data_file).exists(), f"Data directory, {data_file}, does not exist." 81 | self.data = json.load(open(data_file)) 82 | self.tokenizer = tokenizer 83 | 84 | self.len = len(self.data) 85 | 86 | def __getitem__(self, idx): 87 | ins = self.data[idx] 88 | reviews = ins["reviews"] 89 | summary = ins["summary"] 90 | tensor = self.tokenizer(ins["reviews"]) 91 | return {"src": tensor, "reviews": reviews, "summary": summary} 92 | 93 | def __len__(self): 94 | return self.len 95 | 96 | 97 | class OptimusTest(Dataset): 98 | def __init__(self, 99 | data_file, 100 | src_tokenizer: Tokenizer, 101 | tgt_tokenizer: Tokenizer): 102 | super().__init__() 103 | assert Path(data_file).exists(), f"Data directory, {data_file}, does not exist." 104 | self.data = json.load(open(data_file)) 105 | self.pad, self.bos, self.eos = '', '', '' 106 | 107 | self.src_tokenizer = src_tokenizer 108 | self.tgt_tokenizer = self.tokenizer = tgt_tokenizer 109 | 110 | self.len = len(self.data) 111 | 112 | def __len__(self): 113 | return self.len 114 | 115 | def __getitem__(self, idx): 116 | ins = self.data[idx] 117 | src = self.src_tokenizer(ins["reviews"]) 118 | return {"src": src, "reviews": ins["reviews"], "summary": ins["summary"]} 119 | -------------------------------------------------------------------------------- /coop/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from itertools import chain, combinations 4 | from pathlib import Path 5 | 6 | import rouge 7 | 8 | from .models import BiMeanVAE, Optimus 9 | from .reader import ReviewDataset, ReviewTest, OptimusDataset, OptimusTest 10 | from .tokenizer import Tokenizer, SpmTokenizer, BERTTokenizer, GPT2Tokenizer 11 | 12 | R1 = rouge.Rouge(metrics=["rouge-n"], max_n=1, limit_length=False, apply_avg=True, stemming=True, 13 | ensure_compatibility=True) 14 | BAD_WORDS = ["I", "i", "My", "my", "Me", "me", "We", "we", "Our", "our", "us"] 15 | 16 | 17 | def powerset(size): 18 | # https://docs.python.org/3/library/itertools.html#itertools-recipes 19 | return list(map(list, chain.from_iterable(combinations(range(size), r + 1) for r in range(size)))) 20 | 21 | 22 | def get_logger(log_dir: Path): 23 | fmt = "'%(asctime)s - %(levelname)s - %(name)s - %(message)s'" 24 | datefmt = '%m/%d/%Y %H:%M:%S' 25 | logging.basicConfig(format=fmt, datefmt=datefmt, level=logging.INFO) 26 | logger = logging.getLogger(__name__) 27 | logger.setLevel(logging.INFO) 28 | 29 | stream = logging.StreamHandler(sys.stderr) 30 | stream.setLevel(logging.INFO) 31 | stream.setFormatter(logging.Formatter(fmt=fmt, datefmt=datefmt)) 32 | 33 | file = logging.FileHandler(log_dir / "logging") 34 | file.setLevel(logging.INFO) 35 | file.setFormatter(logging.Formatter(fmt=fmt, datefmt=datefmt)) 36 | 37 | logger.addHandler(file) 38 | 39 | return logger 40 | 41 | 42 | def load_tokenizer(config: dict): 43 | model_type = config["model"]["type"] 44 | if model_type == "optimus": 45 | src_tokenizer = BERTTokenizer(config.get("device")) 46 | tgt_tokenizer = GPT2Tokenizer(config.get("device")) 47 | else: 48 | src_tokenizer = tgt_tokenizer = SpmTokenizer(config["spm_path"], config.get("device")) 49 | config["model"]["vocab_size"] = src_tokenizer.vocab_size 50 | config["model"].update({"pad_id": tgt_tokenizer.pad_id, 51 | "bos_id": tgt_tokenizer.bos_id, 52 | "eos_id": tgt_tokenizer.eos_id}) 53 | return src_tokenizer, tgt_tokenizer 54 | 55 | 56 | def load_data(config: dict, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer = None): 57 | model_type = config["model"]["type"] 58 | data_dir = Path(config["data_dir"]) 59 | if model_type == "optimus": 60 | train = OptimusDataset(data_dir / "train.jsonl", src_tokenizer=src_tokenizer, tgt_tokenizer=tgt_tokenizer) 61 | dev = OptimusTest(data_dir / "dev.json", src_tokenizer=src_tokenizer, tgt_tokenizer=tgt_tokenizer) 62 | test = OptimusTest(data_dir / "test.json", src_tokenizer=src_tokenizer, tgt_tokenizer=tgt_tokenizer) 63 | else: 64 | train = ReviewDataset(data_dir / "train.jsonl", tokenizer=src_tokenizer) 65 | dev = ReviewTest(data_dir / "dev.json", tokenizer=src_tokenizer) 66 | test = ReviewTest(data_dir / "test.json", tokenizer=src_tokenizer) 67 | 68 | return train, dev, test 69 | 70 | 71 | def build_model(config: dict): 72 | model_type = config["model"].pop("type").lower() 73 | if model_type == "bimeanvae": 74 | cls = BiMeanVAE 75 | elif model_type == "optimus": 76 | cls = Optimus 77 | else: 78 | raise ValueError(f"Model type {model_type} is not available.") 79 | return cls(**config.pop("model")) 80 | 81 | 82 | def avg(ins): 83 | return [len(x["selected"]) for x in ins] 84 | 85 | 86 | def overlap(ins): 87 | scores = [] 88 | for i in ins: 89 | s = R1.get_scores(i["predicted"], i["reviews"]) 90 | scores.append(s["rouge-1"]["f"]) 91 | return scores 92 | 93 | 94 | def oracle(ins): 95 | scores = [] 96 | for i in ins: 97 | s = R1.get_scores([i["predicted"]], [i["summary"]]) 98 | scores.append(s[f"rouge-1"]["f"]) 99 | return scores 100 | 101 | 102 | def input_output_overlap(inputs, output): 103 | r1 = rouge.Rouge(metrics=["rouge-n"], max_n=1, limit_length=False,) 104 | return r1.get_scores(output, inputs)["rouge-1"]["f"] 105 | -------------------------------------------------------------------------------- /coop/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from sentencepiece import SentencePieceProcessor 3 | 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | 7 | from transformers import BertTokenizerFast, GPT2TokenizerFast 8 | 9 | 10 | class Tokenizer: 11 | def __init__(self, device: str = None): 12 | if device is None: 13 | device = "cuda" if torch.cuda.is_available() else "cpu" 14 | self.device = torch.device(device) 15 | 16 | def __call__(self, reviews: Union[List[str], str]): 17 | raise NotImplementedError 18 | 19 | @property 20 | def bos_id(self): 21 | raise NotImplementedError 22 | 23 | @property 24 | def eos_id(self): 25 | raise NotImplementedError 26 | 27 | @property 28 | def pad_id(self): 29 | raise NotImplementedError 30 | 31 | @property 32 | def vocab_size(self): 33 | raise NotImplementedError 34 | 35 | def decode(self, 36 | ids: Union[List[List[int]], torch.Tensor]): 37 | raise NotImplementedError 38 | 39 | 40 | class SpmTokenizer(Tokenizer): 41 | def __init__(self, spm_path: str, device: str = None): 42 | super().__init__(device) 43 | self.spm = SentencePieceProcessor() 44 | self.spm.Load(spm_path) 45 | 46 | def __call__(self, reviews: Union[List[str], str]): 47 | if isinstance(reviews, str): 48 | reviews = [reviews] 49 | tensor = [torch.tensor([self.bos_id] + self.spm.Encode(r) + [self.eos_id]) for r in reviews] 50 | tensor = pad_sequence(tensor, batch_first=True, padding_value=self.pad_id) 51 | tensor = tensor.to(self.device) 52 | return tensor 53 | 54 | @property 55 | def bos_id(self): 56 | return self.spm.bos_id() 57 | 58 | @property 59 | def eos_id(self): 60 | return self.spm.eos_id() 61 | 62 | @property 63 | def pad_id(self): 64 | return self.spm.pad_id() 65 | 66 | @property 67 | def vocab_size(self): 68 | return self.spm.GetPieceSize() 69 | 70 | def get_ids(self, pieces: List[str], no_prefix: bool = False): 71 | if not no_prefix: 72 | pieces = ["▁" + p for p in pieces] 73 | return [self.spm.PieceToId(p) for p in pieces] 74 | 75 | def decode(self, 76 | ids: Union[List[List[int]], torch.Tensor]): 77 | if isinstance(ids, torch.Tensor): 78 | ids = ids.tolist() 79 | return [self.spm.DecodeIdsWithCheck(x) for x in ids] 80 | 81 | 82 | class BERTTokenizer(Tokenizer): 83 | def __init__(self, device: str = None): 84 | super().__init__(device) 85 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") 86 | 87 | def __call__(self, reviews: Union[List[str], str]): 88 | if isinstance(reviews, str): 89 | reviews = [reviews] 90 | src = self.tokenizer(reviews, padding=True, add_special_tokens=True, truncation=True, max_length=256, 91 | return_tensors="pt", ) 92 | src = {k: v.to(self.device) for k, v in src.items()} 93 | return src 94 | 95 | @property 96 | def bos_id(self): 97 | return self.tokenizer.bos_token_id 98 | 99 | @property 100 | def eos_id(self): 101 | return self.tokenizer.eos_token_id 102 | 103 | @property 104 | def pad_id(self): 105 | return self.tokenizer.pad_token_id 106 | 107 | 108 | class GPT2Tokenizer(Tokenizer): 109 | def __init__(self, device: str = None): 110 | super().__init__(device) 111 | self.pad, self.bos, self.eos = '', '', '' 112 | sp = {'pad_token': self.pad, 'bos_token': self.bos, 'eos_token': self.eos} 113 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 114 | self.tokenizer.add_special_tokens(sp) 115 | 116 | def __call__(self, reviews: Union[List[str], str]): 117 | if isinstance(reviews, str): 118 | reviews = [reviews] 119 | tgt = self.tokenizer([" ".join((self.bos, x, self.eos)) for x in reviews], 120 | padding=True, truncation=True, max_length=256, return_tensors="pt") 121 | tgt["labels"] = tgt["input_ids"] 122 | del tgt["attention_mask"] 123 | tgt = {k: v.to(self.device) for k, v in tgt.items()} 124 | return tgt 125 | 126 | @property 127 | def bos_id(self): 128 | return self.tokenizer.bos_token_id 129 | 130 | @property 131 | def eos_id(self): 132 | return self.tokenizer.eos_token_id 133 | 134 | @property 135 | def pad_id(self): 136 | return self.tokenizer.pad_token_id 137 | 138 | def get_ids(self, pieces: List[str]): 139 | ids = self.tokenizer.convert_tokens_to_ids(pieces) 140 | ids += self.tokenizer.convert_tokens_to_ids(["Ġ" + w for w in pieces]) 141 | return [[w] for w in ids] 142 | 143 | def decode(self, 144 | ids: Union[List[List[int]], torch.Tensor]): 145 | if isinstance(ids, torch.Tensor): 146 | ids = ids.tolist() 147 | return [x.strip() for x in self.tokenizer.batch_decode(ids, skip_special_tokens=True,)] 148 | -------------------------------------------------------------------------------- /coop/vae.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import tarfile 4 | import tempfile 5 | from pathlib import Path 6 | from typing import List, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from .util import load_tokenizer, build_model 12 | 13 | AVAILABLE_MODELS = {"megagonlabs/bimeanvae-yelp", 14 | "megagonlabs/bimeanvae-amzn", 15 | "megagonlabs/optimus-yelp", 16 | "megagonlabs/optimus-amzn"} 17 | 18 | 19 | class VAE(nn.Module): 20 | def __init__(self, model_name_or_path: str, device: str = None): 21 | super().__init__() 22 | if device is None: 23 | if torch.cuda.is_available(): 24 | device = "cuda:0" 25 | else: 26 | device = "cpu" 27 | self.device = torch.device(device) 28 | 29 | # Find the model path 30 | if Path(model_name_or_path).exists(): 31 | tempdir = tempfile.mkdtemp() 32 | try: 33 | # Extract archive 34 | with tarfile.open(model_name_or_path, "r:gz") as archive: 35 | 36 | import os 37 | 38 | def is_within_directory(directory, target): 39 | 40 | abs_directory = os.path.abspath(directory) 41 | abs_target = os.path.abspath(target) 42 | 43 | prefix = os.path.commonprefix([abs_directory, abs_target]) 44 | 45 | return prefix == abs_directory 46 | 47 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 48 | 49 | for member in tar.getmembers(): 50 | member_path = os.path.join(path, member.name) 51 | if not is_within_directory(path, member_path): 52 | raise Exception("Attempted Path Traversal in Tar File") 53 | 54 | tar.extractall(path, members, numeric_owner=numeric_owner) 55 | 56 | 57 | safe_extract(archive, tempdir) 58 | model_dir = Path(tempdir) 59 | # Load model 60 | config = json.load(open(model_dir / "config.json")) 61 | config["device"] = self.device 62 | model_path = model_dir / "pytorch_model.bin" 63 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) 64 | 65 | finally: 66 | # Clean-up 67 | shutil.rmtree(tempdir, ignore_errors=True) 68 | 69 | else: 70 | assert str(model_name_or_path) in AVAILABLE_MODELS, AVAILABLE_MODELS 71 | # Lazy import 72 | from huggingface_hub import hf_hub_url, cached_download 73 | config_url = hf_hub_url(str(model_name_or_path), filename="config.json") 74 | config = json.load(open(cached_download(url=config_url, library_name="coop"))) 75 | model_url = hf_hub_url(str(model_name_or_path), filename="pytorch_model.bin") 76 | model_path = cached_download(url=model_url, library_name="coop") 77 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) 78 | 79 | if "bimeanvae" in str(model_name_or_path): 80 | spm_url = hf_hub_url(str(model_name_or_path), filename="spm.model") 81 | spm_path = cached_download(url=spm_url, library_name="coop") 82 | config["spm_path"] = spm_path 83 | 84 | self.src_tokenizer, self.tgt_tokenizers = load_tokenizer(config) 85 | self.model = build_model(config).eval() 86 | self.model.load_state_dict(state_dict) 87 | self.model.to(self.device) 88 | 89 | @torch.no_grad() 90 | def encode(self, 91 | reviews: Union[List[str], str], 92 | device: str = None): 93 | if isinstance(reviews, str): 94 | reviews = [reviews] 95 | 96 | if device is None: 97 | self.to(self.device) 98 | src = self.src_tokenizer(reviews) 99 | return self.model(src).q.loc 100 | 101 | @torch.no_grad() 102 | def generate(self, 103 | z: torch.Tensor, 104 | num_beams: int = 4, 105 | max_tokens: int = 256, 106 | bad_words: Union[str, List[str], List[int]] = None): 107 | if z.dim() == 1: 108 | z = z.unsqueeze(0) 109 | 110 | if bad_words is not None: 111 | if isinstance(bad_words, str): 112 | bad_words = [bad_words] 113 | if isinstance(bad_words[0], str): 114 | bad_words_ids = self.tgt_tokenizers.get_ids(bad_words) 115 | else: 116 | bad_words_ids = bad_words 117 | else: 118 | bad_words_ids = None 119 | 120 | return self.tgt_tokenizers.decode(self.model.generate( 121 | z=z, 122 | num_beams=num_beams, 123 | max_tokens=max_tokens, 124 | bad_words_ids=bad_words_ids)) 125 | -------------------------------------------------------------------------------- /coop/search.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import tarfile 4 | import tempfile 5 | from multiprocessing import Pool, cpu_count 6 | from pathlib import Path 7 | from typing import List, Union 8 | 9 | import click 10 | import pandas as pd 11 | import rouge 12 | import torch 13 | 14 | from coop.models import Model 15 | from coop.reader import ReviewTest, OptimusTest 16 | from coop.tokenizer import Tokenizer 17 | from coop.util import avg, overlap, oracle, load_tokenizer, load_data, build_model, BAD_WORDS, powerset 18 | 19 | 20 | def brute_force_gen(model: Model, 21 | data: Union[ReviewTest, OptimusTest], 22 | tgt_tokenizer: Tokenizer, 23 | num_beams: int = 4, 24 | bad_words_ids: List[int] = None, 25 | split: int = 1, ): 26 | outs = [] 27 | for i, x in enumerate(data): 28 | z_raw = model(**x).q.loc 29 | idxes = powerset(z_raw.size(0)) 30 | zs = torch.stack([z_raw[idx].mean(dim=0) for idx in idxes]) 31 | gens = [] 32 | for z in torch.split(zs, len(idxes) // split): 33 | g = model.generate(z, num_beams=num_beams, bad_words_ids=bad_words_ids) 34 | gens.extend(tgt_tokenizer.decode(g)) 35 | outs.append([{"selected": [x["reviews"][i] for i in idx], 36 | "reviews": x["reviews"], 37 | "summary": x["summary"], 38 | "predicted": gen, 39 | "idx": idx} for idx, gen in zip(idxes, gens)]) 40 | return outs 41 | 42 | 43 | @click.command() 44 | @click.argument("log_dir_or_file", type=click.Path(exists=True)) 45 | @click.option("--split", type=click.INT, default=1) 46 | def main(log_dir_or_file, split): 47 | log_dir_or_file = Path(log_dir_or_file) 48 | tempdir = None 49 | if not log_dir_or_file.is_dir(): 50 | # Extract archive 51 | tempdir = tempfile.mkdtemp() 52 | with tarfile.open(log_dir_or_file, "r:gz") as archive: 53 | 54 | import os 55 | 56 | def is_within_directory(directory, target): 57 | 58 | abs_directory = os.path.abspath(directory) 59 | abs_target = os.path.abspath(target) 60 | 61 | prefix = os.path.commonprefix([abs_directory, abs_target]) 62 | 63 | return prefix == abs_directory 64 | 65 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 66 | 67 | for member in tar.getmembers(): 68 | member_path = os.path.join(path, member.name) 69 | if not is_within_directory(path, member_path): 70 | raise Exception("Attempted Path Traversal in Tar File") 71 | 72 | tar.extractall(path, members, numeric_owner=numeric_owner) 73 | 74 | 75 | safe_extract(archive, tempdir) 76 | log_dir = Path(tempdir) 77 | else: 78 | log_dir = Path(log_dir_or_file) 79 | 80 | config = json.load(open(log_dir / "config.json")) 81 | src_tokenizer, tgt_tokenizer = load_tokenizer(config) 82 | _, dev, test = load_data(config, src_tokenizer, tgt_tokenizer) 83 | bad_words_ids = tgt_tokenizer.get_ids(BAD_WORDS) 84 | if config["model"]["type"].lower() == "bimeanvae" and config["data_dir"].endswith("amzn"): 85 | # The amzn dataset often includes the pronoun I without prefix. To avoid the issue, this tweak is applied. 86 | bad_words_ids.extend(tgt_tokenizer.get_ids("I", no_prefix=True)) 87 | model = build_model(config).eval() 88 | 89 | model.load_state_dict(torch.load(log_dir / "pytorch_model.bin", map_location=lambda storage, loc: storage)) 90 | if torch.cuda.is_available(): 91 | model.cuda() 92 | 93 | dev_gen = brute_force_gen(model, dev, tgt_tokenizer, bad_words_ids=bad_words_ids, split=split) 94 | test_gen = brute_force_gen(model, test, tgt_tokenizer, bad_words_ids=bad_words_ids, split=split) 95 | 96 | coop = {} 97 | evaluator = rouge.Rouge(metrics=["rouge-n", "rouge-l"], max_n=2, limit_length=False, apply_avg=True) 98 | 99 | with Pool(cpu_count()) as p: 100 | for func in (avg, overlap, oracle): 101 | name = func.__name__ 102 | coop[name] = {} 103 | for key, val in (("dev", dev_gen), ("test", test_gen)): 104 | coop_score = p.map(func, val) 105 | index = list(range(len(val[0]))) 106 | index = [max(index, key=lambda x: s[x]) for s in coop_score] 107 | selected = [v[i] for i, v in zip(index, val)] 108 | rouge_score = evaluator.get_scores( 109 | [x["predicted"] for x in selected], [x["summary"] for x in selected]) 110 | rouge_score = {"_".join((metric, k)): v for metric, vs in rouge_score.items() for k, v in 111 | vs.items()} 112 | coop[name][key] = { 113 | "coop_score": coop_score, 114 | "index": index, 115 | "rouge": rouge_score} 116 | 117 | df = pd.DataFrame({k: coop[name][k]["rouge"] for k in ("dev", "test")}) 118 | df.sort_index(inplace=True) 119 | print(name) 120 | print(df) 121 | 122 | # Clean-up 123 | if tempdir is not None: 124 | shutil.rmtree(tempdir, ignore_errors=True) 125 | else: 126 | json.dump(coop, open(log_dir / "coop.json", "w")) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /coop/models/bimeanvae.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from allennlp.nn.beam_search import BeamSearch 7 | from torch.distributions import Normal, kl_divergence 8 | 9 | from . import Model 10 | from .util import Losses, VAEOut 11 | 12 | 13 | def masked_mean(vector: torch.Tensor, 14 | mask: torch.Tensor, 15 | dim: int, 16 | keepdim: bool = False, 17 | eps: float = 1e-8) -> torch.Tensor: 18 | one_minus_mask = (1.0 - mask.float()).to(dtype=torch.bool) 19 | replaced_vector = vector.masked_fill(one_minus_mask, 0.0) 20 | 21 | value_sum = torch.sum(replaced_vector, dim=dim, keepdim=keepdim) 22 | value_count = torch.sum(mask.float(), dim=dim, keepdim=keepdim) 23 | return value_sum / value_count.clamp(eps) 24 | 25 | 26 | class BiMeanVAE(Model): 27 | def __init__(self, 28 | vocab_size: int, 29 | embedding_dim: int, 30 | hidden_size: int, 31 | latent_dim: int, 32 | pad_id: int, 33 | bos_id: int, 34 | eos_id: int, 35 | num_layers: int = 1, 36 | free_bit: float = 0.05): 37 | super().__init__(hidden_size, latent_dim) 38 | self.vocab_size = vocab_size 39 | self.embedding_dim = embedding_dim 40 | self.num_layers = num_layers 41 | 42 | self.pad_id = pad_id 43 | self.bos_id = bos_id 44 | self.eos_id = eos_id 45 | 46 | self.embed = nn.Embedding(vocab_size, embedding_dim) 47 | self.encoder = nn.LSTM(embedding_dim, hidden_size // 2, num_layers, batch_first=True, bidirectional=True) 48 | 49 | self.proj_z = nn.Linear(hidden_size, 2 * latent_dim) 50 | self.proj_dec = nn.Sequential(nn.Linear(latent_dim, 2 * hidden_size), 51 | nn.Tanh()) 52 | 53 | self.decoder = nn.LSTMCell(embedding_dim + latent_dim, hidden_size) 54 | self.output_layer = nn.Sequential(nn.Linear(hidden_size, embedding_dim), 55 | nn.Linear(embedding_dim, vocab_size, bias=True)) 56 | self.beam = BeamSearch(self.eos_id, max_steps=128, beam_size=4, ) 57 | self.bad_words = set() 58 | # Tying weight 59 | self.output_layer[-1].weight = self.embed.weight 60 | 61 | self.free_bit = free_bit 62 | 63 | def forward(self, 64 | src: torch.Tensor, 65 | tgt: torch.Tensor = None, 66 | do_generate: torch.Tensor = False, 67 | num_beams: int = 4, 68 | **kwargs): 69 | embed = self.embed(src) 70 | input_mask = torch.ne(src, self.pad_id) 71 | 72 | # Encoding 73 | input_length = input_mask.sum(dim=1).cpu().tolist() 74 | packed_embed = nn.utils.rnn.pack_padded_sequence(embed, input_length, batch_first=True, enforce_sorted=False) 75 | encoded = self.encoder(packed_embed)[0] 76 | encoded = nn.utils.rnn.pad_packed_sequence(encoded, batch_first=True)[0] 77 | encoded = masked_mean(encoded, input_mask.unsqueeze(-1), dim=1) 78 | mu, log_var = torch.chunk(self.proj_z(encoded), chunks=2, dim=-1) 79 | std = torch.exp(0.5 * log_var) 80 | 81 | q = Normal(mu, std) 82 | p = Normal(0, 1) 83 | zkl_real = kl_divergence(q, p) 84 | kl_mask = torch.gt(zkl_real, self.free_bit) 85 | bz = embed.size(0) 86 | zkl = zkl_real[kl_mask].sum() / bz 87 | zkl_real = zkl_real.sum(dim=-1).mean() 88 | if self.training: 89 | z = q.rsample() 90 | assert tgt is not None 91 | # Training: 92 | nll = self._recon_loss(tgt, z) 93 | return Losses(nll=nll, zkl=zkl, zkl_real=zkl_real) 94 | else: 95 | # Inference 96 | z = mu 97 | if do_generate: 98 | generated = self.generate(z, num_beams=num_beams) 99 | else: 100 | generated = None 101 | return VAEOut(q=q, generated=generated) 102 | 103 | def _recon_loss(self, 104 | tgt_tensor: torch.Tensor, 105 | z: torch.Tensor): 106 | targets = tgt_tensor[:, 1:] 107 | embed = self.embed(tgt_tensor[:, :-1]) 108 | bz, max_len, _ = embed.size() 109 | hx, cx = torch.chunk(self.proj_dec(z), 2, dim=-1) 110 | recon_losses = [] 111 | for t in range(max_len - 1): 112 | hx, cx = self.decoder(torch.cat((embed[:, t], z), dim=-1), (hx, cx)) 113 | tgt_mask = targets[:, t] != self.pad_id 114 | non_masked_targets = targets[:, t].masked_select(tgt_mask) 115 | non_masked_embeddings = hx.masked_select(tgt_mask.unsqueeze(-1)).view(-1, self.hidden_size) 116 | non_masked_outs = self.output_layer(non_masked_embeddings) 117 | recon_losses.append(F.cross_entropy(non_masked_outs, non_masked_targets, reduction="sum")) 118 | 119 | recon_loss = torch.sum(torch.stack(recon_losses)) / bz 120 | return recon_loss 121 | 122 | def generate(self, 123 | z: torch.Tensor, 124 | num_beams: int = 4, 125 | max_tokens: int = 256, 126 | bad_words_ids: List[int] = None): 127 | self.eval() 128 | if bad_words_ids: 129 | self.bad_words.update(bad_words_ids) 130 | bz, device = len(z), z.device 131 | start_predictions = torch.full((bz,), fill_value=self.bos_id, dtype=torch.long, device=device) 132 | hx, cx = torch.chunk(self.proj_dec(z), 2, dim=-1) 133 | decoder_state = {"z": z, "hx": hx, "cx": cx} 134 | self.beam.beam_size = num_beams 135 | self.beam.max_steps = max_tokens 136 | all_top_k_predictions, _ = self.beam.search(start_predictions, 137 | decoder_state, 138 | self.step) 139 | self.bad_words.clear() 140 | return all_top_k_predictions[:, 0] 141 | 142 | @torch.no_grad() 143 | def step(self, 144 | last_predictions: torch.Tensor, 145 | state: Dict[str, torch.Tensor]): 146 | z = state["z"] 147 | hx, cx = self.decoder(torch.cat((self.embed(last_predictions), z), dim=-1), (state["hx"], state["cx"])) 148 | new_state = {"z": z, "hx": hx, "cx": cx} 149 | log_softmax = torch.nn.functional.log_softmax(self.output_layer(hx), dim=-1) 150 | if self.bad_words: 151 | log_softmax[:, list(self.bad_words)] = float("-inf") 152 | return log_softmax, new_state 153 | 154 | @staticmethod 155 | def klw(step: int, 156 | interval: int, 157 | r: float = 0.8, 158 | t: float = 0.0, 159 | s: int = 10000): 160 | if step < s: 161 | return 0. 162 | else: 163 | return min((step - s) / s, 1) 164 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import tarfile 4 | from collections import defaultdict 5 | from pathlib import Path 6 | from time import time 7 | from typing import Dict, List 8 | 9 | import click 10 | import pandas as pd 11 | import torch 12 | from _jsonnet import evaluate_file 13 | from torch.nn.utils import clip_grad_norm_, clip_grad_value_ 14 | from torch.optim import Adam 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | from tqdm import tqdm 18 | from transformers import get_linear_schedule_with_warmup 19 | 20 | from coop.models import Model, BiMeanVAE, Optimus 21 | from coop.util import get_logger, load_tokenizer, load_data, build_model 22 | from evaluate import evaluate 23 | 24 | 25 | class Trainer: 26 | def __init__(self, 27 | model: Model, 28 | data: List[Dataset], 29 | log_dir: Path, 30 | num_steps: int, 31 | checkout_step: int, 32 | batch_size: int, 33 | lr: float = 1e-4, 34 | clip_value: float = 5., 35 | max_norm: float = 1., 36 | num_keep: int = 10): 37 | log_dir = Path(log_dir) 38 | if torch.cuda.is_available(): 39 | model.cuda() 40 | 41 | self.model = model 42 | self.train, self.dev, self.test = data 43 | 44 | self.opt = Adam(self.model.parameters(), lr, betas=(0.5, 0.999), eps=1e-6, ) 45 | 46 | self.scheduler = get_linear_schedule_with_warmup(self.opt, checkout_step // 10, num_steps) 47 | 48 | self.clip_value = clip_value 49 | self.max_norm = max_norm 50 | self.num_steps = num_steps 51 | self.checkout_step = checkout_step 52 | self.batch_size = batch_size 53 | self.log_dir = log_dir 54 | self.logger = get_logger(log_dir) 55 | self.losses = defaultdict(list) 56 | self.best_score = 0. 57 | self.writer = {key: SummaryWriter(log_dir=str(log_dir / "log" / key)) for key in ("train", "dev", "test")} 58 | self.global_step = 0 59 | self.num_keep = num_keep 60 | self.model_path = [] 61 | 62 | @classmethod 63 | def from_config(cls, 64 | config: dict, 65 | log_dir: Path): 66 | json.dump(config, open(log_dir / "config.json", "w")) 67 | if "spm_path" in config: 68 | shutil.copy(config["spm_path"], log_dir / "spm.model") 69 | tokenizers = load_tokenizer(config) 70 | data = load_data(config, *tokenizers) 71 | model = build_model(config) 72 | 73 | return cls(model, data, log_dir, **config.pop("trainer")) 74 | 75 | def _fit_partial(self, 76 | batch, 77 | p: tqdm = None): 78 | self.model.train() 79 | self.model.zero_grad() 80 | losses = self.model(**batch) 81 | nll, zkl, zkl_real = losses.nll, losses.zkl, losses.zkl_real 82 | klw = self.model.klw(self.global_step, self.checkout_step) 83 | loss = nll + klw * zkl 84 | loss.backward() 85 | if isinstance(self.model, Optimus): 86 | clip_grad_norm_(self.model.parameters(), self.max_norm) 87 | else: 88 | clip_grad_value_(self.model.parameters(), self.clip_value) 89 | 90 | loss_dict = {"nll": nll.item(), "klw": klw, "zkl": zkl.item(), "zkl_real": zkl_real.item()} 91 | 92 | self.opt.step() 93 | self.scheduler.step() 94 | 95 | if p is not None: 96 | for k, v in loss_dict.items(): 97 | self.writer["train"].add_scalar(f"Loss/{k}", v, global_step=self.global_step) 98 | self.losses[k].append(v) 99 | p.set_postfix(**loss_dict) 100 | p.update() 101 | 102 | def fit(self): 103 | train = DataLoader(self.train, batch_size=self.batch_size, shuffle=True, collate_fn=self.train.collate_fn) 104 | p = tqdm(desc=f"Step {self.global_step}", total=self.checkout_step, ncols=100) 105 | 106 | while True: 107 | for batch in train: 108 | self.global_step += 1 109 | self._fit_partial(batch, p=p) 110 | if self.global_step % self.checkout_step == 0: 111 | losses = self._avg_loss(p) 112 | self._archive(losses) 113 | p.close() 114 | self._evaluate() 115 | if isinstance(self.model, BiMeanVAE) and self.global_step == 10000: 116 | self.logger.info("Reset LSTM decoder") 117 | self.model.decoder.reset_parameters() 118 | if self.global_step == self.num_steps: 119 | self._finalize() 120 | return 121 | p = tqdm(desc=f"Step {self.global_step}", total=self.checkout_step, ncols=100) 122 | 123 | def _finalize(self): 124 | archive_file = self.log_dir / "model.tar.gz" 125 | with tarfile.open(archive_file, "w:gz") as archive: 126 | archive.add(self.log_dir / "config.json", arcname="config.json") 127 | archive.add(self.log_dir / "best.th", arcname="pytorch_model.bin") 128 | if isinstance(self.model, BiMeanVAE): 129 | archive.add(self.log_dir / "spm.model", arcname="spm.model") 130 | 131 | def _evaluate(self): 132 | self.model.eval() 133 | # Summarize 134 | metrics = {} 135 | for data_type in ("dev", "test"): 136 | data = getattr(self, data_type) 137 | metrics[data_type] = evaluate(self.model, data, debug=True) 138 | for k, v in metrics[data_type].items(): 139 | metric, tgt, key = k.split("_") 140 | self.writer[data_type].add_scalar(f"Metrics/{tgt}/{metric}/{key}/", v, global_step=self.global_step) 141 | 142 | df = pd.DataFrame(metrics) 143 | df.sort_index(inplace=True) 144 | print(df) 145 | json.dump(metrics, open(self.log_dir / f"metrics-step_{self.global_step}.json", "w")) 146 | dev_scores = {f"R{i}": df["dev"][f"rouge-{i}_sum_f"] for i in "12l"} 147 | if sum(dev_scores.values()) > self.best_score: 148 | self.best_score = sum(dev_scores.values()) 149 | shutil.copy(self.log_dir / f"metrics-step_{self.global_step}.json", self.log_dir / "metrics.json") 150 | shutil.copy(self.log_dir / f"model-step_{self.global_step}.th", self.log_dir / "best.th") 151 | shutil.copy(self.log_dir / f"training_metrics-step_{self.global_step}.json", 152 | self.log_dir / "training_metrics.json") 153 | self.logger.info("Best scores") 154 | for k, v in dev_scores.items(): 155 | self.logger.info(f"DEV: {k}={100 * v:.2f}") 156 | test_scores = {f"R{i}": df["test"][f"rouge-{i}_sum_f"] for i in (1, 2, "l")} 157 | for k, v in test_scores.items(): 158 | self.logger.info(f"TEST: {k}={100 * v:.2f}") 159 | 160 | def _archive(self, 161 | losses: Dict[str, float]): 162 | model_path = self.log_dir / f"model-step_{self.global_step}.th" 163 | torch.save(self.model.state_dict(), model_path) 164 | json.dump(losses, open(self.log_dir / f"training_metrics-step_{self.global_step}.json", "w")) 165 | self.model_path.append(model_path) 166 | if len(self.model_path) > self.num_keep: 167 | self.model_path.pop(0).unlink() 168 | 169 | def _avg_loss(self, 170 | p: tqdm): 171 | losses = {k: sum(v) / len(v) for k, v in self.losses.items()} 172 | losses["klw"] = 1. 173 | p.set_postfix(**losses) 174 | p.update() 175 | self.losses.clear() 176 | return losses 177 | 178 | 179 | @click.command() 180 | @click.argument("config_file", type=click.Path(exists=True)) 181 | @click.option("--log_dir", "-s", type=click.Path(), default=f"/tmp/{str(int(time()))}") 182 | def main(config_file, log_dir): 183 | log_dir = Path(log_dir) 184 | log_dir.mkdir(parents=True) 185 | 186 | config = json.loads(evaluate_file(config_file)) 187 | 188 | trainer = Trainer.from_config(config, log_dir) 189 | trainer.fit() 190 | 191 | 192 | if __name__ == '__main__': 193 | main() 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convex Aggregation for Opinion Summarization 2 | 3 | [![Conference](https://img.shields.io/badge/findings_of_emnlp-2021-red)](https://aclanthology.org/2021.findings-emnlp.328) 4 | [![arXiv](https://img.shields.io/badge/arxiv-2104.01371-success)](https://arxiv.org/abs/2104.01371/) 5 | [![arXiv](https://img.shields.io/badge/colab-demo-yellow)](https://colab.research.google.com/drive/1kyWw9H6TBfpuVrQH_35ofeScX1E-DSpb?usp=sharing) 6 | 7 | Code for [Convex Aggregation for Opinion Summarization](https://arxiv.org/abs/2104.01371). 8 | 9 | The codebase provides an easy-to-use framework that enables the user to train and use text VAE models with different configurations. 10 | 11 | You can also easily configure the architecture of the text VAE model without changing the code at all. You need to use a different Jsonnet file (perhaps with some modification) to train and use a model. 12 | 13 | ![Coop](./img/overview.png) 14 | 15 | ## Citations 16 | ```bibtex 17 | @inproceedings{iso21emnlpfindings, 18 | title = {{C}onvex {A}ggregation for {O}pinion {S}ummarization}, 19 | author = {Hayate Iso and 20 | Xiaolan Wang and 21 | Yoshihiko Suhara and 22 | Stefanos Angelidis and 23 | Wang{-}Chiew Tan}, 24 | booktitle = {Findings of the Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 25 | month = {November}, 26 | year = {2021} 27 | } 28 | ``` 29 | 30 | ## Installation 31 | ```bash 32 | conda create -n coop python=3.7 33 | conda activate coop 34 | conda install -c conda-forge jsonnet sentencepiece # If needed 35 | pip install git+https://github.com/megagonlabs/coop.git 36 | ``` 37 | or 38 | ``` 39 | git clone https://github.com/megagonlabs/coop.git 40 | cd coop 41 | pip install -e . # or python setup.py develop 42 | ``` 43 | 44 | ## Quick tour 45 | Our unsupervised opinion summarization model can generate a summary by decoding the aggregated latent vectors of inputs. 46 | The proposed framework, ```coop``` will find the best summary based on the input-output overlap. 47 | Here you can firstly encode the input reviews, ```reviews```, into the latent vectors, ```z_raw```: 48 | ```python 49 | from typing import List 50 | import torch 51 | from coop import VAE, util 52 | 53 | model_name: str = "megagonlabs/bimeanvae-yelp" # or "megagonlabs/bimeanvae-amzn", "megagonlabs/optimus-yelp", "megagonlabs/optimus-amzn" 54 | vae = VAE(model_name) 55 | 56 | reviews: List[str] = [ 57 | "I love this ramen shop!! Highly recommended!!", 58 | "Here is one of my favorite ramen places! You must try!" 59 | ] 60 | z_raw: torch.Tensor = vae.encode(reviews) # [num_reviews * latent_size] 61 | ``` 62 | Given the latent vectors for input reviews, the model generates summaries from all combinations of latent vectors: 63 | ```python 64 | # All combinations of input reviews 65 | idxes: List[List[int]] = util.powerset(len(reviews)) 66 | # Taking averages for all combinations of latent vectors 67 | zs: torch.Tensor = torch.stack([z_raw[idx].mean(dim=0) for idx in idxes]) # [2^num_reviews - 1 * latent_size] 68 | 69 | outputs: List[str] = vae.generate(zs) 70 | outputs 71 | ``` 72 | Then, the output looks like this: 73 | ```shell 74 | ['I love this restaurant!! Highly recommended!!', 75 | 'Here is one of my favorite ramen places! You must try this place!', 76 | 'I love this place! Food is amazing!!'] 77 | ``` 78 | Finally, our framework, Coop, selects the summary based on the input-output overlap: 79 | ```python 80 | # Input-output overlap is measured by ROUGE-1 F1 score. 81 | best: str = max(outputs, key=lambda x: util.input_output_overlap(inputs=reviews, output=x)) 82 | best 83 | ``` 84 | 85 | Then, the selected summary based on the input-output overlap looks like this: 86 | ```shell 87 | 'Here is one of my favorite ramen places! You must try this place!' 88 | ``` 89 | 90 | ## Evaluate on Dev/Test set 91 | You can easily get the generated examples and evaluate their performance with only 30 lines of code! 92 | Before doing so, you need to download the dev/test set by running the following command. 93 | ```bash 94 | # Download dev and test set for evaluation 95 | python scripts/get_summ.py yelp data/yelp 96 | python scripts/get_summ.py amzn data/amzn 97 | ``` 98 | 99 | Then, you can get the generated examples as follows! 100 | ```python 101 | import json 102 | from typing import List 103 | import pandas as pd 104 | import torch 105 | import rouge 106 | from coop import VAE, util 107 | 108 | task = "yelp" # or "amzn" 109 | split = "dev" # or "test" 110 | data: List[dict] = json.load(open(f"./data/{task}/{split}.json")) 111 | model_name: str = f"megagonlabs/bimeanvae-{task}" # or f"megagonlabs/optimus-{task}" 112 | vae = VAE(model_name) 113 | 114 | hypothesis = [] 115 | for ins in data: 116 | reviews: List[str] = ins["reviews"] 117 | z_raw: torch.Tensor = vae.encode(reviews) 118 | idxes: List[List[int]] = util.powerset(len(reviews)) 119 | zs: torch.Tensor = torch.stack([z_raw[idx].mean(dim=0) for idx in idxes]) # [2^num_reviews - 1 * latent_size] 120 | 121 | outputs: List[str] = vae.generate(zs, bad_words=util.BAD_WORDS) # First-person pronoun blocking 122 | best: str = max(outputs, key=lambda x: util.input_output_overlap(inputs=reviews, output=x)) 123 | hypothesis.append(best) 124 | 125 | reference: List[List[str]] = [ins["summary"] for ins in data] 126 | 127 | evaluator = rouge.Rouge(metrics=["rouge-n", "rouge-l"], max_n=2, limit_length=False, apply_avg=True, 128 | stemming=True, ensure_compatibility=True) 129 | 130 | scores = pd.DataFrame(evaluator.get_scores(hypothesis, reference)) 131 | scores 132 | ``` 133 | 134 | # Available models 135 | All models are hosted on huggingface :hugs: model hub (https://huggingface.co/megagonlabs/). 136 | 137 | 138 | | Model name | Training Data | Encoder | Decoder | 139 | | :-------------------------------------------------------------- | :-------------:|:---------------------:|:-------:| 140 | | [megagonlabs/bimeanvae-yelp](https://huggingface.co/megagonlabs/bimeanvae-yelp) | Yelp | BiLSTM + Mean Pooling | LSTM | 141 | | [megagonlabs/optimus-yelp](https://huggingface.co/megagonlabs/optimus-yelp) | Yelp | bert-base-cased | gpt2 | 142 | | [megagonlabs/bimeanvae-amzn](https://huggingface.co/megagonlabs/bimeanvae-amzn) | Amazon | BiLSTM + Mean Pooling | LSTM | 143 | | [megagonlabs/optimus-amzn](https://huggingface.co/megagonlabs/optimus-amzn) | Amazon | bert-base-cased | gpt2 | 144 | 145 | ```VAE``` automatically downloads model checkpoints from the model hub. 146 | 147 | ## Summarization Performance 148 | ### Yelp dataset [(Chu and Liu, 2019)](https://github.com/sosuperic/MeanSum) 149 | 150 | | Model name | Aggregation | ROUGE-1 F1 | ROUGE-2 F1 | ROUGE-L F1 | 151 | | :-------------------------------------------------------------- |:-----------:|:----------:|:----------:|:----------:| 152 | | [megagonlabs/bimeanvae-yelp](https://huggingface.co/megagonlabs/bimeanvae-yelp) | SimpleAvg | 32.87 | 6.93 | 19.89 | 153 | | [megagonlabs/bimeanvae-yelp](https://huggingface.co/megagonlabs/bimeanvae-yelp) | Coop | **35.37** | **7.35** | **19.94** | 154 | | [megagonlabs/optimus-yelp](https://huggingface.co/megagonlabs/optimus-yelp) | SimpleAvg | 31.23 | 6.48 | 18.27 | 155 | | [megagonlabs/optimus-yelp](https://huggingface.co/megagonlabs/optimus-yelp) | Coop | 33.68 | 7.00 | 18.95 | 156 | 157 | 158 | ### Amazon dataset [(Bražinskas et al., 2020)](https://github.com/abrazinskas/Copycat-abstractive-opinion-summarizer) 159 | | Model name | Aggregation | ROUGE-1 F1 | ROUGE-2 F1 | ROUGE-L F1 | 160 | | :-------------------------------------------------------------- |:-----------:|:----------:|:----------:|:----------:| 161 | | [megagonlabs/bimeanvae-amzn](https://huggingface.co/megagonlabs/bimeanvae-amzn) | SimpleAvg | 33.60 | 6.64 | 20.87 | 162 | | [megagonlabs/bimeanvae-amzn](https://huggingface.co/megagonlabs/bimeanvae-amzn) | Coop | **36.57** | **7.23** | **21.24** | 163 | | [megagonlabs/optimus-amzn](https://huggingface.co/megagonlabs/optimus-amzn) | SimpleAvg | 33.54 | 6.18 | 19.34 | 164 | | [megagonlabs/optimus-amzn](https://huggingface.co/megagonlabs/optimus-amzn) | Coop | 35.32 | 6.22 | 19.84 | 165 | 166 | 167 | # Reproduction 168 | 169 | ## Setup 170 | ```shell 171 | $ unzip coop.zip && cd coop 172 | $ conda create -n coop python=3.7 173 | $ conda activate coop 174 | $ conda install -c conda-forge jsonnet sentencepiece # If needed 175 | $ pip install -r requirements.txt 176 | ``` 177 | 178 | ## Preparation 179 | 180 | ### Yelp dataset 181 | 182 | Download the Yelp dataset from [this link](https://www.yelp.com/dataset). 183 | You only need the JSON file (`yelp_dataset.tar`). 184 | 185 | Move the file to `data/yelp` and uncompress it. You only need `yelp_academic_dataset_review.json` 186 | 187 | ```bash 188 | $ tar -xvf yelp_dataset.tar 189 | $ YELP_RAW=$(pwd)/yelp_academic_dataset_review.json 190 | ``` 191 | 192 | Run the following preprocessing scripts. This may take several hours, depending on your machine spec. 193 | 194 | ```bash 195 | $ mkdir -p ./data/yelp 196 | $ python scripts/preprocess.py yelp $YELP_RAW > ./data/yelp/train.jsonl 197 | ``` 198 | 199 | Additionally, you need to download the reference summaries from [this link](https://s3.us-east-2.amazonaws.com/unsup-sum/summaries_0-200_cleaned.csv) provided by [MeanSum](https://github.com/sosuperic/MeanSum) 200 | 201 | Run the following command to download and preprocess it. 202 | This will create `dev.json` and `test.json`, which follow the dev/test splits 203 | defined in [the original MeanSum paper](https://arxiv.org/abs/1810.05739). 204 | 205 | ``` 206 | $ python scripts/get_summ.py yelp data/yelp 207 | $ ls data/yelp 208 | train.jsonl 209 | dev.json 210 | test.json 211 | ``` 212 | 213 | 214 | ### Amazon dataset 215 | 216 | Download the Amazon dataset from [this link](http://jmcauley.ucsd.edu/data/amazon/links.html). 217 | You only need the following files for 4 categories: 218 | - [Clothing_Shoes_and_Jewelry.json.gz](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Clothing_Shoes_and_Jewelry.json.gz) 219 | - [Electronics.json.gz](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics.json.gz) 220 | - [Health_and_Personal_Care.json.gz](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Health_and_Personal_Care.json.gz) 221 | - [Home_and_Kitchen.json.gz](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Home_and_Kitchen.json.gz) 222 | 223 | Run the script to download the datasets. 224 | You **don't need to uncompress** them. 225 | 226 | ```shell 227 | $ mkdir amzn_raw && cd amzn_raw 228 | $ wget -P data/amazon http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Clothing_Shoes_and_Jewelry.json.gz 229 | $ wget -P data/amazon http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics.json.gz 230 | $ wget -P data/amazon http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Health_and_Personal_Care.json.gz 231 | $ wget -P data/amazon http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Home_and_Kitchen.json.gz 232 | $ AMZN_RAW=$(pwd) 233 | $ ls $AMZN_RAW 234 | Clothing_Shoes_and_Jewelry.json.gz 235 | Electronics.json.gz 236 | Health_and_Personal_Care.json.gz 237 | Home_and_Kitchen.json.gz 238 | $ cd - 239 | ``` 240 | 241 | Run the following preprocessing script. This may take several hours, depending on your machine spec. 242 | 243 | ```bash 244 | $ mkdir -p ./data/amzn 245 | $ python scripts/preprocess.py amzn $AMZN_RAW > ./data/amzn/train.jsonl 246 | ``` 247 | 248 | Download the reference summaries from this link provided by CopyCat. 249 | 250 | Run the following command to download and preprocess it. This will create dev.json and test.json, which follow the dev/test splits defined in the original CopyCat paper. 251 | 252 | ```bash 253 | $ python scripts/get_summ.py amzn data/amzn 254 | $ ls data/amzn 255 | train.jsonl 256 | dev.json 257 | test.json 258 | ``` 259 | 260 | 261 | ## Training 262 | ### Model and Training Configuration 263 | ```config``` directory contains the configuration files used for the experiments. You can copy it and edit the configuration file to run experiments in different settings. 264 | 265 | ```jsonnet 266 | local lib = import '../utils.libsonnet'; 267 | local data_type = "yelp"; 268 | local latent_dim = 512; 269 | local free_bit = 0.25; 270 | local num_steps = 100000; 271 | local checkout_step = 1000; 272 | local batch_size = 256; 273 | local lr = 1e-3; 274 | 275 | { 276 | "data_dir": "./data/%s" % data_type, 277 | "spm_path": "./data/sentencepiece/%s.model" % data_type, 278 | "model": lib.BiMeanVAE(latent_dim, free_bit), 279 | "trainer": lib.VAETrainer(num_steps, checkout_step, batch_size, lr) 280 | } 281 | 282 | ``` 283 | 284 | ### Training a model 285 | To train the model, you can run the following script with ``config`` file and the directory to save checkpoints. 286 | ```bash 287 | $ python train.py -s 288 | ``` 289 | 290 | For example, 291 | 292 | ```bash 293 | $ python train.py config/bimeanvae/yelp.jsonnet -s log/bimeanvae/yelp/ex1 294 | ``` 295 | 296 | ## Evaluation 297 | To evaluate the model with our proposed framework, ```coop```, you can simply run the following: 298 | ```bash 299 | $ python coop/search.py 300 | ``` 301 | 302 | For example, 303 | ```bash 304 | $ python coop/search.py log/bimeanvae/yelp/ex1 305 | ``` 306 | -------------------------------------------------------------------------------- /coop/models/optimus.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from torch.distributions import Normal, kl_divergence 4 | from transformers import BertModel 5 | from transformers.modeling_gpt2 import * 6 | 7 | from . import Model 8 | from .util import Losses, VAEOut 9 | 10 | PAD, BOS, EOS = '', '', '' 11 | SPECIAL = {'pad_token': PAD, 'bos_token': BOS, 'eos_token': EOS} 12 | 13 | 14 | class Optimus(Model): 15 | def __init__(self, 16 | latent_dim: int, 17 | pad_id: int, 18 | bos_id: int, 19 | eos_id: int, 20 | free_bit: float = 0.05): 21 | encoder = BertModel.from_pretrained("bert-base-cased", return_dict=True) 22 | super().__init__(encoder.config.hidden_size, latent_dim) 23 | self.encoder = encoder 24 | self.decoder = OptimusDecoder.from_pretrained("gpt2", latent_dim=latent_dim, pad_id=pad_id, return_dict=True) 25 | self.decoder.resize_token_embeddings(self.decoder.config.vocab_size + len(SPECIAL)) 26 | 27 | self.latent_dim = latent_dim 28 | 29 | self.pad_id = pad_id 30 | self.bos_id = bos_id 31 | self.eos_id = eos_id 32 | 33 | self.proj = nn.Linear(self.encoder.config.hidden_size, 2 * latent_dim, bias=False) 34 | 35 | self.free_bit = free_bit 36 | 37 | def forward(self, 38 | src: Dict[str, torch.Tensor], 39 | tgt: Dict[str, torch.Tensor] = None, 40 | do_generate: torch.Tensor = False, 41 | num_beams: int = 4, 42 | **kwargs): 43 | cls_vec = self.encoder(**src).pooler_output 44 | mu, log_var = torch.chunk(self.proj(cls_vec), chunks=2, dim=-1) 45 | std = torch.exp(0.5 * log_var) 46 | q = Normal(mu, std) 47 | p = Normal(0, 1) 48 | zkl_real = kl_divergence(q, p) 49 | kl_mask = torch.gt(zkl_real, self.free_bit) 50 | bz = cls_vec.size(0) 51 | zkl = zkl_real[kl_mask].sum() / bz 52 | zkl_real = zkl_real.sum(dim=-1).mean() 53 | 54 | if self.training: 55 | assert tgt is not None 56 | z = q.rsample() 57 | outputs = self.decoder(**tgt, 58 | past_key_values=(z,), 59 | latent_as_gpt_memory=True, 60 | latent_as_gpt_emb=True) 61 | return Losses(nll=outputs.loss, zkl=zkl, zkl_real=zkl_real) 62 | else: 63 | if do_generate: 64 | z = q.mean 65 | generated = self.generate(z, num_beams=num_beams) 66 | else: 67 | generated = None 68 | return VAEOut(q=q, generated=generated) 69 | 70 | @torch.no_grad() 71 | def generate(self, 72 | z: torch.Tensor, 73 | num_beams: int = 4, 74 | max_tokens: int = 256, 75 | bad_words_ids: List[int] = None): 76 | bz, _ = z.size() 77 | 78 | input_ids = z.new_full((bz, 1), dtype=torch.long, fill_value=self.bos_id) 79 | generated = self.decoder.generate( 80 | input_ids, 81 | max_length=max_tokens, 82 | min_length=16, 83 | num_beams=num_beams, 84 | bad_words_ids=bad_words_ids, 85 | bos_token_id=self.bos_id, 86 | pad_token_id=self.pad_id, 87 | eos_token_id=self.eos_id, 88 | past_key_values=(z,), 89 | no_repeat_ngram_size=2, 90 | latent_as_gpt_memory=True, 91 | latent_as_gpt_emb=True).tolist() 92 | return generated 93 | 94 | def reset_decoder(self): 95 | device = self.encoder.device 96 | self.decoder = OptimusDecoder.from_pretrained("gpt2", latent_dim=self.latent_dim, pad_id=self.pad_id, 97 | return_dict=True) 98 | self.decoder.resize_token_embeddings(self.decoder.config.vocab_size + len(SPECIAL)) 99 | self.decoder.to(device) 100 | self.train() 101 | 102 | @staticmethod 103 | def klw(step: int, 104 | interval: int, 105 | r: float = 0.75, 106 | t: float = 0.5, 107 | s: int = 100000): 108 | value = (step % interval) / interval 109 | klw = max(min((value - t) / (r - t), 1.), 0.) 110 | return klw 111 | 112 | 113 | class Attention(nn.Module): 114 | def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): 115 | super().__init__() 116 | 117 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 118 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 119 | assert n_state % config.n_head == 0 120 | self.register_buffer( 121 | "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx) 122 | ) 123 | self.register_buffer("masked_bias", torch.tensor(-1e4)) 124 | self.n_head = config.n_head 125 | self.split_size = n_state 126 | self.scale = scale 127 | self.is_cross_attention = is_cross_attention 128 | if self.is_cross_attention: 129 | self.c_attn = Conv1D(2 * n_state, nx) 130 | self.q_attn = Conv1D(n_state, nx) 131 | else: 132 | self.c_attn = Conv1D(3 * n_state, nx) 133 | self.c_proj = Conv1D(n_state, nx) 134 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 135 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 136 | self.pruned_heads = set() 137 | 138 | def prune_heads(self, heads): 139 | if len(heads) == 0: 140 | return 141 | heads, index = find_pruneable_heads_and_indices( 142 | heads, self.n_head, self.split_size // self.n_head, self.pruned_heads 143 | ) 144 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 145 | 146 | # Prune conv1d layers 147 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 148 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 149 | 150 | # Update hyper params 151 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 152 | self.n_head = self.n_head - len(heads) 153 | self.pruned_heads = self.pruned_heads.union(heads) 154 | 155 | def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): 156 | w = torch.matmul(q, k) 157 | if self.scale: 158 | w = w / (float(v.size(-1)) ** 0.5) 159 | nd, ns = w.size(-2), w.size(-1) 160 | 161 | if not self.is_cross_attention: 162 | # if only "normal" attention layer implements causal mask 163 | mask = self.bias[:, :, ns - nd: ns, :ns] 164 | w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) 165 | 166 | if attention_mask is not None: 167 | # Apply the attention mask 168 | w = w + attention_mask 169 | 170 | w = nn.Softmax(dim=-1)(w) 171 | w = self.attn_dropout(w) 172 | 173 | # Mask heads if we want to 174 | if head_mask is not None: 175 | w = w * head_mask 176 | 177 | outputs = [torch.matmul(w, v)] 178 | if output_attentions: 179 | outputs.append(w) 180 | return outputs 181 | 182 | def merge_heads(self, x): 183 | x = x.permute(0, 2, 1, 3).contiguous() 184 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 185 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 186 | 187 | def split_heads(self, x, k=False): 188 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 189 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 190 | if k: 191 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 192 | else: 193 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 194 | 195 | def forward( 196 | self, 197 | hidden_states, 198 | layer_past=None, 199 | attention_mask=None, 200 | head_mask=None, 201 | encoder_hidden_states=None, 202 | encoder_attention_mask=None, 203 | use_cache=False, 204 | output_attentions=False, 205 | ): 206 | if encoder_hidden_states is not None: 207 | assert hasattr( 208 | self, "q_attn" 209 | ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." 210 | query = self.q_attn(hidden_states) 211 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 212 | attention_mask = encoder_attention_mask 213 | else: 214 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 215 | 216 | query = self.split_heads(query) 217 | key = self.split_heads(key, k=True) 218 | value = self.split_heads(value) 219 | if layer_past is not None: 220 | # https://github.com/ChunyuanLI/Optimus/blob/master/code/pytorch_transformers/modeling_gpt2.py#L189-L196 221 | past_key, past_value = layer_past[0][0], layer_past[0][1] # transpose back cf below 222 | 223 | past_key = self.split_heads(past_key, k=True) 224 | past_value = self.split_heads(past_value) 225 | key = torch.cat((past_key, key), dim=-1) 226 | value = torch.cat((past_value, value), dim=-2) 227 | 228 | if use_cache is True: 229 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 230 | else: 231 | present = (None,) 232 | 233 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) 234 | a = attn_outputs[0] 235 | 236 | a = self.merge_heads(a) 237 | a = self.c_proj(a) 238 | a = self.resid_dropout(a) 239 | 240 | outputs = [a, present] + attn_outputs[1:] 241 | return outputs # a, present, (attentions) 242 | 243 | 244 | class Block(nn.Module): 245 | def __init__(self, n_ctx, config, scale=False): 246 | super().__init__() 247 | hidden_size = config.n_embd 248 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 249 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 250 | self.attn = Attention(hidden_size, n_ctx, config, scale) 251 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 252 | if config.add_cross_attention: 253 | self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) 254 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 255 | self.mlp = MLP(inner_dim, config) 256 | 257 | def forward( 258 | self, 259 | hidden_states, 260 | layer_past=None, 261 | attention_mask=None, 262 | head_mask=None, 263 | encoder_hidden_states=None, 264 | encoder_attention_mask=None, 265 | use_cache=False, 266 | output_attentions=False, 267 | ): 268 | attn_outputs = self.attn( 269 | self.ln_1(hidden_states), 270 | layer_past=layer_past, 271 | attention_mask=attention_mask, 272 | head_mask=head_mask, 273 | use_cache=use_cache, 274 | output_attentions=output_attentions, 275 | ) 276 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 277 | outputs = attn_outputs[1:] 278 | # residual connection 279 | hidden_states = attn_output + hidden_states 280 | 281 | if encoder_hidden_states is not None: 282 | # add one self-attention block for cross-attention 283 | assert hasattr( 284 | self, "crossattention" 285 | ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 286 | cross_attn_outputs = self.crossattention( 287 | self.ln_cross_attn(hidden_states), 288 | attention_mask=attention_mask, 289 | head_mask=head_mask, 290 | encoder_hidden_states=encoder_hidden_states, 291 | encoder_attention_mask=encoder_attention_mask, 292 | output_attentions=output_attentions, 293 | ) 294 | attn_output = cross_attn_outputs[0] 295 | # residual connection 296 | hidden_states = hidden_states + attn_output 297 | outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights 298 | 299 | feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) 300 | # residual connection 301 | hidden_states = hidden_states + feed_forward_hidden_states 302 | 303 | outputs = [hidden_states] + outputs 304 | return outputs # hidden_states, present, (cross_attentions, attentions) 305 | 306 | 307 | class OptimusGPT2(GPT2PreTrainedModel): 308 | def __init__(self, config, latent_dim: int): 309 | super().__init__(config) 310 | 311 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 312 | self.wpe = nn.Embedding(config.n_positions, config.n_embd) 313 | self.drop = nn.Dropout(config.embd_pdrop) 314 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) 315 | self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 316 | 317 | # https://github.com/ChunyuanLI/Optimus/blob/master/code/pytorch_transformers/modeling_gpt2.py#L364-L370 318 | self.latent_dim = latent_dim 319 | self.linear_mem = nn.Linear(self.latent_dim, config.hidden_size * config.n_layer, bias=False) 320 | self.linear_emb = nn.Linear(self.latent_dim, config.hidden_size, bias=False) 321 | 322 | self.init_weights() 323 | 324 | def get_input_embeddings(self): 325 | return self.wte 326 | 327 | def set_input_embeddings(self, new_embeddings): 328 | self.wte = new_embeddings 329 | 330 | def _prune_heads(self, heads_to_prune): 331 | """Prunes heads of the model. 332 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 333 | """ 334 | for layer, heads in heads_to_prune.items(): 335 | self.h[layer].attn.prune_heads(heads) 336 | 337 | def forward( 338 | self, 339 | input_ids=None, 340 | past_key_values=None, 341 | attention_mask=None, 342 | token_type_ids=None, 343 | position_ids=None, 344 | head_mask=None, 345 | inputs_embeds=None, 346 | encoder_hidden_states=None, 347 | encoder_attention_mask=None, 348 | use_cache=None, 349 | output_attentions=None, 350 | output_hidden_states=None, 351 | return_dict=None, 352 | latent_as_gpt_emb=False, 353 | latent_as_gpt_memory=False, 354 | **kwargs, 355 | ): 356 | if "past" in kwargs: 357 | warnings.warn( 358 | "The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", 359 | FutureWarning, 360 | ) 361 | past_key_values = kwargs.pop("past") 362 | assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." 363 | 364 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 365 | output_hidden_states = ( 366 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 367 | ) 368 | use_cache = use_cache if use_cache is not None else self.config.use_cache 369 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 370 | 371 | if input_ids is not None and inputs_embeds is not None: 372 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 373 | elif input_ids is not None: 374 | input_shape = input_ids.size() 375 | input_ids = input_ids.view(-1, input_shape[-1]) 376 | batch_size = input_ids.shape[0] 377 | elif inputs_embeds is not None: 378 | input_shape = inputs_embeds.size()[:-1] 379 | batch_size = inputs_embeds.shape[0] 380 | else: 381 | raise ValueError("You have to specify either input_ids or inputs_embeds") 382 | 383 | if token_type_ids is not None: 384 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 385 | if position_ids is not None: 386 | position_ids = position_ids.view(-1, input_shape[-1]) 387 | 388 | if past_key_values is None: 389 | past_length = 0 390 | past_key_values = [None] * len(self.h) 391 | else: 392 | # https://github.com/ChunyuanLI/Optimus/blob/master/code/pytorch_transformers/modeling_gpt2.py#L394-L415 393 | if latent_as_gpt_emb: 394 | past_emb = self.linear_emb(past_key_values[0]) # used as embeddings to add on other three embeddings 395 | 396 | if latent_as_gpt_memory: 397 | if len(past_key_values) == 1: 398 | memory = self.linear_mem(past_key_values[0]) 399 | memory_split = torch.split(memory.unsqueeze(1), self.config.hidden_size, dim=2) 400 | past_key_values_head = list(zip(memory_split, memory_split)) 401 | past_key_values = (past_key_values_head,) + past_key_values[1:] 402 | 403 | past_length = 1 404 | else: 405 | past_length = past_key_values[-1][0].size(-2) 406 | else: 407 | past_length = 0 408 | 409 | if position_ids is None: 410 | device = input_ids.device if input_ids is not None else inputs_embeds.device 411 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 412 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 413 | 414 | # Attention mask. 415 | if attention_mask is not None: 416 | assert batch_size > 0, "batch_size has to be defined and > 0" 417 | attention_mask = attention_mask.view(batch_size, -1) 418 | # We create a 3D attention mask from a 2D tensor mask. 419 | # Sizes are [batch_size, 1, 1, to_seq_length] 420 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 421 | # this attention mask is more simple than the triangular masking of causal attention 422 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 423 | attention_mask = attention_mask[:, None, None, :] 424 | 425 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 426 | # masked positions, this operation will create a tensor which is 0.0 for 427 | # positions we want to attend and -10000.0 for masked positions. 428 | # Since we are adding it to the raw scores before the softmax, this is 429 | # effectively the same as removing these entirely. 430 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 431 | attention_mask = (1.0 - attention_mask) * -10000.0 432 | 433 | # If a 2D ou 3D attention mask is provided for the cross-attention 434 | # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] 435 | if self.config.add_cross_attention and encoder_hidden_states is not None: 436 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 437 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 438 | if encoder_attention_mask is None: 439 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 440 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 441 | else: 442 | encoder_attention_mask = None 443 | 444 | # Prepare head mask if needed 445 | # 1.0 in head_mask indicate we keep the head 446 | # attention_probs has shape bsz x n_heads x N x N 447 | # head_mask has shape n_layer x batch x n_heads x N x N 448 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 449 | 450 | if inputs_embeds is None: 451 | inputs_embeds = self.wte(input_ids) 452 | position_embeds = self.wpe(position_ids) 453 | if token_type_ids is not None: 454 | token_type_embeds = self.wte(token_type_ids) 455 | else: 456 | token_type_embeds = 0 457 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 458 | if latent_as_gpt_emb: 459 | hidden_states = hidden_states + past_emb.unsqueeze(1) 460 | 461 | hidden_states = self.drop(hidden_states) 462 | 463 | output_shape = input_shape + (hidden_states.size(-1),) 464 | 465 | presents = () if use_cache else None 466 | all_attentions = () if output_attentions else None 467 | all_hidden_states = () if output_hidden_states else None 468 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 469 | if output_hidden_states: 470 | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) 471 | 472 | if getattr(self.config, "gradient_checkpointing", False): 473 | 474 | def create_custom_forward(module): 475 | def custom_forward(*inputs): 476 | # checkpointing only works with tuple returns, not with lists 477 | return tuple(output for output in module(*inputs, use_cache, output_attentions)) 478 | 479 | return custom_forward 480 | 481 | outputs = torch.utils.checkpoint.checkpoint( 482 | create_custom_forward(block), 483 | hidden_states, 484 | layer_past, 485 | attention_mask, 486 | head_mask[i], 487 | encoder_hidden_states, 488 | encoder_attention_mask, 489 | ) 490 | else: 491 | outputs = block( 492 | hidden_states, 493 | layer_past=layer_past, 494 | attention_mask=attention_mask, 495 | head_mask=head_mask[i], 496 | encoder_hidden_states=encoder_hidden_states, 497 | encoder_attention_mask=encoder_attention_mask, 498 | use_cache=use_cache, 499 | output_attentions=output_attentions, 500 | ) 501 | 502 | hidden_states, present = outputs[:2] 503 | if use_cache is True: 504 | presents = presents + (present,) 505 | 506 | if output_attentions: 507 | all_attentions = all_attentions + (outputs[2],) 508 | 509 | hidden_states = self.ln_f(hidden_states) 510 | 511 | hidden_states = hidden_states.view(*output_shape) 512 | # Add last hidden state 513 | if output_hidden_states: 514 | all_hidden_states = all_hidden_states + (hidden_states,) 515 | 516 | if not return_dict: 517 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) 518 | 519 | return BaseModelOutputWithPast( 520 | last_hidden_state=hidden_states, 521 | past_key_values=presents, 522 | hidden_states=all_hidden_states, 523 | attentions=all_attentions, 524 | ) 525 | 526 | 527 | class OptimusDecoder(GPT2PreTrainedModel): 528 | authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] 529 | 530 | def __init__(self, config, latent_dim, pad_id): 531 | super().__init__(config) 532 | self.transformer = OptimusGPT2(config, latent_dim=latent_dim) 533 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 534 | self.pad_id = pad_id 535 | 536 | self.init_weights() 537 | self.tie_weights() 538 | 539 | def tie_weights(self): 540 | """ Make sure we are sharing the input and output embeddings. 541 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 542 | """ 543 | self._tie_or_clone_weights(self.lm_head, 544 | self.transformer.wte) 545 | 546 | def get_output_embeddings(self): 547 | return self.lm_head 548 | 549 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 550 | # We don't use caching for simplicity 551 | num_total = input_ids.size(0) 552 | past_key_values = kwargs["past_key_values"][0] 553 | if past_key_values.size(0) != num_total: # For beam search 554 | batch_size, latent_dim = past_key_values.size() 555 | past_key_values = past_key_values.unsqueeze(1).expand(batch_size, num_total // batch_size, latent_dim) 556 | past_key_values = past_key_values.contiguous().view(num_total, latent_dim) 557 | return { 558 | "input_ids": input_ids, 559 | "past_key_values": (past_key_values,), 560 | "latent_as_gpt_memory": True, 561 | "latent_as_gpt_emb": True, 562 | } 563 | 564 | @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) 565 | def forward( 566 | self, 567 | input_ids=None, 568 | past_key_values=None, 569 | attention_mask=None, 570 | token_type_ids=None, 571 | position_ids=None, 572 | head_mask=None, 573 | inputs_embeds=None, 574 | encoder_hidden_states=None, 575 | encoder_attention_mask=None, 576 | labels=None, 577 | use_cache=None, 578 | output_attentions=None, 579 | output_hidden_states=None, 580 | return_dict=None, 581 | latent_as_gpt_emb=False, 582 | latent_as_gpt_memory=False, 583 | **kwargs, 584 | ): 585 | r""" 586 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 587 | Labels for language modeling. 588 | Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids`` 589 | Indices are selected in ``[-100, 0, ..., config.vocab_size]`` 590 | All labels set to ``-100`` are ignored (masked), the loss is only 591 | computed for labels in ``[0, ..., config.vocab_size]`` 592 | """ 593 | if "past" in kwargs: 594 | warnings.warn( 595 | "The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", 596 | FutureWarning, 597 | ) 598 | past_key_values = kwargs.pop("past") 599 | assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." 600 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 601 | 602 | transformer_outputs = self.transformer( 603 | input_ids, 604 | past_key_values=past_key_values, 605 | attention_mask=attention_mask, 606 | token_type_ids=token_type_ids, 607 | position_ids=position_ids, 608 | head_mask=head_mask, 609 | inputs_embeds=inputs_embeds, 610 | encoder_hidden_states=encoder_hidden_states, 611 | encoder_attention_mask=encoder_attention_mask, 612 | use_cache=use_cache, 613 | output_attentions=output_attentions, 614 | output_hidden_states=output_hidden_states, 615 | return_dict=return_dict, 616 | latent_as_gpt_emb=latent_as_gpt_emb, 617 | latent_as_gpt_memory=latent_as_gpt_memory, 618 | ) 619 | hidden_states = transformer_outputs[0] 620 | 621 | lm_logits = self.lm_head(hidden_states) 622 | 623 | loss = None 624 | if labels is not None: 625 | # Shift so that tokens < n predict n 626 | shift_logits = lm_logits[..., :-1, :].contiguous() 627 | shift_labels = labels[..., 1:].contiguous() 628 | # Flatten the tokens 629 | loss_fct = CrossEntropyLoss(reduction="none") 630 | losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 631 | bz = labels.size(0) 632 | loss = losses[shift_labels.view(-1) != self.pad_id].sum() / bz 633 | 634 | if not return_dict: 635 | output = (lm_logits,) + transformer_outputs[1:] 636 | return ((loss,) + output) if loss is not None else output 637 | 638 | return CausalLMOutputWithPast( 639 | loss=loss, 640 | logits=lm_logits, 641 | past_key_values=transformer_outputs.past_key_values, 642 | hidden_states=transformer_outputs.hidden_states, 643 | attentions=transformer_outputs.attentions, 644 | ) 645 | --------------------------------------------------------------------------------