├── .gitignore ├── EmotionFlow.pdf ├── README.md ├── config.py ├── crf.py ├── friends_transcript.json ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | .vscode/ 4 | *.log 5 | *.pkl 6 | models/ 7 | MELD/ 8 | *.gz 9 | *.zip 10 | ijcnlp_dailydialog/ 11 | roberta-base/ 12 | roberta-base 13 | distilroberta-base/ 14 | distilroberta-base 15 | distilbert-base-uncased/ 16 | distilbert-base-uncased 17 | IEMOCAP_full_release/ 18 | .DS_Store 19 | ## Core latex/pdflatex auxiliary files: 20 | *.aux 21 | *.lof 22 | *.log 23 | *.lot 24 | *.fls 25 | *.out 26 | *.toc 27 | *.fmt 28 | *.fot 29 | *.cb 30 | *.cb2 31 | .*.lb 32 | 33 | ## Intermediate documents: 34 | *.dvi 35 | *.xdv 36 | *-converted-to.* 37 | # these rules might exclude image files for figures etc. 38 | # *.ps 39 | # *.eps 40 | # *.pdf 41 | 42 | ## Generated if empty string is given at "Please type another file name for output:" 43 | .pdf 44 | 45 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 46 | *.bbl 47 | *.bcf 48 | *.blg 49 | *-blx.aux 50 | *-blx.bib 51 | *.run.xml 52 | 53 | ## Build tool auxiliary files: 54 | *.fdb_latexmk 55 | *.synctex 56 | *.synctex(busy) 57 | *.synctex.gz 58 | *.synctex.gz(busy) 59 | *.pdfsync 60 | 61 | ## Build tool directories for auxiliary files 62 | # latexrun 63 | latex.out/ 64 | 65 | ## Auxiliary and intermediate files from other packages: 66 | # algorithms 67 | *.alg 68 | *.loa 69 | 70 | # achemso 71 | acs-*.bib 72 | 73 | # amsthm 74 | *.thm 75 | 76 | # beamer 77 | *.nav 78 | *.pre 79 | *.snm 80 | *.vrb 81 | 82 | # changes 83 | *.soc 84 | 85 | # comment 86 | *.cut 87 | 88 | # cprotect 89 | *.cpt 90 | 91 | # elsarticle (documentclass of Elsevier journals) 92 | *.spl 93 | 94 | # endnotes 95 | *.ent 96 | 97 | # fixme 98 | *.lox 99 | 100 | # feynmf/feynmp 101 | *.mf 102 | *.mp 103 | *.t[1-9] 104 | *.t[1-9][0-9] 105 | *.tfm 106 | 107 | #(r)(e)ledmac/(r)(e)ledpar 108 | *.end 109 | *.?end 110 | *.[1-9] 111 | *.[1-9][0-9] 112 | *.[1-9][0-9][0-9] 113 | *.[1-9]R 114 | *.[1-9][0-9]R 115 | *.[1-9][0-9][0-9]R 116 | *.eledsec[1-9] 117 | *.eledsec[1-9]R 118 | *.eledsec[1-9][0-9] 119 | *.eledsec[1-9][0-9]R 120 | *.eledsec[1-9][0-9][0-9] 121 | *.eledsec[1-9][0-9][0-9]R 122 | 123 | # glossaries 124 | *.acn 125 | *.acr 126 | *.glg 127 | *.glo 128 | *.gls 129 | *.glsdefs 130 | *.lzo 131 | *.lzs 132 | 133 | # uncomment this for glossaries-extra (will ignore makeindex's style files!) 134 | # *.ist 135 | 136 | # gnuplottex 137 | *-gnuplottex-* 138 | 139 | # gregoriotex 140 | *.gaux 141 | *.gtex 142 | 143 | # htlatex 144 | *.4ct 145 | *.4tc 146 | *.idv 147 | *.lg 148 | *.trc 149 | *.xref 150 | 151 | # hyperref 152 | *.brf 153 | 154 | # knitr 155 | *-concordance.tex 156 | # TODO Comment the next line if you want to keep your tikz graphics files 157 | *.tikz 158 | *-tikzDictionary 159 | 160 | # listings 161 | *.lol 162 | 163 | # luatexja-ruby 164 | *.ltjruby 165 | 166 | # makeidx 167 | *.idx 168 | *.ilg 169 | *.ind 170 | 171 | # minitoc 172 | *.maf 173 | *.mlf 174 | *.mlt 175 | *.mtc[0-9]* 176 | *.slf[0-9]* 177 | *.slt[0-9]* 178 | *.stc[0-9]* 179 | 180 | # minted 181 | _minted* 182 | *.pyg 183 | 184 | # morewrites 185 | *.mw 186 | 187 | # nomencl 188 | *.nlg 189 | *.nlo 190 | *.nls 191 | 192 | # pax 193 | *.pax 194 | 195 | # pdfpcnotes 196 | *.pdfpc 197 | 198 | # sagetex 199 | *.sagetex.sage 200 | *.sagetex.py 201 | *.sagetex.scmd 202 | 203 | # scrwfile 204 | *.wrt 205 | 206 | # sympy 207 | *.sout 208 | *.sympy 209 | sympy-plots-for-*.tex/ 210 | 211 | # pdfcomment 212 | *.upa 213 | *.upb 214 | 215 | # pythontex 216 | *.pytxcode 217 | pythontex-files-*/ 218 | 219 | # tcolorbox 220 | *.listing 221 | 222 | # thmtools 223 | *.loe 224 | 225 | # TikZ & PGF 226 | *.dpth 227 | *.md5 228 | *.auxlock 229 | 230 | # todonotes 231 | *.tdo 232 | 233 | # vhistory 234 | *.hst 235 | *.ver 236 | 237 | # easy-todo 238 | *.lod 239 | 240 | # xcolor 241 | *.xcp 242 | 243 | # xmpincl 244 | *.xmpi 245 | 246 | # xindy 247 | *.xdy 248 | 249 | # xypic precompiled matrices and outlines 250 | *.xyc 251 | *.xyd 252 | 253 | # endfloat 254 | *.ttt 255 | *.fff 256 | 257 | # Latexian 258 | TSWLatexianTemp* 259 | 260 | ## Editors: 261 | # WinEdt 262 | *.bak 263 | *.sav 264 | 265 | # Texpad 266 | .texpadtmp 267 | 268 | # LyX 269 | *.lyx~ 270 | 271 | # Kile 272 | *.backup 273 | 274 | # gummi 275 | .*.swp 276 | 277 | # KBibTeX 278 | *~[0-9]* 279 | 280 | # TeXnicCenter 281 | *.tps 282 | 283 | # auto folder when using emacs and auctex 284 | ./auto/* 285 | *.el 286 | 287 | # expex forward references with \gathertags 288 | *-tags.tex 289 | 290 | # standalone packages 291 | *.sta 292 | 293 | # Makeindex log files 294 | *.lpz 295 | -------------------------------------------------------------------------------- /EmotionFlow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fpcsong/emotionflow/dd0ac8a7e8f6a578627b02b13344a76088daff11/EmotionFlow.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## EmotionFlow 2 | ------ 3 | Source code for ICASSP2022 paper: EmotionFlow: Capture the Dialogue Level Emotion Transitions 4 | 5 | ### Required Packages: 6 | ------ 7 | transformers=4.14.1 8 | 9 | torch=1.8 10 | 11 | vocab=0.0.5 12 | 13 | numpy 14 | 15 | tqdm 16 | 17 | sklearn 18 | 19 | pickle 20 | 21 | pandas 22 | 23 | 24 | ### Quick start: 25 | ------ 26 | download MELD dataset from https://github.com/declare-lab/MELD/ and save to ./MELD 27 | 28 | #### Training 29 | ------ 30 | ``` 31 | python train.py -tr -wp 0 -bsz 1 -acc_step 8 -lr 1e-4 -ptmlr 1e-5 -dpt 0.3 -bert_path roberta-[base, large] -epochs [20, 5] 32 | ``` 33 | 34 | #### Evaluation 35 | ------ 36 | ``` 37 | python train.py -te -ft -bsz 1 -dpt 0.3 -bert_path roberta-[base, large] 38 | ``` 39 | 40 | #### Results 41 | ------ 42 | 43 | | model | weighted-F1 | Checkpoint | 44 | | ------------------------- | ----------- | ------------------------------------------------------------ | 45 | | EmotionFlow-roberta-base | 65.05 | [roberta-base-meld.pkl](https://drive.google.com/file/d/13tTwxFbfO2ZaNJfic3F2AGATzU6ilA5C/view?usp=sharing) | 46 | | EmotionFlow-roberta-large | 66.50 | [roberta-large-meld.pkl](https://drive.google.com/file/d/1zdS4SEvAzR5aVJ852zyaW4IzStQG6fvU/view?usp=sharing) | 47 | 48 | Checkpoints are produced on a single V100 GPU. 49 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModel 2 | import vocab 3 | import os 4 | import torch 5 | import logging 6 | import numpy as np 7 | import json 8 | import functools 9 | import random 10 | import operator 11 | import multiprocessing 12 | from sklearn.metrics import f1_score 13 | import copy 14 | from tqdm import tqdm 15 | from torch.utils.data import ( 16 | DataLoader, 17 | Dataset, 18 | RandomSampler, 19 | SequentialSampler, 20 | TensorDataset 21 | ) 22 | from crf import * 23 | from collections import OrderedDict as odict 24 | import pickle 25 | import pandas as pd 26 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | CONFIG = { 30 | 'bert_path': 'roberta-base', 31 | 'epochs' : 20, 32 | 'lr' : 1e-4, 33 | 'ptmlr' : 5e-6, 34 | 'batch_size' : 1, 35 | 'max_len' : 256, 36 | 'max_value_list' : 16, 37 | 'bert_dim' : 1024, 38 | 'pad_value' : 1, 39 | 'shift' : 1024, 40 | 'dropout' : 0.3, 41 | 'p_unk': 0.1, 42 | 'data_splits' : 20, 43 | 'num_classes' : 7, 44 | 'wp' : 1, 45 | 'wp_pretrain' : 5, 46 | 'data_path' : './MELD/data/MELD/', 47 | 'accumulation_steps' : 8, 48 | 'rnn_layers' : 2, 49 | 'tf_rate': 0.8, 50 | 'aux_loss_weight': 0.3, 51 | 'ngpus' : torch.cuda.device_count(), 52 | 'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 53 | } 54 | 55 | tokenizer = AutoTokenizer.from_pretrained(CONFIG['bert_path']) 56 | _special_tokens_ids = tokenizer('')['input_ids'] 57 | CLS = _special_tokens_ids[0] 58 | SEP = _special_tokens_ids[1] 59 | CONFIG['CLS'] = CLS 60 | CONFIG['SEP'] = SEP 61 | -------------------------------------------------------------------------------- /crf.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.7.2' 2 | ''' 3 | modified from https://github.com/kmkurn/pytorch-crf 4 | ''' 5 | from typing import List, Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class CRF(nn.Module): 12 | """Conditional random field. 13 | 14 | This module implements a conditional random field [LMP01]_. The forward computation 15 | of this class computes the log likelihood of the given sequence of tags and 16 | emission score tensor. This class also has `~CRF.decode` method which finds 17 | the best tag sequence given an emission score tensor using `Viterbi algorithm`_. 18 | 19 | Args: 20 | num_tags: Number of tags. 21 | batch_first: Whether the first dimension corresponds to the size of a minibatch. 22 | 23 | Attributes: 24 | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size 25 | ``(num_tags,)``. 26 | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size 27 | ``(num_tags,)``. 28 | transitions (`~torch.nn.Parameter`): Transition score tensor of size 29 | ``(num_tags, num_tags)``. 30 | 31 | 32 | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). 33 | "Conditional random fields: Probabilistic models for segmenting and 34 | labeling sequence data". *Proc. 18th International Conf. on Machine 35 | Learning*. Morgan Kaufmann. pp. 282–289. 36 | 37 | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm 38 | """ 39 | 40 | def __init__(self, num_tags: int, batch_first: bool = False) -> None: 41 | if num_tags <= 0: 42 | raise ValueError(f'invalid number of tags: {num_tags}') 43 | super().__init__() 44 | self.num_tags = num_tags 45 | self.batch_first = batch_first 46 | self.start_transitions = nn.Parameter(torch.empty(num_tags)) 47 | self.end_transitions = nn.Parameter(torch.empty(num_tags)) 48 | self.global_transitions = nn.Parameter(torch.empty(num_tags, num_tags)) 49 | self.reset_parameters() 50 | 51 | def reset_parameters(self) -> None: 52 | """Initialize the transition parameters. 53 | 54 | The parameters will be initialized randomly from a uniform distribution 55 | between -0.1 and 0.1. 56 | """ 57 | nn.init.uniform_(self.start_transitions, -0.1, 0.1) 58 | nn.init.uniform_(self.end_transitions, -0.1, 0.1) 59 | nn.init.uniform_(self.global_transitions, -0.1, 0.1) 60 | 61 | 62 | def __repr__(self) -> str: 63 | return f'{self.__class__.__name__}(num_tags={self.num_tags})' 64 | 65 | def forward( 66 | self, 67 | emissions: torch.Tensor, 68 | tags: torch.LongTensor, 69 | mask: Optional[torch.ByteTensor] = None, 70 | reduction: str = 'sum', 71 | ) -> torch.Tensor: 72 | """Compute the conditional log likelihood of a sequence of tags given emission scores. 73 | 74 | Args: 75 | emissions (`~torch.Tensor`): Emission score tensor of size 76 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 77 | ``(batch_size, seq_length, num_tags)`` otherwise. 78 | tags (`~torch.LongTensor`): Sequence of tags tensor of size 79 | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, 80 | ``(batch_size, seq_length)`` otherwise. 81 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 82 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 83 | reduction: Specifies the reduction to apply to the output: 84 | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. 85 | ``sum``: the output will be summed over batches. ``mean``: the output will be 86 | averaged over batches. ``token_mean``: the output will be averaged over tokens. 87 | 88 | Returns: 89 | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if 90 | reduction is ``none``, ``()`` otherwise. 91 | """ 92 | self._validate(emissions, tags=tags, mask=mask) 93 | if reduction not in ('none', 'sum', 'mean', 'token_mean'): 94 | raise ValueError(f'invalid reduction: {reduction}') 95 | if mask is None: 96 | mask = torch.ones_like(tags, dtype=torch.uint8) 97 | 98 | if self.batch_first: 99 | emissions = emissions.transpose(0, 1) 100 | tags = tags.transpose(0, 1) 101 | mask = mask.transpose(0, 1) 102 | 103 | # shape: (batch_size,) 104 | numerator = self._compute_score(emissions, tags, mask) 105 | # shape: (batch_size,) 106 | denominator = self._compute_normalizer(emissions, mask) 107 | # shape: (batch_size,) 108 | llh = numerator - denominator 109 | 110 | if reduction == 'none': 111 | return llh 112 | if reduction == 'sum': 113 | return llh.sum() 114 | if reduction == 'mean': 115 | return llh.mean() 116 | assert reduction == 'token_mean' 117 | return llh.sum() / mask.float().sum() 118 | 119 | def decode( 120 | self, emissions: torch.Tensor, 121 | mask: Optional[torch.ByteTensor] = None) -> List[List[int]]: 122 | """Find the most likely tag sequence using Viterbi algorithm. 123 | 124 | Args: 125 | emissions (`~torch.Tensor`): Emission score tensor of size 126 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 127 | ``(batch_size, seq_length, num_tags)`` otherwise. 128 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 129 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 130 | 131 | Returns: 132 | List of list containing the best tag sequence for each batch. 133 | """ 134 | self._validate(emissions, mask=mask) 135 | if mask is None: 136 | mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) 137 | 138 | if self.batch_first: 139 | emissions = emissions.transpose(0, 1) 140 | mask = mask.transpose(0, 1) 141 | 142 | return self._viterbi_decode(emissions, mask) 143 | 144 | def _validate( 145 | self, 146 | emissions: torch.Tensor, 147 | tags: Optional[torch.LongTensor] = None, 148 | mask: Optional[torch.ByteTensor] = None) -> None: 149 | if emissions.dim() != 3: 150 | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') 151 | if emissions.size(2) != self.num_tags: 152 | raise ValueError( 153 | f'expected last dimension of emissions is {self.num_tags}, ' 154 | f'got {emissions.size(2)}') 155 | 156 | if tags is not None: 157 | if emissions.shape[:2] != tags.shape: 158 | raise ValueError( 159 | 'the first two dimensions of emissions and tags must match, ' 160 | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') 161 | 162 | if mask is not None: 163 | if emissions.shape[:2] != mask.shape: 164 | raise ValueError( 165 | 'the first two dimensions of emissions and mask must match, ' 166 | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') 167 | no_empty_seq = not self.batch_first and mask[0].all() 168 | no_empty_seq_bf = self.batch_first and mask[:, 0].all() 169 | if not no_empty_seq and not no_empty_seq_bf: 170 | raise ValueError('mask of the first timestep must all be on') 171 | 172 | def _compute_score( 173 | self, 174 | emissions: torch.Tensor, 175 | tags: torch.LongTensor, 176 | mask: torch.ByteTensor) -> torch.Tensor: 177 | # emissions: (seq_length, batch_size, num_tags) 178 | # tags: (seq_length, batch_size) 179 | # speakers : (seq_length, batch_size) 180 | # last_turns: (seq_length, batch_size) last turn for the current speaker 181 | # mask: (seq_length, batch_size) 182 | assert emissions.dim() == 3 and tags.dim() == 2 183 | assert emissions.shape[:2] == tags.shape 184 | assert emissions.size(2) == self.num_tags 185 | assert mask.shape == tags.shape 186 | assert mask[0].all() 187 | 188 | seq_length, batch_size = tags.shape 189 | mask = mask.float() 190 | 191 | # Start transition score and first emission 192 | # shape: (batch_size,) 193 | # st_transitions = torch.softmax(self.start_transitions, -1) 194 | # ed_transitions = torch.softmax(self.end_transitions, -1) 195 | # transitions = torch.softmax(self.transitions, -1) 196 | # emissions = torch.softmax(emissions, -1) 197 | # personal_transitions = torch.softmax(self.personal_transitions, -1) 198 | st_transitions = self.start_transitions 199 | ed_transitions = self.end_transitions 200 | score = st_transitions[tags[0]] 201 | score += emissions[0, torch.arange(batch_size), tags[0]] 202 | for i in range(1, seq_length): 203 | # Transition score to next tag, only added if next timestep is valid (mask == 1) 204 | # shape: (batch_size,) 205 | global_transitions = self.global_transitions[tags[i - 1], tags[i]] 206 | score += global_transitions * mask[i] 207 | 208 | # Emission score for next tag, only added if next timestep is valid (mask == 1) 209 | # shape: (batch_size,) 210 | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] 211 | 212 | # End transition score 213 | # shape: (batch_size,) 214 | seq_ends = mask.long().sum(dim=0) - 1 215 | # shape: (batch_size,) 216 | last_tags = tags[seq_ends, torch.arange(batch_size)] 217 | # shape: (batch_size,) 218 | score += ed_transitions[last_tags] 219 | 220 | return score 221 | 222 | def _compute_normalizer( 223 | self, emissions: torch.Tensor, 224 | mask: torch.ByteTensor) -> torch.Tensor: 225 | # emissions: (seq_length, batch_size, num_tags) 226 | # mask: (seq_length, batch_size) 227 | assert emissions.dim() == 3 and mask.dim() == 2 228 | assert emissions.shape[:2] == mask.shape 229 | assert emissions.size(2) == self.num_tags 230 | assert mask[0].all() 231 | 232 | seq_length = emissions.size(0) 233 | batch_size = emissions.size(1) 234 | 235 | st_transitions = self.start_transitions 236 | ed_transitions = self.end_transitions 237 | score = st_transitions + emissions[0] 238 | scores = [] 239 | scores.append(score) 240 | for i in range(1, seq_length): 241 | # Broadcast score for every possible next tag 242 | # shape: (batch_size, num_tags, 1) 243 | broadcast_score = score.unsqueeze(2) 244 | 245 | # Broadcast emission score for every possible current tag 246 | # shape: (batch_size, 1, num_tags) 247 | broadcast_emissions = emissions[i].unsqueeze(1) 248 | 249 | global_transitions = self.global_transitions 250 | # shape: (batch_size, num_tags, num_tags) 251 | next_score = broadcast_score + global_transitions + broadcast_emissions 252 | 253 | next_score = torch.logsumexp(next_score, dim=1) 254 | 255 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 256 | 257 | scores.append(score) 258 | 259 | # End transition score 260 | # shape: (batch_size, num_tags) 261 | score += ed_transitions 262 | 263 | # Sum (log-sum-exp) over all possible tags 264 | # shape: (batch_size,) 265 | return torch.logsumexp(score, dim=1) 266 | 267 | def _viterbi_decode( 268 | self, emissions: torch.FloatTensor, 269 | mask: torch.ByteTensor) -> List[List[int]]: 270 | # emissions: (seq_length, batch_size, num_tags) 271 | # mask: (seq_length, batch_size) 272 | assert emissions.dim() == 3 and mask.dim() == 2 273 | assert emissions.shape[:2] == mask.shape 274 | assert emissions.size(2) == self.num_tags 275 | assert mask[0].all() 276 | 277 | seq_length, batch_size = mask.shape 278 | 279 | # Start transition and first emission 280 | # shape: (batch_size, num_tags) 281 | st_transitions = self.start_transitions 282 | ed_transitions = self.end_transitions 283 | score = st_transitions + emissions[0] 284 | history = [] 285 | 286 | # score is a tensor of size (batch_size, num_tags) where for every batch, 287 | # value at column j stores the score of the best tag sequence so far that ends 288 | # with tag j 289 | # history saves where the best tags candidate transitioned from; this is used 290 | # when we trace back the best tag sequence 291 | 292 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence 293 | # for every possible next tag 294 | scores = [] 295 | scores.append(score) 296 | for i in range(1, seq_length): 297 | # Broadcast viterbi score for every possible next tag 298 | # shape: (batch_size, num_tags, 1) 299 | broadcast_score = score.unsqueeze(2) 300 | 301 | # Broadcast emission score for every possible current tag 302 | # shape: (batch_size, 1, num_tags) 303 | broadcast_emissions = emissions[i].unsqueeze(1) 304 | 305 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 306 | # for each sample, entry at row i and column j stores the score of the best 307 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting 308 | # shape: (batch_size, num_tags, num_tags) 309 | global_transitions = self.global_transitions 310 | 311 | next_score = broadcast_score + global_transitions + broadcast_emissions 312 | 313 | # Find the maximum score over all possible current tag 314 | # shape: (batch_size, num_tags) 315 | next_score, indices = next_score.max(dim=1) 316 | 317 | # Set score to the next score if this timestep is valid (mask == 1) 318 | # and save the index that produces the next score 319 | # shape: (batch_size, num_tags) 320 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 321 | scores.append(score) 322 | history.append(indices) 323 | 324 | # End transition score 325 | # shape: (batch_size, num_tags) 326 | score += ed_transitions 327 | 328 | # Now, compute the best path for each sample 329 | 330 | # shape: (batch_size,) 331 | seq_ends = mask.long().sum(dim=0) - 1 332 | best_tags_list = [] 333 | 334 | for idx in range(batch_size): 335 | # Find the tag which maximizes the score at the last timestep; this is our best tag 336 | # for the last timestep 337 | _, best_last_tag = score[idx].max(dim=0) 338 | best_tags = [best_last_tag.item()] 339 | 340 | # We trace back where the best last tag comes from, append that to our best tag 341 | # sequence, and trace it back again, and so on 342 | for hist in reversed(history[:seq_ends[idx]]): 343 | best_last_tag = hist[idx][best_tags[-1]] 344 | best_tags.append(best_last_tag.item()) 345 | 346 | # Reverse the order because we start from the last timestep 347 | best_tags.reverse() 348 | best_tags_list.append(best_tags) 349 | 350 | return best_tags_list -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModel 3 | from crf import * 4 | 5 | 6 | class CRFModel(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.config = config 10 | self.dropout = config['dropout'] 11 | self.num_classes = config['num_classes'] 12 | self.pad_value = config['pad_value'] 13 | self.CLS = config['CLS'] 14 | self.context_encoder = AutoModel.from_pretrained( 15 | config['bert_path']) 16 | self.dim = self.context_encoder.embeddings.word_embeddings.weight.data.shape[-1] 17 | self.spk_embeddings = nn.Embedding(300, self.dim) 18 | self.crf_layer = CRF(self.num_classes) 19 | self.emission = nn.Linear(self.dim, self.num_classes) 20 | self.loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1) 21 | def device(self): 22 | return self.context_encoder.device 23 | def forward(self, sentences, sentences_mask, speaker_ids, last_turns, emotion_idxs=None): 24 | ''' 25 | sentences: batch * max_turns * max_length 26 | speaker_ids: batch * max_turns 27 | emotion[optional] : batch * max_turns 28 | ''' 29 | batch_size = sentences.shape[0] 30 | max_turns = sentences.shape[1] 31 | max_len = sentences.shape[-1] 32 | speaker_ids = speaker_ids.reshape(batch_size * max_turns, -1) 33 | sentences = sentences.reshape(batch_size * max_turns, -1) 34 | cls_id = torch.ones_like(speaker_ids) * self.CLS 35 | input_ids = torch.cat((cls_id, sentences), 1) 36 | mask = 1 - (input_ids == (self.pad_value)).long() 37 | # with torch.no_grad(): 38 | utterance_encoded = self.context_encoder( 39 | input_ids=input_ids, 40 | attention_mask=mask, 41 | output_hidden_states=True, 42 | return_dict=True 43 | )['last_hidden_state'] 44 | mask_pos = mask.sum(1)-2 45 | features = utterance_encoded[torch.arange(mask_pos.shape[0]), mask_pos, :] 46 | emissions = self.emission(features) 47 | crf_emissions = emissions.reshape(batch_size, max_turns, -1) 48 | crf_emissions = crf_emissions.transpose(0, 1) 49 | sentences_mask = sentences_mask.transpose(0, 1) 50 | speaker_ids = speaker_ids.reshape(batch_size, max_turns).transpose(0, 1) 51 | last_turns = last_turns.transpose(0, 1) 52 | # train 53 | if emotion_idxs is not None: 54 | emotion_idxs = emotion_idxs.transpose(0, 1) 55 | loss1 = -self.crf_layer(crf_emissions, emotion_idxs, mask=sentences_mask) 56 | # 接上分类loss让CRF专注序列信息 57 | loss2 = self.loss_func(emissions.view(-1, self.num_classes), emotion_idxs.view(-1)) 58 | loss = loss1 + loss2 59 | return loss 60 | # test 61 | else: 62 | return self.crf_layer.decode(crf_emissions, mask=sentences_mask) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | from model import CRFModel 3 | 4 | speaker_vocab_dict_path = 'vocabs/speaker_vocab.pkl' 5 | emotion_vocab_dict_path = 'vocabs/emotion_vocab.pkl' 6 | sentiment_vocab_dict_path = 'vocabs/sentiment_vocab.pkl' 7 | 8 | 9 | def pad_to_len(list_data, max_len, pad_value): 10 | list_data = list_data[-max_len:] 11 | len_to_pad = max_len-len(list_data) 12 | pads = [pad_value] * len_to_pad 13 | list_data.extend(pads) 14 | return list_data 15 | 16 | 17 | def get_vocabs(file_paths, addi_file_path): 18 | speaker_vocab = vocab.UnkVocab() 19 | emotion_vocab = vocab.Vocab() 20 | sentiment_vocab = vocab.Vocab() 21 | # 保证neutral 在第0类 22 | emotion_vocab.word2index('neutral', train=True) 23 | # global speaker_vocab, emotion_vocab 24 | for file_path in file_paths: 25 | data = pd.read_csv(file_path) 26 | for row in tqdm(data.iterrows(), desc='get vocab from {}'.format(file_path)): 27 | meta = row[1] 28 | emotion = meta['Emotion'].lower() 29 | emotion_vocab.word2index(emotion, train=True) 30 | additional_data = json.load(open(addi_file_path, 'r')) 31 | for episode_id in additional_data: 32 | for scene in additional_data.get(episode_id): 33 | for utterance in scene['utterances']: 34 | speaker = utterance['speakers'][0].lower() 35 | speaker_vocab.word2index(speaker, train=True) 36 | speaker_vocab = speaker_vocab.prune_by_count(1000) 37 | speakers = list(speaker_vocab.counts.keys()) 38 | speaker_vocab = vocab.UnkVocab() 39 | for speaker in speakers: 40 | speaker_vocab.word2index(speaker, train=True) 41 | 42 | logging.info('total {} speakers'.format(len(speaker_vocab.counts.keys()))) 43 | torch.save(emotion_vocab.to_dict(), emotion_vocab_dict_path) 44 | torch.save(speaker_vocab.to_dict(), speaker_vocab_dict_path) 45 | torch.save(sentiment_vocab.to_dict(), sentiment_vocab_dict_path) 46 | 47 | def load_emorynlp_and_builddataset(file_path, train=False): 48 | speaker_vocab = vocab.UnkVocab.from_dict(torch.load( 49 | speaker_vocab_dict_path 50 | )) 51 | emotion_vocab = vocab.Vocab.from_dict(torch.load( 52 | emotion_vocab_dict_path 53 | )) 54 | data = pd.read_csv(file_path) 55 | ret_utterances = [] 56 | ret_speaker_ids = [] 57 | ret_emotion_idxs = [] 58 | utterances = [] 59 | full_contexts = [] 60 | speaker_ids = [] 61 | emotion_idxs = [] 62 | sentiment_idxs = [] 63 | pre_dial_id = -1 64 | max_turns = 0 65 | for row in tqdm(data.iterrows(), desc='processing file {}'.format(file_path)): 66 | meta = row[1] 67 | utterance = meta['Utterance'].lower().replace( 68 | '’', '\'').replace("\"", '') 69 | speaker = meta['Speaker'].lower() 70 | utterance = speaker + ' says:, ' + utterance 71 | emotion = meta['Emotion'].lower() 72 | dialogue_id = meta['Scene_ID'] 73 | utterance_id = meta['Utterance_ID'] 74 | if pre_dial_id == -1: 75 | pre_dial_id = dialogue_id 76 | if dialogue_id != pre_dial_id: 77 | ret_utterances.append(full_contexts) 78 | ret_speaker_ids.append(speaker_ids) 79 | ret_emotion_idxs.append(emotion_idxs) 80 | max_turns = max(max_turns, len(utterances)) 81 | utterances = [] 82 | full_contexts = [] 83 | speaker_ids = [] 84 | emotion_idxs = [] 85 | pre_dial_id = dialogue_id 86 | speaker_id = speaker_vocab.word2index(speaker) 87 | emotion_idx = emotion_vocab.word2index(emotion) 88 | token_ids = tokenizer(utterance, add_special_tokens=False)[ 89 | 'input_ids'] + [CONFIG['SEP']] 90 | full_context = [] 91 | if len(utterances) > 0: 92 | context = utterances[-3:] 93 | for pre_uttr in context: 94 | full_context += pre_uttr 95 | full_context += token_ids 96 | # query 97 | query = speaker + ' feels ' 98 | query_ids = [CONFIG['SEP']] + tokenizer(query, add_special_tokens=False)['input_ids'] + [CONFIG['SEP']] 99 | full_context += query_ids 100 | 101 | full_context = pad_to_len( 102 | full_context, CONFIG['max_len'], CONFIG['pad_value']) 103 | # + CONFIG['shift'] 104 | utterances.append(token_ids) 105 | full_contexts.append(full_context) 106 | speaker_ids.append(speaker_id) 107 | emotion_idxs.append(emotion_idx) 108 | 109 | pad_utterance = [CONFIG['SEP']] + tokenizer( 110 | "1", 111 | add_special_tokens=False 112 | )['input_ids'] + [CONFIG['SEP']] 113 | pad_utterance = pad_to_len( 114 | pad_utterance, CONFIG['max_len'], CONFIG['pad_value']) 115 | # for CRF 116 | ret_mask = [] 117 | ret_last_turns = [] 118 | for dial_id, utterances in tqdm(enumerate(ret_utterances), desc='build dataset'): 119 | mask = [1] * len(utterances) 120 | while len(utterances) < max_turns: 121 | utterances.append(pad_utterance) 122 | ret_emotion_idxs[dial_id].append(-1) 123 | ret_speaker_ids[dial_id].append(0) 124 | mask.append(0) 125 | ret_mask.append(mask) 126 | ret_utterances[dial_id] = utterances 127 | 128 | last_turns = [-1] * max_turns 129 | for turn_id in range(max_turns): 130 | curr_spk = ret_speaker_ids[dial_id][turn_id] 131 | if curr_spk == 0: 132 | break 133 | for idx in range(0, turn_id): 134 | if curr_spk == ret_speaker_ids[dial_id][idx]: 135 | last_turns[turn_id] = idx 136 | ret_last_turns.append(last_turns) 137 | dataset = TensorDataset( 138 | torch.LongTensor(ret_utterances), 139 | torch.LongTensor(ret_speaker_ids), 140 | torch.LongTensor(ret_emotion_idxs), 141 | torch.ByteTensor(ret_mask), 142 | torch.LongTensor(ret_last_turns) 143 | ) 144 | return dataset 145 | 146 | 147 | def load_meld_and_builddataset(file_path, train=False): 148 | speaker_vocab = vocab.UnkVocab.from_dict(torch.load( 149 | speaker_vocab_dict_path 150 | )) 151 | emotion_vocab = vocab.Vocab.from_dict(torch.load( 152 | emotion_vocab_dict_path 153 | )) 154 | 155 | data = pd.read_csv(file_path) 156 | ret_utterances = [] 157 | ret_speaker_ids = [] 158 | ret_emotion_idxs = [] 159 | utterances = [] 160 | full_contexts = [] 161 | speaker_ids = [] 162 | emotion_idxs = [] 163 | pre_dial_id = -1 164 | max_turns = 0 165 | for row in tqdm(data.iterrows(), desc='processing file {}'.format(file_path)): 166 | meta = row[1] 167 | utterance = meta['Utterance'].replace( 168 | '’', '\'').replace("\"", '') 169 | speaker = meta['Speaker'] 170 | utterance = speaker + ' says:, ' + utterance 171 | emotion = meta['Emotion'].lower() 172 | dialogue_id = meta['Dialogue_ID'] 173 | utterance_id = meta['Utterance_ID'] 174 | if pre_dial_id == -1: 175 | pre_dial_id = dialogue_id 176 | if dialogue_id != pre_dial_id: 177 | ret_utterances.append(full_contexts) 178 | ret_speaker_ids.append(speaker_ids) 179 | ret_emotion_idxs.append(emotion_idxs) 180 | max_turns = max(max_turns, len(utterances)) 181 | utterances = [] 182 | full_contexts = [] 183 | speaker_ids = [] 184 | emotion_idxs = [] 185 | pre_dial_id = dialogue_id 186 | speaker_id = speaker_vocab.word2index(speaker) 187 | emotion_idx = emotion_vocab.word2index(emotion) 188 | token_ids = tokenizer(utterance, add_special_tokens=False)[ 189 | 'input_ids'] + [CONFIG['SEP']] 190 | full_context = [] 191 | if len(utterances) > 0: 192 | context = utterances[-3:] 193 | for pre_uttr in context: 194 | full_context += pre_uttr 195 | full_context += token_ids 196 | # query 197 | query = 'Now ' + speaker + ' feels ' 198 | query_ids = tokenizer(query, add_special_tokens=False)['input_ids'] + [CONFIG['SEP']] 199 | full_context += query_ids 200 | 201 | full_context = pad_to_len( 202 | full_context, CONFIG['max_len'], CONFIG['pad_value']) 203 | # + CONFIG['shift'] 204 | utterances.append(token_ids) 205 | full_contexts.append(full_context) 206 | speaker_ids.append(speaker_id) 207 | emotion_idxs.append(emotion_idx) 208 | 209 | pad_utterance = [CONFIG['SEP']] + tokenizer( 210 | "1", 211 | add_special_tokens=False 212 | )['input_ids'] + [CONFIG['SEP']] 213 | pad_utterance = pad_to_len( 214 | pad_utterance, CONFIG['max_len'], CONFIG['pad_value']) 215 | # for CRF 216 | ret_mask = [] 217 | ret_last_turns = [] 218 | for dial_id, utterances in tqdm(enumerate(ret_utterances), desc='build dataset'): 219 | mask = [1] * len(utterances) 220 | while len(utterances) < max_turns: 221 | utterances.append(pad_utterance) 222 | ret_emotion_idxs[dial_id].append(-1) 223 | ret_speaker_ids[dial_id].append(0) 224 | mask.append(0) 225 | ret_mask.append(mask) 226 | ret_utterances[dial_id] = utterances 227 | 228 | last_turns = [-1] * max_turns 229 | for turn_id in range(max_turns): 230 | curr_spk = ret_speaker_ids[dial_id][turn_id] 231 | if curr_spk == 0: 232 | break 233 | for idx in range(0, turn_id): 234 | if curr_spk == ret_speaker_ids[dial_id][idx]: 235 | last_turns[turn_id] = idx 236 | ret_last_turns.append(last_turns) 237 | dataset = TensorDataset( 238 | torch.LongTensor(ret_utterances), 239 | torch.LongTensor(ret_speaker_ids), 240 | torch.LongTensor(ret_emotion_idxs), 241 | torch.ByteTensor(ret_mask), 242 | torch.LongTensor(ret_last_turns) 243 | ) 244 | return dataset 245 | 246 | def get_paramsgroup(model, warmup=False): 247 | no_decay = ['bias', 'LayerNorm.weight'] 248 | pre_train_lr = CONFIG['ptmlr'] 249 | ''' 250 | frozen_params = [] 251 | frozen_layers = [3,4,5,6,7,8] 252 | for layer_idx in frozen_layers: 253 | frozen_params.extend( 254 | list(map(id, model.context_encoder.encoder.layer[layer_idx].parameters())) 255 | ) 256 | ''' 257 | bert_params = list(map(id, model.context_encoder.parameters())) 258 | crf_params = list(map(id, model.crf_layer.parameters())) 259 | params = [] 260 | warmup_params = [] 261 | for name, param in model.named_parameters(): 262 | # if id(param) in frozen_params: 263 | # continue 264 | lr = CONFIG['lr'] 265 | weight_decay = 0 266 | if id(param) in bert_params: 267 | lr = pre_train_lr 268 | if id(param) in crf_params: 269 | lr = CONFIG['lr'] * 10 270 | if not any(nd in name for nd in no_decay): 271 | weight_decay = 0 272 | params.append( 273 | { 274 | 'params': param, 275 | 'lr': lr, 276 | 'weight_decay': weight_decay 277 | } 278 | ) 279 | # warmup的时候不考虑bert 280 | warmup_params.append( 281 | { 282 | 'params': param, 283 | 'lr': 0 if id(param) in bert_params else lr, 284 | 'weight_decay': weight_decay 285 | } 286 | ) 287 | if warmup: 288 | return warmup_params 289 | params = sorted(params, key=lambda x: x['lr'], reverse=True) 290 | return params 291 | 292 | 293 | def train_epoch(model, optimizer, data, epoch_num=0, max_step=-1): 294 | 295 | loss_func = torch.nn.CrossEntropyLoss(ignore_index=-1) 296 | sampler = RandomSampler(data) 297 | dataloader = DataLoader( 298 | data, 299 | batch_size=CONFIG['batch_size'], 300 | sampler=sampler, 301 | num_workers=0 # multiprocessing.cpu_count() 302 | ) 303 | tq_train = tqdm(total=len(dataloader), position=1) 304 | accumulation_steps = CONFIG['accumulation_steps'] 305 | 306 | for batch_id, batch_data in enumerate(dataloader): 307 | batch_data = [x.to(model.device()) for x in batch_data] 308 | sentences = batch_data[0] 309 | speaker_ids = batch_data[1] 310 | emotion_idxs = batch_data[2] 311 | mask = batch_data[3] 312 | last_turns = batch_data[4] 313 | outputs = model(sentences, mask, speaker_ids, last_turns, emotion_idxs) 314 | loss = outputs 315 | # loss += loss_func(outputs[3], sentiment_idxs) 316 | tq_train.set_description('loss is {:.2f}'.format(loss.item())) 317 | tq_train.update() 318 | loss = loss / accumulation_steps 319 | loss.backward() 320 | if batch_id % accumulation_steps == 0: 321 | optimizer.step() 322 | optimizer.zero_grad() 323 | # torch.cuda.empty_cache() 324 | tq_train.close() 325 | 326 | 327 | def test(model, data): 328 | 329 | pred_list = [] 330 | hidden_pred_list = [] 331 | selection_list = [] 332 | y_true_list = [] 333 | model.eval() 334 | sampler = SequentialSampler(data) 335 | dataloader = DataLoader( 336 | data, 337 | batch_size=CONFIG['batch_size'], 338 | sampler=sampler, 339 | num_workers=0, # multiprocessing.cpu_count() 340 | ) 341 | tq_test = tqdm(total=len(dataloader), desc="testing", position=2) 342 | for batch_id, batch_data in enumerate(dataloader): 343 | batch_data = [x.to(model.device()) for x in batch_data] 344 | sentences = batch_data[0] 345 | speaker_ids = batch_data[1] 346 | emotion_idxs = batch_data[2].cpu().numpy().tolist() 347 | mask = batch_data[3] 348 | last_turns = batch_data[4] 349 | outputs = model(sentences, mask, speaker_ids, last_turns) 350 | for batch_idx in range(mask.shape[0]): 351 | for seq_idx in range(mask.shape[1]): 352 | if mask[batch_idx][seq_idx]: 353 | pred_list.append(outputs[batch_idx][seq_idx]) 354 | y_true_list.append(emotion_idxs[batch_idx][seq_idx]) 355 | tq_test.update() 356 | F1 = f1_score(y_true=y_true_list, y_pred=pred_list, average='weighted') 357 | model.train() 358 | return F1 359 | 360 | 361 | def train(model, train_data_path, dev_data_path, test_data_path): 362 | if CONFIG['task_name'] == 'meld': 363 | devset = load_meld_and_builddataset(dev_data_path) 364 | testset = load_meld_and_builddataset(test_data_path) 365 | trainset = load_meld_and_builddataset(train_data_path) 366 | else: 367 | devset = load_emorynlp_and_builddataset(dev_data_path) 368 | testset = load_emorynlp_and_builddataset(test_data_path) 369 | trainset = load_emorynlp_and_builddataset(train_data_path) 370 | 371 | 372 | # warmup 373 | optimizer = torch.optim.AdamW(get_paramsgroup(model, warmup=True)) 374 | for epoch in range(CONFIG['wp']): 375 | train_epoch(model, optimizer, trainset, epoch_num=epoch) 376 | torch.cuda.empty_cache() 377 | f1 = test(model, devset) 378 | torch.cuda.empty_cache() 379 | print('f1 on dev @ warmup epoch {} is {:.4f}'.format( 380 | epoch, f1), flush=True) 381 | # train 382 | optimizer = torch.optim.AdamW(get_paramsgroup(model)) 383 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 384 | optimizer, step_size=1, gamma=0.9) 385 | best_f1 = -1 386 | tq_epoch = tqdm(total=CONFIG['epochs'], position=0) 387 | for epoch in range(CONFIG['epochs']): 388 | tq_epoch.set_description('training on epoch {}'.format(epoch)) 389 | tq_epoch.update() 390 | train_epoch(model, optimizer, trainset, epoch_num=epoch) 391 | torch.cuda.empty_cache() 392 | f1 = test(model, devset) 393 | torch.cuda.empty_cache() 394 | print('f1 on dev @ epoch {} is {:.4f}'.format(epoch, f1), flush=True) 395 | # ''' 396 | if f1 > best_f1: 397 | best_f1 = f1 398 | torch.save(model, 399 | 'models/f1_{:.4f}_@epoch{}.pkl' 400 | .format(best_f1, epoch)) 401 | if lr_scheduler.get_last_lr()[0] > 1e-5: 402 | lr_scheduler.step() 403 | f1 = test(model, testset) 404 | print('f1 on test @ epoch {} is {:.4f}'.format(epoch, f1), flush=True) 405 | # f1 = test(model, test_on_trainset) 406 | # print('f1 on train @ epoch {} is {:.4f}'.format(epoch, f1), flush=True) 407 | # ''' 408 | tq_epoch.close() 409 | lst = os.listdir('./models') 410 | lst = list(filter(lambda item: item.endswith('.pkl'), lst)) 411 | lst.sort(key=lambda x: os.path.getmtime(os.path.join('models', x))) 412 | model = torch.load(os.path.join('models', lst[-1])) 413 | f1 = test(model, testset) 414 | print('best f1 on test is {:.4f}'.format(f1), flush=True) 415 | 416 | 417 | if __name__ == '__main__': 418 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 419 | parser.add_argument('-te', '--test', action='store_true', 420 | help='run test', default=False) 421 | parser.add_argument('-tr', '--train', action='store_true', 422 | help='run train', default=False) 423 | parser.add_argument('-ft', '--finetune', action='store_true', 424 | help='fine tune base the best model', default=False) 425 | parser.add_argument('-pr', '--print_error', action='store_true', 426 | help='print error case', default=False) 427 | parser.add_argument('-bsz', '--batch', help='Batch_size', 428 | required=False, default=CONFIG['batch_size'], type=int) 429 | parser.add_argument('-epochs', '--epochs', help='epochs', 430 | required=False, default=CONFIG['epochs'], type=int) 431 | parser.add_argument('-lr', '--lr', help='learning rate', 432 | required=False, default=CONFIG['lr'], type=float) 433 | parser.add_argument('-p_unk', '--p_unk', help='prob to generate unk speaker', 434 | required=False, default=CONFIG['p_unk'], type=float) 435 | parser.add_argument('-ptmlr', '--ptm_lr', help='ptm learning rate', 436 | required=False, default=CONFIG['ptmlr'], type=float) 437 | parser.add_argument('-tsk', '--task_name', default='meld', type=str) 438 | parser.add_argument('-fp16', '--fp_16', action='store_true', 439 | help='use fp 16', default=False) 440 | parser.add_argument('-wp', '--warm_up', default=CONFIG['wp'], 441 | type=int, required=False) 442 | parser.add_argument('-dpt', '--dropout', default=CONFIG['dropout'], 443 | type=float, required=False) 444 | parser.add_argument('-e_stop', '--eval_stop', 445 | default=500, type=int, required=False) 446 | parser.add_argument('-bert_path', '--bert_path', 447 | default=CONFIG['bert_path'], type=str, required=False) 448 | parser.add_argument('-data_path', '--data_path', 449 | default=CONFIG['data_path'], type=str, required=False) 450 | parser.add_argument('-acc_step', '--accumulation_steps', 451 | default=CONFIG['accumulation_steps'], type=int, required=False) 452 | 453 | args = parser.parse_args() 454 | CONFIG['data_path'] = args.data_path 455 | CONFIG['lr'] = args.lr 456 | CONFIG['ptmlr'] = args.ptm_lr 457 | CONFIG['epochs'] = args.epochs 458 | CONFIG['bert_path'] = args.bert_path 459 | CONFIG['batch_size'] = args.batch 460 | CONFIG['dropout'] = args.dropout 461 | CONFIG['wp'] = args.warm_up 462 | CONFIG['p_unk'] = args.p_unk 463 | CONFIG['accumulation_steps'] = args.accumulation_steps 464 | CONFIG['task_name'] = args.task_name 465 | train_data_path = os.path.join(CONFIG['data_path'], 'train_sent_emo.csv') 466 | test_data_path = os.path.join(CONFIG['data_path'], 'test_sent_emo.csv') 467 | dev_data_path = os.path.join(CONFIG['data_path'], 'dev_sent_emo.csv') 468 | if args.task_name =='emorynlp': 469 | train_data_path = os.path.join(CONFIG['data_path'], 'emorynlp_train_final.csv') 470 | test_data_path = os.path.join(CONFIG['data_path'], 'emorynlp_test_final.csv') 471 | dev_data_path = os.path.join(CONFIG['data_path'], 'emorynlp_dev_final.csv') 472 | os.makedirs('vocabs', exist_ok=True) 473 | os.makedirs('models', exist_ok=True) 474 | seed = 1024 475 | torch.manual_seed(seed) 476 | np.random.seed(seed) 477 | random.seed(seed) 478 | torch.backends.cudnn.benchmark = True 479 | # torch.autograd.set_detect_anomaly(True) 480 | get_vocabs([train_data_path, dev_data_path, test_data_path], 481 | 'friends_transcript.json') 482 | # model = PortraitModel(CONFIG) 483 | model = CRFModel(CONFIG) 484 | device = CONFIG['device'] 485 | model.to(device) 486 | print('---config---') 487 | for k, v in CONFIG.items(): 488 | print(k, '\t\t\t', v, flush=True) 489 | if args.finetune: 490 | lst = os.listdir('./models') 491 | lst = list(filter(lambda item: item.endswith('.pkl'), lst)) 492 | lst.sort(key=lambda x: os.path.getmtime(os.path.join('models', x))) 493 | model = torch.load(os.path.join('models', lst[-1])) 494 | print('checkpoint {} is loaded'.format( 495 | os.path.join('models', lst[-1])), flush=True) 496 | if args.train: 497 | train(model, train_data_path, dev_data_path, test_data_path) 498 | if args.test: 499 | # testset = load_meld_and_builddataset(dev_data_path) 500 | if args.task_name =='emorynlp': 501 | testset = load_emorynlp_and_builddataset(test_data_path) 502 | if args.task_name == 'meld': 503 | testset = load_meld_and_builddataset(test_data_path) 504 | best_f1 = test(model, testset) 505 | print(best_f1) 506 | 507 | 508 | # python train.py -tr -wp 0 -bsz 1 -acc_step 8 -lr 1e-4 -ptmlr 1e-5 -dpt 0.3 >> output.log 0.6505 --------------------------------------------------------------------------------