├── Illustration of Code and Dataset.pdf ├── T5_tuning_liftTL.py ├── README.md ├── transfer_learning_CW.py ├── transfer_learning.py ├── auto_detect_span.py ├── Seq2seq_lifted_all.py ├── dataset_creation_GPT3 ├── framework1.py └── framework2.py └── Run.ipynb /Illustration of Code and Dataset.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yongchao98/NL2TL/HEAD/Illustration of Code and Dataset.pdf -------------------------------------------------------------------------------- /T5_tuning_liftTL.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | nltk.download('punkt') 3 | import transformers 4 | model_checkpoint = "t5-large" # can choose t5-large or t5-base 5 | print(model_checkpoint) 6 | print('*'*20) 7 | print('\n') 8 | 9 | # Directly from json file to the dataset_total 10 | from IPython.core import error 11 | import json 12 | from fnmatch import fnmatchcase as match 13 | import random 14 | import os 15 | import pandas as pd 16 | import datasets 17 | from datasets import Dataset, DatasetDict, load_dataset, load_from_disk 18 | import nltk 19 | import numpy as np 20 | 21 | # when input data is the real training dataset 22 | home_path = 'supple_data/Data_word_inorder_21867/' 23 | # for word predict 24 | original_list = ['combine_dev_seq2tree_idea4.jsonl', 25 | 'combine_train_seq2tree_idea4.jsonl', 26 | 'combine_test_seq2tree_idea4.jsonl' 27 | ] 28 | 29 | dataset_total = [] 30 | for file in original_list: 31 | for line in open(home_path + file, 'r', encoding='utf-8'): # input data########################### 32 | dataset_total.append(json.loads(line)) 33 | random.shuffle(dataset_total) 34 | 35 | len_train = int(0.7*len(dataset_total)); len_dev = int(0.85*len(dataset_total)); 36 | 37 | import csv 38 | f = open(home_path+'/total_data.csv','w') 39 | csv_writer = csv.writer(f) 40 | csv_writer.writerow(['id', 'logic_ltl', 'logic_sentence', 'ltl', 'sentence']) 41 | for i in range(len(dataset_total)): 42 | csv_writer.writerow([i, ' '.join(dataset_total[i]['logic_ltl']), ' '.join(dataset_total[i]['logic_sentence']), ' '.join(dataset_total[i]['ltl']), ' '.join(dataset_total[i]['sentence'])]) 43 | f.close() 44 | 45 | dataset = load_dataset('csv', data_files=home_path + '/total_data.csv') 46 | 47 | train_dataset, test_dataset= dataset['train'].train_test_split(test_size=0.1).values() 48 | #dev_dataset, test_dataset = validation_dataset.train_test_split(test_size=0.5).values() 49 | 50 | from transformers import AutoTokenizer 51 | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 52 | 53 | if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]: 54 | prefix = "Transform the following sentence into Signal Temporal logic: " 55 | else: 56 | prefix = "" 57 | 58 | max_input_length = 1024 59 | max_target_length = 128 60 | def preprocess_function(examples): 61 | inputs = [prefix + doc for doc in examples["logic_sentence"]] 62 | model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) 63 | 64 | # Setup the tokenizer for targets 65 | with tokenizer.as_target_tokenizer(): 66 | labels = tokenizer(examples["logic_ltl"], max_length=max_target_length, truncation=True) 67 | 68 | model_inputs["labels"] = labels["input_ids"] 69 | model_inputs["logic_sentence"] = examples["logic_sentence"] 70 | model_inputs["logic_ltl"] = examples["logic_ltl"] 71 | model_inputs["id"] = examples["id"] 72 | return model_inputs 73 | 74 | tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True) 75 | tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True) 76 | #tokenized_dev_dataset = dev_dataset.map(preprocess_function, batched=True) 77 | 78 | from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer 79 | model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) 80 | batch_size = 16 81 | model_name = model_checkpoint.split("/")[-1]+"-epoch20-infix-word" 82 | output_dir = '../trained_models/' 83 | model_dir = output_dir+model_name 84 | 85 | args = Seq2SeqTrainingArguments( 86 | model_dir, 87 | model_name, 88 | evaluation_strategy = "epoch", 89 | learning_rate=2e-5, 90 | per_device_train_batch_size=batch_size, 91 | per_device_eval_batch_size=batch_size, 92 | weight_decay=0.01, 93 | seed=1203, 94 | save_total_limit=1, 95 | num_train_epochs=20, 96 | predict_with_generate=True, 97 | fp16=False, 98 | #push_to_hub=True, 99 | #report_to="tensorboard", 100 | #load_best_model_at_end=True, 101 | #save_strategy = "no" 102 | ) 103 | 104 | data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) 105 | 106 | def compute_metrics(eval_pred): 107 | predictions, labels = eval_pred 108 | # print(predictions) 109 | # print(labels) 110 | # Replace -100 in the labels as we can't decode them. 111 | predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) 112 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 113 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 114 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 115 | count = 0 116 | for i in range(len(decoded_preds)): 117 | pred = nltk.sent_tokenize(decoded_preds[i].strip()) 118 | label = nltk.sent_tokenize(decoded_labels[i].strip()) 119 | if pred == label: 120 | count += 1 121 | return {'top-1 accuracy': round(count / len(decoded_preds), 6)} 122 | 123 | trainer = Seq2SeqTrainer( 124 | model, 125 | args, 126 | train_dataset=tokenized_train_dataset , 127 | eval_dataset=tokenized_test_dataset , 128 | data_collator=data_collator, 129 | tokenizer=tokenizer, 130 | compute_metrics=compute_metrics 131 | ) 132 | trainer.train() 133 | 134 | import torch 135 | from transformers import AutoModelForSeq2SeqLM 136 | 137 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 138 | count_correct = 0 139 | for i in range(min(len(tokenized_test_dataset),1000)): 140 | inputs = [prefix + tokenized_test_dataset[i]['logic_sentence']] 141 | inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt").to(device) 142 | output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64) 143 | decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0] 144 | predicted_title = decoded_output.strip() 145 | if predicted_title == tokenized_test_dataset[i]['logic_ltl']: 146 | count_correct += 1 147 | else: 148 | print(predicted_title) 149 | print(tokenized_test_dataset[i]['logic_ltl']) 150 | print('\n') 151 | print('Accuracy: ', count_correct/(i + 1)) 152 | print('\n'*2) 153 | 154 | with open(output_dir +'result_'+model_name+'.txt', 'w') as f_result: 155 | f_result.write(model_name+' accuracy is: ' + str(count_correct / (i + 1)) + '\n') 156 | f_result.close() 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NL2TL (EMNLP 2023) 2 | Webpage: https://yongchao98.github.io/MIT-realm-NL2TL/ 3 | 4 | Demo Website: http://realm-02.mit.edu:8444 5 | 6 | Paper Link: https://arxiv.org/pdf/2305.07766.pdf 7 | 8 | Dataset Link: https://drive.google.com/drive/folders/10F-qyOhpqEi83o9ZojymqRPUtwzSOcfq?usp=sharing 9 | 10 | Model Link: [https://drive.google.com/drive/folders/1vSaKOunMPA3uiOdx6IDbe-gmfREXQ9uO?usp=share_link](https://drive.google.com/drive/folders/1ZfZoYovWoy5z247VXZWZBniNrCOONX4N?usp=share_link) 11 | 12 | To access the Demo Website, please send email to ycchen98@mit.edu or yongchaochen@fas.harvard.edu for **password** 13 | 14 | This project is to transform human natural languages into Signal temporal logics (STL). Here to enhance the generalizability, in each natural language the specific atomic proposition (AP) is represented as prop_1, prop_2, etc. In this way, the trained model can be easier to transfer into various specific domains. The APs refer to some specific specifications like grad the apple, or go to the room. 15 | 16 | Also in the current work, the co-reference is not considered. Therefore, **each prop_i should only appear once in each sentence**. One inference example is as the following: 17 | 18 | Input natural language: 19 | 20 | ``` 21 | If ( prop_2 ) happens and continues to happen until at some point during the 176 to 415 time units that ( prop_1 ) , and also if ( prop_3 ) , then the scenario is equivalent to ( prop_4 ) . 22 | ``` 23 | 24 | Output Signal temporal logic: 25 | 26 | ``` 27 | ( ( ( prop_2 until [176,415] prop_1 ) and prop_3 ) equal prop_4 ) 28 | ``` 29 | 30 | The operations we used are U(until), F(finally), G(globally), |(or), &(and), ->(imply), <->(equal), negation. Also we allow the time interval definition, like U[0,5], F[12,100], and G[30,150]. The time numer right now is constrained into integer, and can use infinite to express all the time in the future, like [5,infinite]. The following are the illustrations. More NL-TL pair examples at https://drive.google.com/file/d/1f-wQ8AKInlTpXTYKwICRC0eZ-JKjAefh/view?usp=sharing 31 | ``` 32 | prop_1 U[0,5] prop_2 : There exits one time point t between 0 and 5 timesteps from now, that prop_1 continues to happen until at this timestep, and prop_2 happens at this timestep. 33 | ``` 34 | ``` 35 | F[12,100] prop_2 : There exits one time point t between 12 and 100 timesteps from now, that prop_2 happens at this timestep. 36 | ``` 37 | ``` 38 | G[30,150] prop_2 : For all the time between 30 and 150 timesteps from now, that prop_2 always happens. 39 | ``` 40 | ``` 41 | prop_1 -> prop_2 : If prop_1 happens, then prop_2 also happens. 42 | ``` 43 | ``` 44 | prop_1 <-> prop_2: prop_1 happens if and only if prop_2 happens. 45 | ``` 46 | 47 | ## Description 48 | 49 | Signal Temporal Logic (STL) is a formal language for specifying properties of signals. It is used to specify properties of continuous-time signals, such as signals from sensors or control systems, in a way that is precise and easy to understand. 50 | 51 | STL has a syntax that is similar to the temporal logic used in computer science, but it is specialized for continuous-time signals. It includes operators for describing the values of a signal, as well as operators for combining and modifying those descriptions. 52 | 53 | For example, the STL formula F[0, 2] (x > 0.5) specifies the property that the signal x is greater than 0.5 for all time points between 0 and 2 seconds. This formula can be read as "the signal x is eventually greater than 0.5 for a period of at least 2 seconds". 54 | 55 | STL can be used to verify that a signal satisfies a given property, or to synthesize a controller that ensures that a signal satisfies a given property. It is a powerful tool for reasoning about the behavior of continuous-time systems. 56 | 57 | While STL is quite powerful, humans are more familiar with defining the tasks via natural languages. Here we try to bridge this gap via fine-tuning large languages models. 58 | 59 | ## Getting Started 60 | 61 | ### Dependencies 62 | 63 | * The inference model should run on GPU, you can run the notebook file Run.ipynb on Google Colab, or run_trained_model.py on your own GPU environment. 64 | * As for setting the environment, here we install our environmrnt via Minoconda. You can first set up it via https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html 65 | * Then it is time to install packages: 66 | ``` 67 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 68 | conda install pip 69 | conda install python 70 | conda install numpy 71 | conda install pandas 72 | pip install transformers 73 | pip install SentencePiece 74 | ``` 75 | 76 | ### Installing 77 | 78 | * First download the whole directory with command 79 | ``` 80 | git clone git@github.com:yongchao98/NL2TL.git 81 | ``` 82 | * Then download the trained wieghts (e.g., checkpoint-62500) of our model in [https://drive.google.com/file/d/19uiB_2XnnnVmDInaLbQeoZq25ghUdg4D/view](https://drive.google.com/drive/folders/1ZfZoYovWoy5z247VXZWZBniNrCOONX4N?usp=sharing) 83 | * After downloading both the code and model weights, put the model weights checkpoint-62500 into your self-defined directory. 84 | 85 | ### Other codes and datasets 86 | 87 | * As for other codes and datasets published on github, please read the **Illustration of Code and Dataset.pdf** for specific explanation of their utilities. 88 | 89 | ## Authors 90 | 91 | Contributors names and contact info 92 | 93 | Yongchao Chen (Harvard University, Massachusetts Institute of Technology, Laboratory of Information and Decision Systems) 94 | 95 | yongchaochen@fas.harvard.edu or ycchen98@mit.edu 96 | 97 | ## Citation for BibTeX 98 | 99 | @article{chen2023nl2tl, 100 | title={NL2TL: Transforming Natural Languages to Temporal Logics using Large Language Models}, 101 | author={Chen, Yongchao and Gandhi, Rujul and Zhang, Yang and Fan, Chuchu}, 102 | journal={arXiv preprint arXiv:2305.07766}, 103 | year={2023} 104 | } 105 | } 106 | 107 | ## Version History 108 | 109 | * 0.1 110 | * Initial Release on May 12, 2023 111 | 112 | ## License 113 | 114 | This corresponding paper of this project will be attached here in the next months. This project can only be commercially used under our permission. 115 | 116 | ## Recommended Work 117 | 118 | [AutoTAMP: Autoregressive Task and Motion Planning with LLMs as Translators and Checkers](https://arxiv.org/pdf/2306.06531.pdf) 119 | 120 | [Scalable Multi-Robot Collaboration with Large Language Models: Centralized or Decentralized Systems?](https://yongchao98.github.io/MIT-REALM-Multi-Robot/) 121 | 122 | -------------------------------------------------------------------------------- /transfer_learning_CW.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | nltk.download('punkt') 3 | import transformers 4 | 5 | # Directly from json file to the dataset_total 6 | from IPython.core import error 7 | import json 8 | from fnmatch import fnmatchcase as match 9 | import random 10 | import os 11 | import pandas as pd 12 | import datasets 13 | from datasets import Dataset, DatasetDict, load_dataset, load_from_disk 14 | import nltk 15 | import numpy as np 16 | from argparse import ArgumentParser 17 | 18 | parser = ArgumentParser() 19 | parser.add_argument('-seed', '--seed', type=int, default=1203) # input random seed 20 | parser.add_argument('-init_weight', '--init_weight', default='with_pre-train') # The initial weight is pre-trained by us or pure from T5 21 | parser.add_argument('-model_checkpoint', '--model_checkpoint', default='t5-base') # The model type t5-base and t5-large 22 | args = parser.parse_args() 23 | 24 | int_seed = args.seed 25 | init_weight = args.init_weight 26 | model_checkpoint = args.model_checkpoint 27 | 28 | print(model_checkpoint) 29 | print('*'*20) 30 | print('\n') 31 | 32 | # when input data is the real training dataset 33 | home_path = 'Data_transfer_domain/' 34 | # for word predict 35 | original_list = [ 36 | 'CW_total_3382_for_transfer_word_midfix.jsonl' 37 | ] 38 | 39 | dataset_total = [] 40 | for file in original_list: 41 | for line in open(home_path + file, 'r', encoding='utf-8'): # input data########################### 42 | dataset_total.append(json.loads(line)) 43 | random.shuffle(dataset_total) 44 | 45 | def preprocess_function(examples): 46 | inputs = [prefix + doc for doc in examples["sentence"]] 47 | model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) 48 | 49 | # Setup the tokenizer for targets 50 | with tokenizer.as_target_tokenizer(): 51 | labels = tokenizer(examples["ltl"], max_length=max_target_length, truncation=True) 52 | 53 | model_inputs["labels"] = labels["input_ids"] 54 | model_inputs["sentence"] = examples["sentence"] 55 | model_inputs["ltl"] = examples["ltl"] 56 | model_inputs["id"] = examples["id"] 57 | return model_inputs 58 | 59 | def compute_metrics(eval_pred): 60 | predictions, labels = eval_pred 61 | # print(predictions) 62 | # print(labels) 63 | # Replace -100 in the labels as we can't decode them. 64 | predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) 65 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 66 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 67 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 68 | count = 0 69 | for i in range(len(decoded_preds)): 70 | pred = nltk.sent_tokenize(decoded_preds[i].strip()) 71 | label = nltk.sent_tokenize(decoded_labels[i].strip()) 72 | if pred == label: 73 | count += 1 74 | return {'top-1 accuracy': round(count / len(decoded_preds), 6)} 75 | 76 | from transformers import AutoTokenizer 77 | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 78 | 79 | if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]: 80 | prefix = "Transform the following sentence into Signal Temporal logic: " 81 | else: 82 | prefix = "" 83 | 84 | max_input_length = 1024 85 | max_target_length = 128 86 | 87 | unique_ltl = [] 88 | for i in range(len(dataset_total)): 89 | if dataset_total[i]['ltl'] not in unique_ltl: 90 | unique_ltl.append(dataset_total[i]['ltl']) 91 | 92 | from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer 93 | 94 | input_model_dir = 'dir_to_save_the_weights_of_pre-trained_T5_on_lifted_NL_TL/' 95 | if init_weight == 'with_pre-train': 96 | model = AutoModelForSeq2SeqLM.from_pretrained( 97 | input_model_dir + model_checkpoint + "-finetuned-epoch20/checkpoint-13000") 98 | elif init_weight == 'without_pre-train': 99 | model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) 100 | else: 101 | print('Initial model weights error!') 102 | batch_size = 16 103 | 104 | dataset_name = 'CW' 105 | output_dir_total = '../trained_models/' + dataset_name + '/' 106 | if not os.path.exists(output_dir_total): 107 | os.mkdir(output_dir_total) 108 | 109 | output_dir = output_dir_total+dataset_name+'_varied_dataset_size_seed'+str(int_seed)+'_'+init_weight+'_'+model_checkpoint+'/' 110 | if not os.path.exists(output_dir): 111 | os.mkdir(output_dir) 112 | 113 | with open(output_dir +'result.txt', 'w') as f_result: 114 | for a in range(4,40,4): 115 | random.shuffle(unique_ltl) 116 | dataset_train = []; dataset_test = [] 117 | for i in range(len(dataset_total)): 118 | if dataset_total[i]['ltl'] in unique_ltl[0:a]: 119 | dataset_train.append(dataset_total[i]) 120 | else: 121 | dataset_test.append(dataset_total[i]) 122 | print('The num of training class types is: ', a) 123 | print('The num of testing class types is: ', 39-a) 124 | print('The num of training dataset is: ', len(dataset_train)) 125 | print('The num of testing dataset is: ',len(dataset_test)) 126 | print('/n'*2) 127 | 128 | import csv 129 | f = open(home_path+'/train_data.csv','w') 130 | csv_writer = csv.writer(f) 131 | csv_writer.writerow(['id', 'ltl', 'sentence']) 132 | for i in range(len(dataset_train)): 133 | csv_writer.writerow([i, ' '.join(dataset_train[i]['ltl']), ' '.join(dataset_train[i]['sentence'])]) 134 | f.close() 135 | 136 | f = open(home_path+'/test_data.csv','w') 137 | csv_writer = csv.writer(f) 138 | csv_writer.writerow(['id', 'ltl', 'sentence']) 139 | for i in range(len(dataset_test)): 140 | csv_writer.writerow([i, ' '.join(dataset_test[i]['ltl']), ' '.join(dataset_test[i]['sentence'])]) 141 | f.close() 142 | 143 | data_files = {"train": home_path + '/train_data.csv', "test": home_path + '/test_data.csv'} 144 | dataset = load_dataset("csv", data_files=data_files) 145 | 146 | train_dataset, test_dataset= dataset.values() 147 | 148 | tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True) 149 | tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True) 150 | 151 | model_name = model_checkpoint.split("/")[-1]+"-CW-epoch20-"+'train_typenum'+str(a) 152 | model_dir = output_dir+model_name 153 | 154 | args = Seq2SeqTrainingArguments( 155 | model_dir, 156 | model_name, 157 | evaluation_strategy = "epoch", 158 | learning_rate=2e-5, 159 | per_device_train_batch_size=batch_size, 160 | per_device_eval_batch_size=batch_size, 161 | weight_decay=0.01, 162 | seed=int_seed, 163 | save_total_limit=1, 164 | num_train_epochs=20, 165 | predict_with_generate=True, 166 | fp16=False 167 | ) 168 | data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) 169 | 170 | trainer = Seq2SeqTrainer( 171 | model, 172 | args, 173 | train_dataset=tokenized_train_dataset , 174 | eval_dataset=tokenized_test_dataset , 175 | data_collator=data_collator, 176 | tokenizer=tokenizer, 177 | compute_metrics=compute_metrics 178 | ) 179 | trainer.train() 180 | 181 | import torch 182 | from transformers import AutoModelForSeq2SeqLM 183 | 184 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 185 | 186 | count_correct = 0 187 | for i in range(min(len(tokenized_test_dataset),1000)): 188 | inputs = [prefix + tokenized_test_dataset[i]['sentence']] 189 | inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt").to(device) 190 | output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64) 191 | decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0] 192 | predicted_title = decoded_output.strip() 193 | if predicted_title == tokenized_test_dataset[i]['ltl']: 194 | count_correct += 1 195 | else: 196 | print(predicted_title) 197 | print(tokenized_test_dataset[i]['ltl']) 198 | print('\n') 199 | print('The training type number is: ', a) 200 | print('Accuracy: ', count_correct/(i + 1)) 201 | print('\n'*2) 202 | f_result.write(str(a) + ' ' + str(count_correct / (i + 1)) + '\n') 203 | f_result.close() 204 | #trainer.save_model() 205 | -------------------------------------------------------------------------------- /transfer_learning.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | nltk.download('punkt') 3 | import transformers 4 | 5 | # Directly from json file to the dataset_total 6 | from IPython.core import error 7 | import json 8 | from fnmatch import fnmatchcase as match 9 | import random 10 | import os 11 | import pandas as pd 12 | import datasets 13 | from datasets import Dataset, DatasetDict, load_dataset, load_from_disk 14 | import numpy as np 15 | from argparse import ArgumentParser 16 | 17 | parser = ArgumentParser() 18 | parser.add_argument('-seed', '--seed', type=int, default=1203) # input random seed 19 | parser.add_argument('-name', '--name', default='GLTL') # The dataset name to train GLTL, circuit, navi 20 | parser.add_argument('-init_weight', '--init_weight', default='with_pre-train') # The initial weight is pre-trained by us or pure from T5 21 | parser.add_argument('-data_size', '--data_size', default='0.1-0.9') # The dataset size range '0.1-0.9' or '0.01-0.09' 22 | parser.add_argument('-model_checkpoint', '--model_checkpoint', default='t5-base') # The model type t5-base and t5-large 23 | args = parser.parse_args() 24 | int_seed = args.seed 25 | dataset_name = args.name 26 | init_weight = args.init_weight 27 | data_size = args.data_size 28 | model_checkpoint = args.model_checkpoint 29 | 30 | print(model_checkpoint) 31 | print('*'*20) 32 | print('\n') 33 | 34 | # when input data is the ground full NL-TL training dataset 35 | home_path = 'Data_transfer_domain/' 36 | # for word predict 37 | if dataset_name == 'GLTL': 38 | original_list = [ 39 | 'GLTL_train_8923_for_transfer_word_midfix.jsonl', 40 | 'GLTL_test_2232_for_transfer_word_midfix.jsonl' 41 | ] 42 | elif dataset_name == 'navi': 43 | original_list = [ 44 | 'navi_total_refined.jsonl' 45 | ] 46 | elif dataset_name == 'circuit': 47 | original_list = [ 48 | 'circuit_total_refined.jsonl' 49 | ] 50 | else: print('dataset error!') 51 | 52 | dataset_total = [] 53 | for file in original_list: 54 | for line in open(home_path + file, 'r', encoding='utf-8'): # input data########################### 55 | dataset_total.append(json.loads(line)) 56 | random.shuffle(dataset_total) 57 | 58 | import csv 59 | f = open(home_path+'/total_data.csv','w') 60 | csv_writer = csv.writer(f) 61 | csv_writer.writerow(['id', 'ltl', 'sentence']) 62 | for i in range(len(dataset_total)): 63 | csv_writer.writerow([i, ' '.join(dataset_total[i]['ltl']), ' '.join(dataset_total[i]['sentence'])]) 64 | f.close() 65 | 66 | dataset = load_dataset('csv', data_files=home_path + '/total_data.csv') 67 | 68 | from transformers import AutoTokenizer 69 | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 70 | 71 | if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]: 72 | prefix = "Transform the following sentence into Signal Temporal logic: " 73 | else: 74 | prefix = "" 75 | 76 | max_input_length = 1024 77 | max_target_length = 128 78 | def preprocess_function(examples): 79 | inputs = [prefix + doc for doc in examples["sentence"]] 80 | model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) 81 | 82 | # Setup the tokenizer for targets 83 | with tokenizer.as_target_tokenizer(): 84 | labels = tokenizer(examples["ltl"], max_length=max_target_length, truncation=True) 85 | 86 | model_inputs["labels"] = labels["input_ids"] 87 | model_inputs["sentence"] = examples["sentence"] 88 | model_inputs["ltl"] = examples["ltl"] 89 | model_inputs["id"] = examples["id"] 90 | return model_inputs 91 | 92 | def compute_metrics(eval_pred): 93 | predictions, labels = eval_pred 94 | # print(predictions) 95 | # print(labels) 96 | # Replace -100 in the labels as we can't decode them. 97 | predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) 98 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 99 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 100 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 101 | count = 0 102 | for i in range(len(decoded_preds)): 103 | pred = nltk.sent_tokenize(decoded_preds[i].strip()) 104 | label = nltk.sent_tokenize(decoded_labels[i].strip()) 105 | if pred == label: 106 | count += 1 107 | return {'top-1 accuracy': round(count / len(decoded_preds), 6)} 108 | 109 | def correct_parenthe(input_str): 110 | count = 0 111 | original_list = input_str.split(' ') 112 | for index, item in enumerate(original_list): 113 | if len(item) >2: 114 | if item[-1] == '.': 115 | original_list[index] = original_list[index][:-1] 116 | if item == '(': 117 | count += 1 118 | elif item == ')': 119 | count -= 1 120 | if count >0: 121 | for i in range(count): 122 | original_list.append(')') 123 | if count <0: 124 | for i in range(-count): 125 | original_list.pop(-1) 126 | return ' '.join(original_list) 127 | 128 | from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer 129 | output_dir_total = '../trained_models/' + dataset_name + '/' 130 | if not os.path.exists(output_dir_total): 131 | os.mkdir(output_dir_total) 132 | 133 | if data_size == '0.1-0.9': 134 | output_dir = output_dir_total+dataset_name+'_varied_dataset_size_seed'+str(int_seed)+'_one'+'_'+init_weight+'_'+model_checkpoint+'/' 135 | elif data_size == '0.01-0.09': 136 | output_dir = output_dir_total+dataset_name+'_varied_dataset_size_seed'+str(int_seed)+'_pointone'+'_'+init_weight+'_'+model_checkpoint+'/' 137 | if not os.path.exists(output_dir): 138 | os.mkdir(output_dir) 139 | 140 | with open(output_dir +'result.txt', 'w') as f_result: 141 | for i in range(9): 142 | if data_size == '0.1-0.9': 143 | train_dataset, test_dataset = dataset['train'].train_test_split(test_size=0.9-0.1*i).values() 144 | elif data_size == '0.01-0.09': 145 | train_dataset, test_dataset = dataset['train'].train_test_split(test_size=0.99 - 0.01 * i).values() 146 | else: print('Datasize error!') 147 | tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True) 148 | tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True) 149 | 150 | input_model_dir = 'dir_to_save_the_weights_of_pre-trained_T5_on_lifted_NL_TL/' 151 | if init_weight == 'with_pre-train': 152 | model = AutoModelForSeq2SeqLM.from_pretrained(input_model_dir+model_checkpoint+"-finetuned-epoch20/checkpoint-13000") 153 | elif init_weight == 'without_pre-train': 154 | model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) 155 | else: print('Initial model weights error!') 156 | batch_size = 16 157 | model_name = model_checkpoint.split("/")[-1]+'-'+dataset_name+"-epoch20-trainpoint"+str(i+1) 158 | model_dir = output_dir+model_name 159 | 160 | args = Seq2SeqTrainingArguments( 161 | model_dir, 162 | model_name, 163 | evaluation_strategy = "epoch", 164 | learning_rate=2e-5, 165 | per_device_train_batch_size=batch_size, 166 | per_device_eval_batch_size=batch_size, 167 | weight_decay=0.01, 168 | seed=int_seed, 169 | save_total_limit=1, 170 | num_train_epochs=20, 171 | predict_with_generate=True, 172 | fp16=False, 173 | #push_to_hub=True, 174 | #report_to="tensorboard", 175 | #load_best_model_at_end=True, 176 | #save_strategy = "no" 177 | ) 178 | 179 | data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) 180 | 181 | trainer = Seq2SeqTrainer( 182 | model, 183 | args, 184 | train_dataset=tokenized_train_dataset , 185 | eval_dataset=tokenized_test_dataset , 186 | data_collator=data_collator, 187 | tokenizer=tokenizer, 188 | compute_metrics=compute_metrics 189 | ) 190 | trainer.train() 191 | 192 | import torch 193 | from transformers import AutoModelForSeq2SeqLM 194 | 195 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 196 | #model = AutoModelForSeq2SeqLM.from_pretrained(output_dir+model_name+'/checkpoint-13000').to(device) 197 | count_correct_w_parenthe = 0 198 | count_correct_wo_parenthe = 0 199 | for j in range(min(len(tokenized_test_dataset),1000)): 200 | inputs = [prefix + tokenized_test_dataset[j]['sentence']] 201 | 202 | inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt").to(device) 203 | output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64) 204 | decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0] 205 | #predicted_title = nltk.sent_tokenize(decoded_output.strip())[0] 206 | predicted_title = decoded_output.strip() 207 | if correct_parenthe(predicted_title) == tokenized_test_dataset[j]['ltl']: 208 | count_correct_w_parenthe += 1 209 | else: 210 | print(correct_parenthe(predicted_title)) 211 | print(tokenized_test_dataset[j]['ltl']) 212 | print('\n') 213 | if predicted_title == tokenized_test_dataset[j]['ltl']: 214 | count_correct_wo_parenthe += 1 215 | if data_size == '0.1-0.9': 216 | print('The training data size is: ', (i+1)*0.1) 217 | f_result.write(str((i + 1) * 0.1) + ' ' + str(count_correct_w_parenthe / (j+1)) + ' ' + str(count_correct_wo_parenthe / (j+1)) + '\n') 218 | elif data_size == '0.01-0.09': 219 | print('The training data size is: ', (i+1) * 0.01) 220 | f_result.write(str((i + 1) * 0.01) + ' ' + str(count_correct_w_parenthe / (j+1)) + ' ' + str(count_correct_wo_parenthe / (j+1)) + '\n') 221 | print('The training data number is: ',len(tokenized_train_dataset)) 222 | print('Accuracy with parentheses correction: ', count_correct_w_parenthe/(j+1)) 223 | print('Accuracy without parentheses correction: ', count_correct_wo_parenthe / (j + 1)) 224 | f_result.close() 225 | #trainer.save_model() 226 | -------------------------------------------------------------------------------- /auto_detect_span.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import spacy_transformers 3 | import benepar, spacy 4 | from nltk.tree import Tree 5 | from tqdm import tqdm 6 | import re 7 | import copy 8 | import json 9 | from spacy.language import Language 10 | import random 11 | 12 | benepar.download('benepar_en3') 13 | data_portion = 'total' 14 | #data_size = 1500 15 | 16 | 17 | nlp = spacy.load('en_core_web_trf') 18 | nlp.add_pipe('benepar', config={'model': 'benepar_en3'}) 19 | 20 | # --------------------------- customize functions -------------------------- # 21 | @Language.component("prevent_sbd") 22 | def prevent_sbd(doc): 23 | for token in doc: 24 | # This will entirely disable spaCy's sentence detection 25 | token.is_sent_start = False 26 | return doc 27 | nlp.add_pipe('prevent_sbd', before='parser') 28 | # -------------------------------------------------------------------------- # 29 | 30 | logic_set = set(['G', 'F', '(' , ')', '&', '|', 'U', '[', ']', ':', 'hist|ically', 'U', '->']) 31 | operation_set = set(['>=', '<=', '>', '<', '->', '==', 'rise', 'fall', '=', '-']) 32 | 33 | def get_proposition(ltl): 34 | """ 35 | this function finds the proposition (action object) and automaton rise action object 36 | """ 37 | ltl = re.sub('\[ [0-9]+ : [0-9]+ \]', '', ltl) 38 | while ' ' in ltl: 39 | ltl = ltl.replace(' ',' ') 40 | ltl = ltl.split(' ') 41 | # print(ltl) 42 | propositions = [] 43 | proposition = [] 44 | autos = [] 45 | auto = [] 46 | for op in ltl: 47 | # if op == 'rise' or op == 'fall': 48 | # auto.append(op) 49 | if not op in logic_set: 50 | auto.append(op) 51 | if not op in operation_set: 52 | proposition.append(op) 53 | if len(proposition) == 2: 54 | propositions.append(tuple(proposition)) 55 | proposition = [] 56 | autos.append(auto) 57 | auto = [] 58 | return propositions, autos 59 | 60 | def get_prop_idx(propositions, tree): 61 | """ 62 | Find the index of corresponding proposition in original sentence 63 | """ 64 | prop2idx = [] 65 | tokens = tree.leaves() 66 | for prop in propositions: 67 | act_pos_list = [] 68 | obj_pos_list = [] 69 | for pos, t in enumerate(tokens): 70 | # if not pos in used_index: 71 | if t == prop[0]: 72 | act_pos_list.append(pos) 73 | elif t == prop[1]: 74 | obj_pos_list.append(pos) 75 | assert len(act_pos_list) > 0 and len(obj_pos_list) > 0 76 | min_dis = 10000000 77 | 78 | for p1 in act_pos_list: 79 | for p2 in obj_pos_list: 80 | if abs(p1-p2) < min_dis: 81 | min_dis = abs(p1-p2) 82 | act_pos = p1 83 | obj_pos = p2 84 | # used_index.add(act_pos) 85 | # used_index.add(obj_pos) 86 | prop2idx.append({'prop':prop, 'pos':[act_pos, obj_pos] if act_pos < obj_pos else [obj_pos, act_pos]}) 87 | 88 | return prop2idx 89 | 90 | def find_prop_spans(prop2idx, tree): 91 | """ 92 | find the span of proposition in sentence 93 | """ 94 | indexes = [] 95 | for prop in prop2idx: 96 | pos = prop['pos'] 97 | tree_pos_s = tree.leaf_treeposition(pos[0]) 98 | tree_pos_e = tree.leaf_treeposition(pos[1]) 99 | parent = [] 100 | branch = [] 101 | for s, e in zip(tree_pos_s, tree_pos_e): 102 | if s == e: 103 | parent.append(s) 104 | else: 105 | branch.append(s) 106 | branch.append(e) 107 | break 108 | token_seq = [] 109 | for i in range(branch[0], branch[1]+1): 110 | t_index = copy.copy(parent) 111 | t_index.append(i) 112 | subtree = tree[t_index] 113 | token_seq.extend(subtree.leaves()) 114 | for i in range(len(tokens)): 115 | if tokens[i:i+len(token_seq)] == token_seq: 116 | indexes.append((i, i+len(token_seq))) 117 | return indexes 118 | 119 | def get_new_ltl(auto, ltl, prop_id): 120 | ltl = ltl.split(' ') 121 | start_idxs = [] 122 | end_idxs = [] 123 | for i, op in enumerate(ltl): 124 | if op == auto[0]: 125 | start_idxs.append(i) 126 | if op == auto[-1]: 127 | end_idxs.append(i+1) 128 | dist = 100000000 129 | s_idx = 0 130 | e_idx = len(ltl) 131 | for s in start_idxs: 132 | for e in end_idxs: 133 | if e - s < dist and e - s > 0: 134 | s_idx = s 135 | e_idx = e 136 | dist = e - s 137 | sub_ltl = ' '.join(ltl[s_idx:e_idx]) 138 | ltl = ' '.join(ltl) 139 | num_left = sub_ltl.count('(') 140 | num_right = sub_ltl.count(')') 141 | assert num_left >= num_right and num_left - num_right <= 1 142 | if num_left > num_right: 143 | sub_ltl += ' )' 144 | ltl = ltl.replace(sub_ltl, prop_id) 145 | while ' ' in ltl: 146 | ltl = ltl.replace(' ', ' ') 147 | return ltl, sub_ltl 148 | 149 | 150 | input_path = '/raw_data/circuit_{}.jsonl'.format(data_portion) 151 | corpus = [] 152 | sentences = [] 153 | ltls = [] 154 | with open(input_path,'r') as fin: 155 | for line in fin: 156 | line = json.loads(line) 157 | sentences.append(' '.join(line['sentence']).replace('hist|ically','histically')) 158 | ltls.append(' '.join(line['ltl'])) 159 | # random.shuffle(sentences) 160 | # sentences = ['When signal_1_n is greater than or equal to 48.9 and the transition action that the value of signal_2_n decreases below 55.2 occurs, then for each moment within the following 23 to 50 time units the value of signal signal_3_n should be 20.4 ultimately at a certain moment in the future before the execution ends.'] 161 | 162 | # ltls = 'G ( signal_1_n >= 48.9 & rise ( signal_2_n < 55.2 ) -> G [ 23 : 50 ] ( F ( signal_3_n == 20.4 ) ) )' 163 | 164 | 165 | ltl_idx = 0 166 | errors = 0 167 | corpus = [] 168 | for doc in nlp.pipe(sentences, disable=["tok2vec", "tagger", "attribute_ruler", "lemmatizer"]): 169 | # print(ltl_idx) 170 | # doc = nlp(sentence) 171 | ltl = ltls[ltl_idx] 172 | ltl_idx += 1 173 | # sentence = sentences[ltl_idx] 174 | sent = list(doc.sents)[0] 175 | tree = Tree.fromstring(sent._.parse_string) 176 | tokens = tree.leaves() 177 | sentence = ' '.join(tokens) 178 | ori_ltl = copy.deepcopy(ltl) 179 | ori_sent = copy.deepcopy(sentence) 180 | propositions, autos = get_proposition(ltl) 181 | try: 182 | prop2idx = get_prop_idx(propositions, tree) 183 | except: 184 | # print('-------------------- finding index error ---------------------------') 185 | # print(ltl, tokens) 186 | # print(propositions) 187 | # print(sentences[ltl_idx]) 188 | errors += 1 189 | continue 190 | indexes = find_prop_spans(prop2idx, tree) 191 | 192 | # print('# --------------------------------- #') 193 | # print(sentence) 194 | # print(ltl) 195 | sorted_auto = [] 196 | for i, (auto, index) in enumerate(zip(autos, indexes)): 197 | sorted_auto.append([auto, index, index[1]-index[0]]) 198 | sorted_auto = sorted(sorted_auto, key=lambda x: x[2], reverse=True) 199 | # print(sorted_auto) 200 | index2auto = {} 201 | for auto in sorted_auto: 202 | added = False 203 | for index in index2auto: 204 | if auto[1][0] >= index[0] and auto[1][1] <= index[1]: # this means the longer span contain the shorter span 205 | index2auto[index].append(auto) 206 | added = True 207 | break 208 | if not added: 209 | index2auto[auto[1]] = [auto] 210 | # print('----------------------------------') 211 | # print(sentence) 212 | # print(ltl) 213 | # print(index2auto) 214 | prop2sub_ltl = {} 215 | for i, index in enumerate(sorted(index2auto.keys())): 216 | prop_idx =' prop_{} '.format(str(i+1)) 217 | _autos = index2auto[index] 218 | # print('( prop_{} )'.format(str(i+1)), 'auto',autos, 'tokens', tokens[index[0]:index[1]]) 219 | sentence = sentence.replace(' '.join(tokens[index[0]:index[1]]), '({})'.format(prop_idx)) 220 | prop2sub_ltl[prop_idx]={'span':[index[0],index[1]]} 221 | 222 | for auto in _autos: 223 | auto = auto[0] 224 | # print(auto) 225 | try: 226 | ltl, sub_ltl = get_new_ltl(auto, ltl, prop_idx) 227 | if 'prop' in prop2sub_ltl[prop_idx]: 228 | prop2sub_ltl[prop_idx]['prop'].append(sub_ltl.split(' ')) 229 | else: 230 | prop2sub_ltl[prop_idx]['prop'] = [sub_ltl.split(' ')] 231 | except: 232 | # print('# --------------- bracket error ----------------- #') 233 | # print(auto) 234 | # print(sentence) 235 | # print(ltl) 236 | continue 237 | if not ltl.count('(') == ltl.count(')'): 238 | # print('# --------------- bracket error ----------------- #') 239 | # print(sentence) 240 | # print(ltl) 241 | continue 242 | if 'signal_' in sentence: 243 | # print('# --------------- auto error ----------------- #') 244 | # print(sentence) 245 | # print(ltl) 246 | continue 247 | if 'signal_' in ltl: 248 | # print('# --------------- ltl error ----------------- #') 249 | # print(propositions, autos) 250 | # print(sentence) 251 | # print(ltl) 252 | continue 253 | # print('#----------------- good instance -------------- # ') 254 | # print('ori sent',ori_sent) 255 | # print('ori ltl', ori_ltl) 256 | # print(propositions, autos) 257 | # print(sentence) 258 | # print(ltl) 259 | # print(prop2sub_ltl) 260 | corpus.append({'id':ltl_idx, 'sentence':ori_sent.split(' '), 'ltl': ori_ltl.split(' '),'logic_sentence':sentence.split(' '), 'logic_ltl':ltl.split(), 'propositions': prop2sub_ltl}) 261 | print(len(corpus)) 262 | #if len(corpus) > data_size: 263 | # break 264 | 265 | with open('/raw_data/circuit_{}_span_2.jsonl'.format(data_portion),'w') as fout: 266 | for line in corpus: 267 | fout.write(json.dumps(line)+'\n') 268 | 269 | 270 | 271 | -------------------------------------------------------------------------------- /Seq2seq_lifted_all.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | from io import open 3 | import unicodedata 4 | import string 5 | import re 6 | import random 7 | 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | from torch import optim 12 | import torch.nn.functional as F 13 | from argparse import ArgumentParser 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | MAX_LENGTH = 40 16 | 17 | parser = ArgumentParser() 18 | parser.add_argument('-seed', '--seed', type=int, default=1203) # input random seed 19 | parser.add_argument('-data_size', '--data_size', default='0.01-0.09') # The dataset size range '0.1-0.9' or '0.01-0.09' 20 | args = parser.parse_args() 21 | int_seed = args.seed 22 | data_size = args.data_size 23 | 24 | SOS_token = 0 25 | EOS_token = 1 26 | 27 | class Lang: 28 | def __init__(self, name): 29 | self.name = name 30 | self.word2index = {} 31 | self.word2count = {} 32 | self.index2word = {0: "SOS", 1: "EOS"} 33 | self.n_words = 2 # Count SOS and EOS 34 | 35 | def addSentence(self, sentence): 36 | for word in sentence.split(' '): 37 | self.addWord(word) 38 | 39 | def addWord(self, word): 40 | if word not in self.word2index: 41 | self.word2index[word] = self.n_words 42 | self.word2count[word] = 1 43 | self.index2word[self.n_words] = word 44 | self.n_words += 1 45 | else: 46 | self.word2count[word] += 1 47 | 48 | def readLangs(lang1, lang2, reverse=False, train_ratio=0.1): 49 | print("Reading lines...") 50 | 51 | # Read the file and split into lines 52 | home_path = 'Data_transfer_domain/Seq2Seq_baseline/Seq2seq_lifted_dataset_all_txt/' 53 | 54 | my_file = open(home_path + 'src.txt', "r") 55 | data_test_sentence = my_file.read() 56 | test_sentence = data_test_sentence.split("\n\n\n") 57 | my_file.close() 58 | 59 | my_file = open(home_path + 'tar.txt', "r") 60 | data_test_LTL = my_file.read() 61 | test_LTL = data_test_LTL.split("\n\n\n") 62 | my_file.close() 63 | print('test sentence length is: ', test_sentence) 64 | print('test LTL length is: ', test_LTL) 65 | 66 | lines_train = []; lines_test = [] 67 | for i in range(len(test_sentence)): 68 | if i < int(len(test_sentence)*train_ratio): 69 | lines_train.append(test_sentence[i]+'\t'+test_LTL[i]) 70 | else: 71 | lines_test.append(test_sentence[i]+'\t'+test_LTL[i]) 72 | print('The training ratio is: ', train_ratio) 73 | print('The num of training dataset is: ', len(lines_train)) 74 | print('The num of testing dataset is: ', len(lines_test)) 75 | print('/n'*2) 76 | 77 | # Split every line into pairs and normalize 78 | pairs_train = [[s for s in l.split('\t')] for l in lines_train] 79 | pairs_test = [[s for s in l.split('\t')] for l in lines_test] 80 | 81 | # Reverse pairs, make Lang instances 82 | if reverse: 83 | pairs_train = [list(reversed(p)) for p in pairs_train] 84 | input_lang_total = Lang(lang2) 85 | output_lang_total = Lang(lang1) 86 | pairs_test = [list(reversed(p)) for p in pairs_test] 87 | else: 88 | input_lang_total = Lang(lang1) 89 | output_lang_total = Lang(lang2) 90 | return input_lang_total, output_lang_total, pairs_train, pairs_test 91 | 92 | def filterPair(p): 93 | return len(p[0].split(' ')) < MAX_LENGTH and \ 94 | len(p[1].split(' ')) < MAX_LENGTH 95 | 96 | def filterPairs(pairs): 97 | return [pair for pair in pairs if filterPair(pair)] 98 | 99 | def prepareData(lang1, lang2, reverse=False ,train_ratio=0.1): 100 | input_lang_total, output_lang_total, pairs_train, pairs_test = readLangs(lang1, lang2, reverse, train_ratio) 101 | print("Read %s train sentence pairs" % len(pairs_train)) 102 | pairs_train = filterPairs(pairs_train) 103 | print("Trimmed to %s sentence pairs" % len(pairs_train)) 104 | 105 | print("Read %s test sentence pairs" % len(pairs_test)) 106 | pairs_test = filterPairs(pairs_test) 107 | print("Trimmed to %s sentence pairs" % len(pairs_test)) 108 | print("Counting words...") 109 | for pair in pairs_train: 110 | input_lang_total.addSentence(pair[0]) 111 | output_lang_total.addSentence(pair[1]) 112 | for pair in pairs_test: 113 | input_lang_total.addSentence(pair[0]) 114 | output_lang_total.addSentence(pair[1]) 115 | 116 | print("Counted words:") 117 | print(input_lang_total.name, input_lang_total.n_words) 118 | print(output_lang_total.name, output_lang_total.n_words) 119 | return input_lang_total, output_lang_total, pairs_train, pairs_test 120 | 121 | class EncoderRNN(nn.Module): 122 | def __init__(self, input_size, hidden_size): 123 | super(EncoderRNN, self).__init__() 124 | self.hidden_size = hidden_size 125 | 126 | self.embedding = nn.Embedding(input_size, hidden_size) 127 | self.gru = nn.GRU(hidden_size, hidden_size) 128 | 129 | def forward(self, input, hidden): 130 | embedded = self.embedding(input).view(1, 1, -1) 131 | output = embedded 132 | output, hidden = self.gru(output, hidden) 133 | return output, hidden 134 | 135 | def initHidden(self): 136 | return torch.zeros(1, 1, self.hidden_size, device=device) 137 | 138 | class DecoderRNN(nn.Module): 139 | def __init__(self, hidden_size, output_size): 140 | super(DecoderRNN, self).__init__() 141 | self.hidden_size = hidden_size 142 | 143 | self.embedding = nn.Embedding(output_size, hidden_size) 144 | self.gru = nn.GRU(hidden_size, hidden_size) 145 | self.out = nn.Linear(hidden_size, output_size) 146 | self.softmax = nn.LogSoftmax(dim=1) 147 | 148 | def forward(self, input, hidden): 149 | output = self.embedding(input).view(1, 1, -1) 150 | output = F.relu(output) 151 | output, hidden = self.gru(output, hidden) 152 | output = self.softmax(self.out(output[0])) 153 | return output, hidden 154 | 155 | def initHidden(self): 156 | return torch.zeros(1, 1, self.hidden_size, device=device) 157 | 158 | class AttnDecoderRNN(nn.Module): 159 | def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH): 160 | super(AttnDecoderRNN, self).__init__() 161 | self.hidden_size = hidden_size 162 | self.output_size = output_size 163 | self.dropout_p = dropout_p 164 | self.max_length = max_length 165 | 166 | self.embedding = nn.Embedding(self.output_size, self.hidden_size) 167 | self.attn = nn.Linear(self.hidden_size * 2, self.max_length) 168 | self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) 169 | self.dropout = nn.Dropout(self.dropout_p) 170 | self.gru = nn.GRU(self.hidden_size, self.hidden_size) 171 | self.out = nn.Linear(self.hidden_size, self.output_size) 172 | 173 | def forward(self, input, hidden, encoder_outputs): 174 | embedded = self.embedding(input).view(1, 1, -1) 175 | embedded = self.dropout(embedded) 176 | 177 | attn_weights = F.softmax( 178 | self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1) 179 | attn_applied = torch.bmm(attn_weights.unsqueeze(0), 180 | encoder_outputs.unsqueeze(0)) 181 | 182 | output = torch.cat((embedded[0], attn_applied[0]), 1) 183 | output = self.attn_combine(output).unsqueeze(0) 184 | 185 | output = F.relu(output) 186 | output, hidden = self.gru(output, hidden) 187 | 188 | output = F.log_softmax(self.out(output[0]), dim=1) 189 | return output, hidden, attn_weights 190 | 191 | def initHidden(self): 192 | return torch.zeros(1, 1, self.hidden_size, device=device) 193 | 194 | def indexesFromSentence(lang, sentence): 195 | return [lang.word2index[word] for word in sentence.split(' ')] 196 | 197 | 198 | def tensorFromSentence(lang, sentence): 199 | indexes = indexesFromSentence(lang, sentence) 200 | indexes.append(EOS_token) 201 | return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1) 202 | 203 | 204 | def tensorsFromPair(input_lang, output_lang, pair): 205 | input_tensor = tensorFromSentence(input_lang, pair[0]) 206 | target_tensor = tensorFromSentence(output_lang, pair[1]) 207 | return (input_tensor, target_tensor) 208 | 209 | teacher_forcing_ratio = 0.5 210 | 211 | 212 | def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length = MAX_LENGTH): 213 | encoder_hidden = encoder.initHidden() 214 | 215 | encoder_optimizer.zero_grad() 216 | decoder_optimizer.zero_grad() 217 | 218 | input_length = input_tensor.size(0) 219 | target_length = target_tensor.size(0) 220 | 221 | encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) 222 | 223 | loss = 0 224 | 225 | for ei in range(input_length): 226 | encoder_output, encoder_hidden = encoder( 227 | input_tensor[ei], encoder_hidden) 228 | encoder_outputs[ei] = encoder_output[0, 0] 229 | 230 | decoder_input = torch.tensor([[SOS_token]], device=device) 231 | 232 | decoder_hidden = encoder_hidden 233 | 234 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 235 | 236 | if use_teacher_forcing: 237 | # Teacher forcing: Feed the target as the next input 238 | for di in range(target_length): 239 | decoder_output, decoder_hidden, decoder_attention = decoder( 240 | decoder_input, decoder_hidden, encoder_outputs) 241 | loss += criterion(decoder_output, target_tensor[di]) 242 | decoder_input = target_tensor[di] # Teacher forcing 243 | 244 | else: 245 | # Without teacher forcing: use its own predictions as the next input 246 | for di in range(target_length): 247 | decoder_output, decoder_hidden, decoder_attention = decoder( 248 | decoder_input, decoder_hidden, encoder_outputs) 249 | topv, topi = decoder_output.topk(1) 250 | decoder_input = topi.squeeze().detach() # detach from history as input 251 | 252 | loss += criterion(decoder_output, target_tensor[di]) 253 | if decoder_input.item() == EOS_token: 254 | break 255 | 256 | loss.backward() 257 | 258 | encoder_optimizer.step() 259 | decoder_optimizer.step() 260 | 261 | return loss.item() / target_length 262 | 263 | import time 264 | import math 265 | def asMinutes(s): 266 | m = math.floor(s / 60) 267 | s -= m * 60 268 | return '%dm %ds' % (m, s) 269 | 270 | 271 | def timeSince(since, percent): 272 | now = time.time() 273 | s = now - since 274 | es = s / (percent) 275 | rs = es - s 276 | return '%s (- %s)' % (asMinutes(s), asMinutes(rs)) 277 | 278 | import matplotlib.pyplot as plt 279 | plt.switch_backend('agg') 280 | import matplotlib.ticker as ticker 281 | import numpy as np 282 | 283 | def showPlot(points): 284 | plt.figure() 285 | fig, ax = plt.subplots() 286 | # this locator puts ticks at regular intervals 287 | loc = ticker.MultipleLocator(base=0.2) 288 | ax.yaxis.set_major_locator(loc) 289 | plt.plot(points) 290 | 291 | def trainIters(encoder, decoder, n_iters, pairs_train, print_every=1000, plot_every=100, learning_rate=0.01): 292 | start = time.time() 293 | plot_losses = [] 294 | print_loss_total = 0 # Reset every print_every 295 | plot_loss_total = 0 # Reset every plot_every 296 | 297 | encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) 298 | decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) 299 | training_pairs = [tensorsFromPair(input_lang_total, output_lang_total,random.choice(pairs_train)) 300 | for i in range(n_iters)] 301 | criterion = nn.NLLLoss() 302 | 303 | for iter in range(1, n_iters + 1): 304 | training_pair = training_pairs[iter - 1] 305 | input_tensor = training_pair[0] 306 | target_tensor = training_pair[1] 307 | 308 | loss = train(input_tensor, target_tensor, encoder, 309 | decoder, encoder_optimizer, decoder_optimizer, criterion) 310 | print_loss_total += loss 311 | plot_loss_total += loss 312 | 313 | if iter % print_every == 0: 314 | print_loss_avg = print_loss_total / print_every 315 | print_loss_total = 0 316 | print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters), 317 | iter, iter / n_iters * 100, print_loss_avg)) 318 | 319 | if iter % plot_every == 0: 320 | plot_loss_avg = plot_loss_total / plot_every 321 | plot_losses.append(plot_loss_avg) 322 | plot_loss_total = 0 323 | 324 | showPlot(plot_losses) 325 | 326 | def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): 327 | with torch.no_grad(): 328 | input_tensor = tensorFromSentence(input_lang_total, sentence) 329 | input_length = input_tensor.size()[0] 330 | encoder_hidden = encoder.initHidden() 331 | 332 | encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) 333 | 334 | for ei in range(input_length): 335 | encoder_output, encoder_hidden = encoder(input_tensor[ei], 336 | encoder_hidden) 337 | encoder_outputs[ei] += encoder_output[0, 0] 338 | 339 | decoder_input = torch.tensor([[SOS_token]], device=device) # SOS 340 | 341 | decoder_hidden = encoder_hidden 342 | 343 | decoded_words = [] 344 | decoder_attentions = torch.zeros(max_length, max_length) 345 | 346 | for di in range(max_length): 347 | decoder_output, decoder_hidden, decoder_attention = decoder( 348 | decoder_input, decoder_hidden, encoder_outputs) 349 | decoder_attentions[di] = decoder_attention.data 350 | topv, topi = decoder_output.data.topk(1) 351 | if topi.item() == EOS_token: 352 | decoded_words.append('') 353 | break 354 | else: 355 | decoded_words.append(output_lang_total.index2word[topi.item()]) 356 | 357 | decoder_input = topi.squeeze().detach() 358 | 359 | return decoded_words, decoder_attentions[:di + 1] 360 | 361 | dataset_name = 'dataset_all_seq2seq_baseline' 362 | output_dir_total = 'trained_models/' + dataset_name + '/' 363 | if not os.path.exists(output_dir_total): 364 | os.mkdir(output_dir_total) 365 | 366 | if data_size == '0.1-0.9': 367 | output_dir = output_dir_total+dataset_name+'_seed'+str(int_seed)+'_one'+'/' 368 | elif data_size == '0.01-0.09': 369 | output_dir = output_dir_total + dataset_name + '_seed' + str(int_seed) + '_pointone' + '/' 370 | if not os.path.exists(output_dir): 371 | os.mkdir(output_dir) 372 | 373 | with open(output_dir +'result.txt', 'w') as f_result: 374 | for i in range(9): 375 | if data_size == '0.1-0.9': 376 | ratio = 0.1*(i+1) 377 | elif data_size == '0.01-0.09': 378 | ratio = 0.01*(i+1) 379 | input_lang_total, output_lang_total, pairs_train, pairs_test = prepareData('sentence', 'stl', False, train_ratio=ratio) 380 | hidden_size = 256 381 | encoder1 = EncoderRNN(input_lang_total.n_words, hidden_size).to(device) 382 | attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang_total.n_words, dropout_p=0.1).to(device) 383 | trainIters(encoder1, attn_decoder1, 75000, pairs_train, print_every=5000) 384 | 385 | count = 0 386 | for j in range(min(len(pairs_test),1000)): 387 | pair = pairs_test[j] 388 | output_words, attentions = evaluate(encoder1, attn_decoder1, pair[0]) 389 | output_sentence = ' '.join(output_words) 390 | if pair[1].split(' ') == output_sentence.split(' ')[:-1]: 391 | count += 1 392 | else: 393 | print('>', pair[0]) 394 | print('=', pair[1]) 395 | print('<', output_sentence) 396 | print('\n') 397 | print('Training dataset ratio is: ', ratio) 398 | print('Accuracy is: ', count /len(pairs_test)) 399 | print('\n'*2) 400 | f_result.write(str(ratio) + ' ' + str(count / (j + 1)) + '\n') 401 | f_result.close() 402 | -------------------------------------------------------------------------------- /dataset_creation_GPT3/framework1.py: -------------------------------------------------------------------------------- 1 | # Generate the STL-NL pairs 2 | import copy 3 | import random 4 | from fnmatch import fnmatchcase as match 5 | import numpy as np 6 | import json 7 | import os 8 | import openai 9 | from tqdm import tqdm 10 | import csv 11 | import pandas as pd 12 | from argparse import ArgumentParser 13 | 14 | parser = ArgumentParser() 15 | parser.add_argument('-n', '--number', type=int, default=10) 16 | parser.add_argument('-f', '--filename', type=int, default=0) 17 | args = parser.parse_args() 18 | init_num = args.number 19 | filename = args.filename 20 | 21 | time_upper_limit = 600 22 | 23 | operation_list = ['negation', '&', '->', '<->', '|', 'G', 'F', 'U', 'G[*,**', 'F[*,**', 'U[*,**'] 24 | two_leaves = ['&', '->', '<->', '|', 'U', 'U[*,**'] 25 | one_leaf = ['negation', 'G', 'F', 'G[*,**', 'F[*,**'] 26 | no_leaf = ['prop**'] 27 | prop_with_time = ['G[*,**', 'F[*,**', 'U[*,**'] 28 | 29 | def operation2word(pre_order_list): 30 | dict_operation = {} 31 | dict_operation['&'] = 'and'; dict_operation['->'] = 'imply'; dict_operation['<->'] = 'equal'; dict_operation['|'] = 'or' 32 | dict_operation['G'] = 'globally'; dict_operation['F'] = 'finally'; dict_operation['U'] = 'until'; dict_operation['negation'] = 'negation' 33 | post_list = [] 34 | for item in pre_order_list: 35 | if match(item, 'G[*,**'): 36 | post_list.append('globally'+item[1:]) 37 | elif match(item, 'F[*,**'): 38 | post_list.append('finally'+item[1:]) 39 | elif match(item, 'U[*,**'): 40 | post_list.append('until'+item[1:]) 41 | elif match(item, 'prop**'): 42 | post_list.append(item) 43 | else: 44 | post_list.append(dict_operation[item]) 45 | return post_list 46 | 47 | def judge_leaf_num(item): 48 | two_leaves = ['&', '->', '<->', '|', 'U', 'U[*,**'] 49 | one_leaf = ['negation', 'G', 'F', 'G[*,**', 'F[*,**'] 50 | if item in two_leaves or item[0]=='U': 51 | return 2 52 | elif match(item, 'prop**'): 53 | return 0 54 | else: 55 | return 1 56 | 57 | def pre_to_mid_exp(item_list_original): 58 | item_list = copy.deepcopy(item_list_original) 59 | item_list.reverse() 60 | word_list = operation2word(item_list) 61 | stack_list = []; stack_list_2 = [] 62 | mark = 0 63 | for i,item in enumerate(item_list): 64 | if judge_leaf_num(item) == 0: 65 | stack_list.append(word_list[i]) 66 | stack_list_2.append(item_list[i]) 67 | elif judge_leaf_num(item) == 1: 68 | if mark: 69 | candidate = word_list[i] + ' '+ '(' + stack_list[-1] + ')' 70 | candidate_2 = item_list[i] + ' '+ '(' + stack_list_2[-1] + ')' 71 | else: 72 | candidate = word_list[i] + ' '+ stack_list[-1] 73 | candidate_2 = item_list[i] + ' '+ stack_list_2[-1] 74 | stack_list.pop(-1) 75 | stack_list.append(candidate) 76 | stack_list_2.pop(-1) 77 | stack_list_2.append(candidate_2) 78 | mark = 1 79 | elif judge_leaf_num(item) == 2: 80 | candidate = '('+ stack_list[-1] + ' ' + word_list[i] + ' ' + stack_list[-2] + ')' 81 | candidate_2 = '('+ stack_list_2[-1] + ' ' + item_list[i] + ' ' + stack_list_2[-2] + ')' 82 | stack_list.pop(-1) 83 | stack_list.pop(-1) 84 | stack_list.append(candidate) 85 | stack_list_2.pop(-1) 86 | stack_list_2.pop(-1) 87 | stack_list_2.append(candidate_2) 88 | mark=0 89 | return candidate, candidate_2 90 | 91 | def generate_ltl_from_list(raw_list): 92 | base_list = copy.deepcopy(raw_list) 93 | if len(base_list)==1: 94 | random_num_2 = random.randint(0, len(one_leaf)-1) 95 | if one_leaf[random_num_2] in prop_with_time: 96 | time1 = random.randint(1, 500); time2 = random.randint(1, 500) + time1 97 | if time2 > time_upper_limit: 98 | base_list.insert(0, one_leaf[random_num_2][0:2]+str(time1)+','+'infinite'+']') 99 | else: 100 | base_list.insert(0, one_leaf[random_num_2][0:2]+str(time1)+','+str(time2)+']') 101 | else: 102 | base_list.insert(0, one_leaf[random_num_2]) 103 | elif len(base_list)>1: 104 | count = len(base_list) 105 | while count>1: 106 | random_num_3 = random.randint(0, len(operation_list)-1) 107 | if operation_list[random_num_3] in prop_with_time: 108 | time1 = random.randint(1, 500); time2 = random.randint(1, 500) + time1 109 | base_list.insert(0, operation_list[random_num_3][0:2]+str(time1)+','+str(time2)+']') 110 | else: 111 | base_list.insert(0, operation_list[random_num_3]) 112 | count -= judge_leaf_num(operation_list[random_num_3]) 113 | count += 1 114 | return base_list 115 | 116 | def generate_ltl_v2(prop_num = 6): 117 | random_num = random.randint(2, prop_num) 118 | num_prop = [i for i in range(1,random_num)] 119 | random.shuffle(num_prop) 120 | father_list = ['prop_'+str(item) for item in num_prop] 121 | if len(father_list) > 1: 122 | divide_index = random.randint(1, len(father_list)-1) 123 | son_list1 = generate_ltl_from_list(father_list[0 : divide_index]) 124 | son_list2 = generate_ltl_from_list(father_list[divide_index : len(father_list)]) 125 | random_num_1 = random.randint(0, len(two_leaves)-1) 126 | 127 | mid_two_leaves_item = [] 128 | if two_leaves[random_num_1] in prop_with_time: 129 | time1 = random.randint(1, 500); time2 = random.randint(1, 500) + time1 130 | if time2 > time_upper_limit: 131 | mid_two_leaves_item.append(two_leaves[random_num_1][0:2]+str(time1)+','+'infinite'+']') 132 | else: 133 | mid_two_leaves_item.append(two_leaves[random_num_1][0:2]+str(time1)+','+str(time2)+']') 134 | else: 135 | mid_two_leaves_item.append(two_leaves[random_num_1]) 136 | 137 | total_list = mid_two_leaves_item + son_list1 + son_list2 138 | return total_list 139 | 140 | # generate STLs 141 | list_ltl = [] 142 | print('-'*40) 143 | print('Generating rule based STL: ') 144 | while len(list_ltl) < init_num: 145 | ltl_can = generate_ltl_v2(9) 146 | mark = 1 147 | if ltl_can is not None and len(ltl_can) <8 and len(ltl_can) >5: 148 | #if len(ltl_can) <=5 and len(ltl_can) >=3: 149 | count_time = 0; count_FGU = 0 150 | count_negation = 0; count_u = 0 151 | for i,item in enumerate(ltl_can): 152 | if item[0] == 'U': 153 | count_u += 1 154 | if match(item, 'G[*,**') or match(item, 'F[*,**') or match(item, 'U[*,**'): 155 | if i']: 164 | if i', '<->', '|', 'G', 'F', 'U', 'G[*,**', 'F[*,**', 'U[*,**'] 25 | two_leaves = ['&', '->', '<->', '|', 'U', 'U[*,**'] 26 | one_leaf = ['negation', 'G', 'F', 'G[*,**', 'F[*,**'] 27 | no_leaf = ['prop**'] 28 | prop_with_time = ['G[*,**', 'F[*,**', 'U[*,**'] 29 | 30 | def operation2word(pre_order_list): 31 | dict_operation = {} 32 | dict_operation['&'] = 'and'; dict_operation['->'] = 'imply'; dict_operation['<->'] = 'equal'; dict_operation['|'] = 'or' 33 | dict_operation['G'] = 'globally'; dict_operation['F'] = 'finally'; dict_operation['U'] = 'until'; dict_operation['negation'] = 'negation' 34 | post_list = [] 35 | for item in pre_order_list: 36 | if match(item, 'G[*,**'): 37 | post_list.append('globally'+item[1:]) 38 | elif match(item, 'F[*,**'): 39 | post_list.append('finally'+item[1:]) 40 | elif match(item, 'U[*,**'): 41 | post_list.append('until'+item[1:]) 42 | elif match(item, 'prop**'): 43 | post_list.append(item) 44 | else: 45 | post_list.append(dict_operation[item]) 46 | return post_list 47 | 48 | def judge_leaf_num(item): 49 | two_leaves = ['&', '->', '<->', '|', 'U', 'U[*,**'] 50 | one_leaf = ['negation', 'G', 'F', 'G[*,**', 'F[*,**'] 51 | if item in two_leaves or item[0]=='U': 52 | return 2 53 | elif match(item, 'prop**'): 54 | return 0 55 | else: 56 | return 1 57 | 58 | def pre_to_mid_exp(item_list_original): 59 | item_list = copy.deepcopy(item_list_original) 60 | item_list.reverse() 61 | word_list = operation2word(item_list) 62 | stack_list = []; stack_list_2 = [] 63 | mark = 0 64 | for i,item in enumerate(item_list): 65 | if judge_leaf_num(item) == 0: 66 | stack_list.append(word_list[i]) 67 | stack_list_2.append(item_list[i]) 68 | elif judge_leaf_num(item) == 1: 69 | if mark: 70 | candidate = word_list[i] + ' '+ '(' + stack_list[-1] + ')' 71 | candidate_2 = item_list[i] + ' '+ '(' + stack_list_2[-1] + ')' 72 | else: 73 | candidate = word_list[i] + ' '+ stack_list[-1] 74 | candidate_2 = item_list[i] + ' '+ stack_list_2[-1] 75 | stack_list.pop(-1) 76 | stack_list.append(candidate) 77 | stack_list_2.pop(-1) 78 | stack_list_2.append(candidate_2) 79 | mark = 1 80 | elif judge_leaf_num(item) == 2: 81 | candidate = '('+ stack_list[-1] + ' ' + word_list[i] + ' ' + stack_list[-2] + ')' 82 | candidate_2 = '('+ stack_list_2[-1] + ' ' + item_list[i] + ' ' + stack_list_2[-2] + ')' 83 | stack_list.pop(-1) 84 | stack_list.pop(-1) 85 | stack_list.append(candidate) 86 | stack_list_2.pop(-1) 87 | stack_list_2.pop(-1) 88 | stack_list_2.append(candidate_2) 89 | mark=0 90 | return candidate, candidate_2 91 | 92 | def generate_ltl_from_list(raw_list): 93 | base_list = copy.deepcopy(raw_list) 94 | if len(base_list)==1: 95 | random_num_2 = random.randint(0, len(one_leaf)-1) 96 | if one_leaf[random_num_2] in prop_with_time: 97 | time1 = random.randint(1, 500); time2 = random.randint(1, 500) + time1 98 | base_list.insert(0, one_leaf[random_num_2][0:2]+str(time1)+','+str(time2)+']') 99 | else: 100 | base_list.insert(0, one_leaf[random_num_2]) 101 | elif len(base_list)>1: 102 | count = len(base_list) 103 | while count>1: 104 | random_num_3 = random.randint(0, len(operation_list)-1) 105 | if operation_list[random_num_3] in prop_with_time: 106 | time1 = random.randint(1, 500); time2 = random.randint(1, 500) + time1 107 | base_list.insert(0, operation_list[random_num_3][0:2]+str(time1)+','+str(time2)+']') 108 | else: 109 | base_list.insert(0, operation_list[random_num_3]) 110 | count -= judge_leaf_num(operation_list[random_num_3]) 111 | count += 1 112 | return base_list 113 | 114 | def generate_ltl(): 115 | two_leaves = ['&', '->', '<->', '|', 'U', 'U[*,**'] 116 | one_leaf = ['negation', 'G', 'F', 'G[*,**', 'F[*,**'] 117 | random_num = random.randint(2, 6) 118 | num_prop = [i for i in range(1,random_num)] 119 | random.shuffle(num_prop) 120 | base_list = ['prop_'+str(item) for item in num_prop] 121 | if len(base_list)==1: 122 | random_num_2 = random.randint(0, len(one_leaf)-1) 123 | if one_leaf[random_num_2] in prop_with_time: 124 | time1 = random.randint(1, 500); time2 = random.randint(1, 500) + time1 125 | base_list.insert(0, one_leaf[random_num_2][0:2]+str(time1)+','+str(time2)+']') 126 | else: 127 | base_list.insert(0, one_leaf[random_num_2]) 128 | elif len(base_list)>1: 129 | count = len(base_list) 130 | while count>1: 131 | random_num_3 = random.randint(0, len(operation_list)-1) 132 | if operation_list[random_num_3] in prop_with_time: 133 | time1 = random.randint(1, 500); time2 = random.randint(1, 500) + time1 134 | base_list.insert(0, operation_list[random_num_3][0:2]+str(time1)+','+str(time2)+']') 135 | else: 136 | base_list.insert(0, operation_list[random_num_3]) 137 | count -= judge_leaf_num(operation_list[random_num_3]) 138 | count += 1 139 | return base_list 140 | 141 | def generate_ltl_v2(prop_num = 6): 142 | random_num = random.randint(2, prop_num) 143 | num_prop = [i for i in range(1,random_num)] 144 | random.shuffle(num_prop) 145 | father_list = ['prop_'+str(item) for item in num_prop] 146 | if len(father_list) > 1: 147 | divide_index = random.randint(1, len(father_list)-1) 148 | son_list1 = generate_ltl_from_list(father_list[0 : divide_index]) 149 | son_list2 = generate_ltl_from_list(father_list[divide_index : len(father_list)]) 150 | random_num_1 = random.randint(0, len(two_leaves)-1) 151 | 152 | mid_two_leaves_item = [] 153 | if two_leaves[random_num_1] in prop_with_time: 154 | time1 = random.randint(1, 500); time2 = random.randint(1, 500) + time1 155 | mid_two_leaves_item.append(two_leaves[random_num_1][0:2]+str(time1)+','+str(time2)+']') 156 | else: 157 | mid_two_leaves_item.append(two_leaves[random_num_1]) 158 | 159 | total_list = mid_two_leaves_item + son_list1 + son_list2 160 | return total_list 161 | 162 | # generate STLs 163 | list_ltl = [] 164 | print('-'*40) 165 | print('Generating rule based STL: ') 166 | while len(list_ltl) < init_num: 167 | ltl_can = generate_ltl_v2(6) 168 | mark = 1 169 | if ltl_can is not None and len(ltl_can) <8 and len(ltl_can) >5: 170 | #if len(ltl_can) <=5 and len(ltl_can) >=3: 171 | count_time = 0; count_FGU = 0 172 | count_negation = 0; count_u = 0 173 | for i,item in enumerate(ltl_can): 174 | if item[0] == 'U': 175 | count_u += 1 176 | if match(item, 'G[*,**') or match(item, 'F[*,**') or match(item, 'U[*,**'): 177 | if i']: 186 | if i', 'prop_1', 'prop_4', 'prop_2', 'prop_3'] 297 | #prompt="Try to transform the following natural languages into signal temporal logics, the operators in the signal temporal logic are: negation, imply, and, equal, until, globally, finally, or .\nThe signal temporal logics are prefix expressions. The examples are as following:\nnatural language: It is required that for every moment during the interval 489 to 663 either the event that ( prop_1 ) is detected and in response ( prop_3 ) should happen , or ( prop_2 ) should be true .\nSTL: ['or', 'globally [489,663]', 'imply', 'prop_1', 'prop_3', 'prop_2']\n\nnatural language: It should be the case that if ( prop_4 ) or ( prop_2 ) then ( prop_3 ), and ( prop_1 ) .\nSTL: ['and', 'imply', 'or', 'prop_4', 'prop_2', 'prop_3', 'prop_1']\n\nnatural language: It is always the case that if it is not the case that ( prop_2 ) then ( prop_3 ), and ( prop_1 ) .\nSTL: ['and', 'globally', 'imply', 'negation', 'prop_2', 'prop_3', 'prop_1']\n\nnatural language: ( prop_3 ) should happen until at some point during the 483 to 907 time units , then ( prop_1 ) should happen, or else ( prop_2 ) , or else ( prop_4 ) .\nSTL: ['or', 'or', 'until [483,907]', 'prop_3', 'prop_1', 'prop_2', 'prop_4']\n\nnatural language: It is true that if the scenario in which ( prop_4 ) leads to ( prop_3 ) happens and continues until ( prop_1 ) happens , then ( prop_2 ) should be observed . And it is also true that if ( prop_2 ) is observed , then ( prop_4 ) should have led to ( prop_3 ) and this condition continues until ( prop_1 ) happens .\nSTL: ['equal', 'until', 'imply', 'prop_4', 'prop_3', 'prop_1', 'prop_2']\n\nnatural language: Before a certain time point within the next 15 to 196 time units ( prop_2 ) leads to ( prop_4 ) and ( prop_3 ) is true , then starting from this time point ( prop_1 ) .\nSTL: ['until [15,196]', 'and', 'imply', 'prop_2', 'prop_4', 'prop_3', 'prop_1']\n\nnatural language: If ( prop_4 ) then implies ( prop_2 ), and in the same time ( prop_1 ) , or else ( prop_3 ) .\nSTL: ['or', 'and', 'imply', 'prop_4', 'prop_2', 'prop_1', 'prop_3']\n\nnatural language: It is always the case that if within the next 139 to 563 time units , the scenario that ( prop_2 ) is detected then as a response ( prop_1 ) , and ( prop_3 ) .\nSTL: ['and', 'globally [139,563]', 'imply', 'prop_2', 'prop_1', 'prop_3']\n\nnatural language: If it is the case that ( prop_2 ) and ( prop_4 ) are equivalent and continue to happen until the scenario that ( prop_1 ) is detected then in response ( prop_3 ) should happen .\nSTL: ['imply', 'until', 'equal', 'prop_2', 'prop_4', 'prop_1', 'prop_3']\n\nnatural language: If ( prop_3 ) then implies ( prop_4 ), this condition should continue to happen until at some point within the next 450 to 942 time units , after that ( prop_2 ) , or ( prop_1 ) .\nSTL: ['or', 'until [450,942]', 'imply', 'prop_3', 'prop_4', 'prop_2', 'prop_1']\n\nnatural language: " 298 | prompt="Try to transform the following natural languages into signal temporal logics, the operators in the signal temporal logic are: negation, imply, and, equal, until, globally, finally, or .\nThe signal temporal logics are prefix expressions. The examples are as following:\nnatural language: It is required that for every moment during the interval 489 to 663 either the event that ( prop_1 ) is detected and in response ( prop_3 ) should happen , or ( prop_2 ) should be true .\nSTL: ['or', 'globally [489,663]', 'imply', 'prop_1', 'prop_3', 'prop_2']\n\nnatural language: It should be the case that if ( prop_4 ) or ( prop_2 ) then ( prop_3 ), and ( prop_1 ) .\nSTL: ['and', 'imply', 'or', 'prop_4', 'prop_2', 'prop_3', 'prop_1']\n\nnatural language: It is always the case that if it is not the case that ( prop_2 ) then ( prop_3 ), and ( prop_1 ) .\nSTL: ['and', 'globally', 'imply', 'negation', 'prop_2', 'prop_3', 'prop_1']\n\nnatural language: ( prop_3 ) should happen until at some point during the 483 to 907 time units , then ( prop_1 ) should happen, or else ( prop_2 ) , or else ( prop_4 ) .\nSTL: ['or', 'or', 'until [483,907]', 'prop_3', 'prop_1', 'prop_2', 'prop_4']\n\nnatural language: It is true that if the scenario in which ( prop_4 ) leads to ( prop_3 ) happens and continues until ( prop_1 ) happens , then ( prop_2 ) should be observed . And it is also true that if ( prop_2 ) is observed , then ( prop_4 ) should have led to ( prop_3 ) and this condition continues until ( prop_1 ) happens .\nSTL: ['equal', 'until', 'imply', 'prop_4', 'prop_3', 'prop_1', 'prop_2']\n\nnatural language: Before a certain time point within the next 15 to 196 time units ( prop_2 ) leads to ( prop_4 ) and ( prop_3 ) is true , then starting from this time point ( prop_1 ) .\nSTL: ['until [15,196]', 'and', 'imply', 'prop_2', 'prop_4', 'prop_3', 'prop_1']\n\nnatural language: If ( prop_4 ) then implies ( prop_2 ), and in the same time ( prop_1 ) , or else ( prop_3 ) .\nSTL: ['or', 'and', 'imply', 'prop_4', 'prop_2', 'prop_1', 'prop_3']\n\nnatural language: It is always the case that if within the next 139 to 563 time units , the scenario that ( prop_2 ) is detected then as a response ( prop_1 ) , and ( prop_3 ) .\nSTL: ['and', 'globally [139,563]', 'imply', 'prop_2', 'prop_1', 'prop_3']\n\nnatural language: If it is the case that ( prop_2 ) and ( prop_4 ) are equivalent and continue to happen until the scenario that ( prop_1 ) is detected then in response ( prop_3 ) should happen .\nSTL: ['imply', 'until', 'equal', 'prop_2', 'prop_4', 'prop_1', 'prop_3']\n\nnatural language: If ( prop_3 ) then implies ( prop_4 ), this condition should continue to happen until at some point within the next 450 to 942 time units , after that ( prop_2 ) , or ( prop_1 ) .\nSTL: ['or', 'until [450,942]', 'imply', 'prop_3', 'prop_4', 'prop_2', 'prop_1']\n\nnatural language: ( prop_3 ) happens until a time in the next 5 to 12 units that ( prop_4 ) does not happen .\nSTL: ['until [5,12]', 'prop_3', 'negation', 'prop_4']\n\nnatural language: The time that ( prop_3 ) happens is when ( prop_1 ) happens , and vice versa .\nSTL: ['equal', 'prop_3', 'prop_1']\n\nnatural language: It is required that both ( prop_2 ) and ( prop_4 ) happen at the same time, or else ( prop_3 ) happens and continues until ( prop_1 ) does not happen.\nSTL: ['or', 'and', 'prop_2', 'prop_4', 'until', 'prop_3', 'negation', 'prop_1']\n\nnatural language: ( prop_3 ) happens and continues until at some point during the 500 to 903 time units ( prop_1 ) happens , and in the same time ( prop_2 ) does not happen .\nSTL: ['and', 'until [500,903]', 'prop_3', 'prop_1', 'negation', 'prop_2']\n\nnatural language: For each time instant in the next 107 to 513 time units ( prop_1 ) is true , or else ( prop_3 ) happens and ( prop_2 ) happens at the same time.\nSTL: ['or', 'globally [107,513]', 'prop_1', 'and', 'prop_3', 'prop_2']\n(globally [107,513] prop_1 or (prop_3 and prop_2))\n\nnatural language: ( prop_1 ) or ( prop_2 ) happens and continues until at some point during the 142 to 365 time units ( prop_4 ) happens and ( prop_3 ) happens at the same time .\nSTL: ['until [142,365]', 'or', 'prop_1', 'prop_2', 'and', 'prop_4', 'prop_3']\n\nnatural language: For each time instant in the next 91 to 471 time units ( prop_2 ) happens , and ( prop_1 ) or ( prop_3 ) also happens .\nSTL: ['and', 'globally [91,471]', 'prop_2', 'or', 'prop_1', 'prop_3']\n\nnatural language: If the case ( prop_1 ) does not happen is equivalent to the case ( prop_2 ) happens , then for each time instant in the next 483 to 715 time units ( prop_3 ) is true .\nSTL: ['imply', 'equal', 'negation', 'prop_1', 'prop_2', 'globally [483,715]', 'prop_3']\n\nnatural language: It is required that either ( prop_1 ) or ( prop_3 ) happens , and in the same time ( prop_2 ) does not happen .\nSTL: ['and', 'or', 'prop_1', 'prop_3', 'negation', 'prop_2']\n\nnatural language: For each time instant in the next 320 to 493 time units ( prop_2 ) happens , is equivalent to the case that if ( prop_3 ) then ( prop_1 ) .\nSTL: ['equal', 'globally [320,493]', 'prop_2', 'imply', 'prop_3', 'prop_1']\n\nnatural language: ( prop_1 ) or ( prop_2 ) happens and continues until at some point during the 152 to 154 time units that ( prop_3 ) does not happen .\nSTL: ['until [152,154]', 'or', 'prop_1', 'prop_2', 'negation', 'prop_3']\n\nnatural language: ( prop_1 ) should not happen and ( prop_2 ) should happen at the same time , and the above scenario is equivalent to the case that at some point during the 230 to 280 time units ( prop_3 ) happens .\nSTL: ['equal', 'and', 'negation', 'prop_1', 'prop_2', 'finally [230,280]', 'prop_3']\n\nnatural language: If ( prop_2 ) then ( prop_3 ) happens , and at some point during the 7 to 283 time units ( prop_1 ) happens .\nSTL: ['and', 'imply', 'prop_2', 'prop_3', 'finally [7,283]', 'prop_1']\n\nnatural language: ( prop_3 ) and ( prop_2 ) should happen at the same time , or else ( prop_4 ) happens and continues until at some point during the 469 to 961 time units ( prop_1 ) happens .\nSTL: ['or', 'and', 'prop_3', 'prop_2', 'until [469,961]', 'prop_4', 'prop_1']\n\nnatural language: ( prop_1 ) implies ( prop_3 ) , and ( prop_4 ) happens if and only if ( prop_2 ) .\nSTL: ['and', 'equal', 'imply', 'prop_1', 'prop_3', 'prop_4', 'prop_2']\n\nnatural language: In the following 10 time steps , the ( prop_1 ) should always happen , and in the meantime , ( prop_2 ) should happen at least once .\nSTL: ['and', 'globally [0,10]', 'prop_1', 'finally', 'prop_2']\n\nnatural language: ( prop_1 ) should not happen if ( prop_2 ) does not happen , and ( prop_3 ) should also be true all the time .\nSTL: ['and', 'imply', 'negation', 'prop_2', 'negation', 'prop_1', 'globally', 'prop_3']\n\nnatural language: If ( prop_1 ) and ( prop_2 ), then ( prop_3 ) until ( prop_4 ) does not happen , and ( prop_5 ) until ( prop_6 ) does not happen .\nSTL: ['and', 'imply', 'and', 'prop_1', 'prop_2', 'until', 'prop_3', 'negation', 'prop_4', 'until', 'prop_5', 'negation', 'prop_6']\n\nnatural language: For each time instant in the next 0 to 120 units, do ( prop_1 ) if ( prop_2 ) , and if possible, ( prop_4 ) .\nSTL: ['and', 'globally [0,120]', 'imply', 'prop_2', 'prop_1', 'prop_4']\n\nnatural language: In the next 0 to 5 time units , do the ( prop_1 ) , but in the next 3 to 4 time units , ( prop_2 ) should not happen .\nSTL: ['and', 'globally [0,5]', 'prop_1', 'globally [3,4]', 'negation', 'prop_2']\n\nnatural language: " 299 | + original_sent + '\nLTL:', 300 | 301 | temperature=0.6, 302 | max_tokens=256, 303 | top_p=1, 304 | frequency_penalty=0, 305 | presence_penalty=0 306 | ) 307 | return response['choices'][0]['text'][1:] 308 | 309 | def verify_correct(test_list): 310 | two_leaves = ['and', 'imply', 'equal', 'or', 'until'] 311 | one_leaf = ['globally', 'finally', 'negation'] 312 | count = 1; mark =1; time_first_appear_index = -1 313 | for i, item in enumerate(test_list): 314 | if item in two_leaves or item[0:5] == 'until': 315 | count += 1 316 | elif item in one_leaf or item[0:8] == 'globally' or item[0:7] == 'finally': 317 | count += 0 318 | elif match(item, 'prop**'): 319 | count -= 1 320 | if time_first_appear_index == -1: 321 | time_first_appear_index = i 322 | else: mark =0 323 | if count !=0: 324 | mark = 0 325 | for i in range(time_first_appear_index, len(test_list)): 326 | if not match(test_list[i], 'prop**'): 327 | mark = 0 328 | return mark 329 | 330 | def str2list_ltl_2(original_ltl): 331 | original_ltl = original_ltl[1:len(original_ltl)-1].split(', ') 332 | logic_ltl = [] 333 | for item in original_ltl: 334 | if item[-1] == ',': 335 | logic_ltl.append(item[1:len(item)-2]) 336 | else: 337 | logic_ltl.append(item[1:len(item)-1]) 338 | return logic_ltl 339 | 340 | print('-'*40) 341 | print('From NL-1 to LTL-2') 342 | LTL_2 = [] 343 | for i in range(len(dataset_nl_1)): 344 | mark = 0; total_try = 0 345 | while mark ==0: 346 | paraphrased_logic_ltl_raw = paraphrase_GPT3_to_ltl(dataset_nl_1[i]) 347 | for index in range(len(paraphrased_logic_ltl_raw )): 348 | if paraphrased_logic_ltl_raw[index] != '\n' and paraphrased_logic_ltl_raw[index] != ' ': 349 | break 350 | paraphrased_logic_ltl = paraphrased_logic_ltl_raw[index:] 351 | mark = verify_correct(str2list_ltl_2(paraphrased_logic_ltl)) 352 | total_try += 1 353 | if total_try >1: 354 | break 355 | if mark: 356 | print(str2list_ltl_2(paraphrased_logic_ltl)) 357 | word_express, operation_express = pre_to_mid_exp_2(str2list_ltl_2(paraphrased_logic_ltl)) 358 | print(word_express) 359 | print('\n') 360 | LTL_2.append(str2list_ltl_2(paraphrased_logic_ltl)) 361 | 362 | # From LTL-2 to NL_2 and generate excel file 363 | path = '../Raw_data' 364 | f = open(path+'/test1.csv','w') 365 | csv_writer = csv.writer(f) 366 | csv_writer.writerow(['paraphrased_logic_sentence','logic_ltl_true_natural_order','original_logic_sentence', 'logic_ltl', 'Mark', 'Comments']) 367 | 368 | for i in range(len(LTL_2)): 369 | paraphrased_logic_sentence_raw = paraphrase_GPT3(pre_to_mid_exp_2(LTL_2[i])[0]) 370 | for index in range(len(paraphrased_logic_sentence_raw )): 371 | if paraphrased_logic_sentence_raw[index] != '\n' and paraphrased_logic_sentence_raw[index] != ' ': 372 | break 373 | paraphrased_logic_sentence = paraphrased_logic_sentence_raw[index:] 374 | csv_writer.writerow([paraphrased_logic_sentence, pre_to_mid_exp_2(LTL_2[i])[0] 375 | ,' ',LTL_2[i] 376 | , str(1), '']) 377 | f.close() 378 | 379 | df = pd.read_csv(path+'/test1.csv') 380 | df.to_excel(path+'/output_davinci_loop_'+str(filename)+'.xlsx', 'Sheet1') 381 | -------------------------------------------------------------------------------- /Run.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"0d6fa6d663574e6888336392536e4897":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_2123b6617db84b8092e43bb8fe5dd6e9","IPY_MODEL_fc6c4a1c018c4156aa00a07ada65269b","IPY_MODEL_d8e2c00be1ab4350852ffd5e69a5e459"],"layout":"IPY_MODEL_48c78f9d82dd415389548c69fd977aac"}},"2123b6617db84b8092e43bb8fe5dd6e9":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_a455a02a61dc403e9d004acb75cec3a5","placeholder":"​","style":"IPY_MODEL_b0a5506b39ff4839a0ad79b293bf1afe","value":"Downloading (…)lve/main/config.json: 100%"}},"fc6c4a1c018c4156aa00a07ada65269b":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_9a4b06f22c62474bb19b0bb3e213dd74","max":1208,"min":0,"orientation":"horizontal","style":"IPY_MODEL_d06acbf2fe3a4ef1a586ba8cbc379231","value":1208}},"d8e2c00be1ab4350852ffd5e69a5e459":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_4c6e1ced7b914f0caa8a2d549fcf8b7f","placeholder":"​","style":"IPY_MODEL_c5febe95edb846e3831f059292cf07a6","value":" 1.21k/1.21k [00:00<00:00, 62.3kB/s]"}},"48c78f9d82dd415389548c69fd977aac":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"a455a02a61dc403e9d004acb75cec3a5":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"b0a5506b39ff4839a0ad79b293bf1afe":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"9a4b06f22c62474bb19b0bb3e213dd74":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"d06acbf2fe3a4ef1a586ba8cbc379231":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"4c6e1ced7b914f0caa8a2d549fcf8b7f":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c5febe95edb846e3831f059292cf07a6":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"7c4ba98f462d47d887855ac9e945770d":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_f0679738f4b940879dd68e94fe8d0d53","IPY_MODEL_e892d45ee3534c2490b599e782ecb204","IPY_MODEL_6138d12388654b67a0c5b3e514ef8c8c"],"layout":"IPY_MODEL_685a0a9aea8d46e09352a0534f892779"}},"f0679738f4b940879dd68e94fe8d0d53":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_9eb0819b08b94c708c1e62001c556550","placeholder":"​","style":"IPY_MODEL_ad012aad9541413197c0daf8c4a49fba","value":"Downloading (…)ve/main/spiece.model: 100%"}},"e892d45ee3534c2490b599e782ecb204":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_cd3e1bd7c1724e738331bb663b4dfb3e","max":791656,"min":0,"orientation":"horizontal","style":"IPY_MODEL_126af1b857114b0cb010d16cbe8e5ee8","value":791656}},"6138d12388654b67a0c5b3e514ef8c8c":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_91dc34d8904442b38e1c340c231b2dfc","placeholder":"​","style":"IPY_MODEL_223ea5741ce4461784c2216600a1b791","value":" 792k/792k [00:00<00:00, 929kB/s]"}},"685a0a9aea8d46e09352a0534f892779":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"9eb0819b08b94c708c1e62001c556550":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ad012aad9541413197c0daf8c4a49fba":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"cd3e1bd7c1724e738331bb663b4dfb3e":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"126af1b857114b0cb010d16cbe8e5ee8":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"91dc34d8904442b38e1c340c231b2dfc":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"223ea5741ce4461784c2216600a1b791":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"a9692ceb9afc4ad088e8fb0ef8f00b3e":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_ac5aebc3402f436ba7ad6790639616b8","IPY_MODEL_eb42d4a40f7a400788e65c7212bef645","IPY_MODEL_04f0c4b7b02d447b9c8617577b8cdad8"],"layout":"IPY_MODEL_2cedb4307c9843e8ae6bc672a3accef3"}},"ac5aebc3402f436ba7ad6790639616b8":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_5c81bf8a7d674d07aa1f21672a8619a8","placeholder":"​","style":"IPY_MODEL_f80da51a42b346d5b7e3d00fc1e179f2","value":"Downloading (…)/main/tokenizer.json: 100%"}},"eb42d4a40f7a400788e65c7212bef645":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_5667e8c228ea4ddebbb520381aa6c180","max":1389353,"min":0,"orientation":"horizontal","style":"IPY_MODEL_684bc651706d40a4b348ef9e7e0f8a12","value":1389353}},"04f0c4b7b02d447b9c8617577b8cdad8":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_671b57dd2c194e858262aaf0720ac21a","placeholder":"​","style":"IPY_MODEL_ded425fe9b9f455a9a39356ba55059dc","value":" 1.39M/1.39M [00:00<00:00, 6.60MB/s]"}},"2cedb4307c9843e8ae6bc672a3accef3":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"5c81bf8a7d674d07aa1f21672a8619a8":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f80da51a42b346d5b7e3d00fc1e179f2":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"5667e8c228ea4ddebbb520381aa6c180":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"684bc651706d40a4b348ef9e7e0f8a12":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"671b57dd2c194e858262aaf0720ac21a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ded425fe9b9f455a9a39356ba55059dc":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}},"accelerator":"GPU","gpuClass":"standard"},"cells":[{"cell_type":"code","source":["! pip install datasets transformers nltk"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"r15AGrLU9qtg","executionInfo":{"status":"ok","timestamp":1683921645491,"user_tz":240,"elapsed":31979,"user":{"displayName":"Yongchao Chen","userId":"06121221447580717190"}},"outputId":"a8166c31-2c0b-4c64-eb9b-55f8915d9c4d"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting datasets\n"," Downloading datasets-2.12.0-py3-none-any.whl (474 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m474.6/474.6 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting transformers\n"," Downloading transformers-4.29.1-py3-none-any.whl (7.1 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.1/7.1 MB\u001b[0m \u001b[31m77.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (3.8.1)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.22.4)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n","Collecting dill<0.3.7,>=0.3.0 (from datasets)\n"," Downloading dill-0.3.6-py3-none-any.whl (110 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.5/110.5 kB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n","Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.27.1)\n","Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.65.0)\n","Collecting xxhash (from datasets)\n"," Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.5/212.5 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting multiprocess (from datasets)\n"," Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.3/134.3 kB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.4.0)\n","Collecting aiohttp (from datasets)\n"," Downloading aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting huggingface-hub<1.0.0,>=0.11.0 (from datasets)\n"," Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m224.5/224.5 kB\u001b[0m \u001b[31m13.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.1)\n","Collecting responses<0.19 (from datasets)\n"," Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n","Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)\n"," Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m35.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk) (8.1.3)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk) (1.2.0)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n","Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.0.12)\n","Collecting multidict<7.0,>=4.5 (from aiohttp->datasets)\n"," Downloading multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m114.5/114.5 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting async-timeout<5.0,>=4.0.0a3 (from aiohttp->datasets)\n"," Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)\n","Collecting yarl<2.0,>=1.0 (from aiohttp->datasets)\n"," Downloading yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (268 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m30.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting frozenlist>=1.1.1 (from aiohttp->datasets)\n"," Downloading frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (149 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m149.6/149.6 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting aiosignal>=1.1.2 (from aiohttp->datasets)\n"," Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0.0,>=0.11.0->datasets) (4.5.0)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2022.12.7)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.4)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2022.7.1)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n","Installing collected packages: tokenizers, xxhash, multidict, frozenlist, dill, async-timeout, yarl, responses, multiprocess, huggingface-hub, aiosignal, transformers, aiohttp, datasets\n","Successfully installed aiohttp-3.8.4 aiosignal-1.3.1 async-timeout-4.0.2 datasets-2.12.0 dill-0.3.6 frozenlist-1.3.3 huggingface-hub-0.14.1 multidict-6.0.4 multiprocess-0.70.14 responses-0.18.0 tokenizers-0.13.3 transformers-4.29.1 xxhash-3.2.0 yarl-1.9.2\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MD6CrnUo9BVo"},"outputs":[],"source":["from transformers import (AutoModelForSeq2SeqLM, \n"," AutoTokenizer, \n"," T5Tokenizer)\n","import torch\n","import pandas as pd\n","from datasets import Dataset, DatasetDict, load_dataset, load_from_disk\n","from tqdm import tqdm\n","\n","#import subprocess\n","import sys\n","import os\n","import argparse\n","from IPython.core import error\n","import random\n","import numpy as np\n","import nltk\n","import json\n","import csv"]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive', force_remount=True)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"KlkSinYBKytq","executionInfo":{"status":"ok","timestamp":1683921954820,"user_tz":240,"elapsed":32391,"user":{"displayName":"Yongchao Chen","userId":"06121221447580717190"}},"outputId":"6d925a7c-1e99-4354-b214-9dd65fc67767"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","source":["output_dir = 'path_to_your_model_weight/'\n","# Here you need to link this path in your Google drive to the place preseving your model weights, e.g., checkpoint-62500\n","# You can download it on the github page\n","\n","model_checkpoint = \"t5-base\"\n","prefix = \"Transform the following sentence into Signal Temporal logic: \"\n","\n","max_input_length = 1024\n","max_target_length = 128\n","tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=max_input_length)\n","\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","tl_model = AutoModelForSeq2SeqLM.from_pretrained(output_dir+\"checkpoint-62500\").to(device)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":113,"referenced_widgets":["0d6fa6d663574e6888336392536e4897","2123b6617db84b8092e43bb8fe5dd6e9","fc6c4a1c018c4156aa00a07ada65269b","d8e2c00be1ab4350852ffd5e69a5e459","48c78f9d82dd415389548c69fd977aac","a455a02a61dc403e9d004acb75cec3a5","b0a5506b39ff4839a0ad79b293bf1afe","9a4b06f22c62474bb19b0bb3e213dd74","d06acbf2fe3a4ef1a586ba8cbc379231","4c6e1ced7b914f0caa8a2d549fcf8b7f","c5febe95edb846e3831f059292cf07a6","7c4ba98f462d47d887855ac9e945770d","f0679738f4b940879dd68e94fe8d0d53","e892d45ee3534c2490b599e782ecb204","6138d12388654b67a0c5b3e514ef8c8c","685a0a9aea8d46e09352a0534f892779","9eb0819b08b94c708c1e62001c556550","ad012aad9541413197c0daf8c4a49fba","cd3e1bd7c1724e738331bb663b4dfb3e","126af1b857114b0cb010d16cbe8e5ee8","91dc34d8904442b38e1c340c231b2dfc","223ea5741ce4461784c2216600a1b791","a9692ceb9afc4ad088e8fb0ef8f00b3e","ac5aebc3402f436ba7ad6790639616b8","eb42d4a40f7a400788e65c7212bef645","04f0c4b7b02d447b9c8617577b8cdad8","2cedb4307c9843e8ae6bc672a3accef3","5c81bf8a7d674d07aa1f21672a8619a8","f80da51a42b346d5b7e3d00fc1e179f2","5667e8c228ea4ddebbb520381aa6c180","684bc651706d40a4b348ef9e7e0f8a12","671b57dd2c194e858262aaf0720ac21a","ded425fe9b9f455a9a39356ba55059dc"]},"id":"rsRrkdFh8Qwm","executionInfo":{"status":"ok","timestamp":1683922212237,"user_tz":240,"elapsed":26716,"user":{"displayName":"Yongchao Chen","userId":"06121221447580717190"}},"outputId":"a14a8ddf-06ec-4fb1-9155-4357d0b76264"},"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/plain":["Downloading (…)lve/main/config.json: 0%| | 0.00/1.21k [00:00