├── LICENSE ├── LCW-Using-ChemBERTa-2-For-Property_Prediction.py ├── LCW-Fine-Tuning-ChemBERTa-2.py ├── LCW-Using-ChemBERTa-2-For-Property_Prediction.ipynb └── LCW-Fine-Tuning-ChemBERTa-2.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 jwoerner42 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LCW-Using-ChemBERTa-2-For-Property_Prediction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Released under MIT License 5 | # 6 | # Copyright (c) 2023 Andrew SID Lang, Oral Roberts University, U.S.A. 7 | # 8 | # Copyright (c) 2023 Jan HR Woerner, Oral Roberts University, U.S.A. 9 | # 10 | # Copyright (c) 2023 Wei-Khiong (Wyatt) Chong, Advent Polytech Co., Ltd, Taiwan. 11 | # 12 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 13 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 14 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 15 | # permit persons to whom the Software is furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the 18 | # Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 21 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS 22 | # OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 23 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | 25 | 26 | import torch 27 | from torch.utils.data import Dataset 28 | from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments 29 | import pandas as pd 30 | from transformers import pipeline 31 | from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error 32 | from scipy.stats import spearmanr 33 | import matplotlib.pyplot as plt 34 | import warnings 35 | 36 | warnings.filterwarnings("ignore", message="Was asked to gather along dimension 0, but all input tensors were scalars") 37 | output_directory = "./output" 38 | # max length of SMILES over both sets 39 | max_length = 195 40 | 41 | class Input(Dataset): 42 | def __init__(self, i_data, i_tokenizer, i_max_length): 43 | self.data = i_data 44 | self.tokenizer = i_tokenizer 45 | self.max_length = i_max_length 46 | 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | def __getitem__(self, idx): 51 | i_smiles = self.data.iloc[idx]["Standardized_SMILES"] 52 | i_inputs = self.tokenizer(i_smiles, return_tensors="pt", padding='max_length', truncation=True, 53 | max_length=self.max_length) 54 | i_inputs["input_ids"] = i_inputs["input_ids"].squeeze(0) 55 | i_inputs["attention_mask"] = i_inputs["attention_mask"].squeeze(0) 56 | if "token_type_ids" in i_inputs: 57 | i_inputs["token_type_ids"] = i_inputs["token_type_ids"].squeeze(0) 58 | i_inputs["labels"] = torch.tensor(self.data.iloc[idx]["median_WS"], dtype=torch.float).unsqueeze(0) 59 | return i_inputs 60 | 61 | 62 | # retrieve the device to move the model to 63 | def get_device(): 64 | if torch.cuda.is_available(): 65 | dev = torch.device("cuda") 66 | print("Using NV GPU.") 67 | # The mps device in torch does repeatedly lead to a RuntimeError: Placeholder storage has not been allocated 68 | # on MPS device! 69 | #elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 70 | # dev = torch.device("mps") 71 | # print("Using M1 GPU.") 72 | else: 73 | print("No GPU available, using the CPU instead.") 74 | dev = torch.device("cpu") 75 | return dev 76 | 77 | 78 | # Predict properties for new SMILES strings 79 | def predict_smiles(u_smiles, dev): 80 | preds = [] 81 | for i_smiles in u_smiles: 82 | inputs = tokenizer(i_smiles, return_tensors="pt", padding='max_length', truncation=True, max_length=195).to(dev) 83 | # max_length=195 84 | with torch.no_grad(): 85 | outputs = model(**inputs) 86 | pred_property = outputs.logits.squeeze().item() 87 | preds.append(pred_property) 88 | r_mse = mean_squared_error(data["median_WS"], preds, squared=False) 89 | r2 = r2_score(data["median_WS"], preds) 90 | mae = mean_absolute_error(data["median_WS"], preds) 91 | correlation, p_value = spearmanr(data["median_WS"], preds) 92 | return r_mse, r2, mae, preds, correlation, p_value 93 | 94 | 95 | # display the results 96 | def display_results(dataset_type, in_r_mse, in_r2, in_mae, preds, correlation, p_val): 97 | print(dataset_type) 98 | print("N:", len(data["median_WS"])) 99 | print("R2:", in_r2) 100 | print("Root Mean Square Error:", in_r_mse) 101 | print("Mean Absolute Error:", in_mae) 102 | print("Spearman correlation:", correlation) 103 | print("p-value:", p_val) 104 | 105 | plt.scatter(data["median_WS"], preds) 106 | plt.xlabel("train['median_WS']") 107 | plt.ylabel("predictions") 108 | plt.title("Scatter Plot of " + dataset_type + " ['median_WS'] vs Predictions") 109 | plt.show() 110 | 111 | 112 | # assume test and predictions are two arrays of the same length 113 | # run it for train smiles data 114 | def run_prediction(prep_smiles, set_type, dev, is_saved): 115 | out_r_mse, out_r2, out_mae, predictions, correlation, p_value = predict_smiles(prep_smiles, dev) 116 | 117 | display_results(set_type, out_r_mse, out_r2, out_mae, predictions, correlation, p_value) 118 | if is_saved: 119 | results_df = pd.DataFrame({"actual_WS": test["median_WS"], "predicted_WS": predictions}) 120 | results_df.to_csv("testset_results.csv", index=False) 121 | 122 | # 123 | # LCW Using ChemBERTa-2 For Property Prediction main program 124 | # 125 | 126 | # Read in solubility data 127 | train_data = pd.read_csv('aqua_train.csv') 128 | test_data = pd.read_csv('aqua_test.csv') 129 | 130 | # pick out columns 131 | data = train_data[['Standardized_SMILES', 'median_WS']] 132 | test = test_data[['Standardized_SMILES', 'median_WS']] 133 | 134 | # Load a pretrained transformer model and tokenizer 135 | model_name = "DeepChem/ChemBERTa-77M-MTR" 136 | tokenizer = AutoTokenizer.from_pretrained(model_name) 137 | config = AutoConfig.from_pretrained(model_name) 138 | config.num_hidden_layers += 1 139 | model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1) 140 | 141 | # move model to the device 142 | device = get_device() 143 | model.to(device) 144 | 145 | # Prepare the dataset for training 146 | train_dataset = Input(data, tokenizer, max_length) 147 | 148 | # Set up training arguments 149 | training_args = TrainingArguments( 150 | output_dir=output_directory, 151 | num_train_epochs=100, # Number of training epochs 152 | per_device_train_batch_size=86, # Batch size 153 | logging_steps=100, # Log training metrics every 100 steps 154 | optim="adamw_torch", # switch optimizer to avoid warning 155 | seed=123, # Set a random seed for reproducibility 156 | ) 157 | 158 | # Train the model 159 | trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, ) 160 | trainer.train() 161 | trainer.save_model("./output") # save model to output folder 162 | 163 | # Create a prediction pipeline 164 | 165 | predictor = pipeline("text-classification", model=model, tokenizer=tokenizer) 166 | 167 | # Prepare new SMILES strings for prediction and run the model for test data 168 | test_smiles = test['Standardized_SMILES'] 169 | run_prediction(test_smiles, "TEST SET", device, False) 170 | 171 | # Prepare new SMILES strings for prediction and run the model for training data 172 | train_smiles = data['Standardized_SMILES'] 173 | run_prediction(train_smiles, "TEST SET", device, False) 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /LCW-Fine-Tuning-ChemBERTa-2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Released under MIT License 5 | # 6 | # Copyright (c) 2023 Andrew SID Lang, Oral Roberts University, U.S.A. 7 | # 8 | # Copyright (c) 2023 Jan HR Woerner, Oral Roberts University, U.S.A. 9 | # 10 | # Copyright (c) 2023 Wei-Khiong (Wyatt) Chong, Advent Polytech Co., Ltd, Taiwan. 11 | # 12 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 13 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 14 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to 15 | # permit persons to whom the Software is furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the 18 | # Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 21 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS 22 | # OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 23 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | 25 | 26 | import torch 27 | from torch.utils.data import Dataset 28 | from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification 29 | from transformers import Trainer, TrainingArguments, TrainerCallback, pipeline 30 | import pandas as pd 31 | import warnings 32 | import numpy as np 33 | import evaluate 34 | from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error 35 | from scipy.stats import spearmanr 36 | import matplotlib.pyplot as plt 37 | 38 | warnings.filterwarnings("ignore", message="Was asked to gather along dimension 0, but all input tensors were scalars") 39 | # define the input file 40 | input_file: str = 'aqua.csv' 41 | output_directory: str = './output' 42 | # Define the maximum sequence length 43 | max_length = 195 44 | 45 | 46 | class MyData: 47 | def __init__(self, i_data): 48 | self.data = i_data 49 | 50 | def get_split(self, train_ratio=0.8, valid_ratio=0.1, seed=None): 51 | n = len(self.data) 52 | indices = np.arange(n) 53 | if seed is not None: 54 | np.random.seed(seed) 55 | np.random.shuffle(indices) 56 | train_size = int(train_ratio * n) 57 | valid_size = int(valid_ratio * n) 58 | train_indices = indices[:train_size] 59 | valid_indices = indices[train_size:train_size + valid_size] 60 | test_indices = indices[train_size + valid_size:] 61 | i_train_data = self.data.iloc[train_indices].reset_index(drop=True) 62 | i_valid_data = self.data.iloc[valid_indices].reset_index(drop=True) 63 | i_test_data = self.data.iloc[test_indices].reset_index(drop=True) 64 | return i_train_data, i_valid_data, i_test_data 65 | 66 | 67 | class Input(Dataset): 68 | def __init__(self, i_data, i_tokenizer, i_max_length): 69 | self.data = i_data 70 | self.tokenizer = i_tokenizer 71 | self.max_length = i_max_length 72 | 73 | def __len__(self): 74 | return len(self.data) 75 | 76 | def __getitem__(self, idx): 77 | smiles = self.data.iloc[idx]["Standardized_SMILES"] 78 | inputs = self.tokenizer(smiles, return_tensors="pt", padding='max_length', truncation=True, 79 | max_length=self.max_length) 80 | inputs["input_ids"] = inputs["input_ids"].squeeze(0) 81 | inputs["attention_mask"] = inputs["attention_mask"].squeeze(0) 82 | if "token_type_ids" in inputs: 83 | inputs["token_type_ids"] = inputs["token_type_ids"].squeeze(0) 84 | inputs["labels"] = torch.tensor(self.data.iloc[idx]["median_WS"], dtype=torch.float).unsqueeze(0) 85 | return inputs 86 | 87 | 88 | # Define a callback for printing validation loss 89 | class PrintValidationLossCallback(TrainerCallback): 90 | def on_evaluate(self, args, state, control, **kwargs): 91 | if state is not None and hasattr(state, 'eval_loss'): 92 | print(f"Validation loss: {state.eval_loss:.4f}") 93 | 94 | 95 | def compute_metrics(eval_pred): 96 | logits, labels = eval_pred 97 | predictions = np.argmax(logits, axis=-1) 98 | return metric.compute(predictions=predictions, references=labels) 99 | 100 | 101 | # Read in solubility data and split 102 | def read_solubility(filename: str): 103 | my_data = pd.read_csv(filename) 104 | # Create an instance of the MyData class 105 | my_data = MyData(my_data) 106 | # Split your data into training, validation, and testing sets 107 | train_data, valid_data, test_data = my_data.get_split(seed=123) 108 | # pick out columns 109 | r_data = train_data[['Standardized_SMILES', 'median_WS']] 110 | r_valid = valid_data[['Standardized_SMILES', 'median_WS']] 111 | r_test = test_data[['Standardized_SMILES', 'median_WS']] 112 | return r_data, r_valid, r_test 113 | 114 | 115 | # retrieve the device to move the model to 116 | def get_device(): 117 | if torch.cuda.is_available(): 118 | dev = torch.device("cuda") 119 | print("Using NV GPU.") 120 | # The mps device in torch does repeatedly lead to a RuntimeError: Placeholder storage has not been allocated 121 | # on MPS device! 122 | #elif torch.backends.mps.is_available(): 123 | # dev = torch.device("mps") 124 | # print("Using M1 GPU.") 125 | else: 126 | print("No GPU available, using the CPU instead.") 127 | dev = torch.device("cpu") 128 | return dev 129 | 130 | 131 | # Predict properties for new SMILES strings 132 | def predict_smiles(u_smiles, dev): 133 | preds = [] 134 | for i_smiles in u_smiles: 135 | # max_length=195 and move the inputs also to the device 136 | inputs = tokenizer(i_smiles, return_tensors="pt", padding='max_length', truncation=True, max_length=195).to(dev) 137 | with torch.no_grad(): 138 | outputs = model(**inputs) 139 | pred_property = outputs.logits.squeeze().item() 140 | preds.append(pred_property) 141 | r_mse = mean_squared_error(data["median_WS"], preds, squared=False) 142 | r2 = r2_score(data["median_WS"], preds) 143 | mae = mean_absolute_error(data["median_WS"], preds) 144 | correlation, p_value = spearmanr(data["median_WS"], preds) 145 | return r_mse, r2, mae, preds, correlation, p_value 146 | 147 | 148 | # display the results 149 | def display_results(dataset_type, in_r_mse, in_r2, in_mae, preds, correlation, p_val): 150 | print(dataset_type) 151 | print("N:", len(data["median_WS"])) 152 | print("R2:", in_r2) 153 | print("Root Mean Square Error:", in_r_mse) 154 | print("Mean Absolute Error:", in_mae) 155 | print("Spearman correlation:", correlation) 156 | print("p-value:", p_val) 157 | 158 | plt.scatter(data["median_WS"], preds) 159 | plt.xlabel("train['median_WS']") 160 | plt.ylabel("predictions") 161 | plt.title("Scatter Plot of " + dataset_type + " ['median_WS'] vs Predictions") 162 | plt.show() 163 | 164 | 165 | # assume test and predictions are two arrays of the same length 166 | # run it for prepared smiles data, set a string set_type for output, device, and a flag to save results 167 | def run_prediction(prep_smiles, set_type, dev, is_saved): 168 | out_r_mse, out_r2, out_mae, predictions, correlation, p_value = predict_smiles(prep_smiles, dev) 169 | display_results(set_type, out_r_mse, out_r2, out_mae, predictions, correlation, p_value) 170 | if is_saved: 171 | results_df = pd.DataFrame({"actual_WS": test["median_WS"], "predicted_WS": predictions}) 172 | results_df.to_csv("testset_results.csv", index=False) 173 | 174 | 175 | # 176 | # main program 177 | # 178 | 179 | # AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1) 180 | # Load a pretrained transformer model and tokenizer 181 | model_name = "DeepChem/ChemBERTa-77M-MTR" 182 | tokenizer = AutoTokenizer.from_pretrained(model_name) 183 | config = AutoConfig.from_pretrained(model_name) 184 | config.num_hidden_layers += 1 185 | model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1) 186 | 187 | # see if GPU and assign model (move model to the device) 188 | device = get_device() 189 | model.to(device) 190 | 191 | # Read and prepare the dataset for training 192 | data, valid, test = read_solubility(input_file) 193 | train_dataset = Input(data, tokenizer, max_length) 194 | validation_dataset = Input(valid, tokenizer, max_length) 195 | 196 | # Set up training arguments 197 | training_args = TrainingArguments( 198 | output_dir=output_directory, 199 | optim="adamw_torch", # switch optimizer to avoid warning 200 | num_train_epochs=100, # Train the model for 100 epochs 201 | per_device_train_batch_size=128, # Set the batch size to 128 202 | per_device_eval_batch_size=128, # Set the evaluation batch size to 128 203 | logging_steps=10, # Log training metrics every 100 steps 204 | eval_steps=10, # Evaluate the model every 100 steps 205 | save_steps=10, # Save the model every 100 steps 206 | seed=123, # Set the random seed for reproducibility 207 | evaluation_strategy="steps", # Evaluate the model every eval_steps steps 208 | load_best_model_at_end=True 209 | ) 210 | 211 | # Train the model 212 | trainer = Trainer( 213 | model=model, 214 | args=training_args, 215 | train_dataset=train_dataset, 216 | eval_dataset=validation_dataset, 217 | ) 218 | 219 | # Add the callback to the trainer 220 | trainer.add_callback(PrintValidationLossCallback()) 221 | 222 | metric = evaluate.load("accuracy") 223 | 224 | # Train the model 225 | trainer.train() 226 | 227 | # Save the model 228 | trainer.save_model("./output") 229 | 230 | # Create a prediction pipeline 231 | predictor = pipeline("text-classification", model=model, tokenizer=tokenizer) 232 | 233 | 234 | # Prepare new SMILES strings for prediction TRAINING-SET 235 | run_prediction(data['Standardized_SMILES'], "TRAINING SET", device, False) 236 | 237 | # Prepare new SMILES strings for prediction VALIDATION SET 238 | run_prediction(valid['Standardized_SMILES'], "VALIDATION SET", device, False) 239 | 240 | # Prepare new SMILES strings for prediction TEST SET 241 | run_prediction(test['Standardized_SMILES'], "TEST SET", device, True) 242 | -------------------------------------------------------------------------------- /LCW-Using-ChemBERTa-2-For-Property_Prediction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e4fd208b-665e-4d83-861c-3c974f9a5ba4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Released under MIT License\n", 11 | "#\n", 12 | "# Copyright (c) 2023 Andrew SID Lang, Oral Roberts University, U.S.A.\n", 13 | "#\n", 14 | "# Copyright (c) 2023 Jan HR Woerner, Oral Roberts University, U.S.A.\n", 15 | "#\n", 16 | "# Copyright (c) 2023 Wei-Khiong (Wyatt) Chong, Advent Polytech Co., Ltd, Taiwan.\n", 17 | "#\n", 18 | "# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated\n", 19 | "# documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the\n", 20 | "# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to\n", 21 | "# permit persons to whom the Software is furnished to do so, subject to the following conditions:\n", 22 | "#\n", 23 | "# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the\n", 24 | "# Software.\n", 25 | "#\n", 26 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO\n", 27 | "# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS\n", 28 | "# OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR\n", 29 | "# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n", 30 | "\n", 31 | "import torch\n", 32 | "from torch.utils.data import Dataset\n", 33 | "from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments\n", 34 | "import pandas as pd\n", 35 | "import warnings\n", 36 | "warnings.filterwarnings(\"ignore\", message=\"Was asked to gather along dimension 0, but all input tensors were scalars\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "948c5fef-6380-4fd2-b82d-3cb5c0b7d774", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "class Input(Dataset):\n", 47 | " def __init__(self, i_data, i_tokenizer, i_max_length):\n", 48 | " self.data = i_data\n", 49 | " self.tokenizer = i_tokenizer\n", 50 | " self.max_length = i_max_length\n", 51 | "\n", 52 | " def __len__(self):\n", 53 | " return len(self.data)\n", 54 | "\n", 55 | " def __getitem__(self, idx):\n", 56 | " i_smiles = self.data.iloc[idx][\"Standardized_SMILES\"]\n", 57 | " i_inputs = self.tokenizer(i_smiles, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=self.max_length)\n", 58 | " i_inputs[\"input_ids\"] = i_inputs[\"input_ids\"].squeeze(0)\n", 59 | " i_inputs[\"attention_mask\"] = i_inputs[\"attention_mask\"].squeeze(0)\n", 60 | " if \"token_type_ids\" in i_inputs:\n", 61 | " i_inputs[\"token_type_ids\"] = i_inputs[\"token_type_ids\"].squeeze(0)\n", 62 | " i_inputs[\"labels\"] = torch.tensor(self.data.iloc[idx][\"median_WS\"], dtype=torch.float).unsqueeze(0)\n", 63 | " return i_inputs" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "id": "10d8698a-096a-4e15-b454-2f25a4a19dc3", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Read in solubility data\n", 74 | "train_data = pd.read_csv('aqua_train.csv')\n", 75 | "test_data = pd.read_csv('aqua_test.csv')\n", 76 | "\n", 77 | "# pick out columns\n", 78 | "data = train_data[['Standardized_SMILES', 'median_WS']]\n", 79 | "test = test_data[['Standardized_SMILES', 'median_WS']]" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "id": "f290f56e-2ea6-45bc-90b7-cbe705dde78b", 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "name": "stderr", 90 | "output_type": "stream", 91 | "text": [ 92 | "Some weights of the model checkpoint at DeepChem/ChemBERTa-77M-MTR were not used when initializing RobertaForSequenceClassification: ['regression.out_proj.bias', 'regression.dense.bias', 'regression.dense.weight', 'norm_std', 'regression.out_proj.weight', 'norm_mean']\n", 93 | "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 94 | "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 95 | "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']\n", 96 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 97 | ] 98 | }, 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "Using GPU.\n" 104 | ] 105 | }, 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "RobertaForSequenceClassification(\n", 110 | " (roberta): RobertaModel(\n", 111 | " (embeddings): RobertaEmbeddings(\n", 112 | " (word_embeddings): Embedding(600, 384, padding_idx=1)\n", 113 | " (position_embeddings): Embedding(515, 384, padding_idx=1)\n", 114 | " (token_type_embeddings): Embedding(1, 384)\n", 115 | " (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)\n", 116 | " (dropout): Dropout(p=0.144, inplace=False)\n", 117 | " )\n", 118 | " (encoder): RobertaEncoder(\n", 119 | " (layer): ModuleList(\n", 120 | " (0-2): 3 x RobertaLayer(\n", 121 | " (attention): RobertaAttention(\n", 122 | " (self): RobertaSelfAttention(\n", 123 | " (query): Linear(in_features=384, out_features=384, bias=True)\n", 124 | " (key): Linear(in_features=384, out_features=384, bias=True)\n", 125 | " (value): Linear(in_features=384, out_features=384, bias=True)\n", 126 | " (dropout): Dropout(p=0.109, inplace=False)\n", 127 | " )\n", 128 | " (output): RobertaSelfOutput(\n", 129 | " (dense): Linear(in_features=384, out_features=384, bias=True)\n", 130 | " (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)\n", 131 | " (dropout): Dropout(p=0.144, inplace=False)\n", 132 | " )\n", 133 | " )\n", 134 | " (intermediate): RobertaIntermediate(\n", 135 | " (dense): Linear(in_features=384, out_features=464, bias=True)\n", 136 | " (intermediate_act_fn): GELUActivation()\n", 137 | " )\n", 138 | " (output): RobertaOutput(\n", 139 | " (dense): Linear(in_features=464, out_features=384, bias=True)\n", 140 | " (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)\n", 141 | " (dropout): Dropout(p=0.144, inplace=False)\n", 142 | " )\n", 143 | " )\n", 144 | " )\n", 145 | " )\n", 146 | " )\n", 147 | " (classifier): RobertaClassificationHead(\n", 148 | " (dense): Linear(in_features=384, out_features=384, bias=True)\n", 149 | " (dropout): Dropout(p=0.144, inplace=False)\n", 150 | " (out_proj): Linear(in_features=384, out_features=1, bias=True)\n", 151 | " )\n", 152 | ")" 153 | ] 154 | }, 155 | "execution_count": 4, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "# Load a pretrained transformer model and tokenizer\n", 162 | "model_name = \"DeepChem/ChemBERTa-77M-MTR\"\n", 163 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 164 | "config = AutoConfig.from_pretrained(model_name)\n", 165 | "config.num_hidden_layers += 1\n", 166 | "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)\n", 167 | "\n", 168 | "#see if GPU\n", 169 | "if torch.cuda.is_available(): \n", 170 | " device = torch.device(\"cuda\")\n", 171 | " print(\"Using GPU.\")\n", 172 | "else:\n", 173 | " print(\"No GPU available, using the CPU instead.\")\n", 174 | " device = torch.device(\"cpu\")\n", 175 | "# move model to the device\n", 176 | "model.to(device) " 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 5, 182 | "id": "513419d0-426e-4fcd-869e-7cec0f16e078", 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/html": [ 188 | "\n", 189 | "
\n", 190 | " \n", 191 | " \n", 192 | " [2300/2300 09:37, Epoch 100/100]\n", 193 | "
\n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | "
StepTraining Loss
1006.768000
2001.259400
3001.048200
4000.960900
5000.910400
6000.857300
7000.819500
8000.794000
9000.781700
10000.757900
11000.744500
12000.724800
13000.711800
14000.708500
15000.701500
16000.684800
17000.681700
18000.671000
19000.661500
20000.660700
21000.657900
22000.660900
23000.656800

