├── Dockerfile ├── preprocess_bindingdb.py ├── README.md ├── data.py ├── score_compounds.py ├── train.py └── models.py /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM conda/miniconda3 2 | # RUN conda install pytorch torchvision torchaudio cpuonly -c pytorch 3 | RUN conda install rdkit pytorch -c conda-forge -c pytorch 4 | # WORKDIR /workspace 5 | RUN python -m pip install scipy 6 | COPY . . 7 | ENTRYPOINT ["python", "score_compounds.py"] -------------------------------------------------------------------------------- /preprocess_bindingdb.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import math 3 | import pickle 4 | from tqdm import tqdm 5 | import numpy as np 6 | from rdkit.Chem import MolFromSmiles, MolToSmiles 7 | from sklearn.model_selection import train_test_split 8 | import collections 9 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 10 | 11 | 12 | def featurize_mol(smiles): 13 | return np.array(GetMorganFingerprintAsBitVect(MolFromSmiles(smiles), 3)) 14 | 15 | 16 | mols = [] 17 | f = open('BindingDB_All.tsv', 'r') 18 | next(f) 19 | seqs = [] 20 | for i, row in tqdm(enumerate(csv.reader(f, delimiter=' '))): 21 | # 8 or 10 for ki/kd, 9 or 11 for ic50/ec50 22 | if (row[8] or row[9] or row[10] or row[11]) and (10 < len([char for char in row[1] if char not in '()=@[]123456789']) < 70) and row[37] != 'NULL' and MolFromSmiles(row[1]): 23 | val = (row[10] if row[10] else (row[8] if row[8] else (row[9] if row[9] else row[11]))).replace('<', '').replace('>', '').strip() 24 | seqs.append(row[37].upper()) 25 | mols.append((MolToSmiles(MolFromSmiles(row[1])), math.log10(float(val) + 1e-10))) 26 | 27 | allowed_seqs = [seq for seq, count in collections.Counter(seqs).most_common() if count > 10] 28 | 29 | for seq in tqdm(allowed_seqs): 30 | vals = [mols[i][1] for i in range(len(mols)) if seqs[i] == seq] 31 | if not (True in [4 < val < 50 for val in vals]): 32 | i = 0 33 | while i < len(mols): 34 | if seqs[i] == seq: 35 | del mols[i] 36 | del seqs[i] 37 | else: 38 | i += 1 39 | allowed_seqs = [seq for seq, count in collections.Counter(seqs).most_common() if count > 10] 40 | 41 | training_seqs, testing_seqs = train_test_split(allowed_seqs, test_size=100) 42 | training_seqs = set(training_seqs) 43 | testing_seqs = set(testing_seqs) 44 | train_mols, train_seqs = zip(*[(mols[i], seqs[i]) for i in range(len(mols)) if seqs[i] in training_seqs]) 45 | test_mols, test_seqs = zip(*[(mols[i], seqs[i]) for i in range(len(mols)) if seqs[i] in testing_seqs]) 46 | y_train = np.array([binding for _, binding in train_mols]) 47 | y_test = np.array([binding for _, binding in test_mols]) 48 | 49 | x_train = np.array([featurize_mol(smiles) for smiles, _ in train_mols], dtype=bool) 50 | x_test = np.array([featurize_mol(smiles) for smiles, _ in test_mols], dtype=bool) 51 | pickle.dump((x_train, x_test, y_train, y_test, train_seqs, test_seqs), open('bindingdb_data.pickle', 'wb')) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Few-Shot Compound Activity Prediction (FS-CAP) 2 | 3 | This repository contains code for the few-shot compound activity prediction (FS-CAP) algorithm. 4 | 5 | ## Docker instructions 6 | First add a trained model file to the same folder as the `Dockerfile` (or use the one provided [here](https://drive.google.com/file/d/1SD8H5j6U7gyZOI_oncZrEzDBrz-z7-Ng/view?usp=sharing), which was trained with 8 context compounds), then run `sudo docker build -t fscap .` to build the container, then run `sudo docker run fscap` along with any command line arguments. For example, `sudo docker run fscap --context_smiles "" --context_activities --query_smiles ""`. 7 | 8 | ## Requirements 9 | [RDKit](https://www.rdkit.org/docs/Install.html) is required. All code was tested in Python 3.10. The following pip packages are also required: 10 | ``` 11 | torch 12 | scipy 13 | scikit-learn 14 | numpy 15 | tqdm 16 | ``` 17 | 18 | ## Preprocessing 19 | We only provide code to preprocess BindingDB for training, but testing on other datasets using a trained model should be relatively straightforward using the `score_compounds.py` script. 20 | 21 | ### BindingDB 22 | `preprocess_bindingdb.py` contains code to extract and preprocess data from BindingDB. Calling `python preprocess_bindingdb.py` will load data from `BindingDB_All.tsv` which should be placed in the folder beforehand, and after running it will produce a `bindingdb_data.pickle` file that is ready for training. For the paper, we used `BindingDB_All.tsv` from BindingDB's [Download](https://www.bindingdb.org/rwd/bind/chemsearch/marvin/SDFdownload.jsp?all_download=yes) page, available [here](https://www.bindingdb.org/bind/downloads/BindingDB_All_2022m8.tsv.zip). 23 | 24 | ## Training 25 | `train.py` contains the main script to train FS-CAP. By default, the model will train with 8 context compound and will save tensorboard logs to the `logs` folder. After training, it will save the model file to `model.pt`. Other model hyperparameters can be found and adjusted in the `config` variable in the `train.py` file. 26 | 27 | ## Inference 28 | `score_compounds.py` uses the trained model to perform inference on a given set of context and query compounds. The following parameters must be supplied: `--context_smiles` specifies the SMILES strings of the context molecules, separated by semicolons (e.g. `CCC;CCCC;CCCCC`), and `--context_activities` specifies the associated activites in nanomoles/liter (nM) (e.g. `1000;1;10` if the activites are 1000 nM, 1 nM, and 10 nM, respectively). `--query_smiles` specifies the SMILES string(s) of the query molecule(s) (if multiple, separate with semicolons), `--model_file` specifies the path to the trained model (default `model.pt`), and `encoding_dim` specifies the `encoding_dim` parameter used in training (default 512). The script prints to stdout the activity prediction of the query molecule(s) in nM, one prediction per line. 29 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import pickle 4 | import random 5 | from copy import deepcopy, copy 6 | import numpy as np 7 | 8 | 9 | class FSCAPDataset(Dataset): 10 | def __init__(self, x, y, targets, context_ranges): 11 | target_to_idxs = {} 12 | for i in range(len(x)): 13 | if targets[i] not in target_to_idxs: 14 | target_to_idxs[targets[i]] = [] 15 | target_to_idxs[targets[i]].append(i) 16 | y[y < 2] = 2 17 | y[y > 10] = 10 18 | for seq in target_to_idxs: 19 | seq_data = y[target_to_idxs[seq]].flatten() 20 | top_idxs = [] 21 | for start, end in context_ranges: 22 | top_idxs.append(torch.arange(0, len(seq_data))[(seq_data >= start) & (seq_data < end)]) 23 | top_idxs = [[target_to_idxs[seq][idx] for idx in idxs] for idxs in top_idxs] 24 | target_to_idxs[seq] = top_idxs 25 | self.x = x 26 | self.y = y 27 | self.targets = targets 28 | self.target_to_idxs = target_to_idxs 29 | self.avail_idxs = {target: [[]] * len(context_ranges) for target in self.target_to_idxs} 30 | 31 | def __len__(self): 32 | return len(self.x) 33 | 34 | def __getitem__(self, idx): 35 | target = self.targets[idx] 36 | context_idxs = [] 37 | for i in range(len(self.avail_idxs[target])): 38 | if not self.avail_idxs[target][i]: 39 | self.avail_idxs[target][i] = copy(self.target_to_idxs[target][i]) 40 | random.shuffle(self.avail_idxs[target][i]) 41 | context_idxs.append(self.avail_idxs[target][i].pop()) 42 | return self.x[context_idxs], self.y[context_idxs], self.x[idx], self.y[idx], target 43 | 44 | 45 | def get_dataloaders(batch_size, context_ranges): 46 | x_train, x_test, y_train, y_test, train_seqs, test_seqs = pickle.load(open('bindingdb_data.pickle', 'rb')) 47 | 48 | valid_seqs = [] 49 | for line in open('clusterRes_rep_seq.fasta'): 50 | if not line.startswith('>'): 51 | valid_seqs.append(line.strip()) 52 | valid_idxs = [] 53 | for i in range(len(x_train)): 54 | if train_seqs[i] in valid_seqs: 55 | valid_idxs.append(i) 56 | valid_idxs = np.array(valid_idxs) 57 | x_train = x_train[valid_idxs] 58 | y_train = y_train[valid_idxs] 59 | train_seqs = [seq for seq in train_seqs if seq in valid_seqs] 60 | 61 | valid_idxs = [] 62 | for i in range(len(x_test)): 63 | if test_seqs[i] in valid_seqs: 64 | valid_idxs.append(i) 65 | valid_idxs = np.array(valid_idxs) 66 | x_test = x_test[valid_idxs] 67 | y_test = y_test[valid_idxs] 68 | test_seqs = [seq for seq in test_seqs if seq in valid_seqs] 69 | 70 | train_dataloader = DataLoader(FSCAPDataset(x_train, y_train, train_seqs, context_ranges), batch_size=batch_size, shuffle=True) 71 | test_dataloader = DataLoader(FSCAPDataset(x_test, y_test, test_seqs, context_ranges), batch_size=batch_size, shuffle=True) 72 | return train_dataloader, test_dataloader 73 | 74 | 75 | def get_for_eval(data_file='bindingdb_data.pickle'): 76 | x_train, x_test, y_train, y_test, train_seqs, test_seqs, token_to_idx = pickle.load(open(data_file, 'rb')) 77 | return token_to_idx, y_train.mean(), y_train.std(), max(x_train.max(), x_test.max()) + 1, x_train.shape[1] -------------------------------------------------------------------------------- /score_compounds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from rdkit.Chem import MolFromSmiles 4 | from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect 5 | import torch 6 | from models import * 7 | from data import * 8 | import math 9 | 10 | 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | 14 | class FSCAP: 15 | def __init__(self, model_file): 16 | self.context_encoder = MLPEncoder(2048, config).to(device) 17 | self.query_encoder = MLPEncoder(2048, config).to(device) 18 | self.predictor = Predictor(config['encoding_dim'] * 2, config).to(device) 19 | context_encoder_dict, query_encoder_dict, predictor_dict = torch.load(model_file, map_location=device) 20 | self.context_encoder.load_state_dict(context_encoder_dict) 21 | self.query_encoder.load_state_dict(query_encoder_dict) 22 | self.predictor.load_state_dict(predictor_dict) 23 | self.context_encoder.eval() 24 | self.query_encoder.eval() 25 | self.predictor.eval() 26 | 27 | def predict(self, context_smiles, context_activities, queries): 28 | context_x = torch.tensor(np.array([self.featurize_mol(smile) for smile in context_smiles], dtype=bool)).unsqueeze(0) 29 | context_y = torch.tensor(np.array([self.clip_activity(math.log10(float(activity) + 1e-10)) for activity in context_activities])).unsqueeze(0) 30 | query_x = torch.tensor(np.array([self.featurize_mol(smile) for smile in queries], dtype=bool)) 31 | context_x, context_y, query_x = context_x.to(dtype=torch.float32, device=device), context_y.to(dtype=torch.float32, device=device).unsqueeze(-1), query_x.to(dtype=torch.float32, device=device) 32 | context = torch.zeros((len(context_smiles), len(context_x), config['encoding_dim']), device=device) 33 | for j in range(len(context_smiles)): 34 | context[j] = self.context_encoder(context_x[:, j, :], context_y[:, j, :]) 35 | context = context.mean(0) 36 | query = self.query_encoder(query_x) 37 | tiled_contexts = torch.zeros((len(queries), config['encoding_dim']), device=device) 38 | for i in range(len(queries)): 39 | tiled_contexts[i] = context 40 | x = torch.concat((tiled_contexts, query), dim=1) 41 | out = self.predictor(x) 42 | return (10 ** out.detach().cpu().flatten()).tolist() 43 | 44 | def featurize_mol(self, smiles): 45 | if not ((10 <= len([char for char in smiles if char not in '()=@[]123456789']) <= 70) and MolFromSmiles(smiles)): 46 | raise ValueError('smiles is invalid, or too long/short') 47 | return np.array(GetMorganFingerprintAsBitVect(MolFromSmiles(smiles), 3)) 48 | 49 | def clip_activity(self, val): 50 | if val < -2.5: 51 | val = -2.5 52 | if val > 6.5: 53 | val = 6.5 54 | return val 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--context_smiles', type=str) 60 | parser.add_argument('--context_activities', type=str) 61 | parser.add_argument('--query_smiles', type=str) 62 | parser.add_argument('--model_file', type=str, default='model.pt') 63 | parser.add_argument('--encoding_dim', type=int, default=512) 64 | args = parser.parse_args() 65 | config = {'encoding_dim': args.encoding_dim} 66 | fscap = FSCAP(args.model_file) 67 | for prediction in fscap.predict(args.context_smiles.split(';'), args.context_activities.split(';'), args.query_smiles.split(';')): 68 | print(prediction) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from scipy.stats import linregress 4 | import numpy as np 5 | from models import * 6 | from data import * 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | context_num = 8 11 | 12 | 13 | config = { 14 | 'run_name': f'test', 15 | 'context_ranges': [(-50, 50)] * context_num, # unit is log10 nM 16 | 'val_freq': 1024, 17 | 'lr': 0.000040012, 18 | 'layer_width': 2048, 19 | 'batch_size': 1024, 20 | 'warmup_steps': 128, 21 | 'total_epochs': 2 ** 15, 22 | 'n_heads': 16, 23 | 'n_layers': 4, 24 | 'affinity_embed_layers': 1, 25 | 'init_range': 0.2, 26 | 'scalar_dropout': 0.15766, 27 | 'embed_dropout': 0.16668, 28 | 'final_dropout': 0.10161, 29 | 'pred_dropout': True, 30 | 'pred_batchnorm': False, 31 | 'pred_dropout_p': 0.1, 32 | 'encoder_batchnorm': True 33 | } 34 | 35 | if config['simple']: 36 | config['dataloader_batch'] = 1024 // len(config['context_ranges']) 37 | else: 38 | config['dataloader_batch'] = 128 // len(config['context_ranges']) 39 | 40 | 41 | train_dataloader, test_dataloader = get_dataloaders(config['dataloader_batch'], config['context_ranges']) 42 | context_encoder = MLPEncoder(2048, config).cuda() 43 | query_encoder = MLPEncoder(2048, config).cuda() 44 | predictor = Predictor(config['d_model'] * 2, config).cuda() 45 | 46 | optimizer = optim.RAdam(list(context_encoder.parameters()) + list(query_encoder.parameters()) + list(predictor.parameters()), lr=config['lr']) 47 | 48 | warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, 0.0001, 1, total_iters=config['warmup_steps']) 49 | annealing_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['total_epochs']) 50 | scheduler = optim.lr_scheduler.SequentialLR(optimizer, [warmup_scheduler, annealing_scheduler], milestones=[config['warmup_steps']]) 51 | writer = SummaryWriter('logs/' + config['run_name']) 52 | 53 | epoch = 0 54 | while True: 55 | total_loss = 0 56 | count = 0 57 | for i, (context_x, context_y, query_x, query_y, _) in enumerate(train_dataloader): 58 | context_x, context_y, query_x, query_y = context_x.to(dtype=torch.float32, device='cuda'), context_y.to(dtype=torch.float32, device='cuda').unsqueeze(-1), query_x.to(dtype=torch.float32, device='cuda'), query_y.to(dtype=torch.float32, device='cuda').unsqueeze(-1) 59 | context = torch.zeros((len(config['context_ranges']), len(context_x), config['d_model']), device='cuda') 60 | for j in range(len(config['context_ranges'])): 61 | context[j] = context_encoder(context_x[:, j, :], context_y[:, j, :]) 62 | context = context.mean(0) 63 | query = query_encoder(query_x) 64 | x = torch.concat((context, query), dim=1) 65 | loss = torch.mean((predictor(x) - query_y) ** 2) 66 | total_loss += loss.item() 67 | count += 1 68 | loss.backward() 69 | if i % (config['val_freq'] * (config['batch_size'] // config['dataloader_batch'])) == 0: 70 | writer.add_scalar('loss/train', total_loss / count, epoch) 71 | context_encoder.eval() 72 | query_encoder.eval() 73 | predictor.eval() 74 | with torch.no_grad(): 75 | loss = 0 76 | target_to_pred = {} 77 | target_to_real = {} 78 | all_pred = [] 79 | all_real = [] 80 | for j, (context_x, context_y, query_x, query_y, targets) in enumerate(test_dataloader): 81 | context_x, context_y, query_x, query_y = context_x.to(dtype=torch.float32, device='cuda'), context_y.to(dtype=torch.float32, device='cuda').unsqueeze(-1), query_x.to(dtype=torch.float32, device='cuda'), query_y.to(dtype=torch.float32, device='cuda').unsqueeze(-1) 82 | context = torch.zeros((len(config['context_ranges']), len(context_x), config['d_model']), device='cuda') 83 | for k in range(len(config['context_ranges'])): 84 | context[k] = context_encoder(context_x[:, k, :], context_y[:, k, :]) 85 | context = context.mean(0) 86 | query = query_encoder(query_x) 87 | x = torch.concat((context, query), dim=1) 88 | out = predictor(x) 89 | loss += torch.mean((out - query_y) ** 2).item() 90 | pred = out.cpu().numpy().flatten() 91 | real = query_y.cpu().numpy().flatten() 92 | all_pred.extend(pred) 93 | all_real.extend(real) 94 | for k, target in enumerate(targets): 95 | if target not in target_to_real: 96 | target_to_pred[target] = [] 97 | target_to_real[target] = [] 98 | target_to_pred[target].append(pred[k]) 99 | target_to_real[target].append(real[k]) 100 | writer.add_scalar('loss/test', loss / (j + 1), epoch) 101 | try: 102 | writer.add_scalar('corr/raw', linregress(all_pred, all_real).rvalue, epoch) 103 | except: 104 | writer.add_scalar('corr/raw', 0, epoch) 105 | corrs = [] 106 | for target in target_to_real: 107 | try: 108 | corrs.append(linregress(target_to_pred[target], target_to_real[target]).rvalue) 109 | except: 110 | corrs.append(0) 111 | writer.add_scalar('corr/per_target', np.mean(corrs), epoch) 112 | 113 | context_encoder.train() 114 | query_encoder.train() 115 | predictor.train() 116 | if i % (config['batch_size'] // config['dataloader_batch']) == 0: 117 | total_loss = 0 118 | count = 0 119 | optimizer.step() 120 | optimizer.zero_grad() 121 | scheduler.step() 122 | writer.add_scalar('lr', scheduler.get_last_lr()[0], epoch) 123 | epoch += 1 124 | if epoch == config['total_epochs']: 125 | torch.save((context_encoder.state_dict(), query_encoder.state_dict(), predictor.state_dict()), f'model.pt') 126 | exit() 127 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class MLPEncoder(nn.Module): 6 | def __init__(self, in_dim, config): 7 | super().__init__() 8 | if config['encoder_batchnorm']: 9 | self.fc = nn.Sequential(nn.Linear(in_dim, config['layer_width']), 10 | nn.BatchNorm1d(config['layer_width']), 11 | nn.ReLU(), 12 | nn.Linear(config['layer_width'], config['layer_width']), 13 | nn.BatchNorm1d(config['layer_width']), 14 | nn.ReLU(), 15 | nn.Linear(config['layer_width'], config['layer_width']), 16 | nn.BatchNorm1d(config['layer_width']), 17 | nn.ReLU(), 18 | nn.Linear(config['layer_width'], config['layer_width']), 19 | nn.BatchNorm1d(config['layer_width']), 20 | nn.ReLU(), 21 | nn.Linear(config['layer_width'], config['layer_width']), 22 | nn.BatchNorm1d(config['layer_width']), 23 | nn.ReLU(), 24 | nn.Linear(config['layer_width'], config['d_model'])) 25 | else: 26 | self.fc = nn.Sequential(nn.Linear(in_dim, config['layer_width']), 27 | nn.ReLU(), 28 | nn.Linear(config['layer_width'], config['layer_width']), 29 | nn.ReLU(), 30 | nn.Linear(config['layer_width'], config['layer_width']), 31 | nn.ReLU(), 32 | nn.Linear(config['layer_width'], config['layer_width']), 33 | nn.ReLU(), 34 | nn.Linear(config['layer_width'], config['layer_width']), 35 | nn.ReLU(), 36 | nn.Linear(config['layer_width'], config['d_model'])) 37 | 38 | def forward(self, x, scalar=None): 39 | if scalar != None: 40 | return self.fc(x * scalar) 41 | return self.fc(x) 42 | 43 | 44 | class Predictor(nn.Module): 45 | def __init__(self, in_dim, config): 46 | super().__init__() 47 | if config['pred_dropout']: 48 | if config['pred_batchnorm']: 49 | self.fc = nn.Sequential(nn.Linear(in_dim, 2048), 50 | nn.ReLU(), 51 | nn.Dropout(config['pred_dropout_p']), 52 | nn.BatchNorm1d(2048), 53 | nn.Linear(2048, 2048), 54 | nn.ReLU(), 55 | nn.Dropout(config['pred_dropout_p']), 56 | nn.BatchNorm1d(2048), 57 | nn.Linear(2048, 2048), 58 | nn.ReLU(), 59 | nn.Dropout(config['pred_dropout_p']), 60 | nn.BatchNorm1d(2048), 61 | nn.Linear(2048, 1024), 62 | nn.ReLU(), 63 | nn.Linear(1024, 1024), 64 | nn.ReLU(), 65 | nn.Linear(1024, 1)) 66 | else: 67 | self.fc = nn.Sequential(nn.Linear(in_dim, 2048), 68 | nn.ReLU(), 69 | nn.Dropout(config['pred_dropout_p']), 70 | nn.Linear(2048, 2048), 71 | nn.ReLU(), 72 | nn.Dropout(config['pred_dropout_p']), 73 | nn.Linear(2048, 2048), 74 | nn.ReLU(), 75 | nn.Dropout(config['pred_dropout_p']), 76 | nn.Linear(2048, 1024), 77 | nn.ReLU(), 78 | nn.Linear(1024, 1024), 79 | nn.ReLU(), 80 | nn.Linear(1024, 1)) 81 | else: 82 | if config['pred_batchnorm']: 83 | self.fc = nn.Sequential(nn.Linear(in_dim, 2048), 84 | nn.ReLU(), 85 | nn.BatchNorm1d(2048), 86 | nn.Linear(2048, 2048), 87 | nn.ReLU(), 88 | nn.BatchNorm1d(2048), 89 | nn.Linear(2048, 2048), 90 | nn.ReLU(), 91 | nn.BatchNorm1d(2048), 92 | nn.Linear(2048, 1024), 93 | nn.ReLU(), 94 | nn.Linear(1024, 1024), 95 | nn.ReLU(), 96 | nn.Linear(1024, 1)) 97 | else: 98 | self.fc = nn.Sequential(nn.Linear(in_dim, 2048), 99 | nn.ReLU(), 100 | nn.Linear(2048, 2048), 101 | nn.ReLU(), 102 | nn.Linear(2048, 2048), 103 | nn.ReLU(), 104 | nn.Linear(2048, 1024), 105 | nn.ReLU(), 106 | nn.Linear(1024, 1024), 107 | nn.ReLU(), 108 | nn.Linear(1024, 1)) 109 | 110 | def forward(self, x): 111 | return self.fc(x) 112 | --------------------------------------------------------------------------------