├── Code ├── Igor │ ├── predict.py │ └── train.py ├── Moiz │ ├── data_prep.py │ ├── data_utils.py │ ├── model_utils.py │ ├── predict.py │ └── train.py └── Ujjwal │ ├── encode_text.py │ ├── finetune_xlm.py │ ├── inference.py │ ├── inference.sh │ ├── post_process.py │ ├── pretrain_xlm.py │ └── train.sh ├── README.md ├── SETTINGS.json ├── blend.py ├── directory_structure.txt ├── entry_points.md ├── img ├── 1.png ├── 2.png └── 3.png ├── inference.py ├── prepare_data_inference.py ├── prepare_data_train.py ├── pytorch-xla-env-setup.py ├── requirements.txt └── train.py /Code/Igor/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script takes the following thing as input: 3 | 1) An Input file to be scored 4 | 2) BACKBONE used for tokenizer 5 | 3) model checkpoints prefix 6 | and produces: 7 | 1) Scored input file 8 | 9 | Parameters: 10 | - backbone: Bert based model from huggingface transformers. 'camembert/camembert-large' for example 11 | - model_file_prefix: prefix of checkpoints name of saved model 12 | - in_file: The name of (not-processed) input file to be scored. In our case it's test.csv.zip from competition dataset 13 | Output file = {model_file_prefix}.probs.csv 14 | """ 15 | 16 | import numpy as np 17 | import pandas as pd 18 | import os 19 | os.environ['XLA_USE_BF16'] = "1" 20 | from glob import glob 21 | import torch 22 | import torch.nn as nn 23 | from torch.utils.data import Dataset,DataLoader 24 | from torch.autograd import Variable 25 | from torch.utils.data.sampler import SequentialSampler, RandomSampler 26 | import sklearn 27 | import time 28 | import random 29 | from datetime import datetime 30 | from tqdm import tqdm 31 | tqdm.pandas() 32 | from transformers import AutoModel, AutoTokenizer 33 | from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule 34 | import gc 35 | import re 36 | import nltk 37 | nltk.download('punkt') 38 | from nltk import sent_tokenize 39 | from pandarallel import pandarallel 40 | import argparse 41 | pandarallel.initialize(nb_workers=4, progress_bar=False) 42 | 43 | 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--backbone', type=str) 46 | parser.add_argument('--model_file_prefix', type=str) 47 | parser.add_argument('--in_file', type=str) 48 | parsed_args = parser.parse_args() 49 | 50 | 51 | SEED = 42 52 | MAX_LENGTH = 224 53 | FILE_NAME = parsed_args.model_file_prefix 54 | BACKBONE_PATH = parsed_args.backbone 55 | 56 | def seed_everything(seed): 57 | random.seed(seed) 58 | os.environ['PYTHONHASHSEED'] = str(seed) 59 | np.random.seed(seed) 60 | torch.manual_seed(seed) 61 | torch.cuda.manual_seed(seed) 62 | torch.backends.cudnn.deterministic = True 63 | torch.backends.cudnn.benchmark = True 64 | 65 | seed_everything(SEED) 66 | 67 | from nltk import sent_tokenize 68 | from random import shuffle 69 | import random 70 | import albumentations 71 | from albumentations.core.transforms_interface import DualTransform, BasicTransform 72 | 73 | 74 | LANGS = { 75 | 'en': 'english', 76 | 'it': 'italian', 77 | 'fr': 'french', 78 | 'es': 'spanish', 79 | 'tr': 'turkish', 80 | 'ru': 'russian', 81 | 'pt': 'portuguese' 82 | } 83 | 84 | def get_sentences(text, lang='en'): 85 | return sent_tokenize(text, LANGS.get(lang, 'english')) 86 | 87 | def exclude_duplicate_sentences(text, lang='en'): 88 | sentences = [] 89 | for sentence in get_sentences(text, lang): 90 | sentence = sentence.strip() 91 | if sentence not in sentences: 92 | sentences.append(sentence) 93 | return ' '.join(sentences) 94 | 95 | def clean_text(text, lang='en'): 96 | text = str(text) 97 | text = re.sub(r'[0-9"]', '', text) 98 | text = re.sub(r'#[\S]+\b', '', text) 99 | text = re.sub(r'@[\S]+\b', '', text) 100 | text = re.sub(r'https?\S+', '', text) 101 | text = re.sub(r'\s+', ' ', text) 102 | text = exclude_duplicate_sentences(text, lang) 103 | return text.strip() 104 | 105 | tokenizer = AutoTokenizer.from_pretrained(BACKBONE_PATH) 106 | class DatasetRetriever(Dataset): 107 | 108 | def __init__(self, df): 109 | self.comment_texts = df['comment_text'].values 110 | self.ids = df['id'].values 111 | 112 | 113 | def get_tokens(self, text): 114 | encoded = tokenizer.encode_plus( 115 | text, 116 | add_special_tokens=True, 117 | max_length=MAX_LENGTH, 118 | pad_to_max_length=True 119 | ) 120 | return encoded['input_ids'], encoded['attention_mask'] 121 | 122 | def __len__(self): 123 | return self.ids.shape[0] 124 | 125 | def __getitem__(self, idx): 126 | text = self.comment_texts[idx] 127 | 128 | tokens, attention_mask = self.get_tokens(text) 129 | tokens, attention_mask = torch.tensor(tokens), torch.tensor(attention_mask) 130 | 131 | return self.ids[idx], tokens, attention_mask 132 | 133 | df_test = pd.read_csv(parsed_args.in_file) 134 | df_test['comment_text'] = df_test.parallel_apply(lambda x: clean_text(x['content'], x['lang']), axis=1) 135 | df_test = df_test.drop(columns=['content']) 136 | 137 | test_dataset = DatasetRetriever(df_test) 138 | 139 | class ToxicSimpleNNModel(nn.Module): 140 | 141 | def __init__(self): 142 | super(ToxicSimpleNNModel, self).__init__() 143 | self.backbone = AutoModel.from_pretrained(BACKBONE_PATH) 144 | self.dropout = nn.Dropout(0.3) 145 | self.linear = nn.Linear( 146 | in_features=self.backbone.pooler.dense.out_features*2, 147 | out_features=2, 148 | ) 149 | 150 | def forward(self, input_ids, attention_masks): 151 | bs, seq_length = input_ids.shape 152 | seq_x, _ = self.backbone(input_ids=input_ids, attention_mask=attention_masks) 153 | apool = torch.mean(seq_x, 1) 154 | mpool, _ = torch.max(seq_x, 1) 155 | x = torch.cat((apool, mpool), 1) 156 | x = self.dropout(x) 157 | return self.linear(x) 158 | 159 | import warnings 160 | 161 | warnings.filterwarnings("ignore") 162 | 163 | import torch_xla 164 | import torch_xla.core.xla_model as xm 165 | import torch_xla.distributed.parallel_loader as pl 166 | import torch_xla.distributed.xla_multiprocessing as xmp 167 | 168 | 169 | class MultiTPUPredictor: 170 | 171 | def __init__(self, model, device): 172 | 173 | 174 | 175 | self.model = model 176 | self.device = device 177 | 178 | xm.master_print(f'Model prepared. Device is {self.device}') 179 | 180 | 181 | def run_inference(self, test_loader, e, verbose=True, verbose_step=50): 182 | self.model.eval() 183 | result = {'id': [], 'toxic': []} 184 | t = time.time() 185 | for step, (ids, inputs, attention_masks) in enumerate(test_loader): 186 | if verbose: 187 | if step % 50 == 0: 188 | xm.master_print(f'Prediction Step {step}, time: {(time.time() - t):.5f}') 189 | 190 | with torch.no_grad(): 191 | inputs = inputs.to(self.device, dtype=torch.long) 192 | attention_masks = attention_masks.to(self.device, dtype=torch.long) 193 | outputs = self.model(inputs, attention_masks) 194 | toxics = nn.functional.softmax(outputs, dim=1).data.cpu().numpy()[:,1] 195 | 196 | result['id'].extend(ids.numpy()) 197 | result['toxic'].extend(toxics) 198 | 199 | result = pd.DataFrame(result) 200 | node_count = len(glob('node_submissions/*.csv')) 201 | result.to_csv(f'node_submissions/submission_{e}_{node_count}_{datetime.utcnow().microsecond}.csv', index=False) 202 | 203 | def _mp_fn(rank, flags): 204 | device = xm.xla_device() 205 | model = net.to(device) 206 | 207 | test_sampler = torch.utils.data.distributed.DistributedSampler( 208 | test_dataset, 209 | num_replicas=xm.xrt_world_size(), 210 | rank=xm.get_ordinal(), 211 | shuffle=False 212 | ) 213 | test_loader = torch.utils.data.DataLoader( 214 | test_dataset, 215 | batch_size=16, 216 | sampler=test_sampler, 217 | pin_memory=False, 218 | drop_last=False, 219 | num_workers=1 220 | ) 221 | 222 | fitter = MultiTPUPredictor(model=model, device=device) 223 | fitter.run_inference(test_loader,E) 224 | 225 | if not os.path.exists('node_submissions'): 226 | os.makedirs('node_submissions') 227 | else: 228 | files = glob('node_submissions/*') 229 | for f in files: 230 | os.remove(f) 231 | net = ToxicSimpleNNModel() 232 | checkpoints = glob(f'{FILE_NAME}*bin') 233 | for n,cp in enumerate(checkpoints): 234 | E = n + 1 235 | print(f'cp number {E}') 236 | checkpoint = torch.load(cp, map_location=torch.device('cpu')) 237 | net.load_state_dict(checkpoint); 238 | checkpoint = None 239 | del checkpoint 240 | FLAGS={} 241 | xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork') 242 | 243 | submission = pd.concat([pd.read_csv(path) for path in glob('node_submissions/*.csv')]).groupby('id').mean() 244 | submission.to_csv(parsed_args.model_file_prefix + ".prob.csv") 245 | -------------------------------------------------------------------------------- /Code/Igor/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script takes the following thing as input: 3 | 1) An Input file for training 4 | 2) BACKBONE of Bert based models in huggingface format 5 | 3) validation file if it exists 6 | 4) open_subtitles file 7 | 5) LANGUAGE of file 8 | 6) use or not use validation file for tuning 9 | and produces: 10 | 1) models checkpoints 11 | 12 | Parameters: 13 | - backbone: Bert based model from huggingface transformers. 'camembert/camembert-large' for example 14 | - train_file: The name of (not-processed) train file 15 | - val_file: The name of (not-processed) validation file 16 | - os_file: The name of open_subtitles file 17 | - model_file_prefix: Full path and prefix of model checkpoints 18 | - lang: language of file 19 | - val_tune: 0 or 1 ; 1 - use validation file for tuning, 0 - don't use 20 | """ 21 | 22 | 23 | import numpy as np 24 | import pandas as pd 25 | import os 26 | os.environ['XLA_USE_BF16'] = "1" 27 | 28 | from glob import glob 29 | 30 | import torch 31 | import torch.nn as nn 32 | from torch.utils.data import Dataset,DataLoader 33 | from torch.autograd import Variable 34 | from torch.utils.data.sampler import SequentialSampler, RandomSampler 35 | import sklearn 36 | 37 | import time 38 | import random 39 | from datetime import datetime 40 | from tqdm import tqdm 41 | tqdm.pandas() 42 | 43 | from transformers import BertModel, BertTokenizer 44 | from transformers import XLMRobertaModel, XLMRobertaTokenizer, AutoModel, AutoTokenizer 45 | from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule 46 | from catalyst.data.sampler import DistributedSamplerWrapper, BalanceClassSampler 47 | 48 | import gc 49 | import re 50 | 51 | import nltk 52 | nltk.download('punkt') 53 | 54 | from nltk import sent_tokenize 55 | 56 | from pandarallel import pandarallel 57 | 58 | pandarallel.initialize(nb_workers=4, progress_bar=False) 59 | 60 | import argparse 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--backbone', type=str) 64 | parser.add_argument('--model_file_prefix', type=str) 65 | parser.add_argument('--train_file', type=str) 66 | parser.add_argument('--val_file', type=str) 67 | parser.add_argument('--val_tune', type=int) 68 | parser.add_argument('--os_file', type=str) 69 | parser.add_argument('--lang', type=str) 70 | args = parser.parse_args() 71 | 72 | SEED = 42 73 | 74 | MAX_LENGTH = 224 75 | FILE_NAME = args.model_file_prefix 76 | 77 | BACKBONE_PATH = args.backbone 78 | 79 | 80 | def seed_everything(seed): 81 | random.seed(seed) 82 | os.environ['PYTHONHASHSEED'] = str(seed) 83 | np.random.seed(seed) 84 | torch.manual_seed(seed) 85 | torch.cuda.manual_seed(seed) 86 | torch.backends.cudnn.deterministic = True 87 | torch.backends.cudnn.benchmark = True 88 | 89 | seed_everything(SEED) 90 | 91 | from nltk import sent_tokenize 92 | from random import shuffle 93 | import random 94 | import albumentations 95 | from albumentations.core.transforms_interface import DualTransform, BasicTransform 96 | 97 | 98 | LANGS = { 99 | 'en': 'english', 100 | 'it': 'italian', 101 | 'fr': 'french', 102 | 'es': 'spanish', 103 | 'tr': 'turkish', 104 | 'ru': 'russian', 105 | 'pt': 'portuguese' 106 | } 107 | 108 | def get_sentences(text, lang='en'): 109 | return sent_tokenize(text, LANGS.get(lang, 'english')) 110 | 111 | def exclude_duplicate_sentences(text, lang='en'): 112 | sentences = [] 113 | for sentence in get_sentences(text, lang): 114 | sentence = sentence.strip() 115 | if sentence not in sentences: 116 | sentences.append(sentence) 117 | return ' '.join(sentences) 118 | 119 | def clean_text(text, lang='en'): 120 | text = str(text) 121 | text = re.sub(r'[0-9"]', '', text) 122 | text = re.sub(r'#[\S]+\b', '', text) 123 | text = re.sub(r'@[\S]+\b', '', text) 124 | text = re.sub(r'https?\S+', '', text) 125 | text = re.sub(r'\s+', ' ', text) 126 | text = exclude_duplicate_sentences(text, lang) 127 | return text.strip() 128 | 129 | 130 | class NLPTransform(BasicTransform): 131 | """ Transform for nlp task.""" 132 | 133 | @property 134 | def targets(self): 135 | return {"data": self.apply} 136 | 137 | def update_params(self, params, **kwargs): 138 | if hasattr(self, "interpolation"): 139 | params["interpolation"] = self.interpolation 140 | if hasattr(self, "fill_value"): 141 | params["fill_value"] = self.fill_value 142 | return params 143 | 144 | def get_sentences(self, text, lang='en'): 145 | return sent_tokenize(text, LANGS.get(lang, 'english')) 146 | 147 | class ShuffleSentencesTransform(NLPTransform): 148 | """ Do shuffle by sentence """ 149 | def __init__(self, always_apply=False, p=0.5): 150 | super(ShuffleSentencesTransform, self).__init__(always_apply, p) 151 | 152 | def apply(self, data, **params): 153 | text, lang = data 154 | sentences = self.get_sentences(text, lang) 155 | random.shuffle(sentences) 156 | return ' '.join(sentences), lang 157 | 158 | class ExcludeDuplicateSentencesTransform(NLPTransform): 159 | """ Exclude equal sentences """ 160 | def __init__(self, always_apply=False, p=0.5): 161 | super(ExcludeDuplicateSentencesTransform, self).__init__(always_apply, p) 162 | 163 | def apply(self, data, **params): 164 | text, lang = data 165 | sentences = [] 166 | for sentence in self.get_sentences(text, lang): 167 | sentence = sentence.strip() 168 | if sentence not in sentences: 169 | sentences.append(sentence) 170 | return ' '.join(sentences), lang 171 | 172 | class ExcludeNumbersTransform(NLPTransform): 173 | """ exclude any numbers """ 174 | def __init__(self, always_apply=False, p=0.5): 175 | super(ExcludeNumbersTransform, self).__init__(always_apply, p) 176 | 177 | def apply(self, data, **params): 178 | text, lang = data 179 | text = re.sub(r'[0-9]', '', text) 180 | text = re.sub(r'\s+', ' ', text) 181 | return text, lang 182 | 183 | class ExcludeHashtagsTransform(NLPTransform): 184 | """ Exclude any hashtags with # """ 185 | def __init__(self, always_apply=False, p=0.5): 186 | super(ExcludeHashtagsTransform, self).__init__(always_apply, p) 187 | 188 | def apply(self, data, **params): 189 | text, lang = data 190 | text = re.sub(r'#[\S]+\b', '', text) 191 | text = re.sub(r'\s+', ' ', text) 192 | return text, lang 193 | 194 | class ExcludeUsersMentionedTransform(NLPTransform): 195 | """ Exclude @users """ 196 | def __init__(self, always_apply=False, p=0.5): 197 | super(ExcludeUsersMentionedTransform, self).__init__(always_apply, p) 198 | 199 | def apply(self, data, **params): 200 | text, lang = data 201 | text = re.sub(r'@[\S]+\b', '', text) 202 | text = re.sub(r'\s+', ' ', text) 203 | return text, lang 204 | 205 | class ExcludeUrlsTransform(NLPTransform): 206 | """ Exclude urls """ 207 | def __init__(self, always_apply=False, p=0.5): 208 | super(ExcludeUrlsTransform, self).__init__(always_apply, p) 209 | 210 | def apply(self, data, **params): 211 | text, lang = data 212 | text = re.sub(r'https?\S+', '', text) 213 | text = re.sub(r'\s+', ' ', text) 214 | return text, lang 215 | 216 | class SynthesicOpenSubtitlesTransform(NLPTransform): 217 | def __init__(self, always_apply=False, p=0.5): 218 | super(SynthesicOpenSubtitlesTransform, self).__init__(always_apply, p) 219 | df = pd.read_csv(args.os_file, index_col='id')[['comment_text', 'toxic', 'lang']] 220 | df = df[~df['comment_text'].isna()] 221 | df = df[df.lang == args.lang] 222 | df['comment_text'] = df.parallel_apply(lambda x: clean_text(x['comment_text'], x['lang']), axis=1) 223 | df = df.drop_duplicates(subset='comment_text') 224 | df['toxic'] = df['toxic'].round().astype(np.int) 225 | 226 | self.synthesic_toxic = df[df['toxic'] == 1].comment_text.values 227 | self.synthesic_non_toxic = df[df['toxic'] == 0].comment_text.values 228 | 229 | del df 230 | gc.collect(); 231 | 232 | def generate_synthesic_sample(self, text, toxic): 233 | texts = [text] 234 | if toxic == 0: 235 | for i in range(random.randint(1,5)): 236 | texts.append(random.choice(self.synthesic_non_toxic)) 237 | else: 238 | for i in range(random.randint(0,2)): 239 | texts.append(random.choice(self.synthesic_non_toxic)) 240 | 241 | for i in range(random.randint(1,3)): 242 | texts.append(random.choice(self.synthesic_toxic)) 243 | random.shuffle(texts) 244 | return ' '.join(texts) 245 | 246 | def apply(self, data, **params): 247 | text, toxic = data 248 | text = self.generate_synthesic_sample(text, toxic) 249 | return text, toxic 250 | 251 | def get_train_transforms(): 252 | return albumentations.Compose([ 253 | ExcludeUsersMentionedTransform(p=0.95), 254 | ExcludeUrlsTransform(p=0.95), 255 | ExcludeNumbersTransform(p=0.95), 256 | ExcludeHashtagsTransform(p=0.95), 257 | ExcludeDuplicateSentencesTransform(p=0.95), 258 | ], p=1.0) 259 | 260 | def get_synthesic_transforms(): 261 | return SynthesicOpenSubtitlesTransform(p=0.5) 262 | 263 | 264 | train_transforms = get_train_transforms(); 265 | synthesic_transforms = get_synthesic_transforms() 266 | tokenizer = AutoTokenizer.from_pretrained(BACKBONE_PATH) 267 | shuffle_transforms = ShuffleSentencesTransform(always_apply=True) 268 | 269 | def onehot(size, target): 270 | vec = torch.zeros(size, dtype=torch.float32) 271 | vec[target] = 1. 272 | return vec 273 | 274 | class DatasetRetriever(Dataset): 275 | 276 | def __init__(self, labels, comment_texts, langs, ids, use_train_transforms=False, test=False): 277 | self.test = test 278 | self.labels = labels 279 | self.ids = ids 280 | self.comment_texts = comment_texts 281 | self.langs = langs 282 | self.use_train_transforms = use_train_transforms 283 | 284 | def get_tokens(self, text): 285 | encoded = tokenizer.encode_plus( 286 | text, 287 | add_special_tokens=True, 288 | max_length=MAX_LENGTH, 289 | pad_to_max_length=True 290 | ) 291 | return encoded['input_ids'], encoded['attention_mask'] 292 | 293 | def __len__(self): 294 | return self.comment_texts.shape[0] 295 | 296 | def __getitem__(self, idx): 297 | text = self.comment_texts[idx] 298 | lang = self.langs[idx] 299 | tmp_id = self.ids[idx] 300 | if self.test is False: 301 | label = self.labels[idx] 302 | target = onehot(2, label) 303 | 304 | if self.use_train_transforms: 305 | text, _ = train_transforms(data=(text, lang))['data'] 306 | tokens, attention_mask = self.get_tokens(str(text)) 307 | token_length = sum(attention_mask) 308 | if token_length > 0.8*MAX_LENGTH: 309 | text, _ = shuffle_transforms(data=(text, lang))['data'] 310 | elif token_length < 60: 311 | text, _ = synthesic_transforms(data=(text, label))['data'] 312 | else: 313 | tokens, attention_mask = torch.tensor(tokens), torch.tensor(attention_mask) 314 | return target, tokens, attention_mask, tmp_id 315 | 316 | tokens, attention_mask = self.get_tokens(str(text)) 317 | tokens, attention_mask = torch.tensor(tokens), torch.tensor(attention_mask) 318 | 319 | if self.test is False: 320 | return target, tokens, attention_mask,tmp_id 321 | return tmp_id, tokens, attention_mask 322 | 323 | def get_labels(self): 324 | return list(np.char.add(self.labels.astype(str), self.langs)) 325 | 326 | 327 | 328 | 329 | df_train = pd.read_csv(args.train_file) 330 | 331 | 332 | df_train = df_train[~df_train['comment_text'].isna()] 333 | 334 | if os.path.isfile(args.val_file) and args.val_tune!=1: 335 | df_add = pd.read_csv(args.val_file) 336 | print(df_add.shape) 337 | df_add = df_add[df_add.lang == args.lang] 338 | df_add['comment_text'] = df_add.parallel_apply(lambda x: clean_text(x['comment_text'], x['lang']), axis=1) 339 | df_add = df_add[['id','comment_text','toxic']] 340 | df_train = pd.concat([df_train,df_add]) 341 | 342 | df_train['lang'] = args.lang 343 | df_train = df_train[~df_train['comment_text'].isna()] 344 | df_train['comment_text'] = df_train.parallel_apply(lambda x: clean_text(x['comment_text'], x['lang']), axis=1) 345 | 346 | 347 | print(df_train.shape) 348 | 349 | train_dataset = DatasetRetriever( 350 | labels=df_train['toxic'].values, 351 | comment_texts=df_train['comment_text'].values, 352 | langs=df_train['lang'].values, 353 | ids=df_train.index.values, 354 | use_train_transforms=True, 355 | ) 356 | 357 | del df_train 358 | gc.collect(); 359 | 360 | for targets, tokens, attention_masks, ids in train_dataset: 361 | break 362 | 363 | print(targets) 364 | print(tokens.shape) 365 | print(attention_masks.shape) 366 | print(ids) 367 | 368 | if os.path.isfile(args.val_file): 369 | df_val = pd.read_csv(args.val_file, index_col='id') 370 | print(df_val.shape) 371 | df_val = df_val[df_val.lang == args.lang] 372 | validation_tune_dataset = DatasetRetriever( 373 | labels=df_val['toxic'].values, 374 | comment_texts=df_val['comment_text'].values, 375 | langs=df_val['lang'].values, 376 | ids=df_val.index.values, 377 | use_train_transforms=True, 378 | ) 379 | 380 | df_val['comment_text'] = df_val.parallel_apply(lambda x: clean_text(x['comment_text'], x['lang']), axis=1) 381 | 382 | validation_dataset = DatasetRetriever( 383 | labels=df_val['toxic'].values, 384 | comment_texts=df_val['comment_text'].values, 385 | langs=df_val['lang'].values, 386 | ids=df_val.index.values, 387 | use_train_transforms=False, 388 | ) 389 | 390 | del df_val 391 | gc.collect(); 392 | 393 | for targets, tokens, attention_masks, ids in validation_dataset: 394 | break 395 | 396 | print(targets) 397 | print(tokens.shape) 398 | print(attention_masks.shape) 399 | 400 | 401 | class RocAucMeter(object): 402 | def __init__(self): 403 | self.reset() 404 | 405 | def reset(self): 406 | self.y_true = np.array([0,1]) 407 | self.y_pred = np.array([0.5,0.5]) 408 | self.score = 0 409 | 410 | def update(self, y_true, y_pred): 411 | y_true = y_true.cpu().numpy().argmax(axis=1) 412 | y_pred = nn.functional.softmax(y_pred, dim=1).data.cpu().numpy()[:,1] 413 | self.y_true = np.hstack((self.y_true, y_true)) 414 | self.y_pred = np.hstack((self.y_pred, y_pred)) 415 | self.score = sklearn.metrics.roc_auc_score(self.y_true, self.y_pred, labels=np.array([0, 1])) 416 | 417 | @property 418 | def avg(self): 419 | return self.score 420 | 421 | class AverageMeter(object): 422 | """Computes and stores the average and current value""" 423 | def __init__(self): 424 | self.reset() 425 | 426 | def reset(self): 427 | self.val = 0 428 | self.avg = 0 429 | self.sum = 0 430 | self.count = 0 431 | 432 | def update(self, val, n=1): 433 | self.val = val 434 | self.sum += val * n 435 | self.count += n 436 | self.avg = self.sum / self.count 437 | 438 | class LabelSmoothing(nn.Module): 439 | def __init__(self, smoothing = 0.1): 440 | super(LabelSmoothing, self).__init__() 441 | self.confidence = 1.0 - smoothing 442 | self.smoothing = smoothing 443 | 444 | def forward(self, x, target): 445 | if self.training: 446 | x = x.float() 447 | target = target.float() 448 | logprobs = torch.nn.functional.log_softmax(x, dim = -1) 449 | 450 | nll_loss = -logprobs * target 451 | nll_loss = nll_loss.sum(-1) 452 | 453 | smooth_loss = -logprobs.mean(dim=-1) 454 | 455 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 456 | 457 | return loss.mean() 458 | else: 459 | return torch.nn.functional.cross_entropy(x, target) 460 | 461 | import warnings 462 | 463 | warnings.filterwarnings("ignore") 464 | 465 | import torch_xla 466 | import torch_xla.core.xla_model as xm 467 | import torch_xla.distributed.parallel_loader as pl 468 | import torch_xla.distributed.xla_multiprocessing as xmp 469 | 470 | from catalyst.data.sampler import DistributedSamplerWrapper, BalanceClassSampler 471 | 472 | class TPUFitter: 473 | 474 | def __init__(self, model, device, config): 475 | if not os.path.exists('node_submissions'): 476 | os.makedirs('node_submissions') 477 | else: 478 | files = glob('node_submissions/*') 479 | for f in files: 480 | os.remove(f) 481 | 482 | 483 | 484 | 485 | self.config = config 486 | self.epoch = 0 487 | 488 | self.model = model 489 | self.device = device 490 | 491 | param_optimizer = list(self.model.named_parameters()) 492 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 493 | optimizer_grouped_parameters = [ 494 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001}, 495 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 496 | ] 497 | 498 | self.optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr*xm.xrt_world_size()) 499 | self.scheduler = config.SchedulerClass(self.optimizer, **config.scheduler_params) 500 | 501 | self.criterion = config.criterion 502 | xm.master_print(f'Fitter prepared. Device is {self.device}') 503 | 504 | def fit(self, train_loader): 505 | for e in range(self.config.n_epochs): 506 | if self.config.verbose: 507 | lr = self.optimizer.param_groups[0]['lr'] 508 | timestamp = datetime.utcnow().isoformat() 509 | self.log(f'\n{timestamp}\nLR: {lr}') 510 | 511 | t = time.time() 512 | para_loader = pl.ParallelLoader(train_loader, [self.device]) 513 | save_flag = 1 514 | if e == 0 or args.val_tune == 1: 515 | save_flag = 0 516 | losses, final_scores = self.train_one_epoch(para_loader.per_device_loader(self.device),e,save_flag) 517 | 518 | self.log(f'[RESULT]: Train. Epoch: {self.epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, time: {(time.time() - t):.5f}') 519 | 520 | t = time.time() 521 | 522 | if self.config.validation_scheduler: 523 | self.scheduler.step(metrics=final_scores.avg) 524 | 525 | 526 | self.epoch += 1 527 | 528 | def run_tuning_and_inference(self, validation_tune_loader): 529 | for e in range(2): 530 | self.optimizer.param_groups[0]['lr'] = self.config.lr*xm.xrt_world_size() 531 | para_loader = pl.ParallelLoader(validation_tune_loader, [self.device]) 532 | losses, final_scores = self.train_one_epoch(para_loader.per_device_loader(self.device),e,1) 533 | 534 | 535 | 536 | def train_one_epoch(self, train_loader,e,save_flag): 537 | self.model.train() 538 | 539 | losses = AverageMeter() 540 | final_scores = RocAucMeter() 541 | t = time.time() 542 | for step, (targets, inputs, attention_masks, ids) in enumerate(train_loader): 543 | if self.config.verbose: 544 | if step % self.config.verbose_step == 0: 545 | self.log( 546 | f'Train Step {step}, loss: ' + \ 547 | f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \ 548 | f'time: {(time.time() - t):.5f}' 549 | ) 550 | 551 | inputs = inputs.to(self.device, dtype=torch.long) 552 | attention_masks = attention_masks.to(self.device, dtype=torch.long) 553 | targets = targets.to(self.device, dtype=torch.float) 554 | 555 | self.optimizer.zero_grad() 556 | 557 | outputs = self.model(inputs, attention_masks) 558 | loss = self.criterion(outputs, targets) 559 | 560 | batch_size = inputs.size(0) 561 | 562 | final_scores.update(targets, outputs) 563 | 564 | losses.update(loss.detach().item(), batch_size) 565 | 566 | loss.backward() 567 | xm.optimizer_step(self.optimizer) 568 | 569 | if self.config.step_scheduler: 570 | self.scheduler.step() 571 | 572 | self.model.eval() 573 | if save_flag==1: 574 | self.save(f'{FILE_NAME}_epoch_{e}.bin') 575 | return losses, final_scores 576 | 577 | def save(self, path): 578 | xm.save(self.model.state_dict(), path) 579 | 580 | def log(self, message): 581 | if self.config.verbose: 582 | xm.master_print(message) 583 | 584 | 585 | 586 | from transformers import XLMRobertaModel 587 | 588 | class ToxicSimpleNNModel(nn.Module): 589 | 590 | def __init__(self): 591 | super(ToxicSimpleNNModel, self).__init__() 592 | self.backbone = AutoModel.from_pretrained(BACKBONE_PATH) 593 | self.dropout = nn.Dropout(0.3) 594 | self.linear = nn.Linear( 595 | in_features=self.backbone.pooler.dense.out_features*2, 596 | out_features=2, 597 | ) 598 | 599 | def forward(self, input_ids, attention_masks): 600 | bs, seq_length = input_ids.shape 601 | seq_x, _ = self.backbone(input_ids=input_ids, attention_mask=attention_masks) 602 | apool = torch.mean(seq_x, 1) 603 | mpool, _ = torch.max(seq_x, 1) 604 | x = torch.cat((apool, mpool), 1) 605 | x = self.dropout(x) 606 | return self.linear(x) 607 | 608 | class TrainGlobalConfig: 609 | """ Global Config for this notebook """ 610 | num_workers = 0 611 | batch_size = 16 # bs 612 | n_epochs = 8 613 | lr = 0.5 * 1e-5 614 | fold_number = 0 615 | 616 | # ------------------- 617 | verbose = True 618 | verbose_step = 50 619 | # ------------------- 620 | 621 | # -------------------- 622 | step_scheduler = False 623 | validation_scheduler = True 624 | SchedulerClass = torch.optim.lr_scheduler.ReduceLROnPlateau 625 | scheduler_params = dict( 626 | mode='max', 627 | factor=0.7, 628 | patience=1000, 629 | verbose=False, 630 | threshold=0.0001, 631 | threshold_mode='abs', 632 | cooldown=0, 633 | min_lr=1e-8, 634 | eps=1e-08 635 | ) 636 | # -------------------- 637 | 638 | # ------------------- 639 | criterion = LabelSmoothing(smoothing=0.1) 640 | # ------------------- 641 | 642 | net = ToxicSimpleNNModel() 643 | 644 | def _mp_fn(rank, flags): 645 | device = xm.xla_device() 646 | net.to(device) 647 | 648 | train_sampler = DistributedSamplerWrapper( 649 | sampler=BalanceClassSampler(labels=train_dataset.get_labels(), mode="downsampling"), 650 | num_replicas=xm.xrt_world_size(), 651 | rank=xm.get_ordinal(), 652 | shuffle=True 653 | ) 654 | train_loader = torch.utils.data.DataLoader( 655 | train_dataset, 656 | batch_size=TrainGlobalConfig.batch_size, 657 | sampler=train_sampler, 658 | pin_memory=False, 659 | drop_last=True, 660 | num_workers=TrainGlobalConfig.num_workers, 661 | ) 662 | if rank == 0: 663 | time.sleep(1) 664 | 665 | fitter = TPUFitter(model=net, device=device, config=TrainGlobalConfig) 666 | fitter.fit(train_loader) 667 | if os.path.isfile(args.val_file) and args.val_tune==1: 668 | validation_sampler = torch.utils.data.distributed.DistributedSampler( 669 | validation_dataset, 670 | num_replicas=xm.xrt_world_size(), 671 | rank=xm.get_ordinal(), 672 | shuffle=False 673 | ) 674 | validation_loader = torch.utils.data.DataLoader( 675 | validation_dataset, 676 | batch_size=TrainGlobalConfig.batch_size, 677 | sampler=validation_sampler, 678 | pin_memory=False, 679 | drop_last=False, 680 | num_workers=TrainGlobalConfig.num_workers 681 | ) 682 | validation_tune_sampler = torch.utils.data.distributed.DistributedSampler( 683 | validation_tune_dataset, 684 | num_replicas=xm.xrt_world_size(), 685 | rank=xm.get_ordinal(), 686 | shuffle=True 687 | ) 688 | validation_tune_loader = torch.utils.data.DataLoader( 689 | validation_tune_dataset, 690 | batch_size=TrainGlobalConfig.batch_size, 691 | sampler=validation_tune_sampler, 692 | pin_memory=False, 693 | drop_last=False, 694 | num_workers=TrainGlobalConfig.num_workers 695 | ) 696 | fitter.run_tuning_and_inference(validation_tune_loader) 697 | 698 | FLAGS={} 699 | xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork') 700 | 701 | -------------------------------------------------------------------------------- /Code/Moiz/data_prep.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./Code/Moiz/") 3 | 4 | import argparse 5 | from data_utils import * 6 | 7 | def data_prep(in_file, file_type, model_name, should_chunk, max_chunk, 8 | long_comment_action, out_dir, out_file_prefix): 9 | """ 10 | Takes an input file, model type as input and produces the 11 | the output file with tokenized text 12 | 13 | :return: 14 | """ 15 | inp = pd.read_csv(in_file) 16 | print(f'Preparing Input: {in_file}') 17 | 18 | if file_type == 'test': 19 | inp = inp[['id', 'lang', 'comment_text']] 20 | elif file_type == 'train': 21 | inp = inp[['id', 'lang', 'comment_text', 'toxic_float']] 22 | else: 23 | raise NotImplementedError 24 | 25 | process_file( 26 | 192, model_name, 27 | inp, file_type, out_dir, out_file_prefix, 28 | should_chunk = should_chunk, chunk_size=100000, 29 | long_comment_action=long_comment_action, max_chunk=max_chunk) 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--in_file', type=str) 35 | parser.add_argument('--in_file_type', type=str) 36 | parser.add_argument('--model_name', type=str) 37 | parser.add_argument('--should_chunk', type=int) 38 | parser.add_argument('--max_chunk', type=int, default=0) 39 | parser.add_argument('--long_comment_action', type=str) 40 | 41 | parser.add_argument('--out_dir', type=str) 42 | parser.add_argument('--out_file_prefix', type=str) 43 | 44 | parsed_args = parser.parse_args() 45 | 46 | should_chunk = parsed_args.should_chunk != 0 47 | 48 | if parsed_args.max_chunk == 0: 49 | max_chunk = None 50 | else: 51 | max_chunk = parsed_args.max_chunk 52 | 53 | data_prep( 54 | parsed_args.in_file, parsed_args.in_file_type, parsed_args.model_name, 55 | should_chunk, max_chunk, 56 | parsed_args.long_comment_action, parsed_args.out_dir, parsed_args.out_file_prefix 57 | ) -------------------------------------------------------------------------------- /Code/Moiz/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from transformers import AutoTokenizer 4 | from tqdm.autonotebook import tqdm 5 | from sklearn.model_selection import StratifiedKFold 6 | import re 7 | import time 8 | 9 | RANDOM_STATE = 5353 10 | np.random.seed(RANDOM_STATE) 11 | # Import 12 | from pandarallel import pandarallel 13 | 14 | # Initialization 15 | pandarallel.initialize() 16 | 17 | 18 | def clean_text(text): 19 | text = str(text) 20 | text = re.sub(r'[0-9"]', '', text) 21 | text = re.sub(r'#[\S]+\b', '', text) 22 | text = re.sub(r'@[\S]+\b', '', text) 23 | text = re.sub(r'https?\S+', '', text) 24 | text = re.sub(r'\s+', ' ', text) 25 | return text 26 | 27 | def encode(texts, tokenizer, pad, max_len): 28 | enc_di = tokenizer.batch_encode_plus( 29 | texts, 30 | return_attention_masks=False, 31 | return_token_type_ids=False, 32 | pad_to_max_length=pad, 33 | max_length=max_len 34 | ) 35 | return [np.array(x) for x in enc_di['input_ids']] 36 | 37 | 38 | def split_encode(inp, window_length, max_len): 39 | """ 40 | :param inp: 41 | :param window_length: 42 | :param max_len: 43 | :return: 44 | """ 45 | out_all = [] 46 | num_unchanged_ids = 0 47 | num_split_ids = 0 48 | num_split_rows = 0 49 | input_id_loc = inp.columns.get_loc('input_ids') 50 | 51 | for row in inp.itertuples(): 52 | if len(row.input_ids) < max_len: 53 | input_ids_current = list(row.input_ids) 54 | num_unchanged_ids += 1 55 | pad_length = max_len - len(row.input_ids) 56 | if pad_length > 0: 57 | input_ids_current += [1] * pad_length 58 | row_ref = list(row)[1:] 59 | row_ref[input_id_loc] = np.array(input_ids_current) 60 | out_all.append(row_ref) 61 | else: 62 | input_ids_content = list(row.input_ids)[1:-1] 63 | num_split_ids += 1 64 | start = 0 65 | while True: 66 | if len(input_ids_content[start:]) <= max_len - 2: 67 | # This is the last row 68 | input_ids_current = [0] + input_ids_content[start:start + max_len - 2] + [2] 69 | pad_length = max_len - len(input_ids_current) 70 | if pad_length > 0: 71 | input_ids_current += [1] * pad_length 72 | row_ref = list(row)[1:] 73 | row_ref[input_id_loc] = np.array(input_ids_current) 74 | out_all.append(row_ref) 75 | num_split_rows += 1 76 | break 77 | else: 78 | input_ids_current = [0] + input_ids_content[start:start + max_len - 2] + [2] 79 | start += window_length 80 | # No padding should be needed 81 | row_ref = list(row)[1:] 82 | row_ref[input_id_loc] = np.array(input_ids_current) 83 | out_all.append(row_ref) 84 | num_split_rows += 1 85 | 86 | out_df = pd.DataFrame(out_all) 87 | out_df.columns = inp.columns 88 | print(f'SUMMARY: unchanged: {num_unchanged_ids} split: {num_split_ids} rows: {num_split_rows} out_df{out_df.shape}') 89 | return out_df 90 | 91 | 92 | def dump_chunk(train, max_len, tokenizer, out_dir, out_prefix, out_suffix, long_comment_action): 93 | train = train.copy().sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True) 94 | 95 | # Clean 96 | print('Started Cleaning: ', time.ctime()) 97 | # Temp - commenting off cleaning for template messages 98 | train['comment_text'] = train.parallel_apply(lambda x: clean_text(x['comment_text']), axis=1) 99 | print('End Cleaning', time.ctime()) 100 | 101 | original_shape = train.shape 102 | if 'strata' in train.columns: 103 | del train['strata'] 104 | 105 | if long_comment_action == 'split': 106 | train['input_ids'] = encode(train['comment_text'].values, tokenizer, False, 1000) 107 | del train['comment_text'] 108 | train = split_encode(train, 100, max_len) 109 | 110 | elif long_comment_action == 'drop': 111 | train['input_ids'] = encode(train['comment_text'].values, tokenizer, True, max_len) 112 | del train['comment_text'] 113 | keep_flag = train['input_ids'].map(lambda x: True if x[-1] == 1 else False) 114 | train = train[keep_flag].reset_index(drop=True) 115 | 116 | elif long_comment_action == 'ignore': 117 | train['input_ids'] = encode(train['comment_text'].values, tokenizer, True, max_len) 118 | del train['comment_text'] 119 | 120 | else: 121 | raise ValueError 122 | 123 | print(f'{out_prefix}_{out_suffix} Original: {original_shape}: Final: {train.shape}') 124 | train.to_pickle(f'{out_dir}/{out_prefix}_{out_suffix}.pkl') 125 | 126 | def process_file( 127 | max_len, 128 | model_name, 129 | train, 130 | file_type : str, 131 | out_dir, 132 | out_prefix, 133 | should_chunk: bool, # whether should break input file into smaller files 134 | chunk_size : int, 135 | long_comment_action: str, 136 | max_chunk = None 137 | ): 138 | """ 139 | 1) Read the full file 140 | 2) Shuffle 141 | 3) Dump into Chunks of 50K samples 142 | 143 | :param inp_file: 144 | :return: 145 | """ 146 | 147 | tokenizer = AutoTokenizer.from_pretrained(model_name) 148 | # train = read_file(f'{in_dir}/{inp_file}', file_type) 149 | 150 | if file_type == 'train': 151 | print(train['toxic_float'].value_counts()) 152 | print(train['toxic_float'].value_counts(normalize=True)) 153 | 154 | train = train.dropna().reset_index(drop=True) 155 | print(f'# Rows after dropping NA: {train.shape}') 156 | 157 | if should_chunk and file_type != 'test': 158 | k_fold = train.shape[0] // chunk_size 159 | if k_fold > 2: 160 | train['strata'] = train['lang'] + train['toxic_float'].astype('str') 161 | print(train['strata'].value_counts()) 162 | print(train['strata'].value_counts(normalize=True)) 163 | sss = StratifiedKFold(n_splits=k_fold, random_state=RANDOM_STATE) 164 | for i, (train_index, test_index) in tqdm(enumerate(sss.split(np.zeros(train.shape[0]), train['strata']))): 165 | if (max_chunk is not None) and (i >= max_chunk): 166 | print('Max Chunks Processed, Exiting') 167 | break 168 | dump_chunk(train.iloc[test_index], max_len, tokenizer, out_dir, out_prefix, f'p{i}', long_comment_action) 169 | 170 | else: 171 | print('Not Enough observation in inp to break into chunks') 172 | dump_chunk(train, max_len, tokenizer, out_dir, out_prefix, 'p0', long_comment_action) 173 | else: 174 | dump_chunk(train, max_len, tokenizer, out_dir, out_prefix, 'p0', long_comment_action) -------------------------------------------------------------------------------- /Code/Moiz/model_utils.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import pandas as pd 3 | import numpy as np 4 | import tensorflow as tf 5 | import random 6 | import os 7 | from transformers import TFAutoModel 8 | from logging import getLogger, Formatter, FileHandler, StreamHandler, INFO 9 | from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling1D, GlobalMaxPool1D, Dropout, Concatenate 10 | from tensorflow.keras.models import Model 11 | from tensorflow.keras.optimizers import Adam 12 | import transformers 13 | from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard 14 | from sklearn.metrics import roc_auc_score, accuracy_score 15 | import matplotlib.pyplot as plt 16 | import gc 17 | 18 | """ 19 | This File contains all the function to train / score a model 20 | """ 21 | 22 | def set_seed(): 23 | seed = 5353 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | tf.random.set_seed(seed) 27 | os.environ['TF_DETERMINISTIC_OPS'] = '1' 28 | 29 | 30 | def init_logger(): 31 | logger = getLogger() 32 | logger.setLevel(INFO) 33 | # if not logger.hasHandlers(): 34 | handler = StreamHandler() 35 | format_str = '%(asctime)s: %(funcName)20s() -- %(message)s' 36 | handler.setFormatter(Formatter(format_str)) 37 | logger.addHandler(handler) 38 | return logger 39 | 40 | 41 | def plotit(train_history): 42 | # summarize history for accuracy 43 | plt.plot(train_history.history['accuracy']) 44 | plt.plot(train_history.history['val_accuracy']) 45 | plt.title('model accuracy') 46 | plt.ylabel('accuracy') 47 | plt.xlabel('epoch') 48 | plt.legend(['train', 'test'], loc='upper left') 49 | plt.show() 50 | 51 | plt.plot(train_history.history['loss']) 52 | plt.plot(train_history.history['val_loss']) 53 | plt.title('model loss') 54 | plt.ylabel('loss') 55 | plt.xlabel('epoch') 56 | plt.legend(['train', 'test'], loc='upper left') 57 | plt.show() 58 | 59 | 60 | lrelu = lambda x: tf.keras.activations.relu(x, alpha=0.1) 61 | 62 | 63 | def build_model(model_name, max_len, init_lr=1e-5, label_smoothing=0): 64 | if model_name == 'jplu/tf-xlm-roberta-large': 65 | config = transformers.RobertaConfig.from_pretrained(model_name) 66 | elif model_name == 'xlm-mlm-100-1280': 67 | config = transformers.XLMConfig.from_pretrained(model_name) 68 | elif model_name == 'bert-base-multilingual-cased': 69 | config = transformers.BertConfig.from_pretrained(model_name) 70 | else: 71 | raise NotImplementedError 72 | config.output_hidden_states = True 73 | transformer_layer = TFAutoModel.from_pretrained(model_name, config=config) 74 | input_word_ids = Input(shape=(max_len,), dtype=tf.int64, name="input_word_ids") 75 | 76 | if model_name == 'xlm-mlm-100-1280': 77 | _, outputs = transformer_layer(input_word_ids) 78 | else: 79 | _, _, outputs = transformer_layer(input_word_ids) 80 | 81 | outputs_sub = Concatenate()([outputs[-1], outputs[-2], outputs[-3], outputs[-4]]) 82 | 83 | avg_pool = GlobalAveragePooling1D()(outputs_sub) 84 | max_pool = GlobalMaxPool1D()(outputs_sub) 85 | 86 | pooled_output = Concatenate()([max_pool, avg_pool]) 87 | 88 | pooled_output = Dropout(0.5)(pooled_output) 89 | out = Dense(300, activation=lrelu)(pooled_output) 90 | out = Dropout(0.5)(out) 91 | out = Dense(1, activation='sigmoid')(out) 92 | 93 | model = Model(inputs=input_word_ids, outputs=out) 94 | loss = tf.keras.losses.BinaryCrossentropy(label_smoothing=label_smoothing) 95 | model.compile(Adam(lr=init_lr), loss=loss, metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]) 96 | return model 97 | 98 | 99 | def get_strategy(): 100 | # TPU Config 101 | try: 102 | # TPU detection. No parameters necessary if TPU_NAME environment variable is 103 | # set: this is always the case on Kaggle. 104 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 105 | print(f'Running on TPU {tpu.master()}') 106 | except ValueError: 107 | tpu = None 108 | 109 | if tpu: 110 | tf.config.experimental_connect_to_cluster(tpu) 111 | tf.tpu.experimental.initialize_tpu_system(tpu) 112 | strategy = tf.distribute.experimental.TPUStrategy(tpu) 113 | else: 114 | # Default distribution strategy in Tensorflow. Works on CPU and single GPU. 115 | print('NOT running on TPU ') 116 | strategy = tf.distribute.get_strategy() 117 | return strategy 118 | 119 | 120 | class Solver: 121 | def __init__(self, logger, config): 122 | self.config = config 123 | self.logger = logger 124 | 125 | self.train_size = None 126 | self.max_epochs = None 127 | self.prefetch_policy = tf.data.experimental.AUTOTUNE 128 | 129 | # Check whether TPU available or need to run on GPU / CPU etc 130 | # Update the Batch size if running in a distributed fashion 131 | self.strategy = get_strategy() 132 | self.config['global_batch_size'] = self.config["batch_size"] * self.strategy.num_replicas_in_sync 133 | self.logger.info(f'Batch Size: {self.config["batch_size"]} ' 134 | f'Global Batch Size: {self.config["global_batch_size"]}') 135 | 136 | def get_data_loaders(self, train_files: list, val_file, class_balance, class_ratio, min_thresh, max_thresh): 137 | 138 | train_data = [] 139 | for index, train_file in enumerate(train_files): 140 | self.logger.info(train_file) 141 | temp = pd.read_pickle(train_file) 142 | self.logger.info(f'# Rows - raw: {temp.shape}') 143 | temp = temp[(temp['toxic_float'] < min_thresh) | (max_thresh <= temp['toxic_float'])] 144 | self.logger.info(f'# Rows - thresholding: {temp.shape}') 145 | temp['toxic'] = temp['toxic_float'].map(lambda x: 1 if x >= max_thresh else 0) 146 | del temp['toxic_float'] 147 | self.logger.info(temp['toxic'].value_counts(normalize=True)) 148 | if class_balance: 149 | num_pos = temp[temp['toxic'] == 1].shape[0] 150 | num_neg = temp[temp['toxic'] == 0].shape[0] 151 | to_sel_neg = np.minimum(num_pos * class_ratio, num_neg) 152 | temp = pd.concat( 153 | [temp[temp['toxic'] == 1], temp[temp['toxic'] == 0].sample(to_sel_neg, random_state=5353)], 154 | axis=0).sample(frac=1, random_state=5353).reset_index(drop=True) 155 | train_data.append(temp) 156 | train_data = pd.concat(train_data, axis=0).reset_index(drop=True) 157 | self.logger.info(f'# Rows Train Original (concatenated): {train_data.shape}') 158 | gc.collect() 159 | 160 | # Ensure that toxic column has just two unique values 161 | assert train_data['toxic'].nunique() == 2 162 | 163 | val_data = pd.read_pickle(val_file) 164 | val_data = val_data.rename(columns={'toxic_float': 'toxic'}) 165 | if self.config['dev_mode']: 166 | train_data = train_data.sample(np.minimum(256, train_data.shape[0])).reset_index(drop=True) 167 | val_data = val_data.sample(np.minimum(256, val_data.shape[0])).reset_index(drop=True) 168 | self.max_epochs = 1 169 | 170 | self.logger.info(f'# Rows Train: {train_data.shape}') 171 | self.logger.info(f'# Rows Val: {val_data.shape}') 172 | 173 | train_dataset = ( 174 | tf.data.Dataset 175 | .from_tensor_slices( 176 | (np.array(train_data['input_ids'].values.tolist()), train_data['toxic'].values)) 177 | .repeat() 178 | .shuffle(train_data.shape[0]) 179 | .batch(self.config['global_batch_size']) 180 | .prefetch(self.prefetch_policy) 181 | ) 182 | 183 | valid_dataset = ( 184 | tf.data.Dataset 185 | .from_tensor_slices( 186 | (np.array(val_data['input_ids'].values.tolist()), val_data['toxic'].values)) 187 | .batch(self.config['global_batch_size']) 188 | .cache() 189 | .prefetch(self.prefetch_policy) 190 | ) 191 | self.train_size = train_data.shape[0] 192 | return train_dataset, valid_dataset 193 | 194 | def train(self, train_files: list, val_file, model_save_file, 195 | max_epochs, patience, init_lr, 196 | class_balance, class_ratio, label_smoothing, 197 | min_thresh, max_thresh, 198 | epoch_split_factor=1, model_resume_file=None): 199 | # Load the training and validation data and create Loaders 200 | self.max_epochs = max_epochs 201 | train_dataset, valid_dataset = self.get_data_loaders( 202 | train_files, val_file, class_balance, class_ratio, min_thresh, max_thresh) 203 | # Create a callback that saves the model's weights 204 | 205 | with self.strategy.scope(): 206 | model = build_model(self.config['model_name'], max_len=self.config['max_len'], init_lr=init_lr, 207 | label_smoothing=label_smoothing) 208 | if model_resume_file is not None: 209 | # If full path then load from it else from ckpt dir 210 | self.logger.info(f'Loading model from {f"{model_resume_file}"}') 211 | model.load_weights(f'{model_resume_file}') 212 | 213 | self.logger.info(model.summary()) 214 | 215 | # Some times the model starts overfitting even before one Epoch 216 | # This is just a hack to reduce # of steps from what they actually are for an Epoch 217 | # so that best model can be saved in between 218 | if self.config['dev_mode']: 219 | n_steps = (self.train_size // self.config['global_batch_size']) 220 | else: 221 | n_steps = (self.train_size // self.config['global_batch_size']) // epoch_split_factor 222 | 223 | print(f'# Steps Train: {n_steps}') 224 | cp_callback = ModelCheckpoint( 225 | monitor='val_auc', 226 | mode='max', 227 | filepath=f'{model_save_file}', 228 | save_weights_only=True, 229 | save_best_only=True, 230 | verbose=1) 231 | 232 | es_cbk = EarlyStopping(monitor='val_auc', mode='max', verbose=1, patience=patience) 233 | reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_auc', factor=0.2, 234 | patience=1, min_lr=1e-8) 235 | 236 | train_history = model.fit( 237 | train_dataset, 238 | steps_per_epoch=n_steps, 239 | validation_data=valid_dataset, 240 | epochs=self.max_epochs, 241 | callbacks=[cp_callback, es_cbk, reduce_lr] 242 | ) 243 | return train_history 244 | 245 | def get_test_data_loader(self, test_file): 246 | test_data = pd.read_pickle(test_file) 247 | if self.config['dev_mode']: 248 | # test_data = test_data.head(10).reset_index(drop=True) 249 | test_data = test_data[test_data['id'].isin(['0', '1', '10', '100', '1000', '10000'])].reset_index(drop=True) 250 | self.logger.info(f'# Rows Train: {test_data.shape}') 251 | test_dataset = ( 252 | tf.data.Dataset 253 | .from_tensor_slices( 254 | np.array(test_data['input_ids'].values.tolist())) 255 | .batch(self.config['global_batch_size']) 256 | ) 257 | return test_data, test_dataset 258 | 259 | def score_val(self, test_file, model_ckpt, out_file): 260 | test_data, test_dataset = self.get_test_data_loader(test_file) 261 | 262 | with self.strategy.scope(): 263 | model = build_model(self.config['model_name'], max_len=self.config['max_len']) 264 | self.logger.info(f'Loading model from {f"{model_ckpt}"}') 265 | model.load_weights(f'{model_ckpt}') 266 | 267 | test_data['pred_toxic'] = model.predict(test_dataset, verbose=1) 268 | test_data.to_pickle(out_file) 269 | return test_data 270 | 271 | def score_test(self, test_file, model_ckpt, out_file): 272 | self.config['global_batch_size'] = self.config['global_batch_size'] * 4 273 | 274 | test_data, test_dataset = self.get_test_data_loader(test_file) 275 | 276 | with self.strategy.scope(): 277 | model = build_model(self.config['model_name'], max_len=self.config['max_len']) 278 | self.logger.info(f'Loading model from {f"{model_ckpt}"}') 279 | model.load_weights(f'{model_ckpt}') 280 | 281 | test_data['pred_toxic'] = model.predict(test_dataset, verbose=1) 282 | test_data = test_data.rename(columns={'pred_toxic': 'toxic'}) 283 | test_data[['id', 'toxic']].to_csv(out_file, index=False) 284 | 285 | def score_test_split(self, test_file, model_ckpt, out_file): 286 | self.config['global_batch_size'] = self.config['global_batch_size'] * 4 287 | 288 | test_data, test_dataset = self.get_test_data_loader(test_file) 289 | 290 | with self.strategy.scope(): 291 | model = build_model(self.config['model_name'], max_len=self.config['max_len']) 292 | self.logger.info(f'Loading model from {f"{model_ckpt}"}') 293 | model.load_weights(f'{model_ckpt}') 294 | 295 | test_data['pred_toxic'] = model.predict(test_dataset, verbose=1) 296 | test_data = test_data.rename(columns={'pred_toxic': 'toxic'}) 297 | self.logger.info(f'Pre Rollup Count: {test_data.shape}') 298 | test_data = test_data.groupby('id')['toxic'].max().reset_index() 299 | self.logger.info(f'Post Rollup Count: {test_data.shape}') 300 | test_data[['id', 'toxic']].to_csv(out_file, index=False) 301 | 302 | def score_test_extra(self, test_file, model_ckpt, out_file): 303 | self.config['global_batch_size'] = self.config['global_batch_size'] * 4 304 | 305 | test_data, test_dataset = self.get_test_data_loader(test_file) 306 | 307 | with self.strategy.scope(): 308 | model = build_model(self.config['model_name'], max_len=self.config['max_len']) 309 | self.logger.info(f'Loading model from {f"{model_ckpt}"}') 310 | model.load_weights(f'{model_ckpt}') 311 | 312 | test_data['pred_toxic'] = model.predict(test_dataset, verbose=1) 313 | test_data = test_data.rename(columns={'pred_toxic': 'toxic'}) 314 | self.logger.info(f'Pre Rollup Count: {test_data.shape}') 315 | test_data = test_data.groupby('id')['toxic'].mean().reset_index() 316 | self.logger.info(f'Post Rollup Count: {test_data.shape}') 317 | test_data[['id', 'toxic']].to_csv(out_file, index=False) -------------------------------------------------------------------------------- /Code/Moiz/predict.py: -------------------------------------------------------------------------------- 1 | # https://www.kaggle.com/moizsaifee/jigsaw-train-v10-step2-mcp-submission 2 | import sys 3 | sys.path.append("./Code/Moiz/") 4 | 5 | from model_utils import * 6 | import argparse 7 | 8 | logger = init_logger() 9 | set_seed() 10 | 11 | """ 12 | This script takes the following thing as input: 13 | 1) An Input file to be scored 14 | 2) An Input model used for scoring 15 | and produces: 16 | 1) Scored input file 17 | 18 | Parameters: 19 | - dev_mode: for debugging - for debugging purpose, if switched on, then only some specific IDs are scored 20 | - model_name: one of 'jplu/tf-xlm-roberta-large' / 'xlm-mlm-100-1280' / 'bert-base-multilingual-cased' 21 | - model_file: The name of saved model bin file, artificat of the training process 22 | - in_file: The name of (pre-processed) input file to be scored, need to run data prep to generate thi from raw 23 | - in_file_type: one of "extra" or "foreign" / "english" - slightly different ways in which prob across multiple obs 24 | of the same id are handled - Max vs avg - that's the only difference 25 | - out_file: The file (full path) in which the results would be written to 26 | """ 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--dev_mode', type=int, default=0) 31 | parser.add_argument('--model_name', type=str) 32 | parser.add_argument('--model_file', type=str) 33 | 34 | parser.add_argument('--in_file', type=str) 35 | parser.add_argument('--in_file_type', type=str) 36 | parser.add_argument('--out_file', type=str) 37 | 38 | parsed_args = parser.parse_args() 39 | 40 | config = { 41 | 'max_len': 192, 42 | 'batch_size': 16, 43 | 'model_name': parsed_args.model_name, 44 | 'dev_mode': parsed_args.dev_mode, 45 | } 46 | 47 | solver = Solver(logger, config) 48 | if parsed_args.in_file_type == 'extra': 49 | solver.score_test_extra( 50 | parsed_args.in_file, 51 | parsed_args.model_file, 52 | parsed_args.out_file 53 | ) 54 | else: 55 | solver.score_test_split( 56 | parsed_args.in_file, 57 | parsed_args.model_file, 58 | parsed_args.out_file 59 | ) 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /Code/Moiz/train.py: -------------------------------------------------------------------------------- 1 | # https://www.kaggle.com/moizsaifee/jigsaw-train-v10-step2-mcp 2 | # https://www.kaggle.com/moizsaifee/jigsaw-train-v10-step2-mcp-variant1 3 | # https://www.kaggle.com/moizsaifee/jigsaw-train-v10-step2-mcp-variant2 4 | # https://www.kaggle.com/moizsaifee/jigsaw-train-v10-step2-mcp-variant3 5 | 6 | import sys 7 | sys.path.append("./Code/Moiz/") 8 | from model_utils import * 9 | import ast 10 | 11 | import argparse 12 | logger = init_logger() 13 | set_seed() 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--dev_mode', type=int, default=0) 18 | parser.add_argument('--model_name', type=str) 19 | parser.add_argument('--train_files', type=str) 20 | parser.add_argument('--val_file', type=str) 21 | 22 | parser.add_argument('--model_save_file', type=str) 23 | parser.add_argument('--max_epochs', type=int) 24 | parser.add_argument('--patience', type=int) 25 | parser.add_argument('--init_lr', type=float) 26 | parser.add_argument('--class_balance', type=int) 27 | parser.add_argument('--class_ratio', type=int) 28 | parser.add_argument('--label_smoothing', type=float) 29 | parser.add_argument('--min_thresh', type=float) 30 | parser.add_argument('--max_thresh', type=float) 31 | parser.add_argument('--epoch_split_factor', type=int) 32 | parser.add_argument('--model_resume_file', type=str) 33 | 34 | parsed_args = parser.parse_args() 35 | print(parsed_args) 36 | config = { 37 | 'max_len': 192, 38 | 'batch_size': 16, 39 | 'dev_mode': parsed_args.dev_mode, 40 | 'model_name': parsed_args.model_name, 41 | } 42 | if parsed_args.model_resume_file == "None": 43 | model_resume_file = None 44 | else: 45 | model_resume_file = parsed_args.model_resume_file 46 | 47 | solver = Solver(logger, config) 48 | solver.train( 49 | train_files = ast.literal_eval(parsed_args.train_files), 50 | val_file = parsed_args.val_file, 51 | model_save_file = parsed_args.model_save_file, 52 | max_epochs = parsed_args.max_epochs, 53 | patience=parsed_args.patience, 54 | init_lr=parsed_args.init_lr, 55 | class_balance=parsed_args.class_balance, 56 | class_ratio=parsed_args.class_ratio, 57 | label_smoothing=parsed_args.label_smoothing, 58 | min_thresh=parsed_args.min_thresh, 59 | max_thresh=parsed_args.max_thresh, 60 | epoch_split_factor=parsed_args.epoch_split_factor, 61 | model_resume_file=model_resume_file 62 | ) 63 | -------------------------------------------------------------------------------- /Code/Ujjwal/encode_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import numpy as np 4 | from transformers import XLMRobertaTokenizer 5 | 6 | TOKENIZER = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') 7 | 8 | def regular_encode(texts): 9 | encode = TOKENIZER.batch_encode_plus( 10 | texts, 11 | pad_to_max_length=True, 12 | max_length=128 13 | ) 14 | return np.array(encode['input_ids']) 15 | 16 | def save_encodings(): 17 | data1 = pd.read_csv(PATH + '/source/source_1/train_english.csv')[['comment_text']] 18 | data2 = pd.read_csv(PATH + '/source/source_1/train_foreign.csv')[['comment_text']] 19 | data3 = pd.read_csv(PATH + '/source/source_1/subtitle.csv')[['comment_text']] 20 | data = data1.append(data2).append(data3) 21 | data = regular_encode(data['comment_text'].values.tolist()) 22 | print('train_encode.npz:', data.shape) 23 | np.savez(PATH + '/step_1/input/train_encode.npz', arr_0=data) 24 | 25 | data1 = pd.read_csv(PATH + '/source/source_1/valid_foreign.csv')[['comment_text']] 26 | data2 = pd.read_csv(PATH + '/source/source_1/valid_english.csv')[['comment_text']] 27 | data = data1.append(data2) 28 | print('valid_encode.npz:', data.shape) 29 | np.savez(PATH + '/step_1/input/valid_encode.npz', arr_0=data) 30 | 31 | data1 = pd.read_csv(PATH + '/source/source_1/test_foreign.csv')[['comment_text']] 32 | data2 = pd.read_csv(PATH + '/source/source_1/test_english.csv')[['comment_text']] 33 | data = data1.append(data2) 34 | print('test_encode.npz:', data.shape) 35 | np.savez(PATH + '/step_1/input/test_encode.npz', arr_0=data) 36 | 37 | data1 = pd.read_csv(PATH + '/source/source_1/train_english.csv') 38 | data1 = data1[data1['source'] == '2020-train'][['comment_text']] 39 | data2 = pd.read_csv(PATH + '/source/source_1/train_foreign.csv') 40 | data2 = data2[data2['source'] == '2020-train'][['comment_text']] 41 | data3 = pd.read_csv(PATH + '/source/source_1/valid_foreign.csv') 42 | data3 = data3[data3['original'] == 1][['comment_text']] 43 | data4 = pd.read_csv(PATH + '/source/source_1/test_foreign.csv') 44 | data4 = data4[data4['original'] == 1][['comment_text']] 45 | data = data1.append(data2).append(data3).append(data4) 46 | print('data_subset.npz:', data.shape) 47 | np.savez(PATH + '/step_1/input/data_subset.npz', arr_0=data) 48 | 49 | data1 = pd.read_csv(PATH + '/source/source_1/valid_foreign.csv')[['comment_text']] 50 | data2 = pd.read_csv(PATH + '/source/source_1/test_foreign.csv')[['comment_text']] 51 | data3 = pd.read_csv(PATH + '/source/source_1/valid_english.csv')[['comment_text']] 52 | data4 = pd.read_csv(PATH + '/source/source_1/test_english.csv')[['comment_text']] 53 | data = data1.append(data2).append(data3).append(data4) 54 | print('data.npz:', data.shape) 55 | np.savez(PATH + '/step_0/input/data.npz', arr_0=data) 56 | return None 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("--path", type=str) 61 | args = parser.parse_args() 62 | PATH = args.path 63 | save_encodings() 64 | -------------------------------------------------------------------------------- /Code/Ujjwal/finetune_xlm.py: -------------------------------------------------------------------------------- 1 | ## Code borrowed from https://www.kaggle.com/riblidezso/train-from-mlm-finetuned-xlm-roberta-large 2 | 3 | import os 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.metrics import roc_auc_score 8 | from sklearn.model_selection import train_test_split 9 | import tensorflow as tf 10 | from tensorflow.keras.layers import Dense, Input, Dropout 11 | from tensorflow.keras.models import Model 12 | from tensorflow.keras.optimizers import Adam 13 | import transformers 14 | from transformers import TFRobertaModel, AutoTokenizer 15 | import logging 16 | AUTO = tf.data.experimental.AUTOTUNE 17 | 18 | def connect_to_TPU(): 19 | try: 20 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 21 | print('Running on TPU ', tpu.master()) 22 | except ValueError: 23 | tpu = None 24 | if tpu: 25 | tf.config.experimental_connect_to_cluster(tpu) 26 | tf.tpu.experimental.initialize_tpu_system(tpu) 27 | strategy = tf.distribute.experimental.TPUStrategy(tpu) 28 | else: 29 | strategy = tf.distribute.get_strategy() 30 | global_batch_size = BATCH_SIZE * strategy.num_replicas_in_sync 31 | return tpu, strategy, global_batch_size 32 | 33 | 34 | def load_jigsaw_trans(index, langs=['tr','it','es','ru','fr','pt'], 35 | columns=['comment_text', 'toxic']): 36 | train_6langs=[] 37 | for i in range(len(langs)): 38 | fn = FILEPATH2 + 'jigsaw-toxic-comment-train-google-%s-cleaned.csv'%langs[i] 39 | fn = pd.read_csv(fn)[columns].sample(frac=1., random_state=i).reset_index(drop=True) 40 | train_6langs.append(downsample(fn, index)) 41 | return train_6langs 42 | 43 | def downsample(df, index): 44 | print(df.toxic.value_counts()) 45 | pos = df.query('toxic==1').reset_index(drop=True) 46 | neg = df.query('toxic==0').reset_index(drop=True) 47 | neg = neg[neg.index % 10 == index].reset_index(drop=True) 48 | print(pos.shape, neg.shape) 49 | ds_df = pos.append(neg).sample(frac=1.).reset_index(drop=True) 50 | print(ds_df.shape) 51 | return ds_df 52 | 53 | def regular_encode(texts, tokenizer, maxlen=512): 54 | enc_di = tokenizer.batch_encode_plus(texts, 55 | pad_to_max_length=True, 56 | max_length=maxlen) 57 | return np.array(enc_di['input_ids']) 58 | 59 | def create_dist_dataset(X, y=None, training=False): 60 | dataset = tf.data.Dataset.from_tensor_slices(X) 61 | if y is not None: 62 | dataset_y = tf.data.Dataset.from_tensor_slices(y) 63 | dataset = tf.data.Dataset.zip((dataset, dataset_y)) 64 | if training: 65 | dataset = dataset.shuffle(len(X)).repeat() 66 | dataset = dataset.batch(global_batch_size).prefetch(AUTO) 67 | dist_dataset = strategy.experimental_distribute_dataset(dataset) 68 | return dist_dataset 69 | 70 | def create_model_and_optimizer(): 71 | with strategy.scope(): 72 | transformer_layer = TFRobertaModel.from_pretrained(PRETRAINED_MODEL) 73 | model = build_model(transformer_layer) 74 | optimizer_transformer = Adam(learning_rate=LR_TRANSFORMER) 75 | optimizer_head = Adam(learning_rate=LR_HEAD) 76 | return model, optimizer_transformer, optimizer_head 77 | 78 | 79 | def build_model(transformer): 80 | inp = Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_word_ids") 81 | x = transformer(inp)[0][:, 0, :] 82 | x = Dropout(DROPOUT)(x) 83 | out = Dense(1, activation='sigmoid', name='custom_head')(x) 84 | model = Model(inputs=[inp], outputs=[out]) 85 | return model 86 | 87 | def define_losses_and_metrics(): 88 | with strategy.scope(): 89 | loss_object = tf.keras.losses.BinaryCrossentropy( 90 | reduction=tf.keras.losses.Reduction.NONE, from_logits=False) 91 | def compute_loss(labels, predictions): 92 | per_example_loss = loss_object(labels, predictions) 93 | loss = tf.nn.compute_average_loss( 94 | per_example_loss, global_batch_size = global_batch_size) 95 | return loss 96 | train_accuracy_metric = tf.keras.metrics.AUC(name='training_AUC') 97 | return compute_loss, train_accuracy_metric 98 | 99 | def train(train_dist_dataset, val_dist_dataset=None, y_val=None, 100 | total_steps=2000, validate_every=200): 101 | best_weights, history = None, [] 102 | step = 0 103 | for tensor in train_dist_dataset: 104 | distributed_train_step(tensor) 105 | step+=1 106 | if (step % validate_every == 0): 107 | train_metric = train_accuracy_metric.result().numpy() 108 | print("Step %d, train AUC: %.5f" % (step, train_metric)) 109 | if val_dist_dataset: 110 | val_metric = roc_auc_score(y_val, predict(val_dist_dataset)) 111 | print("Step %d, val AUC: %.5f" % (step,val_metric)) 112 | history.append(val_metric) 113 | if history[-1] == max(history): 114 | best_weights = model.get_weights() 115 | train_accuracy_metric.reset_states() 116 | if step == total_steps: 117 | break 118 | model.set_weights(best_weights) 119 | 120 | @tf.function 121 | def distributed_train_step(data): 122 | strategy.experimental_run_v2(train_step, args=(data,)) 123 | 124 | def train_step(inputs): 125 | features, labels = inputs 126 | transformer_trainable_variables = [ v for v in model.trainable_variables 127 | if (('pooler' not in v.name) and 128 | ('custom' not in v.name))] 129 | head_trainable_variables = [ v for v in model.trainable_variables 130 | if 'custom' in v.name] 131 | with tf.GradientTape(persistent=True) as tape: 132 | predictions = model(features, training=True) 133 | loss = compute_loss(labels, predictions) 134 | gradients_transformer = tape.gradient(loss, transformer_trainable_variables) 135 | gradients_head = tape.gradient(loss, head_trainable_variables) 136 | del tape 137 | optimizer_transformer.apply_gradients(zip(gradients_transformer, 138 | transformer_trainable_variables)) 139 | optimizer_head.apply_gradients(zip(gradients_head, 140 | head_trainable_variables)) 141 | train_accuracy_metric.update_state(labels, predictions) 142 | 143 | def predict(dataset): 144 | predictions = [] 145 | for tensor in dataset: 146 | predictions.append(distributed_prediction_step(tensor)) 147 | predictions = np.vstack(list(map(np.vstack,predictions))) 148 | return predictions 149 | 150 | @tf.function 151 | def distributed_prediction_step(data): 152 | predictions = strategy.experimental_run_v2(prediction_step, args=(data,)) 153 | return strategy.experimental_local_results(predictions) 154 | 155 | def prediction_step(inputs): 156 | features = inputs 157 | predictions = model(features, training=False) 158 | return predictions 159 | 160 | if __name__ == '__main__': 161 | 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument("--path", type=str) 164 | parser.add_argument("--mode", type=str) 165 | parser.add_argument("--fold", type=int) 166 | parser.add_argument("--pseudo", type=int) 167 | parser.add_argument("--out", type=str) 168 | args = parser.parse_args() 169 | 170 | PATH = args.path 171 | PRETRAINED_TOKENIZER= 'jplu/tf-xlm-roberta-large' 172 | MAX_LEN = 192 173 | DROPOUT = 0.5 174 | BATCH_SIZE = 16 175 | TOTAL_STEPS_STAGE1 = 2000 176 | VALIDATE_EVERY_STAGE1 = 200 177 | TOTAL_STEPS_STAGE2 = 200 178 | VALIDATE_EVERY_STAGE2 = 10 179 | LR_TRANSFORMER = 5e-6 180 | LR_HEAD = 1e-3 181 | 182 | PRETRAINED_MODEL = PATH + 'step-2/' + args.mode + '/' 183 | 184 | FILEPATH1 = PATH + 'source/source_1/' 185 | FILEPATH2 = PATH + 'source/source_2/' 186 | 187 | tpu, strategy, global_batch_size = connect_to_TPU() 188 | print("REPLICAS: ", strategy.num_replicas_in_sync) 189 | train_df = pd.concat(load_jigsaw_trans(args.fold)).sample(frac=1., random_state=2017) 190 | train_df = train_df.reset_index(drop=True) 191 | val_df = pd.read_csv(FILEPATH1 + 'validation.csv') 192 | if args.pseudo: 193 | extra_df = pd.read_csv(FILEPATH1 + 'pseudo_label.csv')[['comment_text', 'toxic']] 194 | val_df = val_df.append(extra_df).sample(frac=1., random_state=2017).reset_index(drop=True) 195 | test_df = pd.read_csv(FILEPATH1 + 'test_foreign.csv') 196 | sub_df = test_df[['id','lang','weight']].copy() 197 | 198 | tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_TOKENIZER) 199 | X_train = regular_encode(train_df.comment_text.values, tokenizer, maxlen=MAX_LEN) 200 | X_val = regular_encode(val_df.comment_text.values, tokenizer, maxlen=MAX_LEN) 201 | X_test = regular_encode(test_df.comment_text.values, tokenizer, maxlen=MAX_LEN) 202 | y_train = train_df.toxic.values.reshape(-1,1) 203 | y_val = val_df.toxic.values.reshape(-1,1) 204 | 205 | train_dist_dataset = create_dist_dataset(X_train, y_train, True) 206 | val_dist_dataset = create_dist_dataset(X_val) 207 | test_dist_dataset = create_dist_dataset(X_test) 208 | 209 | model, optimizer_transformer, optimizer_head = create_model_and_optimizer() 210 | model.summary() 211 | compute_loss, train_accuracy_metric = define_losses_and_metrics() 212 | 213 | train(train_dist_dataset, val_dist_dataset, y_val, TOTAL_STEPS_STAGE1, VALIDATE_EVERY_STAGE1) 214 | optimizer_head.learning_rate.assign(1e-4) 215 | X_train, X_val, y_train, y_val = train_test_split(X_val, y_val, test_size = 0.1) 216 | train_dist_dataset = create_dist_dataset(X_train, y_train, training=True) 217 | val_dist_dataset = create_dist_dataset(X_val, y_val) 218 | 219 | train(train_dist_dataset, val_dist_dataset, y_val, total_steps=TOTAL_STEPS_STAGE2, validate_every=VALIDATE_EVERY_STAGE2) 220 | 221 | sub_df['toxic'] = predict(test_dist_dataset)[:,0] 222 | sub_df.to_csv('submission.csv', index=False) 223 | 224 | model.save(args.path + '/step_3/model_{}.h5'.format(args.out)) 225 | 226 | 227 | -------------------------------------------------------------------------------- /Code/Ujjwal/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | import tensorflow as tf 6 | from tensorflow.keras.layers import Dense, Input, Dropout 7 | from tensorflow.keras.models import Model 8 | from tensorflow.keras.optimizers import Adam 9 | import transformers 10 | from transformers import TFRobertaModel, AutoTokenizer 11 | AUTO = tf.data.experimental.AUTOTUNE 12 | 13 | def connect_to_TPU(): 14 | try: 15 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 16 | print('Running on TPU ', tpu.master()) 17 | except ValueError: 18 | tpu = None 19 | if tpu: 20 | tf.config.experimental_connect_to_cluster(tpu) 21 | tf.tpu.experimental.initialize_tpu_system(tpu) 22 | strategy = tf.distribute.experimental.TPUStrategy(tpu) 23 | else: 24 | strategy = tf.distribute.get_strategy() 25 | global_batch_size = BATCH_SIZE * strategy.num_replicas_in_sync 26 | return tpu, strategy, global_batch_size 27 | 28 | def regular_encode(texts, tokenizer, maxlen=512): 29 | enc_di = tokenizer.batch_encode_plus(texts, 30 | return_token_type_ids=False, 31 | pad_to_max_length=True, 32 | max_length=maxlen) 33 | return np.array(enc_di['input_ids']) 34 | 35 | def create_dist_dataset(X): 36 | dataset = tf.data.Dataset.from_tensor_slices(X) 37 | dataset = dataset.batch(global_batch_size).prefetch(AUTO) 38 | dist_dataset = strategy.experimental_distribute_dataset(dataset) 39 | return dist_dataset 40 | 41 | def create_model_and_optimizer(): 42 | with strategy.scope(): 43 | transformer_layer = TFRobertaModel.from_pretrained(PRETRAINED_TOKENIZER) 44 | model = build_model(transformer_layer) 45 | return model 46 | 47 | def build_model(transformer): 48 | inp = Input(shape=(MAX_LEN,), dtype=tf.int32, name="input_word_ids") 49 | x = transformer(inp)[0][:, 0, :] 50 | x = Dropout(DROPOUT)(x) 51 | out = Dense(1, activation='sigmoid', name='custom_head')(x) 52 | model = Model(inputs=[inp], outputs=[out]) 53 | return model 54 | 55 | def predict(dataset): 56 | predictions = [] 57 | for tensor in dataset: 58 | predictions.append(distributed_prediction_step(tensor)) 59 | predictions = np.vstack(list(map(np.vstack,predictions))) 60 | return predictions 61 | 62 | @tf.function 63 | def distributed_prediction_step(data): 64 | predictions = strategy.experimental_run_v2(prediction_step, args=(data,)) 65 | return strategy.experimental_local_results(predictions) 66 | 67 | def prediction_step(inputs): 68 | features = inputs 69 | predictions = model(features, training=False) 70 | return predictions 71 | 72 | if __name__ == '__main__': 73 | 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("--path", type=str) 76 | parser.add_argument("--mode", type=str) 77 | args = parser.parse_args() 78 | 79 | PATH = args.path 80 | PRETRAINED_TOKENIZER= 'jplu/tf-xlm-roberta-large' 81 | MAX_LEN = 192 82 | DROPOUT = 0.5 83 | BATCH_SIZE = 16 84 | tpu, strategy, global_batch_size = connect_to_TPU() 85 | score = pd.read_csv(PATH + 'source/source_1/test_foreign.csv') 86 | submit = score[['id','lang','original']].copy() 87 | tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_TOKENIZER) 88 | X_test = regular_encode(score.comment_text.values, tokenizer, maxlen=MAX_LEN) 89 | test_dist_dataset = create_dist_dataset(X_test) 90 | model = create_model_and_optimizer() 91 | model.summary() 92 | model.load_weights(PATH + '/step-3/model_{}.h5'.format(args.mode)) 93 | submit['score'] = predict(test_dist_dataset) 94 | print(submit) 95 | submit.to_csv(PATH + '/step-3/submission_{}.csv'.format(args.mode), index=False) 96 | 97 | -------------------------------------------------------------------------------- /Code/Ujjwal/inference.sh: -------------------------------------------------------------------------------- 1 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'simpl_v1_1' 2 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'simpl_v1_4' 3 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'pslbl_v1_0' 4 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'simpl_v2_8' 5 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'simpl_v2_5' 6 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'pslbl_v2_7' 7 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'simpl_v3_6' 8 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'simpl_v3_3' 9 | python3 inference.py --path '../../Input/Ujjwal/Data/' --mode 'pslbl_v3_8' 10 | python3 post_process.py --path '../../Input/Ujjwal/Data/' -------------------------------------------------------------------------------- /Code/Ujjwal/post_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | import numpy as np 5 | from glob import glob 6 | from functools import reduce 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--path", type=str) 10 | args = parser.parse_args() 11 | 12 | base = pd.read_csv(args.path + 'source/source_1/test_foreign.csv')['original'].tolist() 13 | 14 | datasets = [] 15 | 16 | for file in glob(args.path + '/step-3/submission*.csv'): 17 | dataset = pd.read_csv(file)[['id','score']] 18 | dataset.columns = ['id','toxic'] 19 | dataset['toxic'] = dataset['toxic'] / dataset['toxic'].mean() 20 | dataset['toxic'] = dataset['toxic'] * 0.21 21 | dataset['weight'] = base 22 | datasets.append(dataset) 23 | 24 | datasets = reduce(lambda x,y : x.append(y), datasets) 25 | 26 | no_tta = datasets.copy() 27 | with_tta = datasets.copy() 28 | 29 | no_tta['weight'] = no_tta['weight'].map(lambda x : 0. if x != 1. else 1.) 30 | no_tta['toxic'] = no_tta['toxic'] * no_tta['weight'] 31 | no_tta = no_tta.groupby('id')['toxic','weight'].sum().reset_index() 32 | no_tta['toxic'] = no_tta['toxic'] / no_tta['weight'] 33 | no_tta = no_tta[['id','toxic']] 34 | no_tta['toxic'] = no_tta['toxic'] / no_tta['toxic'].mean() 35 | no_tta['toxic'] = no_tta['toxic'] * 0.21 36 | 37 | with_tta['weight'] = with_tta['weight'].map(lambda x : 5. if x == 1. else 1.) 38 | with_tta['toxic'] = with_tta['toxic'] * with_tta['weight'] 39 | with_tta = with_tta.groupby('id')['toxic','weight'].sum().reset_index() 40 | with_tta['toxic'] = with_tta['toxic'] / with_tta['weight'] 41 | with_tta = with_tta[['id','toxic']] 42 | with_tta['toxic'] = with_tta['toxic'] / with_tta['toxic'].mean() 43 | with_tta['toxic'] = with_tta['toxic'] * 0.21 44 | 45 | no_tta.to_csv(args.path + 'no_tta.csv', index=False) 46 | with_tta.to_csv(args.path + 'with_tta.csv', index=False) -------------------------------------------------------------------------------- /Code/Ujjwal/pretrain_xlm.py: -------------------------------------------------------------------------------- 1 | ## Code borrowed from https://www.kaggle.com/riblidezso/finetune-xlm-roberta-on-jigsaw-test-data-with-mlm 2 | 3 | import argparse 4 | from ast import literal_eval 5 | import os 6 | import numpy as np 7 | import pandas as pd 8 | import tensorflow as tf 9 | from tensorflow.keras.optimizers import Adam 10 | import transformers 11 | from transformers import TFAutoModelWithLMHead, AutoTokenizer 12 | import logging 13 | logging.getLogger().setLevel(logging.NOTSET) 14 | AUTO = tf.data.experimental.AUTOTUNE 15 | 16 | def connect_to_TPU(): 17 | try: 18 | tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 19 | print('Running on TPU ', tpu.master()) 20 | except ValueError: 21 | tpu = None 22 | if tpu: 23 | tf.config.experimental_connect_to_cluster(tpu) 24 | tf.tpu.experimental.initialize_tpu_system(tpu) 25 | strategy = tf.distribute.experimental.TPUStrategy(tpu) 26 | else: 27 | strategy = tf.distribute.get_strategy() 28 | global_batch_size = BATCH_SIZE * strategy.num_replicas_in_sync 29 | return tpu, strategy, global_batch_size 30 | 31 | def prepare_mlm_input_and_labels(X): 32 | inp_mask = np.random.rand(*X.shape)<0.15 33 | inp_mask[X<=2] = False 34 | labels = -1 * np.ones(X.shape, dtype=int) 35 | labels[inp_mask] = X[inp_mask] 36 | X_mlm = np.copy(X) 37 | inp_mask_2mask = inp_mask & (np.random.rand(*X.shape)<0.90) 38 | X_mlm[inp_mask_2mask] = 250001 # mask token is the last in the dict 39 | inp_mask_2random = inp_mask_2mask & (np.random.rand(*X.shape) < 1/9) 40 | X_mlm[inp_mask_2random] = np.random.randint(3, 250001, inp_mask_2random.sum()) 41 | return X_mlm, labels 42 | 43 | def create_dist_dataset(X, y=None, training=False): 44 | dataset = tf.data.Dataset.from_tensor_slices(X) 45 | if y is not None: 46 | dataset_y = tf.data.Dataset.from_tensor_slices(y) 47 | dataset = tf.data.Dataset.zip((dataset, dataset_y)) 48 | if training: 49 | dataset = dataset.shuffle(len(X)).repeat() 50 | dataset = dataset.batch(global_batch_size).prefetch(AUTO) 51 | dist_dataset = strategy.experimental_distribute_dataset(dataset) 52 | return dist_dataset 53 | 54 | def create_mlm_model_and_optimizer(): 55 | with strategy.scope(): 56 | model = TFAutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL) 57 | optimizer = tf.keras.optimizers.Adam(learning_rate=LR) 58 | return model, optimizer 59 | 60 | def define_mlm_loss_and_metrics(): 61 | with strategy.scope(): 62 | mlm_loss_object = masked_sparse_categorical_crossentropy 63 | def compute_mlm_loss(labels, predictions): 64 | per_example_loss = mlm_loss_object(labels, predictions) 65 | loss = tf.nn.compute_average_loss( 66 | per_example_loss, global_batch_size = global_batch_size) 67 | return loss 68 | train_mlm_loss_metric = tf.keras.metrics.Mean() 69 | return compute_mlm_loss, train_mlm_loss_metric 70 | 71 | def masked_sparse_categorical_crossentropy(y_true, y_pred): 72 | y_true_masked = tf.boolean_mask(y_true, tf.not_equal(y_true, -1)) 73 | y_pred_masked = tf.boolean_mask(y_pred, tf.not_equal(y_true, -1)) 74 | loss = tf.keras.losses.sparse_categorical_crossentropy(y_true_masked, y_pred_masked, from_logits=True) 75 | return loss 76 | 77 | def train_mlm(train_dist_dataset, total_steps=2000, evaluate_every=200): 78 | step = 0 79 | for tensor in train_dist_dataset: 80 | distributed_mlm_train_step(tensor) 81 | step+=1 82 | if (step % evaluate_every == 0): 83 | train_metric = train_mlm_loss_metric.result().numpy() 84 | print("Step %d, train loss: %.2f" % (step, train_metric)) 85 | train_mlm_loss_metric.reset_states() 86 | if step == total_steps: 87 | break 88 | 89 | @tf.function 90 | def distributed_mlm_train_step(data): 91 | strategy.experimental_run_v2(mlm_train_step, args=(data,)) 92 | 93 | @tf.function 94 | def mlm_train_step(inputs): 95 | features, labels = inputs 96 | with tf.GradientTape() as tape: 97 | predictions = mlm_model(features, training=True)[0] 98 | loss = compute_mlm_loss(labels, predictions) 99 | gradients = tape.gradient(loss, mlm_model.trainable_variables) 100 | optimizer.apply_gradients(zip(gradients, mlm_model.trainable_variables)) 101 | train_mlm_loss_metric.update_state(loss) 102 | 103 | 104 | if __name__ == '__main__': 105 | 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument("--mode", type=str) 108 | parser.add_argument("--path", type=str) 109 | args = parser.parse_args() 110 | 111 | PATH = args.path 112 | MAX_LEN = 128 113 | BATCH_SIZE = 16 114 | TOTAL_STEPS = 10000 115 | EVALUATE_EVERY = 200 116 | LR = 1e-5 117 | PRETRAINED_MODEL = 'jplu/tf-xlm-roberta-large' 118 | 119 | np.random.seed(2017) 120 | tpu, strategy, global_batch_size = connect_to_TPU() 121 | print("REPLICAS: ", strategy.num_replicas_in_sync) 122 | 123 | if args.mode == 'version1': 124 | X_trn = np.load(PATH + 'step-1/train_encode.npz')['arr_0'] 125 | X_val = np.load(PATH + 'step-1/valid_encode.npz')['arr_0'] 126 | X_tst = np.load(PATH + 'step-1/test_encode.npz')['arr_0'] 127 | X_train_mlm = np.vstack([X_trn, X_val, X_tst]) 128 | np.random.shuffle(X_train_mlm) 129 | X_train_mlm = X_train_mlm[:1000000,:] 130 | if args.mode == 'version2': 131 | X_trn = np.load(PATH + 'step-1/data.npz')['arr_0'] 132 | X_train_mlm = X_trn 133 | np.random.shuffle(X_train_mlm) 134 | X_train_mlm = X_train_mlm[:1000000,:] 135 | if args.mode == 'version3': 136 | X_trn = np.load(PATH + 'step-1/data_subset.npz')['arr_0'] 137 | X_train_mlm = X_trn 138 | np.random.shuffle(X_train_mlm) 139 | 140 | X_train_mlm, y_train_mlm = prepare_mlm_input_and_labels(X_train_mlm) 141 | train_dist_dataset = create_dist_dataset(X_train_mlm, y_train_mlm, True) 142 | mlm_model, optimizer = create_mlm_model_and_optimizer() 143 | mlm_model.summary() 144 | compute_mlm_loss, train_mlm_loss_metric = define_mlm_loss_and_metrics() 145 | train_mlm(train_dist_dataset, TOTAL_STEPS, EVALUATE_EVERY) 146 | mlm_model.save_pretrained(PATH + 'step-2/' + args.mode + '/') 147 | 148 | -------------------------------------------------------------------------------- /Code/Ujjwal/train.sh: -------------------------------------------------------------------------------- 1 | python3 encode_text.py --path '../../Input/Ujjwal/Data' 2 | 3 | python3 pretrain_xlm.py --path '../../Input/Ujjwal/Data' --mode version1 4 | python3 pretrain_xlm.py --path '../../Input/Ujjwal/Data' --mode version2 5 | python3 pretrain_xlm.py --path '../../Input/Ujjwal/Data' --mode version3 6 | 7 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version1 --fold 1 --pseudo 0 --out 'simpl_v1_1' 8 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version1 --fold 4 --pseudo 0 --out 'simpl_v1_4' 9 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version1 --fold 0 --pseudo 1 --out 'pslbl_v1_0' 10 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version2 --fold 8 --pseudo 0 --out 'simpl_v2_8' 11 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version2 --fold 5 --pseudo 0 --out 'simpl_v2_5' 12 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version2 --fold 7 --pseudo 1 --out 'pslbl_v2_7' 13 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version3 --fold 6 --pseudo 0 --out 'simpl_v3_6' 14 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version3 --fold 3 --pseudo 0 --out 'simpl_v3_3' 15 | python3 finetune_xlm.py --path '../../Input/Ujjwal/Data' --mode version3 --fold 8 --pseudo 1 --out 'pslbl_v3_8' 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Hello! 2 | 3 | Here is what you'll find in this repository: 4 | * High level overview of components of our solution which led to our final score (Solution Summary) 5 | * Summary of various components of our solution (Component Summary) 6 | * Complete data processing, model training, scoring and blending code for all the components 7 | 8 | Here is what you'll NOT find in this repository: 9 | * The input / translated data, chached models (due to Github file size limit) 10 | * Point and click reproducibility (due to absence of cached input & processed data / models) 11 | 12 | 13 | Some additional info can be found on [this](https://towardsdatascience.com/kaggle-3rd-place-solution-jigsaw-multilingual-toxic-comment-classification-e36d7d194bfb) blog 14 | 15 | 16 | # Solution Summary 17 | ![Solution Summary](https://github.com/moizsaifee/kaggle-jigsaw-multilingual-toxic-comment-classification-3rd-place-solution/blob/master/img/1.png) 18 | 19 | # Component Summary 20 | ![Solution Summary](https://github.com/moizsaifee/kaggle-jigsaw-multilingual-toxic-comment-classification-3rd-place-solution/blob/master/img/2.png) 21 | ![Solution Summary](https://github.com/moizsaifee/kaggle-jigsaw-multilingual-toxic-comment-classification-3rd-place-solution/blob/master/img/3.png) 22 | 23 | # Code / Implementation 24 | 25 | ## MODEL BUILD: There are three options to produce the solution. 26 | 1) very fast prediction 27 | a) runs in a few minutes 28 | b) uses precomputed neural network predictions 29 | 2) ordinary prediction (use models from Output/Models/Igor/,Output/Models/Moiz/,Input/Ujjwal/Data/step-[2-3]/*h5) 30 | a) expect this to run for 5-6 hours 31 | b) uses binary model files 32 | 3) retrain models 33 | a) expect this to run about 1-2 days 34 | b) trains all models from scratch 35 | c) follow this with (2) to produce entire solution from scratch 36 | 37 | command to run each build is below 38 | 1) very fast prediction (overwrites Output/Predictions/submission.csv) 39 | python ./blend.py 40 | 41 | 2) ordinary prediction (overwrites Output/Predictions/submission.csv, Output/Predictions/Moiz/*csv, Output/Predictions/Ujjwal/*csv, Output/Models/Igor/{lang}/*probs.csv ) 42 | python ./inference.py 43 | 44 | 3) retrain models (overwrites models in Output/Models/Igor_dev/,Output/Models/Moiz_dev/,Input/Ujjwal/Data/step-[2-3]/*h5) 45 | python ./train.py 46 | 47 | 48 | 49 | ## CONTENTS 50 | Input/Igor/test.csv.zip :original kaggle test data 51 | Input/Igor/validation.csv.zip : original kaggle validation data 52 | 53 | Input/Igor/train_data_{lang}_google.csv.zip : translated train data via google translater (https://www.kaggle.com/miklgr500/jigsaw-train-multilingual-coments-google-api) 54 | 55 | Input/Igor/train_data_{lang}_yandex.csv.zip : translated train data via yandex translater (https://www.kaggle.com/ma7555/jigsaw-train-translated-yandex-api) 56 | 57 | Input/Igor/633287_1126366_compressed_open-subtitles-synthesic.csv.zip : pseudolabeling open_subtitles (https://www.kaggle.com/shonenkov/open-subtitles-toxic-pseudo-labeling) 58 | 59 | 60 | Input/Common/Raw/ : original kaggle datasets + pseudolabeling open_subtitles and pseudolabeling test_data 61 | Input/Common/Processed_Ujjwal/ : translated datasets 62 | 63 | 64 | Input/Moiz/train/ : data produced by ./prepare_data_train.py 65 | Input/Moiz/test/ : data produced by ./prepare_data_inference.py 66 | 67 | 68 | Output/Models/Igor/{lang}/*bin : PyTorch checkpoints of mono-lingual models 69 | Output/Models/Moiz/*h5 : TF checkpoints of MAS models 70 | Input/Ujjwal/Data/step-[2-3]/*h5 : TF checkpoints of MLM models 71 | 72 | Output/Models/Igor/{lang}/*probs.csv : predictions of mono-lingual models 73 | Output/Predictions/Moiz/*csv : predictions of MAS models 74 | Input/Ujjwal/Data/*tta.csv : predictions of MLM models 75 | 76 | Output/Predictions/submission.csv : final submission file 77 | 78 | 79 | 80 | ## HARDWARE: (The following specs were used to create the original solution) 81 | v3-128 TPU - need for TF training (All TF models were trained via Kaggle). It's important that instance has 16Gb memory per core (128 totally) 82 | 64Gb memory - need for PyTorch training (All PyTorch models were trained via Google Colab Pro which has more memory than Kaggle instance but less TPU memory (8 vs 16)) 83 | Access to internet for downloading packages 84 | 85 | 86 | ## SOFTWARE (python packages are detailed separately in `requirements.txt`): 87 | Python 3.6.9 88 | 89 | 90 | WARNING! Do no install pytorch-xla-env-setup.py before starting TF code. In this case there is an incompatibility in using TPU via TF and via PyTorch in the same instance runtime. The valid sequence of running (including install packages) is in ./train.py and ./inference.py. 91 | 92 | 93 | ## DATA SETUP 94 | 95 | 96 | ## DATA PROCESSING 97 | ### The train/predict code will also call this script if it has not already been run on the relevant data. 98 | python ./prepare_data_train.py 99 | python ./prepare_data_inference.py 100 | 101 | 102 | 103 | 104 | 105 | 106 | ############## Ujjwal model description 107 | The following code repository produces the the submission file for MLM part. 108 | 109 | The code is borrowed from @riblidezso's following notebooks: 110 | 111 | - [Pre-training Roberta-XLM](https://www.kaggle.com/riblidezso/finetune-xlm-roberta-on-jigsaw-test-data-with-mlm) 112 | - [Supervised Training Roberta-XLM](https://www.kaggle.com/riblidezso/train-from-mlm-finetuned-xlm-roberta-large) 113 | 114 | These codes assume TPU access. 115 | 116 | 117 | ## Overview 118 | 119 | ### Input Data: 120 | 121 | There are two sources of input data: 122 | 123 | - source_1 124 | - train_english: given english dataset for toxic comments 125 | - train_foreign: train_english dataset translated to foreign dataset 126 | - valid_english: validation data translated to english 127 | - valid_foreign: original validation dataset 128 | - test_english: test dataset translated to english 129 | - test_foreign: original test dataset 130 | - subtitle: open subtitle dataset 131 | - pseudo_label: given test dataset pseudo-labeled based on our model prediction scores 132 | 133 | - source_2: 134 | - [Public Dataset](https://www.kaggle.com/miklgr500/jigsaw-train-multilingual-coments-google-api) 135 | 136 | We used three different input pipelines to pre-train XLM models. We translated each record to various languages (en, es, fr, tr, ru, it, pt) to obtain more data for pre-training the model. 137 | 138 | - Version 1: Translated train, valid and test 139 | - Version 2: Translated train and open subtitle dataset 140 | - Version 3: Translated validation and test 141 | 142 | ### Step - 1: Data Processing 143 | 144 | Code: encode.py 145 | Input: source/source_1 files 146 | Output: encoded npz arrays step_1 147 | 148 | We encoded the text in CSV files to create numpy arrays with numerical encodings. We did this to reduce the TPU runtime of the notebooks. The encoded arrays can be found in this [Kaggle Dataset](https://www.kaggle.com/brightertiger/jigsawencode) 149 | 150 | ### Step - 2: Pre-training 151 | 152 | Code: pretrain_xlm.py 153 | Input: encoded npz arrays step_1 154 | Output: xlm-model weights step_2/version* 155 | 156 | We used the three input versions to pre-train three XLM-Roberta models using masked language modeling. 157 | 158 | The Kaggle Scripts corresponding to three versions are: 159 | - [Version 1](https://www.kaggle.com/brightertiger/finetune-xlm-roberta-on-jigsaw-test-data-with-mlm?scriptVersionId=35754034) 160 | - [Version 2](https://www.kaggle.com/brightertiger/finetune-xlm-roberta-on-jigsaw-test-data-with-mlm?scriptVersionId=35762322) 161 | - [Version 3](https://www.kaggle.com/brightertiger/finetune-xlm-roberta-on-jigsaw-test-data-with-mlm?scriptVersionId=35904862) 162 | 163 | These three versions output three XLM-Roberta models that are used for supervised training in the next step. The models are saved here. 164 | 165 | - [Version 1](https://www.kaggle.com/brightertiger/mlmv1) 166 | - [Version 2](https://www.kaggle.com/brightertiger/mlmv2) 167 | - [Version 3](https://www.kaggle.com/brightertiger/mlmv3) 168 | 169 | ### Step - 3: Fine-tuning 170 | 171 | Code: finetune_xlm.py 172 | Input: source/source_2 files, step-2/input models, fold-idx 173 | Output: step-3 model weights 174 | 175 | The three models from previous version are fine-tuned using task labels in this step. The models train best when downsampled 1:1 ratio of toxic and non-toxic labels. To ensure this, each model is fine-tuning task in triggered ~10 times with each a different subset for non-toxic labels. To add more diversity to the training pipeline, in half of the runs pseudo-labels (generated from our predictions) were added to the validation dataset. 176 | 177 | The Kaggle scripts for this version can be found at: 178 | 179 | - [Version 1](https://www.kaggle.com/brightertiger/mlm-v1-code) 180 | - [Version 2](https://www.kaggle.com/brightertiger/mlm-v2-code) 181 | - [Version 3](https://www.kaggle.com/brightertiger/mlm-v3-code) 182 | 183 | ### Step - 4: Inference 184 | 185 | Code: inference.py 186 | Input: step-3 model weights 187 | Output: step-3 score files 188 | 189 | These codes can be used for running inference based on the models trained in Step-2. For running scoring on new file, replace it with the one located in /Input/Ujjwal/Data/source/source_1/test_foreign.csv. 190 | 191 | ### Step - 5: Post-Processing 192 | 193 | Code: post-process.py 194 | Input: step-3 score files 195 | Output: with_tta.csv, without_tta.csv 196 | 197 | The final output is blended is generated by averaging the two versions with and without test-time augmentation (TTA). 198 | 199 | - Without TTA: use only those records present in original test. Give zero weight to everything else. 200 | - With TTA: use records present in original file (weight=5.) and records obtained from translation (weight=1.) 201 | 202 | These files are avialable in [output folder](Output/Predictions/Ujjwal). A Simple average of these files scores 0.9460 and 0.9446 on pulic and private leaderboards respectively. These files are then combined with the scores from my other team mates in the final blend. 203 | 204 | 205 | 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /SETTINGS.json: -------------------------------------------------------------------------------- 1 | { 2 | "mas_data_prep_inp_path": "./Input/Common/", 3 | 4 | "mas_data_prep_out_path": "./Input/Moiz/train/", 5 | "mas_train_out_model_path": "./Output/Models/Moiz_dev/", 6 | "mas_predict_inp_path": "./Input/Moiz/test/", 7 | 8 | "mas_train_dev_mode": 0, 9 | "mas_train_inp_path": "./Input/Moiz/train/", 10 | 11 | "mas_predict_dev_mode": 0, 12 | "mas_predict_out_path": "./Output/Predictions/Moiz/", 13 | "mas_predict_model_path": "./Output/Models/Moiz/" 14 | 15 | } -------------------------------------------------------------------------------- /blend.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | import re 5 | 6 | ############# MLM MODELS 7 | ujj_mlm1 = pd.read_csv('Input/Ujjwal/Data/with_tta.csv').rename(columns={'toxic':'mlm1'}) # Public MLM with TTA 8 | ujj_mlm2 = pd.read_csv('Input/Ujjwal/Data/no_tta.csv').rename(columns={'toxic':'mlm2'}) # Public MLM Variant 9 | final_mlm = pd.merge(ujj_mlm1, ujj_mlm2, on='id') 10 | final_mlm['mlm_blend'] = (final_mlm['mlm1'] + final_mlm['mlm2'])/ 2 11 | final_mlm = final_mlm[['id', 'mlm_blend']] 12 | final_mlm['mlm_blend_rank'] = final_mlm['mlm_blend'].rank(pct=True) 13 | 14 | 15 | 16 | ############### MAS MODELS 17 | mas_rank = pd.read_csv('Output/Models/Moiz/submission_MAS.csv').rename(columns={'toxic': 'mas_blend_rank'}) 18 | final_mas = mas_rank 19 | 20 | 21 | ############# MONOLINGUAL BERTS 22 | test_sub_all = pd.read_csv('Output/Models/Moiz/submission_Roberta.csv') 23 | df_test = pd.read_csv('Input/Igor/test.csv.zip') 24 | test_sub_all = test_sub_all.merge(df_test[['id','lang']],on='id',how='left') 25 | test_sub_tr = pd.read_csv('Output/Models/Igor/tr/inf_tr_google_dbmdz.prob.csv') 26 | test_sub_tr2 = pd.read_csv('Output/Models/Igor/tr/inf_tr_google_savasy.prob.csv') 27 | test_sub_it = pd.read_csv('Output/Models/Igor/it/inf_it_yandex_xxl.prob.csv') 28 | test_sub_es = pd.read_csv('Output/Models/Igor/es/inf_es_google_wwm.prob.csv') 29 | test_sub_es2 = pd.read_csv('Output/Models/Igor/es/inf_es_yandex_wwm.prob.csv') 30 | test_sub_ru = pd.read_csv('Output/Models/Igor/ru/inf_ru_google_conv.prob.csv') 31 | test_sub_ru2 = pd.read_csv('Output/Models/Igor/ru/inf_ru_yandex_conv.prob.csv') 32 | test_sub_fr = pd.read_csv('Output/Models/Igor/fr/inf_fr_google_camembert_large.prob.csv') 33 | test_sub_fr2 = pd.read_csv('Output/Models/Igor/fr/inf_fr_yandex_camembert_large.prob.csv') 34 | 35 | test_sub_tr.columns = ['id','tr_toxic'] 36 | test_sub_tr2.columns = ['id','tr_toxic2'] 37 | test_sub_it.columns = ['id','it_toxic'] 38 | test_sub_es.columns = ['id','es_toxic'] 39 | test_sub_es2.columns = ['id','es_toxic2'] 40 | test_sub_ru.columns = ['id','ru_toxic'] 41 | test_sub_ru2.columns = ['id','ru_toxic2'] 42 | test_sub_fr.columns = ['id','fr_toxic'] 43 | test_sub_fr2.columns = ['id','fr_toxic2'] 44 | 45 | test_sub_all = test_sub_all.merge(test_sub_tr[['id','tr_toxic']],on='id',how='left') 46 | test_sub_all = test_sub_all.merge(test_sub_tr2[['id','tr_toxic2']],on='id',how='left') 47 | test_sub_all = test_sub_all.merge(test_sub_it[['id','it_toxic']],on='id',how='left') 48 | test_sub_all = test_sub_all.merge(test_sub_es[['id','es_toxic']],on='id',how='left') 49 | test_sub_all = test_sub_all.merge(test_sub_es2[['id','es_toxic2']],on='id',how='left') 50 | test_sub_all = test_sub_all.merge(test_sub_ru[['id','ru_toxic']],on='id',how='left') 51 | test_sub_all = test_sub_all.merge(test_sub_ru2[['id','ru_toxic2']],on='id',how='left') 52 | test_sub_all = test_sub_all.merge(test_sub_fr[['id','fr_toxic']],on='id',how='left') 53 | test_sub_all = test_sub_all.merge(test_sub_fr2[['id','fr_toxic2']],on='id',how='left') 54 | 55 | ######## blend monolingual with single roberta 56 | test_sub_all['pred'] = np.where(test_sub_all.lang == 'tr',0.1*test_sub_all.toxic + 0.45*test_sub_all.tr_toxic + 0.45*test_sub_all.tr_toxic2,test_sub_all.toxic) 57 | test_sub_all['pred'] = np.where(test_sub_all.lang == 'it',0.1*test_sub_all.pred + 0.9*test_sub_all.it_toxic,test_sub_all.pred) 58 | test_sub_all['pred'] = np.where(test_sub_all.lang == 'es',0.1*test_sub_all.pred + 0.45*test_sub_all.es_toxic + 0.45*test_sub_all.es_toxic2,test_sub_all.pred) 59 | test_sub_all['pred'] = np.where(test_sub_all.lang == 'ru',0.1*test_sub_all.pred + 0.45*test_sub_all.ru_toxic + 0.45*test_sub_all.ru_toxic2 ,test_sub_all.pred) 60 | test_sub_all['pred'] = np.where(test_sub_all.lang == 'fr',0.1*test_sub_all.pred + 0.45*test_sub_all.fr_toxic + 0.45*test_sub_all.fr_toxic2,test_sub_all.pred) 61 | 62 | final_igor = test_sub_all[['id', 'lang','pred']] 63 | final_igor['igor_blend_rank'] = final_igor['pred'].rank(pct=True) 64 | 65 | ############## single Roberta from MAS pipeline 66 | single_roberta = pd.read_csv('Output/Models/Moiz/submission_Roberta.csv') 67 | single_roberta = single_roberta.rename(columns={'toxic': 'single_roberta'}) 68 | single_roberta['single_roberta_rank'] = single_roberta['single_roberta'].rank(pct=True) 69 | 70 | 71 | 72 | ############ RANK BLEND 73 | res = pd.merge(pd.merge(pd.merge(final_mlm, final_mas, on='id'), final_igor, on='id'), single_roberta, on='id') 74 | res['blend'] = 0.08 * res['single_roberta_rank'] + 0.29* res['mlm_blend_rank'] + 0.44*res['igor_blend_rank'] + 0.19*res['mas_blend_rank'] 75 | 76 | 77 | 78 | 79 | 80 | ############# POST PROCESSING Bot messages 81 | clusters = pd.read_csv('Input/Patrick/cluster200.csv')[['id', 'cluster']] 82 | test = pd.read_csv('Input/Igor/test.csv.zip') 83 | test['has_ip'] = test['content'].apply(lambda x: re.search(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', x)).map(lambda x: 0 if x is None else 1) 84 | test['has_ip'].value_counts() 85 | 86 | test['is_bot_msg'] = test['content'].str.strip().str.startswith('Bonjour,') & ( 87 | test['content'].str.strip().str.endswith('(bot de maintenance)') | 88 | test['content'].str.strip().str.endswith('Salebot (d)')) 89 | test.groupby('lang')['is_bot_msg'].value_counts() 90 | 91 | 92 | res = pd.merge(res, test[['id', 'has_ip', 'is_bot_msg']], on='id') 93 | res = pd.merge(res, clusters, on='id') 94 | 95 | res['blend_pp'] = np.where( 96 | (res['has_ip']==1) & 97 | (res['lang'].isin(['es', 'tr'])) & 98 | ((0.5 < res['blend']) & (res['blend'] < 0.9)), 1.1*res['blend'], res['blend']) 99 | res[['blend', 'blend_pp']].corr() 100 | 101 | 102 | res['blend_pp'] = np.where((res['is_bot_msg']==1) , 0.5*res['blend_pp'], res['blend_pp']) 103 | res[['blend', 'blend_pp']].corr() 104 | 105 | 106 | res['blend_pp'] = np.where((res['cluster'].isin([39, 52, 100, 14, 111, 80, 42, 196])) , 0.5*res['blend_pp'], res['blend_pp']) 107 | res[['blend', 'blend_pp']].corr() 108 | 109 | res = res.rename(columns={'blend_pp':'toxic'}) 110 | 111 | res[['id', 'toxic']].to_csv('Output/Predictions/submission.csv', index=False) 112 | -------------------------------------------------------------------------------- /directory_structure.txt: -------------------------------------------------------------------------------- 1 | ./Code 2 | ./Code/Moiz 3 | ./Code/Igor 4 | ./Code/Ujjwal 5 | ./Code/Patrick 6 | ./Input 7 | ./Input/Common 8 | ./Input/Moiz 9 | ./Input/Igor 10 | ./Input/Ujjwal 11 | ./Input/Patrick 12 | ./Output 13 | ./Output/Models 14 | ./Output/Models/Moiz 15 | ./Output/Models/Igor 16 | ./Output/Models/Ujjwal 17 | ./Output/Models/Patrick 18 | ./Output/Predictions 19 | ./Output/Predictions/Moiz 20 | ./Output/Predictions/Igor 21 | ./Output/Predictions/Ujjwal 22 | ./Output/Predictions/Patrick 23 | -------------------------------------------------------------------------------- /entry_points.md: -------------------------------------------------------------------------------- 1 | 1. python train.py, which would 2 | Read training data from Input directory 3 | Train models. 4 | Save your models to Output 5 | 2. python inference.py, which would 6 | Read test data from Input 7 | Load models from Output/Models/Igor/{lang}/*bin , Output/Models/Moiz/*h5 , Input/Ujjwal/Data/step-[2-3]/*h5 8 | Use models to make predictions on new samples 9 | Save your predictions to Output/Models/Igor/{lang}/*probs.csv Output/Predictions/Moiz/*csv Input/Ujjwal/Data/*tta.csv 10 | Blend predictions from all models and save end prediction file to Output/Predictions/submission.csv 11 | -------------------------------------------------------------------------------- /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moizsaifee/kaggle-jigsaw-multilingual-toxic-comment-classification-3rd-place-solution/6187b7c8c7e8c982d7aa6090b66684458f7072a8/img/1.png -------------------------------------------------------------------------------- /img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moizsaifee/kaggle-jigsaw-multilingual-toxic-comment-classification-3rd-place-solution/6187b7c8c7e8c982d7aa6090b66684458f7072a8/img/2.png -------------------------------------------------------------------------------- /img/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moizsaifee/kaggle-jigsaw-multilingual-toxic-comment-classification-3rd-place-solution/6187b7c8c7e8c982d7aa6090b66684458f7072a8/img/3.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import json 4 | import pandas as pd 5 | from functools import reduce 6 | 7 | with open('./SETTINGS.json') as f: 8 | global_config = json.load(f) 9 | 10 | 11 | def predict_moiz(): 12 | """ 13 | Set up the tuples of input arguments to be able to score all the desired folds of the model. This function 14 | expects the items listed under input to be present in the indicated directories 15 | 16 | Note: This function executes the python scripts through system call, so if any input / output / model is missing 17 | the code will not halt and continue to run - error messages would be displaced on screen though. 18 | 19 | - Input: 20 | Models: (location - at SETTINGS.json) 21 | Roberta - model_roberta_base.h5, model_roberta_fold0-3 22 | Bert - model_bert_base.h5, model_bert_fold0-3 23 | XLM - model_xlm_base.h5, model_xlm_fold0-2 24 | Scoring Input: (location - at SETTINGS.json) 25 | Preprocessed {Foreign, English, Foreign -> Foreign} x {Roberta, Bert, XLM} tokenized inputs 26 | - Output: (location - at SETTINGS.json) 27 | - pred_roberta_* 28 | - pred_bert_* 29 | - pred_xlm_* 30 | :return: nothing, everything written to file 31 | """ 32 | 33 | score_tuples = [ 34 | 35 | # Roberta - English model, just 1 fold 36 | ("test_roberta_english_p0.pkl", "english", "jplu/tf-xlm-roberta-large", "model_roberta_base.h5", "pred_roberta_english.csv"), 37 | 38 | # Roberta - Foreign, 4 fold 39 | ("test_roberta_foreign_split_p0.pkl", "foreign", "jplu/tf-xlm-roberta-large", "model_roberta_fold0.h5", "pred_roberta_foreign_fold0.csv"), 40 | ("test_roberta_foreign_split_p0.pkl", "foreign", "jplu/tf-xlm-roberta-large", "model_roberta_fold1.h5", "pred_roberta_foreign_fold1.csv"), 41 | ("test_roberta_foreign_split_p0.pkl", "foreign", "jplu/tf-xlm-roberta-large", "model_roberta_fold2.h5", "pred_roberta_foreign_fold2.csv"), 42 | ("test_roberta_foreign_split_p0.pkl", "foreign", "jplu/tf-xlm-roberta-large", "model_roberta_fold3.h5", "pred_roberta_foreign_fold3.csv"), 43 | 44 | # Roberta extra (translated foreign) - 4 fold 45 | ("test_roberta_extra_p0.pkl", "extra", "jplu/tf-xlm-roberta-large", "model_roberta_fold0.h5", "pred_roberta_extra_fold0.csv"), 46 | ("test_roberta_extra_p0.pkl", "extra", "jplu/tf-xlm-roberta-large", "model_roberta_fold1.h5", "pred_roberta_extra_fold1.csv"), 47 | ("test_roberta_extra_p0.pkl", "extra", "jplu/tf-xlm-roberta-large", "model_roberta_fold2.h5", "pred_roberta_extra_fold2.csv"), 48 | ("test_roberta_extra_p0.pkl", "extra", "jplu/tf-xlm-roberta-large", "model_roberta_fold3.h5", "pred_roberta_extra_fold3.csv"), 49 | 50 | # Bert - English model, just 1 fold 51 | ("test_bert_english_p0.pkl", "english", "bert-base-multilingual-cased", "model_bert_base.h5", "pred_bert_english.csv"), 52 | 53 | # Bert - Foreign, 4 fold 54 | ("test_bert_foreign_split_p0.pkl", "foreign", "bert-base-multilingual-cased", "model_bert_fold0.h5", "pred_bert_foreign_fold0.csv"), 55 | ("test_bert_foreign_split_p0.pkl", "foreign", "bert-base-multilingual-cased", "model_bert_fold1.h5", "pred_bert_foreign_fold1.csv"), 56 | ("test_bert_foreign_split_p0.pkl", "foreign", "bert-base-multilingual-cased", "model_bert_fold2.h5", "pred_bert_foreign_fold2.csv"), 57 | ("test_bert_foreign_split_p0.pkl", "foreign", "bert-base-multilingual-cased", "model_bert_fold3.h5", "pred_bert_foreign_fold3.csv"), 58 | 59 | # Bert extra (translated foreign) - 4 fold 60 | ("test_bert_extra_p0.pkl", "extra", "bert-base-multilingual-cased", "model_bert_fold0.h5", "pred_bert_extra_fold0.csv"), 61 | ("test_bert_extra_p0.pkl", "extra", "bert-base-multilingual-cased", "model_bert_fold1.h5", "pred_bert_extra_fold1.csv"), 62 | ("test_bert_extra_p0.pkl", "extra", "bert-base-multilingual-cased", "model_bert_fold2.h5", "pred_bert_extra_fold2.csv"), 63 | ("test_bert_extra_p0.pkl", "extra", "bert-base-multilingual-cased", "model_bert_fold3.h5", "pred_bert_extra_fold3.csv"), 64 | 65 | # XLM - English model, just 1 fold 66 | ("test_xlm_english_p0.pkl", "english", "xlm-mlm-100-1280", "model_xlm_base.h5", "pred_xlm_english.csv"), 67 | 68 | # XLM - Foreign, ~/.tr 3 fold 69 | ("test_xlm_foreign_split_p0.pkl", "foreign", "xlm-mlm-100-1280", "model_xlm_fold0.h5", "pred_xlm_foreign_fold0.csv"), 70 | ("test_xlm_foreign_split_p0.pkl", "foreign", "xlm-mlm-100-1280", "model_xlm_fold1.h5", "pred_xlm_foreign_fold1.csv"), 71 | ("test_xlm_foreign_split_p0.pkl", "foreign", "xlm-mlm-100-1280", "model_xlm_fold2.h5", "pred_xlm_foreign_fold2.csv"), 72 | 73 | ] 74 | for score_tuple in score_tuples: 75 | in_file, in_file_type, model_name, model_file, out_file = score_tuple 76 | cmd = f'python ./Code/Moiz/predict.py ' \ 77 | f'--dev_mode={global_config["mas_predict_dev_mode"]} ' \ 78 | f'--in_file="{global_config["mas_predict_inp_path"]}/{in_file}" ' \ 79 | f'--in_file_type="{in_file_type}" ' \ 80 | f'--out_file="{global_config["mas_predict_out_path"]}/{out_file}" ' \ 81 | f'--model_name="{model_name}" ' \ 82 | f'--model_file={global_config["mas_predict_model_path"]}/{model_file} ' 83 | print(cmd) 84 | os.system(cmd) 85 | 86 | 87 | def belnd_roberta_moiz(): 88 | print('Blending Roberta') 89 | # https://www.kaggle.com/moizsaifee/jigsaw-train-v10-step2-mcp-submission 90 | wd = os.getcwd() 91 | os.chdir("./Output/Predictions/Moiz/") 92 | res1 = pd.read_csv('./pred_roberta_foreign_fold0.csv').rename( 93 | columns={'toxic': 'toxic_f1'}) 94 | res2 = pd.read_csv('./pred_roberta_english.csv').rename( 95 | columns={'toxic': 'toxic_e1'}) 96 | res3 = pd.read_csv('./pred_roberta_extra_fold0.csv').rename( 97 | columns={'toxic': 'toxic_ex1'}) 98 | 99 | res4 = pd.read_csv('./pred_roberta_foreign_fold1.csv').rename( 100 | columns={'toxic': 'toxic_f2'}) 101 | res5 = pd.read_csv('./pred_roberta_english.csv').rename( 102 | columns={'toxic': 'toxic_e2'}) 103 | res6 = pd.read_csv('./pred_roberta_extra_fold1.csv').rename( 104 | columns={'toxic': 'toxic_ex2'}) 105 | 106 | res7 = pd.read_csv('./pred_roberta_foreign_fold2.csv').rename( 107 | columns={'toxic': 'toxic_f3'}) 108 | res8 = pd.read_csv('./pred_roberta_english.csv').rename( 109 | columns={'toxic': 'toxic_e3'}) 110 | res9 = pd.read_csv('./pred_roberta_extra_fold2.csv').rename( 111 | columns={'toxic': 'toxic_ex3'}) 112 | 113 | res10 = pd.read_csv('./pred_roberta_foreign_fold3.csv').rename( 114 | columns={'toxic': 'toxic_f4'}) 115 | res11 = pd.read_csv('./pred_roberta_english.csv').rename( 116 | columns={'toxic': 'toxic_e4'}) 117 | res12 = pd.read_csv('./pred_roberta_extra_fold3.csv').rename( 118 | columns={'toxic': 'toxic_ex4'}) 119 | 120 | res = reduce(lambda x, y: pd.merge(x, y, on='id'), 121 | [res1, res2, res3, res4, res5, res6, res7, res8, res9, res10, res11, res12]) 122 | 123 | res.head() 124 | 125 | K = 1. / 2 126 | res['toxic'] = (( 127 | (0.8 * res['toxic_f1'] + 0.2 * (0.5 * res['toxic_e1'] + 0.5 * res['toxic_ex1'])) ** K + 128 | (0.8 * res['toxic_f2'] + 0.2 * (0.5 * res['toxic_e2'] + 0.5 * res['toxic_ex2'])) ** K + 129 | (0.8 * res['toxic_f3'] + 0.2 * (0.5 * res['toxic_e3'] + 0.5 * res['toxic_ex3'])) ** K + 130 | (0.8 * res['toxic_f4'] + 0.2 * (0.5 * res['toxic_e4'] + 0.5 * res['toxic_ex4'])) ** K 131 | ) / 4) ** (1 / K) 132 | os.chdir(wd) 133 | res.to_csv('./Output/Predictions/Moiz/submission_Roberta.csv') 134 | return res[['id', 'toxic']] 135 | 136 | 137 | def blend_bert_moiz(): 138 | print('Blending Bert') 139 | 140 | wd = os.getcwd() 141 | os.chdir("./Output/Predictions/Moiz/") 142 | # https://www.kaggle.com/moizsaifee/bert-jigsaw-train-v10-step2-mcp-submission 143 | res1 = pd.read_csv('./pred_bert_foreign_fold0.csv').rename( 144 | columns={'toxic': 'toxic_f1'}) 145 | res2 = pd.read_csv('./pred_bert_english.csv').rename( 146 | columns={'toxic': 'toxic_e1'}) 147 | res3 = pd.read_csv('./pred_bert_extra_fold0.csv').rename( 148 | columns={'toxic': 'toxic_ex1'}) 149 | 150 | res4 = pd.read_csv('./pred_bert_foreign_fold1.csv').rename( 151 | columns={'toxic': 'toxic_f2'}) 152 | res5 = pd.read_csv('./pred_bert_english.csv').rename( 153 | columns={'toxic': 'toxic_e2'}) 154 | res6 = pd.read_csv('./pred_bert_extra_fold0.csv').rename( 155 | columns={'toxic': 'toxic_ex2'}) 156 | 157 | res7 = pd.read_csv('./pred_bert_foreign_fold2.csv').rename( 158 | columns={'toxic': 'toxic_f3'}) 159 | res8 = pd.read_csv('./pred_bert_english.csv').rename( 160 | columns={'toxic': 'toxic_e3'}) 161 | res9 = pd.read_csv('./pred_bert_extra_fold0.csv').rename( 162 | columns={'toxic': 'toxic_ex3'}) 163 | 164 | res10 = pd.read_csv('./pred_bert_foreign_fold3.csv').rename( 165 | columns={'toxic': 'toxic_f4'}) 166 | res11 = pd.read_csv('./pred_bert_english.csv').rename( 167 | columns={'toxic': 'toxic_e4'}) 168 | res12 = pd.read_csv('./pred_bert_extra_fold0.csv').rename( 169 | columns={'toxic': 'toxic_ex4'}) 170 | 171 | res = reduce(lambda x, y: pd.merge(x, y, on='id'), 172 | [res1, res2, res3, res4, res5, res6, res7, res8, res9, res10, res11, res12]) # ] 173 | 174 | K = 1. / 2 175 | res['toxic'] = (( 176 | (0.8 * res['toxic_f1'] + 0.2 * (0.5 * res['toxic_e1'] + 0.5 * res['toxic_ex1'])) ** K + 177 | (0.8 * res['toxic_f2'] + 0.2 * (0.5 * res['toxic_e2'] + 0.5 * res['toxic_ex2'])) ** K + 178 | (0.8 * res['toxic_f3'] + 0.2 * (0.5 * res['toxic_e3'] + 0.5 * res['toxic_ex3'])) ** K + 179 | (0.8 * res['toxic_f4'] + 0.2 * (0.5 * res['toxic_e4'] + 0.5 * res['toxic_ex4'])) ** K 180 | ) / 4) ** (1 / K) 181 | os.chdir(wd) 182 | return res[['id', 'toxic']] 183 | 184 | 185 | def blend_xlm_moiz(): 186 | print('Blending XLM') 187 | wd = os.getcwd() 188 | os.chdir("./Output/Predictions/Moiz/") 189 | # https://www.kaggle.com/moizsaifee/xlm-jigsaw-train-v10-step2-mcp-submission 190 | res1 = pd.read_csv('./pred_xlm_foreign_fold0.csv').rename( 191 | columns={'toxic': 'toxic_f1'}) 192 | res2 = pd.read_csv('./pred_xlm_english.csv').rename( 193 | columns={'toxic': 'toxic_e1'}) 194 | 195 | res4 = pd.read_csv('./pred_xlm_foreign_fold1.csv').rename( 196 | columns={'toxic': 'toxic_f2'}) 197 | res5 = pd.read_csv('./pred_xlm_english.csv').rename( 198 | columns={'toxic': 'toxic_e2'}) 199 | 200 | res7 = pd.read_csv('./pred_xlm_foreign_fold2.csv').rename( 201 | columns={'toxic': 'toxic_f3'}) 202 | res8 = pd.read_csv('./pred_xlm_english.csv').rename( 203 | columns={'toxic': 'toxic_e3'}) 204 | res = reduce(lambda x, y: pd.merge(x, y, on='id'), [res1, res2, res4, res5, res7, res8]) 205 | 206 | K=1/2 207 | res['toxic'] = (( 208 | (0.8*res['toxic_f1'] + 0.2*(res['toxic_e1']))**K + 209 | (0.8*res['toxic_f2'] + 0.2*(res['toxic_e2']))**K + 210 | (0.8*res['toxic_f3'] + 0.2*(res['toxic_e3']))**K 211 | )/3)**(1/K) 212 | # res['toxic'] = res['toxic_x'] 213 | os.chdir(wd) 214 | return res[['id', 'toxic']] 215 | 216 | 217 | def blend_moiz(): 218 | # https://www.kaggle.com/moizsaifee/moiz-submissions-final-blender?scriptVersionId=35417364 219 | """ 220 | Assuming all the input files are generated using predict_moiz() 221 | This function generate's Moiz's blend 222 | :return: 223 | """ 224 | res_roberta = belnd_roberta_moiz().rename(columns={'toxic': 'toxic_roberta_xlm'}) 225 | res_bert = blend_bert_moiz().rename(columns={'toxic': 'toxic_bert'}) 226 | res_xlm = blend_xlm_moiz().rename(columns={'toxic': 'toxic_xlm'}) 227 | 228 | res = pd.merge(pd.merge(res_roberta, res_bert, on='id'), res_xlm, on='id') 229 | 230 | # The prob blend which scored 9456 231 | # https://www.kaggle.com/moizsaifee/moiz-submissions-final-blender?scriptVersionId=35344067 232 | K = 1 233 | res['toxic_prob'] = (( 234 | (9 * (res['toxic_roberta_xlm'] ** (1 / K))) + 235 | (2 * (res['toxic_xlm'] ** (1 / K))) + 236 | (1 * (res['toxic_bert'] ** (1 / K))) 237 | ) / 12) ** K 238 | 239 | 240 | # Gen the rank blend which scored 9457 241 | # https://www.kaggle.com/moizsaifee/moiz-submissions-final-blender?scriptVersionId=35417364 242 | res['toxic_roberta_xlm'] = res['toxic_roberta_xlm'].rank(pct=True) 243 | res['toxic_xlm'] = res['toxic_xlm'].rank(pct=True) 244 | res['toxic_bert'] = res['toxic_bert'].rank(pct=True) 245 | K = 1 246 | res['toxic'] = (( 247 | (9 * (res['toxic_roberta_xlm'] ** (1 / K))) + 248 | (2 * (res['toxic_xlm'] ** (1 / K))) + 249 | (1 * (res['toxic_bert'] ** (1 / K))) 250 | ) / 12) ** K 251 | 252 | return res[['id', 'toxic', 'toxic_prob']] 253 | 254 | 255 | def predict_ujjwal(): 256 | os.system('/bin/bash inference.sh') 257 | 258 | def predict_igor(): 259 | os.system('pip install torchvision > /dev/null') 260 | os.system('curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py') 261 | os.system('python pytorch-xla-env-setup.py --version 20200420 --apt-packages libomp5 libopenblas-dev ') 262 | os.system('pip install transformers==2.11.0 > /dev/null') 263 | os.system('pip install pandarallel > /dev/null') 264 | os.system('pip install catalyst==20.4.2 > /dev/null') 265 | 266 | models_list = [['./Output/Models/Igor/fr/inf_fr_yandex_camembert_large','camembert/camembert-large'], 267 | ['./Output/Models/Igor/fr/inf_fr_google_camembert_large','camembert/camembert-large'], 268 | ['./Output/Models/Igor/es/inf_es_google_wwm','dccuchile/bert-base-spanish-wwm-cased'], 269 | ['./Output/Models/Igor/es/inf_es_yandex_wwm','dccuchile/bert-base-spanish-wwm-cased'], 270 | ['./Output/Models/Igor/ru/inf_ru_google_conv','DeepPavlov/rubert-base-cased-conversational'], 271 | ['./Output/Models/Igor/ru/inf_ru_yandex_conv','DeepPavlov/rubert-base-cased-conversational'], 272 | ['./Output/Models/Igor/tr/inf_tr_google_dbmdz','dbmdz/bert-base-turkish-cased'], 273 | ['./Output/Models/Igor/tr/inf_tr_google_savasy','savasy/bert-turkish-text-classification'], 274 | ['./Output/Models/Igor/it/inf_it_yandex_xxl','dbmdz/bert-base-italian-xxl-uncased']] 275 | 276 | for p in models_list: 277 | model_file_prefix = p[0] 278 | backbone = p[1] 279 | cmd = f'python "./Code/Igor/predict.py" --in_file="./Input/Igor/test.csv.zip" ' \ 280 | f'--model_file_prefix="{model_file_prefix}" ' \ 281 | f'--backbone="{backbone}" ' 282 | print(cmd) 283 | os.system(cmd) 284 | 285 | 286 | if __name__ == '__main__': 287 | predict_moiz() 288 | res = blend_moiz() 289 | res.to_csv('./Output/Predictions/Moiz/submission_MAS.csv') 290 | 291 | os.chdir('Code/Ujjwal') 292 | predict_ujjwal() 293 | 294 | os.chdir('../../') 295 | predict_igor() 296 | 297 | os.system('python blend.py') 298 | 299 | -------------------------------------------------------------------------------- /prepare_data_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | with open('./SETTINGS.json') as f: 6 | global_config = json.load(f) 7 | 8 | 9 | def data_prep_moiz(): 10 | data_typles = [ 11 | ('test_foreign.csv', 'jplu/tf-xlm-roberta-large', 'split', 'test_roberta_foreign_split'), 12 | ('test_english.csv', 'jplu/tf-xlm-roberta-large', 'ignore', 'test_roberta_english'), 13 | ('test_extra.csv', 'jplu/tf-xlm-roberta-large', 'ignore', 'test_roberta_extra'), 14 | 15 | ('test_foreign.csv', 'bert-base-multilingual-cased', 'ignore', 'test_bert_foreign_split'), 16 | ('test_english.csv', 'bert-base-multilingual-cased', 'ignore', 'test_bert_english'), 17 | ('test_extra.csv', 'bert-base-multilingual-cased', 'ignore', 'test_bert_extra'), 18 | 19 | ('test_foreign.csv', 'xlm-mlm-100-1280', 'ignore', 'test_xlm_foreign_split'), 20 | ('test_english.csv', 'xlm-mlm-100-1280', 'ignore', 'test_xlm_english'), 21 | ('test_extra.csv', 'xlm-mlm-100-1280', 'ignore', 'test_xlm_extra'), 22 | 23 | ] 24 | in_file_type= 'test' 25 | should_chunk = 0 26 | max_chunk = 0 27 | 28 | for data_tuple in data_typles: 29 | in_file, model_name, long_comment_action, out_file = data_tuple 30 | cmd = f'python ./Code/Moiz/data_prep.py ' \ 31 | f'--in_file="{global_config["mas_data_prep_inp_path"]}/Processed_Ujjwal/{in_file}" ' \ 32 | f'--in_file_type="{in_file_type}" ' \ 33 | f'--model_name="{model_name}" ' \ 34 | f'--should_chunk={should_chunk} ' \ 35 | f'--max_chunk={max_chunk} ' \ 36 | f'--long_comment_action="{long_comment_action}" ' \ 37 | f'--out_dir="{global_config["mas_predict_inp_path"]}" ' \ 38 | f'--out_file="{out_file}" ' 39 | print(cmd) 40 | os.system(cmd) 41 | 42 | 43 | if __name__ == '__main__': 44 | data_prep_moiz() 45 | -------------------------------------------------------------------------------- /prepare_data_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import numpy as np 5 | import time 6 | 7 | with open('./SETTINGS.json') as f: 8 | global_config = json.load(f) 9 | 10 | 11 | def data_preproces_step1_moiz(): 12 | """ 13 | - This function does column rename, minor pre-procesing of the input files read from mas_data_prep_inp_path/Raw/ 14 | and saves in the mas_data_prep_out_path/processed/ path 15 | - mas_data_prep_out_path, mas_data_prep_inp_path are specified in SETTINGS.json 16 | 17 | Need the following files as input: 18 | 1) Jigsaw 2018 English Train Data 19 | 2) Jigsaw 2019 English Train Data 20 | 3) Jigsaw 2020 Validation Data 21 | 4) Jigsaw English -> Foreign Translated Data (Google / Yandex) 22 | 5) Jigsaw 2020 Test Data 23 | 6) Pseudo Labels for Test (Using Public Kernel Output, Not a Major driver of performance) 24 | 7) Subtitles Data - pre-processed and pseudo lables created using an English trained model 25 | 26 | :return: 27 | """ 28 | 29 | ## Train English 30 | print(f'{time.ctime()}: Processing Train English') 31 | inp_2018 = pd.read_csv(f'{global_config["mas_data_prep_inp_path"]}/Raw/jigsaw-toxic-comment-train.csv') 32 | inp_2018['non_toxic_label_max'] = inp_2018[ 33 | ['severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].apply(lambda x: np.max(x), axis=1) 34 | inp_2018['toxic_float'] = np.maximum(inp_2018['toxic'], inp_2018['non_toxic_label_max']) 35 | inp_2018['lang'] = 'en' 36 | inp_2018 = inp_2018[['id', 'lang', 'comment_text', 'toxic_float']] 37 | 38 | inp_2019 = pd.read_csv(f'{global_config["mas_data_prep_inp_path"]}/Raw/jigsaw-unintended-bias-train.csv') 39 | inp_2019 = inp_2019[['id', 'comment_text', 'toxic', 'severe_toxicity', 40 | 'obscene', 'identity_attack', 'insult', 'threat']] 41 | for col in ['toxic', 'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat']: 42 | inp_2019[col] = inp_2019[col].round(1) 43 | # If any of other label is set, count that in toxic too 44 | inp_2019['non_toxic_label_max'] = inp_2019[ 45 | ['severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack']].apply(lambda x: np.max(x), axis=1) 46 | inp_2019['toxic_float'] = np.maximum(inp_2019['toxic'], inp_2019['non_toxic_label_max']) 47 | inp_2019['lang'] = 'en' 48 | inp_2019 = inp_2019[['id', 'lang', 'comment_text', 'toxic_float']] 49 | train_english = pd.concat([inp_2018, inp_2019], axis=0).reset_index(drop=True) 50 | train_english.to_csv(f'{global_config["mas_data_prep_out_path"]}/processed/train_english.csv') 51 | 52 | print(f'{time.ctime()}: Processing Validation') 53 | ## Preprocess Valid Data 54 | valid = pd.read_csv(f'{global_config["mas_data_prep_inp_path"]}/Raw/validation.csv') 55 | valid = valid.rename(columns={'toxic': 'toxic_float'}) 56 | valid = valid[['id', 'lang', 'comment_text', 'toxic_float']] 57 | valid.to_csv(f'{global_config["mas_data_prep_out_path"]}/processed/valid_foreign.csv') 58 | 59 | print(f'{time.ctime()}: Processing Train Foreign') 60 | ## Train Foreign 61 | train_foreign_ujjwal = pd.read_csv(f'{global_config["mas_data_prep_inp_path"]}/Processed_Ujjwal/train_foreign.csv') 62 | train_foreign_ujjwal['id'] = train_foreign_ujjwal['id'].astype('str') 63 | train_label = pd.read_csv(f'{global_config["mas_data_prep_out_path"]}/processed/train_english.csv', low_memory=False) 64 | train_label['id'] = train_label['id'].astype('str') 65 | train_foreign_ujjwal = pd.merge(train_foreign_ujjwal, train_label[['id', 'toxic_float']], on='id', how='left') 66 | train_foreign_ujjwal = train_foreign_ujjwal[~train_foreign_ujjwal['toxic_float'].isnull()].reset_index(drop=True) 67 | train_foreign_ujjwal = train_foreign_ujjwal[['id', 'lang', 'comment_text', 'toxic_float']] 68 | train_foreign_ujjwal.to_csv(f'{global_config["mas_data_prep_out_path"]}/processed/train_foreign.csv') 69 | 70 | print(f'{time.ctime()}: Processing Test Foreign') 71 | # Test Foreign 72 | test = pd.read_csv(f'{global_config["mas_data_prep_inp_path"]}/Raw/test.csv') 73 | test = test.rename(columns={'content': 'comment_text'}) 74 | test = test[['id', 'lang', 'comment_text']] 75 | test.to_csv(f'{global_config["mas_data_prep_out_path"]}/processed/test_foreign.csv', index=False) 76 | 77 | print(f'{time.ctime()}: Processing Test Pseudo') 78 | ## Pseudo 79 | inp = pd.read_csv(f'{global_config["mas_data_prep_out_path"]}/processed/test_foreign.csv') 80 | inp_labels = pd.read_csv(f'{global_config["mas_data_prep_inp_path"]}/Raw/submission_public_lb9463.csv') 81 | inp = pd.merge(inp, inp_labels, on='id') 82 | inp['toxic_float'] = inp['toxic'].round(1) 83 | del inp['toxic'] 84 | inp.to_csv(f'{global_config["mas_data_prep_out_path"]}/processed/test_pseudo.csv', index=False) 85 | 86 | # Subtitles Data 87 | print(f'{time.ctime()}: Processing Subtitles') 88 | inp = pd.read_csv(f'{global_config["mas_data_prep_inp_path"]}/Raw/subtitle_pseudo.csv') 89 | inp = inp[['id', 'lang', 'comment_text', 'toxic_float']] 90 | inp.to_csv(f'{global_config["mas_data_prep_out_path"]}/processed/subtitle_pseudo.csv', index=False) 91 | 92 | 93 | def data_preproces_step2_moiz(): 94 | """ 95 | - This function takes as input files created in step1 and converts text into tokens and saves in appropriate model 96 | directory 97 | - Big files are chunked in files with 100K or so record so that we don't go out of memory when training 98 | 99 | :return: 100 | """ 101 | prefix_mapping = { 102 | 'jplu/tf-xlm-roberta-large': 'roberta', 103 | 'bert-base-multilingual-cased': 'bert', 104 | 'xlm-mlm-100-1280': 'xlm' 105 | } 106 | in_file_type = 'train' 107 | long_comment_action = 'ignore' 108 | 109 | for model_name in ['bert-base-multilingual-cased', 'jplu/tf-xlm-roberta-large', 'xlm-mlm-100-1280']: 110 | model_prefix = prefix_mapping[model_name] 111 | data_tuples = [ 112 | ('train_foreign.csv', 'train_foreign', 1, 6), 113 | ('subtitle_pseudo.csv', 'subtitle', 1, 4), 114 | ('train_english.csv', 'train_english', 1, 2), 115 | ('test_pseudo.csv', 'test_pseudo', 0, 0), 116 | ('valid_foreign.csv', 'valid_foreign', 0, 0) 117 | ] 118 | for data_tuple in data_tuples: 119 | in_file, out_file, should_chunk, max_chunk = data_tuple 120 | print(f'{time.ctime()} Processing {in_file}') 121 | cmd = f'python ./Code/Moiz/data_prep.py ' \ 122 | f'--in_file="{global_config["mas_data_prep_out_path"]}/processed/{in_file}" ' \ 123 | f'--in_file_type="{in_file_type}" ' \ 124 | f'--model_name="{model_name}" ' \ 125 | f'--should_chunk={should_chunk} ' \ 126 | f'--max_chunk={max_chunk} ' \ 127 | f'--long_comment_action="{long_comment_action}" ' \ 128 | f'--out_dir="{global_config["mas_data_prep_out_path"]}/{model_prefix}/" ' \ 129 | f'--out_file="{out_file}" ' 130 | print(cmd) 131 | os.system(cmd) 132 | 133 | 134 | if __name__ == '__main__': 135 | data_preproces_step1_moiz() 136 | data_preproces_step2_moiz() -------------------------------------------------------------------------------- /pytorch-xla-env-setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Sample usage: 3 | # python env-setup.py --version 1.5 --apt-packages libomp5 4 | import argparse 5 | import collections 6 | from datetime import datetime 7 | import os 8 | import platform 9 | import re 10 | import requests 11 | import subprocess 12 | import threading 13 | 14 | VersionConfig = collections.namedtuple('VersionConfig', 15 | ['wheels', 'tpu', 'py_version']) 16 | OLDEST_VERSION = datetime.strptime('20200318', '%Y%m%d') 17 | DIST_BUCKET = 'gs://tpu-pytorch/wheels' 18 | TORCH_WHEEL_TMPL = 'torch-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' 19 | TORCH_XLA_WHEEL_TMPL = 'torch_xla-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' 20 | TORCHVISION_WHEEL_TMPL = 'torchvision-{whl_version}-cp{py_version}-cp{py_version}m-linux_x86_64.whl' 21 | 22 | 23 | def update_tpu_runtime(tpu_name, version): 24 | print(f'Updating TPU runtime to {version.tpu} ...') 25 | 26 | try: 27 | import cloud_tpu_client 28 | except ImportError: 29 | subprocess.call(['pip', 'install', 'cloud-tpu-client']) 30 | import cloud_tpu_client 31 | 32 | client = cloud_tpu_client.Client(tpu_name) 33 | client.configure_tpu_version(version.tpu) 34 | print('Done updating TPU runtime') 35 | 36 | 37 | def get_py_version(): 38 | version_tuple = platform.python_version_tuple() 39 | return version_tuple[0] + version_tuple[1] # major_version + minor_version 40 | 41 | 42 | def get_version(version): 43 | if version == 'nightly': 44 | return VersionConfig('nightly', 'pytorch-nightly', get_py_version()) 45 | 46 | version_date = None 47 | try: 48 | version_date = datetime.strptime(version, '%Y%m%d') 49 | except ValueError: 50 | pass # Not a dated nightly. 51 | 52 | if version_date: 53 | if version_date < OLDEST_VERSION: 54 | raise ValueError(f'Oldest nightly version available is {OLDEST_VERSION}') 55 | return VersionConfig(f'nightly+{version}', f'pytorch-dev{version}', 56 | get_py_version()) 57 | 58 | version_regex = re.compile('^(\d+\.)+\d+$') 59 | if not version_regex.match(version): 60 | raise ValueError(f'{version} is an invalid torch_xla version pattern') 61 | return VersionConfig(version, f'pytorch-{version}', get_py_version()) 62 | 63 | 64 | def install_vm(version, apt_packages, is_root=False): 65 | torch_whl = TORCH_WHEEL_TMPL.format( 66 | whl_version=version.wheels, py_version=version.py_version) 67 | torch_whl_path = os.path.join(DIST_BUCKET, torch_whl) 68 | torch_xla_whl = TORCH_XLA_WHEEL_TMPL.format( 69 | whl_version=version.wheels, py_version=version.py_version) 70 | torch_xla_whl_path = os.path.join(DIST_BUCKET, torch_xla_whl) 71 | torchvision_whl = TORCHVISION_WHEEL_TMPL.format( 72 | whl_version=version.wheels, py_version=version.py_version) 73 | torchvision_whl_path = os.path.join(DIST_BUCKET, torchvision_whl) 74 | apt_cmd = ['apt-get', 'install', '-y'] 75 | apt_cmd.extend(apt_packages) 76 | 77 | if not is_root: 78 | # Colab/Kaggle run as root, but not GCE VMs so we need privilege 79 | apt_cmd.insert(0, 'sudo') 80 | 81 | installation_cmds = [ 82 | ['pip', 'uninstall', '-y', 'torch', 'torchvision'], 83 | ['gsutil', 'cp', torch_whl_path, '.'], 84 | ['gsutil', 'cp', torch_xla_whl_path, '.'], 85 | ['gsutil', 'cp', torchvision_whl_path, '.'], 86 | ['pip', 'install', torch_whl], 87 | ['pip', 'install', torch_xla_whl], 88 | ['pip', 'install', torchvision_whl], 89 | apt_cmd, 90 | ] 91 | for cmd in installation_cmds: 92 | subprocess.call(cmd) 93 | 94 | 95 | def run_setup(args): 96 | version = get_version(args.version) 97 | # Update TPU 98 | print('Updating TPU and VM. This may take around 2 minutes.') 99 | update = threading.Thread( 100 | target=update_tpu_runtime, args=( 101 | args.tpu, 102 | version, 103 | )) 104 | update.start() 105 | install_vm(version, args.apt_packages, is_root=not args.tpu) 106 | update.join() 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument( 112 | '--version', 113 | type=str, 114 | default='20200515', 115 | help='Versions to install (nightly, release version, or YYYYMMDD).', 116 | ) 117 | parser.add_argument( 118 | '--apt-packages', 119 | nargs='+', 120 | default=['libomp5'], 121 | help='List of apt packages to install', 122 | ) 123 | parser.add_argument( 124 | '--tpu', 125 | type=str, 126 | help='[GCP] Name of the TPU (same zone, project as VM running script)', 127 | ) 128 | args = parser.parse_args() 129 | run_setup(args) 130 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | alabaster==0.7.12 3 | albumentations==0.1.12 4 | altair==4.1.0 5 | asgiref==3.2.10 6 | astor==0.8.1 7 | astropy==4.0.1.post1 8 | astunparse==1.6.3 9 | atari-py==0.2.6 10 | atomicwrites==1.4.0 11 | attrs==19.3.0 12 | audioread==2.1.8 13 | autograd==1.3 14 | Babel==2.8.0 15 | backcall==0.2.0 16 | beautifulsoup4==4.6.3 17 | bleach==3.1.5 18 | blis==0.4.1 19 | bokeh==1.4.0 20 | boto==2.49.0 21 | boto3==1.14.9 22 | botocore==1.17.9 23 | Bottleneck==1.3.2 24 | branca==0.4.1 25 | bs4==0.0.1 26 | CacheControl==0.12.6 27 | cachetools==4.1.0 28 | catalogue==1.0.0 29 | certifi==2020.6.20 30 | cffi==1.14.0 31 | chainer==6.5.0 32 | chardet==3.0.4 33 | click==7.1.2 34 | cloudpickle==1.3.0 35 | cmake==3.12.0 36 | cmdstanpy==0.4.0 37 | colorlover==0.3.0 38 | community==1.0.0b1 39 | contextlib2==0.5.5 40 | convertdate==2.2.1 41 | coverage==3.7.1 42 | coveralls==0.5 43 | crcmod==1.7 44 | cufflinks==0.17.3 45 | cvxopt==1.2.5 46 | cvxpy==1.0.31 47 | cycler==0.10.0 48 | cymem==2.0.3 49 | Cython==0.29.20 50 | daft==0.0.4 51 | dask==2.12.0 52 | dataclasses==0.7 53 | datascience==0.10.6 54 | decorator==4.4.2 55 | defusedxml==0.6.0 56 | descartes==1.1.0 57 | dill==0.3.2 58 | distributed==1.25.3 59 | Django==3.0.7 60 | dlib==19.18.0 61 | docopt==0.6.2 62 | docutils==0.15.2 63 | dopamine-rl==1.0.5 64 | earthengine-api==0.1.226 65 | easydict==1.9 66 | ecos==2.0.7.post1 67 | editdistance==0.5.3 68 | en-core-web-sm==2.2.5 69 | entrypoints==0.3 70 | ephem==3.7.7.1 71 | et-xmlfile==1.0.1 72 | fa2==0.3.5 73 | fancyimpute==0.4.3 74 | fastai==1.0.61 75 | fastdtw==0.3.4 76 | fastprogress==0.2.3 77 | fastrlock==0.5 78 | fbprophet==0.6 79 | feather-format==0.4.1 80 | featuretools==0.4.1 81 | filelock==3.0.12 82 | firebase-admin==4.1.0 83 | fix-yahoo-finance==0.0.22 84 | Flask==1.1.2 85 | folium==0.8.3 86 | fsspec==0.7.4 87 | future==0.16.0 88 | gast==0.3.3 89 | GDAL==2.2.2 90 | gdown==3.6.4 91 | gensim==3.6.0 92 | geographiclib==1.50 93 | geopy==1.17.0 94 | gin-config==0.3.0 95 | glob2==0.7 96 | google==2.0.3 97 | google-api-core==1.16.0 98 | google-api-python-client==1.7.12 99 | google-auth==1.17.2 100 | google-auth-httplib2==0.0.3 101 | google-auth-oauthlib==0.4.1 102 | google-cloud-bigquery==1.21.0 103 | google-cloud-core==1.0.3 104 | google-cloud-datastore==1.8.0 105 | google-cloud-firestore==1.7.0 106 | google-cloud-language==1.2.0 107 | google-cloud-storage==1.18.1 108 | google-cloud-translate==1.5.0 109 | google-colab==1.0.0 110 | google-pasta==0.2.0 111 | google-resumable-media==0.4.1 112 | googleapis-common-protos==1.52.0 113 | googledrivedownloader==0.4 114 | graphviz==0.10.1 115 | grpcio==1.30.0 116 | gspread==3.0.1 117 | gspread-dataframe==3.0.7 118 | gym==0.17.2 119 | h5py==2.10.0 120 | HeapDict==1.0.1 121 | holidays==0.9.12 122 | html5lib==1.0.1 123 | httpimport==0.5.18 124 | httplib2==0.17.4 125 | httplib2shim==0.0.3 126 | humanize==0.5.1 127 | hyperopt==0.1.2 128 | ideep4py==2.0.0.post3 129 | idna==2.9 130 | image==1.5.32 131 | imageio==2.4.1 132 | imagesize==1.2.0 133 | imbalanced-learn==0.4.3 134 | imblearn==0.0 135 | imgaug==0.2.9 136 | importlib-metadata==1.6.1 137 | imutils==0.5.3 138 | inflect==2.1.0 139 | intel-openmp==2020.0.133 140 | intervaltree==2.1.0 141 | ipykernel==4.10.1 142 | ipython==5.5.0 143 | ipython-genutils==0.2.0 144 | ipython-sql==0.3.9 145 | ipywidgets==7.5.1 146 | itsdangerous==1.1.0 147 | jax==0.1.69 148 | jaxlib==0.1.47 149 | jdcal==1.4.1 150 | jedi==0.17.1 151 | jieba==0.42.1 152 | Jinja2==2.11.2 153 | jmespath==0.10.0 154 | joblib==0.15.1 155 | jpeg4py==0.1.4 156 | jsonschema==2.6.0 157 | jupyter==1.0.0 158 | jupyter-client==5.3.4 159 | jupyter-console==5.2.0 160 | jupyter-core==4.6.3 161 | kaggle==1.5.6 162 | kapre==0.1.3.1 163 | Keras==2.3.1 164 | Keras-Applications==1.0.8 165 | Keras-Preprocessing==1.1.2 166 | keras-vis==0.4.1 167 | kiwisolver==1.2.0 168 | knnimpute==0.1.0 169 | librosa==0.6.3 170 | lightgbm==2.2.3 171 | llvmlite==0.31.0 172 | lmdb==0.98 173 | lucid==0.3.8 174 | LunarCalendar==0.0.9 175 | lxml==4.2.6 176 | Markdown==3.2.2 177 | MarkupSafe==1.1.1 178 | matplotlib==3.2.2 179 | matplotlib-venn==0.11.5 180 | missingno==0.4.2 181 | mistune==0.8.4 182 | mizani==0.6.0 183 | mkl==2019.0 184 | mlxtend==0.14.0 185 | more-itertools==8.4.0 186 | moviepy==0.2.3.5 187 | mpmath==1.1.0 188 | msgpack==1.0.0 189 | multiprocess==0.70.10 190 | multitasking==0.0.9 191 | murmurhash==1.0.2 192 | music21==5.5.0 193 | natsort==5.5.0 194 | nbconvert==5.6.1 195 | nbformat==5.0.7 196 | networkx==2.4 197 | nibabel==3.0.2 198 | nltk==3.2.5 199 | notebook==5.2.2 200 | np-utils==0.5.12.1 201 | numba==0.48.0 202 | numexpr==2.7.1 203 | numpy==1.18.5 204 | nvidia-ml-py3==7.352.0 205 | oauth2client==4.1.3 206 | oauthlib==3.1.0 207 | okgrade==0.4.3 208 | opencv-contrib-python==4.1.2.30 209 | opencv-python==4.1.2.30 210 | openpyxl==2.5.9 211 | opt-einsum==3.2.1 212 | osqp==0.6.1 213 | packaging==20.4 214 | palettable==3.3.0 215 | pandas==1.0.5 216 | pandas-datareader==0.8.1 217 | pandas-gbq==0.11.0 218 | pandas-profiling==1.4.1 219 | pandocfilters==1.4.2 220 | parso==0.7.0 221 | pathlib==1.0.1 222 | patsy==0.5.1 223 | pexpect==4.8.0 224 | pickleshare==0.7.5 225 | Pillow==7.0.0 226 | pip-tools==4.5.1 227 | plac==1.1.3 228 | plotly==4.4.1 229 | plotnine==0.6.0 230 | pluggy==0.7.1 231 | portpicker==1.3.1 232 | prefetch-generator==1.0.1 233 | preshed==3.0.2 234 | prettytable==0.7.2 235 | progressbar2==3.38.0 236 | prometheus-client==0.8.0 237 | promise==2.3 238 | prompt-toolkit==1.0.18 239 | protobuf==3.10.0 240 | psutil==5.4.8 241 | psycopg2==2.7.6.1 242 | ptyprocess==0.6.0 243 | py==1.8.2 244 | pyarrow==0.14.1 245 | pyasn1==0.4.8 246 | pyasn1-modules==0.2.8 247 | pycocotools==2.0.1 248 | pycparser==2.20 249 | pydata-google-auth==1.1.0 250 | pydot==1.3.0 251 | pydot-ng==2.0.0 252 | pydotplus==2.0.2 253 | PyDrive==1.3.1 254 | pyemd==0.5.1 255 | pyglet==1.5.0 256 | Pygments==2.1.3 257 | pygobject==3.26.1 258 | pymc3==3.7 259 | PyMeeus==0.3.7 260 | pymongo==3.10.1 261 | pymystem3==0.2.0 262 | PyOpenGL==3.1.5 263 | pyparsing==2.4.7 264 | pyrsistent==0.16.0 265 | pysndfile==1.3.8 266 | PySocks==1.7.1 267 | pystan==2.19.1.1 268 | pytest==3.6.4 269 | python-apt==1.6.5+ubuntu0.3 270 | python-chess==0.23.11 271 | python-dateutil==2.8.1 272 | python-louvain==0.14 273 | python-slugify==4.0.0 274 | python-utils==2.4.0 275 | pytz==2018.9 276 | PyWavelets==1.1.1 277 | PyYAML==3.13 278 | pyzmq==19.0.1 279 | qtconsole==4.7.5 280 | QtPy==1.9.0 281 | regex==2019.12.20 282 | requests==2.23.0 283 | requests-oauthlib==1.3.0 284 | resampy==0.2.2 285 | retrying==1.3.3 286 | rpy2==3.2.7 287 | rsa==4.6 288 | s3fs==0.4.2 289 | s3transfer==0.3.3 290 | sacremoses==0.0.43 291 | scikit-image==0.16.2 292 | scikit-learn==0.22.2.post1 293 | scipy==1.4.1 294 | screen-resolution-extra==0.0.0 295 | scs==2.1.2 296 | seaborn==0.10.1 297 | Send2Trash==1.5.0 298 | sentencepiece==0.1.91 299 | setuptools-git==1.2 300 | Shapely==1.7.0 301 | simplegeneric==0.8.1 302 | six==1.12.0 303 | sklearn==0.0 304 | sklearn-pandas==1.8.0 305 | smart-open==2.0.0 306 | snowballstemmer==2.0.0 307 | sortedcontainers==2.2.2 308 | spacy==2.2.4 309 | Sphinx==1.8.5 310 | sphinxcontrib-websupport==1.2.2 311 | SQLAlchemy==1.3.17 312 | sqlparse==0.3.1 313 | srsly==1.0.2 314 | statsmodels==0.10.2 315 | sympy==1.1.1 316 | tables==3.4.4 317 | tabulate==0.8.7 318 | tbb==2020.0.133 319 | tblib==1.6.0 320 | tensorboard==2.2.2 321 | tensorboard-plugin-wit==1.6.0.post3 322 | tensorboardcolab==0.0.22 323 | tensorflow==2.2.0 324 | tensorflow-addons==0.8.3 325 | tensorflow-datasets==2.1.0 326 | tensorflow-estimator==2.2.0 327 | tensorflow-gcs-config==2.2.0 328 | tensorflow-hub==0.8.0 329 | tensorflow-metadata==0.22.2 330 | tensorflow-privacy==0.2.2 331 | tensorflow-probability==0.10.0 332 | termcolor==1.1.0 333 | terminado==0.8.3 334 | testpath==0.4.4 335 | text-unidecode==1.3 336 | textblob==0.15.3 337 | textgenrnn==1.4.1 338 | Theano==1.0.4 339 | thinc==7.4.0 340 | tifffile==2020.6.3 341 | tokenizers==0.7.0 342 | toolz==0.10.0 343 | torch==1.5.1+cu101 344 | torchsummary==1.5.1 345 | torchtext==0.3.1 346 | torchvision==0.6.1+cu101 347 | tornado==4.5.3 348 | tqdm==4.41.1 349 | traitlets==4.3.3 350 | transformers==2.11.0 351 | tweepy==3.6.0 352 | typeguard==2.7.1 353 | typing==3.6.6 354 | typing-extensions==3.6.6 355 | tzlocal==1.5.1 356 | umap-learn==0.4.4 357 | uritemplate==3.0.1 358 | urllib3==1.24.3 359 | vega-datasets==0.8.0 360 | wasabi==0.7.0 361 | wcwidth==0.2.5 362 | webencodings==0.5.1 363 | Werkzeug==1.0.1 364 | widgetsnbextension==3.5.1 365 | wordcloud==1.5.0 366 | wrapt==1.12.1 367 | xarray==0.15.1 368 | xgboost==0.90 369 | xkit==0.0.0 370 | xlrd==1.1.0 371 | xlwt==1.3.0 372 | yellowbrick==0.9.1 373 | zict==2.0.0 374 | zipp==3.1.0 375 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os 3 | import json 4 | import pandas as pd 5 | from functools import reduce 6 | 7 | with open('./SETTINGS.json') as f: 8 | global_config = json.load(f) 9 | 10 | 11 | def train_moiz(): 12 | prefix_mapping = { 13 | 'jplu/tf-xlm-roberta-large': 'roberta', 14 | 'bert-base-multilingual-cased': 'bert', 15 | 'xlm-mlm-100-1280': 'xlm' 16 | } 17 | for model_name in ['jplu/tf-xlm-roberta-large', 'xlm-mlm-100-1280', 'bert-base-multilingual-cased']: 18 | # Train the Base Model 19 | train_prefix = prefix_mapping[model_name] 20 | train_inp_dir = f'{global_config["mas_train_inp_path"]}/{train_prefix}' 21 | train_files= [f'{train_inp_dir}/{"train_foreign"}_p{x}.pkl' for x in list(range(0,6))] + \ 22 | [f'{train_inp_dir}/{"train_english"}_p{x}.pkl' for x in list(range(0,2))] 23 | val_file = f'{train_inp_dir}/valid_foreign_p0.pkl' 24 | model_resume_file = "None" 25 | model_save_file = f'{global_config["mas_train_out_model_path"]}/model_{train_prefix}_base.h5' 26 | max_epochs = 10 27 | patience = 3 28 | init_lr = 5e-6 29 | class_balance = 1 30 | class_ratio = 2 31 | label_smoothing = 0 32 | min_thresh = 0.2 33 | max_thresh = 0.7 34 | epoch_split_factor = 5 35 | 36 | cmd = f'python ./Code/Moiz/train.py ' \ 37 | f'--dev_mode={global_config["mas_train_dev_mode"]} ' \ 38 | f'--model_name="{model_name}" ' \ 39 | f'--train_files="{train_files}" ' \ 40 | f'--val_file="{val_file}" ' \ 41 | f'--model_save_file="{model_save_file}" ' \ 42 | f'--max_epochs={max_epochs} ' \ 43 | f'--patience={patience} ' \ 44 | f'--init_lr={init_lr} ' \ 45 | f'--class_balance={class_balance} ' \ 46 | f'--class_ratio={class_ratio} ' \ 47 | f'--label_smoothing={label_smoothing} ' \ 48 | f'--min_thresh={min_thresh} ' \ 49 | f'--max_thresh={max_thresh} ' \ 50 | f'--epoch_split_factor={epoch_split_factor} ' \ 51 | f'--model_resume_file="{model_resume_file}" ' 52 | print(cmd) 53 | os.system(cmd) 54 | # Train the FOLD models 55 | for fold_id in range(4): 56 | # Train the Step #2 57 | train_prefix = prefix_mapping[model_name] 58 | train_inp_dir = f'{global_config["mas_train_inp_path"]}/{train_prefix}' 59 | train_parts = list(range(fold_id, fold_id + 1)) 60 | train_files = [f'{train_inp_dir}/test_pseudo_p0.pkl'] + \ 61 | [f'{train_inp_dir}/{"subtitle"}_p{x}.pkl' for x in train_parts] + \ 62 | [f'{train_inp_dir}/{"train_foreign"}_p{x}.pkl' for x in train_parts] 63 | val_file = f'{train_inp_dir}/valid_foreign_p0.pkl' 64 | model_resume_file = f'{global_config["mas_train_out_model_path"]}/model_{train_prefix}_base.h5' 65 | model_save_file = f'{global_config["mas_train_out_model_path"]}/model_{train_prefix}_fold{fold_id}.h5' 66 | max_epochs = 5 67 | patience = 2 68 | init_lr = 5e-6 69 | class_balance = 1 70 | class_ratio = 3 71 | label_smoothing = 0.1 72 | min_thresh = 0.3 73 | max_thresh = 0.6 74 | epoch_split_factor = 2 75 | 76 | cmd = f'python ./Code/Moiz/train.py ' \ 77 | f'--dev_mode={global_config["mas_train_dev_mode"]} ' \ 78 | f'--model_name="{model_name}" ' \ 79 | f'--train_files="{train_files}" ' \ 80 | f'--val_file="{val_file}" ' \ 81 | f'--model_save_file="{model_save_file}" ' \ 82 | f'--max_epochs={max_epochs} ' \ 83 | f'--patience={patience} ' \ 84 | f'--init_lr={init_lr} ' \ 85 | f'--class_balance={class_balance} ' \ 86 | f'--class_ratio={class_ratio} ' \ 87 | f'--label_smoothing={label_smoothing} ' \ 88 | f'--min_thresh={min_thresh} ' \ 89 | f'--max_thresh={max_thresh} ' \ 90 | f'--epoch_split_factor={epoch_split_factor} ' \ 91 | f'--model_resume_file="{model_resume_file}" ' 92 | 93 | print(cmd) 94 | os.system(cmd) 95 | 96 | # Final Step - Fine Tuning on Validation 97 | train_prefix = prefix_mapping[model_name] 98 | train_inp_dir = f'{global_config["mas_train_inp_path"]}/{train_prefix}' 99 | train_files = [f'{train_inp_dir}/valid_foreign_p0.pkl'] 100 | val_file = f'{train_inp_dir}/valid_foreign_p0.pkl' 101 | model_resume_file = f'{global_config["mas_train_out_model_path"]}/model_{train_prefix}_fold{fold_id}.h5' 102 | model_save_file = f'{global_config["mas_train_out_model_path"]}/model_{train_prefix}_fold{fold_id}.h5' 103 | max_epochs = 1 104 | patience = 1 105 | init_lr = 5e-6 106 | class_balance = 0 107 | class_ratio = 3 108 | label_smoothing = 0 109 | min_thresh = 0.3 110 | max_thresh = 0.6 111 | epoch_split_factor = 1 112 | 113 | cmd = f'python ./Code/Moiz/train.py ' \ 114 | f'--dev_mode={global_config["mas_train_dev_mode"]} ' \ 115 | f'--model_name="{model_name}" ' \ 116 | f'--train_files="{train_files}" ' \ 117 | f'--val_file="{val_file}" ' \ 118 | f'--model_save_file="{model_save_file}" ' \ 119 | f'--max_epochs={max_epochs} ' \ 120 | f'--patience={patience} ' \ 121 | f'--init_lr={init_lr} ' \ 122 | f'--class_balance={class_balance} ' \ 123 | f'--class_ratio={class_ratio} ' \ 124 | f'--label_smoothing={label_smoothing} ' \ 125 | f'--min_thresh={min_thresh} ' \ 126 | f'--max_thresh={max_thresh} ' \ 127 | f'--epoch_split_factor={epoch_split_factor} ' \ 128 | f'--model_resume_file="{model_resume_file}" ' 129 | 130 | print(cmd) 131 | os.system(cmd) 132 | 133 | def train_igor(): 134 | os.system('pip install torchvision > /dev/null') 135 | os.system('curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py') 136 | os.system('python pytorch-xla-env-setup.py --version 20200420 --apt-packages libomp5 libopenblas-dev ') 137 | os.system('pip install transformers==2.11.0 > /dev/null') 138 | os.system('pip install pandarallel > /dev/null') 139 | os.system('pip install catalyst==20.4.2 > /dev/null') 140 | train_list=[ 141 | ['it','./Input/Igor/train_data_it_yandex.csv.zip','dbmdz/bert-base-italian-xxl-uncased','./Output/Models/Igor_dev/it/debug_yandex','./Input/Igor/validation.csv.zip','1'], 142 | ['es','./Input/Igor/train_data_es_google.csv.zip','dccuchile/bert-base-spanish-wwm-cased','./Output/Models/Igor_dev/es/debug_google','./Input/Igor/validation.csv.zip','0'], 143 | ['es','./Input/Igor/train_data_es_yandex.csv.zip','dccuchile/bert-base-spanish-wwm-cased','./Output/Models/Igor_dev/es/debug_yandex','./Input/Igor/validation.csv.zip','0'], 144 | ['fr','./Input/Igor/train_data_fr_google.csv.zip','camembert/camembert-large','./Output/Models/Igor_dev/fr/debug_google','no_val','0'], 145 | ['fr','./Input/Igor/train_data_fr_yandex.csv.zip','camembert/camembert-large','./Output/Models/Igor_dev/fr/debug_yandex','no_val','0'], 146 | ['ru','./Input/Igor/train_data_ru_google.csv.zip','DeepPavlov/rubert-base-cased-conversational','./Output/Models/Igor_dev/ru/debug_google','no_val','0'], 147 | ['ru','./Input/Igor/train_data_ru_yandex.csv.zip','DeepPavlov/rubert-base-cased-conversational','./Output/Models/Igor_dev/ru/debug_yandex','no_val','0'], 148 | ['tr','./Input/Igor/train_data_tr_google.csv.zip','dbmdz/bert-base-turkish-cased','./Output/Models/Igor_dev/tr/debug_dbmdz','./Input/Igor/validation.csv.zip','0'], 149 | ['tr','./Input/Igor/train_data_tr_google.csv.zip','savasy/bert-turkish-text-classification','./Output/Models/Igor_dev/tr/debug_savasy','./Input/Igor/validation.csv.zip','0'] 150 | 151 | ] 152 | for r in train_list: 153 | lang = r[0] 154 | input_file = r[1] 155 | backbone = r[2] 156 | model_file_prefix = r[3] 157 | val_file = r[4] 158 | val_tune = r[5] 159 | cmd = f'python ./Code/Igor/train.py --backbone="{backbone}" --model_file_prefix="{model_file_prefix}" --train_file="{input_file}" --val_file="{val_file}" --val_tune={val_tune} --os_file="./Input/Igor/633287_1126366_compressed_open-subtitles-synthesic.csv.zip" --lang="{lang}"' 160 | print(cmd) 161 | os.system(cmd) 162 | 163 | def train_ujjwal(): 164 | os.system('/bin/bash train.sh') 165 | 166 | 167 | if __name__ == '__main__': 168 | train_moiz() 169 | os.chdir('Code/Ujjwal') 170 | train_ujjwal() 171 | os.chdir('../../') 172 | train_igor() 173 | --------------------------------------------------------------------------------