" 296 | ], 297 | "text/plain": [ 298 | "" 299 | ] 300 | }, 301 | "metadata": {}, 302 | "output_type": "display_data" 303 | } 304 | ], 305 | "source": [ 306 | "# max length of SMILES over both sets\n", 307 | "max_length = 195\n", 308 | "# Prepare the dataset for training\n", 309 | "train_dataset = Input(data, tokenizer,max_length)\n", 310 | "\n", 311 | "# Set up training arguments\n", 312 | "training_args = TrainingArguments(\n", 313 | " output_dir=\"./output\",\n", 314 | " num_train_epochs=100, # Number of training epochs\n", 315 | " per_device_train_batch_size=86, # Batch size\n", 316 | " logging_steps=100, # Log training metrics every 100 steps\n", 317 | " optim=\"adamw_torch\", # switch optimizer to avoid warning\n", 318 | " seed=123, # Set a random seed for reproducibility\n", 319 | ")\n", 320 | "\n", 321 | " \n", 322 | "# Train the model\n", 323 | "trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset,)\n", 324 | "trainer.train() \n", 325 | "trainer.save_model(\"./output\") # save model to output folder" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 10, 331 | "id": "d3ad4bb7-baa3-49f1-a041-7d27159a0cb7", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "from transformers import pipeline\n", 336 | "# Create a prediction pipeline\n", 337 | "\n", 338 | "predictor = pipeline(\"text-classification\", model=model, tokenizer=tokenizer)\n", 339 | "\n", 340 | "# Prepare new SMILES strings for prediction\n", 341 | "test_smiles = test['Standardized_SMILES']\n", 342 | "\n", 343 | "# Predict properties for new SMILES strings\n", 344 | "predictions = []\n", 345 | "for smiles in test_smiles:\n", 346 | " inputs = tokenizer(smiles, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=195).to(device)\n", 347 | " with torch.no_grad():\n", 348 | " outputs = model(**inputs)\n", 349 | " predicted_property = outputs.logits.squeeze().item()\n", 350 | " predictions.append(predicted_property)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 11, 356 | "id": "4eea13b3-f35f-4db9-a1a9-61f2d16faceb", 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "TEST SET\n", 364 | "N: 2552\n", 365 | "R2: 0.8054362061353308\n", 366 | "Root Mean Square Error: 1.0125407666739181\n", 367 | "Mean Absolute Error: 0.716880692248506\n", 368 | "Spearman correlation: 0.8995523590904071\n", 369 | "p-value: 0.0\n" 370 | ] 371 | }, 372 | { 373 | "data": { 374 | "image/png": "\n", 375 | "text/plain": [ 376 | "

