├── .idea
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
└── RECAP.iml
├── requirements.txt
├── LICENSE
├── src
├── utils
│ ├── retriever_utils.py
│ └── generator_utils.py
├── preprocess
│ ├── encode_comments.py
│ ├── retrieval.py
│ ├── recent.py
│ └── retrieved.py
├── retrieve.py
├── models
│ ├── retriever.py
│ └── dialogpt.py
├── generate.py
├── train_generator.py
├── train_retriever.py
└── eval.py
└── README.md
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm==4.64.1
2 | numpy==1.23.3
3 | nltk==3.8.1
4 | torch==1.12.1
5 | datasets==1.18.3
6 | transformers==4.17.0
7 | sentence-transformers==2.2.2
8 | pytorch-ignite==0.4.12
9 | bert-score==0.3.13
10 | git+https://github.com/google-research/bleurt.git
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/RECAP.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 shuailiu6626
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/utils/retriever_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | from transformers import PreTrainedTokenizerBase
5 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
6 | from transformers.file_utils import PaddingStrategy
7 |
8 |
9 | def encode_example(example, tokenizer):
10 | max_len=128
11 |
12 | encoded_example = {}
13 | encoded_srcs = tokenizer(
14 | [txt.replace("<|TITLE|> ", "").replace(" <|EOS|> ", tokenizer.eos_token) for txt in example['srcs']],
15 | max_length=max_len//2,
16 | truncation=True,
17 | padding="max_length",
18 | )
19 | encoded_tgts = tokenizer(
20 | [txt for txt in example['tgts']],
21 | max_length=max_len//2,
22 | truncation=True,
23 | padding="max_length"
24 | )
25 |
26 | encoded_example = {
27 | "srcs_ids": encoded_srcs.input_ids,
28 | "srcs_attention_mask": encoded_srcs.attention_mask,
29 | "tgts_ids": encoded_tgts.input_ids,
30 | "tgts_attention_mask": encoded_tgts.attention_mask,
31 | }
32 |
33 | return encoded_example
34 |
35 |
36 | @dataclass
37 | class DataCollator:
38 |
39 | def __call__(self, features, return_tensors=None):
40 | import numpy as np
41 |
42 | seq_num = len(features[0]['srcs_ids'])
43 | features = [{k if 'rep' not in k else 'labels':
44 | torch.vstack(v) if 'rep' not in k else v.reshape(seq_num, -1)
45 | for k, v in feature.items()} for feature in features]
46 | batch = {}
47 | for key in features[0].keys():
48 | batch[key] = torch.cat([feature[key].unsqueeze(0) for feature in features], dim=0)
49 |
50 | return batch
--------------------------------------------------------------------------------
/src/preprocess/encode_comments.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch
7 | from sentence_transformers import SentenceTransformer
8 |
9 |
10 | if __name__ == '__main__':
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('-d', '--data_path', type=str, help="path to the raw data", default="data/raw")
14 | parser.add_argument('-o', '--output_path', type=str, help="output processed data path", default="temp/reps")
15 | args = parser.parse_args()
16 |
17 | models = {}
18 | models['style'] = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cuda').half()
19 | models['semantic'] = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device='cuda').half()
20 |
21 | for split in ["train", "valid", "test", "pandora"]:
22 | lines = []
23 | with open(os.path.join(args.data_path, f"{split}.jsonl")) as f:
24 | for i, line in tqdm(enumerate(f), disable=True):
25 | line = json.loads(line)
26 | lines.append(line)
27 | lines = sorted(lines, key=lambda x:x[1])
28 | srcs = [line[2].replace("<|TITLE|> ", "").replace(" <|EOS|> ", "\n") for line in lines]
29 | tgts = [line[3] for line in lines]
30 |
31 | for rep_type in ['style', 'semantic']:
32 | reps = {}
33 | reps['src'] = models[rep_type].encode(srcs, batch_size=1024, convert_to_tensor=False, normalize_embeddings=False, show_progress_bar=False)
34 | reps['tgt'] = models[rep_type].encode(tgts, batch_size=1024, convert_to_tensor=False, normalize_embeddings=False, show_progress_bar=False)
35 | reps['src'] = torch.from_numpy(reps['src']).half()
36 | reps['tgt'] = torch.from_numpy(reps['tgt']).half()
37 | torch.save(reps, os.path.join(args.output_path, f"{rep_type}/{split}_fp16.pt"), create_dir=True)
--------------------------------------------------------------------------------
/src/preprocess/retrieval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | from collections import defaultdict
5 | import pandas as pd
6 | import random
7 | import json
8 |
9 | import torch
10 | from datasets import load_dataset, load_from_disk
11 |
12 |
13 | if __name__ == '__main__':
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('-d', '--data_path', type=str, help="path to the raw data", default="data/raw")
17 | parser.add_argument('-o', '--output_path', type=str, help="output processed data path", default="data/retrieval")
18 | args = parser.parse_args()
19 |
20 | for split in ["train", "valid", "test", "pandora"]:
21 | lines = []
22 | user2lines = defaultdict(list)
23 | with open(os.path.join(args.data_path, f"{split}.jsonl")) as f:
24 | for i, line in tqdm(enumerate(f), disable=True):
25 | line = json.loads(line)
26 | lines.append(line)
27 |
28 | lines = sorted(lines, key=lambda x:x[1])
29 | for i, line in enumerate(lines):
30 | author = line[0]
31 | user2lines[author].append(i)
32 |
33 | authors = sorted(user2lines.keys())
34 | author2idx = {author: idx for idx, author in enumerate(authors)}
35 |
36 | author_comments = defaultdict(list)
37 | for author in tqdm(authors, disable=True):
38 | author_lines = user2lines.get(author, None)
39 | assert author_lines
40 | prev_t = 0
41 | for author_line in author_lines:
42 | line = lines[author_line]
43 | _author, timestamp, src, tgt, subreddit = line
44 |
45 | assert author == _author
46 | assert prev_t <= timestamp
47 | prev_t = timestamp
48 |
49 | author_comments[author].append((int(timestamp), src, tgt, author_line))
50 | author_comments[author] = sorted(author_comments[author], key=lambda x:x[0])
51 | assert len(author_comments) == len(authors)
52 |
53 | with open(os.path.join(args.output_path, f"{split}.jsonl"), 'w') as f:
54 | for author, samples in tqdm(author_comments.items(), disable=True):
55 |
56 | ex = {
57 | "srcs": [sample[1] for sample in samples],
58 | "tgts": [sample[2] for sample in samples],
59 | }
60 |
61 | json_line = json.dumps(ex)
62 | f.write(f"{json_line}\n")
--------------------------------------------------------------------------------
/src/preprocess/recent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from collections import defaultdict
5 |
6 |
7 |
8 | if __name__ == "__main__":
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('-d', '--data_path', type=str, help="path to the raw data", default="data/raw")
12 | parser.add_argument('-o', '--output_path', type=str, help="output processed data path", default="data/recent")
13 | args = parser.parse_args()
14 |
15 | if not os.path.exists(args.output_path):
16 | os.makedirs(args.output_path)
17 |
18 | for split in ["train", "valid", "test", "pandora"]:
19 | lines = []
20 | user2lines = defaultdict(list)
21 | with open(os.path.join(args.data_path, f"{split}.jsonl")) as f:
22 | for i, line in enumerate(f):
23 | line = json.loads(line)
24 | lines.append(line)
25 |
26 | lines = sorted(lines, key=lambda x:x[1])
27 | for i, line in enumerate(lines):
28 | author = line[0]
29 | user2lines[author].append(i)
30 |
31 | authors = sorted(user2lines.keys())
32 | author2idx = {author: idx for idx, author in enumerate(authors)}
33 |
34 | author_comments = defaultdict(list)
35 | for author in authors:
36 | author_lines = user2lines.get(author, None)
37 | assert author_lines
38 | for author_line in author_lines:
39 | line = lines[author_line]
40 | _author, timestamp, src, tgt, subreddit = line
41 | assert author == _author
42 | author_comments[author].append((int(timestamp), src, tgt))
43 | author_comments[author] = sorted(author_comments[author], key=lambda x:x[0])
44 | assert len(author_comments) == len(authors)
45 |
46 | global_i = 0
47 | with open(os.path.join(args.output_path, f"{split}.jsonl"), 'w') as f:
48 | for author, samples in author_comments.items():
49 | recent_responses = []
50 | for i, (timestamp, src, tgt) in enumerate(samples):
51 | ex = {
52 | "id": global_i,
53 | "author": author,
54 | "author_id": author2idx[author],
55 | "timestamp": timestamp,
56 | "src": src,
57 | "tgt": tgt,
58 | }
59 | ex["style_tgts"] = recent_responses
60 | ex["semantic_tgts"] = None
61 |
62 | if i >= 90:
63 | global_i += 1
64 | json_line = json.dumps(ex)
65 | f.write(f"{json_line}\n")
66 |
67 | recent_responses.insert(0, f"{tgt}")
68 | if len(recent_responses) > 10:
69 | del recent_responses[-1]
--------------------------------------------------------------------------------
/src/preprocess/retrieved.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from tqdm import tqdm
5 | from collections import Counter, defaultdict
6 |
7 | import torch
8 |
9 |
10 | if __name__ == '__main__':
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('-d', '--data_path', type=str, help="path to the raw data", default="data/raw")
14 | parser.add_argument('-r', '--retrieved_path', type=str, help="path to the retriever output", default="temp/retrieved")
15 | parser.add_argument('-o', '--output_path', type=str, help="output processed data path", default="data/retrieved")
16 | args = parser.parse_args()
17 |
18 | retrieved = {
19 | "style": torch.load(os.path.join(args.retrieved_path, "style.pt")),
20 | "semantic": torch.load(os.path.join(args.retrieved_path, "semantic.pt")),
21 | }
22 |
23 | for split in ["train", "valid", "test", "pandora"]:
24 | sorted_style = retrieved['style'][split]
25 | sorted_semantic = retrieved['semantic'][split]
26 |
27 | lines = []
28 | user2lines = defaultdict(list)
29 | with open(os.path.join(args.data_path, f"{split}.jsonl")) as f:
30 | for i, line in tqdm(enumerate(f), disable=True):
31 | line = json.loads(line)
32 | lines.append(line)
33 |
34 | lines = sorted(lines, key=lambda x:x[1])
35 | for i, line in enumerate(lines):
36 | author = line[0]
37 | user2lines[author].append(i)
38 |
39 | authors = sorted(user2lines.keys())
40 | author2idx = {author: idx for idx, author in enumerate(authors)}
41 |
42 | author_comments = defaultdict(list)
43 | for author in tqdm(authors, disable=True):
44 | author_lines = user2lines.get(author, None)
45 | assert author_lines
46 | for author_line in author_lines:
47 | line = lines[author_line]
48 | _author, timestamp, src, tgt, subreddit = line
49 | assert author == _author
50 | author_comments[author].append((int(timestamp), src, tgt, author_line))
51 | author_comments[author] = sorted(author_comments[author], key=lambda x:x[0])
52 | assert len(author_comments) == len(authors)
53 |
54 | global_i = 0
55 | with open(os.path.join(args.output_path, f"{split}.jsonl"), 'w') as f:
56 | for i, (author, samples) in enumerate(tqdm(author_comments.items(), disable=True)):
57 | most_similar_style = sorted_style[i*10:(i+1)*10]
58 | most_similar_semantic = sorted_semantic[i*10:(i+1)*10]
59 | for i, (timestamp, src, tgt, _) in enumerate(samples[-10:]):
60 | ex = {
61 | "id": global_i,
62 | "author": author,
63 | "author_id": author2idx[author],
64 | "timestamp": timestamp,
65 | "src": src,
66 | "tgt": tgt,
67 | }
68 | ex["style_tgts"] = [samples[j][2] for j in most_similar_style[i][:10]]
69 | ex["semantic_tgts"] = [samples[j][2] for j in most_similar_semantic[i][:10]]
70 |
71 | global_i += 1
72 | json_line = json.dumps(ex)
73 | f.write(f"{json_line}\n")
74 |
--------------------------------------------------------------------------------
/src/retrieve.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | from tqdm import trange
5 | from dataclasses import dataclass
6 | from collections import defaultdict
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 | from datasets import load_from_disk
12 | from models.retriever import Retriever
13 |
14 |
15 | @dataclass
16 | class DataCollator:
17 |
18 | def __call__(self, features, return_tensors=None):
19 | import numpy as np
20 |
21 | seq_num = len(features[0]['srcs_ids'])
22 | features = [{k if 'rep' not in k else 'labels':
23 | torch.vstack(v) if 'rep' not in k else v.reshape(seq_num, -1)
24 | for k, v in feature.items()} for feature in features]
25 | batch = {}
26 | for key in features[0].keys():
27 | batch[key] = torch.cat([feature[key].unsqueeze(0) for feature in features], dim=0)
28 |
29 | return batch
30 |
31 | if __name__ == "__main__":
32 |
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('-d', '--data_path', type=str, help="path to the dataset", default="data/retrieval")
35 | parser.add_argument('-m', '--model_path', type=str, help="path to the retrieval model", default="temp/retriever")
36 | parser.add_argument('-s', '--save_path', type=str, help="path to save the output", default="temp/retrieved")
37 | parser.add_argument(
38 | '-r', '--ref_type',
39 | type=str, default="style",
40 | choices=["style", "semantic"],
41 | )
42 | parser.add_argument(
43 | '-b', '--batch_size',
44 | type=int, default=2,
45 | )
46 | args = parser.parse_args()
47 |
48 | dataset = load_from_disk(os.path.join(args.data_path, "hf_dataset"))
49 | data_columns = [
50 | 'srcs_ids',
51 | 'srcs_attention_mask',
52 | 'tgts_ids',
53 | 'tgts_attention_mask',
54 | f'{args.ref_type}_rep',
55 | ]
56 | dataset.set_format(
57 | type='torch',
58 | columns=data_columns,
59 | )
60 |
61 | data_collator = DataCollator()
62 |
63 | model = Retriever("distilroberta-base", use_gold_tgt_rep=False, nhead=12)
64 |
65 | path = os.path.join(args.model_path, args.ref_type, "model/checkpoint-best/pytorch_model.bin")
66 | state_dict = torch.load(path)
67 | model.load_state_dict(state_dict)
68 |
69 | _ = model.cuda()
70 | _ = model.eval()
71 |
72 | bsz = args.batch_size
73 | retrieved_ids = defaultdict(list)
74 | for split in ['train', 'valid', 'test']:
75 |
76 | pbar = trange(0, len(dataset[split]), bsz, disable=False)
77 | for i in pbar:
78 | features = [dataset[split][j] for j in range(i, min(i+bsz, len(dataset[split])))]
79 | features = data_collator(features)
80 | batch = {k: v.cuda() for k, v in features.items()}
81 |
82 | preds = model(**batch).logits
83 | preds = F.normalize(preds, dim=-1)
84 | reps = F.normalize(batch['labels'], dim=-1)
85 |
86 | pred_sims = torch.matmul(preds, reps.transpose(1, 2))
87 | pred_sims += -2. * torch.ones_like(pred_sims).triu(diagonal=90)
88 | pred_sorted_sims, pred_sorted_indices = pred_sims.sort(descending=True, dim=-1)
89 | retrieved_ids[split].append(pred_sorted_indices[:, :, :10].cpu())
90 |
91 | retrieved_ids[split] = torch.cat(retrieved_ids[split], dim=0).reshape(1, -1, 10).squeeze(0)
92 |
93 | torch.save(retrieved_ids, os.path.join(args.save_path, f"{args.ref_type}.pt"))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RECAP
2 |
3 | This is the project that implements personal specific text generation project.
4 | It uses the gpt model as a basis and fine-tuning them to achieve the goal.
5 | Enjoy using this proj.
6 | # Installation
7 |
8 | Commends for enviroment setup with conda.
9 | ```bash
10 | conda create --name recap python=3.8
11 | conda activate recap
12 | pip install -U pip
13 | pip install -r requirements.txt
14 |
15 | ```
16 |
17 | # Data
18 |
19 | The data is extracted from the Reddit dump from [pushshift.io](https://pushshift.io/). To preserve persona and personal writing style as much as possible, we did not filter out conversations with unethical content. You can download the raw data from the link [here](https://drive.google.com/file/d/1YC43Pqn15E7IIb90hjtauqRbwCOqAi3x/view?usp=sharing).
20 |
21 | # Pre-processing
22 |
23 | Pre-process the raw data into the format for retrieval and generation.
24 |
25 | ## Retrieval Data
26 |
27 | ### Encode text representations
28 |
29 | ```bash
30 | python src/preprocess/encode_comments.py -d -o
31 | ```
32 |
33 | ### Retrieval
34 |
35 | ```bash
36 | python src/preprocess/retrieval.py -d -o
37 | ```
38 |
39 | ## Generation Data
40 |
41 | ### Most recent hisotry responses
42 |
43 | ```bash
44 | python src/preprocess/recent.py -d -o
45 | ```
46 |
47 | ### Retrieved by hierarchical transformer
48 |
49 | This requires the retriever output in `retrieved_path`. Please see section `training retriever` and `inference retrieve` for details on how to train and retrieve with the hierarchical transformer retriever.
50 | ```bash
51 | python src/preprocess/retrieved.py -d -r -o
52 | ```
53 |
54 | # Training
55 |
56 | Train the retriever and the generator on a single GPU. The code works for multi GPUs, but the `batch_size` here is per device batch size, so please change it accordingly if you use more than one GPU.
57 |
58 | ## Retriever
59 |
60 | ```bash
61 | python src/train_retriever.py \
62 | --data_path \
63 | --raw_data_path \
64 | --reps_path \
65 | --save_path \
66 | --ref_type