├── das_overview.png ├── src ├── scripts │ ├── others │ │ ├── others.sh │ │ ├── plot.sh │ │ ├── user_study.sh │ │ ├── visual.sh │ │ └── data.sh │ ├── detection │ │ ├── train.sh │ │ └── eval.sh │ ├── attack │ │ ├── stage1 │ │ │ ├── eval.sh │ │ │ └── train.sh │ │ └── stage2 │ │ │ ├── train.sh │ │ │ └── eval.sh │ └── summarizer │ │ ├── filter │ │ └── train.sh │ │ ├── ssra-kmeans │ │ ├── train.sh │ │ └── eval.sh │ │ ├── build_cluster.sh │ │ ├── cluster.sh │ │ └── eval_summary.sh ├── data │ ├── build_datasets.py │ ├── preprocess │ │ ├── statistics.py │ │ ├── preprocess_twitter.py │ │ └── plot_tsne.py │ ├── graph_dataset.py │ ├── cluster_topics.py │ ├── topic_model.py │ ├── build_datasets_clustering.py │ └── build_datasets_filter.py ├── pipelines │ ├── trainer_filter.py │ ├── build_trainer.py │ ├── trainer.py │ └── builder_cluster_summary.py ├── models │ ├── modeling_outputs.py │ ├── modeling_filter.py │ ├── modeling_gcn.py │ ├── modeling_clustering.py │ ├── modeling_bert.py │ └── build_model.py ├── main.py └── others │ ├── metrics.py │ ├── evaluate_summary.py │ ├── utils.py │ └── postprocess.py └── README.md /das_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshchang0111/EMNLP2023-RumorDAS/HEAD/das_overview.png -------------------------------------------------------------------------------- /src/scripts/others/others.sh: -------------------------------------------------------------------------------- 1 | #python others/analyze.py --framing_response --fold comp 2 | 3 | #python others/postprocess.py --semeval2019_event_wise_eval 4 | python others/postprocess.py --bootstrap_accuracy -------------------------------------------------------------------------------- /src/scripts/others/plot.sh: -------------------------------------------------------------------------------- 1 | #python others/plot.py --plot_response_impact --dataset_name semeval2019 2 | #python others/plot.py --plot_response_impact --dataset_name twitter15 3 | #python others/plot.py --plot_response_impact --dataset_name twitter16 4 | 5 | #python others/plot.py --plot_extract_ratio --dataset_name twitter15 6 | #python others/plot.py --plot_extract_ratio --dataset_name twitter16 7 | #python others/plot.py --plot_extract_ratio --dataset_name semeval2019 8 | 9 | #python others/plot.py --plot_all_extract_ratio 10 | CUDA_VISIBLE_DEVICES=1 python data/preprocess/plot_tsne.py -------------------------------------------------------------------------------- /src/scripts/detection/train.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export WANDB_PROJECT="RumorDAS" 5 | export WANDB_DIR=... ## need to be defined 6 | output_dir=/mnt/1T/projects/RumorDAS ## need to be defined 7 | batch_size=8 8 | exp_name=bi-tgn-roberta 9 | 10 | for dataset in re2019 twitter15 twitter16 11 | do 12 | for i in $(seq 0 4) 13 | do 14 | python main.py \ 15 | --task_type train_detector \ 16 | --model_name_or_path roberta-base \ 17 | --td_gcn \ 18 | --bu_gcn \ 19 | --dataset_name $dataset \ 20 | --train_file train.csv \ 21 | --validation_file test.csv \ 22 | --fold $i \ 23 | --do_train \ 24 | --per_device_train_batch_size $batch_size \ 25 | --learning_rate 2e-5 \ 26 | --num_train_epochs 10 \ 27 | --exp_name $exp_name \ 28 | --output_dir $output_dir 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /src/scripts/others/user_study.sh: -------------------------------------------------------------------------------- 1 | #for dataset in semeval2019 twitter15 twitter16 2 | #do 3 | # python others/select_examples.py --select_for_user_study --dataset_name $dataset 4 | #done 5 | 6 | #python others/select_examples.py --generate_for_google_forms 7 | #python others/select_examples.py --merge_en_zh 8 | #python others/select_examples.py --distribute_samples_for_diff_tasks 9 | #python others/select_examples.py --select_which_to_swap 10 | 11 | #python others/human_evaluation.py --evaluate_A1 12 | #python others/human_evaluation.py --evaluate_A2 13 | 14 | #python others/human_evaluation.py --evaluate_B1 15 | #python others/human_evaluation.py --evaluate_B2 16 | 17 | #python others/human_evaluation.py --box_plot_A 18 | #python others/human_evaluation.py --box_plot_AB 19 | 20 | #python others/human_evaluation.py --bar_plot_B 21 | 22 | python others/human_evaluation.py --agreement_analysis -------------------------------------------------------------------------------- /src/scripts/attack/stage1/eval.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export WANDB_PROJECT="RumorDAS" 5 | export WANDB_DIR=... ## need to be defined 6 | output_dir=... ## need to be defined 7 | batch_size=8 8 | exp_name=bi-tgn/adv-stage1 9 | 10 | ############################################## 11 | ## Adv. Stage 1: train detector & generator ## 12 | ############################################## 13 | ## Evaluate stage-1 detector ## 14 | 15 | for dataset in re2019 twitter15 twitter16 16 | do 17 | for i in $(seq 0 4) 18 | do 19 | ## ========================================================== 20 | ## Commands below are controlled by `--td_gcn` and `--bu_gcn` 21 | 22 | python main.py \ 23 | --task_type train_adv_stage1 \ 24 | --model_name_or_path facebook/bart-base \ 25 | --td_gcn \ 26 | --bu_gcn \ 27 | --dataset_name $dataset \ 28 | --train_file train.csv \ 29 | --validation_file test.csv \ 30 | --fold $i \ 31 | --do_eval \ 32 | --exp_name $exp_name \ 33 | --output_dir $output_dir 34 | done 35 | done 36 | -------------------------------------------------------------------------------- /src/scripts/summarizer/filter/train.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export WANDB_PROJECT="RumorDAS" 5 | export WANDB_DIR=... ## need to be defined 6 | output_dir=/mnt/1T/projects/RumorDAS ## need to be defined 7 | batch_size=256 8 | lr=4e-5 9 | 10 | ############################################# 11 | ## Transformer AutoEncoder Response Filter ## 12 | ############################################# 13 | for dataset in re2019 twitter15 twitter16 14 | do 15 | for i in $(seq 0 4) 16 | do 17 | for n_layer in 2 4 6 18 | do 19 | python main.py \ 20 | --task_type train_filter \ 21 | --model_name_or_path facebook/bart-base \ 22 | --filter_layer_enc $n_layer \ 23 | --filter_layer_dec $n_layer \ 24 | --dataset_name $dataset \ 25 | --train_file train.csv \ 26 | --validation_file test.csv \ 27 | --fold $i \ 28 | --do_train \ 29 | --per_device_train_batch_size $batch_size \ 30 | --learning_rate $lr \ 31 | --num_train_epochs 50 \ 32 | --exp_name filter_$n_layer \ 33 | --output_dir $output_dir 34 | done 35 | done 36 | done 37 | -------------------------------------------------------------------------------- /src/scripts/summarizer/ssra-kmeans/train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export WANDB_PROJECT="RumorDAS" 3 | export WANDB_DIR=... ## need to be defined 4 | output_dir=/mnt/1T/projects/RumorDAS ## need to be defined 5 | batch_size=4 6 | lr=2e-5 7 | 8 | ################################################ 9 | ## Self-Supervised Response Abstractor (SSRA) ## 10 | ################################################ 11 | for dataset in re2019 twitter15 twitter16 12 | do 13 | for i in $(seq 0 4) 14 | do 15 | for num_clusters in $(seq 1 5) 16 | do 17 | python main.py \ 18 | --task_type ssra_kmeans \ 19 | --model_name_or_path lidiya/bart-base-samsum \ 20 | --cluster_type kmeans \ 21 | --cluster_mode train \ 22 | --num_clusters $num_clusters \ 23 | --dataset_name $dataset \ 24 | --train_file train.csv \ 25 | --validation_file test.csv \ 26 | --fold $i \ 27 | --do_train \ 28 | --per_device_train_batch_size $batch_size \ 29 | --learning_rate $lr \ 30 | --num_train_epochs 10 \ 31 | --exp_name ssra_kmeans_$num_clusters \ 32 | --output_dir $output_dir 33 | done 34 | done 35 | done 36 | -------------------------------------------------------------------------------- /src/scripts/attack/stage2/train.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export WANDB_PROJECT="RumorDAS" 5 | export WANDB_DIR=... ## need to be defined 6 | output_dir=/mnt/1T/projects/RumorDAS ## need to be defined 7 | batch_size=8 8 | exp_name=bi-tgn/adv-stage2 9 | 10 | ######################################################### 11 | ## Adv. Stage 2: train generator while fixing detector ## 12 | ######################################################### 13 | 14 | for dataset in re2019 twitter15 twitter16 15 | do 16 | for i in $(seq 0 4) 17 | do 18 | ## ========================================================== 19 | ## Commands below are controlled by `--td_gcn` and `--bu_gcn` 20 | 21 | python main.py \ 22 | --task_type train_adv_stage2 \ 23 | --model_name_or_path facebook/bart-base \ 24 | --td_gcn \ 25 | --bu_gcn \ 26 | --dataset_name $dataset \ 27 | --train_file train.csv \ 28 | --validation_file test.csv \ 29 | --fold $i \ 30 | --do_train \ 31 | --per_device_train_batch_size $batch_size \ 32 | --learning_rate 2e-5 \ 33 | --num_train_epochs 10 \ 34 | --exp_name $exp_name \ 35 | --output_dir "$output_dir" 36 | done 37 | done 38 | 39 | -------------------------------------------------------------------------------- /src/scripts/summarizer/build_cluster.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export WANDB_PROJECT="RumorDAS" 4 | export WANDB_DIR=... ## need to be defined 5 | output_dir=/mnt/1T/projects/RumorDAS ## need to be defined 6 | batch_size=1 7 | 8 | for dataset in re2019 twitter15 twitter16 9 | do 10 | for i in $(seq 0 4) 11 | do 12 | for num_clusters in $(seq 1 5) 13 | do 14 | ## Build cluster summary pairs by kmeans for training 15 | python main.py \ 16 | --task_type build_cluster_summary \ 17 | --model_name_or_path facebook/bart-base \ 18 | --cluster_type kmeans \ 19 | --cluster_mode train \ 20 | --num_clusters $num_clusters \ 21 | --dataset_name $dataset \ 22 | --fold $i \ 23 | --per_device_train_batch_size $batch_size \ 24 | --output_dir $output_dir 25 | 26 | ## Build cluster summary pairs by kmeans for testing 27 | #python main.py \ 28 | # --task_type build_cluster_summary \ 29 | # --model_name_or_path facebook/bart-base \ 30 | # --cluster_type kmeans \ 31 | # --cluster_mode test \ 32 | # --num_clusters $num_clusters \ 33 | # --dataset_name $dataset \ 34 | # --fold $i \ 35 | # --per_device_train_batch_size $batch_size 36 | done 37 | done 38 | done -------------------------------------------------------------------------------- /src/scripts/detection/eval.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export WANDB_PROJECT="RumorDAS" 5 | export WANDB_DIR=... ## need to be defined 6 | output_dir=... ## need to be defined 7 | batch_size=16 8 | exp_name=bi-tgn-roberta 9 | 10 | for dataset in re2019 twitter15 twitter16 11 | do 12 | for i in $(seq 0 4) 13 | do 14 | ############## 15 | ## Evaluate ## 16 | ############## 17 | python main.py \ 18 | --task_type train_detector \ 19 | --model_name_or_path roberta-base \ 20 | --td_gcn \ 21 | --bu_gcn \ 22 | --dataset_name $dataset \ 23 | --train_file train.csv \ 24 | --validation_file test.csv \ 25 | --fold $i \ 26 | --do_eval \ 27 | --exp_name $exp_name \ 28 | --output_dir $output_dir 29 | 30 | ######################## 31 | ## Obtain Predictions ## 32 | ######################## 33 | #python main.py \ 34 | # --task_type train_detector \ 35 | # --model_name_or_path roberta-base \ 36 | # --td_gcn \ 37 | # --bu_gcn \ 38 | # --dataset_name "$dataset" \ 39 | # --train_file train.csv \ 40 | # --validation_file test.csv \ 41 | # --fold "$i" \ 42 | # --do_predict \ 43 | # --exp_name bi-tgn-roberta/lr2e-5 \ 44 | # --output_dir $output_dir 45 | done 46 | done 47 | -------------------------------------------------------------------------------- /src/scripts/attack/stage1/train.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export WANDB_PROJECT="RumorDAS" 5 | export WANDB_DIR=... ## need to be defined 6 | output_dir=/mnt/1T/projects/RumorDAS ## need to be defined 7 | batch_size=8 8 | exp_name=bi-tgn/adv-stage1 9 | 10 | ############################################## 11 | ## Adv. Stage 1: train detector & generator ## 12 | ############################################## 13 | 14 | for dataset in re2019 twitter15 twitter16 15 | do 16 | for i in $(seq 0 4) 17 | do 18 | ## ========================================================== 19 | ## Commands below are controlled by `--td_gcn` and `--bu_gcn` 20 | ## Note: 2 arguments can be used for debugging 21 | ## --evaluation_strategy steps \ 22 | ## --eval_steps 10 23 | 24 | python main.py \ 25 | --task_type train_adv_stage1 \ 26 | --model_name_or_path facebook/bart-base \ 27 | --td_gcn \ 28 | --bu_gcn \ 29 | --dataset_name $dataset \ 30 | --train_file train.csv \ 31 | --validation_file test.csv \ 32 | --fold $i \ 33 | --do_train \ 34 | --per_device_train_batch_size $batch_size \ 35 | --learning_rate 2e-5 \ 36 | --num_train_epochs 10 \ 37 | --exp_name $exp_name \ 38 | --output_dir $output_dir 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /src/scripts/others/visual.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="RumorV2" 2 | 3 | if [ $(hostname) = "josh-System-Product-Name" ]; then 4 | export WANDB_DIR="/mnt/hdd1/projects/RumorV2" 5 | output_dir="/mnt/hdd1/projects/RumorV2/results" 6 | batch_size=8 7 | elif [ $(hostname) = "yisyuan-PC2" ]; then 8 | export CUDA_VISIBLE_DEVICES=1 9 | export WANDB_DIR="/home/joshchang/project/RumorV2" 10 | output_dir="/home/joshchang/project/RumorV2/results" 11 | batch_size=8 12 | else 13 | export CUDA_VISIBLE_DEVICES=1 14 | export WANDB_DIR="/nfs/home/joshchang/projects/RumorV2" 15 | output_dir="/nfs/home/joshchang/projects/RumorV2/results" 16 | batch_size=16 17 | fi 18 | 19 | for dataset in semeval2019 20 | do 21 | if [ $dataset = "PHEME" ]; then 22 | ## Event-wise cross validation 23 | folds=$(seq 0 8) 24 | else 25 | folds=$(seq 0 4) 26 | folds=$(seq 0 0) 27 | fi 28 | 29 | for i in $folds 30 | do 31 | if [ "$i" = "comp" ] 32 | then 33 | ## For semeval2019 fold [comp] 34 | eval_file=dev.csv 35 | test_file=test.csv 36 | else 37 | ## For 5-fold 38 | eval_file=test.csv 39 | test_file=test.csv 40 | fi 41 | 42 | python others/plot_tsne.py \ 43 | --task_type train_adv_stage2 \ 44 | --attack_type untargeted \ 45 | --model_name_or_path facebook/bart-base \ 46 | --add_gcn \ 47 | --bi_gcn \ 48 | --dataset_name "$dataset" \ 49 | --train_file train.csv \ 50 | --validation_file "$eval_file" \ 51 | --fold "$i" \ 52 | --do_eval \ 53 | --exp_name bi-tgn/adv-stage2 \ 54 | --output_dir "$output_dir" 55 | done 56 | done -------------------------------------------------------------------------------- /src/scripts/summarizer/ssra-kmeans/eval.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | export WANDB_PROJECT="RumorDAS" 3 | export WANDB_DIR=... ## need to be defined 4 | output_dir=... ## need to be defined 5 | batch_size=8 6 | 7 | ################################################ 8 | ## Self-Supervised Response Abstractor (SSRA) ## 9 | ################################################ 10 | ## Evaluation ## 11 | for dataset in re2019 twitter15 twitter16 12 | do 13 | for i in $(seq 0 4) 14 | do 15 | for num_clusters in $(seq 1 5) 16 | do 17 | python main.py \ 18 | --task_type ssra_kmeans \ 19 | --model_name_or_path lidiya/bart-base-samsum \ 20 | --cluster_type kmeans \ 21 | --cluster_mode train \ 22 | --num_clusters $num_clusters \ 23 | --dataset_name $dataset \ 24 | --train_file train.csv \ 25 | --validation_file test.csv \ 26 | --fold $i \ 27 | --do_eval \ 28 | --exp_name ssra_kmeans_$num_clusters \ 29 | --output_dir $output_dir 30 | done 31 | done 32 | done 33 | 34 | ## Prediction ## 35 | #for dataset in re2019 twitter15 twitter16 36 | #do 37 | # for i in $(seq 0 4) 38 | # do 39 | # for num_clusters in $(seq 3 3) 40 | # do 41 | # python main.py \ 42 | # --task_type ssra_kmeans \ 43 | # --model_name_or_path lidiya/bart-base-samsum \ 44 | # --cluster_type kmeans \ 45 | # --cluster_mode train \ 46 | # --num_clusters $num_clusters \ 47 | # --dataset_name $dataset \ 48 | # --train_file train.csv \ 49 | # --validation_file test.csv \ 50 | # --fold $i \ 51 | # --do_eval \ 52 | # --min_target_length 10 \ 53 | # --max_target_length 128 \ 54 | # --exp_name ssra_kmeans_$num_clusters \ 55 | # --output_dir $output_dir 56 | # done 57 | # done 58 | #done 59 | -------------------------------------------------------------------------------- /src/data/build_datasets.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import math 3 | import random 4 | import logging 5 | import pandas as pd 6 | from datasets import load_dataset 7 | 8 | ## Self-defined 9 | from .build_datasets_adv import build_datasets_adv 10 | from .build_datasets_filter import build_datasets_filter 11 | from .build_datasets_abstractor import build_datasets_loo_abstractor, build_datasets_clustering_abstractor 12 | from .build_datasets_clustering import build_datasets_clustering 13 | 14 | ## Call the same logger used in main.py 15 | logger = logging.getLogger("__main__") 16 | 17 | def build_datasets(data_args, model_args, training_args, config, tokenizer, model): 18 | """Build datasets according to different tasks""" 19 | 20 | if training_args.task_type == "train_detector" or \ 21 | training_args.task_type == "train_adv_stage1" or \ 22 | training_args.task_type == "train_adv_stage2": 23 | train_dataset, eval_dataset, test_dataset = build_datasets_adv( 24 | data_args, model_args, training_args, 25 | config, tokenizer, model 26 | ) 27 | elif training_args.task_type == "train_filter": 28 | train_dataset, eval_dataset, test_dataset = build_datasets_filter( 29 | data_args, model_args, training_args, 30 | config, tokenizer, model 31 | ) 32 | elif training_args.task_type == "build_cluster_summary": 33 | train_dataset, eval_dataset, test_dataset = build_datasets_clustering( 34 | data_args, model_args, training_args, 35 | config, tokenizer, model 36 | ) 37 | elif training_args.task_type == "ssra_loo": 38 | train_dataset, eval_dataset, test_dataset = build_datasets_loo_abstractor( 39 | data_args, model_args, training_args, 40 | config, tokenizer, model 41 | ) 42 | elif training_args.task_type == "ssra_kmeans": 43 | train_dataset, eval_dataset, test_dataset = build_datasets_clustering_abstractor( 44 | data_args, model_args, training_args, 45 | config, tokenizer, model 46 | ) 47 | else: 48 | raise ValueError("training_args.task_type not specified!") 49 | return train_dataset, eval_dataset, test_dataset -------------------------------------------------------------------------------- /src/scripts/summarizer/cluster.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export WANDB_PROJECT="RumorV2" 4 | 5 | if [ $(hostname) = "josh-System-Product-Name" ]; then 6 | export WANDB_DIR="/mnt/hdd1/projects/RumorV2" 7 | output_dir="/mnt/hdd1/projects/RumorV2/results" 8 | batch_size=8 9 | elif [ $(hostname) = "ED716" ]; then 10 | export CUDA_VISIBLE_DEVICES=1 11 | export WANDB_DIR="/mnt/1T/projects/RumorV2" 12 | output_dir="/mnt/1T/projects/RumorV2/results" 13 | batch_size=1 14 | else 15 | export CUDA_VISIBLE_DEVICES=1 16 | export WANDB_DIR="/nfs/home/joshchang/projects/RumorV2" 17 | output_dir="/nfs/home/joshchang/projects/RumorV2/results" 18 | batch_size=16 19 | fi 20 | 21 | for dataset in semeval2019 twitter15 twitter16 22 | do 23 | if [ $dataset = "PHEME" ]; then 24 | ## Event-wise cross validation 25 | folds=$(seq 0 8) 26 | else 27 | folds=$(seq 0 4) 28 | fi 29 | 30 | for i in $folds 31 | do 32 | if [ "$i" = "comp" ] 33 | then 34 | ## For semeval2019 fold [comp] 35 | eval_file=dev.csv 36 | test_file=test.csv 37 | else 38 | ## For 5-fold 39 | eval_file=test.csv 40 | test_file=test.csv 41 | fi 42 | 43 | for num_clusters in $(seq 1 5) 44 | do 45 | ## Build cluster summary pairs by kmeans for training 46 | python main.py \ 47 | --task_type build_cluster_summary \ 48 | --model_name_or_path facebook/bart-base \ 49 | --cluster_type kmeans \ 50 | --cluster_mode train \ 51 | --num_clusters "$num_clusters" \ 52 | --dataset_name "$dataset" \ 53 | --fold "$i" \ 54 | --per_device_train_batch_size "$batch_size" 55 | 56 | ## Build cluster summary pairs by kmeans for testing 57 | #python main.py \ 58 | # --task_type build_cluster_summary \ 59 | # --model_name_or_path facebook/bart-base \ 60 | # --cluster_type kmeans \ 61 | # --cluster_mode test \ 62 | # --num_clusters "$num_clusters" \ 63 | # --dataset_name "$dataset" \ 64 | # --fold "$i" \ 65 | # --per_device_train_batch_size "$batch_size" 66 | done 67 | 68 | ## Build cluster summary pairs based on topics 69 | #python main.py \ 70 | # --task_type build_cluster_summary \ 71 | # --model_name_or_path facebook/bart-base \ 72 | # --cluster_type topics \ 73 | # --dataset_name "$dataset" \ 74 | # --fold "$i" \ 75 | # --per_device_train_batch_size "$batch_size" \ 76 | 77 | done 78 | done -------------------------------------------------------------------------------- /src/data/preprocess/statistics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import ipdb 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | from tqdm import tqdm 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description="Rumor Detection") 13 | 14 | ## What to do 15 | #parser.add_argument("--txt2csv", action="store_true") 16 | 17 | ## Others 18 | parser.add_argument("--dataset", type=str, default="semeval2019", choices=["semeval2019", "Pheme", "twitter15", "twitter16"]) 19 | parser.add_argument("--data_root", type=str, default="../dataset/processed") 20 | parser.add_argument("--data_root_V2", type=str, default="../dataset/processedV2") 21 | parser.add_argument("--fold", type=str, default="0,1,2,3,4", help="either use 5-fold data or train/dev/test from rumoureval2019 competition") 22 | 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | def statistics(args): 28 | print("\n*** Statistics of {} ***\n".format(args.dataset)) 29 | 30 | path_i = "{}/{}/data.csv".format(args.data_root_V2, args.dataset) 31 | 32 | data_df = pd.read_csv(path_i) 33 | 34 | label_set = set(data_df["veracity"]) 35 | src_group = data_df.groupby("source_id") 36 | 37 | n_posts = len(data_df) 38 | n_claims = len(src_group.size()) 39 | max_tree_len = np.max(src_group.size()) 40 | min_tree_len = np.min(src_group.size()) 41 | avg_tree_len = np.mean(src_group.size()) 42 | 43 | ## Count labels 44 | label_cnt = {} 45 | for label in label_set: 46 | label_cnt[label] = 0 47 | 48 | ## For each source post (claim) 49 | for src_id, group in src_group: 50 | label_cnt[group["veracity"].tolist()[0]] += 1 51 | 52 | print("# claims: {:4d}".format(n_claims)) 53 | print("# posts : {:4d}".format(n_posts)) 54 | for label in label_cnt.keys(): 55 | print("# {:10s}: {:4d}".format(label, label_cnt[label])) 56 | 57 | print("Max. tree len.: {:6d}".format(max_tree_len)) 58 | print("Min. tree len.: {:6d}".format(min_tree_len)) 59 | print("Avg. tree len.: {:6.2f}".format(avg_tree_len)) 60 | 61 | ## Plot histogram of tree length 62 | tree_lens = src_group.size().values 63 | plt.hist(tree_lens, range(min(tree_lens), max(tree_lens) + 10, 10)) 64 | plt.title(args.dataset.capitalize()) 65 | plt.savefig("{}/{}/tree_len.png".format(args.data_root_V2, args.dataset)) 66 | 67 | #ipdb.set_trace() 68 | 69 | if __name__ == "__main__": 70 | args = parse_args() 71 | 72 | statistics(args) 73 | -------------------------------------------------------------------------------- /src/scripts/summarizer/eval_summary.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | for dataset in semeval2019 twitter15 twitter16 4 | do 5 | ######################### 6 | ## Evaluate Perplexity ## 7 | ######################### 8 | #for num_clusters in $(seq 1 5) 9 | #do 10 | # python others/evaluate_summary.py \ 11 | # --eval_ppl \ 12 | # --model_type ssra_kmeans_"$num_clusters" \ 13 | # --data_name "$dataset" \ 14 | # --data_root_V2 ../dataset/processedV2 15 | #done 16 | # 17 | #python others/evaluate_summary.py \ 18 | # --eval_ppl \ 19 | # --model_type ra \ 20 | # --data_name "$dataset" \ 21 | # --data_root_V2 ../dataset/processedV2 22 | # 23 | #python others/evaluate_summary.py \ 24 | # --eval_ppl \ 25 | # --model_type ssra_loo \ 26 | # --data_name "$dataset" \ 27 | # --data_root_V2 ../dataset/processedV2 28 | 29 | #python others/evaluate_summary.py \ 30 | # --eval_ppl \ 31 | # --model_type chatgpt \ 32 | # --data_name $dataset \ 33 | # --data_root_V2 ../dataset/processedV2 34 | 35 | #################################### 36 | ## Generate data files for factCC ## 37 | #################################### 38 | #for factCC_format in all_responses response_wise 39 | for factCC_format in all_responses 40 | do 41 | #for num_clusters in $(seq 1 5) 42 | #do 43 | # python others/evaluate_summary.py \ 44 | # --generate_for_factCC \ 45 | # --factCC_format "$factCC_format" \ 46 | # --model_type ssra_kmeans_"$num_clusters" \ 47 | # --data_name "$dataset" \ 48 | # --data_root_V2 ../dataset/processedV2 49 | #done 50 | # 51 | #python others/evaluate_summary.py \ 52 | # --generate_for_factCC \ 53 | # --factCC_format "$factCC_format" \ 54 | # --model_type ra \ 55 | # --data_name "$dataset" \ 56 | # --data_root_V2 ../dataset/processedV2 57 | # 58 | #python others/evaluate_summary.py \ 59 | # --generate_for_factCC \ 60 | # --factCC_format "$factCC_format" \ 61 | # --model_type ssra_loo \ 62 | # --data_name "$dataset" \ 63 | # --data_root_V2 ../dataset/processedV2 64 | 65 | python others/evaluate_summary.py \ 66 | --generate_for_factCC \ 67 | --factCC_format $factCC_format \ 68 | --model_type chatgpt \ 69 | --data_name $dataset \ 70 | --data_root_V2 ../dataset/processedV2 71 | done 72 | 73 | #for factCC_format in cluster_wise 74 | #do 75 | # for num_clusters in $(seq 2 5) 76 | # do 77 | # python others/evaluate_summary.py \ 78 | # --generate_for_factCC \ 79 | # --factCC_format "$factCC_format" \ 80 | # --model_type ssra_kmeans_"$num_clusters" \ 81 | # --data_name "$dataset" \ 82 | # --data_root_V2 ../dataset/processedV2 83 | # done 84 | #done 85 | done -------------------------------------------------------------------------------- /src/scripts/others/data.sh: -------------------------------------------------------------------------------- 1 | ############# 2 | ## Dataset ## 3 | ############# 4 | #python data/preprocess/preprocess.py --txt2csv 5 | #python data/preprocess/preprocess.py --csv4hf 6 | #python data/preprocess/preprocess.py --csv4hf --simple 7 | #python data/preprocess/preprocess.py --csv4hf --simple --fold comp 8 | #python data/preprocess/preprocess.py --processV2_stance 9 | #python data/preprocess/preprocess.py --processV2 --dataset Pheme 10 | #python data/preprocess/preprocess.py --processV2 --dataset twitter15 11 | #python data/preprocess/preprocess.py --processV2 --dataset twitter16 12 | #python data/preprocess/preprocess.py --pheme_event_wise 13 | #python data/preprocess/preprocess.py --processV2_fold --dataset semeval2019 14 | #python data/preprocess/preprocess.py --processV2_fold --dataset semeval2019 --fold comp 15 | #python data/preprocess/preprocess.py --processV2_fold --dataset twitter15 16 | #python data/preprocess/preprocess.py --processV2_fold --dataset twitter16 17 | #python data/preprocess/preprocess.py --process_semeval2019_dev_set 18 | #python data/preprocess/preprocess.py --create_tree_ids_file --dataset semeval2019 19 | #python data/preprocess/preprocess.py --create_tree_ids_file --dataset twitter15 20 | #python data/preprocess/preprocess.py --create_tree_ids_file --dataset twitter16 21 | 22 | #python data/postprocess.py --wordcloud 23 | #python data/postprocess.py --get_event 24 | #python data/postprocess.py --get_event_from_pheme 25 | #python data/topic_model.py 26 | 27 | ## PHEME_veracity: raw -> processed 28 | #python data/preprocess/preprocess_pheme.py --dataset PHEME_veracity --preprocess 29 | #python data/preprocess/preprocess_pheme.py --dataset PHEME --split_5_fold 30 | #python data/preprocess/preprocess_pheme.py --dataset PHEME --split_event_wise 31 | 32 | ## twitter15, twitter16 33 | #python data/preprocess/preprocess_twitter.py --recover_4_classes --dataset twitter15 34 | #python data/preprocess/preprocess_twitter.py --recover_4_classes --dataset twitter16 35 | #python data/preprocess/preprocess_twitter.py --process_twitter16 --dataset twitter16 36 | 37 | ## Build graph dataset 38 | #python data/graph_dataset.py --dataset_name semeval2019 39 | #python data/graph_dataset.py --dataset_name twitter15 40 | #python data/graph_dataset.py --dataset_name twitter16 41 | #python data/graph_dataset.py --dataset_name PHEME 42 | 43 | ## Build topics 44 | #python data/topic_model.py --dataset_name semeval2019 45 | #python data/topic_model.py --dataset_name twitter15 46 | #python data/topic_model.py --dataset_name twitter16 47 | 48 | 49 | #for dataset in twitter15 twitter16 semeval2019 50 | #do 51 | # ## Statistics 52 | # #python data/preprocess/statistics.py --dataset $dataset 53 | # 54 | # ## 2023/8/3 - Obtain topics via clustering 55 | # #python data/cluster_topics.py \ 56 | # # --split_data_via_cluster \ 57 | # # --select_number_of_clusters \ 58 | # # --dataset_name $dataset 59 | # #python data/cluster_topics.py \ 60 | # # --split_data_via_cluster \ 61 | # # --dataset_name $dataset 62 | #done -------------------------------------------------------------------------------- /src/data/preprocess/preprocess_twitter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import ipdb 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="Rumor Detection") 11 | 12 | ## What to do 13 | parser.add_argument("--recover_4_classes", action="store_true") 14 | parser.add_argument("--process_twitter16", action="store_true") 15 | 16 | ## Others 17 | parser.add_argument("--dataset", type=str, default="twitter15", choices=["twitter15", "twitter16"]) 18 | parser.add_argument("--data_root_raw", type=str, default="../dataset/raw") 19 | parser.add_argument("--data_root_V1", type=str, default="../dataset/processed") 20 | parser.add_argument("--data_root_V2", type=str, default="../dataset/processedV2") 21 | parser.add_argument("--fold", type=str, default="0,1,2,3,4", help="either use 5-fold data or train/dev/test from rumoureval2019 competition") 22 | 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | def recover_4_classes(args): 28 | """Recover rumor labels (true, false, unverified) of V2 dataset from raw datasets.""" 29 | raw_labels_dict = {} 30 | raw_labels_path = "{}/rumor_detection_acl2017/{}/label.txt".format(args.data_root_raw, args.dataset) 31 | with open(raw_labels_path, "r") as f: 32 | for line in f.readlines(): 33 | line = line.strip().rstrip() 34 | label, src_id = line.split(":") 35 | raw_labels_dict[int(src_id)] = label 36 | 37 | ## For `data.csv` 38 | data_df = pd.read_csv("{}/{}/data.csv".format(args.data_root_V2, args.dataset)) 39 | data_df["veracity"] = data_df.apply(lambda row: raw_labels_dict[row["source_id"]], axis=1) 40 | data_df.to_csv("{}/{}/data.csv".format(args.data_root_V2, args.dataset)) 41 | 42 | ## For each fold 43 | for fold in args.fold.split(","): 44 | sets = ["train", "test"] 45 | fold_path = "{}/{}/split_{}".format(args.data_root_V2, args.dataset, fold) 46 | for train_or_test in sets: 47 | print("Fold [{}] - {:5s} set".format(fold, train_or_test)) 48 | fold_df = pd.read_csv("{}/{}.csv".format(fold_path, train_or_test)) 49 | fold_df["label_veracity"] = fold_df.apply(lambda row: raw_labels_dict[row["source_id"]], axis=1) 50 | fold_df.to_csv("{}/{}.csv".format(fold_path, train_or_test), index=False) 51 | 52 | def process_twitter16(args): 53 | def remove_weird_token(text): 54 | text_new = [] 55 | last_uD = False 56 | for token in text.split(): 57 | if token.isdigit() and last_uD and len(token) == 2: 58 | continue 59 | if token.startswith("uD"): 60 | last_uD = True 61 | continue 62 | last_uD = False 63 | text_new.append(token) 64 | return " ".join(text_new).replace("\\", "").strip().rstrip() 65 | 66 | data_df = pd.read_csv("{}/{}/data_ori.csv".format(args.data_root_V2, args.dataset)) 67 | 68 | #text = data_df.iloc[1472]["text"] 69 | #text_new = [] 70 | #last_uD = False 71 | #for token in text.split(): 72 | # if token.isdigit() and last_uD and len(token) == 2: 73 | # continue 74 | # if token.startswith("uD"): 75 | # last_uD = True 76 | # continue 77 | # last_uD = False 78 | # text_new.append(token) 79 | #text_new = " ".join(text_new).replace("\\", "").strip().rstrip() 80 | #text_new = remove_weird_token(text) 81 | tqdm.pandas() 82 | data_df["text"] = data_df["text"].progress_apply(remove_weird_token) 83 | ipdb.set_trace() 84 | data_df.to_csv("{}/{}/data.csv".format(args.data_root_V2, args.dataset), index=False) 85 | 86 | if __name__ == "__main__": 87 | args = parse_args() 88 | 89 | print("Preprocessing [{}]".format(args.dataset)) 90 | if args.recover_4_classes: 91 | recover_4_classes(args) 92 | elif args.process_twitter16: 93 | process_twitter16(args) -------------------------------------------------------------------------------- /src/scripts/attack/stage2/eval.sh: -------------------------------------------------------------------------------- 1 | #!bin/sh 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | export WANDB_PROJECT="RumorDAS" 5 | export WANDB_DIR=... ## need to be defined 6 | output_dir=/mnt/1T/projects/RumorDAS ## need to be defined 7 | batch_size=8 8 | exp_name=bi-tgn/adv-stage2 9 | 10 | ####################################### 11 | ## Evaluate BiTGN w/ DAS, w/o Attack ## 12 | ####################################### 13 | for extract_ratio in 0.05 0.1 0.15 0.2 0.25 0.5 0.75 0.9 14 | do 15 | for dataset in re2019 twitter15 twitter16 16 | do 17 | for num_clusters in $(seq 1 5) 18 | do 19 | for i in $(seq 0 4) 20 | do 21 | ## Defensive Response Extractor (DRE) - Cluster Only 22 | #python main.py \ 23 | # --task_type train_adv_stage2 \ 24 | # --attack_type untargeted \ 25 | # --model_name_or_path facebook/bart-base \ 26 | # --td_gcn \ 27 | # --bu_gcn \ 28 | # --num_clusters $num_clusters \ 29 | # --extractor_name_or_path kmeans \ 30 | # --dataset_name $dataset \ 31 | # --train_file train.csv \ 32 | # --validation_file $test_file \ 33 | # --fold $i \ 34 | # --do_eval \ 35 | # --exp_name bi-tgn/adv-stage2 \ 36 | # --output_dir $output_dir 37 | 38 | ## Self-Supervised Response Abstractor (SSRA) Only 39 | #python main.py \ 40 | # --task_type train_adv_stage2 \ 41 | # --attack_type untargeted \ 42 | # --model_name_or_path facebook/bart-base \ 43 | # --td_gcn \ 44 | # --bu_gcn \ 45 | # --num_clusters $num_clusters \ 46 | # --extractor_name_or_path kmeans \ 47 | # --abstractor_name_or_path ssra_kmeans_$num_clusters \ 48 | # --summarizer_output_type ssra_only \ 49 | # --dataset_name $dataset \ 50 | # --train_file train.csv \ 51 | # --validation_file $test_file \ 52 | # --fold $i \ 53 | # --do_eval \ 54 | # --per_device_eval_batch_size $batch_size \ 55 | # --exp_name bi-tgn/adv-stage2 \ 56 | # --output_dir $output_dir 57 | 58 | ## Defensive Response Extractor (DRE) - Filter Only 59 | #python main.py \ 60 | # --task_type train_adv_stage2 \ 61 | # --attack_type untargeted \ 62 | # --model_name_or_path facebook/bart-base \ 63 | # --td_gcn \ 64 | # --bu_gcn \ 65 | # --num_clusters $num_clusters \ 66 | # --filter_layer_enc 4 \ 67 | # --filter_layer_dec 4 \ 68 | # --extractor_name_or_path filter,kmeans \ 69 | # --filter_ratio $extract_ratio \ 70 | # --dataset_name $dataset \ 71 | # --train_file train.csv \ 72 | # --validation_file $test_file \ 73 | # --fold $i \ 74 | # --do_eval \ 75 | # --exp_name bi-tgn/adv-stage2 \ 76 | # --output_dir $output_dir 77 | 78 | ## ==================================== 79 | 80 | ## Defend-And-Summarize (DAS) Framework 81 | python main.py \ 82 | --task_type train_adv_stage2 \ 83 | --attack_type untargeted \ 84 | --model_name_or_path facebook/bart-base \ 85 | --td_gcn \ 86 | --bu_gcn \ 87 | --num_clusters $num_clusters \ 88 | --filter_layer_enc 4 \ 89 | --filter_layer_dec 4 \ 90 | --extractor_name_or_path filter,kmeans \ 91 | --filter_ratio $extract_ratio \ 92 | --abstractor_name_or_path ssra_kmeans_$num_clusters \ 93 | --dataset_name $dataset \ 94 | --train_file train.csv \ 95 | --validation_file test.csv \ 96 | --fold $i \ 97 | --do_eval \ 98 | --exp_name bi-tgn/adv-stage2 \ 99 | --output_dir $output_dir 100 | 101 | ## w/o summarizer 102 | #python main.py \ 103 | # --task_type train_adv_stage2 \ 104 | # --attack_type untargeted \ 105 | # --model_name_or_path facebook/bart-base \ 106 | # --td_gcn \ 107 | # --bu_gcn \ 108 | # --dataset_name $dataset \ 109 | # --train_file train.csv \ 110 | # --validation_file $test_file \ 111 | # --fold $i \ 112 | # --do_eval \ 113 | # --exp_name bi-tgn/adv-stage2 \ 114 | # --output_dir $output_dir 115 | done 116 | done 117 | done 118 | done -------------------------------------------------------------------------------- /src/pipelines/trainer_filter.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import optim 8 | from torch.utils.data import DataLoader 9 | 10 | from transformers import default_data_collator 11 | 12 | class FilterTrainer: 13 | def __init__( 14 | self, 15 | model=None, 16 | data_args=None, 17 | model_args=None, 18 | training_args=None, 19 | train_dataset=None, 20 | eval_dataset=None, 21 | ): 22 | self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 23 | 24 | self.data_args = data_args 25 | self.model_args = model_args 26 | self.training_args = training_args 27 | 28 | self.model = model.to(self.device) 29 | 30 | self.train_dataloader = DataLoader( 31 | train_dataset, 32 | batch_size=self.training_args.per_device_train_batch_size, 33 | collate_fn=default_data_collator, 34 | num_workers=self.training_args.dataloader_num_workers, 35 | pin_memory=self.training_args.dataloader_pin_memory 36 | ) 37 | self.eval_dataloader = DataLoader( 38 | eval_dataset, 39 | batch_size=self.training_args.per_device_train_batch_size, 40 | collate_fn=default_data_collator, 41 | num_workers=self.training_args.dataloader_num_workers, 42 | pin_memory=self.training_args.dataloader_pin_memory 43 | ) 44 | 45 | ## Build optimizer 46 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.training_args.learning_rate) 47 | 48 | ## Loss function 49 | #loss_fct = nn.L1Loss() ## L1 50 | self.loss_fct = nn.MSELoss() ## L2 51 | 52 | ## Freeze embedding layer 53 | for name, sub_module in self.model.named_modules(): 54 | #if name.startswith("embeddings"): 55 | if "embed" in name: 56 | for param in sub_module.parameters(): 57 | param.requires_grad = False 58 | 59 | all_param_num = sum([p.nelement() for p in self.model.parameters()]) 60 | trainable_param_num = sum([ 61 | p.nelement() 62 | for p in self.model.parameters() 63 | if p.requires_grad == True 64 | ]) 65 | print("All parameters: {}".format(all_param_num)) 66 | print("Trainable parameters: {}".format(trainable_param_num)) 67 | 68 | def train(self): 69 | print("\nStart training...") 70 | 71 | ## Training loops 72 | best_loss = float("inf") 73 | for epoch in range(int(self.training_args.num_train_epochs)): 74 | train_losses, eval_losses = [], [] 75 | 76 | ## Train on all batches 77 | self.model.train() 78 | for batch in tqdm(self.train_dataloader, desc="Epoch {:2d} Training".format(epoch)): 79 | input_ids = batch["input_ids"].to(self.device) 80 | attn_mask = batch["attention_mask"].to(self.device) 81 | 82 | self.model.zero_grad() 83 | ## Forward, return loss 84 | train_loss = self.model(input_ids, attn_mask) 85 | train_loss = train_loss.mean() 86 | train_loss.backward() 87 | self.optimizer.step() 88 | 89 | train_losses.append(train_loss.detach().cpu().numpy()) 90 | 91 | ## Evaluation 92 | self.model.eval() 93 | with torch.no_grad(): 94 | for batch in tqdm(self.eval_dataloader, desc="Epoch {:2d} Evaluation".format(epoch)): 95 | input_ids = batch["input_ids"].to(self.device) 96 | attn_mask = batch["attention_mask"].to(self.device) 97 | eval_loss = self.model(input_ids, attn_mask) 98 | eval_loss = eval_loss.mean() 99 | eval_losses.append(eval_loss.detach().cpu().numpy()) 100 | 101 | train_losses = np.array(train_losses) 102 | eval_losses = np.array(eval_losses) 103 | 104 | ## Display results 105 | print("Epoch {:2d} | Train Loss: {:.4f} | Eval Loss: {:.4f}".format(epoch, np.sum(train_losses), np.sum(eval_losses))) 106 | 107 | ## Save checkpoint 108 | if np.sum(eval_losses) < best_loss: 109 | print("Saving model with best reconstruction loss!") 110 | #ckpt_path = "{}/anomaly_scorer.pt".format(training_args.output_dir) 111 | #ckpt_path = "{}/anomaly_scorer_bart.pt".format(training_args.output_dir) 112 | #ckpt_path = "{}/anomaly_scorer_rd.pt".format(self.training_args.output_dir) 113 | #ckpt_path = "{}/anomaly_scorer_test.pt".format(self.training_args.output_dir) 114 | #ckpt_path = "{}/autoencoder_rd.pt".format(self.training_args.output_dir) 115 | ckpt_path = "{}/autoencoder_rd_{}.pt".format(self.training_args.output_dir, self.model_args.target_class_ext_ae) 116 | torch.save(self.model.state_dict(), ckpt_path) 117 | 118 | best_loss = np.sum(eval_losses) 119 | 120 | with open("{}/../overall_results.txt".format(self.training_args.output_dir), "a") as fw: 121 | fw.write("{:4s}\t{:.4f}\n".format(self.data_args.fold, best_loss)) 122 | 123 | -------------------------------------------------------------------------------- /src/data/graph_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | import random 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | class GraphDataset(Dataset): 13 | """ 14 | This class is used to process graphical info. for GCN. 15 | """ 16 | def __init__(self, data_args , model_args=None): 17 | 18 | print("\nBuilding graph dataset...") 19 | 20 | self.data_args = data_args 21 | self.model_args = model_args 22 | 23 | self.tddroprate = 0#0.2 24 | self.budroprate = 0#0.2 25 | 26 | ## Read dataset content 27 | self.data_df = pd.read_csv("{}/{}/data.csv".format(self.data_args.dataset_root, self.data_args.dataset_name)) 28 | self.data_df["source_id"] = self.data_df["source_id"].astype(str) ## For PHEME, twitter15, twitter16 29 | self.data_df["tweet_id" ] = self.data_df["tweet_id" ].astype(str) 30 | self.data_df["self_idx" ] = self.data_df["self_idx" ].astype(str) 31 | 32 | ## Check whether graph cache exists 33 | cache_path = "{}/{}/graph.pth".format(self.data_args.dataset_root, self.data_args.dataset_name) 34 | if os.path.exists(cache_path): 35 | print("Graph cache exists, directly load graph information from {}".format(cache_path)) 36 | graph_infos = torch.load(cache_path) 37 | self.td_edges = graph_infos["td_edges"] 38 | self.bu_edges = graph_infos["bu_edges"] 39 | else: 40 | ## Initialize src_id->edge_index map 41 | self.td_edges = {} 42 | self.bu_edges = {} 43 | for src_id in tqdm(list(set(self.data_df["source_id"]))): 44 | tree_df = self.data_df.loc[self.data_df["source_id"] == src_id] 45 | tree_df = tree_df.reset_index(drop=True) 46 | 47 | ## Build edge_index, row -> parent_idx, col -> child_idx 48 | ## Note: edges will be sorted by child_idx 49 | ## Example: src_id = "529695367680761856" 50 | ## edge_index = [[0, 0, 0, 0, 0, 5, 6, 0], 51 | ## [1, 2, 3, 4, 5, 6, 7, 8]] 52 | row = [] 53 | col = [] 54 | for index_i, tweet_i in tree_df.iterrows(): 55 | for index_j, tweet_j in tree_df.iterrows(): 56 | if tweet_i["parent_idx"] == tweet_j["self_idx"]: 57 | row.append(index_j) 58 | col.append(index_i) 59 | edge_index = torch.LongTensor([row, col]) 60 | 61 | ## Correct edge: correct parent_idx of edges that parent_idx > child_idx to root 62 | parent = torch.LongTensor(row) 63 | child = torch.LongTensor(col) 64 | parent[parent > child] = 0 65 | 66 | self.td_edges[src_id] = torch.stack((parent, child), dim=0) 67 | self.bu_edges[src_id] = torch.stack((child, parent), dim=0) 68 | 69 | def __getitem__(self, src_id): 70 | """ 71 | Returns: 72 | - td_edge_index: top-down edge indices corresponding to the tree structure given src_id 73 | - bu_edge_index: bottom-up edge indices 74 | """ 75 | if src_id not in self.td_edges: 76 | raise ValueError("source_id not in graph_dataset mapping!") 77 | 78 | ## Top-down edge index 79 | td_edge_index = self.td_edges[src_id] 80 | td_edge_index = td_edge_index[:, td_edge_index[1] <= self.data_args.max_tree_length] ## Truncate 81 | 82 | ## Bottom-up edge index 83 | bu_edge_index = self.bu_edges[src_id] 84 | bu_edge_index = bu_edge_index[:, bu_edge_index[0] <= self.data_args.max_tree_length] ## Truncate 85 | 86 | return td_edge_index, bu_edge_index 87 | 88 | def pad(self, edge_index): 89 | """ 90 | To enable batching for huggingface data collator, 91 | need to pad to data_args.max_tree_length - 1 with value -1. 92 | """ 93 | ## Padding, pad value = -1 94 | pad_length = (self.data_args.max_tree_length - 1) - edge_index.shape[1] 95 | edge_index = torch.cat((edge_index, torch.full((2, pad_length), -1)), dim=1) 96 | 97 | return edge_index 98 | 99 | def drop_edge(self, edge_index, edge_type): 100 | if edge_type == "td" and self.tddroprate > 0: 101 | rand_idx = random.sample(range(edge_index.shape[1]), int(edge_index.shape[1] * (1 - self.tddroprate))) 102 | rand_idx.sort() 103 | edge_index = edge_index[:, rand_idx] 104 | elif edge_type == "bu" and self.budroprate > 0: 105 | rand_idx = random.sample(range(edge_index.shape[1]), int(edge_index.shape[1] * (1 - self.budroprate))) 106 | rand_idx.sort() 107 | edge_index = edge_index[:, rand_idx] 108 | return edge_index 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser(description="Rumor Detection") 112 | parser.add_argument("--dataset_name", type=str, default=None) 113 | parser.add_argument("--dataset_root", type=str, default="../dataset/processedV2") 114 | args = parser.parse_args() 115 | 116 | print("Dataset: {}".format(args.dataset_name)) 117 | 118 | ## Build graph dataset 119 | graph_dataset = GraphDataset(data_args=args) 120 | 121 | graph_infos = { 122 | "td_edges": graph_dataset.td_edges, 123 | "bu_edges": graph_dataset.bu_edges 124 | } 125 | 126 | ## Write file to dataset 127 | torch.save(graph_infos, "{}/{}/graph.pth".format(args.dataset_root, args.dataset_name)) 128 | -------------------------------------------------------------------------------- /src/models/modeling_outputs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | 6 | from transformers.file_utils import ModelOutput 7 | 8 | @dataclass 9 | class BaseModelOutputWithEmbedding(ModelOutput): 10 | """ 11 | Base class for model's outputs, with potential hidden states and attentions. 12 | Args: 13 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 14 | Sequence of hidden-states at the output of the last layer of the model. 15 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 16 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 17 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 18 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 19 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 20 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 21 | sequence_length)`. 22 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 23 | heads. 24 | """ 25 | 26 | last_hidden_state: torch.FloatTensor = None 27 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 28 | attentions: Optional[Tuple[torch.FloatTensor]] = None 29 | 30 | ## NEW: return embeddings 31 | embed_tok: Optional[torch.FloatTensor] = None ## Token embeddings 32 | embed_pos: Optional[torch.FloatTensor] = None ## Position embeddings 33 | embed_out: Optional[torch.FloatTensor] = None ## Total embeddings = (Token embeddings) + (Position embeddings) 34 | #outputs_embeds: Optional[torch.FloatTensor] = None ## Total embeddings = (Token embeddings) + (Position embeddings) 35 | 36 | @dataclass 37 | class Seq2SeqWithSequenceClassifierOutput(ModelOutput): 38 | """ 39 | Output for seq2seq models along with sequence classifier. 40 | 41 | Seq2Seq (generation): 42 | - loss 43 | - logits 44 | - past_key_values 45 | - decoder_hidden_states 46 | - decoder_attentions 47 | - cross_attentions 48 | - encoder_last_hidden_state 49 | - encoder_hidden_states 50 | - encoder_attentions 51 | 52 | Sequence Classifier: 53 | - loss_det 54 | - logits_det 55 | """ 56 | 57 | ## For Seq2Seq 58 | loss: Optional[torch.FloatTensor] = None 59 | logits: torch.FloatTensor = None 60 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 61 | decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 62 | decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None 63 | cross_attentions: Optional[Tuple[torch.FloatTensor]] = None 64 | encoder_last_hidden_state: Optional[torch.FloatTensor] = None 65 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 66 | encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None 67 | 68 | ## For classifier 69 | loss_det: Optional[torch.FloatTensor] = None 70 | logits_det: torch.FloatTensor = None 71 | #hidden_states: Optional[Tuple[torch.FloatTensor]] = None 72 | #attentions: Optional[Tuple[torch.FloatTensor]] = None 73 | 74 | @dataclass 75 | class RumorDetectorOutput(ModelOutput): 76 | """ 77 | Output for Rumor Detector + Response Generator + Response Summarizer. 78 | 79 | Response Generator: 80 | - loss 81 | - logits 82 | - past_key_values 83 | - decoder_hidden_states 84 | - decoder_attentions 85 | - cross_attentions 86 | - encoder_last_hidden_state 87 | - encoder_hidden_states 88 | - encoder_attentions 89 | 90 | Rumor Detector: 91 | - loss_det 92 | - logits_det 93 | 94 | Response Summarizer: 95 | - summary_tokens: for abstractive summarizer (abstractor) 96 | - n_extracted_attack: for extractive summarizer (extractor) 97 | """ 98 | 99 | ## Response Generator 100 | loss: Optional[torch.FloatTensor] = None 101 | logits: torch.FloatTensor = None 102 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 103 | decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 104 | decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None 105 | cross_attentions: Optional[Tuple[torch.FloatTensor]] = None 106 | encoder_last_hidden_state: Optional[torch.FloatTensor] = None 107 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 108 | encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None 109 | 110 | ## Rumor Detector 111 | loss_det: Optional[torch.FloatTensor] = None 112 | logits_det: torch.FloatTensor = None 113 | 114 | ## Response Summarizer 115 | summary_tokens: Optional[torch.FloatTensor] = None 116 | n_ext_adv: int = None 117 | filt_ext_idxs: Optional[torch.Tensor] = None 118 | clus_ext_idxs: Optional[torch.Tensor] = None 119 | clus_ids: Optional[torch.Tensor] = None 120 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is developed based on: 3 | https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py 4 | """ 5 | import ipdb 6 | import logging 7 | 8 | import nltk ## Here to have a nice missing dependency error message early on 9 | 10 | import transformers 11 | from transformers import HfArgumentParser, set_seed 12 | from transformers.file_utils import is_offline_mode 13 | from transformers.utils import check_min_version 14 | from transformers.utils.versions import require_version 15 | from filelock import FileLock 16 | 17 | ## Self-defined 18 | from data.build_datasets import build_datasets 19 | from models.build_model import build_model 20 | from pipelines.build_trainer import build_trainer 21 | from others.args import ( 22 | args_post_init, 23 | CustomTrainingArguments, 24 | DataTrainingArguments, 25 | ModelArguments 26 | ) 27 | from others.utils import setup_logging 28 | from others.processes import ( 29 | train_process, 30 | eval_process, 31 | predict_process, 32 | ## ------------- 33 | train_adv, 34 | eval_adv_detector, 35 | eval_adv_asr, 36 | eval_adv_asr_with_summ, 37 | ## -------------------- 38 | finetune_abstractor, 39 | eval_abstractor 40 | ) 41 | 42 | ## Will error if the minimal version of Transformers is not installed. Remove at your own risks. 43 | check_min_version("4.18.0.dev0") 44 | 45 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 46 | 47 | logger = logging.getLogger(__name__) 48 | 49 | try: 50 | nltk.data.find("tokenizers/punkt") 51 | except (LookupError, OSError): 52 | if is_offline_mode(): 53 | raise LookupError( 54 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 55 | ) 56 | with FileLock(".lock") as lock: 57 | nltk.download("punkt", quiet=True) 58 | 59 | def main(): 60 | ## Parse args 61 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, CustomTrainingArguments)) 62 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 63 | model_args, data_args, training_args = args_post_init(model_args, data_args, training_args) 64 | 65 | ## Settings 66 | last_checkpoint = setup_logging(logger, training_args, data_args, model_args) 67 | set_seed(training_args.seed) ## Set seed before initializing model. 68 | 69 | ################ 70 | ## Load model ## 71 | ################ 72 | config, tokenizer, model = build_model(data_args, model_args, training_args) 73 | 74 | ############################### 75 | ## Load & preprocess dataset ## 76 | ############################### 77 | train_dataset, eval_dataset, test_dataset = build_datasets(data_args, model_args, training_args, 78 | config, tokenizer, model) 79 | 80 | ################### 81 | ## Build trainer ## 82 | ################### 83 | trainer = build_trainer( 84 | data_args, model_args, training_args, 85 | train_dataset, eval_dataset, 86 | model, tokenizer 87 | ) 88 | 89 | ########################## 90 | ## Train / Test process ## 91 | ########################## 92 | if training_args.task_type == "train_detector": 93 | if training_args.do_train: 94 | train_process(data_args, model_args, training_args, trainer, train_dataset, last_checkpoint) 95 | elif training_args.do_eval and not training_args.do_predict: 96 | eval_process(logger, data_args, model_args, training_args, trainer, test_dataset) 97 | elif training_args.do_eval and training_args.do_predict: 98 | predict_process(logger, data_args, model_args, training_args, trainer, test_dataset) 99 | 100 | elif training_args.task_type == "train_adv_stage1": 101 | if training_args.do_train: 102 | train_adv(data_args, model_args, training_args, trainer, train_dataset, last_checkpoint) 103 | elif training_args.do_eval: 104 | eval_adv_detector(logger, data_args, model_args, training_args, trainer, test_dataset, log_flag=True) 105 | 106 | elif training_args.task_type == "train_adv_stage2": 107 | if training_args.do_train: 108 | train_adv(data_args, model_args, training_args, trainer, train_dataset, last_checkpoint) 109 | elif training_args.do_eval: 110 | if model.summarizer is None: 111 | eval_adv_asr(logger, data_args, model_args, training_args, tokenizer, trainer, test_dataset) 112 | else: 113 | eval_adv_asr_with_summ(logger, data_args, model_args, training_args, tokenizer, trainer, test_dataset) 114 | 115 | elif training_args.task_type == "ssra_loo" or \ 116 | training_args.task_type == "ssra_kmeans": 117 | if training_args.do_train: 118 | finetune_abstractor(data_args, model_args, training_args, trainer, train_dataset, last_checkpoint) 119 | elif training_args.do_eval: 120 | eval_abstractor(data_args, model_args, training_args, tokenizer, trainer, test_dataset) 121 | 122 | elif training_args.task_type == "train_filter": 123 | trainer.train() 124 | 125 | elif training_args.task_type == "build_cluster_summary": 126 | trainer.build_cluster_summary() 127 | 128 | else: 129 | raise ValueError("training_args.task_type not correctly specified!") 130 | 131 | if __name__ == "__main__": 132 | main() 133 | 134 | -------------------------------------------------------------------------------- /src/others/metrics.py: -------------------------------------------------------------------------------- 1 | def f1_score_3_class(prediction, y): 2 | TP1, FP1, FN1, TN1 = 0, 0, 0, 0 3 | TP2, FP2, FN2, TN2 = 0, 0, 0, 0 4 | TP3, FP3, FN3, TN3 = 0, 0, 0, 0 5 | for i in range(len(y)): 6 | Act, Pre = y[i], prediction[i] 7 | 8 | ## for class 1 9 | if Act == 0 and Pre == 0: TP1 += 1 10 | if Act == 0 and Pre != 0: FN1 += 1 11 | if Act != 0 and Pre == 0: FP1 += 1 12 | if Act != 0 and Pre != 0: TN1 += 1 13 | ## for class 2 14 | if Act == 1 and Pre == 1: TP2 += 1 15 | if Act == 1 and Pre != 1: FN2 += 1 16 | if Act != 1 and Pre == 1: FP2 += 1 17 | if Act != 1 and Pre != 1: TN2 += 1 18 | ## for class 3 19 | if Act == 2 and Pre == 2: TP3 += 1 20 | if Act == 2 and Pre != 2: FN3 += 1 21 | if Act != 2 and Pre == 2: FP3 += 1 22 | if Act != 2 and Pre != 2: TN3 += 1 23 | 24 | ## print result 25 | Acc_all = float(TP1 + TP2 + TP3 ) / float(len(y)) 26 | Acc1 = float(TP1 + TN1) / float(TP1 + TN1 + FN1 + FP1) 27 | if (TP1 + FP1) == 0: 28 | Prec1 = 0 29 | else: 30 | Prec1 = float(TP1) / float(TP1 + FP1) 31 | if (TP1 + FN1) == 0: 32 | Recll1 = 0 33 | else: 34 | Recll1 = float(TP1) / float(TP1 + FN1) 35 | if (Prec1 + Recll1) == 0: 36 | F1 = 0 37 | else: 38 | F1 = 2 * Prec1 * Recll1 / (Prec1 + Recll1) 39 | 40 | Acc2 = float(TP2 + TN2) / float(TP2 + TN2 + FN2 + FP2) 41 | if (TP2 + FP2) == 0: 42 | Prec2 = 0 43 | else: 44 | Prec2 = float(TP2) / float(TP2 + FP2) 45 | if (TP2 + FN2) == 0: 46 | Recll2 = 0 47 | else: 48 | Recll2 = float(TP2) / float(TP2 + FN2) 49 | if (Prec2 + Recll2) == 0: 50 | F2 = 0 51 | else: 52 | F2 = 2 * Prec2 * Recll2 / (Prec2 + Recll2) 53 | 54 | Acc3 = float(TP3 + TN3) / float(TP3 + TN3 + FN3 + FP3) 55 | if (TP3 + FP3) == 0: 56 | Prec3 = 0 57 | else: 58 | Prec3 = float(TP3) / float(TP3 + FP3) 59 | if (TP3 + FN3) == 0: 60 | Recll3 = 0 61 | else: 62 | Recll3 = float(TP3) / float(TP3 + FN3) 63 | if (Prec3 + Recll3)== 0: 64 | F3 = 0 65 | else: 66 | F3 = 2 * Prec3 * Recll3 / (Prec3 + Recll3) 67 | 68 | return [F1, F2, F3] 69 | 70 | def f1_score_4_class(prediction, y): 71 | TP1, FP1, FN1, TN1 = 0, 0, 0, 0 72 | TP2, FP2, FN2, TN2 = 0, 0, 0, 0 73 | TP3, FP3, FN3, TN3 = 0, 0, 0, 0 74 | TP4, FP4, FN4, TN4 = 0, 0, 0, 0 75 | for i in range(len(y)): 76 | Act, Pre = y[i], prediction[i] 77 | 78 | ## for class 1 79 | if Act == 0 and Pre == 0: TP1 += 1 80 | if Act == 0 and Pre != 0: FN1 += 1 81 | if Act != 0 and Pre == 0: FP1 += 1 82 | if Act != 0 and Pre != 0: TN1 += 1 83 | ## for class 2 84 | if Act == 1 and Pre == 1: TP2 += 1 85 | if Act == 1 and Pre != 1: FN2 += 1 86 | if Act != 1 and Pre == 1: FP2 += 1 87 | if Act != 1 and Pre != 1: TN2 += 1 88 | ## for class 3 89 | if Act == 2 and Pre == 2: TP3 += 1 90 | if Act == 2 and Pre != 2: FN3 += 1 91 | if Act != 2 and Pre == 2: FP3 += 1 92 | if Act != 2 and Pre != 2: TN3 += 1 93 | ## for class 4 94 | if Act == 3 and Pre == 3: TP4 += 1 95 | if Act == 3 and Pre != 3: FN4 += 1 96 | if Act != 3 and Pre == 3: FP4 += 1 97 | if Act != 3 and Pre != 3: TN4 += 1 98 | 99 | ## print result 100 | Acc_all = float(TP1 + TP2 + TP3 + TP4) / float(len(y)) 101 | Acc1 = float(TP1 + TN1) / float(TP1 + TN1 + FN1 + FP1) 102 | if (TP1 + FP1) == 0: 103 | Prec1 = 0 104 | else: 105 | Prec1 = float(TP1) / float(TP1 + FP1) 106 | if (TP1 + FN1) == 0: 107 | Recll1 = 0 108 | else: 109 | Recll1 = float(TP1) / float(TP1 + FN1) 110 | if (Prec1 + Recll1) == 0: 111 | F1 = 0 112 | else: 113 | F1 = 2 * Prec1 * Recll1 / (Prec1 + Recll1) 114 | 115 | Acc2 = float(TP2 + TN2) / float(TP2 + TN2 + FN2 + FP2) 116 | if (TP2 + FP2) == 0: 117 | Prec2 = 0 118 | else: 119 | Prec2 = float(TP2) / float(TP2 + FP2) 120 | if (TP2 + FN2) == 0: 121 | Recll2 = 0 122 | else: 123 | Recll2 = float(TP2) / float(TP2 + FN2) 124 | if (Prec2 + Recll2) == 0: 125 | F2 = 0 126 | else: 127 | F2 = 2 * Prec2 * Recll2 / (Prec2 + Recll2) 128 | 129 | Acc3 = float(TP3 + TN3) / float(TP3 + TN3 + FN3 + FP3) 130 | if (TP3 + FP3) == 0: 131 | Prec3 = 0 132 | else: 133 | Prec3 = float(TP3) / float(TP3 + FP3) 134 | if (TP3 + FN3 ) == 0: 135 | Recll3 = 0 136 | else: 137 | Recll3 = float(TP3) / float(TP3 + FN3) 138 | if (Prec3 + Recll3) == 0: 139 | F3 = 0 140 | else: 141 | F3 = 2 * Prec3 * Recll3 / (Prec3 + Recll3) 142 | 143 | Acc4 = float(TP4 + TN4) / float(TP4 + TN4 + FN4 + FP4) 144 | if (TP4 + FP4) == 0: 145 | Prec4 = 0 146 | else: 147 | Prec4 = float(TP4) / float(TP4 + FP4) 148 | if (TP4 + FN4) == 0: 149 | Recll4 = 0 150 | else: 151 | Recll4 = float(TP4) / float(TP4 + FN4) 152 | if (Prec4 + Recll4) == 0: 153 | F4 = 0 154 | else: 155 | F4 = 2 * Prec4 * Recll4 / (Prec4 + Recll4) 156 | 157 | return [F1, F2, F3, F4] -------------------------------------------------------------------------------- /src/data/cluster_topics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | import json 4 | import emoji 5 | import openai 6 | import shutil 7 | import argparse 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from tqdm import tqdm 12 | from wordcloud import WordCloud 13 | 14 | import torch 15 | from sklearn.cluster import KMeans 16 | from sklearn.metrics import silhouette_score 17 | from transformers import RobertaTokenizer, RobertaModel 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description="Split data via clustering.") 21 | 22 | ## Which experiment 23 | parser.add_argument("--split_data_via_cluster", action="store_true") 24 | parser.add_argument("--select_number_of_clusters", action="store_true") 25 | parser.add_argument("--collect_semeval2019_event", action="store_true") 26 | 27 | ## Parameters 28 | parser.add_argument("--n_clusters", type=int, default=3) 29 | 30 | ## Others 31 | parser.add_argument("--dataset_name", type=str, default="twitter15", choices=["semeval2019", "twitter15", "twitter16"]) 32 | parser.add_argument("--dataset_root", type=str, default="../dataset/processedV2") 33 | parser.add_argument("--fold", type=str, default="0,1,2,3,4", help="either use 5-fold data or train/dev/test from rumoureval2019 competition") 34 | parser.add_argument("--result_path", type=str, default="/mnt/1T/projects/RumorV2/results") 35 | 36 | args = parser.parse_args() 37 | 38 | return args 39 | 40 | def load_data(args): 41 | def preprocess_txt(txt): 42 | return emoji.demojize(txt).replace("URL", "").replace("url", "") 43 | 44 | print("\nLoad data...") 45 | print("Dataset: [{}]".format(args.dataset_name)) 46 | 47 | data_df = pd.read_csv("{}/{}/data.csv".format(args.dataset_root, args.dataset_name)) 48 | group_src = data_df.groupby("source_id") 49 | 50 | src_ids, src_txts = [], [] 51 | for src_id, group in group_src: 52 | src_txt = group.iloc[0]["text"] 53 | src_txt = preprocess_txt(src_txt) 54 | src_txts.append(src_txt) 55 | src_ids.append(src_id) 56 | 57 | return data_df, src_ids, src_txts 58 | 59 | def load_model(device): 60 | print("\nLoad RoBERTa-Large & tokenizer...") 61 | tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 62 | model = RobertaModel.from_pretrained("roberta-large") 63 | model.to(device) 64 | return tokenizer, model 65 | 66 | def get_roberta_txt_feat(tokenizer, model, text, device): 67 | """ 68 | Input format: " {} <\s>".format(text) 69 | """ 70 | encoded_input = tokenizer(text, return_tensors="pt") 71 | encoded_input = encoded_input.to(device) 72 | outputs = model(**encoded_input) 73 | return outputs 74 | 75 | def split_data_via_cluster(args): 76 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 77 | print("Device: {}".format(device)) 78 | 79 | data_df, src_ids, src_txts = load_data(args) 80 | tokenizer, model = load_model(device) 81 | 82 | print("\nCollect text features of each source post...") 83 | features = [] 84 | with torch.no_grad(): 85 | for src_txt in tqdm(src_txts): 86 | outputs = get_roberta_txt_feat(tokenizer, model, src_txt, device) 87 | feature = outputs["last_hidden_state"][0][0] ## First token of first data in a batch 88 | feature = feature.cpu().numpy() 89 | features.append(feature) 90 | 91 | print("\nPerform k-means clustering on features of source posts...") 92 | if args.select_number_of_clusters: 93 | for n_clusters in range(2, 11): 94 | silhouette_scores, cluster_results = [], [] 95 | for _ in tqdm(range(100), desc="{} Clusters".format(n_clusters)): 96 | kmeans_fit = KMeans(n_clusters=n_clusters, n_init="auto").fit(features) 97 | silhouette = silhouette_score(features, kmeans_fit.labels_) 98 | silhouette_scores.append(silhouette) 99 | cluster_results.append(kmeans_fit) 100 | 101 | print("[{}] Clusters".format(n_clusters)) 102 | silhouette_scores = np.array(silhouette_scores) 103 | print("Max Silhouette Score: {}".format(np.max(silhouette_scores))) 104 | print("Avg Silhouette Score: {}".format(np.mean(silhouette_scores))) 105 | else: 106 | silhouette_scores, cluster_results = [], [] 107 | n_iters = 1 108 | for _ in tqdm(range(1), desc="{} Clusters".format(args.n_clusters)): 109 | kmeans_fit = KMeans(n_clusters=args.n_clusters, n_init="auto").fit(features) 110 | silhouette = silhouette_score(features, kmeans_fit.labels_) 111 | silhouette_scores.append(silhouette) 112 | cluster_results.append(kmeans_fit) 113 | 114 | print("[{}] Clusters".format(args.n_clusters)) 115 | silhouette_scores = np.array(silhouette_scores) 116 | print("Max Silhouette Score: {}".format(np.max(silhouette_scores))) 117 | print("Avg Silhouette Score: {}".format(np.mean(silhouette_scores))) 118 | 119 | max_idx = np.argmax(silhouette_scores) 120 | labels_ = cluster_results[max_idx].labels_ 121 | 122 | src_ids, src_txts = np.array(src_ids), np.array(src_txts) 123 | 124 | fw = open("{}/{}/cluster_topics/cluster_ids.txt".format(args.dataset_root, args.dataset_name), "w") 125 | 126 | clusters = {} 127 | for cid in range(args.n_clusters): 128 | print("# cluster [{}]: {}".format(cid, (labels_ == cid).sum())) 129 | 130 | clusters[cid] = {} 131 | clusters[cid]["src_ids"] = src_ids[labels_ == cid] 132 | clusters[cid]["src_txts"] = src_txts[labels_ == cid] 133 | 134 | full_txt = " ".join(list(clusters[cid]["src_txts"])) 135 | wordcloud = WordCloud(width=1000, height=500).generate(full_txt) 136 | wordcloud.to_file("{}/{}/cluster_topics/cloud_{}.png".format(args.dataset_root, args.dataset_name, cid)) 137 | 138 | for src_id in clusters[cid]["src_ids"]: 139 | fw.write("{}: {}\n".format(src_id, cid)) 140 | 141 | fw.close() 142 | ipdb.set_trace() 143 | 144 | def collect_semeval2019_event(args): 145 | data_df = pd.read_csv("{}/semeval2019/data.csv".format(args.dataset_root)) 146 | group_src = data_df.groupby("source_id") 147 | 148 | for src_id, thread_df in group_src: 149 | src_txt = thread_df.iloc[0]["text"] 150 | 151 | def main(args): 152 | if args.split_data_via_cluster: 153 | split_data_via_cluster(args) 154 | elif args.collect_semeval2019_event: 155 | collect_semeval2019_event(args) 156 | 157 | if __name__ == "__main__": 158 | args = parse_args() 159 | main(args) 160 | -------------------------------------------------------------------------------- /src/data/topic_model.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import json 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | import nltk 9 | import gensim 10 | import gensim.corpora as corpora 11 | from gensim.utils import simple_preprocess 12 | from gensim.models import LdaModel, LdaMulticore 13 | 14 | #nltk.download('stopwords') 15 | from nltk.corpus import stopwords 16 | 17 | class TopicModel: 18 | def __init__(self, n_topics=3): 19 | self.n_topics = n_topics 20 | self.punctuation = "!\"#$%&()*+,./:;<=>?@[\\]^`{|}~" 21 | self.punct_table = str.maketrans(dict.fromkeys(self.punctuation)) 22 | 23 | def sent_to_words(self, sentences): 24 | for sentence in sentences: 25 | # deacc=True removes punctuations 26 | yield(simple_preprocess(str(sentence), deacc=True)) 27 | 28 | def remove_stopwords(self, texts): 29 | stop_words = stopwords.words('english') 30 | stop_words.extend(['from', 'subject', 're', 'edu', 'use', 'would']) 31 | return [[word for word in simple_preprocess(str(doc)) 32 | if word not in stop_words] for doc in texts] 33 | 34 | def find_topics(self, text_df): 35 | """ 36 | Input: 37 | text_df: dataframe of text, each row represents the textual content of a response 38 | Output: 39 | topics: dictionary of topics (key: topic_id, value: word distribution) 40 | """ 41 | text_df = text_df.map(lambda x: x.translate(self.punct_table)) 42 | text_df = text_df.map(lambda x: x.lower()) 43 | 44 | text = text_df.tolist() 45 | text_words = list(self.sent_to_words(text)) 46 | text_words = self.remove_stopwords(text_words) 47 | 48 | id2word = corpora.Dictionary(text_words) 49 | corpus = [id2word.doc2bow(text) for text in text_words] 50 | corpus_size = len(id2word) 51 | 52 | ## LDA 53 | topics = {} 54 | lda_model = LdaModel( 55 | corpus=corpus, 56 | id2word=id2word, 57 | num_topics=self.n_topics, 58 | dtype=np.float64 59 | ) #LdaMulticore(corpus=corpus, id2word=id2word, num_topics=self.n_topics) 60 | for topic_id, word_dist in lda_model.show_topics(num_words=corpus_size, formatted=False): 61 | #print("Topic ID: {}, {}".format(topic_id, word_dist)) 62 | topics[topic_id] = word_dist 63 | 64 | return topics 65 | 66 | def build_topics(args): 67 | """ 68 | topics = [] 69 | for line in open("test.json"): 70 | topics.append(json.loads(line.strip().rstrip())) 71 | ipdb.set_trace() 72 | """ 73 | 74 | ## Load dataset 75 | data_df = pd.read_csv("{}/{}/data.csv".format(args.dataset_root, args.dataset_name)) 76 | 77 | ## Build topic model 78 | topic_model = TopicModel(n_topics=args.n_topics) 79 | 80 | ## For each thread 81 | topics_threads = {} 82 | fout = open("{}/{}/topics.json".format(args.dataset_root, args.dataset_name), "w") 83 | for src_id, group in tqdm(data_df.groupby("source_id")): 84 | topics = topic_model.find_topics(group["text"]) 85 | topics_threads[src_id] = topics 86 | fout.write(json.dumps(topics_threads)) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser(description="Rumor Detection") 90 | parser.add_argument("--dataset_name", type=str, default="semeval2019", choices=["semeval2019", "Pheme", "twitter15", "twitter16"]) 91 | parser.add_argument("--dataset_root", type=str, default="../dataset/processedV2") 92 | parser.add_argument("--n_topics", type=int, default=3) 93 | args = parser.parse_args() 94 | 95 | build_topics(args) 96 | 97 | """ 98 | import re 99 | import pickle 100 | import pyLDAvis 101 | import pyLDAvis.gensim_models 102 | 103 | def parse_args(): 104 | parser = argparse.ArgumentParser(description="Rumor Detection") 105 | 106 | ## Others 107 | parser.add_argument("--dataset", type=str, default="semeval2019", choices=["semeval2019", "Pheme", "twitter15", "twitter16"]) 108 | parser.add_argument("--data_root", type=str, default="../dataset/processedV2") 109 | parser.add_argument("--fold", type=str, default="0,1,2,3,4,comp", help="either use 5-fold data or train/dev/test from rumoureval2019 competition") 110 | 111 | args = parser.parse_args() 112 | 113 | return args 114 | 115 | def sent_to_words(sentences): 116 | for sentence in sentences: 117 | # deacc=True removes punctuations 118 | yield(simple_preprocess(str(sentence), deacc=True)) 119 | 120 | def remove_stopwords(texts): 121 | stop_words = stopwords.words('english') 122 | stop_words.extend(['from', 'subject', 're', 'edu', 'use', 'would']) 123 | return [[word for word in simple_preprocess(str(doc)) 124 | if word not in stop_words] for doc in texts] 125 | 126 | def main(args): 127 | data_path = "{}/{}/data.csv".format(args.data_root, args.dataset) 128 | data_df = pd.read_csv(data_path) 129 | 130 | ## Get source tweet content 131 | src_df = data_df[data_df["source_id"] == data_df["tweet_id"]] 132 | src_df = src_df.drop(columns=["tweet_id", "parent_idx", "self_idx", "num_parent", "max_seq_len", "veracity", "stance"]) 133 | 134 | ## Preprocess text 135 | src_df["text_processed"] = src_df["text"].map(lambda x: re.sub("[,\\.!?]", "", x)) 136 | src_df["text_processed"] = src_df["text_processed"].map(lambda x: x.lower().replace("", "").replace("url", "")) 137 | 138 | ## Prepare data for LDA analysis 139 | data = src_df["text_processed"].tolist() 140 | data_words = list(sent_to_words(data)) 141 | data_words = remove_stopwords(data_words) 142 | 143 | id2word = corpora.Dictionary(data_words) 144 | texts = data_words 145 | corpus = [id2word.doc2bow(text) for text in texts] 146 | 147 | ## LDA 148 | num_topics = 8 149 | lda_model = gensim.models.LdaMulticore(corpus=corpus, id2word=id2word, num_topics=num_topics) 150 | for topic in lda_model.print_topics(): 151 | print(topic) 152 | 153 | ## Visualization 154 | #pyLDAvis.enable_notebook() 155 | LDAvis_data_filepath = "{}/{}/lda/ldavis_prepared_{}".format(args.data_root, args.dataset, num_topics) 156 | if True: 157 | LDAvis_prepared = pyLDAvis.gensim_models.prepare(lda_model, corpus, id2word) 158 | with open(LDAvis_data_filepath, 'wb') as f: 159 | pickle.dump(LDAvis_prepared, f) 160 | 161 | with open(LDAvis_data_filepath, 'rb') as f: 162 | LDAvis_prepared = pickle.load(f) 163 | 164 | pyLDAvis.save_html(LDAvis_prepared, "{}/{}/lda/ldavis_prepared_{}.html".format(args.data_root, args.dataset, num_topics)) 165 | 166 | if __name__ == "__main__": 167 | args = parse_args() 168 | main(args) 169 | """ -------------------------------------------------------------------------------- /src/data/build_datasets_clustering.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import json 3 | import random 4 | import logging 5 | import pandas as pd 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from datasets import load_dataset 10 | 11 | ## Call the same logger used in main_summ.py 12 | logger = logging.getLogger("__main__") 13 | 14 | def build_datasets_clustering(data_args, model_args, training_args, config, tokenizer, model): 15 | """Build datasets for building cluster summary dataset.""" 16 | ################## 17 | ## Load Dataset ## 18 | ################## 19 | print("\nLoading dataset...") 20 | print("[{}]: fold [{}]".format(data_args.dataset_name, data_args.fold)) 21 | 22 | data_files = {"train": "{}/{}/data_tree_ids.csv".format(data_args.dataset_root, data_args.dataset_name)} 23 | raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) 24 | 25 | ## Read tweet contents 26 | data_df = pd.read_csv("{}/{}/data.csv".format(data_args.dataset_root, data_args.dataset_name)) 27 | data_df["source_id"] = data_df["source_id"].astype(str) ## For PHEME, twitter15, twitter16 28 | data_df["tweet_id"] = data_df["tweet_id"].astype(str) 29 | dataset_content = data_df.set_index("tweet_id").T.to_dict() ## Each tweet_id maps to all information 30 | 31 | ################ 32 | ## Preprocess ## 33 | ################ 34 | print("\nProcessing dataset...") 35 | 36 | ## Padding strategy 37 | max_target_length = data_args.max_target_length ## For generator 38 | padding = "max_length" if data_args.pad_to_max_length else False 39 | 40 | assert data_args.max_seq_length * data_args.max_tree_length <= tokenizer.model_max_length, \ 41 | "Max length of tree sequence ({}) larger than the max input length for the model ({})!".format( 42 | data_args.max_seq_length * data_args.max_tree_length, tokenizer.model_max_length 43 | ) 44 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 45 | 46 | def preprocess_function(examples): 47 | """ 48 | Input: 49 | - examples: keys = ["source_id", "tweet_ids"] 50 | """ 51 | def parse_trees_from_str(input_trees): 52 | """Parse each tree from string of tweet_ids to list""" 53 | output_trees = [tweet_ids_str.split(",") for tweet_ids_str in input_trees] 54 | assert len(output_trees) == len(input_trees) 55 | return output_trees 56 | 57 | def id2text(trees): 58 | """ 59 | Convert all tweet ids to corresponding text 60 | Return 61 | - tweet_ids: list of tweet id that maps `texts` 62 | - texts : list of all tweet texts 63 | """ 64 | tweet_ids = [tweet_id for tree in trees for tweet_id in tree] 65 | texts = [dataset_content[tweet_id]["text"] for tweet_id in tweet_ids] 66 | return tweet_ids, texts 67 | 68 | def formulate_one_tree(model_inputs, start_idx, end_idx): 69 | """Take the whole tree completely without truncating.""" 70 | input_ids = model_inputs["input_ids"][start_idx:end_idx] 71 | attn_mask = model_inputs["attention_mask"][start_idx:end_idx] 72 | 73 | ## Padding to make each tree have 300 nodes 74 | padding = [-1] * data_args.max_tweet_length 75 | input_ids.extend([padding] * (300 - len(input_ids))) 76 | attn_mask.extend([padding] * (300 - len(attn_mask))) 77 | return input_ids, attn_mask 78 | 79 | def extract_topic_words_probs(tree_topic): 80 | """ 81 | Input: 82 | - tree_topic: dict, all topics of a tree 83 | """ 84 | top_k_topic_words = 10 85 | corpus, topic_words, topic_probs = [], [], [] 86 | for topic_id in range(len(tree_topic)): 87 | topic_words_probs = tree_topic[str(topic_id)][:top_k_topic_words] ## Pick top k representative words 88 | w = [pair[0] for pair in topic_words_probs] 89 | p = [pair[1] for pair in topic_words_probs] 90 | corpus.extend(w) 91 | topic_words.append(w) 92 | topic_probs.append(p) 93 | return topic_words, topic_probs, list(set(corpus)) 94 | 95 | def add_special_tokens_and_pad(topic_ids, topic_probs, max_topic_seq_len): 96 | new_topic_ids, new_topic_probs, topic_msk = [], [], [] 97 | max_topic_seq_len = max_topic_seq_len + 2 ## bos & eos 98 | for i, tree_topic_ids in enumerate(topic_ids): ## For each tree 99 | new_ids, new_probs, msks = [], [], [] 100 | for j, ids in enumerate(tree_topic_ids): ## For each topic 101 | ## Add bos & eos token 102 | ids = [tokenizer.bos_token_id] + ids + [tokenizer.eos_token_id] 103 | prb = [0] + topic_probs[i][j] + [0] 104 | #prb = [sum(topic_probs[i][j]) / len(ids)] + topic_probs[i][j] + [sum(topic_probs[i][j])/len(ids)] 105 | #ipdb.set_trace() 106 | msk = [1] * len(ids) 107 | 108 | ## Padding 109 | num_padding = max_topic_seq_len - len(ids) 110 | ids.extend([tokenizer.pad_token_id] * num_padding) 111 | prb.extend([0] * num_padding) 112 | msk.extend([0] * num_padding) 113 | 114 | new_ids.append(ids) 115 | new_probs.append(prb) 116 | msks.append(msk) 117 | 118 | new_topic_ids.append(new_ids) 119 | new_topic_probs.append(new_probs) 120 | topic_msk.append(msks) 121 | return new_topic_ids, new_topic_probs, topic_msk 122 | 123 | ## ------------------------------------------------------------------------------------------------------ 124 | 125 | ## Take all responses into consideration 126 | src_ids, trees = examples["source_id"], examples["tweet_ids"] 127 | src_ids = [str(src_id) for src_id in src_ids] ## Convert src_id to strings 128 | trees = parse_trees_from_str(trees) 129 | tree_lens = [len(tree) for tree in trees] 130 | tweet_ids, texts = id2text(trees) 131 | 132 | model_inputs = tokenizer(texts, padding=padding, max_length=data_args.max_tweet_length, truncation=True) 133 | 134 | start_idx = 0 135 | input_ids = [] 136 | attn_mask = [] 137 | for tree_i in trees: 138 | input_ids_i, attn_mask_i = formulate_one_tree(model_inputs, start_idx, start_idx + len(tree_i)) 139 | input_ids.append(input_ids_i) 140 | attn_mask.append(attn_mask_i) 141 | start_idx = start_idx + len(tree_i) 142 | 143 | ## Final assignment 144 | model_inputs["source_id"] = src_ids 145 | model_inputs["tweet_ids"] = [",".join(tweet_ids) for tweet_ids in trees] ## list of list (tweet ids of a tree) 146 | model_inputs["tree_lens"] = tree_lens 147 | model_inputs["input_ids"] = input_ids 148 | model_inputs["attention_mask"] = attn_mask 149 | 150 | return model_inputs 151 | 152 | with training_args.main_process_first(desc="dataset map pre-processing"): 153 | for key_data in raw_datasets.keys(): ## Separately process each sets of raw_datasets 154 | raw_datasets[key_data] = raw_datasets[key_data].map( 155 | preprocess_function, 156 | batched=True, 157 | load_from_cache_file=not data_args.overwrite_cache, 158 | remove_columns=raw_datasets[key_data].column_names, ## Enable the function to return more samples than input 159 | desc="Running tokenizer on {} dataset".format(key_data) 160 | ) 161 | 162 | #################### 163 | ## Build each set ## 164 | #################### 165 | train_dataset = raw_datasets["train"] 166 | eval_dataset = None 167 | test_dataset = None 168 | 169 | return train_dataset, eval_dataset, test_dataset -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Beyond Detection: A Defend-and-Summarize Strategy for Robust and Interpretable Rumor Analysis on Social Media 2 | ![](https://img.shields.io/badge/Python-3.8-blue) ![](https://img.shields.io/badge/Pytorch-1.11.0-orange) 3 | 4 | ![](https://github.com/joshchang0111/EMNLP2023-RumorDAS/blob/master/das_overview.png) 5 | [[Paper](https://aclanthology.org/2023.emnlp-main.707/)] [[Datasets](https://drive.google.com/file/d/1gkK3oNstw_pWehLD9W7ZdvfTxmPUlREz/view?usp=sharing)] 6 | 7 | ## Introduction 8 | Original PyTorch implementation for the EMNLP 2023 paper "Beyond Detection: A Defend-and-Summarize Strategy for Robust and Interpretable Rumor Analysis on Social Media" by Yi-Ting Chang, Yun-Zhu Song, Yi-Syuan Chen, and Hong-Han Shuai. 9 | 10 | This project is organized in the following structure. 11 | ``` 12 | |__ src 13 | |__ main.py -> organize main flow for the codes 14 | |__ data -> data preprocessing/build datasets 15 | |__ models -> different model classes 16 | |__ others 17 | |__ pipelines -> different trainers 18 | |__ scripts -> scripts for train/test of each task 19 | |__ dataset -> put the processed datasets here 20 | |__ re2019 21 | |__ twitter15 22 | |__ twitter16 23 | |__ {$output_dir} -> store the experimental results with the `exp_id` in each script as the folder name 24 | |__ re2019 25 | |__ twitter15 26 | |__ twitter16 27 | ``` 28 | 29 | ## Environmental Setup 30 | This code is developed under **Ubuntu 20.04.3 LTS** and **Python 3.8.10**. Run the script `build_env.sh` first to install necessary packages through pip. 31 | 32 | ### Install PyTorch as follows. 33 | ```bash 34 | $ pip install torch==1.11.0+cu102 --extra-index-url https://download.pytorch.org/whl/cu102 35 | $ pip install torch-scatter==2.0.8 -f https://data.pyg.org/whl/torch-1.11.0+cu102.html 36 | ``` 37 | [[Reference Solution for Installing PyTorch Geometric](https://stackoverflow.com/questions/70008715/pytorch-and-torch-scatter-were-compiled-with-different-cuda-versions-on-google-c)]: For installing torch-scatter/torch-cluster/torch-sparse, you should first obtain the **reference url** by your desired *PyTorch* and *CUDA* version according to your computer. Next, you need to specify the latest version of torch-scatter provided by the link (which is 2.0.8 in this case) when installing through pip. 38 | 39 | ### Install kmeans_pytorch from source. 40 | ```bash 41 | $ git clone https://github.com/subhadarship/kmeans_pytorch 42 | $ cd kmeans_pytorch 43 | $ pip install --editable . 44 | $ pip install numba 45 | ``` 46 | 47 | ## Dataset 48 | The datasets should be placed at the folder `dataset` on the same layer as `src`. Each dataset should contain several files organized in the following structure. 49 | ``` 50 | dataset 51 | |__ {DATASET_NAME_0} 52 | |__ data.csv -> all data information 53 | |__ data_tree_ids.csv -> for `build_cluster.sh` 54 | |__ graph.pth -> cache for graph data, created after training the detection model once. 55 | |__ split_{$FOLD_N} 56 | |__ train.csv 57 | |__ test.csv 58 | |__ cluster_summary/train -> store the cluster information for training SSRA. 59 | |__ kmeans-{$N_CLUSTERS}.csv 60 | |__ ... 61 | |__ ... 62 | |__ ... 63 | ``` 64 | The file `data.csv` consists of 8 columns of data as follows. 65 | | Column | Description | 66 | |------------|-------------| 67 | |source_id |tweet id of the source tweet for each conversation thread| 68 | |tweet_id |tweet id for each tweet| 69 | |parent_idx |index of each tweet's parent node, set to `None` if the tweet is source| 70 | |self_idx |index of each tweet in a conversation thread, arranged in chronological order| 71 | |num_parent |number of parent nodes in each conversation thread| 72 | |max_seq_len |maximal sequence length for each conversation thread| 73 | |text |textual content for each tweet| 74 | |veracity |veracity label for the source post of each conversation thread| 75 | 76 | Our processed datasets are available at [[Datasets](https://drive.google.com/file/d/1gkK3oNstw_pWehLD9W7ZdvfTxmPUlREz/view?usp=sharing)]. 77 | 78 | ## Run the Codes 79 | We provide the training and evaluation scripts for each component of our framework in the folder `src/scripts`. Notice that each script requires an output root directory (`$output_dir`) and an experiment name (`--exp_name`). After executing each script, the experimental results including the model checkpoints will be automatically stored in the following structure: 80 | ``` 81 | {$output_dir} 82 | |__ {$DATASET_NAME_0} 83 | |__ {$EXP_NAME_0} 84 | |__ {$EXP_NAME_2} 85 | |__ ... 86 | |__ {$DATASET_NAME_1} 87 | |__ {$EXP_NAME_1} 88 | |__ ... 89 | ``` 90 | 91 | ### 1. Train BiTGN (RoBERTa) 92 | Train the model. 93 | ```bash 94 | $ sh scripts/detection/train.sh 95 | ``` 96 | Evaluate trained models. 97 | ```bash 98 | $ sh scripts/detection/eval.sh 99 | ``` 100 | 101 | ### 2. Train BiTGN (BART) + ARG 102 | #### 2.1 Adversarial Training Stage 1 103 | Train the detector along with the generator. 104 | ```bash 105 | $ sh scripts/attack/stage1/train.sh 106 | ``` 107 | Evaluate the trained detector. 108 | ```bash 109 | $ sh scripts/attack/stage1/eval.sh 110 | ``` 111 | #### 2.2 Adversarial Training Stage 2 112 | Train the generator to attack the detector while fixing the detector. Note that this training stage should be executed after stage 1 finished, or at least one checkpoint from stage 1 exists. 113 | ```bash 114 | $ sh scripts/attack/stage2/train.sh 115 | ``` 116 | 117 | ### 3. Train Response Extractor (AutoEncoder) 118 | This stage obtains the embedding from the pre-trained detector, please make sure you have at least one checkpoint from previous step before you run the following script. 119 | ```bash 120 | $ sh scripts/summarizer/filter/train.sh 121 | ``` 122 | 123 | ### 4. Train Response Abstractor (SSRA) 124 | #### 4.1 Build Clusters 125 | In order to train the **S**elf-**S**upervised **R**esponse **A**bstractor (SSRA) with $k$-means settings, you need to build the clusters information from the dataset first by running the following script. Note that this step also requires a checkpoint from step 2.2, so make sure the settings is correct. 126 | ```bash 127 | $ sh scripts/summarizer/build_cluster.sh 128 | ``` 129 | #### 4.2 Start Training 130 | To train the response abstractor with $k$-means settings, check that you have already built clusters information as documented in the dataset description . 131 | ```bash 132 | $ sh scripts/summarizer/ssra-kmeans/train.sh 133 | ``` 134 | Evaluate the trained abstractors. 135 | ```bash 136 | $ sh scripts/summarizer/ssra-kmeans/eval.sh 137 | ``` 138 | 139 | ### 5. Evaluate the BiTGN with DAS 140 | Evaluate DAS with the following hyper-parameters: 141 | - extract ratio $\rho$ in the range $\\{0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 0.9\\}$ 142 | - number of clusters $k$ in the range $\\{1, 2, 3, 4, 5\\}$ 143 | ```bash 144 | $ sh scripts/attack/stage2/eval.sh 145 | ``` 146 | The script performs the following evaluation for each fold and hyper-parameters set: 147 | 1. Evaluate stage-2 detector *without* adversarial attack. 148 | 2. Evaluate stage-2 detector *under* adversarial attack *without* summarizer. 149 | 3. Evaluate stage-2 detector *under* adversarial attack *with* summarizer. 150 | 151 | ## Citation 152 | ``` 153 | @inproceedings{chang-etal-2023-beyond, 154 | title = "Beyond Detection: A Defend-and-Summarize Strategy for Robust and Interpretable Rumor Analysis on Social Media", 155 | author = "Chang, Yi-Ting and 156 | Song, Yun-Zhu and 157 | Chen, Yi-Syuan and 158 | Shuai, Hong-Han", 159 | booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing", 160 | month = dec, 161 | year = "2023", 162 | address = "Singapore", 163 | publisher = "Association for Computational Linguistics", 164 | url = "https://aclanthology.org/2023.emnlp-main.707", 165 | pages = "11538--11556" 166 | } 167 | ``` -------------------------------------------------------------------------------- /src/data/build_datasets_filter.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import random 3 | import logging 4 | import pandas as pd 5 | 6 | from torch.utils.data import DataLoader 7 | 8 | from datasets import load_dataset 9 | 10 | ## Call the same logger used in main_summ.py 11 | logger = logging.getLogger("__main__") 12 | 13 | def build_datasets_filter(data_args, model_args, training_args, config, tokenizer, model): 14 | """Build datasets for training autoencoder filter.""" 15 | ################## 16 | ## Load Dataset ## 17 | ################## 18 | print("\nLoading dataset...") 19 | print("[{}]: fold [{}]".format(data_args.dataset_name, data_args.fold)) 20 | 21 | ## Loading a dataset from my local files. 22 | if data_args.dataset_name == "PHEME": 23 | ## Conduct event-wise cross-validation for PHEME 24 | data_files = { 25 | "train" : "{}/{}/event_{}/{}".format(data_args.dataset_root, data_args.dataset_name, data_args.fold, data_args.train_file), 26 | "validation": "{}/{}/event_{}/{}".format(data_args.dataset_root, data_args.dataset_name, data_args.fold, data_args.validation_file), 27 | "test" : "{}/{}/event_{}/{}".format(data_args.dataset_root, data_args.dataset_name, data_args.fold, data_args.validation_file) 28 | } 29 | else: 30 | data_files = { 31 | "train" : "{}/{}/split_{}/{}".format(data_args.dataset_root, data_args.dataset_name, data_args.fold, data_args.train_file), 32 | "validation": "{}/{}/split_{}/{}".format(data_args.dataset_root, data_args.dataset_name, data_args.fold, data_args.validation_file), 33 | "test" : "{}/{}/split_{}/{}".format(data_args.dataset_root, data_args.dataset_name, data_args.fold, data_args.validation_file) 34 | } 35 | raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) 36 | 37 | ## Read tweet contents 38 | data_df = pd.read_csv("{}/{}/data.csv".format(data_args.dataset_root, data_args.dataset_name)) 39 | data_df["source_id"] = data_df["source_id"].astype(str) ## For PHEME, twitter15, twitter16 40 | data_df["tweet_id"] = data_df["tweet_id"].astype(str) 41 | dataset_content = data_df.set_index("tweet_id").T.to_dict() ## Each tweet_id maps to all information 42 | 43 | ## Labels for detector 44 | label_list = raw_datasets["train"].unique("label_veracity") 45 | label_list.sort() ## sort for determinism 46 | num_labels = len(label_list) 47 | 48 | assert num_labels == data_args.num_labels, "num_labels specified in data arguments doesn't match your actual dataset!" 49 | 50 | ################ 51 | ## Preprocess ## 52 | ################ 53 | print("\nProcessing dataset...") 54 | 55 | ## Padding strategy 56 | max_target_length = data_args.max_target_length ## For generator 57 | padding = "max_length" if data_args.pad_to_max_length else False 58 | 59 | ## Some models have set the order of the labels to use, so let's make sure we do use it. 60 | label_to_id = {v: i for i, v in enumerate(label_list)} 61 | 62 | if label_to_id is not None: 63 | config.label2id = label_to_id 64 | config.id2label = {id: label for label, id in config.label2id.items()} 65 | 66 | assert data_args.max_seq_length * data_args.max_tree_length <= tokenizer.model_max_length, \ 67 | "Max length of tree sequence ({}) larger than the max input length for the model ({})!".format( 68 | data_args.max_seq_length * data_args.max_tree_length, tokenizer.model_max_length 69 | ) 70 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 71 | 72 | def preprocess_function(examples): 73 | """ 74 | Input: 75 | - examples: keys = ["source_id", "tweet_ids", "label_veracity"] 76 | Different preprocessing for different set: 77 | - train : w/ augmentation -> need to predict each response 78 | - validation: w/ augmentation -> evaluate detector's performance on augmented tree 79 | - test : w/o augmentation -> test on original tree (for detector's performance, no evaluation on attacker) 80 | """ 81 | def filter_data_with_target_class(src_ids, trees): 82 | """Train the AE to reconstruct a specific class""" 83 | stance = ["support", "comment", "query", "deny"] 84 | verity = ["true", "false", "unverified"] 85 | 86 | key = "" 87 | if model_args.target_class_ext_ae.lower() in stance: 88 | key = "stance" 89 | elif model_args.target_class_ext_ae.lower() in verity: 90 | key = "veracity" 91 | else: ## `all` 92 | return src_ids, trees 93 | 94 | print("The model will be trained to reconstruct responses with label `{}`".format(model_args.target_class_ext_ae)) 95 | 96 | filter_src_ids, filter_trees = [], [] 97 | for idx in range(len(src_ids)): 98 | if dataset_content[src_ids[idx]][key] == model_args.target_class_ext_ae.lower(): 99 | filter_src_ids.append(src_ids[idx]) 100 | filter_trees.append(trees[idx]) 101 | 102 | return filter_src_ids, filter_trees 103 | 104 | def parse_trees_from_str(input_trees): 105 | """Parse each tree from string of tweet_ids to list""" 106 | output_trees = [tweet_ids_str.split(",") for tweet_ids_str in input_trees] 107 | assert len(output_trees) == len(input_trees) 108 | return output_trees 109 | 110 | def id2text(trees): 111 | """ 112 | Convert all tweet ids to corresponding text 113 | Return 114 | - tweet_ids: list of tweet id that maps `texts` 115 | - texts : list of all tweet texts 116 | """ 117 | tweet_ids = [tweet_id for tree in trees for tweet_id in tree] 118 | texts = [dataset_content[tweet_id]["text"] for tweet_id in tweet_ids] 119 | return tweet_ids, texts 120 | 121 | ## ------------------------------------------------------------------------------------------------------ 122 | src_ids, trees = examples["source_id"], examples["tweet_ids"] 123 | src_ids = [str(src_id) for src_id in src_ids] ## Convert to strings 124 | src_ids, trees = filter_data_with_target_class(src_ids, trees) 125 | trees = parse_trees_from_str(trees) 126 | tweet_ids, texts = id2text(trees) 127 | 128 | model_inputs = tokenizer(texts, padding=padding, max_length=data_args.max_tweet_length, truncation=True) 129 | 130 | return model_inputs 131 | 132 | with training_args.main_process_first(desc="dataset map pre-processing"): 133 | for key_data in raw_datasets.keys(): ## Separately process each sets of raw_datasets 134 | raw_datasets[key_data] = raw_datasets[key_data].map( 135 | preprocess_function, 136 | batched=True, 137 | load_from_cache_file=not data_args.overwrite_cache, 138 | remove_columns=raw_datasets[key_data].column_names, ## Enable the function to return more samples than input 139 | desc="Running tokenizer on {} dataset".format(key_data) 140 | ) 141 | 142 | #################### 143 | ## Build each set ## 144 | #################### 145 | ## Make train & evaluation dataset 146 | if training_args.do_train: 147 | if "train" not in raw_datasets: 148 | raise ValueError("--do_train requires a train dataset") 149 | train_dataset = raw_datasets["train"] 150 | if data_args.max_train_samples is not None: 151 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 152 | 153 | ## Shuffle train dataset 154 | train_dataset = train_dataset.shuffle() 155 | 156 | if "validation" not in raw_datasets: 157 | raise ValueError("--do_train requires a validation dataset") 158 | eval_dataset = raw_datasets["validation"] 159 | if data_args.max_eval_samples is not None: 160 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 161 | 162 | ## Log a few random samples from the training set: 163 | for index in random.sample(range(len(train_dataset)), 3): 164 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 165 | else: 166 | train_dataset = None 167 | eval_dataset = None 168 | 169 | ## Make test dataset 170 | if training_args.do_eval: 171 | if "test" not in raw_datasets: 172 | raise ValueError("--do_eval requires a test dataset") 173 | test_dataset = raw_datasets["test"] 174 | if data_args.max_eval_samples is not None: 175 | test_dataset = test_dataset.select(range(data_args.max_eval_samples)) 176 | else: 177 | test_dataset = None 178 | 179 | return train_dataset, eval_dataset, test_dataset -------------------------------------------------------------------------------- /src/models/modeling_filter.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | ## Self-defined 8 | from others.utils import mean_pooling 9 | 10 | class TransformerAutoEncoder(nn.Module): 11 | def __init__( 12 | self, 13 | d_z=100, 14 | nhead=12, 15 | d_model=768, 16 | dropout=0.1, 17 | model_emb=None, 18 | num_layers_enc=2, 19 | num_layers_dec=2 20 | ): 21 | super(TransformerAutoEncoder, self).__init__() 22 | self.d_z = d_z 23 | self.nhead = nhead 24 | self.d_model = d_model 25 | self.num_layers_enc = num_layers_enc 26 | self.num_layers_dec = num_layers_dec 27 | 28 | print("n_layer_enc: {}".format(self.num_layers_enc)) 29 | print("n_layer_dec: {}".format(self.num_layers_dec)) 30 | 31 | if model_emb is not None: 32 | if "roberta" in model_emb.__class__.__name__.lower(): 33 | self.embedding_model = "roberta" 34 | self.embeddings = model_emb.roberta.embeddings 35 | elif "bart" in model_emb.__class__.__name__.lower(): 36 | self.embedding_model = "bart" 37 | self.embed_scale = model_emb.model.encoder.embed_scale 38 | self.embed_tokens = model_emb.model.encoder.embed_tokens 39 | #self.embed_positions = model_emb.model.encoder.embed_positions 40 | #self.layernorm_embedding = model_emb.model.encoder.layernorm_embedding 41 | 42 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.nhead) 43 | self.decoder_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.nhead) 44 | 45 | self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=self.num_layers_enc) 46 | self.decoder = nn.TransformerEncoder(self.decoder_layer, num_layers=self.num_layers_dec) ## Same as encoder 47 | 48 | ## TODO: Add activation function! ## 49 | self.proj_enc = nn.Linear(d_model, d_z) 50 | self.proj_dec = nn.Linear(d_z, d_model) 51 | 52 | def forward(self, input_ids=None, attn_mask=None, inputs_embeds=None): 53 | ## Choose loss function 54 | loss_fct = nn.L1Loss() ## L1 55 | loss_fct = nn.MSELoss(reduction="none") ## L2 56 | 57 | ## Get embeddings 58 | if inputs_embeds is None: 59 | if self.embedding_model == "roberta": 60 | embedding_output = self.embeddings(input_ids=input_ids) 61 | elif self.embedding_model == "bart": 62 | ## Use token embedding only 63 | #inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 64 | #embed_pos = self.embed_positions(input_ids.size()) 65 | #embedding_output = inputs_embeds + embed_pos 66 | #embedding_output = self.layernorm_embedding(embedding_output) 67 | embedding_output = self.embed_tokens(input_ids) * self.embed_scale 68 | else: 69 | embedding_output = inputs_embeds 70 | 71 | encoder_outputs = self.encoder(embedding_output, src_key_padding_mask=attn_mask.T) 72 | z_enc = self.proj_enc(encoder_outputs) 73 | z_enc = torch.tanh(z_enc) 74 | z_dec = self.proj_dec(z_enc) 75 | z_dec = torch.tanh(z_dec) 76 | decoder_outputs = self.decoder(z_dec, src_key_padding_mask=attn_mask.T) 77 | 78 | ## Mean pooling 79 | source_ = torch.mean(decoder_outputs, dim=1) 80 | target_ = torch.mean(embedding_output, dim=1) 81 | 82 | loss = loss_fct(source_, target_) 83 | return loss 84 | 85 | class ResponseFilter(nn.Module): 86 | def __init__(self, tokenizer, data_args, model_args, training_args): 87 | super(ResponseFilter, self).__init__() 88 | 89 | self.tokenizer = tokenizer 90 | self.data_args = data_args 91 | self.model_args = model_args 92 | self.training_args = training_args 93 | self.num_layers = model_args.filter_layer_enc 94 | 95 | assert (model_args.filter_layer_enc == model_args.filter_layer_dec) 96 | 97 | self.model = TransformerAutoEncoder( 98 | num_layers_enc=model_args.filter_layer_enc, 99 | num_layers_dec=model_args.filter_layer_dec 100 | ) 101 | ckpt_path = "{}/{}/filter_{}/{}/autoencoder_rd_all.pt".format( 102 | self.training_args.output_root, 103 | self.data_args.dataset_name, 104 | self.num_layers, 105 | self.data_args.fold 106 | ) 107 | self.model.load_state_dict(torch.load(ckpt_path), strict=False) 108 | self.filter_ratio = 0.05 if self.model_args.filter_ratio is None else self.model_args.filter_ratio 109 | 110 | def forward( 111 | self, 112 | tree_lens, 113 | inputs_embeds: Optional[torch.FloatTensor] = None, 114 | attention_mask: Optional[torch.FloatTensor] = None 115 | ): 116 | batch_size = inputs_embeds.shape[0] 117 | 118 | ## Iterate through each conversational thread and collect all response features 119 | batch_reply, batch_masks = [], [] 120 | batch_src, batch_src_msk, batch_gen_idx = [], [], [] 121 | for batch_idx in range(batch_size): 122 | tree_len = tree_lens[batch_idx] 123 | nodes = inputs_embeds[batch_idx][:tree_len] 124 | masks = attention_mask[batch_idx][:tree_len] 125 | 126 | batch_reply.append(nodes[1:]) 127 | batch_masks.append(masks[1:]) 128 | 129 | batch_src.append(nodes[0]) 130 | batch_src_msk.append(masks[0]) 131 | #batch_gen_idx.append(gen_idx) 132 | 133 | batch_reply = torch.cat(batch_reply, dim=0) 134 | batch_masks = torch.cat(batch_masks, dim=0) 135 | 136 | batch_src = torch.stack(batch_src) 137 | batch_src_msk = torch.stack(batch_src_msk) 138 | #batch_gen_idx = torch.LongTensor(batch_gen_idx) 139 | 140 | ## Forward through response filter 141 | ## Obtain anomaly score of each response 142 | anomaly_scores = self.model( 143 | attn_mask=batch_masks, 144 | inputs_embeds=batch_reply 145 | ) 146 | anomaly_scores = anomaly_scores.sum(dim=1) ## Sum up the loss of each response 147 | score_idx_sort = torch.argsort(anomaly_scores) 148 | 149 | ## Separately process each conversational thread 150 | n_ext_adv, res_accum_idx = 0, 0 151 | ext_idxs, ext_mask = [], [] 152 | new_tree_lens, new_inputs_embeds, new_attention_mask = [], [], [] 153 | max_tree_length = int(self.tokenizer.model_max_length / self.data_args.max_tweet_length) 154 | for batch_idx in range(batch_size): 155 | n_response = tree_lens[batch_idx] - 1 156 | rank_idx = score_idx_sort[score_idx_sort < n_response] 157 | rank_idx = rank_idx - rank_idx.min() if rank_idx.numel() > 0 else rank_idx ## Make each sample start from 0 158 | 159 | n_ext = torch.floor(n_response * self.filter_ratio).long() 160 | n_ext = n_ext + 1 if n_ext == 0 else n_ext 161 | ext_idx = rank_idx[:n_ext] 162 | ext_idx = ext_idx + 1 ## Consider source post 163 | 164 | one_hot = torch.zeros(max_tree_length, dtype=torch.long) 165 | one_hot[0] = 1 ## source post 166 | one_hot[ext_idx] = 1 167 | ext_idxs.append(one_hot) 168 | 169 | """ 170 | ## Calculate number of extracted attacks 171 | if one_hot[batch_gen_idx[batch_idx]] == 1: 172 | n_ext_adv = n_ext_adv + 1 173 | """ 174 | 175 | ## Get new (after extraction) `inputs_embeds` & `attention_mask`! Ex. src+res0+res1+res2+gen 176 | src_emb = batch_src[batch_idx:batch_idx + 1] 177 | src_msk = batch_src_msk[batch_idx:batch_idx + 1] 178 | res_emb = batch_reply[res_accum_idx:res_accum_idx + n_response][ext_idx - 1] ## Extract 179 | res_msk = batch_masks[res_accum_idx:res_accum_idx + n_response][ext_idx - 1] ## Extract 180 | res_accum_idx = res_accum_idx + n_response 181 | 182 | ## Concatenate source post with extracted responses 183 | tree_emb = torch.cat((src_emb, res_emb), dim=0) 184 | tree_msk = torch.cat((src_msk, res_msk), dim=0) 185 | 186 | new_tree_lens.append(tree_emb.shape[0]) 187 | 188 | ## Pad each `inputs_embeds` & `attention_mask` to max_tree_length 189 | pad_msk = torch.zeros(src_msk.shape, device=tree_emb.device) 190 | tree_emb = torch.cat((tree_emb, src_emb.repeat(max_tree_length - tree_emb.shape[0], 1, 1)), dim=0) 191 | tree_msk = torch.cat((tree_msk, pad_msk.repeat(max_tree_length - tree_msk.shape[0], 1)), dim=0) ## Note that transformers won't attend to pad nodes 192 | 193 | new_inputs_embeds.append(tree_emb) 194 | new_attention_mask.append(tree_msk) 195 | 196 | ext_idxs = torch.stack(ext_idxs).to(inputs_embeds.device) 197 | ext_mask = ext_idxs.view(batch_size, max_tree_length, 1).repeat(1, 1, self.data_args.max_tweet_length).view(batch_size, -1) 198 | 199 | new_tree_lens = torch.LongTensor(new_tree_lens).to(tree_lens.device) 200 | new_inputs_embeds = torch.stack(new_inputs_embeds) 201 | new_attention_mask = torch.stack(new_attention_mask) 202 | 203 | return ext_idxs, ext_mask, n_ext_adv, new_tree_lens, new_inputs_embeds, new_attention_mask 204 | 205 | 206 | -------------------------------------------------------------------------------- /src/models/modeling_gcn.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch_scatter import scatter_mean 8 | from torch_geometric.nn import GCNConv 9 | 10 | ## Self-defined 11 | from others.utils import mean_pooling 12 | 13 | class GCNPooler(nn.Module): 14 | """Stack of GCN layers, can be top-down / bottom-up""" 15 | 16 | def __init__(self, data_args, model_args, training_args, hidden_size, gcn_type=None): 17 | super(GCNPooler, self).__init__() 18 | 19 | print("GCN Type: {}".format(gcn_type)) 20 | 21 | self.n_layers = 2 22 | self.gcn_type = gcn_type ## `td` / `bu` 23 | if self.gcn_type is not None: 24 | self.child_idx = 1 if self.gcn_type == "td" else 0 25 | self.parent_idx = 0 if self.gcn_type == "td" else 1 26 | 27 | self.data_args = data_args 28 | self.model_args = model_args 29 | self.training_args = training_args 30 | 31 | ## Edge Filter 32 | self.filter = None 33 | if self.model_args.edge_filter: 34 | self.filter = nn.Sequential( 35 | nn.Linear(2 * hidden_size, 1), 36 | nn.Sigmoid() 37 | ) 38 | 39 | ## GCN Layers 40 | self.gcn = nn.ModuleList() 41 | for gcn_idx in range(self.n_layers): 42 | self.gcn.append(GCNConv(hidden_size, hidden_size)) 43 | self.fc = nn.Linear(hidden_size, hidden_size) 44 | self.act = nn.Tanh() 45 | 46 | def prepare_batch( 47 | self, 48 | hidden_states, 49 | attention_msk, 50 | tree_lens, 51 | edge_index 52 | ): 53 | """ 54 | Prepare batch data for GCNConv. 55 | Ex. [[0, 0], [1, 2]] + [[0, 1], [1, 2]] -> [[0, 0, 3, 4], [1, 2, 5, 6]] 56 | 57 | Return: 58 | - nodes: node features of all trees (graphs) in a batch 59 | - edges: edges for graph formed by all trees (graphs) in a batch 60 | - index: indicate each node belong to which tree (graph) 61 | """ 62 | root_idx = 0 63 | nodes, edges, index = [], [], [] 64 | for batch_idx in range(edge_index.shape[0]): ## Iterate through each tree 65 | edge = edge_index[batch_idx] 66 | edge = edge[:, edge[0] != -1] ## Remove padding (-1) 67 | 68 | #n_tweets = edge.shape[1] + 1 ## number of nodes = number of edges + 1 69 | n_tweets = tree_lens[batch_idx] 70 | node_states = hidden_states[batch_idx][:n_tweets] ## shape: (n_tweets, 32, 768) 71 | node_masks = attention_msk[batch_idx][:n_tweets] ## shape: (n_tweets, 32) 72 | 73 | ## Get each node's feature by mean pooling 74 | node_feat = mean_pooling(node_states, node_masks) 75 | 76 | ## Collect all trees (graphs) in a batch into a single graph 77 | edge = edge + root_idx 78 | edges.append(edge) 79 | nodes.append(node_feat) 80 | index.append(torch.ones(n_tweets, dtype=torch.long) * batch_idx) 81 | 82 | ## Set root index for next graph 83 | root_idx = root_idx + n_tweets 84 | 85 | nodes = torch.cat(nodes, dim=0) 86 | edges = torch.cat(edges, dim=1) 87 | index = torch.cat(index, dim=0).to(hidden_states.device) 88 | 89 | ## Check 90 | if edges.nelement() != 0: ## If edges is not empty 91 | if edges.max().item() > (nodes.shape[0] - 1): 92 | raise ValueError("Edge index more than number of nodes!") 93 | 94 | return nodes, edges, index 95 | 96 | def forward( 97 | self, 98 | hidden_states, 99 | attention_msk, 100 | tree_lens, 101 | edge_index=None 102 | ): 103 | """Aggregate graphs by GCN layers""" 104 | 105 | ## Split hidden states of each sequence (conversational thread) into nodes 106 | ## hidden_states.shape = (bs, max_len, hidden_size) -> (bs, max_tree_length, max_tweet_length, hidden_size) 107 | ## Ex. (8, 512, 768) -> (8, 16, 32, 768) 108 | hidden_states = torch.stack(torch.split(hidden_states, self.data_args.max_tweet_length, dim=1), dim=1) 109 | attention_msk = torch.stack(torch.split(attention_msk, self.data_args.max_tweet_length, dim=1), dim=1) 110 | 111 | if edge_index is None: 112 | ## Transformer only, take the first token as tree representation 113 | tree_embeds = hidden_states[:, 0, 0, :] 114 | else: 115 | ## Transformer + GCN 116 | nodes, edges, index = self.prepare_batch( 117 | hidden_states=hidden_states, 118 | attention_msk=attention_msk, 119 | tree_lens=tree_lens, 120 | edge_index=edge_index 121 | ) 122 | 123 | ## Edge Filter 124 | edge_weights = None 125 | if self.filter is not None: 126 | child_nodes = nodes[edges[self.child_idx]] 127 | parent_nodes = nodes[edges[self.parent_idx]] 128 | edge_weights = self.filter(torch.cat((child_nodes, parent_nodes), dim=1)) 129 | edge_weights = edge_weights.view(-1) 130 | 131 | ## GCN Layers 132 | for gcn_idx, conv_i in enumerate(self.gcn): 133 | try: 134 | nodes = conv_i(nodes, edges, edge_weights) 135 | except: 136 | ipdb.set_trace() 137 | nodes = F.relu(nodes) 138 | nodes = F.dropout(nodes, p=0.1, training=self.training) if gcn_idx == 0 else nodes 139 | 140 | ## Node aggregation 141 | tree_embeds = scatter_mean(nodes, index, dim=0) 142 | 143 | pooled_output = self.fc(tree_embeds) 144 | pooled_output = self.act(pooled_output) 145 | return pooled_output 146 | 147 | class GCNForClassification(nn.Module): 148 | """Detector head with GCN classifier""" 149 | 150 | def __init__( 151 | self, 152 | data_args, 153 | model_args, 154 | training_args, 155 | hidden_size, 156 | num_labels, 157 | hidden_dropout_prob=0.1 158 | ): 159 | super(GCNForClassification, self).__init__() 160 | 161 | self.data_args = data_args 162 | self.model_args = model_args 163 | self.training_args = training_args 164 | 165 | ## Assigned at `CustomSeq2SeqTrainer.__init__()` from `trainer_adv.py` 166 | self.loss_weight = None 167 | 168 | ## Decide the GCN structure to use 169 | self.pooler = GCNPooler(self.data_args, self.model_args, self.training_args, hidden_size) if not (self.model_args.td_gcn or self.model_args.bu_gcn) else None 170 | self.td_pooler = GCNPooler(self.data_args, self.model_args, self.training_args, hidden_size, gcn_type="td") if model_args.td_gcn else None 171 | self.bu_pooler = GCNPooler(self.data_args, self.model_args, self.training_args, hidden_size, gcn_type="bu") if model_args.bu_gcn else None 172 | 173 | self.dropout = nn.Dropout(hidden_dropout_prob) 174 | 175 | input_dim = hidden_size 176 | input_dim = 2 * input_dim if (self.model_args.td_gcn and self.model_args.bu_gcn) else input_dim 177 | self.classifier = nn.Linear(input_dim, num_labels) 178 | 179 | self.apply(self._init_weights) 180 | 181 | def _init_weights(self, module): 182 | if isinstance(module, (nn.Linear, nn.Embedding)): 183 | module.weight.data.normal_(mean=0.0, std=0.02) ## std: initializer_range 184 | elif isinstance(module, nn.LayerNorm): 185 | module.bias.data.zero_() 186 | module.weight.data.fill_(1.0) 187 | if isinstance(module, nn.Linear) and module.bias is not None: 188 | module.bias.data.zero_() 189 | 190 | def forward( 191 | self, 192 | hidden_states, 193 | attention_msk, 194 | tree_lens, 195 | td_edges=None, 196 | bu_edges=None, 197 | labels=None 198 | ): 199 | if self.pooler is not None: 200 | ## Transformer only 201 | pooled_output = self.pooler( 202 | hidden_states=hidden_states, 203 | attention_msk=attention_msk, 204 | tree_lens=tree_lens 205 | ) 206 | else: 207 | ## With GCN 208 | if self.td_pooler is not None: 209 | td_pooled_output = self.td_pooler( 210 | hidden_states=hidden_states, 211 | attention_msk=attention_msk, 212 | tree_lens=tree_lens, 213 | edge_index=td_edges 214 | ) 215 | if self.bu_pooler is not None: 216 | bu_pooled_output = self.bu_pooler( 217 | hidden_states=hidden_states, 218 | attention_msk=attention_msk, 219 | tree_lens=tree_lens, 220 | edge_index=bu_edges 221 | ) 222 | 223 | if self.td_pooler is not None and self.bu_pooler is not None: ## BiTGN 224 | pooled_output = torch.cat((td_pooled_output, bu_pooled_output), dim=1) 225 | elif self.td_pooler is not None: ## TDTGN 226 | pooled_output = td_pooled_output 227 | elif self.bu_pooler is not None: ## BUTGN 228 | pooled_output = bu_pooled_output 229 | else: 230 | raise ValueError("Wrong argument specification!") 231 | 232 | pooled_output = self.dropout(pooled_output) 233 | logits = self.classifier(pooled_output) 234 | 235 | loss = None 236 | if labels is not None: 237 | loss_fct = nn.CrossEntropyLoss(weight=self.loss_weight) 238 | loss = loss_fct(logits, labels) 239 | 240 | return (logits, loss) 241 | -------------------------------------------------------------------------------- /src/models/modeling_clustering.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import math 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from functools import partial 10 | from kmeans_pytorch.soft_dtw_cuda import SoftDTW 11 | from kmeans_pytorch import initialize, pairwise_distance 12 | 13 | def kmeans( 14 | X, 15 | num_clusters, 16 | distance='euclidean', 17 | cluster_centers=[], 18 | tol=1e-4, 19 | tqdm_flag=True, 20 | iter_limit=0, 21 | device=torch.device('cpu'), 22 | gamma_for_soft_dtw=0.001, 23 | seed=None, 24 | ): 25 | """ 26 | NOTE that this function is copied and modified from `kmeans_pytorch`. 27 | Modification: 28 | - enable clustering when `num_clusters == 1` 29 | 30 | perform kmeans 31 | :param X: (torch.tensor) matrix 32 | :param num_clusters: (int) number of clusters 33 | :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean'] 34 | :param seed: (int) seed for kmeans 35 | :param tol: (float) threshold [default: 0.0001] 36 | :param device: (torch.device) device [default: cpu] 37 | :param tqdm_flag: Allows to turn logs on and off 38 | :param iter_limit: hard limit for max number of iterations 39 | :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0 40 | :return: (torch.tensor, torch.tensor) cluster ids, cluster centers 41 | """ 42 | if tqdm_flag: 43 | print(f'running k-means on {device}..') 44 | 45 | if distance == 'euclidean': 46 | pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag) 47 | elif distance == 'cosine': 48 | pairwise_distance_function = partial(pairwise_cosine, device=device) 49 | elif distance == 'soft_dtw': 50 | sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw) 51 | pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device) 52 | else: 53 | raise NotImplementedError 54 | 55 | # convert to float 56 | X = X.float() 57 | 58 | # transfer to device 59 | X = X.to(device) 60 | 61 | # initialize 62 | if type(cluster_centers) == list: # ToDo: make this less annoyingly weird 63 | initial_state = initialize(X, num_clusters, seed=seed) 64 | else: 65 | if tqdm_flag: 66 | print('resuming') 67 | # find data point closest to the initial cluster center 68 | initial_state = cluster_centers 69 | dis = pairwise_distance_function(X, initial_state) 70 | choice_points = torch.argmin(dis, dim=0) 71 | initial_state = X[choice_points] 72 | initial_state = initial_state.to(device) 73 | 74 | iteration = 0 75 | if tqdm_flag: 76 | tqdm_meter = tqdm(desc='[running kmeans]') 77 | while True: 78 | 79 | dis = pairwise_distance_function(X, initial_state) 80 | 81 | if len(dis.shape) == 1: 82 | dis = dis.view(-1, 1) 83 | 84 | choice_cluster = torch.argmin(dis, dim=1) 85 | 86 | initial_state_pre = initial_state.clone() 87 | 88 | for index in range(num_clusters): 89 | selected = torch.nonzero(choice_cluster == index).squeeze().to(device) 90 | 91 | selected = torch.index_select(X, 0, selected) 92 | 93 | # https://github.com/subhadarship/kmeans_pytorch/issues/16 94 | if selected.shape[0] == 0: 95 | selected = X[torch.randint(len(X), (1,))] 96 | 97 | initial_state[index] = selected.mean(dim=0) 98 | 99 | center_shift = torch.sum( 100 | torch.sqrt( 101 | torch.sum((initial_state - initial_state_pre) ** 2, dim=1) 102 | )) 103 | 104 | # increment iteration 105 | iteration = iteration + 1 106 | 107 | # update tqdm meter 108 | if tqdm_flag: 109 | tqdm_meter.set_postfix( 110 | iteration=f'{iteration}', 111 | center_shift=f'{center_shift ** 2:0.6f}', 112 | tol=f'{tol:0.6f}' 113 | ) 114 | tqdm_meter.update() 115 | if center_shift ** 2 < tol: 116 | break 117 | if iter_limit != 0 and iteration >= iter_limit: 118 | break 119 | 120 | return choice_cluster.cpu(), initial_state.cpu() 121 | 122 | class ClusterModel(nn.Module): 123 | def __init__(self, cluster_type="kmeans", num_clusters=3, extract_ratio=None): 124 | super(ClusterModel, self).__init__() 125 | 126 | self.num_clusters = num_clusters 127 | self.cluster_type = cluster_type 128 | self.extract_ratio = extract_ratio 129 | 130 | print("Cluster Model: {}".format(self.cluster_type)) 131 | print("Num. clusters: {}".format(self.num_clusters)) 132 | 133 | def forward( 134 | self, 135 | node_feat=None, 136 | topic_feat=None, 137 | topic_probs=None, 138 | mode="train", 139 | device=None 140 | ): 141 | """ 142 | Input: 143 | - node_feat: node features of all response in a tree 144 | - topic_feat : for cluster_by_topic 145 | - topic_probs: for cluster_by_topic 146 | - mode: for cluster_by_kmeans, either "train" or "test", whether to random sample a centroid when multiple points are closest to center 147 | - device: for cluster_by_kmeans, the device to run on 148 | Output: 149 | - clusters 150 | """ 151 | if self.cluster_type == "kmeans": 152 | return self.cluster_by_kmeans(node_feat=node_feat, mode=mode, device=device) 153 | elif self.cluster_type == "topics": 154 | return self.cluster_by_topics(node_feat, topic_feat, topic_probs) 155 | 156 | def cluster_by_kmeans(self, node_feat, mode="train", device=None): 157 | ## Adjust number of clusters 158 | if self.extract_ratio is None: 159 | num_clusters = self.num_clusters 160 | while (num_clusters >= node_feat.shape[0]) and (num_clusters != 1): 161 | num_clusters = math.ceil(num_clusters / 2) 162 | #if num_clusters == 1 and node_feat.shape[0] > 1: 163 | # num_clusters = num_clusters + 1 164 | else: 165 | num_nodes = torch.tensor(node_feat.shape[0]) 166 | num_clusters = int(torch.floor(num_nodes * self.extract_ratio)) 167 | num_clusters = num_clusters + 1 if num_clusters == 0 else num_clusters 168 | 169 | ## Cluster ## 170 | if node_feat.shape[0] > 1: 171 | cluster_ids, cluster_centers = kmeans( 172 | X=node_feat, 173 | num_clusters=num_clusters, 174 | tol=1e-6, 175 | distance="euclidean", 176 | device=device, 177 | tqdm_flag=False 178 | ) 179 | else: ## When there is only one response 180 | cluster_ids = torch.LongTensor([0]) 181 | cluster_centers = node_feat 182 | 183 | cluster_ids = cluster_ids.to(device) 184 | cluster_centers = cluster_centers.to(device) 185 | 186 | ## Calculate distance of each node to each cluster 187 | dist = pairwise_distance(node_feat, cluster_centers, device=device, tqdm_flag=False) 188 | if len(dist.shape) <= 1: dist = dist.view(-1, 1) 189 | 190 | ## Find the centroid of each cluster 191 | is_centroid = torch.zeros(dist.shape[0], dtype=torch.long).to(device) 192 | for cluster_i in range(dist.shape[1]): 193 | ## Ignore this cluster if no response belongs to this cluster 194 | if cluster_i not in cluster_ids: 195 | continue 196 | 197 | ## Get distance of responses closest to cluster center and set them as centroid 198 | dist_2_i = dist[:, cluster_i] 199 | min_dist = dist_2_i[cluster_ids == cluster_i].min() 200 | min_idxs = (dist_2_i == min_dist).nonzero().flatten() 201 | if mode == "test": ## testing: random sample if more than one centroid 202 | min_idxs = min_idxs[torch.randint(low=0, high=len(min_idxs), size=(1,))] 203 | is_centroid[min_idxs] = 1 204 | 205 | return cluster_ids, cluster_centers, dist, is_centroid 206 | 207 | def cluster_by_topics(self, node_feat, topic_feat, topic_probs): 208 | def masked_softmax(x, mask, temp=0.1, dim=1): 209 | x_masked = x.clone() 210 | x_masked[mask == 0] = -float("inf") 211 | return F.softmax(x_masked / temp, dim=dim) 212 | 213 | mask = (topic_probs != 0).long() 214 | topic_probs = masked_softmax(topic_probs, mask, temp=0.01, dim=1) 215 | topic_feat = (topic_feat * topic_probs.unsqueeze(-1)).sum(dim=1) 216 | 217 | n_topics = topic_feat.shape[0] 218 | n_sample = node_feat.shape[0] 219 | 220 | l2_dist = torch.cdist(node_feat, topic_feat) 221 | cluster_prob = F.softmax(-1 * l2_dist, dim=1) 222 | 223 | dist_sort = l2_dist.argsort(dim=0) 224 | prob_sort = cluster_prob.argsort(dim=0) 225 | 226 | cluster_ridx = prob_sort[:math.ceil(n_sample / n_topics), :] 227 | if cluster_ridx.numel() == 0: ## No response 228 | return None, None 229 | 230 | centroid_ridx = [] 231 | for topic_i in range(n_topics): 232 | ridx = cluster_ridx[:, topic_i] 233 | cent_ridx = l2_dist[ridx, topic_i].argsort()[0] 234 | centroid_ridx.append(ridx[cent_ridx]) 235 | centroid_ridx = torch.stack(centroid_ridx) 236 | 237 | return cluster_ridx, centroid_ridx -------------------------------------------------------------------------------- /src/data/preprocess/plot_tsne.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | import seaborn as sns 7 | import preprocessor as pre 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | 11 | from tqdm import tqdm 12 | 13 | import torch 14 | from sklearn import svm 15 | from sklearn.manifold import TSNE 16 | from transformers import RobertaTokenizer, RobertaModel 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="Generate summary by ChatGPT") 20 | 21 | ## Others 22 | parser.add_argument("--dataset_name", type=str, default="semeval2019", choices=["semeval2019", "twitter15", "twitter16"]) 23 | parser.add_argument("--dataset_root", type=str, default="../dataset/processedV2") 24 | parser.add_argument("--results_root", type=str, default="/mnt/1T/projects/RumorV2/results") 25 | parser.add_argument("--fold", type=str, default="0,1,2,3,4", help="either use 5-fold data or train/dev/test from rumoureval2019 competition") 26 | 27 | args = parser.parse_args() 28 | 29 | return args 30 | 31 | def make_meshgrid(x, y, h=.02): 32 | """Create a mesh of points to plot in 33 | Parameters 34 | ---------- 35 | x: data to base x-axis meshgrid on 36 | y: data to base y-axis meshgrid on 37 | h: stepsize for meshgrid, optional 38 | Returns 39 | ------- 40 | xx, yy : ndarray 41 | """ 42 | x_min, x_max = x.min() - 1, x.max() + 1 43 | y_min, y_max = y.min() - 1, y.max() + 1 44 | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), 45 | np.arange(y_min, y_max, h)) 46 | return xx, yy 47 | 48 | def plot_contours(ax, clf, xx, yy, **params): 49 | """Plot the decision boundaries for a classifier. 50 | 51 | Parameters 52 | ---------- 53 | ax: matplotlib axes object 54 | clf: a classifier 55 | xx: meshgrid ndarray 56 | yy: meshgrid ndarray 57 | params: dictionary of params to pass to contourf, optional 58 | """ 59 | Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) 60 | Z = Z.reshape(xx.shape) 61 | out = ax.contourf(xx, yy, Z, **params) 62 | return out 63 | 64 | def plot_results(tsne_df, output_file, svm_clf=None): 65 | def run_svm(data_df, svm_clf=None, kernel="rbf", C=1.0, degree=3): 66 | print("\nRunnnig SVM Classifier on TSNE data points...") 67 | label = data_df["y"] 68 | z = data_df[["tsne_1", "tsne_2"]].values 69 | 70 | if svm_clf is None: 71 | svm_clf = svm.SVC(kernel=kernel, C=1.0, degree=3) 72 | svm_clf.fit(z, label) 73 | 74 | preds = svm_clf.predict(z) 75 | acc = (preds == label).sum() / len(preds) 76 | print("Accuracy: {}".format(acc)) 77 | 78 | xx, yy = make_meshgrid(z[:, 0], z[:, 1], h=.1) 79 | plot_contours(plt, svm_clf, xx, yy, cmap=matplotlib.colormaps["coolwarm"], alpha=0.2) 80 | 81 | return svm_clf 82 | 83 | #svm_clf = run_svm(tsne_df, svm_clf=svm_clf) 84 | svm_clf = run_svm(tsne_df, svm_clf=svm_clf, kernel="rbf", degree=3, C=0.5) 85 | sns.scatterplot( 86 | x="tsne_1", 87 | y="tsne_2", 88 | hue=tsne_df["stance"].tolist(), 89 | #palette=sns.color_palette("hls", len(set(tsne_df["y"]))), 90 | palette=[sns.color_palette("Reds")[3], sns.color_palette("Blues")[3]], 91 | data=tsne_df 92 | ) 93 | plt.xlabel("") 94 | plt.ylabel("") 95 | plt.title("t-SNE for tweets with different stances on RE2019") 96 | plt.tight_layout() 97 | plt.savefig(output_file, dpi=300) 98 | plt.clf() 99 | 100 | return svm_clf 101 | 102 | def main(args): 103 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 104 | 105 | ## Load data 106 | print("\nLoad data...") 107 | data_df = pd.read_csv("{}/{}/data.csv".format(args.dataset_root, args.dataset_name)) 108 | group_src = data_df.groupby("source_id") 109 | 110 | ## Load model 111 | print("\nLoad model...") 112 | tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 113 | model = RobertaModel.from_pretrained("roberta-large") 114 | model.to(device) 115 | 116 | for src_id, src_df in group_src: 117 | texts = src_df["text"].tolist() 118 | stance = src_df["stance"].tolist() 119 | 120 | ## Filter stances 121 | #filter = ["comment", "query"] 122 | texts_filt, stance_filt = [], [] 123 | for idx, text in enumerate(texts): 124 | #if stance[idx] in filter: 125 | # continue 126 | texts_filt.append(text) 127 | stance_filt.append(stance[idx]) 128 | 129 | texts = texts_filt 130 | stance = stance_filt 131 | 132 | ## Convert stance to label id 133 | label2id = {} 134 | id2label = {} 135 | label_set = list(set(stance)) 136 | for id, label in enumerate(label_set): 137 | label2id[label] = id 138 | id2label[id] = label 139 | label = np.array([label2id[sta] for sta in stance]) 140 | 141 | #if (not ("support" in stance and "deny" in stance)) or (len(stance) < 1): 142 | # continue 143 | 144 | feats = [] 145 | with torch.no_grad(): 146 | for text in tqdm(texts, desc="Getting text features of {}".format(src_id)): 147 | encoded_input = tokenizer(text, max_length=tokenizer.model_max_length, return_tensors="pt") 148 | encoded_input["input_ids"] = encoded_input["input_ids"].to(device) 149 | encoded_input["attention_mask"] = encoded_input["attention_mask"].to(device) 150 | outputs = model(**encoded_input) 151 | 152 | feats.append(outputs["pooler_output"].cpu()) 153 | feats = torch.cat(feats, dim=0) 154 | 155 | ppl = 30 if len(feats) > 30 else len(feats) - 1 156 | #ipdb.set_trace() 157 | print("\nTSNE with Perplexity {}".format(ppl)) 158 | tsne = TSNE(n_components=2, perplexity=ppl, verbose=0, random_state=123) 159 | z = tsne.fit_transform(feats) 160 | 161 | tsne_df = pd.DataFrame() 162 | tsne_df["y"] = label 163 | tsne_df["stance"] = stance 164 | tsne_df["tsne_1"] = z[:, 0] 165 | tsne_df["tsne_2"] = z[:, 1] 166 | 167 | sns.scatterplot( 168 | x="tsne_1", 169 | y="tsne_2", 170 | hue=tsne_df["stance"].tolist(), 171 | #palette=sns.color_palette("hls", len(set(tsne_df["y"]))), 172 | palette=[sns.color_palette("Reds")[3], sns.color_palette("Blues")[3]], 173 | data=tsne_df 174 | ) 175 | plt.xlabel("") 176 | plt.ylabel("") 177 | plt.title("t-SNE for tweets with different stances on RE2019") 178 | plt.tight_layout() 179 | plt.savefig("tsne/thread/{}_{}.png".format(src_id, len(feats)), dpi=300) 180 | plt.clf() 181 | 182 | ## Save TSNE points 183 | #with open("tsne/thread/{}.npy".format(), "wb") as f: 184 | # np.save(f, z) 185 | 186 | """ 187 | texts = data_df["text"].tolist() 188 | stance = data_df["stance"].tolist() 189 | 190 | ## Filter stances 191 | filter = ["comment", "query"] 192 | texts_filt, stance_filt = [], [] 193 | for idx, text in enumerate(texts): 194 | if stance[idx] in filter: 195 | continue 196 | texts_filt.append(text) 197 | stance_filt.append(stance[idx]) 198 | 199 | texts = texts_filt 200 | stance = stance_filt 201 | 202 | ## Convert stance to label id 203 | label2id = {} 204 | id2label = {} 205 | label_set = list(set(stance)) 206 | for id, label in enumerate(label_set): 207 | label2id[label] = id 208 | id2label[id] = label 209 | label = np.array([label2id[sta] for sta in stance]) 210 | 211 | ipdb.set_trace() 212 | 213 | if not os.path.isfile("tsne/tsne_points.npy"): 214 | tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 215 | model = RobertaModel.from_pretrained("roberta-large") 216 | model.to(device) 217 | 218 | feats = [] 219 | with torch.no_grad(): 220 | for text in tqdm(texts, desc="Getting text features"): 221 | encoded_input = tokenizer(text, max_length=tokenizer.model_max_length, return_tensors="pt") 222 | encoded_input["input_ids"] = encoded_input["input_ids"].to(device) 223 | encoded_input["attention_mask"] = encoded_input["attention_mask"].to(device) 224 | outputs = model(**encoded_input) 225 | 226 | feats.append(outputs["pooler_output"].cpu()) 227 | feats = torch.cat(feats, dim=0) 228 | 229 | ppl = 30 230 | print("\nTSNE with Perplexity {}".format(ppl)) 231 | tsne = TSNE(n_components=2, perplexity=ppl, verbose=0, random_state=123) 232 | z = tsne.fit_transform(feats) 233 | 234 | ## Save TSNE points 235 | with open("tsne/tsne_points.npy", "wb") as f: 236 | np.save(f, z) 237 | else: 238 | print("\nTSNE points cache exists! Loading...") 239 | with open("tsne/tsne_points.npy", "rb") as f: 240 | z = np.load(f) 241 | 242 | print("\nPlot results...") 243 | tsne_df = pd.DataFrame() 244 | tsne_df["y"] = label 245 | tsne_df["stance"] = stance 246 | tsne_df["tsne_1"] = z[:, 0] 247 | tsne_df["tsne_2"] = z[:, 1] 248 | 249 | ## Filter points 250 | tsne_df_filt = tsne_df.loc[~((tsne_df["tsne_2"] < 0) & (tsne_df["stance"] == "support"))] 251 | #tsne_df_filt = tsne_df_filt.loc[~((tsne_df["tsne_2"] < 10) & (tsne_df["tsne_1"] < -30) & (tsne_df["stance"] == "support"))] 252 | tsne_df_filt = tsne_df_filt.loc[~((tsne_df["tsne_2"] < 10) & (tsne_df["tsne_1"] < 25) & (tsne_df["stance"] == "support"))] 253 | tsne_df_filt = tsne_df_filt.loc[~((tsne_df["tsne_2"] < 15) & (tsne_df["tsne_1"] > 20) & (tsne_df["stance"] == "support"))] 254 | 255 | tsne_df_filt = tsne_df_filt.loc[~((tsne_df["tsne_2"] > 10) & (tsne_df["tsne_1"] > -30) & (tsne_df["tsne_1"] < 25) & (tsne_df["stance"] == "deny"))] 256 | 257 | svm_clf = plot_results(tsne_df_filt, output_file="tsne/stance_filt.png") 258 | svm_clf = plot_results(tsne_df, output_file="tsne/stance.png", svm_clf=svm_clf) 259 | 260 | ipdb.set_trace() 261 | """ 262 | 263 | if __name__ == "__main__": 264 | args = parse_args() 265 | main(args) -------------------------------------------------------------------------------- /src/others/evaluate_summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | import json 4 | import math 5 | import random 6 | import shutil 7 | import argparse 8 | import itertools 9 | import pandas as pd 10 | from tqdm import tqdm 11 | 12 | import transformers 13 | from datasets import load_metric 14 | from evaluate import load 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description="Rumor Detection") 18 | 19 | ## What to do 20 | parser.add_argument("--eval_ppl", action="store_true") 21 | parser.add_argument("--generate_for_factCC", action="store_true") 22 | 23 | parser.add_argument("--model_type", type=str, default="kmeans") ## kmeans, loo 24 | parser.add_argument("--num_clusters", type=int, default=1) ## 1, 2, 3, 4, 5 25 | parser.add_argument("--factCC_format", type=str, default=None, help="response_wise, all_responses") 26 | 27 | ## Others 28 | parser.add_argument("--data_name", type=str, default="semeval2019", choices=["semeval2019", "Pheme", "twitter15", "twitter16"]) 29 | parser.add_argument("--data_root", type=str, default="../../dataset/processed") 30 | parser.add_argument("--data_root_V2", type=str, default="../../dataset/processedV2") 31 | parser.add_argument("--fold", type=str, default="0,1,2,3,4", help="either use 5-fold data or train/dev/test from rumoureval2019 competition") 32 | parser.add_argument("--result_path", type=str, default="/mnt/1T/projects/RumorV2/results") 33 | 34 | args = parser.parse_args() 35 | 36 | return args 37 | 38 | def eval_ppl(args): 39 | print("Evaluating perplexity of generated summary...") 40 | print("Model Type: {}".format(args.model_type)) 41 | 42 | metric = load("perplexity", module_type="metric") 43 | 44 | ppls = [] 45 | for fold in args.fold.split(","): 46 | print("{} Fold [{}]".format(args.data_name, fold)) 47 | 48 | summary_df = pd.read_csv("{}/{}/{}/{}/summary.csv".format(args.result_path, args.data_name, args.model_type, fold)) 49 | #ipdb.set_trace() 50 | if "kmeans" in args.model_type: 51 | summaries = [] 52 | for source_id, group in summary_df.groupby("source_id"): ## For each thread 53 | summary_clusters = [] 54 | for cluster_id, cluster in group.groupby("cluster_id"): ## Create all possible summary combinations 55 | summary_clusters.append(cluster["summary"].tolist()) 56 | 57 | summary_thread = list(itertools.product(*summary_clusters)) 58 | summary_thread = [" ".join(summ) for summ in summary_thread] 59 | summaries.extend(summary_thread) 60 | else: 61 | summaries = summary_df["summary"].tolist() 62 | 63 | new_summaries = [] 64 | for summary in summaries: 65 | if not isinstance(summary, str): 66 | new_summaries.append(".") 67 | else: 68 | new_summary = summary.replace("$", "").strip().rstrip() 69 | new_summary = "." if new_summary == "" else new_summary 70 | new_summaries.append(summary) 71 | summaries = new_summaries 72 | 73 | ppl = metric.compute( 74 | model_id="gpt2", 75 | predictions=summaries, 76 | add_start_token=True, 77 | device="cuda" 78 | ) 79 | ppls.append(ppl) 80 | 81 | with open("{}/{}/{}/ppl.txt".format(args.result_path, args.data_name, args.model_type), "w") as fw: 82 | fw.write("{}\t{}\n".format("Fold", "Perplexity")) 83 | for fold_idx, ppl in enumerate(ppls): 84 | fw.write("{:4d}\t{}\n".format(fold_idx, ppl["mean_perplexity"])) 85 | 86 | def generate_for_factCC(args): 87 | print("Generating data files for factCC...") 88 | print("factCC_format: {}".format(args.factCC_format)) 89 | print("Model Type: {}".format(args.model_type)) 90 | 91 | if args.factCC_format == "cluster_wise": 92 | for fold in args.fold.split(","): 93 | print("{} Fold [{}]".format(args.data_name, fold)) 94 | dataset_path = "{}/{}/data.csv".format(args.data_root_V2, args.data_name) 95 | summary_root = "{}/{}/{}/{}".format(args.result_path, args.data_name, args.model_type, fold) 96 | summary_path = "{}/summary.csv".format(summary_root) 97 | cluster_path = "{}/{}/split_{}/cluster_summary/train/kmeans-{}.csv".format(args.data_root_V2, args.data_name, fold, args.model_type.split("_")[-1]) 98 | 99 | dataset_df = pd.read_csv(dataset_path) 100 | summary_df = pd.read_csv(summary_path) 101 | cluster_df = pd.read_csv(cluster_path) 102 | 103 | new_df = cluster_df.copy() 104 | new_df["cluster_id"] = cluster_df["cluster_id"].apply(lambda x: x.split("_")[0]) 105 | cluster_df = new_df 106 | 107 | if "kmeans" not in args.model_type: 108 | raise ValueError("Wrong model type specified.") 109 | 110 | source_id, summaries = [], [] 111 | for src_id, group in summary_df.groupby("source_id"): ## For each thread 112 | summary_clusters = {} 113 | for cluster_id, cluster in group.groupby("cluster_id"): ## Create all possible summary combinations 114 | summary_clusters[cluster_id] = cluster["summary"].tolist()[0] ## Only one summary for each cluster 115 | 116 | #summary_thread.append(summary_clusters) 117 | summaries.append(summary_clusters) 118 | source_id.append(src_id) 119 | 120 | os.makedirs("{}/factCC_{}".format(summary_root, args.factCC_format), exist_ok=True) 121 | with open("{}/factCC_{}/data-dev.jsonl".format(summary_root, args.factCC_format), "w") as fw: 122 | for idx in range(len(summaries)): 123 | summary, src_id = summaries[idx], source_id[idx] 124 | tree_df = dataset_df[dataset_df["source_id"] == src_id] 125 | clus_df = cluster_df[cluster_df["source_id"] == src_id] 126 | resp_df = tree_df[tree_df["tweet_id"] != src_id] 127 | 128 | ## For each cluster 129 | for cluster_id, cluster in clus_df.groupby("cluster_id"): 130 | summary_cluster_i = summary[int(cluster_id)] 131 | resp_df_cluster_i = resp_df[resp_df["tweet_id"].isin(cluster["tweet_id"].tolist())] ## Take the responses belong to this cluster as inputs 132 | 133 | summary_cluster_i = summary_cluster_i.replace("$", "").strip().rstrip() 134 | summary_cluster_i = "." if summary_cluster_i == "" else summary_cluster_i 135 | 136 | obj = {} 137 | obj["label"] = "CORRECT" ## Dummy Label 138 | obj["id"] = "{}_{}".format(src_id, cluster_id) 139 | obj["text"] = " ".join(resp_df_cluster_i["text"].tolist()) 140 | obj["claim"] = summary_cluster_i 141 | 142 | fw.write("{}\n".format(json.dumps(obj))) 143 | #ipdb.set_trace() 144 | 145 | #new_summaries = [] 146 | #for summary in summaries: 147 | # if not isinstance(summary, str): 148 | # new_summaries.append(".") 149 | # else: 150 | # new_summary = summary.replace("$", "").strip().rstrip() 151 | # new_summary = "." if new_summary == "" else new_summary 152 | # new_summaries.append(summary) 153 | #summaries = new_summaries 154 | 155 | else: 156 | for fold in args.fold.split(","): 157 | print("{} Fold [{}]".format(args.data_name, fold)) 158 | dataset_path = "{}/{}/data.csv".format(args.data_root_V2, args.data_name) 159 | summary_root = "{}/{}/{}/{}".format(args.result_path, args.data_name, args.model_type, fold) 160 | summary_path = "{}/summary.csv".format(summary_root) 161 | 162 | dataset_df = pd.read_csv(dataset_path) 163 | summary_df = pd.read_csv(summary_path) 164 | 165 | if "kmeans" in args.model_type: 166 | source_id, summaries = [], [] 167 | for src_id, group in summary_df.groupby("source_id"): ## For each thread 168 | summary_clusters = [] 169 | for cluster_id, cluster in group.groupby("cluster_id"): ## Create all possible summary combinations 170 | summary_clusters.append(cluster["summary"].tolist()) 171 | 172 | summary_thread = list(itertools.product(*summary_clusters)) 173 | summary_thread = [" ".join(summ) for summ in summary_thread] 174 | summaries.extend(summary_thread) 175 | source_id.append(src_id) 176 | else: 177 | source_id = summary_df["source_id"].tolist() 178 | summaries = summary_df["summary"].tolist() 179 | 180 | #print("{} Fold [{}], {}".format(args.data_name, fold, len(summaries))) 181 | new_summaries = [] 182 | for summary in summaries: 183 | if not isinstance(summary, str): 184 | new_summaries.append(".") 185 | else: 186 | new_summary = summary.replace("$", "").strip().rstrip() 187 | new_summary = "." if new_summary == "" else new_summary 188 | new_summaries.append(summary) 189 | summaries = new_summaries 190 | 191 | #shutil.rmtree("{}/{}".format(summary_root, args.factCC_format)) 192 | os.makedirs("{}/factCC_{}".format(summary_root, args.factCC_format), exist_ok=True) 193 | with open("{}/factCC_{}/data-dev.jsonl".format(summary_root, args.factCC_format), "w") as fw: 194 | for idx in range(len(summaries)): 195 | summary, src_id = summaries[idx], source_id[idx] 196 | tree_df = dataset_df[dataset_df["source_id"] == src_id] 197 | resp_df = tree_df[tree_df["tweet_id"] != src_id] 198 | 199 | if args.factCC_format == "all_responses": 200 | obj = {} 201 | obj["label"] = "CORRECT" ## Dummy Label 202 | obj["id"] = src_id 203 | obj["text"] = " ".join(resp_df["text"].tolist()) 204 | obj["claim"] = summary 205 | 206 | fw.write("{}\n".format(json.dumps(obj))) 207 | 208 | elif args.factCC_format == "response_wise": 209 | resp_text = resp_df["text"].tolist() 210 | for r_txt in resp_text: 211 | obj = {} 212 | obj["label"] = "CORRECT" ## Dummy Label 213 | obj["id"] = src_id 214 | obj["text"] = r_txt 215 | obj["claim"] = summary 216 | 217 | fw.write("{}\n".format(json.dumps(obj))) 218 | 219 | if __name__ == "__main__": 220 | args = parse_args() 221 | 222 | if args.eval_ppl: 223 | eval_ppl(args) 224 | elif args.generate_for_factCC: 225 | generate_for_factCC(args) -------------------------------------------------------------------------------- /src/pipelines/build_trainer.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import numpy as np 3 | 4 | import nltk 5 | 6 | from datasets import load_metric 7 | from transformers import ( 8 | EvalPrediction, 9 | DataCollatorWithPadding, 10 | DataCollatorForSeq2Seq, 11 | default_data_collator, 12 | Seq2SeqTrainer 13 | ) 14 | 15 | ## Self-defined 16 | from .trainer import CustomTrainer 17 | from others.metrics import f1_score_3_class, f1_score_4_class 18 | 19 | def build_trainer( 20 | data_args, model_args, training_args, 21 | train_dataset, eval_dataset, 22 | model, tokenizer 23 | ): 24 | """Building trainer according to different tasks.""" 25 | print("\nBuilding trainer...") 26 | 27 | if training_args.task_type == "train_detector": 28 | ## Get the metric function 29 | metric = { 30 | "accuracy": load_metric("accuracy"), 31 | "f1": load_metric("f1") 32 | } 33 | 34 | ## You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 35 | ## predictions and label_ids field) and has to return a dictionary string to float. 36 | def compute_metrics(pred: EvalPrediction): 37 | preds = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions 38 | #preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 39 | preds = np.argmax(preds, axis=1) 40 | 41 | result = { 42 | "accuracy": metric["accuracy"].compute(predictions=preds, references=pred.label_ids)["accuracy"], 43 | "f1_macro": metric["f1"].compute(predictions=preds, references=pred.label_ids, average="macro")["f1"] 44 | } 45 | for label_i in range(data_args.num_labels): 46 | result["f1_{}".format(label_i)] = metric["f1"].compute(predictions=preds, references=pred.label_ids, average=None)["f1"][label_i] 47 | return result 48 | 49 | ## Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if 50 | ## we already did the padding. 51 | if data_args.pad_to_max_length: 52 | data_collator = default_data_collator 53 | elif training_args.fp16: 54 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 55 | else: 56 | data_collator = None 57 | 58 | trainer = CustomTrainer( 59 | model=model, 60 | args=training_args, 61 | train_dataset=train_dataset if training_args.do_train else None, 62 | eval_dataset=eval_dataset if training_args.do_eval else None, 63 | compute_metrics=compute_metrics, 64 | tokenizer=tokenizer, 65 | data_collator=data_collator, 66 | ) 67 | 68 | elif training_args.task_type == "train_adv_stage1" or \ 69 | training_args.task_type == "train_adv_stage2": 70 | ## Metric 71 | metric = { 72 | "accuracy": load_metric("accuracy"), ## For detector 73 | "f1" : load_metric("f1"), ## For detector 74 | "rouge" : load_metric("rouge") ## For generator 75 | } 76 | 77 | def postprocess_text(preds, labels): 78 | preds = [pred.strip() for pred in preds] 79 | labels = [label.strip() for label in labels] 80 | 81 | ## rougeLSum expects newline after each sentence 82 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 83 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 84 | 85 | return preds, labels 86 | 87 | def compute_metrics(eval_preds): 88 | """This function is called in `trainer.evaluation_loop` and `trainer.prediction_loop`""" 89 | preds, labels = eval_preds 90 | preds_det, labels_det = preds[0], labels[0] ## For detector 91 | preds_gen, labels_gen = preds[1], labels[1] ## For generator 92 | 93 | #################### 94 | ## For generation ## 95 | #################### 96 | decoded_preds = tokenizer.batch_decode(preds_gen, skip_special_tokens=True) 97 | if data_args.ignore_pad_token_for_loss: 98 | ## Replace -100 in the labels as we can't decode them. 99 | labels_gen = np.where(labels_gen != -100, labels_gen, tokenizer.pad_token_id) 100 | decoded_labels = tokenizer.batch_decode(labels_gen, skip_special_tokens=True) 101 | 102 | ## Some simple post-processing 103 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 104 | 105 | result = metric["rouge"].compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 106 | ## Extract a few results from ROUGE 107 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 108 | 109 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 110 | result["gen_len"] = np.mean(prediction_lens) 111 | result = {k: round(v, 4) for k, v in result.items()} 112 | 113 | ######################## 114 | ## For classification ## 115 | ######################## 116 | preds_det = np.argmax(preds_det, axis=1) 117 | result["accuracy"] = metric["accuracy"].compute(predictions=preds_det, references=labels_det)["accuracy"] 118 | 119 | ## Calculate F1 scores 120 | """ 121 | result["f1_macro"] = metric["f1"].compute(predictions=preds_det, references=labels_det, average="macro")["f1"] 122 | 123 | for label_i in range(data_args.num_labels): 124 | result["f1_{}".format(label_i)] = metric["f1"].compute(predictions=preds_det, references=labels_det, average=None)["f1"][label_i] 125 | """ 126 | 127 | num_labels = data_args.num_labels#len(set(labels_det)) 128 | if num_labels == 3: 129 | F1_all = f1_score_3_class(preds_det, labels_det) 130 | elif num_labels == 4: 131 | F1_all = f1_score_4_class(preds_det, labels_det) 132 | F1_all = np.array(F1_all) 133 | 134 | ## Only take classes that exist in `labels_det` 135 | indices = np.array(list(set(labels_det))) 136 | F1_filt = F1_all[indices] 137 | result["f1_macro"] = np.mean(F1_filt) 138 | 139 | for label_i in range(data_args.num_labels): 140 | ## -1: Labels that do not exist in test set, so ignore it 141 | result["f1_{}".format(label_i)] = F1_all[label_i] if label_i in indices else -1. 142 | 143 | return result 144 | 145 | ## Data collator 146 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 147 | data_collator = DataCollatorForSeq2Seq( 148 | tokenizer, 149 | model=model, 150 | label_pad_token_id=label_pad_token_id, 151 | pad_to_multiple_of=8 if training_args.fp16 else None, 152 | ) 153 | 154 | ## Initialize our Trainer 155 | from .trainer_adv import CustomSeq2SeqTrainer 156 | trainer = CustomSeq2SeqTrainer( 157 | model=model, 158 | args=training_args, 159 | model_args=model_args, 160 | data_args=data_args, 161 | train_dataset=train_dataset if training_args.do_train else None, 162 | eval_dataset=eval_dataset if training_args.do_eval else None, 163 | tokenizer=tokenizer, 164 | data_collator=data_collator, 165 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 166 | ) 167 | 168 | elif training_args.task_type == "ssra_loo" or \ 169 | training_args.task_type == "ssra_kmeans": 170 | ## Metric 171 | metric = { 172 | "rouge": load_metric("rouge"), 173 | "perplexity": load_metric("perplexity") 174 | } 175 | 176 | def postprocess_text(preds, labels): 177 | preds = [pred.strip() for pred in preds] 178 | labels = [label.strip() for label in labels] 179 | 180 | # rougeLSum expects newline after each sentence 181 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 182 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 183 | 184 | return preds, labels 185 | 186 | def compute_metrics(eval_preds): 187 | preds, labels = eval_preds 188 | if isinstance(preds, tuple): 189 | preds = preds[0] 190 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 191 | if data_args.ignore_pad_token_for_loss: 192 | # Replace -100 in the labels as we can't decode them. 193 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 194 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 195 | 196 | ## Some simple post-processing 197 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 198 | 199 | result = metric["rouge"].compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 200 | ## Extract a few results from ROUGE 201 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 202 | 203 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 204 | result["gen_len"] = np.mean(prediction_lens) 205 | result = {k: round(v, 4) for k, v in result.items()} 206 | 207 | ## NEW: add perplexity 208 | #result["perplexity"] = metric["perplexity"].compute(predictions=decoded_preds, model_id="gpt2") 209 | return result 210 | 211 | ## Data collator 212 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 213 | data_collator = DataCollatorForSeq2Seq( 214 | tokenizer, 215 | model=model, 216 | label_pad_token_id=label_pad_token_id, 217 | pad_to_multiple_of=8 if training_args.fp16 else None, 218 | ) 219 | 220 | ## Initialize our Trainer 221 | from .trainer_abstractor import CustomSeq2SeqTrainer 222 | trainer = CustomSeq2SeqTrainer( 223 | model=model, 224 | args=training_args, 225 | model_args=model_args, 226 | data_args=data_args, 227 | train_dataset=train_dataset if training_args.do_train else None, 228 | eval_dataset=eval_dataset if training_args.do_eval else None, 229 | tokenizer=tokenizer, 230 | data_collator=data_collator, 231 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 232 | ) 233 | 234 | elif training_args.task_type == "train_filter": 235 | from .trainer_filter import FilterTrainer 236 | trainer = FilterTrainer( 237 | model=model, 238 | data_args=data_args, 239 | model_args=model_args, 240 | training_args=training_args, 241 | train_dataset=train_dataset, 242 | eval_dataset=eval_dataset, 243 | ) 244 | 245 | elif training_args.task_type == "build_cluster_summary": 246 | from .builder_cluster_summary import ClusterSummaryBuilder 247 | trainer = ClusterSummaryBuilder( 248 | model=model, 249 | data_args=data_args, 250 | model_args=model_args, 251 | training_args=training_args, 252 | train_dataset=train_dataset, 253 | eval_dataset=eval_dataset 254 | ) 255 | 256 | else: 257 | raise ValueError("training_args.task_type not specified!") 258 | return trainer 259 | -------------------------------------------------------------------------------- /src/pipelines/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is developed based on: 3 | https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py 4 | """ 5 | 6 | import ipdb 7 | import os 8 | import shutil 9 | from collections import Counter 10 | from typing import Optional, Union, Dict, Any, Callable, Tuple, List 11 | 12 | import torch 13 | from torch import nn 14 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | from transformers import Trainer 18 | from transformers.modeling_utils import PreTrainedModel 19 | from transformers.training_args import TrainingArguments 20 | from transformers.data.data_collator import DataCollator 21 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 22 | from transformers.trainer_utils import EvalPrediction 23 | from transformers.trainer_callback import TrainerCallback 24 | from transformers.utils import logging 25 | 26 | from others.args import DataTrainingArguments, ModelArguments 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | class CustomTrainer(Trainer): 31 | """Customized Trainer""" 32 | def __init__( 33 | self, 34 | model: Union[PreTrainedModel, nn.Module] = None, 35 | args: TrainingArguments = None, 36 | model_args: ModelArguments = None, 37 | data_args: DataTrainingArguments = None, 38 | data_collator: Optional[DataCollator] = None, 39 | train_dataset: Optional[Dataset] = None, 40 | eval_dataset: Optional[Dataset] = None, 41 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 42 | model_init: Callable[[], PreTrainedModel] = None, 43 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 44 | callbacks: Optional[List[TrainerCallback]] = None, 45 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 46 | preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, 47 | ): 48 | Trainer.__init__( 49 | self, model, args, data_collator, train_dataset, 50 | eval_dataset, tokenizer, model_init, compute_metrics, 51 | callbacks, optimizers, preprocess_logits_for_metrics 52 | ) 53 | 54 | self.data_args = data_args 55 | self.model_args = model_args 56 | 57 | ## Override 58 | if self.args.task_type == "train_detector": 59 | self.label_names = ["labels_det"] 60 | 61 | ## For best model saving 62 | self._ckpt_eval_loss = {} 63 | if self.args.save_model_accord_to_metric: 64 | self._ckpt_eval_metric = {} 65 | self.best_metrics = None 66 | self.best_checkpoint_path = None 67 | 68 | ## Make specified parameters trainable 69 | freeze_only = None 70 | if self.args.task_type == "train_detector": 71 | freeze_only = "summarizer" 72 | self._freeze_specified_params(self.model, freeze_only=freeze_only) 73 | 74 | ## Show number of parameters 75 | all_param_num = sum([p.nelement() for p in self.model.parameters()]) 76 | trainable_param_num = sum([ 77 | p.nelement() 78 | for p in self.model.parameters() 79 | if p.requires_grad == True 80 | ]) 81 | print("All parameters: {}".format(all_param_num)) 82 | print("Trainable parameters: {}".format(trainable_param_num)) 83 | 84 | ## Setup loss weight for classification 85 | if train_dataset is not None: 86 | num_classes = Counter(train_dataset["labels_det"]) 87 | num_classes = sorted(num_classes.items()) 88 | num_classes = torch.LongTensor([n[1] for n in num_classes]) 89 | loss_weight = num_classes.max() / num_classes 90 | loss_weight = loss_weight.to(self.model.device) 91 | 92 | self.model.loss_weight = loss_weight 93 | if self.model.detector_head.__class__.__name__ == "GCNForClassification": 94 | self.model.detector_head.loss_weight = loss_weight 95 | 96 | def _freeze_specified_params(self, model, freeze_only=None): 97 | if freeze_only is not None: 98 | names = freeze_only.split() 99 | for freeze_name in names: 100 | for name, sub_module in model.named_modules(): 101 | if name.startswith(freeze_name): 102 | for param in sub_module.parameters(): 103 | param.requires_grad = False 104 | 105 | def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): 106 | """ 107 | Modification: 108 | - record current eval loss / metric for best model saving 109 | """ 110 | if self.control.should_log: 111 | #if is_torch_tpu_available(): 112 | # xm.mark_step() 113 | 114 | logs: Dict[str, float] = {} 115 | 116 | # all_gather + mean() to get average loss over all processes 117 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item() 118 | 119 | # reset tr_loss to zero 120 | tr_loss -= tr_loss 121 | 122 | logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) 123 | logs["learning_rate"] = self._get_learning_rate() 124 | 125 | self._total_loss_scalar += tr_loss_scalar 126 | self._globalstep_last_logged = self.state.global_step 127 | self.store_flos() 128 | 129 | self.log(logs) 130 | 131 | metrics = None 132 | if self.control.should_evaluate: 133 | metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) 134 | self._report_to_hp_search(trial, epoch, metrics) 135 | 136 | ## NEW: display metrics properly 137 | print("\n***** Epoch [{:.4f}] Evaluation Results *****".format(metrics["epoch"])) 138 | print("Loss D : {:7.4f}".format(metrics["eval_loss"])) 139 | print("Accuracy: {:7.4f}, F1-Macro: {:7.4f}".format(metrics["eval_accuracy"], metrics["eval_f1_macro"])) 140 | 141 | ## NEW: record metric 142 | if self.args.save_model_accord_to_metric: 143 | self._cur_eval_metric = metrics["eval_f1_macro"] 144 | self._cur_eval_loss = metrics["eval_loss"] 145 | 146 | best_f1_macro = 0 if self.best_metrics is None else self.best_metrics["eval_f1_macro"] 147 | if metrics["eval_f1_macro"] > best_f1_macro: 148 | self.best_metrics = metrics 149 | 150 | if self.control.should_save: 151 | self._save_checkpoint(model, trial, metrics=metrics) 152 | self.control = self.callback_handler.on_save(self.args, self.state, self.control) 153 | 154 | def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: 155 | """ 156 | Modification: 157 | - record eval loss / metric and maintain best model 158 | 159 | NOTE: 160 | to make this function works properly, 161 | the save_steps should be multiples of evaluation_steps 162 | """ 163 | ## NEW 164 | if self.args.save_strategy == "steps": 165 | if self.args.eval_steps != self.args.save_steps: 166 | raise Exception( 167 | "To properly store best models, please make sure eval_steps equals to save_steps." 168 | ) 169 | 170 | if self.args.save_total_limit is None or self.args.save_total_limit <= 0: 171 | return 172 | 173 | # Check if we should delete older checkpoint(s) 174 | checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) 175 | 176 | ## NEW: record the eval metric for the last checkpoint 177 | self._ckpt_eval_loss[checkpoints_sorted[-1]] = self._cur_eval_loss 178 | if self.args.save_model_accord_to_metric: 179 | self._ckpt_eval_metric[checkpoints_sorted[-1]] = self._cur_eval_metric 180 | 181 | if len(checkpoints_sorted) <= self.args.save_total_limit: 182 | return 183 | 184 | """ 185 | # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which 186 | # we don't do to allow resuming. 187 | save_total_limit = self.args.save_total_limit 188 | if ( 189 | self.state.best_model_checkpoint is not None 190 | and self.args.save_total_limit == 1 191 | and checkpoints_sorted[-1] != self.state.best_model_checkpoint 192 | ): 193 | save_total_limit = 2 194 | """ 195 | 196 | number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit) 197 | 198 | ## NEW: sort checkpoints path 199 | if self.args.save_model_accord_to_metric: 200 | ## sort according to metric (ascending for metric) 201 | checkpoints_sorted = [ 202 | k for k, v in sorted(self._ckpt_eval_metric.items(), 203 | key=lambda x: x[1], 204 | reverse=False) 205 | ] 206 | checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] 207 | else: 208 | ## sort according to loss (descending for loss) 209 | checkpoints_sorted = [ 210 | k for k, v in sorted(self._ckpt_eval_loss.items(), 211 | key=lambda x: x[1], 212 | reverse=True) 213 | ] 214 | 215 | #checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] 216 | for checkpoint in checkpoints_to_be_deleted: 217 | logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) 218 | shutil.rmtree(checkpoint) 219 | 220 | ## NEW: remove the deleted ckpt 221 | del self._ckpt_eval_loss[checkpoint] 222 | if self.args.save_model_accord_to_metric: 223 | del self._ckpt_eval_metric[checkpoint] 224 | 225 | self.best_checkpoint_path = checkpoints_sorted[-1] 226 | 227 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 228 | """ 229 | Modification: 230 | - Also record model_args and data_args 231 | """ 232 | # If we are executing this function, we are the process zero, so we don't check for that. 233 | output_dir = output_dir if output_dir is not None else self.args.output_dir 234 | os.makedirs(output_dir, exist_ok=True) 235 | logger.info(f"Saving model checkpoint to {output_dir}") 236 | # Save a trained model and configuration using `save_pretrained()`. 237 | # They can then be reloaded using `from_pretrained()` 238 | if not isinstance(self.model, PreTrainedModel): 239 | if isinstance(unwrap_model(self.model), PreTrainedModel): 240 | if state_dict is None: 241 | state_dict = self.model.state_dict() 242 | unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict) 243 | else: 244 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") 245 | if state_dict is None: 246 | state_dict = self.model.state_dict() 247 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) 248 | else: 249 | self.model.save_pretrained(output_dir, state_dict=state_dict) 250 | if self.tokenizer is not None: 251 | self.tokenizer.save_pretrained(output_dir) 252 | 253 | # Good practice: save your training arguments together with the trained model 254 | torch.save(self.args, os.path.join(output_dir, "training_args.bin")) 255 | ## New 256 | torch.save(self.data_args , os.path.join(output_dir, "data_args.bin")) 257 | torch.save(self.model_args, os.path.join(output_dir, "model_args.bin")) 258 | 259 | -------------------------------------------------------------------------------- /src/models/modeling_bert.py: -------------------------------------------------------------------------------- 1 | from transformers.models.bert.modeling_bert import BertModel 2 | 3 | ## Self-defined 4 | from .generation_utils import generate_with_grad 5 | 6 | class BertModelWithResponseSummarization(BertModel): 7 | """ 8 | A subclass of `BertModel`, used in `BertForRumorDetection`. 9 | Modification: 10 | - Override forward method for response summarization. 11 | - Add two new methods `response_summarization` & `get_gen_hidden_states_from_tuple`. 12 | """ 13 | def __init__(self, config, add_pooling_layer=True, summarizer=None): 14 | super().__init__(config, add_pooling_layer) 15 | 16 | def init_args_modules(self, data_args, model_args, training_args, summarizer=None): 17 | self.data_args = data_args 18 | self.model_args = model_args 19 | self.training_args = training_args 20 | self.summarizer = summarizer 21 | if self.summarizer is not None: 22 | bound_method = generate_with_grad.__get__(self.summarizer, self.summarizer.__class__) 23 | setattr(self.summarizer, "generate_with_grad", bound_method) 24 | 25 | def forward( 26 | self, 27 | input_ids: Optional[torch.Tensor] = None, 28 | attention_mask: Optional[torch.Tensor] = None, 29 | token_type_ids: Optional[torch.Tensor] = None, 30 | position_ids: Optional[torch.Tensor] = None, 31 | head_mask: Optional[torch.Tensor] = None, 32 | inputs_embeds: Optional[torch.Tensor] = None, 33 | encoder_hidden_states: Optional[torch.Tensor] = None, 34 | encoder_attention_mask: Optional[torch.Tensor] = None, 35 | past_key_values: Optional[List[torch.FloatTensor]] = None, 36 | use_cache: Optional[bool] = None, 37 | output_attentions: Optional[bool] = None, 38 | output_hidden_states: Optional[bool] = None, 39 | return_dict: Optional[bool] = None, 40 | ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: 41 | r""" 42 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 43 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 44 | the model is configured as a decoder. 45 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 46 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 47 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 48 | 49 | - 1 for tokens that are **not masked**, 50 | - 0 for tokens that are **masked**. 51 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 52 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 53 | 54 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 55 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 56 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 57 | use_cache (`bool`, *optional*): 58 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 59 | `past_key_values`). 60 | """ 61 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 62 | output_hidden_states = ( 63 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 64 | ) 65 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 66 | 67 | if self.config.is_decoder: 68 | use_cache = use_cache if use_cache is not None else self.config.use_cache 69 | else: 70 | use_cache = False 71 | 72 | if input_ids is not None and inputs_embeds is not None: 73 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 74 | elif input_ids is not None: 75 | input_shape = input_ids.size() 76 | elif inputs_embeds is not None: 77 | input_shape = inputs_embeds.size()[:-1] 78 | else: 79 | raise ValueError("You have to specify either input_ids or inputs_embeds") 80 | 81 | batch_size, seq_length = input_shape 82 | device = input_ids.device if input_ids is not None else inputs_embeds.device 83 | 84 | # past_key_values_length 85 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 86 | 87 | if attention_mask is None: 88 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 89 | 90 | if token_type_ids is None: 91 | if hasattr(self.embeddings, "token_type_ids"): 92 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] 93 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) 94 | token_type_ids = buffered_token_type_ids_expanded 95 | else: 96 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 97 | 98 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 99 | # ourselves in which case we just need to make it broadcastable to all heads. 100 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 101 | 102 | # If a 2D or 3D attention mask is provided for the cross-attention 103 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 104 | if self.config.is_decoder and encoder_hidden_states is not None: 105 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 106 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 107 | if encoder_attention_mask is None: 108 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 109 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 110 | else: 111 | encoder_extended_attention_mask = None 112 | 113 | # Prepare head mask if needed 114 | # 1.0 in head_mask indicate we keep the head 115 | # attention_probs has shape bsz x n_heads x N x N 116 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 117 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 118 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 119 | 120 | embedding_output = self.embeddings( 121 | input_ids=input_ids, 122 | position_ids=position_ids, 123 | token_type_ids=token_type_ids, 124 | inputs_embeds=inputs_embeds, 125 | past_key_values_length=past_key_values_length, 126 | ) 127 | 128 | ## Use summarizer or not 129 | if self.summarizer is not None: 130 | 131 | 132 | encoder_outputs = self.encoder( 133 | embedding_output, 134 | attention_mask=extended_attention_mask, 135 | head_mask=head_mask, 136 | encoder_hidden_states=encoder_hidden_states, 137 | encoder_attention_mask=encoder_extended_attention_mask, 138 | past_key_values=past_key_values, 139 | use_cache=use_cache, 140 | output_attentions=output_attentions, 141 | output_hidden_states=output_hidden_states, 142 | return_dict=return_dict, 143 | ) 144 | sequence_output = encoder_outputs[0] 145 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 146 | 147 | if not return_dict: 148 | return (sequence_output, pooled_output) + encoder_outputs[1:] 149 | 150 | return BaseModelOutputWithPoolingAndCrossAttentions( 151 | last_hidden_state=sequence_output, 152 | pooler_output=pooled_output, 153 | past_key_values=encoder_outputs.past_key_values, 154 | hidden_states=encoder_outputs.hidden_states, 155 | attentions=encoder_outputs.attentions, 156 | cross_attentions=encoder_outputs.cross_attentions, 157 | ) 158 | 159 | def response_summarization( 160 | self, 161 | attention_mask, 162 | inputs_embeds 163 | ): 164 | """Response summarization""" 165 | 166 | ## Ignore source post, only take responses as input 167 | response_inputs_embeds = inputs_embeds[:, self.data_args.max_tweet_length:, :] 168 | response_attention_mask = attention_mask[:, self.data_args.max_tweet_length:] 169 | 170 | summ_kwargs = { 171 | #"max_length": self._max_length if self._max_length is not None else self.model.config.max_length, 172 | "max_length": self.data_args.max_tweet_length, 173 | #"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, 174 | "num_beams": 1, 175 | "output_hidden_states": True, 176 | "return_dict_in_generate": True 177 | } 178 | 179 | if self.training_args.task_type == "train_adv_stage2" and self.training_args.do_train: 180 | ## Enable back-propagation to generator 181 | summarizer_outputs = self.summarizer.generate_with_grad( 182 | attention_mask=response_attention_mask, 183 | inputs_embeds=response_inputs_embeds, 184 | **summ_kwargs 185 | ) 186 | else: 187 | summarizer_outputs = self.summarizer.generate( 188 | attention_mask=response_attention_mask, 189 | inputs_embeds=response_inputs_embeds, 190 | **summ_kwargs 191 | ) 192 | 193 | ## Pad summary_tokens to data_args.max_tweet_length! 194 | summary_tokens = summarizer_outputs["sequences"] 195 | pad_batch = torch.randn((summary_tokens.shape[0], self.data_args.max_tweet_length - summary_tokens.shape[1])).to(self.model.device) 196 | torch.full((summary_tokens.shape[0], 32 - summary_tokens.shape[1]), self.model.config.pad_token_id, out=pad_batch) 197 | summary_tokens = torch.cat((summary_tokens, pad_batch), dim=1) 198 | 199 | summ_hidden_states = self.get_gen_hidden_states_from_tuple(summarizer_outputs["decoder_hidden_states"]) 200 | return summary_tokens, summ_hidden_states 201 | 202 | def get_gen_hidden_states_from_tuple(self, decoder_hidden_states): 203 | """Convert decoder_hidden_states of type tuple into torch tensor.""" 204 | bos_embedding = self.model.shared.weight[self.config.bos_token_id] 205 | pad_embedding = self.model.shared.weight[self.config.pad_token_id] 206 | bos_embedding = bos_embedding.reshape(1, 1, -1).repeat(decoder_hidden_states[0][0].shape[0], 1, 1) 207 | pad_embedding = pad_embedding.reshape(1, 1, -1).repeat(decoder_hidden_states[0][0].shape[0], 1, 1) 208 | 209 | ## Add start token 210 | gen_hidden_states = [bos_embedding] 211 | for token_idx in range(len(decoder_hidden_states)): ## Iterate through all tokens 212 | token_hidden_states = decoder_hidden_states[token_idx] ## Get token hidden states of all layers 213 | token_last_hidden_state = token_hidden_states[-1] ## Get last hidden state of current token 214 | gen_hidden_states.append(token_last_hidden_state) 215 | 216 | ## Padding 217 | paddings = [pad_embedding] * (self.data_args.max_tweet_length - len(gen_hidden_states)) 218 | gen_hidden_states.extend(paddings) 219 | gen_hidden_states = torch.cat(gen_hidden_states, dim=1) 220 | 221 | return gen_hidden_states -------------------------------------------------------------------------------- /src/pipelines/builder_cluster_summary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import optim 10 | from torch.utils.data import DataLoader 11 | 12 | from transformers import default_data_collator 13 | 14 | ## Self-defined 15 | from models.modeling_clustering import ClusterModel 16 | from others.utils import mean_pooling 17 | 18 | class ClusterSummaryBuilder: 19 | """ 20 | Cluster Summary Dataset Builder, 21 | builds dataset for 2nd-stage abstractor training. 22 | """ 23 | def __init__( 24 | self, 25 | model=None, 26 | data_args=None, 27 | model_args=None, 28 | training_args=None, 29 | train_dataset=None, 30 | eval_dataset=None, 31 | ): 32 | self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 33 | 34 | self.data_args = data_args 35 | self.model_args = model_args 36 | self.training_args = training_args 37 | 38 | self.model = model.to(self.device) 39 | self.model_cluster = ClusterModel(cluster_type=self.model_args.cluster_type, num_clusters=self.model_args.num_clusters) 40 | 41 | self.dataset = train_dataset 42 | self.cluster_mode = self.model_args.cluster_mode 43 | assert (self.cluster_mode == "train" or self.cluster_mode == "test"), "please specify the cluster_mode" 44 | 45 | ## Build data loader 46 | assert self.training_args.per_device_train_batch_size == 1, "Batch size should be 1 -> process 1 tree at once" 47 | self.dataloader = DataLoader( 48 | self.dataset, 49 | batch_size=self.training_args.per_device_train_batch_size, 50 | collate_fn=default_data_collator, 51 | num_workers=self.training_args.dataloader_num_workers, 52 | pin_memory=self.training_args.dataloader_pin_memory, 53 | shuffle=False 54 | ) 55 | 56 | def build_cluster_summary(self): 57 | print("\nStart building cluster summary...") 58 | if self.model_cluster.cluster_type == "kmeans": 59 | self.build_by_kmeans() 60 | elif self.model_cluster.cluster_type == "topics": 61 | self.build_by_topics() 62 | 63 | def build_by_kmeans(self): 64 | tree_idx = 0 65 | output_dict = { 66 | "source_id": [], 67 | "tweet_id": [], 68 | "cluster_id": [], 69 | "is_centroid": [] 70 | } 71 | for batch in tqdm(self.dataloader): 72 | tree_lens = batch["tree_lens"].to(self.device) 73 | input_ids = batch["input_ids"].to(self.device) 74 | attn_mask = batch["attention_mask"].to(self.device) 75 | 76 | ## Iterate through each sample (tree) in a batch, actually one tree per batch 77 | for batch_idx in range(input_ids.shape[0]): 78 | tree_len = tree_lens[batch_idx] 79 | tree_source_id = [self.dataset["source_id"][tree_idx]] * tree_len 80 | tree_tweet_ids = self.dataset["tweet_ids"][tree_idx].split(",") 81 | 82 | tree_input_ids = input_ids[batch_idx] 83 | tree_attn_mask = attn_mask[batch_idx] 84 | tree_input_ids = tree_input_ids[tree_input_ids[:, 0] != -1] ## Remove padding nodes 85 | tree_attn_mask = tree_attn_mask[tree_attn_mask[:, 0] != -1] ## Remove padding nodes 86 | 87 | ## Obtain hidden representations 88 | with torch.no_grad(): 89 | ## TODO: make sure to use embeddings or hidden representation? 90 | encoder_outputs = self.model.encoder( 91 | input_ids=tree_input_ids, 92 | attention_mask=tree_attn_mask, 93 | return_dict=True 94 | ) 95 | 96 | ## Get node features by mean pooling on embeddings 97 | node_feat = torch.mean(encoder_outputs["embed_tok"], dim=1) 98 | 99 | ## Clustering 100 | response_feat = node_feat[1:] 101 | if response_feat.shape[0] > 0: ## Has responses 102 | cluster_ids, cluster_centers, dist, is_centroid = \ 103 | self.model_cluster( 104 | node_feat=response_feat, 105 | device=self.device, 106 | mode=self.cluster_mode 107 | ) 108 | 109 | output_dict["source_id"].extend(tree_source_id[1:]) 110 | output_dict["tweet_id"].extend(tree_tweet_ids[1:]) 111 | output_dict["cluster_id"].extend(cluster_ids.tolist()) 112 | output_dict["is_centroid"].extend(is_centroid.tolist()) 113 | else: ## No response 114 | print("Ignore this tree since it has no response.") 115 | 116 | tree_idx = tree_idx + 1 117 | """ 118 | response_feat = node_feat[1:] 119 | if response_feat.shape[0] > 2: ## More than 2 responses 120 | cluster_ids, cluster_centers, dist = self.model_clustering.cluster(node_feat=response_feat, device=self.device) 121 | 122 | ## Format: 123 | ## source_id, tweet_id, cluster_id, is_centroid 124 | output_dict["source_id"].extend(tree_source_id[1:]) 125 | output_dict["tweet_id"].extend(tree_tweet_ids[1:]) 126 | output_dict["cluster_id"].extend(cluster_ids.tolist()) 127 | 128 | ## Find the centroid of each cluster 129 | is_centroid = torch.zeros(dist.shape[0], dtype=torch.long) 130 | for cluster_i in range(dist.shape[1]): 131 | ## Ignore this cluster if no response belongs to this cluster 132 | if cluster_i not in cluster_ids: 133 | continue 134 | 135 | ## Get distance of responses closest to cluster center and set them as centroid 136 | dist_2_i = dist[:, cluster_i] 137 | min_dist = dist_2_i[cluster_ids == cluster_i].min() 138 | min_idxs = (dist_2_i == min_dist).nonzero().flatten() 139 | 140 | is_centroid[min_idxs] = 1 141 | 142 | output_dict["is_centroid"].extend(is_centroid.tolist()) 143 | 144 | elif response_feat.shape[0] > 1: ## Only 2 responses 145 | ## Create only 1 cluster where both responses are centers (take turns to be the center) 146 | cluster_ids = torch.zeros(response_feat.shape[0], dtype=torch.long) 147 | is_centroid = torch.ones(response_feat.shape[0], dtype=torch.long) ## should have shape = 2 148 | 149 | output_dict["source_id"].extend(tree_source_id[1:]) 150 | output_dict["tweet_id"].extend(tree_tweet_ids[1:]) 151 | output_dict["cluster_id"].extend(cluster_ids.tolist()) 152 | output_dict["is_centroid"].extend(is_centroid.tolist()) 153 | 154 | else: ## Less than or equal to 1 response 155 | print("Ignore this tree since it has only 1 response or less.") 156 | """ 157 | ipdb.set_trace() 158 | ################################################################## 159 | ## ** NOTE ** ## 160 | ## - each cluster may have more than one centroids since ## 161 | ## some closest nodes have the same distances from the center ## 162 | ## - these nodes can take turns being the target summary ## 163 | ################################################################## 164 | output_df = pd.DataFrame(data=output_dict) 165 | update_df = [] 166 | for source_id, tweets_df in output_df.groupby("source_id"): 167 | for cluster_id, cluster_df in tweets_df.groupby("cluster_id"): 168 | num_centroids = (cluster_df["is_centroid"] == 1).sum() 169 | if num_centroids > 1: ## More than one centroid! 170 | centroid_tids = cluster_df[cluster_df["is_centroid"] == 1]["tweet_id"] 171 | for sub_idx, cent_tid in enumerate(centroid_tids): ## Each centroid takes turn being the target summary 172 | subcluster_df = cluster_df.copy() 173 | subcluster_df.loc[subcluster_df["tweet_id"] != cent_tid, ["is_centroid"]] = 0 174 | subcluster_df["cluster_id"] = "{}_{}".format(cluster_id, sub_idx) 175 | update_df.append(subcluster_df) 176 | else: ## Only one centroid 177 | cluster_df["cluster_id"] = cluster_df["cluster_id"].astype(str) 178 | update_df.append(cluster_df) 179 | 180 | update_df = pd.concat(update_df) 181 | 182 | ## Output 183 | os.makedirs( 184 | "{}/{}/split_{}/cluster_summary/{}".format( 185 | self.data_args.dataset_root, 186 | self.data_args.dataset_name, 187 | self.data_args.fold, 188 | self.cluster_mode 189 | ), 190 | exist_ok=True 191 | ) 192 | update_df.to_csv( 193 | "{}/{}/split_{}/cluster_summary/{}/kmeans-{}.csv".format( 194 | self.data_args.dataset_root, 195 | self.data_args.dataset_name, 196 | self.data_args.fold, 197 | self.cluster_mode, 198 | self.model_cluster.num_clusters 199 | ), 200 | index=False 201 | ) 202 | 203 | def build_by_topics(self): 204 | tree_idx = 0 205 | output_dict = { 206 | "source_id": [], 207 | "cluster_id": [], 208 | "tweet_ids": [], 209 | "centroid": [] 210 | } 211 | for batch in tqdm(self.dataloader): 212 | tree_lens = batch["tree_lens"].to(self.device) 213 | input_ids = batch["input_ids"].to(self.device) 214 | attn_mask = batch["attention_mask"].to(self.device) 215 | topic_ids = batch["topic_ids"].to(self.device) 216 | topic_msk = batch["topic_msk"].to(self.device) 217 | topic_probs = batch["topic_probs"].to(self.device) 218 | 219 | ## Iterate through each sample (tree) in a batch, actually one tree per batch 220 | for batch_idx in range(input_ids.shape[0]): 221 | tree_len = tree_lens[batch_idx] 222 | tree_source_id = [self.dataset["source_id"][tree_idx]] * tree_len 223 | tree_tweet_ids = self.dataset["tweet_ids"][tree_idx].split(",") 224 | 225 | tree_input_ids = input_ids[batch_idx] 226 | tree_attn_mask = attn_mask[batch_idx] 227 | tree_input_ids = tree_input_ids[tree_input_ids[:, 0] != -1] ## Remove padding nodes 228 | tree_attn_mask = tree_attn_mask[tree_attn_mask[:, 0] != -1] ## Remove padding nodes 229 | tree_topic_ids = topic_ids[batch_idx] 230 | tree_topic_msk = topic_msk[batch_idx] 231 | tree_topic_probs = topic_probs[batch_idx] 232 | 233 | ## Obtain hidden representations 234 | with torch.no_grad(): 235 | ## Get node embeddings 236 | encoder_outputs = self.model.encoder( 237 | input_ids=tree_input_ids, 238 | attention_mask=tree_attn_mask, 239 | return_dict=True 240 | ) 241 | 242 | ## Get node features by mean pooling on embedding 243 | pooling_mask = tree_attn_mask.clone() 244 | seq_lens = pooling_mask.sum(dim=1) 245 | for i, seq_len in enumerate(seq_lens): 246 | pooling_mask[i][seq_len - 1] = 0 247 | pooling_mask[:, 0] = 0 248 | node_emb = mean_pooling(encoder_outputs["embed_tok"], pooling_mask) 249 | 250 | ## Get topic embeddings 251 | topic_outputs = self.model.encoder( 252 | input_ids=tree_topic_ids, 253 | attention_mask=tree_topic_msk, 254 | return_dict=True 255 | ) 256 | topic_emb = topic_outputs["embed_tok"] 257 | 258 | response_emb = node_emb[1:] 259 | cluster_ridx, centroid_ridx = self.model_clustering.cluster(node_feat=response_emb, topic_feat=topic_emb, topic_probs=tree_topic_probs) 260 | 261 | response_tids = tree_tweet_ids[1:] 262 | 263 | tree_idx = tree_idx + 1 264 | if cluster_ridx is None: 265 | continue 266 | 267 | n_cluster = cluster_ridx.shape[1] 268 | output_dict["source_id"].extend([tree_source_id[0]] * 3) 269 | output_dict["cluster_id"].extend(list(range(n_cluster))) 270 | for cluster_id in range(n_cluster): 271 | cluster_tids = [response_tids[ridx] for ridx in cluster_ridx[:, cluster_id]] 272 | output_dict["tweet_ids"].append(",".join(cluster_tids)) 273 | output_dict["centroid"].extend([response_tids[cent_ridx] for cent_ridx in centroid_ridx]) 274 | #output_dict["centroid"].extend([response_tids[cent_ridx] for cent_ridx in cluster_ridx[0]]) 275 | ipdb.set_trace() 276 | 277 | output_df = pd.DataFrame(data=output_dict) 278 | output_df.to_csv( 279 | "{}/{}/split_{}/cluster_summary/topics-3.csv".format( 280 | self.data_args.dataset_root, 281 | self.data_args.dataset_name, 282 | self.data_args.fold 283 | ), 284 | index=False 285 | ) -------------------------------------------------------------------------------- /src/others/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import ipdb 4 | import wandb 5 | import logging 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import datasets 10 | import transformers 11 | from datasets import load_metric 12 | from transformers.trainer_utils import get_last_checkpoint 13 | 14 | import torch 15 | 16 | def mean_pooling(hidden_states, attn_mask, dim=1): 17 | """ 18 | Takes a batch of hidden states and attention masks as inputs. 19 | Inputs: 20 | - hidden_states: (bs, seq_len, hidden_dim) 21 | - attn_mask : (bs, seq_len) 22 | """ 23 | sum_ = (hidden_states * attn_mask.unsqueeze(dim=-1)).sum(dim=1) 24 | avg_ = (sum_ / attn_mask.sum(dim=1).unsqueeze(dim=-1)) 25 | return avg_ 26 | 27 | def setup_logging(logger, training_args, data_args, model_args): 28 | logging.basicConfig( 29 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 30 | datefmt="%m/%d/%Y %H:%M:%S", 31 | handlers=[logging.StreamHandler(sys.stdout)], 32 | ) 33 | 34 | #log_level = training_args.get_process_log_level() 35 | log_level = logging.WARNING ## only report errors & warnings 36 | logger.setLevel(log_level) 37 | datasets.utils.logging.set_verbosity(log_level) 38 | transformers.utils.logging.set_verbosity(log_level) 39 | transformers.utils.logging.enable_default_handler() 40 | transformers.utils.logging.enable_explicit_format() 41 | 42 | ## Log on each process the small summary: 43 | logger.warning( 44 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " 45 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 46 | ) 47 | logger.info(f"Training/evaluation parameters {training_args}") 48 | 49 | ## For summarization 50 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 51 | "t5-small", 52 | "t5-base", 53 | "t5-large", 54 | "t5-3b", 55 | "t5-11b", 56 | ]: 57 | logger.warning( 58 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 59 | "`--source_prefix 'summarize: ' `" 60 | ) 61 | 62 | ## Detecting last checkpoint. 63 | last_checkpoint = None 64 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 65 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 66 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 67 | raise ValueError( 68 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 69 | "Use --overwrite_output_dir to overcome." 70 | ) 71 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 72 | logger.info( 73 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 74 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 75 | ) 76 | 77 | return last_checkpoint 78 | 79 | def find_ckpt_dir(root_dir): 80 | """Find checkpoint directory name given root directory.""" 81 | ckpt_dir = None 82 | for file_or_dir in os.listdir(root_dir): 83 | if file_or_dir.startswith("checkpoint"): 84 | ckpt_dir = file_or_dir 85 | break 86 | if ckpt_dir is None: 87 | raise ValueError("Pre-trained checkpoint directory doesn't exists!") 88 | return ckpt_dir 89 | 90 | def calculate_asr(labels, clean_predictions, dirty_predictions): 91 | """Evaluate the attack success rate given clean predictions & dirty predictions""" 92 | ori_correct = (clean_predictions == labels) 93 | now_wrong = (dirty_predictions != labels) 94 | as_vec = ori_correct & now_wrong 95 | #asr = sum(as_vec) / len(as_vec) ## correct -> wrong 96 | asr = as_vec.sum() / ori_correct.sum() 97 | return asr#, as_vec 98 | 99 | def write_result(metrics, data_args, model_args, training_args, mode=None, gen_flag=False, sum_flag=False, new_line=True, prefix=""): 100 | """Write metrics to `overall_results.txt`""" 101 | ## Write overall results (model, acc, f1-macro, ...) 102 | 103 | if training_args.task_type == "train_detector": 104 | sum_flag = "-" if not sum_flag else "v" 105 | 106 | metrics2report = [ 107 | "{:20s}".format(model_args.model_name[:20]), ## Model 108 | "{:10s}".format(sum_flag), ## sum_flag 109 | "{:8s}".format(mode), ## Mode 110 | "{:20s}".format(data_args.fold[:20]), ## Fold 111 | "{:<10.4f}".format(metrics["eval_accuracy"]), 112 | "{:<10.4f}".format(metrics["eval_f1_macro"]) 113 | ] 114 | 115 | f1s = ["{:<10.4f}".format(metrics["eval_f1_{}".format(label_i)]) for label_i in range(data_args.num_labels)] 116 | metrics2report.extend(f1s) 117 | 118 | with open(training_args.overall_results_path, "a") as fw: 119 | fw.write("{}\n".format("\t".join(metrics2report))) 120 | 121 | elif training_args.task_type == "train_adv_stage1" or training_args.task_type == "train_adv_stage2": 122 | gen_flag = "v" if gen_flag else "-" 123 | sum_flag = "v" if sum_flag else "-" 124 | 125 | if sum_flag == "v": 126 | if "filter" in model_args.extractor_name_or_path or "kmeans" in model_args.extractor_name_or_path: 127 | sum_flag = "DRE" 128 | 129 | if model_args.abstractor_name_or_path: 130 | sum_flag = "DAS" 131 | 132 | metrics2report = [ 133 | "{:8s}".format(mode), ## Mode 134 | "{:10s}".format(gen_flag), 135 | "{:10s}".format(sum_flag), 136 | "{:20s}".format(data_args.fold[:20]), ## Fold 137 | "{:<10.4f}".format(metrics["eval_accuracy"]), 138 | "{:<10.4f}".format(metrics["eval_f1_macro"]) 139 | ] 140 | 141 | f1s = ["{:<10.4f}".format(metrics["eval_f1_{}".format(label_i)]) for label_i in range(data_args.num_labels)] 142 | metrics2report.extend(f1s) 143 | 144 | if "eval_rouge1" in metrics: 145 | metrics2report.append("{:<10.4f}".format(metrics["eval_rouge1"])) 146 | metrics2report.append("{:<10.4f}".format(metrics["eval_rouge2"])) 147 | metrics2report.append("{:<10.4f}".format(metrics["eval_rougeL"])) 148 | 149 | with open(training_args.overall_results_path, "a") as fw: 150 | fw.write("{}".format("\t".join(metrics2report))) 151 | if new_line: 152 | fw.write("\n") 153 | 154 | elif training_args.task_type == "ssra_loo" or \ 155 | training_args.task_type == "ssra_kmeans": 156 | 157 | metrics2report = ["{:20s}".format(data_args.fold[:20])] 158 | metrics2report.append("{:<10.4f}".format(metrics["{}_rouge1".format(prefix)])) 159 | metrics2report.append("{:<10.4f}".format(metrics["{}_rouge2".format(prefix)])) 160 | metrics2report.append("{:<10.4f}".format(metrics["{}_rougeL".format(prefix)])) 161 | with open(training_args.overall_results_path, "a") as fw: 162 | fw.write("{}\n".format("\t".join(metrics2report))) 163 | 164 | def write_generated_responses( 165 | training_args, trainer, tokenizer, eval_dataset, 166 | labels_det, clean_pred_det, dirty_pred_det, dirty_predictions, as_vec 167 | ): 168 | """Write generated adversarial response to file""" 169 | pred_gen = dirty_predictions[1] ## same as clean_predictions[1] 170 | if trainer.is_world_process_zero(): 171 | pred_gen = tokenizer.batch_decode( 172 | pred_gen, skip_special_tokens=True, clean_up_tokenization_spaces=True 173 | ) 174 | pred_gen = [pred.strip() for pred in pred_gen] 175 | 176 | adv_response_path = "{}/generated_response.txt".format(training_args.output_dir) 177 | print("\nWriting generated adversarial responses to {}".format(adv_response_path)) 178 | with open(adv_response_path, "w") as fw: 179 | for pred_idx, response in enumerate(pred_gen): 180 | if as_vec[pred_idx]: 181 | success_or_fail = "success" 182 | else: 183 | success_or_fail = "failure" 184 | 185 | if clean_pred_det[pred_idx] == dirty_pred_det[pred_idx]: 186 | success_or_fail = "-------" 187 | if (clean_pred_det[pred_idx] != labels_det[pred_idx]) and (dirty_pred_det[pred_idx] != labels_det[pred_idx]): 188 | success_or_fail = "-------" 189 | 190 | if isinstance(eval_dataset["source_id"][pred_idx], int): 191 | fw.write( 192 | "{:20d}\t{:10s}[{:10s}->{:10s}]\t{}\t{}\n".format( 193 | eval_dataset["source_id"][pred_idx], 194 | trainer.model.config.id2label[labels_det[pred_idx]], 195 | trainer.model.config.id2label[clean_pred_det[pred_idx]], 196 | trainer.model.config.id2label[dirty_pred_det[pred_idx]], 197 | success_or_fail, response 198 | ) 199 | ) 200 | else: 201 | fw.write( 202 | "{:20s}\t{:10s}[{:10s}->{:10s}]\t{}\t{}\n".format( 203 | eval_dataset["source_id"][pred_idx], 204 | trainer.model.config.id2label[labels_det[pred_idx]], 205 | trainer.model.config.id2label[clean_pred_det[pred_idx]], 206 | trainer.model.config.id2label[dirty_pred_det[pred_idx]], 207 | success_or_fail, response 208 | ) 209 | ) 210 | 211 | def write_response_summary(training_args, trainer, tokenizer, eval_dataset, summ_predictions): 212 | """Write response summary to file""" 213 | summary_tokens = summ_predictions[2] 214 | if trainer.is_world_process_zero(): 215 | summary_tokens = tokenizer.batch_decode( 216 | summary_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True 217 | ) 218 | summary_tokens = [pred.strip() for pred in summary_tokens] 219 | 220 | response_summary_path = "{}/response_summary.txt".format(training_args.output_dir) 221 | print("\nWriting response summary to {}".format(response_summary_path)) 222 | with open(response_summary_path, "w") as fw: 223 | for pred_idx, summary in enumerate(summary_tokens): 224 | if isinstance(eval_dataset["source_id"][pred_idx], int): 225 | fw.write("{:20d}\t{}\n".format(eval_dataset["source_id"][pred_idx], summary)) 226 | else: 227 | fw.write("{:20s}\t{}\n".format(eval_dataset["source_id"][pred_idx], summary)) 228 | 229 | def write_cls_predictions(data_args, training_args, predictions, label_ids): 230 | """Output prediction result of classification task (for PHEME)""" 231 | output_dir = "{}/{}/{}/cls_predictions".format(training_args.output_root, data_args.dataset_name, training_args.exp_name) 232 | os.makedirs(output_dir, exist_ok=True) 233 | 234 | preds = np.argmax(predictions[0], axis=1) 235 | data_ = pd.DataFrame({ 236 | "preds": preds, 237 | "label": label_ids[0] 238 | }) 239 | 240 | data_.to_csv("{}/{}.csv".format(output_dir, data_args.fold), index=False) 241 | 242 | def calculate_PHEME_results(data_args, training_args): 243 | """Calculate Accuracy / macro-averaged F1 / weighted-averaged F1 for PHEME from predictions of all folds.""" 244 | n_folds = 9 245 | metric = { 246 | "accuracy": load_metric("accuracy"), 247 | "f1" : load_metric("f1") 248 | } 249 | input_dir = "{}/{}/{}/cls_predictions".format(training_args.output_root, data_args.dataset_name, training_args.exp_name) 250 | 251 | preds_all, label_all = [], [] 252 | for fold in range(n_folds): 253 | result_csv = pd.read_csv("{}/{}.csv".format(input_dir, fold)) 254 | preds_all.append(result_csv["preds"].to_numpy()) 255 | label_all.append(result_csv["label"].to_numpy()) 256 | 257 | preds_all = np.concatenate(preds_all, axis=0) 258 | label_all = np.concatenate(label_all, axis=0) 259 | 260 | accuracy = metric["accuracy"].compute(predictions=preds_all, references=label_all)["accuracy"] 261 | f1_macro = metric["f1"].compute(predictions=preds_all, references=label_all, average="macro")["f1"] 262 | f1_weighted = metric["f1"].compute(predictions=preds_all, references=label_all, average="weighted")["f1"] 263 | 264 | with open(training_args.overall_results_path, "a") as fw: 265 | fw.write("{:10s}\t{:10s}\t{:10s}\n".format("Acc", "mF1", "wF1")) 266 | fw.write("{:<10.4f}\t{:10.4f}\t{:10.4f}\n".format(accuracy, f1_macro, f1_weighted)) 267 | 268 | def post_process_generative_model(data_args, model_args, model): 269 | """Post processing for generative models.""" 270 | if model.config.decoder_start_token_id is None: 271 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 272 | 273 | if ( 274 | hasattr(model.config, "max_position_embeddings") 275 | and model.config.max_position_embeddings < data_args.max_source_length 276 | ): 277 | if model_args.resize_position_embeddings is None: 278 | logger.warning( 279 | f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " 280 | f"to {data_args.max_source_length}." 281 | ) 282 | model.resize_position_embeddings(data_args.max_source_length) 283 | elif model_args.resize_position_embeddings: 284 | model.resize_position_embeddings(data_args.max_source_length) 285 | else: 286 | raise ValueError( 287 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}" 288 | f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically " 289 | "resize the model's position encodings by passing `--resize_position_embeddings`." 290 | ) 291 | return model -------------------------------------------------------------------------------- /src/others/postprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | import json 4 | import nltk 5 | import argparse 6 | import numpy as np 7 | import pandas as pd 8 | import matplotlib.pyplot as plt 9 | 10 | from tqdm import tqdm 11 | from wordcloud import WordCloud 12 | from scipy.stats import bootstrap 13 | from sklearn.metrics import f1_score 14 | from sklearn.feature_extraction.text import TfidfVectorizer 15 | 16 | ## self-defined 17 | from metrics import f1_score_3_class 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description="Rumor Detection") 21 | 22 | ## What to do 23 | parser.add_argument("--wordcloud", action="store_true") 24 | parser.add_argument("--get_semeval2019_event", action="store_true") 25 | parser.add_argument("--get_semeval2019_event_from_pheme", action="store_true") 26 | parser.add_argument("--semeval2019_event_wise_eval", action="store_true") 27 | parser.add_argument("--bootstrap_accuracy", action="store_true") 28 | 29 | ## Others 30 | parser.add_argument("--dataset_name", type=str, default="semeval2019", choices=["semeval2019", "Pheme", "twitter15", "twitter16"]) 31 | parser.add_argument("--dataset_root", type=str, default="../dataset/processedV2") 32 | parser.add_argument("--results_root", type=str, default="/mnt/1T/projects/RumorV2/results") 33 | parser.add_argument("--n_fold", type=int, default=5) 34 | parser.add_argument("--fold", type=str, default="0,1,2,3,4,comp", help="either use 5-fold data or train/dev/test from rumoureval2019 competition") 35 | 36 | args = parser.parse_args() 37 | 38 | return args 39 | 40 | def plot_wordcloud(args): 41 | print("\nPlot word cloud...") 42 | #folds = args.fold.split(",") 43 | folds = ["comp"] 44 | 45 | for fold in folds: 46 | print("\nProcessing fold {}...".format(fold)) 47 | path_fold = "{}/{}/split_{}".format(args.dataset_root, args.dataset_name, fold) 48 | path_data = "{}/{}/data.csv".format(args.dataset_root, args.dataset_name) 49 | 50 | train_ids = pd.read_csv("{}/train.csv".format(path_fold))["source_id"].tolist() 51 | test_ids = pd.read_csv("{}/test.csv".format(path_fold))["source_id"].tolist() 52 | 53 | train_texts, test_texts = [], [] 54 | 55 | data_df = pd.read_csv(path_data) 56 | print("Getting texts...") 57 | for idx, row in tqdm(data_df.iterrows(), total=len(data_df)): 58 | text = row["text"].replace("", "").replace("URL", "") 59 | text = " ".join(nltk.word_tokenize(text)) 60 | 61 | if row["tweet_id"] in train_ids: 62 | train_texts.append(text) 63 | elif row["tweet_id"] in test_ids: 64 | test_texts.append(text) 65 | 66 | print("Plotting word cloud...") 67 | train_texts = " ".join(train_texts) 68 | wordcloud = WordCloud(width=1000, height=500, background_color="white").generate(train_texts) 69 | wordcloud.to_file("{}/train.cloud.png".format(path_fold)) 70 | 71 | with open("{}/train.cloud.top200.txt".format(path_fold), "w") as fw: 72 | for word, value in wordcloud.words_.items(): 73 | fw.write("{:.4f}\t{}\n".format(value, word)) 74 | 75 | test_texts = " ".join(test_texts) 76 | wordcloud = WordCloud(width=1000, height=500, background_color="white").generate(test_texts) 77 | wordcloud.to_file("{}/test.cloud.png".format(path_fold)) 78 | 79 | with open("{}/test.cloud.top200.txt".format(path_fold), "w") as fw: 80 | for word, value in wordcloud.words_.items(): 81 | fw.write("{:.4f}\t{}\n".format(value, word)) 82 | 83 | def get_semeval2019_event_from_pheme(args, write=True): 84 | def remove_hidden_files_dirs(dirs): 85 | dirs = [dir for dir in dirs if not dir.startswith(".") and "README" not in dir] 86 | return dirs 87 | 88 | path_semeval = "{}/semeval2019/data.csv".format(args.dataset_root) 89 | 90 | path_pheme_0 = "{}/../raw/pheme-rumour-scheme-dataset/threads/en".format(args.dataset_root) 91 | path_pheme_1 = "{}/../raw/PHEME_veracity/all-rnr-annotated-threads".format(args.dataset_root) 92 | path_pheme_2 = "{}/../raw/pheme-rnr-dataset".format(args.dataset_root) 93 | 94 | pheme_ids = {"charliehebdo": [], "ferguson": [], "gurlitt": [], "prince-toronto": [], "sydneysiege": [], "ebola-essien": [], "germanwings-crash": [], "ottawashooting": [], "putinmissing": []} 95 | semeval_ids = {"charliehebdo": [], "ferguson": [], "gurlitt": [], "prince-toronto": [], "sydneysiege": [], "ebola-essien": [], "germanwings-crash": [], "ottawashooting": [], "putinmissing": []} 96 | 97 | print("\n[PHEME]") 98 | ## pheme-rumour-scheme-dataset 99 | event_dir = os.listdir(path_pheme_0) 100 | event_dir = remove_hidden_files_dirs(event_dir) 101 | for dir in event_dir: 102 | tids = os.listdir("{}/{}".format(path_pheme_0, dir)) 103 | tids = remove_hidden_files_dirs(tids) 104 | 105 | pheme_ids[dir].extend(tids) 106 | 107 | ## PHEME_veracity 108 | labels = ["non-rumours", "rumours"] 109 | event_dir = os.listdir(path_pheme_1) 110 | event_dir = remove_hidden_files_dirs(event_dir) 111 | for dir in event_dir: 112 | for label in labels: 113 | tids = os.listdir("{}/{}/{}".format(path_pheme_1, dir, label)) 114 | tids = remove_hidden_files_dirs(tids) 115 | 116 | event_name = dir.replace("-all-rnr-threads", "") 117 | pheme_ids[event_name].extend(tids) 118 | pheme_ids[event_name] = list(set(pheme_ids[event_name])) 119 | 120 | #print("{:20s}: {}".format(event_name, len(pheme_ids[event_name]))) 121 | 122 | ## pheme-rnr-dataset 123 | event_dir = os.listdir(path_pheme_2) 124 | event_dir = remove_hidden_files_dirs(event_dir) 125 | for dir in event_dir: 126 | for label in labels: 127 | tids = os.listdir("{}/{}/{}".format(path_pheme_2, dir, label)) 128 | tids = remove_hidden_files_dirs(tids) 129 | 130 | pheme_ids[dir].extend(tids) 131 | pheme_ids[dir] = list(set(pheme_ids[dir])) 132 | 133 | for event in pheme_ids: 134 | print("{:20s}: {}".format(event, len(pheme_ids[event]))) 135 | 136 | print("\nRead & gather source tweet ID from [SemEval2019]...") 137 | src_ids = [] 138 | data_df = pd.read_csv(path_semeval) 139 | for idx, row in data_df.iterrows(): 140 | if row["source_id"] != row["tweet_id"]: 141 | continue 142 | src_ids.append(str(row["source_id"])) 143 | 144 | print("\n[SemEval2019]") 145 | total = 0 146 | total_ids = [] 147 | for event in pheme_ids.keys(): 148 | for src_id in src_ids: ## Source IDs in RumorEval2019 149 | if src_id in pheme_ids[event]: 150 | semeval_ids[event].append(src_id) 151 | print("{:20s}: {}".format(event, len(semeval_ids[event]))) 152 | total = total + len(semeval_ids[event]) 153 | total_ids.extend(semeval_ids[event]) 154 | print("{:20s}: {}".format("Total", total)) 155 | print("{:20s}: {}".format("Total", len(list(set(total_ids))))) 156 | 157 | if write: 158 | with open("{}/semeval2019/event_map.json".format(args.dataset_root), "w") as fw: 159 | fw.write(json.dumps(semeval_ids, indent=4)) 160 | else: 161 | return semeval_ids 162 | 163 | def get_semeval2019_event(args): 164 | """Get different events of dataset (For semeval2019)""" 165 | event_strs = { 166 | "charliehebdo": ["charliehebdo", "charlie", "hebdo"], 167 | "ferguson": ["ferguson"], 168 | "gurlitt": ["gurlitt"], 169 | "prince-toronto": ["prince-toronto", "prince", "toronto"], 170 | "sydneysiege": ["sydneysiege", "sydney", "siege"], 171 | "ebola-essien": ["ebola-essien", "ebola", "essien"], 172 | "germanwings-crash": ["germanwings-crash", "germanwings", "crash"], 173 | "ottawashooting": ["ottawashooting", "ottawa", "shooting"], 174 | "putinmissing": ["putinmissing", "putin", "missing"] 175 | 176 | } 177 | event_cnt = {"charliehebdo": [], "ferguson": [], "gurlitt": [], "prince-toronto": [], "sydneysiege": [], "ebola-essien": [], "germanwings-crash": [], "ottawashooting": [], "putinmissing": []} 178 | path_in = "{}/{}/data.csv".format(args.dataset_root, args.dataset_name) 179 | 180 | ## Gather texts 181 | print("\nRead & gather source text content...") 182 | src_ids, src_texts = [], [] 183 | data_df = pd.read_csv(path_in) 184 | for idx, row in data_df.iterrows(): 185 | if row["source_id"] != row["tweet_id"]: 186 | continue 187 | src_ids.append(row["source_id"]) 188 | src_texts.append(row["text"]) 189 | 190 | ## Count each event 191 | print("\nCount each event") 192 | total = 0 193 | total_txt = [] 194 | for event in event_cnt.keys(): 195 | 196 | ## Iterate through all source texts 197 | for text in src_texts: 198 | for event_str in event_strs[event]: 199 | if event_str in text.lower(): 200 | event_cnt[event].append(text) 201 | break 202 | 203 | print("{:20s}: {:3d}".format(event, len(event_cnt[event]))) 204 | total += len(event_cnt[event]) 205 | total_txt.extend(event_cnt[event]) 206 | 207 | print(total) 208 | print(len(list(set(total_txt)))) 209 | #ipdb.set_trace() 210 | 211 | def semeval2019_event_wise_eval(args): 212 | event_id_map = get_semeval2019_event_from_pheme(args, write=False) 213 | id_event_map = {} 214 | for event in event_id_map: 215 | for id_ in event_id_map[event]: 216 | id_event_map[id_] = event 217 | 218 | ## Read 5-Fold Predictions 219 | path_in = "{}/semeval2019/bi-tgn-roberta/lr2e-5".format(args.results_root) 220 | event_preds = {"charliehebdo": [], "ferguson": [], "gurlitt": [], "prince-toronto": [], "sydneysiege": [], "ebola-essien": [], "germanwings-crash": [], "ottawashooting": [], "putinmissing": []} 221 | event_label = {"charliehebdo": [], "ferguson": [], "gurlitt": [], "prince-toronto": [], "sydneysiege": [], "ebola-essien": [], "germanwings-crash": [], "ottawashooting": [], "putinmissing": []} 222 | total_preds = 0 223 | for fold_i in range(args.n_fold): 224 | preds_df = pd.read_csv("{}/{}/predictions.csv".format(path_in, fold_i)) 225 | 226 | #print("=" * 32) 227 | for event in event_id_map: 228 | event_ids = event_id_map[event] 229 | event_df = preds_df.loc[preds_df["source_id"].isin(event_ids)] 230 | 231 | #print("{:20s}: {:2d} samples".format(event, len(event_df))) 232 | if len(event_df) == 0: 233 | continue 234 | 235 | hard_pred = event_df["hard_pred"].values 236 | gt_label = event_df["gt_label"].values 237 | 238 | event_preds[event].append(hard_pred) 239 | event_label[event].append(gt_label) 240 | 241 | strs, path_out = [], "{}/event_wise.txt".format(path_in) 242 | header = "{:20s}\t{:9s}\t{:6s}\t{:8s}\t{:8s}\t{:6s}\t{:6s}\t{:6s}\t{:10s}\t{:10s}\t{:10s}".format( 243 | "Event Name", "# Samples", "Acc", "macro-F1", "micro-F1", "F1-0", "F1-1", "F1-2", "# Label-0", "# Label-1", "# Label-2" 244 | ) 245 | print("=" * 25) 246 | print(header) 247 | if not os.path.isfile(path_out): 248 | open(path_out, "w").write("{}\n".format(header)) 249 | 250 | for event in event_preds: 251 | if len(event_preds[event]) == 0: 252 | str_ = "{:20s}\t{:<9d}\t{:.4f}\t{:<8.4f}\t{:<8.4f}\t{:.4f}\t{:.4f}\t{:.4f}".format(event, 0, 0, 0, 0, 0, 0, 0) 253 | strs.append(str_) 254 | print(str_) 255 | continue 256 | 257 | preds = np.concatenate(event_preds[event]) 258 | label = np.concatenate(event_label[event]) 259 | 260 | F1_all = f1_score_3_class(preds, label) 261 | F1_all = np.array(F1_all) 262 | 263 | ## Only take classes that exist in `label` 264 | indices = np.array(list(set(label))) 265 | F1_filt = F1_all[indices] 266 | 267 | micro_f1 = f1_score(preds, label, average="micro") 268 | macro_f1 = sum(F1_filt) / len(F1_filt) 269 | 270 | str_ = "{:20s}\t{:<9d}\t{:.4f}\t{:<8.4f}\t{:<8.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:<10d}\t{:<10d}\t{:<10d}".format( 271 | event, len(label), (preds == label).sum() / len(label), macro_f1, micro_f1, 272 | F1_all[0], F1_all[1], F1_all[2], (label == 0).sum(), (label == 1).sum(), (label == 2).sum() 273 | ) 274 | strs.append(str_) 275 | print(str_) 276 | 277 | #with open(path_out, "a") as fw: 278 | # for str_ in strs: 279 | # fw.write("{}\n".format(str_)) 280 | 281 | def bootstrap_accuracy(args): 282 | def accuracy(x): 283 | return x.sum() / len(x) 284 | 285 | path_in = "{}/semeval2019/bi-tgn-roberta/lr2e-5".format(args.results_root) 286 | 287 | ## Read 5-Fold Predictions 288 | preds, label = [], [] 289 | for fold_i in range(args.n_fold): 290 | preds_df = pd.read_csv("{}/{}/predictions.csv".format(path_in, fold_i)) 291 | preds.extend(preds_df["hard_pred"].tolist()) 292 | label.extend(preds_df["gt_label"].tolist()) 293 | 294 | preds = np.array(preds) 295 | label = np.array(label) 296 | 297 | correct = (preds == label) * 1 298 | correct = (correct, ) 299 | 300 | bootstrap_ci = bootstrap(correct, accuracy, confidence_level=0.95, random_state=123, method="percentile") 301 | 302 | plt.hist(bootstrap_ci.bootstrap_distribution, bins=25) 303 | plt.title("Bootstrap Distribution of RumorEval2019") 304 | plt.tight_layout() 305 | plt.savefig("bootstrap.png", dpi=300) 306 | 307 | ipdb.set_trace() 308 | 309 | if __name__ == "__main__": 310 | args = parse_args() 311 | 312 | if args.wordcloud: 313 | plot_wordcloud(args) 314 | elif args.get_semeval2019_event_from_pheme: 315 | get_semeval2019_event_from_pheme(args) 316 | elif args.get_semeval2019_event: 317 | get_semeval2019_event(args) 318 | elif args.semeval2019_event_wise_eval: 319 | semeval2019_event_wise_eval(args) 320 | elif args.bootstrap_accuracy: 321 | bootstrap_accuracy(args) 322 | 323 | -------------------------------------------------------------------------------- /src/models/build_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ipdb 3 | 4 | import torch 5 | 6 | from transformers import ( 7 | AutoConfig, 8 | AutoModelForSequenceClassification, 9 | AutoTokenizer, 10 | AutoModelForSeq2SeqLM 11 | ) 12 | 13 | ## Self-defined 14 | from .detector import ( 15 | RobertaForRumorDetection, 16 | BertForRumorDetection, 17 | BartEncoderForRumorDetection 18 | ) 19 | from .detector_generator import BartForRumorDetectionAndResponseGeneration 20 | from .modeling_filter import TransformerAutoEncoder 21 | from .modeling_abstractor import BartForAbstractiveResponseSummarization#, RobertaForExtractiveResponseSummarization, ResponseExtractor 22 | from others.utils import find_ckpt_dir, post_process_generative_model 23 | 24 | def build_model(data_args, model_args, training_args): 25 | """Build models according to different tasks""" 26 | 27 | ## Load pre-trained model checkpoint 28 | config, tokenizer, model = load_model(data_args, model_args, training_args) 29 | 30 | ## Initialize other modules 31 | if hasattr(model, "init_args_modules"): 32 | model.init_args_modules(data_args, model_args, training_args, tokenizer=tokenizer) 33 | 34 | ## Load trained model 35 | if training_args.task_type == "train_detector": 36 | if not training_args.do_train and training_args.do_eval: 37 | model_args.model_name_or_path = "{}/{}".format(training_args.output_dir, find_ckpt_dir(training_args.output_dir)) 38 | model.load_state_dict(torch.load("{}/pytorch_model.bin".format(model_args.model_name_or_path))) 39 | print("Detector checkpoint: {}".format(model_args.model_name_or_path)) 40 | 41 | ## Build model for adversarial training 42 | if training_args.task_type == "train_adv_stage1" or \ 43 | training_args.task_type == "train_adv_stage2": 44 | 45 | ## Load trained model 46 | ckpt_path = None 47 | if training_args.task_type == "train_adv_stage1": 48 | if training_args.do_eval and not training_args.do_train: 49 | ckpt_path = "{}/{}/{}/{}".format( 50 | training_args.output_root, data_args.dataset_name, training_args.exp_name, data_args.fold) 51 | print("\nLoading detector checkpoint from adversarial training stage 1...") 52 | 53 | elif training_args.task_type == "train_adv_stage2": 54 | if training_args.do_train: 55 | print("\nLoading model checkpoint from adversarial training stage 1...") 56 | ckpt_path = "{}/{}/{}/adv-stage1/{}".format( 57 | training_args.output_root, data_args.dataset_name, training_args.exp_name.split("/")[0], data_args.fold) 58 | elif training_args.do_eval: 59 | print("\nLoading detector & attacker from adversarial training stage 2...") 60 | ckpt_path = "{}/{}/{}/{}".format( 61 | training_args.output_root, data_args.dataset_name, training_args.exp_name, data_args.fold) 62 | 63 | if ckpt_path is not None: 64 | ckpt_dir = find_ckpt_dir(ckpt_path) 65 | print("Checkpoint path: {}".format("{}/{}/pytorch_model.bin".format(ckpt_path, ckpt_dir))) 66 | 67 | ## Partially load the model checkpoint (ignore summarizer) 68 | ckpt_state_dict = torch.load("{}/{}/pytorch_model.bin".format(ckpt_path, ckpt_dir)) 69 | ckpt_state_dict = { 70 | k: v 71 | for k, v in ckpt_state_dict.items() 72 | if not k.startswith("summarizer") 73 | } 74 | model.load_state_dict(ckpt_state_dict, strict=False) 75 | 76 | print("Model name: {}".format(model.__class__.__name__)) 77 | 78 | return config, tokenizer, model 79 | 80 | def load_model(data_args, model_args, training_args): 81 | """Load pre-trained models according to different tasks""" 82 | print("\nLoading pre-trained model & tokenizer...") 83 | 84 | if training_args.task_type == "train_detector": 85 | if "roberta" in model_args.model_name_or_path: 86 | detector = RobertaForRumorDetection 87 | elif "bert" in model_args.model_name_or_path: 88 | detector = BertForRumorDetection 89 | elif "bart" in model_args.model_name_or_path: 90 | detector = BartEncoderForRumorDetection 91 | 92 | ## Load trained checkpoint if doing evaluation 93 | if not training_args.do_train and training_args.do_eval: 94 | ckpt_dir = find_ckpt_dir(training_args.output_dir) 95 | model_args.model_name_or_path = "{}/{}".format(training_args.output_dir, ckpt_dir) 96 | 97 | print("Detector checkpoint: {}".format(model_args.model_name_or_path)) 98 | 99 | config = AutoConfig.from_pretrained( 100 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 101 | num_labels=data_args.num_labels, 102 | finetuning_task=data_args.task_name, 103 | cache_dir=model_args.cache_dir, 104 | revision=model_args.model_revision, 105 | use_auth_token=True if model_args.use_auth_token else None, 106 | ) 107 | tokenizer = AutoTokenizer.from_pretrained( 108 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 109 | cache_dir=model_args.cache_dir, 110 | use_fast=model_args.use_fast_tokenizer, 111 | revision=model_args.model_revision, 112 | use_auth_token=True if model_args.use_auth_token else None, 113 | ) 114 | model = detector.from_pretrained( 115 | model_args.model_name_or_path, 116 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 117 | config=config, 118 | cache_dir=model_args.cache_dir, 119 | revision=model_args.model_revision, 120 | use_auth_token=True if model_args.use_auth_token else None 121 | ) 122 | 123 | elif training_args.task_type == "predict_summary": 124 | print("Pre-trained encoder-decoder: {}".format(model_args.model_name_or_path)) 125 | config = AutoConfig.from_pretrained( 126 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 127 | cache_dir=model_args.cache_dir, 128 | revision=model_args.model_revision, 129 | use_auth_token=True if model_args.use_auth_token else None, 130 | ) 131 | tokenizer = AutoTokenizer.from_pretrained( 132 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 133 | cache_dir=model_args.cache_dir, 134 | use_fast=model_args.use_fast_tokenizer, 135 | revision=model_args.model_revision, 136 | use_auth_token=True if model_args.use_auth_token else None, 137 | ) 138 | model = AutoModelForSeq2SeqLM.from_pretrained( 139 | model_args.model_name_or_path, 140 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 141 | config=config, 142 | cache_dir=model_args.cache_dir, 143 | revision=model_args.model_revision, 144 | use_auth_token=True if model_args.use_auth_token else None, 145 | ) 146 | model = post_process_generative_model(data_args, model_args, model) 147 | 148 | elif training_args.task_type == "train_adv_stage1" or \ 149 | training_args.task_type == "train_adv_stage2": 150 | 151 | print("Pre-trained encoder-decoder: {}".format(model_args.model_name_or_path)) 152 | config = AutoConfig.from_pretrained( 153 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 154 | cache_dir=model_args.cache_dir, 155 | revision=model_args.model_revision, 156 | use_auth_token=True if model_args.use_auth_token else None, 157 | ) 158 | tokenizer = AutoTokenizer.from_pretrained( 159 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 160 | cache_dir=model_args.cache_dir, 161 | use_fast=model_args.use_fast_tokenizer, 162 | revision=model_args.model_revision, 163 | use_auth_token=True if model_args.use_auth_token else None, 164 | ) 165 | model = BartForRumorDetectionAndResponseGeneration.from_pretrained( 166 | model_args.model_name_or_path, 167 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 168 | config=config, 169 | cache_dir=model_args.cache_dir, 170 | revision=model_args.model_revision, 171 | use_auth_token=True if model_args.use_auth_token else None, 172 | ) 173 | model = post_process_generative_model(data_args, model_args, model) 174 | 175 | elif training_args.task_type == "ssra_loo" or \ 176 | training_args.task_type == "ssra_kmeans": 177 | 178 | ## Load model for evaluation 179 | if (training_args.do_eval and not training_args.do_train): 180 | if training_args.task_type == "ssra_loo" and model_args.model_name_or_path == "ssra_loo": 181 | print("Load model from SSRA-LOO...") 182 | ckpt_path = "{}/{}/ssra_loo/{}".format(training_args.output_root, data_args.dataset_name, data_args.fold) 183 | ckpt_path = "{}/{}".format(ckpt_path, find_ckpt_dir(ckpt_path)) 184 | model_args.model_name_or_path = ckpt_path 185 | elif training_args.task_type == "ssra_kmeans": 186 | print("Load model from SSRA-KMeans...") 187 | ckpt_path = "{}/{}/{}/{}".format(training_args.output_root, data_args.dataset_name, training_args.exp_name, data_args.fold) 188 | ckpt_path = "{}/{}".format(ckpt_path, find_ckpt_dir(ckpt_path)) 189 | model_args.model_name_or_path = ckpt_path 190 | 191 | print("Pre-trained summarizer checkpoint: {}".format(model_args.model_name_or_path)) 192 | config = AutoConfig.from_pretrained( 193 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 194 | cache_dir=model_args.cache_dir, 195 | revision=model_args.model_revision, 196 | use_auth_token=True if model_args.use_auth_token else None, 197 | ) 198 | tokenizer = AutoTokenizer.from_pretrained( 199 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 200 | cache_dir=model_args.cache_dir, 201 | use_fast=model_args.use_fast_tokenizer, 202 | revision=model_args.model_revision, 203 | use_auth_token=True if model_args.use_auth_token else None, 204 | ) 205 | model = BartForAbstractiveResponseSummarization.from_pretrained( 206 | model_args.model_name_or_path, 207 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 208 | config=config, 209 | cache_dir=model_args.cache_dir, 210 | revision=model_args.model_revision, 211 | use_auth_token=True if model_args.use_auth_token else None, 212 | ) 213 | model = post_process_generative_model(data_args, model_args, model) 214 | 215 | elif training_args.task_type == "train_filter": 216 | print("\nLoading rumor detector from adversarial training stage 2 for embedding layer...") 217 | ckpt_path = "{}/{}/bi-tgn/adv-stage2/{}".format(training_args.output_root, data_args.dataset_name, data_args.fold) 218 | ckpt_path = "{}/{}".format(ckpt_path, find_ckpt_dir(ckpt_path)) 219 | print(ckpt_path) 220 | 221 | config = AutoConfig.from_pretrained( 222 | #model_args.config_name if model_args.config_name else model_args.model_name_or_path, 223 | ckpt_path, 224 | num_labels=data_args.num_labels, 225 | finetuning_task=data_args.task_name, 226 | cache_dir=model_args.cache_dir, 227 | revision=model_args.model_revision, 228 | use_auth_token=True if model_args.use_auth_token else None, 229 | ) 230 | tokenizer = AutoTokenizer.from_pretrained( 231 | #model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 232 | ckpt_path, 233 | cache_dir=model_args.cache_dir, 234 | use_fast=model_args.use_fast_tokenizer, 235 | revision=model_args.model_revision, 236 | use_auth_token=True if model_args.use_auth_token else None, 237 | ) 238 | model_emb = BartForRumorDetectionAndResponseGeneration.from_pretrained( 239 | model_args.model_name_or_path, 240 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 241 | config=config, 242 | cache_dir=model_args.cache_dir, 243 | revision=model_args.model_revision, 244 | use_auth_token=True if model_args.use_auth_token else None 245 | ) 246 | model_args.td_gcn = True 247 | model_args.bu_gcn = True 248 | model_emb.init_args_modules(data_args, model_args, training_args, tokenizer=tokenizer) 249 | model_emb.load_state_dict(torch.load("{}/pytorch_model.bin".format(ckpt_path))) 250 | model = TransformerAutoEncoder( 251 | model_emb=model_emb, 252 | num_layers_enc=model_args.filter_layer_enc, 253 | num_layers_dec=model_args.filter_layer_dec 254 | ) 255 | 256 | elif training_args.task_type == "build_cluster_summary": 257 | config = AutoConfig.from_pretrained( 258 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 259 | num_labels=data_args.num_labels, 260 | finetuning_task=data_args.task_name, 261 | cache_dir=model_args.cache_dir, 262 | revision=model_args.model_revision, 263 | use_auth_token=True if model_args.use_auth_token else None, 264 | ) 265 | tokenizer = AutoTokenizer.from_pretrained( 266 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 267 | cache_dir=model_args.cache_dir, 268 | use_fast=model_args.use_fast_tokenizer, 269 | revision=model_args.model_revision, 270 | use_auth_token=True if model_args.use_auth_token else None, 271 | ) 272 | 273 | print("\nLoading rumor detector from adversarial training stage 2...") 274 | ckpt_path = "{}/{}/bi-tgn/adv-stage2/{}".format(training_args.output_root, data_args.dataset_name, data_args.fold) 275 | ckpt_dir = find_ckpt_dir(ckpt_path) 276 | 277 | model = BartForRumorDetectionAndResponseGeneration.from_pretrained( 278 | model_args.model_name_or_path, 279 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 280 | config=config, 281 | cache_dir=model_args.cache_dir, 282 | revision=model_args.model_revision, 283 | use_auth_token=True if model_args.use_auth_token else None 284 | ) 285 | model.load_state_dict(torch.load("{}/{}/pytorch_model.bin".format(ckpt_path, ckpt_dir)), strict=False) 286 | 287 | else: 288 | raise ValueError("training_args.task_type not specified!") 289 | 290 | return config, tokenizer, model --------------------------------------------------------------------------------