" 377 | ] 378 | }, 379 | "metadata": { 380 | "needs_background": "light" 381 | }, 382 | "output_type": "display_data" 383 | } 384 | ], 385 | "source": [ 386 | "from sklearn.metrics import mean_squared_error\n", 387 | "from sklearn.metrics import r2_score\n", 388 | "from sklearn.metrics import mean_absolute_error\n", 389 | "\n", 390 | "r_mse = mean_squared_error(test[\"median_WS\"], predictions, squared=False)\n", 391 | "r2 = r2_score(test[\"median_WS\"], predictions)\n", 392 | "mae = mean_absolute_error(test[\"median_WS\"], predictions)\n", 393 | "\n", 394 | "print(\"TEST SET\")\n", 395 | "print(\"N:\", len(test[\"median_WS\"]))\n", 396 | "print(\"R2:\", r2)\n", 397 | "print(\"Root Mean Square Error:\", r_mse)\n", 398 | "print(\"Mean Absolute Error:\", mae)\n", 399 | "\n", 400 | "from scipy.stats import spearmanr\n", 401 | "\n", 402 | "# assume test and predictions are two arrays of the same length\n", 403 | "correlation, p_value = spearmanr(test[\"median_WS\"], predictions)\n", 404 | "\n", 405 | "print(\"Spearman correlation:\", correlation)\n", 406 | "print(\"p-value:\", p_value)\n", 407 | "\n", 408 | "import matplotlib.pyplot as plt\n", 409 | "\n", 410 | "plt.scatter(test[\"median_WS\"], predictions)\n", 411 | "plt.xlabel(\"Test_Set Median Water Solubility\")\n", 412 | "plt.ylabel(\"Predictions\")\n", 413 | "plt.title(\"Scatter Plot of Test_Set Median Water Solubility vs Predictions\")\n", 414 | "plt.show()" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 12, 420 | "id": "36823b0f-4d30-4536-a182-7041c3e56642", 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [ 424 | "# Prepare new SMILES strings for prediction\n", 425 | "train_smiles = data['Standardized_SMILES']\n", 426 | "\n", 427 | "# Predict properties for new SMILES strings\n", 428 | "train_predictions = []\n", 429 | "for smiles in train_smiles:\n", 430 | " inputs = tokenizer(smiles, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=195).to(device) \n", 431 | " with torch.no_grad():\n", 432 | " outputs = model(**inputs)\n", 433 | " predicted_property = outputs.logits.squeeze().item()\n", 434 | " train_predictions.append(predicted_property)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": 13, 440 | "id": "b9115caa-d8d9-44e9-9abe-941d082f6ce2", 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "name": "stdout", 445 | "output_type": "stream", 446 | "text": [ 447 | "TRAINING SET\n", 448 | "N: 7655\n", 449 | "R2: 0.871320792914473\n", 450 | "Root Mean Square Error: 0.8098064745023944\n", 451 | "Mean Absolute Error: 0.595570092893068\n", 452 | "Spearman correlation: 0.9277674381474756\n", 453 | "p-value: 0.0\n" 454 | ] 455 | }, 456 | { 457 | "data": { 458 | "image/png": "\n", 459 | "text/plain": [ 460 | "
" 461 | ] 462 | }, 463 | "metadata": { 464 | "needs_background": "light" 465 | }, 466 | "output_type": "display_data" 467 | } 468 | ], 469 | "source": [ 470 | "rmse = mean_squared_error(data[\"median_WS\"], train_predictions, squared=False)\n", 471 | "r2 = r2_score(data[\"median_WS\"], train_predictions)\n", 472 | "mae = mean_absolute_error(data[\"median_WS\"], train_predictions)\n", 473 | "\n", 474 | "print(\"TRAINING SET\")\n", 475 | "print(\"N:\", len(data[\"median_WS\"]))\n", 476 | "print(\"R2:\", r2)\n", 477 | "print(\"Root Mean Square Error:\", rmse)\n", 478 | "print(\"Mean Absolute Error:\", mae)\n", 479 | "\n", 480 | "# assume test and predictions are two arrays of the same length\n", 481 | "correlation, p_value = spearmanr(data[\"median_WS\"], train_predictions)\n", 482 | "\n", 483 | "print(\"Spearman correlation:\", correlation)\n", 484 | "print(\"p-value:\", p_value)\n", 485 | "\n", 486 | "import matplotlib.pyplot as plt\n", 487 | "\n", 488 | "plt.scatter(data[\"median_WS\"], train_predictions)\n", 489 | "plt.xlabel(\"Training_Set Median Water Solubility\")\n", 490 | "plt.ylabel(\"Predictions\")\n", 491 | "plt.title(\"Scatter Plot of Training_Set Median Water Solubility vs Predictions\")\n", 492 | "plt.show()" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "id": "e4544830-cc4f-4eee-a692-cfd99cd9072e", 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [] 502 | } 503 | ], 504 | "metadata": { 505 | "kernelspec": { 506 | "display_name": "Python 3 (ipykernel)", 507 | "language": "python", 508 | "name": "python3" 509 | }, 510 | "language_info": { 511 | "codemirror_mode": { 512 | "name": "ipython", 513 | "version": 3 514 | }, 515 | "file_extension": ".py", 516 | "mimetype": "text/x-python", 517 | "name": "python", 518 | "nbconvert_exporter": "python", 519 | "pygments_lexer": "ipython3", 520 | "version": "3.9.5" 521 | } 522 | }, 523 | "nbformat": 4, 524 | "nbformat_minor": 5 525 | } 526 | -------------------------------------------------------------------------------- /LCW-Fine-Tuning-ChemBERTa-2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e4fd208b-665e-4d83-861c-3c974f9a5ba4", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Released under MIT License\n", 11 | "#\n", 12 | "# Copyright (c) 2023 Andrew SID Lang, Oral Roberts University, U.S.A.\n", 13 | "#\n", 14 | "# Copyright (c) 2023 Jan HR Woerner, Oral Roberts University, U.S.A.\n", 15 | "#\n", 16 | "# Copyright (c) 2023 Wei-Khiong (Wyatt) Chong, Advent Polytech Co., Ltd, Taiwan.\n", 17 | "#\n", 18 | "# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated\n", 19 | "# documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the\n", 20 | "# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to\n", 21 | "# permit persons to whom the Software is furnished to do so, subject to the following conditions:\n", 22 | "#\n", 23 | "# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the\n", 24 | "# Software.\n", 25 | "#\n", 26 | "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO\n", 27 | "# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS\n", 28 | "# OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR\n", 29 | "# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n", 30 | "\n", 31 | "import torch\n", 32 | "from torch.utils.data import Dataset\n", 33 | "from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, TrainerCallback\n", 34 | "import pandas as pd\n", 35 | "import warnings\n", 36 | "warnings.filterwarnings(\"ignore\", message=\"Was asked to gather along dimension 0, but all input tensors were scalars\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "5d899440-4b67-43fe-bad7-555924feeed4", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "class MyData:\n", 47 | " def __init__(self, i_data):\n", 48 | " self.data = i_data\n", 49 | " \n", 50 | " def get_split(self, train_ratio=0.8, valid_ratio=0.1, seed=None):\n", 51 | " n = len(self.data)\n", 52 | " indices = np.arange(n)\n", 53 | " if seed is not None:\n", 54 | " np.random.seed(seed)\n", 55 | " np.random.shuffle(indices)\n", 56 | " train_size = int(train_ratio * n)\n", 57 | " valid_size = int(valid_ratio * n)\n", 58 | " train_indices = indices[:train_size]\n", 59 | " valid_indices = indices[train_size:train_size+valid_size]\n", 60 | " test_indices = indices[train_size+valid_size:]\n", 61 | " i_train_data = self.data.iloc[train_indices].reset_index(drop=True)\n", 62 | " i_valid_data = self.data.iloc[valid_indices].reset_index(drop=True)\n", 63 | " i_test_data = self.data.iloc[test_indices].reset_index(drop=True)\n", 64 | " return i_train_data, i_valid_data, i_test_data" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "948c5fef-6380-4fd2-b82d-3cb5c0b7d774", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "class Input(Dataset):\n", 75 | " def __init__(self, i_data, i_tokenizer, i_max_length):\n", 76 | " self.data = i_data\n", 77 | " self.tokenizer = i_tokenizer\n", 78 | " self.max_length = i_max_length\n", 79 | "\n", 80 | " def __len__(self):\n", 81 | " return len(self.data)\n", 82 | "\n", 83 | " def __getitem__(self, idx):\n", 84 | " smiles = self.data.iloc[idx][\"Standardized_SMILES\"]\n", 85 | " inputs = self.tokenizer(smiles, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=self.max_length)\n", 86 | " inputs[\"input_ids\"] = inputs[\"input_ids\"].squeeze(0)\n", 87 | " inputs[\"attention_mask\"] = inputs[\"attention_mask\"].squeeze(0)\n", 88 | " if \"token_type_ids\" in inputs:\n", 89 | " inputs[\"token_type_ids\"] = inputs[\"token_type_ids\"].squeeze(0)\n", 90 | " inputs[\"labels\"] = torch.tensor(self.data.iloc[idx][\"median_WS\"], dtype=torch.float).unsqueeze(0)\n", 91 | " return inputs" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "id": "d2f6243f-76f2-463d-8087-ac80742fb0a3", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "# Read in solubility data\n", 102 | "my_data = pd.read_csv('aqua.csv')\n", 103 | "\n", 104 | "# Create an instance of the MyData class\n", 105 | "my_data = MyData(my_data)\n", 106 | "\n", 107 | "# Split your data into training, validation, and testing sets\n", 108 | "train_data, valid_data, test_data = my_data.get_split(seed = 123)\n", 109 | "\n", 110 | "# pick out columns\n", 111 | "data = train_data[['Standardized_SMILES', 'median_WS']]\n", 112 | "valid = valid_data[['Standardized_SMILES', 'median_WS']]\n", 113 | "test = test_data[['Standardized_SMILES', 'median_WS']]" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "id": "f290f56e-2ea6-45bc-90b7-cbe705dde78b", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stderr", 124 | "output_type": "stream", 125 | "text": [ 126 | "Some weights of the model checkpoint at DeepChem/ChemBERTa-77M-MTR were not used when initializing RobertaForSequenceClassification: ['regression.dense.weight', 'regression.out_proj.weight', 'regression.out_proj.bias', 'norm_mean', 'norm_std', 'regression.dense.bias']\n", 127 | "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 128 | "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 129 | "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.dense.weight', 'classifier.out_proj.bias']\n", 130 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 131 | ] 132 | }, 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "Using GPU.\n" 138 | ] 139 | }, 140 | { 141 | "data": { 142 | "text/plain": [ 143 | "RobertaForSequenceClassification(\n", 144 | " (roberta): RobertaModel(\n", 145 | " (embeddings): RobertaEmbeddings(\n", 146 | " (word_embeddings): Embedding(600, 384, padding_idx=1)\n", 147 | " (position_embeddings): Embedding(515, 384, padding_idx=1)\n", 148 | " (token_type_embeddings): Embedding(1, 384)\n", 149 | " (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)\n", 150 | " (dropout): Dropout(p=0.144, inplace=False)\n", 151 | " )\n", 152 | " (encoder): RobertaEncoder(\n", 153 | " (layer): ModuleList(\n", 154 | " (0-2): 3 x RobertaLayer(\n", 155 | " (attention): RobertaAttention(\n", 156 | " (self): RobertaSelfAttention(\n", 157 | " (query): Linear(in_features=384, out_features=384, bias=True)\n", 158 | " (key): Linear(in_features=384, out_features=384, bias=True)\n", 159 | " (value): Linear(in_features=384, out_features=384, bias=True)\n", 160 | " (dropout): Dropout(p=0.109, inplace=False)\n", 161 | " )\n", 162 | " (output): RobertaSelfOutput(\n", 163 | " (dense): Linear(in_features=384, out_features=384, bias=True)\n", 164 | " (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)\n", 165 | " (dropout): Dropout(p=0.144, inplace=False)\n", 166 | " )\n", 167 | " )\n", 168 | " (intermediate): RobertaIntermediate(\n", 169 | " (dense): Linear(in_features=384, out_features=464, bias=True)\n", 170 | " (intermediate_act_fn): GELUActivation()\n", 171 | " )\n", 172 | " (output): RobertaOutput(\n", 173 | " (dense): Linear(in_features=464, out_features=384, bias=True)\n", 174 | " (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)\n", 175 | " (dropout): Dropout(p=0.144, inplace=False)\n", 176 | " )\n", 177 | " )\n", 178 | " )\n", 179 | " )\n", 180 | " )\n", 181 | " (classifier): RobertaClassificationHead(\n", 182 | " (dense): Linear(in_features=384, out_features=384, bias=True)\n", 183 | " (dropout): Dropout(p=0.144, inplace=False)\n", 184 | " (out_proj): Linear(in_features=384, out_features=1, bias=True)\n", 185 | " )\n", 186 | ")" 187 | ] 188 | }, 189 | "execution_count": 5, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "#AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)\n", 196 | "\n", 197 | "# Load a pretrained transformer model and tokenizer\n", 198 | "model_name = \"DeepChem/ChemBERTa-77M-MTR\"\n", 199 | "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 200 | "config = AutoConfig.from_pretrained(model_name)\n", 201 | "config.num_hidden_layers += 1\n", 202 | "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)\n", 203 | "\n", 204 | "#see if GPU\n", 205 | "if torch.cuda.is_available(): \n", 206 | " device = torch.device(\"cuda\")\n", 207 | " print(\"Using GPU.\")\n", 208 | "else:\n", 209 | " print(\"No GPU available, using the CPU instead.\")\n", 210 | " device = torch.device(\"cpu\")\n", 211 | "# move model to the device\n", 212 | "model.to(device) " 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 6, 218 | "id": "513419d0-426e-4fcd-869e-7cec0f16e078", 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "data": { 223 | "text/html": [ 224 | "\n", 225 | "
\n", 226 | " \n", 227 | " \n", 228 | " [1600/1600 11:03, Epoch 100/100]\n", 229 | "
\n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | "
StepTraining LossValidation Loss
1014.49850013.903141
2013.31250012.314187
3011.3317009.889343
408.8333006.695904
505.9927004.188548
604.1795003.679473
703.3948002.733828
802.6909002.194749
902.1848001.792031
1001.9166001.556690
1101.6070001.459312
1201.4609001.370523
1301.3925001.330363
1401.2731001.267703
1501.3301001.248693
1601.2054001.236013
1701.1710001.196895
1801.2225001.181159
1901.1322001.157773
2001.1120001.164091
2101.1167001.133514
2201.1236001.133295
2301.0493001.128167
2401.0811001.099577
2501.0824001.096789
2600.9867001.089766
2701.0498001.079090
2801.0274001.067783
2901.0011001.057811
3001.0063001.100570
3100.9958001.029109
3201.0064001.044250
3301.0225001.036337
3400.9642001.032447
3500.9618001.010790
3600.9727001.015232
3700.9902001.014361
3800.9370001.014328
3900.9167001.013889
4000.9608000.983203
4100.9571000.980814
4200.9205000.981581
4300.9354000.983257
4400.8955000.974684
4500.9294000.982432
4600.9262000.971062
4700.9049000.969826
4800.9017000.960682
4900.9368000.976769
5000.8808000.964954
5100.9036000.949082
5200.9097000.966911
5300.8355000.936071
5400.9125000.960301
5500.8818000.935671
5600.8624000.946876
5700.8508000.935555
5800.9047000.928763
5900.8761000.926278
6000.8518000.938022
6100.8859000.917941
6200.8761000.938434
6300.8433000.916394
6400.8366000.931034
6500.8551000.909021
6600.8451000.922578
6700.8425000.908290
6800.8610000.914329
6900.8254000.920769
7000.8277000.896255
7100.8434000.894783
7200.8193000.904708
7300.8170000.909307
7400.8046000.910911
7500.8334000.887669
7600.8167000.900026
7700.8294000.894980
7800.8009000.889096
7900.8044000.896905
8000.8029000.894861
8100.8297000.900829
8200.7979000.890006
8300.7856000.893044
8400.7854000.884947
8500.8031000.903455
8600.7813000.896731
8700.7941000.903011
8800.7994000.877687
8900.7959000.900194
9000.7949000.883582
9100.7847000.887545
9200.7801000.894659
9300.7707000.909069
9400.7759000.883565
9500.8200000.876190
9600.7659000.899084
9700.7636000.884935
9800.7802000.904533
9900.7738000.880672
10000.7754000.875259
10100.7567000.905909
10200.7661000.888669
10300.7523000.898266
10400.7826000.879773
10500.7493000.881599
10600.7850000.883525
10700.7195000.879436
10800.7968000.876413
10900.7442000.889036
11000.7744000.877228
11100.7509000.879265
11200.7590000.879474
11300.7231000.883581
11400.7979000.891133
11500.7259000.875201
11600.7631000.874155
11700.7286000.876161
11800.7651000.879099
11900.7441000.873605
12000.7592000.879335
12100.7704000.874104
12200.7210000.877102
12300.7536000.882011
12400.7492000.874405
12500.7401000.882769
12600.7593000.884418
12700.7265000.873274
12800.7464000.872972
12900.7682000.877476
13000.7047000.882903
13100.7474000.873165
13200.7202000.869180
13300.7886000.879084
13400.7152000.875005
13500.7080000.871279
13600.7599000.879987
13700.7391000.873613
13800.7508000.866602
13900.7474000.874252
14000.7535000.870742
14100.7340000.870454
14200.7501000.865231
14300.7710000.871395
14400.7178000.876452
14500.7427000.872083
14600.7150000.867632
14700.7325000.870068
14800.7231000.870656
14900.7464000.871050
15000.7384000.869421
15100.7432000.867668
15200.7276000.871007
15300.7134000.874482
15400.6977000.874198
15500.7375000.873622
15600.7382000.872050
15700.7352000.872152
15800.7259000.872730
15900.7275000.872686
16000.7151000.872661

" 1041 | ], 1042 | "text/plain": [ 1043 | "" 1044 | ] 1045 | }, 1046 | "metadata": {}, 1047 | "output_type": "display_data" 1048 | } 1049 | ], 1050 | "source": [ 1051 | "from transformers import TrainerCallback\n", 1052 | "import numpy as np\n", 1053 | "import evaluate\n", 1054 | "\n", 1055 | "# Define the maximum sequence length\n", 1056 | "max_length = 195\n", 1057 | "\n", 1058 | "# Prepare the dataset for training\n", 1059 | "train_dataset = Input(data, tokenizer, max_length)\n", 1060 | "validation_dataset = Input(valid, tokenizer, max_length)\n", 1061 | "\n", 1062 | "# Set up training arguments\n", 1063 | "training_args = TrainingArguments(\n", 1064 | " output_dir=\"./output\",\n", 1065 | " optim=\"adamw_torch\", # switch optimizer to avoid warning\n", 1066 | " num_train_epochs=100, # Train the model for 100 epochs\n", 1067 | " per_device_train_batch_size=128, # Set the batch size to 128\n", 1068 | " per_device_eval_batch_size=128, # Set the evaluation batch size to 128\n", 1069 | " logging_steps=10, # Log training metrics every 100 steps\n", 1070 | " eval_steps=10, # Evaluate the model every 100 steps\n", 1071 | " save_steps=10, # Save the model every 100 steps\n", 1072 | " seed=123, # Set the random seed for reproducibility\n", 1073 | " evaluation_strategy=\"steps\", # Evaluate the model every eval_steps steps\n", 1074 | " load_best_model_at_end=True\n", 1075 | ")\n", 1076 | "\n", 1077 | "# Train the model\n", 1078 | "trainer = Trainer(\n", 1079 | " model=model,\n", 1080 | " args=training_args,\n", 1081 | " train_dataset=train_dataset,\n", 1082 | " eval_dataset=validation_dataset,\n", 1083 | ")\n", 1084 | "\n", 1085 | "\n", 1086 | "# Define a callback for printing validation loss\n", 1087 | "class PrintValidationLossCallback(TrainerCallback):\n", 1088 | " def on_evaluate(self, args, state, control, **kwargs):\n", 1089 | " if state is not None and hasattr(state, 'eval_loss'):\n", 1090 | " print(f\"Validation loss: {state.eval_loss:.4f}\")\n", 1091 | "\n", 1092 | "# Add the callback to the trainer\n", 1093 | "trainer.add_callback(PrintValidationLossCallback())\n", 1094 | "\n", 1095 | "\n", 1096 | "\n", 1097 | "metric = evaluate.load(\"accuracy\")\n", 1098 | "\n", 1099 | "def compute_metrics(eval_pred):\n", 1100 | " logits, labels = eval_pred\n", 1101 | " predictions = np.argmax(logits, axis=-1)\n", 1102 | " return metric.compute(predictions=predictions, references=labels)\n", 1103 | "\n", 1104 | "# Train the model\n", 1105 | "trainer.train()\n", 1106 | "\n", 1107 | "# Save the model\n", 1108 | "trainer.save_model(\"./output\")" 1109 | ] 1110 | }, 1111 | { 1112 | "cell_type": "code", 1113 | "execution_count": 7, 1114 | "id": "8b2bc259-724d-42bb-8756-670d7ba21854", 1115 | "metadata": {}, 1116 | "outputs": [], 1117 | "source": [ 1118 | "from sklearn.metrics import mean_squared_error\n", 1119 | "from sklearn.metrics import r2_score\n", 1120 | "from sklearn.metrics import mean_absolute_error\n", 1121 | "from scipy.stats import spearmanr\n", 1122 | "import matplotlib.pyplot as plt\n", 1123 | "from transformers import pipeline\n", 1124 | "\n", 1125 | "# Create a prediction pipeline\n", 1126 | "predictor = pipeline(\"text-classification\", model=model, tokenizer=tokenizer)" 1127 | ] 1128 | }, 1129 | { 1130 | "cell_type": "code", 1131 | "execution_count": 8, 1132 | "id": "dd930758-128a-4097-811e-76ff0fba6c7c", 1133 | "metadata": {}, 1134 | "outputs": [ 1135 | { 1136 | "name": "stdout", 1137 | "output_type": "stream", 1138 | "text": [ 1139 | "TRAINING SET\n", 1140 | "N: 8165\n", 1141 | "R2: 0.8702982901268843\n", 1142 | "Root Mean Square Error: 0.8173305996108375\n", 1143 | "Mean Absolute Error: 0.5987025186064614\n", 1144 | "Spearman correlation: 0.930919173416945\n", 1145 | "p-value: 0.0\n" 1146 | ] 1147 | }, 1148 | { 1149 | "data": { 1150 | "image/png": "\n", 1151 | "text/plain": [ 1152 | "

" 1153 | ] 1154 | }, 1155 | "metadata": { 1156 | "needs_background": "light" 1157 | }, 1158 | "output_type": "display_data" 1159 | } 1160 | ], 1161 | "source": [ 1162 | "# Prepare new SMILES strings for prediction\n", 1163 | "train_smiles = data['Standardized_SMILES']\n", 1164 | "\n", 1165 | "# Predict properties for new SMILES strings\n", 1166 | "predictions = []\n", 1167 | "for smiles in train_smiles:\n", 1168 | " inputs = tokenizer(smiles, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=195).to(device) #max_length=195\n", 1169 | " with torch.no_grad():\n", 1170 | " outputs = model(**inputs)\n", 1171 | " predicted_property = outputs.logits.squeeze().item()\n", 1172 | " predictions.append(predicted_property)\n", 1173 | "\n", 1174 | "r_mse = mean_squared_error(data[\"median_WS\"], predictions, squared=False)\n", 1175 | "r2 = r2_score(data[\"median_WS\"], predictions)\n", 1176 | "mae = mean_absolute_error(data[\"median_WS\"], predictions)\n", 1177 | "\n", 1178 | "print(\"TRAINING SET\")\n", 1179 | "print(\"N:\", len(data[\"median_WS\"]))\n", 1180 | "print(\"R2:\", r2)\n", 1181 | "print(\"Root Mean Square Error:\", r_mse)\n", 1182 | "print(\"Mean Absolute Error:\", mae)\n", 1183 | "\n", 1184 | "# assume test and predictions are two arrays of the same length\n", 1185 | "correlation, p_value = spearmanr(data[\"median_WS\"], predictions)\n", 1186 | "\n", 1187 | "print(\"Spearman correlation:\", correlation)\n", 1188 | "print(\"p-value:\", p_value)\n", 1189 | "\n", 1190 | "plt.scatter(data[\"median_WS\"], predictions)\n", 1191 | "plt.xlabel(\"train['median_WS']\")\n", 1192 | "plt.ylabel(\"predictions\")\n", 1193 | "plt.title(\"Scatter Plot of Training['median_WS'] vs Predictions\")\n", 1194 | "plt.show()" 1195 | ] 1196 | }, 1197 | { 1198 | "cell_type": "code", 1199 | "execution_count": 9, 1200 | "id": "6cf01e38-b1ad-42d8-b18f-10722d1a4e98", 1201 | "metadata": {}, 1202 | "outputs": [ 1203 | { 1204 | "name": "stdout", 1205 | "output_type": "stream", 1206 | "text": [ 1207 | "VALIDATION SET\n", 1208 | "N: 1020\n", 1209 | "R2: 0.8346196871525771\n", 1210 | "Root Mean Square Error: 0.9301778088403094\n", 1211 | "Mean Absolute Error: 0.6894259539831057\n", 1212 | "Spearman correlation: 0.9164363654548762\n", 1213 | "p-value: 0.0\n" 1214 | ] 1215 | }, 1216 | { 1217 | "data": { 1218 | "image/png": "\n", 1219 | "text/plain": [ 1220 | "
" 1221 | ] 1222 | }, 1223 | "metadata": { 1224 | "needs_background": "light" 1225 | }, 1226 | "output_type": "display_data" 1227 | } 1228 | ], 1229 | "source": [ 1230 | "# Prepare new SMILES strings for prediction\n", 1231 | "valid_smiles = valid['Standardized_SMILES']\n", 1232 | "\n", 1233 | "# Predict properties for new SMILES strings\n", 1234 | "predictions = []\n", 1235 | "for smiles in valid_smiles:\n", 1236 | " inputs = tokenizer(smiles, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=195).to(device) #max_length=195\n", 1237 | " with torch.no_grad():\n", 1238 | " outputs = model(**inputs)\n", 1239 | " predicted_property = outputs.logits.squeeze().item()\n", 1240 | " predictions.append(predicted_property)\n", 1241 | "\n", 1242 | "r_mse = mean_squared_error(valid[\"median_WS\"], predictions, squared=False)\n", 1243 | "r2 = r2_score(valid[\"median_WS\"], predictions)\n", 1244 | "mae = mean_absolute_error(valid[\"median_WS\"], predictions)\n", 1245 | "\n", 1246 | "print(\"VALIDATION SET\")\n", 1247 | "print(\"N:\", len(valid[\"median_WS\"]))\n", 1248 | "print(\"R2:\", r2)\n", 1249 | "print(\"Root Mean Square Error:\", r_mse)\n", 1250 | "print(\"Mean Absolute Error:\", mae)\n", 1251 | "\n", 1252 | "# assume test and predictions are two arrays of the same length\n", 1253 | "correlation, p_value = spearmanr(valid[\"median_WS\"], predictions)\n", 1254 | "\n", 1255 | "print(\"Spearman correlation:\", correlation)\n", 1256 | "print(\"p-value:\", p_value)\n", 1257 | "\n", 1258 | "plt.scatter(valid[\"median_WS\"], predictions)\n", 1259 | "plt.xlabel(\"validation['median_WS']\")\n", 1260 | "plt.ylabel(\"predictions\")\n", 1261 | "plt.title(\"Scatter Plot of Validation['median_WS'] vs Predictions\")\n", 1262 | "plt.show()" 1263 | ] 1264 | }, 1265 | { 1266 | "cell_type": "code", 1267 | "execution_count": 10, 1268 | "id": "4eea13b3-f35f-4db9-a1a9-61f2d16faceb", 1269 | "metadata": {}, 1270 | "outputs": [ 1271 | { 1272 | "name": "stdout", 1273 | "output_type": "stream", 1274 | "text": [ 1275 | "TEST SET\n", 1276 | "N: 1022\n", 1277 | "R2: 0.8223617048517198\n", 1278 | "Root Mean Square Error: 0.9380008871807847\n", 1279 | "Mean Absolute Error: 0.6807156079688436\n", 1280 | "Spearman correlation: 0.8991562895604099\n", 1281 | "p-value: 0.0\n" 1282 | ] 1283 | }, 1284 | { 1285 | "data": { 1286 | "image/png": "\n", 1287 | "text/plain": [ 1288 | "
" 1289 | ] 1290 | }, 1291 | "metadata": { 1292 | "needs_background": "light" 1293 | }, 1294 | "output_type": "display_data" 1295 | } 1296 | ], 1297 | "source": [ 1298 | "# Prepare new SMILES strings for prediction\n", 1299 | "test_smiles = test['Standardized_SMILES']\n", 1300 | "\n", 1301 | "# Predict properties for new SMILES strings\n", 1302 | "predictions = []\n", 1303 | "for smiles in test_smiles:\n", 1304 | " inputs = tokenizer(smiles, return_tensors=\"pt\", padding='max_length', truncation=True, max_length=195).to(device) #max_length=195\n", 1305 | " with torch.no_grad():\n", 1306 | " outputs = model(**inputs)\n", 1307 | " predicted_property = outputs.logits.squeeze().item()\n", 1308 | " predictions.append(predicted_property)\n", 1309 | "\n", 1310 | "r_mse = mean_squared_error(test[\"median_WS\"], predictions, squared=False)\n", 1311 | "r2 = r2_score(test[\"median_WS\"], predictions)\n", 1312 | "mae = mean_absolute_error(test[\"median_WS\"], predictions)\n", 1313 | "\n", 1314 | "print(\"TEST SET\")\n", 1315 | "print(\"N:\", len(test[\"median_WS\"]))\n", 1316 | "print(\"R2:\", r2)\n", 1317 | "print(\"Root Mean Square Error:\", r_mse)\n", 1318 | "print(\"Mean Absolute Error:\", mae)\n", 1319 | "\n", 1320 | "# assume test and predictions are two arrays of the same length\n", 1321 | "correlation, p_value = spearmanr(test[\"median_WS\"], predictions)\n", 1322 | "\n", 1323 | "print(\"Spearman correlation:\", correlation)\n", 1324 | "print(\"p-value:\", p_value)\n", 1325 | "\n", 1326 | "plt.scatter(test[\"median_WS\"], predictions)\n", 1327 | "plt.xlabel(\"test['median_WS']\")\n", 1328 | "plt.ylabel(\"predictions\")\n", 1329 | "plt.title(\"Scatter Plot of Test['median_WS'] vs Predictions\")\n", 1330 | "plt.show()" 1331 | ] 1332 | }, 1333 | { 1334 | "cell_type": "code", 1335 | "execution_count": 12, 1336 | "id": "7bf98da0-79d3-4e63-9350-ea6c33ed7fc0", 1337 | "metadata": {}, 1338 | "outputs": [], 1339 | "source": [ 1340 | "# Save results\n", 1341 | "results_df = pd.DataFrame({\"actual_WS\": test[\"median_WS\"], \"predicted_WS\": predictions})\n", 1342 | "results_df.to_csv(\"testset_results.csv\", index=False)" 1343 | ] 1344 | } 1345 | ], 1346 | "metadata": { 1347 | "kernelspec": { 1348 | "display_name": "Python 3 (ipykernel)", 1349 | "language": "python", 1350 | "name": "python3" 1351 | }, 1352 | "language_info": { 1353 | "codemirror_mode": { 1354 | "name": "ipython", 1355 | "version": 3 1356 | }, 1357 | "file_extension": ".py", 1358 | "mimetype": "text/x-python", 1359 | "name": "python", 1360 | "nbconvert_exporter": "python", 1361 | "pygments_lexer": "ipython3", 1362 | "version": "3.9.5" 1363 | } 1364 | }, 1365 | "nbformat": 4, 1366 | "nbformat_minor": 5 1367 | } 1368 | --------------------------------------------------------------------------------