├── img ├── dataset.png └── framework.png ├── dataset └── sample_data.pkl ├── requirements.txt ├── utils ├── evaluation.py ├── data_loader.py └── loss.py ├── src ├── attention.py └── TempATT.py ├── main.py └── README.md /img/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSAIL-SKKU/Temporal-Symptom-Aware-Multitask-Learning-KDD23/HEAD/img/dataset.png -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSAIL-SKKU/Temporal-Symptom-Aware-Multitask-Learning-KDD23/HEAD/img/framework.png -------------------------------------------------------------------------------- /dataset/sample_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSAIL-SKKU/Temporal-Symptom-Aware-Multitask-Learning-KDD23/HEAD/dataset/sample_data.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.7.0 2 | torch==1.12.1+cu113 3 | tqdm==4.65.0 4 | transformers==4.29.2 5 | typing-extensions==4.6.3 6 | urllib3==1.24 7 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from datetime import datetime 4 | from pprint import pprint 5 | from pathlib import Path 6 | import pickle 7 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score,classification_report 8 | 9 | def evaluation(config, outputs, _type,y_true_col,y_pred_col,user_id_col): 10 | 11 | if _type == 'fs': 12 | if config['s_y_num'] == 4: 13 | label_names = ['su_indicator', 'su_ideation','su_behavior', 'su_attempt'] 14 | elif config['s_y_num']== 3: 15 | label_names = ['su_indicator', 'su_ideation','su_behav + att'] 16 | elif config['s_y_num'] == 2: 17 | label_names = ['su_indicator', 'su_id+beh+att'] 18 | 19 | if _type == 'bd': 20 | label_names = ['bp_no','bp_remission', 'bp_manic', 'bp_irritability', 21 | 'bp_anxiety', 'bp_depressed', 'bp_psychosis', 'bp_somatic'] 22 | 23 | y_true = [] 24 | y_pred = [] 25 | user_id = [] 26 | 27 | for i in outputs: 28 | y_true += i[y_true_col] 29 | y_pred += i[y_pred_col] 30 | user_id += i[user_id_col] 31 | 32 | y_true = np.asanyarray(y_true) 33 | y_pred = np.asanyarray(y_pred) 34 | user_id = np.asanyarray(user_id) 35 | 36 | pred_dict = {} 37 | pred_dict['user_id']= user_id 38 | pred_dict['y_true']= y_true 39 | pred_dict['y_pred']= y_pred 40 | 41 | print("-------test_report-------") 42 | metrics_dict = classification_report(y_true, y_pred,zero_division=1, 43 | target_names = label_names, 44 | output_dict=True) 45 | df_result = pd.DataFrame(metrics_dict).transpose() 46 | pprint(df_result) 47 | 48 | print("-------save test_report-------") 49 | save_time = datetime.now().__format__("%m%d_%H%M%S%Z") 50 | save_path = f"../result/" 51 | Path(f"{save_path}/pred").mkdir(parents=True, exist_ok=True) 52 | 53 | df_result.to_csv(f'../result/{save_time}_{_type}.csv') 54 | with open(f'{save_path}pred/{save_time}_{_type}_pred.pkl', "wb") as outfile: 55 | pickle.dump(pred_dict, outfile) -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import torch.nn as nn 4 | from datetime import datetime 5 | import numpy as np 6 | 7 | 8 | def pad_collate_reddit(batch): 9 | s_y = [item[0] for item in batch] 10 | b_y = [item[1] for item in batch] 11 | tweets = [torch.nan_to_num(item[2]) for item in batch] # 각자 잘하는 embedding이 달라서! 12 | timestamp = [item[3] for item in batch] 13 | user_id = [item[4] for item in batch] 14 | 15 | #torch.nan_to_num(a) 16 | post_num = [len(x) for x in b_y] 17 | 18 | b_y = nn.utils.rnn.pad_sequence(b_y, batch_first=True, padding_value=0) 19 | tweets = nn.utils.rnn.pad_sequence(tweets, batch_first=True, padding_value=0) 20 | timestamp = nn.utils.rnn.pad_sequence(timestamp, batch_first=True, padding_value=0) 21 | 22 | 23 | 24 | post_num = torch.tensor(post_num) 25 | s_y = torch.tensor(s_y) 26 | user_id = torch.tensor(user_id) 27 | 28 | return [s_y, b_y, post_num, tweets,timestamp,user_id] 29 | 30 | def get_timestamp(x): 31 | def change_utc(x): 32 | try: 33 | x = str(datetime.fromtimestamp(int(x)/1000)) 34 | return x 35 | except: 36 | return str(x) 37 | 38 | timestamp = [datetime.timestamp(datetime.strptime(change_utc(t),"%Y-%m-%d %H:%M:%S")) for t in x] 39 | time_interval = (timestamp[-1] - np.array(timestamp)) 40 | return time_interval 41 | 42 | 43 | class RedditDataset(Dataset): 44 | def __init__(self, s_y, b_y, tweets,timestamp,user_id, days=30): 45 | super().__init__() 46 | self.s_y = s_y 47 | self.b_y = b_y 48 | self.tweets = tweets 49 | self.timestamp = timestamp 50 | self.user_id = user_id 51 | 52 | self.days = days 53 | 54 | def __len__(self): 55 | return len(self.s_y) 56 | 57 | def __getitem__(self, item): 58 | s_y = torch.tensor(self.s_y[item], dtype=torch.long) 59 | user_id = torch.tensor(self.user_id[item], dtype=torch.long) 60 | 61 | if self.days > len(self.tweets[item]): 62 | b_y = torch.tensor(self.b_y[item], dtype=torch.long) 63 | tweets = torch.tensor(self.tweets[item], dtype=torch.float32) 64 | timestamp = get_timestamp(self.timestamp[item]) 65 | timestamp = torch.tensor(timestamp, dtype=torch.float32) 66 | 67 | else: 68 | b_y = torch.tensor(self.b_y[item][:self.days], dtype=torch.long) 69 | tweets = torch.tensor(self.tweets[item][:self.days], dtype=torch.float32) 70 | timestamp = get_timestamp(self.timestamp[item][:self.days]) 71 | timestamp = torch.tensor(timestamp, dtype=torch.float32) 72 | 73 | return [s_y, b_y, tweets,timestamp,user_id] 74 | 75 | 76 | -------------------------------------------------------------------------------- /src/attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from pprint import pprint 5 | from pathlib import Path 6 | from collections import Counter 7 | import pickle 8 | import random 9 | import argparse 10 | import time 11 | from datetime import datetime 12 | import math 13 | 14 | # torch: 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 19 | from torch.utils.data import Dataset, DataLoader 20 | from torch.optim.lr_scheduler import ExponentialLR 21 | 22 | 23 | class Attention(nn.Module): 24 | def __init__(self, device,hidden_size, batch_first=False): 25 | super(Attention, self).__init__() 26 | 27 | self.hidden_size = hidden_size 28 | self.batch_first = batch_first 29 | self.device = device 30 | self.att_weights = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=True) 31 | 32 | stdv = 1.0 / np.sqrt(self.hidden_size) 33 | for weight in self.att_weights: 34 | nn.init.uniform_(weight, -stdv, stdv) 35 | 36 | def get_mask(self): 37 | pass 38 | 39 | def forward(self, inputs, lengths): 40 | if self.batch_first: 41 | batch_size, max_len = inputs.size()[:2] 42 | else: 43 | max_len, batch_size = inputs.size()[:2] 44 | 45 | # apply attention layer 46 | weights = torch.bmm(torch.tanh(inputs), 47 | self.att_weights # (1, hidden_size) 48 | .permute(1, 0) # (hidden_size, 1) 49 | .unsqueeze(0) # (1, hidden_size, 1) 50 | .repeat(batch_size, 1, 1) # (batch_size, hidden_size, 1) 51 | ) 52 | 53 | attentions = torch.softmax(F.relu(weights.squeeze()), dim=-1) #F.relu # , dim=-1 54 | 55 | # create mask based on the sentence lengths 56 | mask = torch.ones(attentions.size(), requires_grad=True).to(f"cuda:{self.device}") 57 | for i, l in enumerate(lengths): # skip the first sentence 58 | if l < max_len: 59 | mask[i, l:] = 0 60 | 61 | # apply mask and renormalize attention scores (weights) 62 | masked = attentions * mask 63 | _sums = masked.sum(-1).unsqueeze(-1) # sums per row 64 | 65 | attentions = masked.div(_sums) 66 | 67 | if attentions.dim() == 1: 68 | attentions = attentions.unsqueeze(1) 69 | 70 | # apply attention weights 71 | weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs)) 72 | 73 | # get the final fixed vector representations of the sentences 74 | representations = weighted.sum(1).squeeze() 75 | 76 | return representations, attentions -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import random 5 | 6 | import torch 7 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything 8 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 9 | from pytorch_lightning.callbacks import Callback, ModelCheckpoint 10 | from pytorch_lightning.loggers import TensorBoardLogger 11 | 12 | ## 추가 13 | import configparser 14 | import warnings 15 | warnings.filterwarnings('ignore') 16 | 17 | from src.TempATT import TempATT 18 | 19 | def th_seed_everything(seed: int = 2023): 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | os.environ["PYTHONHASHSEED"] = str(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) # type: ignore 25 | torch.backends.cudnn.deterministic = True # type: ignore 26 | torch.backends.cudnn.benchmark = True # type: ignore 27 | 28 | class Arg: 29 | epochs: int = 1 # Max Epochs, BERT paper setting [3,4,5] 30 | report_cycle: int = 30 # Report (Train Metrics) Cycle 31 | cpu_workers: int = os.cpu_count() # Multi cpu workers 32 | test_mode: bool = False # Test Mode enables `fast_dev_run` 33 | optimizer: str = 'AdamW' # AdamW vs AdamP 34 | lr_scheduler: str = 'exp' # ExponentialLR vs CosineAnnealingWarmRestarts 35 | fp16: bool = False # Enable train on FP16 36 | batch_size: int = 64 37 | max_post_num = 30 38 | task_num: int = 2 39 | 40 | def main(args,config): 41 | print("Using PyTorch Ver", torch.__version__) 42 | print("Fix Seed:", config['random_seed']) 43 | seed_everything(config['random_seed']) 44 | th_seed_everything(config['random_seed']) 45 | 46 | # 일단 mood 47 | model = TempATT(args,config) 48 | model.preprocess_dataframe() 49 | 50 | early_stop_callback = EarlyStopping( 51 | monitor='train_loss', 52 | patience=10, 53 | verbose=True, 54 | mode='min' 55 | ) 56 | 57 | print(":: Start Training ::") 58 | trainer = Trainer( 59 | logger=False, 60 | callbacks=[early_stop_callback], 61 | enable_checkpointing = False, 62 | max_epochs=args.epochs, 63 | fast_dev_run=args.test_mode, 64 | num_sanity_val_steps=None if args.test_mode else 0, 65 | deterministic=True, 66 | gpus=[config['gpu']] if torch.cuda.is_available() else None, 67 | precision=16 if args.fp16 else 32 68 | ) 69 | trainer.fit(model) 70 | trainer.test(model,dataloaders=model.test_dataloader()) 71 | 72 | if __name__ == '__main__': 73 | 74 | parser = argparse.ArgumentParser("main.py", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 75 | parser.add_argument("--dropout", type=float, default=0.01,help="dropout probablity") 76 | parser.add_argument("--lr", type=float, default=1e-5, help="learning rate") 77 | parser.add_argument("--gpu", type=int, default=1, help="save fname") 78 | parser.add_argument("--random_seed", type=int, default=2022) 79 | parser.add_argument("--bf", type=int, default=6) 80 | parser.add_argument("--af", type=int, default=30) 81 | parser.add_argument("--embed_type", type=str, default="sb") 82 | parser.add_argument("--hidden_dim", type=int, default=1024) 83 | parser.add_argument("--loss", type=str, default="oe") 84 | parser.add_argument("--save", type=str, default="test") 85 | parser.add_argument("--s_y_num", type=int, default=2) 86 | parser.add_argument("--b_y_num", type=int, default=8) 87 | parser.add_argument("--n_fold", type=int, default=4) 88 | 89 | config = parser.parse_args() 90 | print(config) 91 | args = Arg() 92 | 93 | main(args,config.__dict__) 94 | 95 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | def focal_loss(labels, logits, alpha, gamma): 7 | """Compute the focal loss between `logits` and the ground truth `labels`. 8 | Focal loss = -alpha_t * (1-pt)^gamma * log(pt) 9 | where pt is the probability of being classified to the true class. 10 | pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit). 11 | Args: 12 | labels: A float tensor of size [batch, num_classes]. 13 | logits: A float tensor of size [batch, num_classes]. 14 | alpha: A float tensor of size [batch_size] 15 | specifying per-example weight for balanced cross entropy. 16 | gamma: A float scalar modulating loss from hard and easy examples. 17 | Returns: 18 | focal_loss: A float32 scalar representing normalized total loss. 19 | """ 20 | BCLoss = F.binary_cross_entropy_with_logits(input=logits, target=labels, reduction="none") 21 | 22 | if gamma == 0.0: 23 | modulator = 1.0 24 | else: 25 | modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits))) 26 | 27 | loss = modulator * BCLoss 28 | 29 | weighted_loss = alpha * loss 30 | focal_loss = torch.sum(weighted_loss) 31 | 32 | focal_loss /= torch.sum(labels) 33 | return focal_loss 34 | 35 | 36 | def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma): 37 | """Compute the Class Balanced Loss between `logits` and the ground truth `labels`. 38 | Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits) 39 | where Loss is one of the standard losses used for Neural Networks. 40 | Args: 41 | labels: A int tensor of size [batch]. 42 | logits: A float tensor of size [batch, no_of_classes]. 43 | samples_per_cls: A python list of size [no_of_classes]. 44 | no_of_classes: total number of classes. int 45 | loss_type: string. One of "sigmoid", "focal", "softmax". 46 | beta: float. Hyperparameter for Class balanced loss. 47 | gamma: float. Hyperparameter for Focal loss. 48 | Returns: 49 | cb_loss: A float tensor representing class balanced loss 50 | """ 51 | effective_num = 1.0 - np.power(beta, samples_per_cls) 52 | weights = (1.0 - beta) / np.array(effective_num) 53 | weights = weights / np.sum(weights) * no_of_classes 54 | 55 | labels_one_hot = F.one_hot(labels, no_of_classes).float() 56 | weights = torch.tensor(weights, dtype=torch.float32).cuda() 57 | weights = weights.unsqueeze(0) 58 | weights = weights.repeat(labels_one_hot.shape[0], 1) * labels_one_hot 59 | weights = weights.sum(1) 60 | weights = weights.unsqueeze(1) 61 | weights = weights.repeat(1, no_of_classes) 62 | 63 | if loss_type == "focal": 64 | cb_loss = focal_loss(labels_one_hot, logits, weights, gamma) 65 | elif loss_type == "sigmoid": 66 | cb_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels_one_hot, weight=weights) 67 | elif loss_type == "softmax": 68 | pred = logits.softmax(dim=1) 69 | cb_loss = F.binary_cross_entropy(input=pred, target=labels_one_hot, weight=weights) 70 | 71 | return cb_loss 72 | 73 | def true_metric_loss(true, no_of_classes, scale=1): 74 | batch_size = true.size(0) 75 | true = true.view(batch_size,1) 76 | true_labels = torch.cuda.LongTensor(true).repeat(1, no_of_classes).float() 77 | class_labels = torch.arange(no_of_classes).float().cuda() 78 | phi = (scale * torch.abs(class_labels - true_labels)).cuda() 79 | y = nn.Softmax(dim=1)(-phi) 80 | return y 81 | 82 | def loss_function(output, labels, loss_type, expt_type, scale): 83 | if loss_type == 'oe': 84 | targets = true_metric_loss(labels, expt_type, scale) 85 | return torch.sum(- targets * F.log_softmax(output, -1), -1).mean() 86 | 87 | elif loss_type == 'focal': 88 | beta = 0.9999 89 | gamma = 2.0 90 | no_of_classes = 4 91 | loss_type = "focal" 92 | 93 | sample= torch.bincount(labels,minlength=expt_type).cpu()# label 개수 94 | sample = np.where(sample==0,0.0001,sample) 95 | loss = CB_loss(labels, output, sample, expt_type, loss_type, beta, gamma) 96 | return loss 97 | 98 | else: 99 | loss_fct = nn.CrossEntropyLoss() 100 | loss = loss_fct(output, labels) #.view(-1, self.num_labels),.view(-1) 101 | return loss 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Towards Suicide Prevention from Bipolar Disorder with Temporal Symptom-Aware Multitask Learning 2 | This codebase contains the python scripts for the model for the KDD 2023. https://arxiv.org/abs/2307.00995. 3 | 4 | ## Environment & Installation Steps 5 | Python 3.8 & Pytorch 1.12 6 | ``` 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## Run 11 | Execute the following steps in the same environment: 12 | ``` 13 | cd Temporal-Symptom-Aware-Multitask-Learning-KDD23 & python main.py 14 | ``` 15 | 16 | ## Dataset Format 17 | Processed dataset format should be a DataFrame as a .pkl file having the following columns: 18 | 1. cur_bp_y : Bipolar symptom labels 0~7 19 | 2. fu_30_su_y : suicidality levels 0~3 20 | 3. sb_1024 : list of lists consisting of 1024-dimensional encoding for each reddit post. 21 | 4. created_utc : list containing the datetime objects corresponding to each reddit post. 22 | 23 | 24 | ## Annotation Process 25 | To label the collected Reddit dataset, we recruited four researchers, who are knowledgeable in psychology and fluent in English, as annotators. With the supervision of a psychiatrist, the four trained annotators labeled 818 users and their 7,592 anonymized Reddit posts using the open-source text annotation tool Doccano. During annotations, we mainly consider two different label categories: (i) BD symptoms (e.g., manic, anxiety) and (ii) suicidality levels (e.g., ideation, attempt). We further annotate the diagnosed BD type (e.g., BD-I, BD-II) for data analysis. If there is any conflict in the annotated labels across the annotators, all the annotators discuss and reach to an agreement under the supervision of the psychiatrists. 26 | 27 | 28 | 29 | ## Ethical Concerns 30 | We carefully consider potential ethical issues in this work: (i) protecting users' privacies on Reddit and (ii) avoiding potentially harmful uses of the proposed dataset. The Reddit privacy policy explicitly authorizes third parties to copy user content through the Reddit API. We follow the widely-accepted social media research ethics policies that allow researchers to utilize user data without explicit consent if anonymity is protected (benton et al. 2017; Williams et al., 2017). Any metadata that could be used to specify the author was not collected. In addition, all content is manually scanned to remove personally identifiable information and mask all the named entities. More importantly, the BD dataset will be shared only with other researchers who have agreed to the ethical use of the dataset. This study was reviewed and approved by the Institutional Review Board (SKKU2022-11-038). 31 | 32 | ## How to Request Access 33 | While it is important to ensure that all necessary precautions are taken, we are enthusiastic about sharing this valuable resource with fellow researchers. To request access to the dataset, please contact Daeun Lee (delee12@skku.edu). Access requests should follow the format of the sample application provided below, which consists of three parts: 34 | - Part 0: Download a sample application form (https://sites.google.com/view/daeun-lee/dataset/kdd-2023) 35 | - Part 1: Applicant Information 36 | - Part 2: Dataset Access Application 37 | - Part 3: Ethical Review by Your Organization 38 | 39 | The dataset was produced at Sungkyunkwan University (SKKU) in South Korea, and the research conducted on this dataset at SKKU has been granted exemption from Institutional Review Board (IRB) evaluation by SKKU's IRB (SKKU2022-11-038). This exemption applies to the analysis of pre-existing data that is publicly accessible or involves individuals who cannot be directly identified or linked to identifiable information. Nevertheless, due to the potentially sensitive nature of this data, we require that researchers who receive the data obtain ethical approval from their respective organizations. 40 | 41 | Please submit your access request to Daeun Lee (delee12@skku.edu) and ensure that you include all the necessary information and address the points outlined in the sample application. 42 | 43 | 44 | ## Dataset Availability and Governance Plan 45 | Inspired by the data sharing system of previous research (Zirikly et al. 2019), we have decided to establish a governance process for researcher access to the dataset, following the procedure outlined below. 46 | Due to limitations in the number of available individuals, three out of the five authors will be selected to review access requests submitted in the format specified below. The outcomes of the review will result in the following responses: 47 | 48 | - Approval: If all three members give their approval, the application will be deemed approved, and Daeun will proceed to share the dataset with the researcher. 49 | - Inquiries: The authors may have questions or seek clarification, prompting further communication. 50 | - Revision and resubmission: Should the authors provide specific suggestions for revising and resubmitting the application, the researcher will have the opportunity to address them. 51 | - Rejection: In the event of unanimous disapproval from the authors, the dataset will not be shared. 52 | 53 | The authors will prioritize and promote diversity and inclusivity among the reviewers and the community of researchers utilizing the dataset. 54 | 55 | __Reference__ 56 | Zirikly, A., Resnik, P., Uzuner, O., & Hollingshead, K. (2019, June). CLPsych 2019 shared task: Predicting the degree of suicide risk in Reddit posts. In Proceedings of the sixth workshop on computational linguistics and clinical psychology (pp. 24-33)# Temporal-Symptom-Aware-Multitask-Learning-KDD23 57 | 58 | --- 59 | ### If our work was helpful in your research, please kindly cite this work: 60 | 61 | ``` 62 | @inproceedings{lee2023towards, 63 | title={Towards Suicide Prevention from Bipolar Disorder with Temporal Symptom-Aware Multitask Learning}, 64 | author={Lee, Daeun and Son, Sejung and Jeon, Hyolim and Kim, Seungbae and Han, Jinyoung}, 65 | booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 66 | pages={4357--4369}, 67 | year={2023} 68 | } 69 | ``` 70 | 71 | ### Acknowledgments 72 | This research was supported by the Ministry of Education of the Republic of Korea and the National Research Foundation of Korea (NRF-2022S1A5A8054322) and the National Research Foundation of Korea (NRF) grant funded by the Korea government (MSIT) (No. 2023R1A2C2007625). 73 | 74 | ### Our Lab Site 75 | [Data Science & Artificial Intelligence Laboratory (DSAIL) @ Sungkyunkwan University](https://sites.google.com/view/datasciencelab/home) 76 | -------------------------------------------------------------------------------- /src/TempATT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import argparse 5 | 6 | # torch: 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader 11 | from torch.optim.lr_scheduler import ExponentialLR 12 | 13 | from pytorch_lightning import LightningModule 14 | from transformers import AdamW 15 | 16 | from sklearn.model_selection import StratifiedGroupKFold 17 | #from imblearn.over_sampling import RandomOverSampler 18 | 19 | from utils.loss import loss_function 20 | from utils.data_loader import RedditDataset, pad_collate_reddit 21 | from utils.evaluation import * 22 | from src.attention import Attention 23 | 24 | class Arg: 25 | epochs: int = 1 # Max Epochs, BERT paper setting [3,4,5] 26 | report_cycle: int = 30 # Report (Train Metrics) Cycle 27 | cpu_workers: int = os.cpu_count() # Multi cpu workers 28 | test_mode: bool = False # Test Mode enables `fast_dev_run` 29 | optimizer: str = 'AdamW' # AdamW vs AdamP 30 | lr_scheduler: str = 'exp' # ExponentialLR vs CosineAnnealingWarmRestarts 31 | fp16: bool = False # Enable train on FP16 32 | batch_size: int = 64 33 | max_post_num = 30 34 | task_num: int = 2 35 | 36 | class TempATT(LightningModule): 37 | def __init__(self, args,config): 38 | super().__init__() 39 | # config: 40 | self.args = args 41 | self.config = config 42 | 43 | #model 44 | self.embed_type = self.config['embed_type'] + "_" + str(self.config['hidden_dim']) 45 | self.embed_layer = nn.Linear(self.config['hidden_dim'], self.config['hidden_dim']) 46 | self.lstm = nn.LSTM(input_size=self.config['hidden_dim'], 47 | hidden_size=int(self.config['hidden_dim']/2), 48 | num_layers=2, 49 | bidirectional=True) 50 | 51 | self.time_var = nn.Parameter(torch.randn((2)), requires_grad=True) 52 | self.atten = Attention(self.config['gpu'],self.config['hidden_dim'], batch_first=True) # 2 is bidrectional 53 | self.dropout = nn.Dropout(self.config['dropout']) 54 | 55 | # suicide 56 | self.fc_1 = nn.Linear(self.config['hidden_dim'], self.config['hidden_dim']) 57 | self.fc_2 = nn.Linear(self.config['hidden_dim'], self.config['s_y_num']) 58 | 59 | # aux 60 | self.b_decoder = nn.Linear(self.config['hidden_dim'], self.config['b_y_num']) 61 | 62 | # unweighted loss 63 | self.log_vars = nn.Parameter(torch.randn((self.args.task_num))) 64 | 65 | def forward(self, s_y, b_y, p_num, tweets,timestamp): 66 | #lstm 67 | x = self.dropout(tweets) 68 | 69 | # aux 70 | b_out = self.b_decoder(x) 71 | logits_b = nn.utils.rnn.pack_padded_sequence(b_out, p_num.cpu(), batch_first=True, enforce_sorted=False)[0] 72 | b_y = nn.utils.rnn.pack_padded_sequence(b_y, p_num.cpu(), batch_first=True, enforce_sorted=False)[0] 73 | b_loss = nn.MultiLabelSoftMarginLoss(weight=None,reduction='mean')(logits_b, b_y) 74 | 75 | # main 76 | x = nn.utils.rnn.pack_padded_sequence(x, p_num.cpu(), batch_first=True, enforce_sorted=False) 77 | out, (h_n, c_n) = self.lstm(x) 78 | x, lengths = nn.utils.rnn.pad_packed_sequence(out, batch_first=True) 79 | 80 | #time sensitive attention suicide 81 | timestamp = torch.exp(self.time_var[0]) *timestamp + self.time_var[0] 82 | timestamp = torch.sigmoid(timestamp+ self.time_var[1]) #.size() 83 | x = x+ x*timestamp.unsqueeze(-1) 84 | h, att_score = self.atten(x, p_num.cpu()) # skip connect 85 | 86 | #reddit model 87 | if h.dim() == 1: 88 | h = h.unsqueeze(0) 89 | 90 | logits_s = self.fc_2(self.fc_1(self.dropout(h))) 91 | #logits_s = logits_s.view(-1, self.s_y_num) 92 | s_loss = loss_function(logits_s, s_y, self.config['loss'], self.config['s_y_num'], 1.8) 93 | 94 | # multi task loss 95 | s_prec = torch.exp(-self.log_vars[0]) 96 | s_loss = s_prec*s_loss + self.log_vars[0] 97 | 98 | b_prec = torch.exp(-self.log_vars[1]) 99 | b_loss = b_prec*b_loss + self.log_vars[1] 100 | 101 | total_loss = s_loss + b_loss 102 | return total_loss, b_loss, logits_s, timestamp,att_score, b_y,logits_b 103 | 104 | 105 | def configure_optimizers(self): 106 | optimizer = AdamW(self.parameters(), lr=self.config['lr']) 107 | scheduler = ExponentialLR(optimizer, gamma=0.001) 108 | return { 109 | 'optimizer': optimizer, 110 | 'scheduler': scheduler, 111 | } 112 | 113 | def preprocess_dataframe(self): 114 | data_path = './dataset/sample_data.pkl' 115 | df = pd.read_pickle(data_path) 116 | 117 | # class split 118 | self.s_y_col = "fu_" + str(self.config['af']) + "_su_y" 119 | if self.config['s_y_num'] == 3: 120 | df[self.s_y_col] = df[self.s_y_col].apply(lambda x: 2 if x in [2,3] else x) 121 | elif self.config['s_y_num'] == 2: 122 | df[self.s_y_col] = df[self.s_y_col].apply(lambda x: 1 if x in [1,2,3] else x) 123 | 124 | cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=self.config['random_seed']) 125 | for i,(train_idxs, test_idxs) in enumerate(cv.split(df, df[self.s_y_col], df['author'])): 126 | if i == self.config['n_fold']: 127 | break 128 | self.df_train = df.iloc[train_idxs] 129 | self.df_test = df.iloc[test_idxs] 130 | print(f'# of train:{len(self.df_train)}, val:0, test:{len(self.df_test)}') 131 | 132 | # ros = RandomOverSampler(random_state=2023) 133 | # df_train, y_res = ros.fit_resample(df_train, df_train[self.s_y_col].tolist()) 134 | 135 | def train_dataloader(self): 136 | self.train_data = RedditDataset( 137 | self.df_train[self.s_y_col].values, 138 | self.df_train['cur_bp_y'].values, 139 | self.df_train[self.embed_type].values, 140 | self.df_train["created_utc"].values, 141 | self.df_train['user_id'].values 142 | ) 143 | return DataLoader( 144 | self.train_data, 145 | batch_size=self.args.batch_size, 146 | collate_fn=pad_collate_reddit, 147 | shuffle=True, 148 | num_workers=self.args.cpu_workers, 149 | ) 150 | 151 | def test_dataloader(self): 152 | self.test_data = RedditDataset( 153 | self.df_test[self.s_y_col].values, 154 | self.df_test['cur_bp_y'].values, 155 | self.df_test[self.embed_type].values, 156 | self.df_test["created_utc"].values, 157 | self.df_test['user_id'].values 158 | ) 159 | return DataLoader( 160 | self.test_data, 161 | batch_size=self.args.batch_size, 162 | collate_fn=pad_collate_reddit, 163 | shuffle=False, 164 | num_workers=self.args.cpu_workers, 165 | ) 166 | 167 | def training_step(self, batch, batch_idx): 168 | s_y, b_y, p_num, tweets,timestamp,user_id = batch 169 | loss, b_loss, logit,timestamp,att_score, b_true, b_pred= self(s_y, b_y, p_num, tweets,timestamp) 170 | self.log("train_loss", loss) 171 | 172 | return {'loss': loss} 173 | 174 | def test_step(self, batch, batch_idx): 175 | s_y, b_y, p_num, tweets,timestamp,user_id = batch 176 | loss, b_loss, logit,timestamp,att_score, b_true, b_pred= self(s_y, b_y, p_num, tweets,timestamp) 177 | 178 | # preds 179 | s_true = list(s_y.cpu().numpy()) 180 | s_preds = list(logit.argmax(dim=-1).cpu().numpy()) 181 | 182 | b_true = list(b_true.cpu().numpy()) 183 | b_pred = F.softmax(b_pred, dim=1) 184 | b_preds = np.array(b_pred.cpu()>0.14).astype(int) 185 | b_preds = list(b_preds) 186 | user_id = list(user_id.cpu().numpy()) 187 | 188 | return { 189 | 'loss': loss, 190 | 's_true': s_true, 191 | 's_preds': s_preds, 192 | 'b_true': b_true, 193 | 'b_preds':b_preds, 194 | 'user_id':user_id 195 | } 196 | 197 | def test_epoch_end(self, outputs): 198 | 199 | evaluation(self.config,outputs, 'fs','s_true', 's_preds','user_id') 200 | evaluation(self.config,outputs, 'bd','b_true', 'b_preds','user_id') 201 | 202 | 203 | --------------------------------------------------------------------------------