├── .idea ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml └── RECAP.iml ├── requirements.txt ├── LICENSE ├── src ├── utils │ ├── retriever_utils.py │ └── generator_utils.py ├── preprocess │ ├── encode_comments.py │ ├── retrieval.py │ ├── recent.py │ └── retrieved.py ├── retrieve.py ├── models │ ├── retriever.py │ └── dialogpt.py ├── generate.py ├── train_generator.py ├── train_retriever.py └── eval.py └── README.md /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.64.1 2 | numpy==1.23.3 3 | nltk==3.8.1 4 | torch==1.12.1 5 | datasets==1.18.3 6 | transformers==4.17.0 7 | sentence-transformers==2.2.2 8 | pytorch-ignite==0.4.12 9 | bert-score==0.3.13 10 | git+https://github.com/google-research/bleurt.git -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/RECAP.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 shuailiu6626 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/utils/retriever_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from transformers import PreTrainedTokenizerBase 5 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 6 | from transformers.file_utils import PaddingStrategy 7 | 8 | 9 | def encode_example(example, tokenizer): 10 | max_len=128 11 | 12 | encoded_example = {} 13 | encoded_srcs = tokenizer( 14 | [txt.replace("<|TITLE|> ", "").replace(" <|EOS|> ", tokenizer.eos_token) for txt in example['srcs']], 15 | max_length=max_len//2, 16 | truncation=True, 17 | padding="max_length", 18 | ) 19 | encoded_tgts = tokenizer( 20 | [txt for txt in example['tgts']], 21 | max_length=max_len//2, 22 | truncation=True, 23 | padding="max_length" 24 | ) 25 | 26 | encoded_example = { 27 | "srcs_ids": encoded_srcs.input_ids, 28 | "srcs_attention_mask": encoded_srcs.attention_mask, 29 | "tgts_ids": encoded_tgts.input_ids, 30 | "tgts_attention_mask": encoded_tgts.attention_mask, 31 | } 32 | 33 | return encoded_example 34 | 35 | 36 | @dataclass 37 | class DataCollator: 38 | 39 | def __call__(self, features, return_tensors=None): 40 | import numpy as np 41 | 42 | seq_num = len(features[0]['srcs_ids']) 43 | features = [{k if 'rep' not in k else 'labels': 44 | torch.vstack(v) if 'rep' not in k else v.reshape(seq_num, -1) 45 | for k, v in feature.items()} for feature in features] 46 | batch = {} 47 | for key in features[0].keys(): 48 | batch[key] = torch.cat([feature[key].unsqueeze(0) for feature in features], dim=0) 49 | 50 | return batch -------------------------------------------------------------------------------- /src/preprocess/encode_comments.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch 7 | from sentence_transformers import SentenceTransformer 8 | 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-d', '--data_path', type=str, help="path to the raw data", default="data/raw") 14 | parser.add_argument('-o', '--output_path', type=str, help="output processed data path", default="temp/reps") 15 | args = parser.parse_args() 16 | 17 | models = {} 18 | models['style'] = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cuda').half() 19 | models['semantic'] = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device='cuda').half() 20 | 21 | for split in ["train", "valid", "test", "pandora"]: 22 | lines = [] 23 | with open(os.path.join(args.data_path, f"{split}.jsonl")) as f: 24 | for i, line in tqdm(enumerate(f), disable=True): 25 | line = json.loads(line) 26 | lines.append(line) 27 | lines = sorted(lines, key=lambda x:x[1]) 28 | srcs = [line[2].replace("<|TITLE|> ", "").replace(" <|EOS|> ", "\n") for line in lines] 29 | tgts = [line[3] for line in lines] 30 | 31 | for rep_type in ['style', 'semantic']: 32 | reps = {} 33 | reps['src'] = models[rep_type].encode(srcs, batch_size=1024, convert_to_tensor=False, normalize_embeddings=False, show_progress_bar=False) 34 | reps['tgt'] = models[rep_type].encode(tgts, batch_size=1024, convert_to_tensor=False, normalize_embeddings=False, show_progress_bar=False) 35 | reps['src'] = torch.from_numpy(reps['src']).half() 36 | reps['tgt'] = torch.from_numpy(reps['tgt']).half() 37 | torch.save(reps, os.path.join(args.output_path, f"{rep_type}/{split}_fp16.pt"), create_dir=True) -------------------------------------------------------------------------------- /src/preprocess/retrieval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | from collections import defaultdict 5 | import pandas as pd 6 | import random 7 | import json 8 | 9 | import torch 10 | from datasets import load_dataset, load_from_disk 11 | 12 | 13 | if __name__ == '__main__': 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-d', '--data_path', type=str, help="path to the raw data", default="data/raw") 17 | parser.add_argument('-o', '--output_path', type=str, help="output processed data path", default="data/retrieval") 18 | args = parser.parse_args() 19 | 20 | for split in ["train", "valid", "test", "pandora"]: 21 | lines = [] 22 | user2lines = defaultdict(list) 23 | with open(os.path.join(args.data_path, f"{split}.jsonl")) as f: 24 | for i, line in tqdm(enumerate(f), disable=True): 25 | line = json.loads(line) 26 | lines.append(line) 27 | 28 | lines = sorted(lines, key=lambda x:x[1]) 29 | for i, line in enumerate(lines): 30 | author = line[0] 31 | user2lines[author].append(i) 32 | 33 | authors = sorted(user2lines.keys()) 34 | author2idx = {author: idx for idx, author in enumerate(authors)} 35 | 36 | author_comments = defaultdict(list) 37 | for author in tqdm(authors, disable=True): 38 | author_lines = user2lines.get(author, None) 39 | assert author_lines 40 | prev_t = 0 41 | for author_line in author_lines: 42 | line = lines[author_line] 43 | _author, timestamp, src, tgt, subreddit = line 44 | 45 | assert author == _author 46 | assert prev_t <= timestamp 47 | prev_t = timestamp 48 | 49 | author_comments[author].append((int(timestamp), src, tgt, author_line)) 50 | author_comments[author] = sorted(author_comments[author], key=lambda x:x[0]) 51 | assert len(author_comments) == len(authors) 52 | 53 | with open(os.path.join(args.output_path, f"{split}.jsonl"), 'w') as f: 54 | for author, samples in tqdm(author_comments.items(), disable=True): 55 | 56 | ex = { 57 | "srcs": [sample[1] for sample in samples], 58 | "tgts": [sample[2] for sample in samples], 59 | } 60 | 61 | json_line = json.dumps(ex) 62 | f.write(f"{json_line}\n") -------------------------------------------------------------------------------- /src/preprocess/recent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from collections import defaultdict 5 | 6 | 7 | 8 | if __name__ == "__main__": 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('-d', '--data_path', type=str, help="path to the raw data", default="data/raw") 12 | parser.add_argument('-o', '--output_path', type=str, help="output processed data path", default="data/recent") 13 | args = parser.parse_args() 14 | 15 | if not os.path.exists(args.output_path): 16 | os.makedirs(args.output_path) 17 | 18 | for split in ["train", "valid", "test", "pandora"]: 19 | lines = [] 20 | user2lines = defaultdict(list) 21 | with open(os.path.join(args.data_path, f"{split}.jsonl")) as f: 22 | for i, line in enumerate(f): 23 | line = json.loads(line) 24 | lines.append(line) 25 | 26 | lines = sorted(lines, key=lambda x:x[1]) 27 | for i, line in enumerate(lines): 28 | author = line[0] 29 | user2lines[author].append(i) 30 | 31 | authors = sorted(user2lines.keys()) 32 | author2idx = {author: idx for idx, author in enumerate(authors)} 33 | 34 | author_comments = defaultdict(list) 35 | for author in authors: 36 | author_lines = user2lines.get(author, None) 37 | assert author_lines 38 | for author_line in author_lines: 39 | line = lines[author_line] 40 | _author, timestamp, src, tgt, subreddit = line 41 | assert author == _author 42 | author_comments[author].append((int(timestamp), src, tgt)) 43 | author_comments[author] = sorted(author_comments[author], key=lambda x:x[0]) 44 | assert len(author_comments) == len(authors) 45 | 46 | global_i = 0 47 | with open(os.path.join(args.output_path, f"{split}.jsonl"), 'w') as f: 48 | for author, samples in author_comments.items(): 49 | recent_responses = [] 50 | for i, (timestamp, src, tgt) in enumerate(samples): 51 | ex = { 52 | "id": global_i, 53 | "author": author, 54 | "author_id": author2idx[author], 55 | "timestamp": timestamp, 56 | "src": src, 57 | "tgt": tgt, 58 | } 59 | ex["style_tgts"] = recent_responses 60 | ex["semantic_tgts"] = None 61 | 62 | if i >= 90: 63 | global_i += 1 64 | json_line = json.dumps(ex) 65 | f.write(f"{json_line}\n") 66 | 67 | recent_responses.insert(0, f"{tgt}") 68 | if len(recent_responses) > 10: 69 | del recent_responses[-1] -------------------------------------------------------------------------------- /src/preprocess/retrieved.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from tqdm import tqdm 5 | from collections import Counter, defaultdict 6 | 7 | import torch 8 | 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-d', '--data_path', type=str, help="path to the raw data", default="data/raw") 14 | parser.add_argument('-r', '--retrieved_path', type=str, help="path to the retriever output", default="temp/retrieved") 15 | parser.add_argument('-o', '--output_path', type=str, help="output processed data path", default="data/retrieved") 16 | args = parser.parse_args() 17 | 18 | retrieved = { 19 | "style": torch.load(os.path.join(args.retrieved_path, "style.pt")), 20 | "semantic": torch.load(os.path.join(args.retrieved_path, "semantic.pt")), 21 | } 22 | 23 | for split in ["train", "valid", "test", "pandora"]: 24 | sorted_style = retrieved['style'][split] 25 | sorted_semantic = retrieved['semantic'][split] 26 | 27 | lines = [] 28 | user2lines = defaultdict(list) 29 | with open(os.path.join(args.data_path, f"{split}.jsonl")) as f: 30 | for i, line in tqdm(enumerate(f), disable=True): 31 | line = json.loads(line) 32 | lines.append(line) 33 | 34 | lines = sorted(lines, key=lambda x:x[1]) 35 | for i, line in enumerate(lines): 36 | author = line[0] 37 | user2lines[author].append(i) 38 | 39 | authors = sorted(user2lines.keys()) 40 | author2idx = {author: idx for idx, author in enumerate(authors)} 41 | 42 | author_comments = defaultdict(list) 43 | for author in tqdm(authors, disable=True): 44 | author_lines = user2lines.get(author, None) 45 | assert author_lines 46 | for author_line in author_lines: 47 | line = lines[author_line] 48 | _author, timestamp, src, tgt, subreddit = line 49 | assert author == _author 50 | author_comments[author].append((int(timestamp), src, tgt, author_line)) 51 | author_comments[author] = sorted(author_comments[author], key=lambda x:x[0]) 52 | assert len(author_comments) == len(authors) 53 | 54 | global_i = 0 55 | with open(os.path.join(args.output_path, f"{split}.jsonl"), 'w') as f: 56 | for i, (author, samples) in enumerate(tqdm(author_comments.items(), disable=True)): 57 | most_similar_style = sorted_style[i*10:(i+1)*10] 58 | most_similar_semantic = sorted_semantic[i*10:(i+1)*10] 59 | for i, (timestamp, src, tgt, _) in enumerate(samples[-10:]): 60 | ex = { 61 | "id": global_i, 62 | "author": author, 63 | "author_id": author2idx[author], 64 | "timestamp": timestamp, 65 | "src": src, 66 | "tgt": tgt, 67 | } 68 | ex["style_tgts"] = [samples[j][2] for j in most_similar_style[i][:10]] 69 | ex["semantic_tgts"] = [samples[j][2] for j in most_similar_semantic[i][:10]] 70 | 71 | global_i += 1 72 | json_line = json.dumps(ex) 73 | f.write(f"{json_line}\n") 74 | -------------------------------------------------------------------------------- /src/retrieve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from tqdm import trange 5 | from dataclasses import dataclass 6 | from collections import defaultdict 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from datasets import load_from_disk 12 | from models.retriever import Retriever 13 | 14 | 15 | @dataclass 16 | class DataCollator: 17 | 18 | def __call__(self, features, return_tensors=None): 19 | import numpy as np 20 | 21 | seq_num = len(features[0]['srcs_ids']) 22 | features = [{k if 'rep' not in k else 'labels': 23 | torch.vstack(v) if 'rep' not in k else v.reshape(seq_num, -1) 24 | for k, v in feature.items()} for feature in features] 25 | batch = {} 26 | for key in features[0].keys(): 27 | batch[key] = torch.cat([feature[key].unsqueeze(0) for feature in features], dim=0) 28 | 29 | return batch 30 | 31 | if __name__ == "__main__": 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('-d', '--data_path', type=str, help="path to the dataset", default="data/retrieval") 35 | parser.add_argument('-m', '--model_path', type=str, help="path to the retrieval model", default="temp/retriever") 36 | parser.add_argument('-s', '--save_path', type=str, help="path to save the output", default="temp/retrieved") 37 | parser.add_argument( 38 | '-r', '--ref_type', 39 | type=str, default="style", 40 | choices=["style", "semantic"], 41 | ) 42 | parser.add_argument( 43 | '-b', '--batch_size', 44 | type=int, default=2, 45 | ) 46 | args = parser.parse_args() 47 | 48 | dataset = load_from_disk(os.path.join(args.data_path, "hf_dataset")) 49 | data_columns = [ 50 | 'srcs_ids', 51 | 'srcs_attention_mask', 52 | 'tgts_ids', 53 | 'tgts_attention_mask', 54 | f'{args.ref_type}_rep', 55 | ] 56 | dataset.set_format( 57 | type='torch', 58 | columns=data_columns, 59 | ) 60 | 61 | data_collator = DataCollator() 62 | 63 | model = Retriever("distilroberta-base", use_gold_tgt_rep=False, nhead=12) 64 | 65 | path = os.path.join(args.model_path, args.ref_type, "model/checkpoint-best/pytorch_model.bin") 66 | state_dict = torch.load(path) 67 | model.load_state_dict(state_dict) 68 | 69 | _ = model.cuda() 70 | _ = model.eval() 71 | 72 | bsz = args.batch_size 73 | retrieved_ids = defaultdict(list) 74 | for split in ['train', 'valid', 'test']: 75 | 76 | pbar = trange(0, len(dataset[split]), bsz, disable=False) 77 | for i in pbar: 78 | features = [dataset[split][j] for j in range(i, min(i+bsz, len(dataset[split])))] 79 | features = data_collator(features) 80 | batch = {k: v.cuda() for k, v in features.items()} 81 | 82 | preds = model(**batch).logits 83 | preds = F.normalize(preds, dim=-1) 84 | reps = F.normalize(batch['labels'], dim=-1) 85 | 86 | pred_sims = torch.matmul(preds, reps.transpose(1, 2)) 87 | pred_sims += -2. * torch.ones_like(pred_sims).triu(diagonal=90) 88 | pred_sorted_sims, pred_sorted_indices = pred_sims.sort(descending=True, dim=-1) 89 | retrieved_ids[split].append(pred_sorted_indices[:, :, :10].cpu()) 90 | 91 | retrieved_ids[split] = torch.cat(retrieved_ids[split], dim=0).reshape(1, -1, 10).squeeze(0) 92 | 93 | torch.save(retrieved_ids, os.path.join(args.save_path, f"{args.ref_type}.pt")) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RECAP 2 | 3 | This is the project that implements personal specific text generation project. 4 | It uses the gpt model as a basis and fine-tuning them to achieve the goal. 5 | Enjoy using this proj. 6 | # Installation 7 | 8 | Commends for enviroment setup with conda. 9 | ```bash 10 | conda create --name recap python=3.8 11 | conda activate recap 12 | pip install -U pip 13 | pip install -r requirements.txt 14 | 15 | ``` 16 | 17 | # Data 18 | 19 | The data is extracted from the Reddit dump from [pushshift.io](https://pushshift.io/). To preserve persona and personal writing style as much as possible, we did not filter out conversations with unethical content. You can download the raw data from the link [here](https://drive.google.com/file/d/1YC43Pqn15E7IIb90hjtauqRbwCOqAi3x/view?usp=sharing). 20 | 21 | # Pre-processing 22 | 23 | Pre-process the raw data into the format for retrieval and generation. 24 | 25 | ## Retrieval Data 26 | 27 | ### Encode text representations 28 | 29 | ```bash 30 | python src/preprocess/encode_comments.py -d -o 31 | ``` 32 | 33 | ### Retrieval 34 | 35 | ```bash 36 | python src/preprocess/retrieval.py -d -o 37 | ``` 38 | 39 | ## Generation Data 40 | 41 | ### Most recent hisotry responses 42 | 43 | ```bash 44 | python src/preprocess/recent.py -d -o 45 | ``` 46 | 47 | ### Retrieved by hierarchical transformer 48 | 49 | This requires the retriever output in `retrieved_path`. Please see section `training retriever` and `inference retrieve` for details on how to train and retrieve with the hierarchical transformer retriever. 50 | ```bash 51 | python src/preprocess/retrieved.py -d -r -o 52 | ``` 53 | 54 | # Training 55 | 56 | Train the retriever and the generator on a single GPU. The code works for multi GPUs, but the `batch_size` here is per device batch size, so please change it accordingly if you use more than one GPU. 57 | 58 | ## Retriever 59 | 60 | ```bash 61 | python src/train_retriever.py \ 62 | --data_path \ 63 | --raw_data_path \ 64 | --reps_path \ 65 | --save_path \ 66 | --ref_type