├── requirements.txt ├── normalization.py ├── README.md ├── evaluation_utils_ches20.py ├── evaluation_utils.py ├── run_trans_ascadf.sh ├── run_trans_ascadr.sh ├── run_trans_ches20.sh ├── data_utils_ches20.py ├── data_utils.py ├── fast_attention.py ├── transformer.py └── train_trans.py /requirements.txt: -------------------------------------------------------------------------------- 1 | # Tested with Python 3.8.10 2 | absl-py==2.3.1 3 | numpy==1.24.3 4 | scipy==1.10.1 5 | h5py==3.11.0 6 | tensorflow==2.13.0 7 | -------------------------------------------------------------------------------- /normalization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class LayerScaling(tf.keras.layers.Layer): 4 | def __init__(self, epsilon=0.001, **kwargs): 5 | super().__init__(**kwargs) 6 | 7 | self.epsilon = epsilon 8 | 9 | 10 | def build(self, input_shape): 11 | param_shape = [input_shape[-1]] 12 | 13 | self.gamma = self.add_weight( 14 | name='gamma', 15 | shape=param_shape, 16 | initializer='ones', 17 | regularizer=None, 18 | constraint=None, 19 | trainable=True) 20 | 21 | 22 | def call(self, inputs): 23 | input_shape = inputs.shape 24 | ndims = len(input_shape) 25 | 26 | broadcast_shape = [1] * ndims 27 | broadcast_shape[-1] = input_shape.dims[-1].value 28 | scale = tf.reshape(self.gamma, broadcast_shape) 29 | 30 | _, variance = tf.nn.moments(inputs, axes=-1, keepdims=True) 31 | outputs = inputs / tf.math.sqrt(variance + self.epsilon) * scale 32 | 33 | return outputs 34 | 35 | 36 | 37 | class LayerCentering(tf.keras.layers.Layer): 38 | def __init__(self, **kwargs): 39 | super().__init__(**kwargs) 40 | 41 | 42 | def build(self, input_shape): 43 | param_shape = [input_shape[-1]] 44 | 45 | self.beta = self.add_weight( 46 | name='beta', 47 | shape=param_shape, 48 | initializer='zeros', 49 | regularizer=None, 50 | constraint=None, 51 | trainable=True) 52 | 53 | 54 | def call(self, inputs): 55 | input_shape = inputs.shape 56 | ndims = len(input_shape) 57 | 58 | broadcast_shape = [1] * ndims 59 | broadcast_shape[-1] = input_shape.dims[-1].value 60 | offset = tf.reshape(self.beta, broadcast_shape) 61 | 62 | mean, _ = tf.nn.moments(inputs, axes=-1, keepdims=True) 63 | outputs = inputs - mean + offset 64 | 65 | return outputs 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EstraNet: An Efficient Shift-Invariant Transformer Network for Side Channel Analysis 2 | 3 | This repository contains the implementation of **EstraNet**, an efficient shift-invariant transformer network for Side-Channel Analysis. 4 | For more details, refer to the [paper](https://tches.iacr.org/index.php/TCHES/article/view/11255). 5 | 6 | --- 7 | ## Repository Structure 8 | - **`fast_attention.py`** – Implements the proposed GaussiP attention layer. 9 | - **`normalization.py`** – Implements the layer-centering normalization. 10 | - **`transformer.py`** – Defines the EstraNet model architecture. 11 | - **`train_trans.py`** – Training and evaluation script for EstraNet. 12 | - **`data_utils.py`** – Utilities for loading ASCADf and ASCADr datasets. 13 | - **`data_utils_ches20.py`** – Utilities for loading the CHES20 dataset. 14 | - **`evaluation_utils.py`** – Computes guessing entropy for ASCAD datasets. 15 | - **`evaluation_utils_ches20.py`** – Computes guessing entropy for CHES20 dataset. 16 | - **`run_trans_\.sh`** – Bash scripts with predefined hyperparameters for specific datasets, where `` is one of: 17 | - **ASCADf** ([fixed key](https://github.com/ANSSI-FR/ASCAD/tree/master/ATMEGA_AES_v1/ATM_AES_v1_fixed_key)) 18 | - **ASCADr** ([random key](https://github.com/ANSSI-FR/ASCAD/tree/master/ATMEGA_AES_v1/ATM_AES_v1_variable_key)) 19 | - **CHES20** ([CHES CTF 2020](https://ctf.spook.dev/)) 20 | 21 | --- 22 | 23 | ## Data Pre-processing: 24 | - For the **CHES CTF 2020** dataset, the traces are multiplied by a constant `0.004` to keep the feature value range within **[-120, 120]**. 25 | 26 | --- 27 | 28 | ## Tested on 29 | - Python 3.8.10 30 | - absl-py == 2.3.1 31 | - numpy == 1.24.3 32 | - scipy == 1.10.1 33 | - h5py == 3.11.0 34 | - tensorflow == 2.13.0 35 | 36 | --- 37 | 38 | ## Getting Started 39 | 40 | 1. **Clone the repository:** 41 | ```bash 42 | git clone https://github.com/suvadeep-iitb/EstraNet.git 43 | cd EstraNet 44 | ``` 45 | 2. **Install dependencies (Python >= 3.8 recommended):** 46 | ```bash 47 | pip install -r requirements.txt 48 | ``` 49 | 3. **Set dataset path in the bash script:** 50 | ``` 51 | Open run_trans_\.sh and set the dataset path variable properly. 52 | ``` 53 | 4. **Train EstraNet:** 54 | ```bash 55 | bash run_trans_\.sh train 56 | ``` 57 | 5. **Perform Evaluation:** 58 | ```bash 59 | bash run_trans_\.sh test 60 | ``` 61 | 62 | ---- 63 | 64 | ## Citation: 65 | ``` 66 | @article{DBLP:journals/tches/HajraCM24, 67 | author = {Suvadeep Hajra and 68 | Siddhartha Chowdhury and 69 | Debdeep Mukhopadhyay}, 70 | title = {EstraNet: An Efficient Shift-Invariant Transformer Network for Side-Channel 71 | Analysis}, 72 | journal = {{IACR} Trans. Cryptogr. Hardw. Embed. Syst.}, 73 | volume = {2024}, 74 | number = {1}, 75 | pages = {336--374}, 76 | year = {2024}, 77 | url = {https://doi.org/10.46586/tches.v2024.i1.336-374}, 78 | doi = {10.46586/TCHES.V2024.I1.336-374}, 79 | timestamp = {Sat, 08 Jun 2024 13:14:59 +0200}, 80 | biburl = {https://dblp.org/rec/journals/tches/HajraCM24.bib}, 81 | bibsource = {dblp computer science bibliography, https://dblp.org} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /evaluation_utils_ches20.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import special 3 | import h5py 4 | import tensorflow as tf 5 | 6 | import os, sys 7 | 8 | 9 | def sbox_layer(x): 10 | y1 = (x[0] & x[1]) ^ x[2] 11 | y0 = (x[3] & x[0]) ^ x[1] 12 | y3 = (y1 & x[3]) ^ x[0] 13 | y2 = (y0 & y1) ^ x[3] 14 | return np.stack([y0, y1, y2, y3], axis=1) 15 | 16 | 17 | def shuffle_all(predictions, nonces): 18 | perm = np.random.permutation(predictions.shape[0]) 19 | predictions = predictions[perm] 20 | nonces = nonces[perm] 21 | 22 | return predictions, nonces 23 | 24 | 25 | def get_log_prob(predictions, plaintext): 26 | predictions = np.squeeze(predictions) 27 | n_classes = predictions.shape[0] 28 | keys = np.arange(n_classes, dtype=int) 29 | x_xor_k = np.bitwise_xor(keys, plaintext) 30 | z = np.take(sbox, x_xor_k) 31 | log_prob = np.take(predictions, z) 32 | 33 | return log_prob 34 | 35 | 36 | def gen_key_bits(): 37 | values = np.arange(16, dtype=np.uint8).reshape(-1, 1) 38 | key_bits = np.unpackbits(values, axis=1)[:, -4:] 39 | for k in range(16): 40 | t = key_bits[k, 0] 41 | key_bits[k, 0] = key_bits[k, 3] 42 | key_bits[k, 3] = t 43 | t = key_bits[k, 1] 44 | key_bits[k, 1] = key_bits[k, 2] 45 | key_bits[k, 2] = t 46 | return key_bits 47 | 48 | 49 | def compute_key_rank(predictions, nonces, keys): 50 | n_samples, n_classes = predictions.shape 51 | nonces = (nonces[:n_samples] & 0x1) 52 | keys = np.squeeze(keys) 53 | 54 | predictions, nonces = shuffle_all(predictions, nonces) 55 | 56 | def get_corr_key(keys): 57 | corr_key = ((keys[0] & 0x1) << 0) 58 | corr_key |= ((keys[1] & 0x1) << 1) 59 | corr_key |= ((keys[2] & 0x1) << 2) 60 | corr_key |= ((keys[3] & 0x1) << 3) 61 | return corr_key 62 | corr_key = get_corr_key(keys) 63 | 64 | key_bits = gen_key_bits() 65 | n_keys = key_bits.shape[0] 66 | 67 | neg_log_prob = np.zeros((n_samples, n_keys)) 68 | for k in range(n_keys): 69 | key_rep = np.reshape(key_bits[k, :], [1, -1]) 70 | sbox_in = (nonces ^ key_rep).T 71 | sbox_out = (sbox_layer(sbox_in) & 0x1) 72 | sbox_out = sbox_out.astype(np.float32) 73 | scores = tf.reduce_mean( 74 | tf.nn.sigmoid_cross_entropy_with_logits(sbox_out, predictions), 75 | axis = 1 76 | ).numpy() 77 | neg_log_prob[:, k] = scores 78 | 79 | cum_neg_log_prob = np.zeros((n_samples, n_keys)) 80 | last_neg_log_prob = np.zeros((1, n_keys)) 81 | for i in range(n_samples): 82 | last_neg_log_prob += neg_log_prob[i] 83 | cum_neg_log_prob[i, :] = last_neg_log_prob 84 | 85 | sorted_keys = np.argsort(cum_neg_log_prob, axis=1) 86 | key_ranks = np.zeros((n_samples), dtype=int) - 1 87 | for i in range(n_samples): 88 | for j in range(n_keys): 89 | if sorted_keys[i, j] == corr_key: 90 | key_ranks[i] = j 91 | break 92 | 93 | for i in range(n_samples): 94 | assert key_ranks[i] >= 0, "Assertion failed at index %s" % i 95 | 96 | return key_ranks 97 | 98 | 99 | if __name__ == '__main__': 100 | data_path = sys.argv[1] 101 | 102 | data = h5py.File(data_path, 'r') 103 | 104 | for i in range(10): 105 | label = data['Profiling_traces']['labels'][i] 106 | ptest = data['Profiling_traces']['metadata'][i]['plaintext'][2] 107 | key = data['Profiling_traces']['metadata'][i]['key'][2] 108 | 109 | print(str(label)+'/'+str(sbox[np.bitwise_xor(ptest, key)])+'/'+str(key)) 110 | 111 | -------------------------------------------------------------------------------- /evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import special 3 | import h5py 4 | 5 | import os, sys 6 | 7 | # Rijndael S-box 8 | sbox = np.array( 9 | [0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 10 | 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 11 | 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 12 | 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 13 | 0x71, 0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 14 | 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 15 | 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 16 | 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 17 | 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 18 | 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 19 | 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 20 | 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 21 | 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 22 | 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 23 | 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 24 | 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 25 | 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 26 | 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 27 | 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 28 | 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 29 | 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 30 | 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 31 | 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 32 | 0x54, 0xbb, 0x16]) 33 | 34 | 35 | def shuffle_all(predictions, plaintexts, keys): 36 | rng_state = np.random.get_state() 37 | np.random.shuffle(predictions) 38 | np.random.set_state(rng_state) 39 | np.random.shuffle(plaintexts) 40 | np.random.set_state(rng_state) 41 | np.random.shuffle(keys) 42 | 43 | return predictions, plaintexts, keys 44 | 45 | 46 | def get_log_prob(predictions, plaintext): 47 | predictions = np.squeeze(predictions) 48 | n_classes = predictions.shape[0] 49 | keys = np.arange(n_classes, dtype=int) 50 | x_xor_k = np.bitwise_xor(keys, plaintext) 51 | z = np.take(sbox, x_xor_k) 52 | log_prob = np.take(predictions, z) 53 | 54 | return log_prob 55 | 56 | 57 | def compute_key_rank(predictions, plaintexts, keys): 58 | n_samples, n_classes = predictions.shape 59 | plaintexts = plaintexts[:n_samples] 60 | keys = keys[:n_samples] 61 | predictions, plaintexts, keys = shuffle_all(predictions, plaintexts, keys) 62 | 63 | predictions = special.softmax(predictions, axis=1) 64 | predictions = np.log(predictions+1e-40) 65 | 66 | cum_log_prob = np.zeros((n_samples, n_classes)) 67 | last_log_prob = np.zeros((1, n_classes)) 68 | for i in range(n_samples): 69 | log_prob = get_log_prob(predictions[i, :], plaintexts[i]) 70 | last_log_prob += log_prob 71 | cum_log_prob[i, :] = last_log_prob 72 | 73 | sorted_keys = np.argsort(-cum_log_prob, axis=1) 74 | key_ranks = np.zeros((n_samples), dtype=int) - 1 75 | for i in range(n_samples): 76 | for j in range(n_classes): 77 | if sorted_keys[i, j] == keys[i]: 78 | key_ranks[i] = j 79 | break 80 | 81 | for i in range(n_samples): 82 | assert key_ranks[i] >= 0, "Assertion failed at index %s" % i 83 | 84 | return key_ranks 85 | 86 | 87 | if __name__ == '__main__': 88 | data_path = sys.argv[1] 89 | 90 | data = h5py.File(data_path, 'r') 91 | 92 | for i in range(10): 93 | label = data['Profiling_traces']['labels'][i] 94 | ptest = data['Profiling_traces']['metadata'][i]['plaintext'][2] 95 | key = data['Profiling_traces']['metadata'][i]['key'][2] 96 | 97 | print(str(label)+'/'+str(sbox[np.bitwise_xor(ptest, key)])+'/'+str(key)) 98 | 99 | -------------------------------------------------------------------------------- /run_trans_ascadf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # GPU config 4 | USE_TPU=False 5 | 6 | # Experiment (data/checkpoint/directory) config 7 | DATA_PATH= #Path to the .h5 file containing the dataset 8 | DATASET=ASCAD 9 | CKP_DIR=./ 10 | CKP_IDX=0 11 | WARM_START=False 12 | RESULT_PATH=results 13 | 14 | # Optimization config 15 | LEARNING_RATE=2.5e-4 16 | CLIP=0.25 17 | MIN_LR_RATIO=0.004 18 | INPUT_LENGTH=10000 # or 40000 19 | DATA_DESYNC=200 # 400 for input length 40000 20 | 21 | # Training config 22 | TRAIN_BSZ=16 23 | EVAL_BSZ=16 24 | TRAIN_STEPS=4000000 25 | WARMUP_STEPS=1000000 26 | ITERATIONS=20000 27 | SAVE_STEPS=40000 28 | 29 | # Model config 30 | N_LAYER=2 31 | D_MODEL=128 32 | D_HEAD=32 33 | N_HEAD=8 34 | D_INNER=256 35 | N_HEAD_SM=8 36 | D_HEAD_SM=16 37 | DROPOUT=0.05 38 | CONV_KERNEL_SIZE=3 # The kernel size of the first convolutional layer is set to 11 39 | # This hyper-parameter set the kernel size of the remaining 40 | # convolutional layers 41 | N_CONV_LAYER=2 42 | POOL_SIZE=20 #8 43 | D_KERNEL_MAP=512 44 | BETA_HAT_2=150 45 | MODEL_NORM='preLC' 46 | HEAD_INIT='forward' 47 | SM_ATTN=True 48 | 49 | # Evaluation config 50 | MAX_EVAL_BATCH=100 51 | OUTPUT_ATTN=False 52 | 53 | 54 | if [[ $1 == 'train' ]]; then 55 | python train_trans.py \ 56 | --use_tpu=${USE_TPU} \ 57 | --data_path=${DATA_PATH} \ 58 | --dataset=${DATASET} \ 59 | --checkpoint_dir=${CKP_DIR} \ 60 | --warm_start=${WARM_START} \ 61 | --result_path=${RESULT_PATH} \ 62 | --learning_rate=${LEARNING_RATE} \ 63 | --clip=${CLIP} \ 64 | --min_lr_ratio=${MIN_LR_RATIO} \ 65 | --warmup_steps=${WARMUP_STEPS} \ 66 | --input_length=${INPUT_LENGTH} \ 67 | --data_desync=${DATA_DESYNC} \ 68 | --train_batch_size=${TRAIN_BSZ} \ 69 | --eval_batch_size=${EVAL_BSZ} \ 70 | --train_steps=${TRAIN_STEPS} \ 71 | --iterations=${ITERATIONS} \ 72 | --save_steps=${SAVE_STEPS} \ 73 | --n_layer=${N_LAYER} \ 74 | --d_model=${D_MODEL} \ 75 | --d_head=${D_HEAD} \ 76 | --n_head=${N_HEAD} \ 77 | --d_head=${D_HEAD} \ 78 | --d_inner=${D_INNER} \ 79 | --n_head_softmax=${N_HEAD_SM} \ 80 | --d_head_softmax=${D_HEAD_SM} \ 81 | --dropout=${DROPOUT} \ 82 | --dropatt=${DROPATT} \ 83 | --conv_kernel_size=${CONV_KERNEL_SIZE} \ 84 | --n_conv_layer=${N_CONV_LAYER} \ 85 | --pool_size=${POOL_SIZE} \ 86 | --d_kernel_map=${D_KERNEL_MAP} \ 87 | --beta_hat_2=${BETA_HAT_2} \ 88 | --model_normalization=${MODEL_NORM} \ 89 | --head_initialization=${HEAD_INIT} \ 90 | --softmax_attn=${SM_ATTN} \ 91 | --max_eval_batch=${MAX_EVAL_BATCH} \ 92 | --do_train=True 93 | elif [[ $1 == 'test' ]]; then 94 | python train_trans.py \ 95 | --use_tpu=${USE_TPU} \ 96 | --data_path=${DATA_PATH} \ 97 | --dataset=${DATASET} \ 98 | --checkpoint_dir=${CKP_DIR} \ 99 | --checkpoint_idx=${CKP_IDX} \ 100 | --warm_start=${WARM_START} \ 101 | --result_path=${RESULT_PATH} \ 102 | --learning_rate=${LEARNING_RATE} \ 103 | --clip=${CLIP} \ 104 | --min_lr_ratio=${MIN_LR_RATIO} \ 105 | --warmup_steps=${WARMUP_STEPS} \ 106 | --input_length=${INPUT_LENGTH} \ 107 | --train_batch_size=${TRAIN_BSZ} \ 108 | --eval_batch_size=${EVAL_BSZ} \ 109 | --train_steps=${TRAIN_STEPS} \ 110 | --iterations=${ITERATIONS} \ 111 | --save_steps=${SAVE_STEPS} \ 112 | --n_layer=${N_LAYER} \ 113 | --d_model=${D_MODEL} \ 114 | --d_head=${D_HEAD} \ 115 | --n_head=${N_HEAD} \ 116 | --d_head=${D_HEAD} \ 117 | --d_inner=${D_INNER} \ 118 | --n_head_softmax=${N_HEAD_SM} \ 119 | --d_head_softmax=${D_HEAD_SM} \ 120 | --dropout=${DROPOUT} \ 121 | --dropatt=${DROPATT} \ 122 | --conv_kernel_size=${CONV_KERNEL_SIZE} \ 123 | --n_conv_layer=${N_CONV_LAYER} \ 124 | --pool_size=${POOL_SIZE} \ 125 | --d_kernel_map=${D_KERNEL_MAP} \ 126 | --beta_hat_2=${BETA_HAT_2} \ 127 | --model_normalization=${MODEL_NORM} \ 128 | --head_initialization=${HEAD_INIT} \ 129 | --softmax_attn=${SM_ATTN} \ 130 | --max_eval_batch=${MAX_EVAL_BATCH} \ 131 | --output_attn=${OUTPUT_ATTN} \ 132 | --do_train=False 133 | else 134 | echo "unknown argument 1" 135 | fi 136 | -------------------------------------------------------------------------------- /run_trans_ascadr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # GPU config 4 | USE_TPU=False 5 | 6 | # Experiment (data/checkpoint/directory) config 7 | DATA_PATH=#Path to the .h5 file containing the dataset 8 | DATASET=ASCAD 9 | CKP_DIR=./ 10 | CKP_IDX=0 11 | WARM_START=False 12 | RESULT_PATH=results 13 | 14 | # Optimization config 15 | LEARNING_RATE=2.5e-4 16 | CLIP=0.25 17 | MIN_LR_RATIO=0.004 18 | INPUT_LENGTH=10000 # or 40000 19 | DATA_DESYNC=200 # 400 for input length 40K 20 | 21 | # Training config 22 | TRAIN_BSZ=16 23 | EVAL_BSZ=16 24 | TRAIN_STEPS=4000000 25 | WARMUP_STEPS=1000000 26 | ITERATIONS=20000 27 | SAVE_STEPS=40000 28 | 29 | # Model config 30 | N_LAYER=2 31 | D_MODEL=128 32 | D_HEAD=32 33 | N_HEAD=8 34 | D_INNER=256 35 | N_HEAD_SM=8 36 | D_HEAD_SM=16 37 | DROPOUT=0.05 38 | CONV_KERNEL_SIZE=3 # The kernel size of the first convolutional layer is set to 11 39 | # This hyper-parameter set the kernel size of the remaining 40 | # convolutional layers 41 | N_CONV_LAYER=2 42 | POOL_SIZE=10 43 | D_KERNEL_MAP=512 44 | BETA_HAT_2=50 # 50 for input length 10K and 200 for input length 40K 45 | MODEL_NORM='preLC' 46 | HEAD_INIT='forward' 47 | SM_ATTN=True 48 | 49 | # Evaluation config 50 | MAX_EVAL_BATCH=100 51 | OUTPUT_ATTN=False 52 | 53 | 54 | if [[ $1 == 'train' ]]; then 55 | python train_trans.py \ 56 | --use_tpu=${USE_TPU} \ 57 | --data_path=${DATA_PATH} \ 58 | --dataset=${DATASET} \ 59 | --checkpoint_dir=${CKP_DIR} \ 60 | --warm_start=${WARM_START} \ 61 | --result_path=${RESULT_PATH} \ 62 | --learning_rate=${LEARNING_RATE} \ 63 | --clip=${CLIP} \ 64 | --min_lr_ratio=${MIN_LR_RATIO} \ 65 | --warmup_steps=${WARMUP_STEPS} \ 66 | --input_length=${INPUT_LENGTH} \ 67 | --data_desync=${DATA_DESYNC} \ 68 | --train_batch_size=${TRAIN_BSZ} \ 69 | --eval_batch_size=${EVAL_BSZ} \ 70 | --train_steps=${TRAIN_STEPS} \ 71 | --iterations=${ITERATIONS} \ 72 | --save_steps=${SAVE_STEPS} \ 73 | --n_layer=${N_LAYER} \ 74 | --d_model=${D_MODEL} \ 75 | --d_head=${D_HEAD} \ 76 | --n_head=${N_HEAD} \ 77 | --d_head=${D_HEAD} \ 78 | --d_inner=${D_INNER} \ 79 | --n_head_softmax=${N_HEAD_SM} \ 80 | --d_head_softmax=${D_HEAD_SM} \ 81 | --dropout=${DROPOUT} \ 82 | --dropatt=${DROPATT} \ 83 | --conv_kernel_size=${CONV_KERNEL_SIZE} \ 84 | --n_conv_layer=${N_CONV_LAYER} \ 85 | --pool_size=${POOL_SIZE} \ 86 | --d_kernel_map=${D_KERNEL_MAP} \ 87 | --beta_hat_2=${BETA_HAT_2} \ 88 | --model_normalization=${MODEL_NORM} \ 89 | --head_initialization=${HEAD_INIT} \ 90 | --softmax_attn=${SM_ATTN} \ 91 | --max_eval_batch=${MAX_EVAL_BATCH} \ 92 | --do_train=True 93 | elif [[ $1 == 'test' ]]; then 94 | python train_trans.py \ 95 | --use_tpu=${USE_TPU} \ 96 | --data_path=${DATA_PATH} \ 97 | --dataset=${DATASET} \ 98 | --checkpoint_dir=${CKP_DIR} \ 99 | --checkpoint_idx=${CKP_IDX} \ 100 | --warm_start=${WARM_START} \ 101 | --result_path=${RESULT_PATH} \ 102 | --learning_rate=${LEARNING_RATE} \ 103 | --clip=${CLIP} \ 104 | --min_lr_ratio=${MIN_LR_RATIO} \ 105 | --warmup_steps=${WARMUP_STEPS} \ 106 | --input_length=${INPUT_LENGTH} \ 107 | --train_batch_size=${TRAIN_BSZ} \ 108 | --eval_batch_size=${EVAL_BSZ} \ 109 | --train_steps=${TRAIN_STEPS} \ 110 | --iterations=${ITERATIONS} \ 111 | --save_steps=${SAVE_STEPS} \ 112 | --n_layer=${N_LAYER} \ 113 | --d_model=${D_MODEL} \ 114 | --d_head=${D_HEAD} \ 115 | --n_head=${N_HEAD} \ 116 | --d_head=${D_HEAD} \ 117 | --d_inner=${D_INNER} \ 118 | --n_head_softmax=${N_HEAD_SM} \ 119 | --d_head_softmax=${D_HEAD_SM} \ 120 | --dropout=${DROPOUT} \ 121 | --dropatt=${DROPATT} \ 122 | --conv_kernel_size=${CONV_KERNEL_SIZE} \ 123 | --n_conv_layer=${N_CONV_LAYER} \ 124 | --pool_size=${POOL_SIZE} \ 125 | --d_kernel_map=${D_KERNEL_MAP} \ 126 | --beta_hat_2=${BETA_HAT_2} \ 127 | --model_normalization=${MODEL_NORM} \ 128 | --head_initialization=${HEAD_INIT} \ 129 | --softmax_attn=${SM_ATTN} \ 130 | --max_eval_batch=${MAX_EVAL_BATCH} \ 131 | --output_attn=${OUTPUT_ATTN} \ 132 | --do_train=False 133 | else 134 | echo "unknown argument 1" 135 | fi 136 | -------------------------------------------------------------------------------- /run_trans_ches20.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # GPU config 4 | USE_TPU=False 5 | 6 | # Experiment (data/checkpoint/directory) config 7 | DATA_PATH= # Path to the training and validation dataset. 8 | # Training dataset: /.npy 9 | # Validation dataset: /_valid.npy 10 | DATASET=CHES20 11 | CKP_DIR=./ 12 | CKP_IDX=0 13 | WARM_START=False 14 | RESULT_PATH=results 15 | 16 | # Optimization config 17 | LEARNING_RATE=2.5e-4 18 | CLIP=0.25 19 | MIN_LR_RATIO=0.004 20 | INPUT_LENGTH=10000 # or 40000 21 | DATA_DESYNC=200 # 400 for input length 40K 22 | 23 | # Training config 24 | TRAIN_BSZ=64 25 | EVAL_BSZ=16 26 | TRAIN_STEPS=4000000 27 | WARMUP_STEPS=2000000 28 | ITERATIONS=20000 29 | SAVE_STEPS=40000 30 | 31 | # Model config 32 | N_LAYER=2 33 | D_MODEL=128 34 | D_HEAD=32 35 | N_HEAD=8 36 | D_INNER=256 37 | N_HEAD_SM=8 38 | D_HEAD_SM=16 39 | DROPOUT=0.05 40 | CONV_KERNEL_SIZE=3 # The kernel size of the first convolutional layer is set to 11 41 | # This hyper-parameter set the kernel size of the remaining 42 | # convolutional layers 43 | N_CONV_LAYER=2 44 | POOL_SIZE=10 45 | D_KERNEL_MAP=512 46 | BETA_HAT_2=150 # 150 for input length 10K and 450 for input length 40K 47 | MODEL_NORM='preLC' 48 | HEAD_INIT='forward' 49 | SM_ATTN=True 50 | 51 | # Evaluation config 52 | MAX_EVAL_BATCH=100 53 | OUTPUT_ATTN=False 54 | 55 | 56 | if [[ $1 == 'train' ]]; then 57 | python train_trans.py \ 58 | --use_tpu=${USE_TPU} \ 59 | --data_path=${DATA_PATH} \ 60 | --dataset=${DATASET} \ 61 | --checkpoint_dir=${CKP_DIR} \ 62 | --warm_start=${WARM_START} \ 63 | --result_path=${RESULT_PATH} \ 64 | --learning_rate=${LEARNING_RATE} \ 65 | --clip=${CLIP} \ 66 | --min_lr_ratio=${MIN_LR_RATIO} \ 67 | --warmup_steps=${WARMUP_STEPS} \ 68 | --input_length=${INPUT_LENGTH} \ 69 | --data_desync=${DATA_DESYNC} \ 70 | --train_batch_size=${TRAIN_BSZ} \ 71 | --eval_batch_size=${EVAL_BSZ} \ 72 | --train_steps=${TRAIN_STEPS} \ 73 | --iterations=${ITERATIONS} \ 74 | --save_steps=${SAVE_STEPS} \ 75 | --n_layer=${N_LAYER} \ 76 | --d_model=${D_MODEL} \ 77 | --d_head=${D_HEAD} \ 78 | --n_head=${N_HEAD} \ 79 | --d_head=${D_HEAD} \ 80 | --d_inner=${D_INNER} \ 81 | --n_head_softmax=${N_HEAD_SM} \ 82 | --d_head_softmax=${D_HEAD_SM} \ 83 | --dropout=${DROPOUT} \ 84 | --dropatt=${DROPATT} \ 85 | --conv_kernel_size=${CONV_KERNEL_SIZE} \ 86 | --n_conv_layer=${N_CONV_LAYER} \ 87 | --pool_size=${POOL_SIZE} \ 88 | --d_kernel_map=${D_KERNEL_MAP} \ 89 | --beta_hat_2=${BETA_HAT_2} \ 90 | --model_normalization=${MODEL_NORM} \ 91 | --head_initialization=${HEAD_INIT} \ 92 | --softmax_attn=${SM_ATTN} \ 93 | --max_eval_batch=${MAX_EVAL_BATCH} \ 94 | --do_train=True 95 | elif [[ $1 == 'test' ]]; then 96 | python train_trans.py \ 97 | --use_tpu=${USE_TPU} \ 98 | --data_path=${DATA_PATH} \ 99 | --dataset=${DATASET} \ 100 | --checkpoint_dir=${CKP_DIR} \ 101 | --checkpoint_idx=${CKP_IDX} \ 102 | --warm_start=${WARM_START} \ 103 | --result_path=${RESULT_PATH} \ 104 | --learning_rate=${LEARNING_RATE} \ 105 | --clip=${CLIP} \ 106 | --min_lr_ratio=${MIN_LR_RATIO} \ 107 | --warmup_steps=${WARMUP_STEPS} \ 108 | --input_length=${INPUT_LENGTH} \ 109 | --train_batch_size=${TRAIN_BSZ} \ 110 | --eval_batch_size=${EVAL_BSZ} \ 111 | --train_steps=${TRAIN_STEPS} \ 112 | --iterations=${ITERATIONS} \ 113 | --save_steps=${SAVE_STEPS} \ 114 | --n_layer=${N_LAYER} \ 115 | --d_model=${D_MODEL} \ 116 | --d_head=${D_HEAD} \ 117 | --n_head=${N_HEAD} \ 118 | --d_head=${D_HEAD} \ 119 | --d_inner=${D_INNER} \ 120 | --n_head_softmax=${N_HEAD_SM} \ 121 | --d_head_softmax=${D_HEAD_SM} \ 122 | --dropout=${DROPOUT} \ 123 | --dropatt=${DROPATT} \ 124 | --conv_kernel_size=${CONV_KERNEL_SIZE} \ 125 | --n_conv_layer=${N_CONV_LAYER} \ 126 | --pool_size=${POOL_SIZE} \ 127 | --d_kernel_map=${D_KERNEL_MAP} \ 128 | --beta_hat_2=${BETA_HAT_2} \ 129 | --model_normalization=${MODEL_NORM} \ 130 | --head_initialization=${HEAD_INIT} \ 131 | --softmax_attn=${SM_ATTN} \ 132 | --max_eval_batch=${MAX_EVAL_BATCH} \ 133 | --output_attn=${OUTPUT_ATTN} \ 134 | --do_train=False 135 | else 136 | echo "unknown argument 1" 137 | fi 138 | -------------------------------------------------------------------------------- /data_utils_ches20.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import h5py 4 | 5 | import os, sys 6 | 7 | 8 | def sbox_layer(x): 9 | y1 = (x[0] & x[1]) ^ x[2] 10 | y0 = (x[3] & x[0]) ^ x[1] 11 | y3 = (y1 & x[3]) ^ x[0] 12 | y2 = (y0 & y1) ^ x[3] 13 | return np.stack([y0, y1, y2, y3], axis=1) 14 | 15 | 16 | class Dataset: 17 | def __init__(self, data_path, split, input_length, data_desync=0): 18 | self.data_path = data_path 19 | self.split = split 20 | self.input_length = input_length 21 | self.data_desync = data_desync 22 | 23 | data = np.load(data_path) 24 | self.traces = data['traces'] 25 | self.nonces = data['nonces'] 26 | self.umsk_keys = data['umsk_keys'] 27 | 28 | shift = 17 29 | self.nonces = self.nonces >> shift 30 | self.umsk_keys = self.umsk_keys >> shift 31 | if len(self.umsk_keys.shape) == 1: 32 | self.umsk_keys = np.reshape(self.umsk_keys, [1, -1]) 33 | 34 | sbox_in = np.bitwise_xor(self.nonces, self.umsk_keys) 35 | sbox_in = sbox_in.T 36 | sbox_out = sbox_layer(sbox_in) 37 | self.labels = (sbox_out & 0x1) 38 | self.labels = self.labels.astype(np.float32) 39 | 40 | assert (self.input_length + self.data_desync) <= self.traces.shape[1] 41 | self.traces = self.traces[:, :(self.input_length+self.data_desync)] 42 | 43 | self.num_samples = self.traces.shape[0] 44 | 45 | max_split_size = 2000000000//self.input_length 46 | split_idx = list(range(max_split_size, self.num_samples, max_split_size)) 47 | self.traces = np.split(self.traces, split_idx, axis=0) 48 | self.labels = np.split(self.labels, split_idx, axis=0) 49 | 50 | 51 | def GetTFRecords(self, batch_size, training=False): 52 | dataset = tf.data.Dataset.from_tensor_slices((self.traces[0], self.labels[0])) 53 | for traces, labels in zip(self.traces[1:], self.labels[1:]): 54 | temp_dataset = tf.data.Dataset.from_tensor_slices((traces, labels)) 55 | dataset.concatenate(temp_dataset) 56 | 57 | def shift(x, max_desync): 58 | ds = tf.random.uniform([1], 0, max_desync+1, tf.dtypes.int32) 59 | ds = tf.concat([[0], ds], 0) 60 | x = tf.slice(x, ds, [-1, self.input_length]) 61 | return x 62 | 63 | if training == True: 64 | if self.input_length < self.traces[0].shape[1]: 65 | return dataset.repeat() \ 66 | .shuffle(self.num_samples) \ 67 | .batch(batch_size//4) \ 68 | .map(lambda x, y: (shift(x, self.data_desync), y)) \ 69 | .unbatch() \ 70 | .batch(batch_size, drop_remainder=True) \ 71 | .map(lambda x, y: (tf.cast(x, tf.float32), y)) \ 72 | .prefetch(10) 73 | else: 74 | return dataset.repeat() \ 75 | .shuffle(self.num_samples) \ 76 | .batch(batch_size, drop_remainder=True) \ 77 | .map(lambda x, y: (tf.cast(x, tf.float32), y)) \ 78 | .prefetch(10) 79 | 80 | else: 81 | if self.input_length < self.traces[0].shape[1]: 82 | return dataset.batch(batch_size, drop_remainder=True) \ 83 | .map(lambda x, y: (shift(x, 0), y)) \ 84 | .map(lambda x, y: (tf.cast(x, tf.float32), y)) \ 85 | .prefetch(10) 86 | else: 87 | return dataset.batch(batch_size, drop_remainder=True) \ 88 | .map(lambda x, y: (tf.cast(x, tf.float32), y)) \ 89 | .prefetch(10) 90 | 91 | 92 | def GetDataset(self): 93 | return self.traces, self.labels 94 | 95 | 96 | if __name__ == '__main__': 97 | data_path = sys.argv[1] 98 | batch_size = int(sys.argv[2]) 99 | split = sys.argv[3] 100 | 101 | dataset = Dataset(data_path, split, 5) 102 | 103 | print("traces : "+str(dataset.traces.shape)) 104 | print("labels : "+str(dataset.labels.shape)) 105 | print("plaintext : "+str(dataset.plaintexts.shape)) 106 | print("keys : "+str(dataset.keys.shape)) 107 | print("traces ty : "+str(dataset.traces.dtype)) 108 | print("") 109 | print("") 110 | 111 | tfrecords = dataset.GetTFRecords(batch_size, training=True) 112 | iterator = iter(tfrecords) 113 | for i in range(1): 114 | tr, lbl = iterator.get_next() 115 | print(str(tr.shape)+' '+str(lbl.shape)) 116 | print(str(tr.dtype)+' '+str(lbl.dtype)) 117 | print(str(tr[:, :10])) 118 | print(str(lbl[:, :])) 119 | print("") 120 | 121 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import h5py 4 | 5 | import os, sys 6 | 7 | class Dataset: 8 | def __init__(self, data_path, split, input_length, data_desync=0): 9 | self.data_path = data_path 10 | self.split = split 11 | self.input_length = input_length 12 | self.data_desync = data_desync 13 | 14 | corpus = h5py.File(data_path, 'r') 15 | if split == 'train': 16 | split_key = 'Profiling_traces' 17 | elif split == 'test': 18 | split_key = 'Attack_traces' 19 | 20 | self.traces = corpus[split_key]['traces'][:, :(self.input_length+self.data_desync)] 21 | self.labels = np.reshape(corpus[split_key]['labels'][()], [-1, 1]) 22 | self.labels = self.labels.astype(np.int64) 23 | self.num_samples = self.traces.shape[0] 24 | 25 | #assert (self.input_length + self.data_desync) <= self.traces.shape[1] 26 | #self.traces = self.traces[:, :(self.input_length+self.data_desync)] 27 | 28 | max_split_size = 2000000000//self.input_length 29 | split_idx = list(range(max_split_size, self.num_samples, max_split_size)) 30 | self.traces = np.split(self.traces, split_idx, axis=0) 31 | self.labels = np.split(self.labels, split_idx, axis=0) 32 | 33 | #self.traces = self.traces.astype(np.float32) 34 | 35 | self.plaintexts = self.GetPlaintexts(corpus[split_key]['metadata']) 36 | self.masks = self.GetMasks(corpus[split_key]['metadata']) 37 | self.keys = self.GetKeys(corpus[split_key]['metadata']) 38 | 39 | 40 | def GetPlaintexts(self, metadata): 41 | plaintexts = [] 42 | for i in range(len(metadata)): 43 | plaintexts.append(metadata[i]['plaintext'][2]) 44 | return np.array(plaintexts) 45 | 46 | 47 | def GetKeys(self, metadata): 48 | keys = [] 49 | for i in range(len(metadata)): 50 | keys.append(metadata[i]['key'][2]) 51 | return np.array(keys) 52 | 53 | 54 | def GetMasks(self, metadata): 55 | masks = [] 56 | for i in range(len(metadata)): 57 | masks.append(np.array(metadata[i]['masks'])) 58 | masks = np.stack(masks, axis=0) 59 | return masks 60 | 61 | 62 | def GetTFRecords(self, batch_size, training=False): 63 | dataset = tf.data.Dataset.from_tensor_slices((self.traces[0], self.labels[0])) 64 | for traces, labels in zip(self.traces[1:], self.labels[1:]): 65 | temp_dataset = tf.data.Dataset.from_tensor_slices((traces, labels)) 66 | dataset.concatenate(temp_dataset) 67 | 68 | def shift(x, max_desync): 69 | ds = tf.random.uniform([1], 0, max_desync+1, tf.dtypes.int32) 70 | ds = tf.concat([[0], ds], 0) 71 | x = tf.slice(x, ds, [-1, self.input_length]) 72 | return x 73 | 74 | if training == True: 75 | if self.input_length < self.traces[0].shape[1]: 76 | return dataset.repeat() \ 77 | .shuffle(self.num_samples) \ 78 | .batch(batch_size//2) \ 79 | .map(lambda x, y: (shift(x, self.data_desync), y)) \ 80 | .unbatch() \ 81 | .batch(batch_size, drop_remainder=True) \ 82 | .map(lambda x, y: (tf.cast(x, tf.float32), y)) \ 83 | .prefetch(10) 84 | else: 85 | return dataset.repeat() \ 86 | .shuffle(self.num_samples) \ 87 | .batch(batch_size, drop_remainder=True) \ 88 | .map(lambda x, y: (tf.cast(x, tf.float32), y)) \ 89 | .prefetch(10) 90 | 91 | else: 92 | if self.input_length < self.traces[0].shape[1]: 93 | return dataset.batch(batch_size, drop_remainder=True) \ 94 | .map(lambda x, y: (shift(x, 0), y)) \ 95 | .map(lambda x, y: (tf.cast(x, tf.float32), y)) \ 96 | .prefetch(10) 97 | else: 98 | return dataset.batch(batch_size, drop_remainder=True) \ 99 | .map(lambda x, y: (tf.cast(x, tf.float32), y)) \ 100 | .prefetch(10) 101 | 102 | 103 | def GetDataset(self): 104 | return self.traces, self.labels 105 | 106 | 107 | if __name__ == '__main__': 108 | data_path = sys.argv[1] 109 | batch_size = int(sys.argv[2]) 110 | split = sys.argv[3] 111 | 112 | dataset = Dataset(data_path, split, 5) 113 | 114 | print("traces : "+str(dataset.traces.shape)) 115 | print("labels : "+str(dataset.labels.shape)) 116 | print("plaintext : "+str(dataset.plaintexts.shape)) 117 | print("keys : "+str(dataset.keys.shape)) 118 | print("traces ty : "+str(dataset.traces.dtype)) 119 | print("") 120 | print("") 121 | 122 | tfrecords = dataset.GetTFRecords(batch_size, training=True) 123 | iterator = iter(tfrecords) 124 | for i in range(1): 125 | tr, lbl = iterator.get_next() 126 | print(str(tr.shape)+' '+str(lbl.shape)) 127 | print(str(tr.dtype)+' '+str(lbl.dtype)) 128 | print(str(tr[:, :10])) 129 | print(str(lbl[:, :])) 130 | print("") 131 | 132 | -------------------------------------------------------------------------------- /fast_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | def shape_list(x): 7 | static = x.shape.as_list() 8 | dynamic = tf.shape(x) 9 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 10 | 11 | 12 | def gen_projection_matrix(m, d, seed=0): 13 | n_block = m // d 14 | block_list = [] 15 | cur_seed = seed 16 | for _ in range(n_block): 17 | block = tf.random.normal((d, d), seed=cur_seed) 18 | q, _ = tf.linalg.qr(block) 19 | q = tf.transpose(q) 20 | block_list.append(q) 21 | cur_seed += 1 22 | rem_rows = m - n_block * d 23 | if rem_rows > 0: 24 | block = tf.random.normal((d, d), seed=cur_seed) 25 | q, _ = tf.linalg.qr(block) 26 | q = tf.transpose(q) 27 | block_list.append(q[0:rem_rows]) 28 | proj_matrix = tf.experimental.numpy.vstack(block_list) 29 | cur_seed += 1 30 | 31 | multiplier = tf.norm(tf.random.normal((m, d), seed=cur_seed), axis=1) 32 | 33 | return tf.linalg.matmul(tf.linalg.diag(multiplier), proj_matrix) 34 | 35 | 36 | def positive_kernel_transformation(data, 37 | is_query, 38 | projection_matrix=None, 39 | numerical_stabilizer=0.000001): 40 | data_normalizer = 1.0 / (tf.dtypes.cast(data.shape[-1], tf.float32) ** 0.25) 41 | data= data_normalizer * data 42 | ratio = 1.0 / (tf.dtypes.cast(projection_matrix.shape[0], tf.float32) ** 0.5) 43 | data_dash = tf.einsum("blhd,md->blhm", data, projection_matrix) 44 | diag_data = tf.square(data) 45 | diag_data = tf.reduce_sum(diag_data, axis=-1, keepdims=True) 46 | last_dims_t = (len(data_dash.shape) - 1,) 47 | attention_dims_t = (len(data_dash.shape) - 3,) 48 | if is_query: 49 | data_dash = ratio * ( 50 | tf.math.exp(data_dash - diag_data - tf.math.reduce_max( 51 | data_dash, axis=-1, keepdims=True)) + numerical_stabilizer) 52 | else: 53 | data_dash = ratio * ( 54 | tf.math.exp(data_dash - diag_data - tf.math.reduce_max( 55 | data_dash, axis=[-3, -1], keepdims=True)) + numerical_stabilizer) 56 | 57 | return data_dash 58 | 59 | 60 | def fourier_kernel_transformation(data, projection_matrix): 61 | data_normalizer = 1.0 / (tf.dtypes.cast(data.shape[-1], tf.float32) ** 0.25) 62 | data = data_normalizer * data 63 | ratio = 1.0 / (tf.cast(projection_matrix.shape[0], tf.float32) ** 0.5) 64 | data_dash = tf.einsum("blhd,md->blhm", data, projection_matrix) 65 | data_sin = ratio * tf.math.sin(data_dash) 66 | data_cos = ratio * tf.math.cos(data_dash) 67 | 68 | return tf.concat([data_sin, data_cos], axis=-1) 69 | 70 | 71 | def attention_numerator(qs, ks, vs): 72 | kvs = tf.einsum("lbhm,lbhd->bhmd", ks, vs) 73 | return tf.einsum("lbhm,bhmd->lbhd", qs, kvs) 74 | 75 | 76 | def attention_denominator(qs, ks): 77 | all_ones = tf.ones([ks.shape[0]]) 78 | ks_sum = tf.einsum("lbhm,l->bhm", ks, all_ones) 79 | return tf.einsum("lbhm,bhm->lbh", qs, ks_sum) 80 | 81 | 82 | def linear_attention(value, 83 | query_pos_ft, 84 | key_pos_ft, 85 | projection_matrix=None, 86 | feature_map_type='fourier', 87 | normalize_attn=False): 88 | if feature_map_type == 'fourier': 89 | query_prime = fourier_kernel_transformation(query_pos_ft, projection_matrix) # [B,L,H,M] 90 | key_prime = fourier_kernel_transformation(key_pos_ft, projection_matrix) # [B,L,H,M] 91 | elif feature_map_type == 'positive': 92 | query_prime = positive_kernel_transformation(query_pos_ft, True, projection_matrix) # [B,L,H,M] 93 | key_prime = positive_kernel_transformation(key_pos_ft, False, projection_matrix) # [B,L,H,M] 94 | else: 95 | assert False, "feature_type must be in ['trig', 'positive']" 96 | 97 | query_prime = tf.transpose(query_prime, [1, 0, 2, 3]) # [L,B,H,M] 98 | key_prime = tf.transpose(key_prime, [1, 0, 2, 3]) # [L,B,H,M] 99 | value = tf.transpose(value, [1, 0, 2, 3]) # [L,B,H,D] 100 | 101 | av_attention = attention_numerator(query_prime, key_prime, value) 102 | av_attention = tf.transpose(av_attention, [1, 0, 2, 3]) 103 | if normalize_attn: 104 | attention_normalizer = attention_denominator(query_prime, key_prime) 105 | attention_normalizer = tf.transpose(attention_normalizer, [1, 0, 2]) 106 | attention_normalizer = tf.expand_dims(attention_normalizer, 107 | len(attention_normalizer.shape)) 108 | av_attention = av_attention / attention_normalizer 109 | return [av_attention, query_prime, key_prime] 110 | 111 | 112 | class SelfAttention(tf.keras.layers.Layer): 113 | def __init__(self, 114 | d_model, 115 | d_head, 116 | n_head, 117 | attention_dropout, 118 | feature_map_type='fourier', 119 | normalize_attn=False, 120 | d_kernel_map=128, 121 | head_init_range=(0, 1), 122 | **kwargs): 123 | 124 | super(SelfAttention, self).__init__() 125 | self.d_model = d_model 126 | self.size_per_head = d_head 127 | self.n_head = n_head 128 | self.attention_dropout = attention_dropout 129 | self.d_kernel_map = d_kernel_map 130 | self.feature_map_type = feature_map_type 131 | self.normalize_attn = normalize_attn 132 | self.head_init_range = head_init_range 133 | 134 | def _glorot_initializer(fan_in, fan_out): 135 | limit = math.sqrt(6.0 / (fan_in + fan_out)) 136 | return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit) 137 | 138 | attention_initializer = _glorot_initializer(self.d_model, self.n_head*self.size_per_head) 139 | 140 | self.value_weight = self.add_weight( 141 | "value_weight", 142 | shape=(self.d_model, self.n_head, self.size_per_head), 143 | initializer=attention_initializer, 144 | dtype=tf.float32, 145 | trainable=True) 146 | self.pos_ft_weight = self.add_weight( 147 | "pos_ft_weight", 148 | shape=(self.d_model, self.n_head, self.size_per_head), 149 | initializer=attention_initializer, 150 | dtype=tf.float32, 151 | trainable=False) 152 | self.pos_ft_scale = self.add_weight( 153 | "pos_ft_scale", 154 | shape=(1, 1, self.n_head, 1), 155 | initializer=tf.keras.initializers.Constant(1), 156 | dtype=tf.float32, 157 | trainable=True) 158 | 159 | head_left = self.head_init_range[0] 160 | head_right = self.head_init_range[1] 161 | head_range = head_right - head_left 162 | head_pos = tf.range(head_left+head_range/(2.*self.n_head), head_right, head_range/self.n_head) 163 | self.pos_ft_offsets = self.add_weight( 164 | "pos_ft_offests", 165 | shape=(1, 1, self.n_head, 1), 166 | initializer=tf.keras.initializers.Constant(head_pos), 167 | dtype=tf.float32, 168 | trainable=True) 169 | 170 | output_initializer = _glorot_initializer(self.n_head*self.size_per_head, self.d_model) 171 | self.output_weight = self.add_weight( 172 | "output_weight", 173 | shape=(self.n_head*self.size_per_head, self.d_model), 174 | initializer=output_initializer, 175 | dtype=tf.float32, 176 | trainable=True) 177 | self.output_dropout = tf.keras.layers.Dropout(self.attention_dropout) 178 | 179 | seed = np.random.randint(1e8, dtype=np.int32) 180 | projection_matrix = gen_projection_matrix( 181 | self.d_kernel_map, self.size_per_head, seed=seed) 182 | initializer = tf.keras.initializers.Constant(projection_matrix) 183 | self.projection_matrix = self.add_weight( 184 | "projection_matrix", 185 | shape=projection_matrix.shape, 186 | initializer=initializer, 187 | dtype=projection_matrix.dtype, 188 | trainable=False) 189 | 190 | 191 | def call(self, 192 | source_input, 193 | pos_ft, 194 | pos_ft_slopes, 195 | training): 196 | value = tf.einsum("bnm,mhd->bnhd", source_input, self.value_weight) 197 | pos_ft_projected = tf.einsum("bnm,mhd->bnhd", pos_ft, self.pos_ft_weight) 198 | pos_ft_slopes_projected = tf.einsum("bnm,mhd->bnhd", pos_ft_slopes, self.pos_ft_weight) 199 | 200 | query_pos_ft = self.pos_ft_scale*pos_ft_projected 201 | slope_pos = self.pos_ft_scale*pos_ft_slopes_projected 202 | key_pos_ft = query_pos_ft + self.pos_ft_offsets*slope_pos 203 | 204 | attention_outputs = linear_attention(value, 205 | query_pos_ft, key_pos_ft, 206 | self.projection_matrix, 207 | self.feature_map_type, 208 | self.normalize_attn) 209 | 210 | bsz, slen = shape_list(attention_outputs[0])[:2] 211 | 212 | norms = tf.norm(pos_ft_slopes_projected, axis=-1, keepdims=True)/float(slen) 213 | attention_outputs[0] = norms*attention_outputs[0] 214 | 215 | attention_outputs[0] = tf.reshape(attention_outputs[0], [bsz, slen, -1]) 216 | attention_outputs[0] = tf.einsum("bnm,md->bnd", attention_outputs[0], self.output_weight) 217 | attention_outputs[0] = self.output_dropout(attention_outputs[0], training=training) 218 | 219 | return attention_outputs 220 | 221 | 222 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from fast_attention import SelfAttention 3 | from normalization import LayerScaling, LayerCentering 4 | from tensorflow.keras.layers.experimental import SyncBatchNormalization 5 | 6 | 7 | def shape_list(x): 8 | static = x.shape.as_list() 9 | dynamic = tf.shape(x) 10 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 11 | 12 | 13 | class PositionalFeature(tf.keras.layers.Layer): 14 | def __init__(self, d_feature, beta_hat_2, **kwargs): 15 | super().__init__(**kwargs) 16 | 17 | self.slopes = tf.range(d_feature, 0, -4.0, dtype=tf.float32) / d_feature 18 | self.slopes = self.slopes * beta_hat_2 19 | 20 | def call(self, slen, bsz=None): 21 | pos_seq = tf.range(0, slen, 1.0, dtype=tf.float32) 22 | normalized_slopes = (1. / float(slen-1)) * self.slopes 23 | forward = tf.einsum("i,j->ij", pos_seq, normalized_slopes) 24 | backward = tf.reverse(forward, axis=[0]) 25 | neg_forward = -tf.identity(forward) 26 | neg_backward = -tf.identity(backward) 27 | pos_feature = tf.concat([forward, backward, neg_forward, neg_backward], -1) 28 | 29 | pos_feature_slopes = tf.concat( 30 | [tf.identity(normalized_slopes), 31 | -tf.identity(normalized_slopes), 32 | -tf.identity(normalized_slopes), 33 | tf.identity(normalized_slopes)], axis=0) 34 | pos_feature_slopes = float(slen-1)*tf.reshape(pos_feature_slopes, [1, -1]) 35 | 36 | if bsz is not None: 37 | pos_feature = tf.tile(pos_feature[None, :, :], [bsz, 1, 1]) 38 | pos_feature_slopes = tf.tile(pos_feature_slopes[None, :, :], [bsz, 1, 1]) 39 | else: 40 | pos_feature = pos_feature[None, :, :] 41 | pos_feature_slopes = pos_feature_slopes[None, :, :] 42 | return pos_feature, pos_feature_slopes 43 | 44 | 45 | class PositionwiseFF(tf.keras.layers.Layer): 46 | def __init__(self, d_model, d_inner, dropout, **kwargs): 47 | super().__init__(**kwargs) 48 | 49 | self.d_model = d_model 50 | self.d_inner = d_inner 51 | self.dropout = dropout 52 | 53 | self.layer_1 = tf.keras.layers.Dense( 54 | d_inner, activation=tf.nn.relu, name='layer_1' 55 | ) 56 | self.drop_1 = tf.keras.layers.Dropout(dropout, name='drop_1') 57 | self.layer_2 = tf.keras.layers.Dense(d_model, name='layer_2') 58 | self.drop_2 = tf.keras.layers.Dropout(dropout, name='drop_2') 59 | 60 | 61 | def call(self, inp, training=False): 62 | core_out = inp 63 | core_out = self.layer_1(core_out) 64 | core_out = self.drop_1(core_out, training=training) 65 | core_out = self.layer_2(core_out) 66 | core_out = self.drop_2(core_out, training=training) 67 | 68 | output = [core_out] 69 | return output 70 | 71 | 72 | class TransformerLayer(tf.keras.layers.Layer): 73 | def __init__( 74 | self, 75 | n_head, 76 | d_head, 77 | d_model, 78 | d_inner, 79 | dropout, 80 | feature_map_type, 81 | normalize_attn, 82 | d_kernel_map, 83 | model_normalization, 84 | head_init_range, 85 | **kwargs 86 | ): 87 | super().__init__(**kwargs) 88 | 89 | self.n_head = n_head 90 | self.d_head = d_head 91 | self.d_model = d_model 92 | self.d_inner = d_inner 93 | self.dropout = dropout 94 | self.feature_map_type = feature_map_type 95 | self.normalize_attn = normalize_attn 96 | self.d_kernel_map = d_kernel_map 97 | self.model_normalization = model_normalization 98 | self.head_init_range = head_init_range 99 | 100 | self.self_attn = SelfAttention( 101 | d_model=self.d_model, 102 | d_head=self.d_head, 103 | n_head=self.n_head, 104 | attention_dropout=self.dropout, 105 | feature_map_type=self.feature_map_type, 106 | normalize_attn=self.normalize_attn, 107 | d_kernel_map=self.d_kernel_map, 108 | head_init_range = self.head_init_range, 109 | name="tran_attn", 110 | ) 111 | self.pos_ff = PositionwiseFF( 112 | d_model=self.d_model, 113 | d_inner=self.d_inner, 114 | dropout=self.dropout, 115 | name="pos_ff", 116 | ) 117 | 118 | assert self.model_normalization in ['preLC', 'postLC', 'none'], "model_normalization must be one of 'preLC', 'postLC' or 'none'" 119 | if self.model_normalization in ['preLC', 'postLC']: 120 | self.lc1 = LayerCentering() 121 | self.lc2 = LayerCentering() 122 | 123 | 124 | def call(self, inputs, training=False): 125 | inp, pos_ft, pos_ft_slopes = inputs 126 | if self.model_normalization == 'preLC': 127 | attn_in = self.lc1(inp) 128 | else: 129 | attn_in = inp 130 | attn_outputs = self.self_attn(attn_in, pos_ft, pos_ft_slopes, 131 | training=training) 132 | attn_outputs[0] = attn_outputs[0] + inp 133 | if self.model_normalization == 'postLC': 134 | attn_outputs[0] = self.lc1(attn_outputs[0]) 135 | 136 | if self.model_normalization == 'preLC': 137 | ff_in = self.lc2(attn_outputs[0]) 138 | else: 139 | ff_in = attn_outputs[0] 140 | ff_output = self.pos_ff(ff_in, training=training) 141 | ff_output[0] = ff_output[0] + attn_outputs[0] 142 | if self.model_normalization == 'postLC': 143 | ff_output[0] = self.lc2(ff_output[0]) 144 | 145 | outputs = [ff_output[0]] + attn_outputs[1:] 146 | 147 | return outputs 148 | 149 | 150 | class SoftmaxAttention(tf.keras.layers.Layer): 151 | def __init__(self, d_model, n_head, d_head, **kwargs): 152 | super().__init__(**kwargs) 153 | 154 | self.d_model = d_model 155 | self.n_head = n_head 156 | self.d_head = d_head 157 | 158 | self.q_heads = self.add_weight( 159 | shape=(self.d_head, self.n_head), name="q_heads" 160 | ) 161 | self.k_net = tf.keras.layers.Dense( 162 | self.d_head * self.n_head, name="k_net" 163 | ) 164 | self.v_net = tf.keras.layers.Dense( 165 | self.d_head * self.n_head, name="v_net" 166 | ) 167 | 168 | self.scale = 1. / (self.d_head ** 0.5) 169 | 170 | 171 | def build(self, input_shape): 172 | self.softmax_attn_smoothing = self.add_weight( 173 | "softmax_attn_smoothing", 174 | shape=(), 175 | initializer=tf.keras.initializers.Constant(0), 176 | dtype=tf.float32, 177 | trainable=False) 178 | 179 | 180 | def call(self, inp, softmax_attn_smoothing, training=False): 181 | bsz, slen = inp.shape[:2] 182 | if training: 183 | self.softmax_attn_smoothing.assign(softmax_attn_smoothing) 184 | 185 | k_head = self.k_net(inp) 186 | v_head = self.v_net(inp) 187 | 188 | k_head = tf.reshape(k_head, [-1, slen, self.d_head, self.n_head]) 189 | v_head = tf.reshape(v_head, [-1, slen, self.d_head, self.n_head]) 190 | 191 | attn_score = tf.einsum("bndh,dh->bnh", k_head, self.q_heads) 192 | attn_score = attn_score * self.scale * self.softmax_attn_smoothing 193 | 194 | attn_prob = tf.nn.softmax(attn_score, axis=1) 195 | 196 | attn_out = tf.einsum("bndh,bnh->bnhd", v_head, attn_prob) 197 | attn_out = tf.reshape(attn_out, [bsz, slen, -1]) 198 | 199 | return attn_out, attn_score 200 | 201 | 202 | class Transformer(tf.keras.Model): 203 | def __init__(self, n_layer, d_model, d_head, n_head, d_inner, 204 | d_head_softmax, n_head_softmax, dropout, n_classes, 205 | conv_kernel_size, n_conv_layer, pool_size, d_kernel_map, beta_hat_2, 206 | model_normalization, head_initialization='forward', 207 | softmax_attn=True, output_attn=False): 208 | 209 | super(Transformer, self).__init__() 210 | 211 | self.n_layer = n_layer 212 | self.d_model = d_model 213 | self.d_head = d_head 214 | self.n_head = n_head 215 | self.d_inner = d_inner 216 | self.d_head_softmax = d_head_softmax 217 | self.n_head_softmax = n_head_softmax 218 | self.feature_map_type = 'fourier' 219 | self.normalize_attn = False 220 | self.d_kernel_map = d_kernel_map 221 | self.beta_hat_2 = beta_hat_2 222 | self.model_normalization = model_normalization 223 | self.head_initialization = head_initialization 224 | self.softmax_attn = softmax_attn 225 | 226 | self.dropout = dropout 227 | 228 | self.n_classes = n_classes 229 | 230 | self.conv_kernel_size = conv_kernel_size 231 | self.n_conv_layer = n_conv_layer 232 | self.pool_size = pool_size 233 | 234 | self.output_attn = output_attn 235 | 236 | conv_filters = [min(8*2**i, self.d_model) for i in range(self.n_conv_layer-1)] + [self.d_model] 237 | 238 | self.conv_layers = [] 239 | self.norm_layers = [] 240 | self.relu_layers = [] 241 | self.pool_layers = [] 242 | 243 | for l in range(self.n_conv_layer): 244 | ks = 11 if l is 0 else self.conv_kernel_size 245 | self.conv_layers.append(tf.keras.layers.Conv1D(conv_filters[l], ks)) 246 | self.relu_layers.append(tf.keras.layers.ReLU()) 247 | self.pool_layers.append(tf.keras.layers.AveragePooling1D(self.pool_size, self.pool_size)) 248 | 249 | self.pos_feature = PositionalFeature(self.d_model, self.beta_hat_2) 250 | 251 | head_init_ranges = [] 252 | if self.head_initialization == 'forward': 253 | for i in range(self.n_layer): 254 | if i == 0: 255 | head_init_ranges.append((0., 0.5)) 256 | else: 257 | head_init_ranges.append((0., 1.0)) 258 | elif self.head_initialization == 'backward': 259 | for i in range(self.n_layer): 260 | if i == 0: 261 | head_init_ranges.append((-0.5, 0.0)) 262 | else: 263 | head_init_ranges.append((-1.0, 0.0)) 264 | elif self.head_initialization == 'symmetric': 265 | for i in range(self.n_layer): 266 | if i == 0: 267 | head_init_ranges.append((-0.5, 0.5)) 268 | else: 269 | head_init_ranges.append((-1.0, 1.0)) 270 | else: 271 | assert False, "head_initialization can be one of ['forward', 'backward', 'symmetric']" 272 | 273 | self.tran_layers = [] 274 | for i in range(self.n_layer): 275 | self.tran_layers.append( 276 | TransformerLayer( 277 | n_head=self.n_head, 278 | d_head=self.d_head, 279 | d_model=self.d_model, 280 | d_inner=self.d_inner, 281 | dropout=self.dropout, 282 | feature_map_type=self.feature_map_type, 283 | normalize_attn=self.normalize_attn, 284 | d_kernel_map=self.d_kernel_map, 285 | model_normalization=self.model_normalization, 286 | head_init_range = head_init_ranges[i], 287 | name='layers_._{}'.format(i) 288 | ) 289 | ) 290 | 291 | self.out_dropout = tf.keras.layers.Dropout(dropout, name='out_drop') 292 | 293 | if self.softmax_attn: 294 | self.out_attn = SoftmaxAttention(d_model=self.d_model, n_head=self.n_head_softmax, 295 | d_head=self.d_head_softmax) 296 | self.fc_output = tf.keras.layers.Dense(self.n_classes) 297 | 298 | def call(self, inp, softmax_attn_smoothing=1, training=False): 299 | # convert the input dimension from [bsz, len] to [bsz, len, 1] 300 | inp = tf.expand_dims(inp, axis=-1) 301 | 302 | # apply the convolution blocks 303 | for l in range(self.n_conv_layer): 304 | inp = self.conv_layers[l](inp) 305 | inp = self.relu_layers[l](inp) 306 | inp = self.pool_layers[l](inp) 307 | 308 | bsz, slen = shape_list(inp)[:2] 309 | 310 | pos_ft, pos_ft_slopes = self.pos_feature(slen, bsz) 311 | 312 | core_out = inp 313 | out_list = [] 314 | for i, layer in enumerate(self.tran_layers): 315 | all_out = layer([core_out, pos_ft, pos_ft_slopes], training=training) 316 | core_out = all_out[0] 317 | out_list.append(all_out[1:]) 318 | core_out = self.out_dropout(core_out, training=training) 319 | 320 | # take the evarage across the first (len) dimension to get the final representation 321 | if self.softmax_attn: 322 | core_out, softmax_attn_score = self.out_attn(core_out, softmax_attn_smoothing, training=training) 323 | else: 324 | softmax_attn_score = None 325 | output = tf.reduce_mean(core_out, axis=1) 326 | 327 | # ge the final scores for all classes 328 | scores = self.fc_output(output) 329 | 330 | for i in range(len(out_list)): 331 | for j in range(len(out_list[i])): 332 | out_list[i][j] = tf.transpose(out_list[i][j], [1, 0, 2, 3]) 333 | 334 | if self.output_attn: 335 | return [scores, out_list, softmax_attn_score] 336 | else: 337 | return [scores] 338 | 339 | 340 | -------------------------------------------------------------------------------- /train_trans.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import math 7 | import time 8 | import random 9 | import pickle 10 | 11 | from absl import flags 12 | import absl.logging as _logging # pylint: disable=unused-import 13 | 14 | import tensorflow as tf 15 | import data_utils 16 | import data_utils_ches20 17 | from transformer import Transformer 18 | import evaluation_utils 19 | import evaluation_utils_ches20 20 | import pickle 21 | 22 | import numpy as np 23 | 24 | # GPU config 25 | flags.DEFINE_bool("use_tpu", default=False, 26 | help="Use TPUs rather than plain CPUs.") 27 | 28 | # Experiment (data/checkpoint/directory) config 29 | flags.DEFINE_string("data_path", default="", 30 | help="Path to data file") 31 | flags.DEFINE_string("dataset", default="ASCAD", 32 | help="Name of the dataset (ASCAD, CHES20, AES_RD, DPAv42)") 33 | flags.DEFINE_string("checkpoint_dir", default=None, 34 | help="directory for saving checkpoint.") 35 | flags.DEFINE_integer("checkpoint_idx", default=0, 36 | help="checkpoints index to restore.") 37 | flags.DEFINE_bool("warm_start", default=False, 38 | help="Whether to warm start training from checkpoint.") 39 | flags.DEFINE_string("result_path", default="", 40 | help="Path for eval results") 41 | flags.DEFINE_bool("do_train", default=False, 42 | help="Whether to perform training or evaluation") 43 | 44 | # Optimization config 45 | flags.DEFINE_float("learning_rate", default=2.5e-4, 46 | help="Maximum learning rate.") 47 | flags.DEFINE_float("clip", default=0.25, 48 | help="Gradient clipping value.") 49 | # for cosine decay 50 | flags.DEFINE_float("min_lr_ratio", default=0.004, 51 | help="Minimum ratio learning rate.") 52 | flags.DEFINE_integer("warmup_steps", default=0, 53 | help="Number of steps for linear lr warmup.") 54 | flags.DEFINE_integer("input_length", default=700, 55 | help="The input length for TN model") 56 | flags.DEFINE_integer("data_desync", default=0, 57 | help="Max trace desync for data augmentation") 58 | 59 | # Training config 60 | flags.DEFINE_integer("train_batch_size", default=256, 61 | help="Size of train batch.") 62 | flags.DEFINE_integer("eval_batch_size", default=32, 63 | help="Size of valid batch.") 64 | flags.DEFINE_integer("train_steps", default=100000, 65 | help="Total number of training steps.") 66 | flags.DEFINE_integer("iterations", default=500, 67 | help="Number of iterations per repeat loop.") 68 | flags.DEFINE_integer("save_steps", default=10000, 69 | help="number of steps for model checkpointing.") 70 | 71 | # Model config 72 | flags.DEFINE_integer("n_layer", default=6, 73 | help="Number of layers.") 74 | flags.DEFINE_integer("d_model", default=128, 75 | help="Dimension of the model (d).") 76 | flags.DEFINE_integer("d_head", default=32, 77 | help="Dimension of each head (d_v).") 78 | flags.DEFINE_integer("n_head", default=4, 79 | help="Number of attention heads (H).") 80 | flags.DEFINE_integer("d_inner", default=256, 81 | help="Dimension of inner hidden size in positionwise feed-forward.") 82 | flags.DEFINE_integer("n_head_softmax", default=4, 83 | help="Number of attention heads in softmax attention") 84 | flags.DEFINE_integer("d_head_softmax", default=32, 85 | help="Dimension of each head in softmax attention") 86 | flags.DEFINE_integer("d_kernel_map", default=128, 87 | help="Dimension of the kernel feature map (d_e).") 88 | flags.DEFINE_integer("beta_hat_2", default=100, 89 | help="Distance based scaling in the kernel of self-attention") 90 | flags.DEFINE_float("dropout", default=0.1, 91 | help="Dropout rate.") 92 | flags.DEFINE_integer("conv_kernel_size", default=3, 93 | help="Kernel size of all but the first convolution layers") 94 | flags.DEFINE_integer("n_conv_layer", default=1, 95 | help="Number of convolutional blocks") 96 | flags.DEFINE_integer("pool_size", default=2, 97 | help="Pooling size of the average pooling layers") 98 | flags.DEFINE_string("model_normalization", default='preLC', 99 | help="Normalization type used to normalize layer, can be in ['preLC', 'postLC', 'none']") 100 | flags.DEFINE_string("head_initialization", default='forward', 101 | help="Type of the initialization of the positional attention heads, can be in ['forward', 'backward', 'symmetric']") 102 | flags.DEFINE_bool("softmax_attn", default='True', 103 | help="Whether to use softmax attention instead of global pooling") 104 | 105 | # Evaluation config 106 | flags.DEFINE_integer("max_eval_batch", default=-1, 107 | help="Set -1 to turn off.") 108 | flags.DEFINE_bool("output_attn", default=False, 109 | help="output attention probabilities") 110 | 111 | 112 | FLAGS = flags.FLAGS 113 | 114 | 115 | class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): 116 | def __init__(self, max_lr, tr_steps, wu_steps=0, min_lr_ratio=0.0): 117 | self.max_lr=max_lr 118 | self.tr_steps=tr_steps 119 | self.wu_steps=wu_steps 120 | self.min_lr_ratio=min_lr_ratio 121 | def __call__(self, step): 122 | step_float = tf.cast(step, tf.float32) 123 | wu_steps_float = tf.cast(self.wu_steps, tf.float32) 124 | tr_steps_float = tf.cast(self.tr_steps, tf.float32) 125 | max_lr_float =tf.cast(self.max_lr, tf.float32) 126 | min_lr_ratio_float = tf.cast(self.min_lr_ratio, tf.float32) 127 | 128 | # warmup learning rate using linear schedule 129 | wu_lr = (step_float/wu_steps_float) * max_lr_float 130 | 131 | # decay the learning rate using the cosine schedule 132 | global_step = tf.math.minimum(step_float-wu_steps_float, tr_steps_float-wu_steps_float) 133 | decay_steps = tr_steps_float-wu_steps_float 134 | pi = tf.constant(math.pi) 135 | cosine_decay = .5 * (1. + tf.math.cos(pi * global_step / decay_steps)) 136 | decayed = (1.-min_lr_ratio_float) * cosine_decay + min_lr_ratio_float 137 | decay_lr = max_lr_float * decayed 138 | return tf.cond(step < self.wu_steps, lambda: wu_lr, lambda: decay_lr) 139 | 140 | 141 | def create_model(n_classes): 142 | model = Transformer( 143 | n_layer = FLAGS.n_layer, 144 | d_model = FLAGS.d_model, 145 | d_head = FLAGS.d_head, 146 | n_head = FLAGS.n_head, 147 | d_inner = FLAGS.d_inner, 148 | n_head_softmax = FLAGS.n_head_softmax, 149 | d_head_softmax = FLAGS.d_head_softmax, 150 | dropout = FLAGS.dropout, 151 | n_classes = n_classes, 152 | conv_kernel_size = FLAGS.conv_kernel_size, 153 | n_conv_layer = FLAGS.n_conv_layer, 154 | pool_size = FLAGS.pool_size, 155 | d_kernel_map = FLAGS.d_kernel_map, 156 | beta_hat_2 = FLAGS.beta_hat_2, 157 | model_normalization = FLAGS.model_normalization, 158 | head_initialization = FLAGS.head_initialization, 159 | softmax_attn = FLAGS.softmax_attn, 160 | output_attn = FLAGS.output_attn 161 | ) 162 | 163 | return model 164 | 165 | 166 | def train(train_dataset, eval_dataset, num_train_batch, num_eval_batch, strategy, chk_name): 167 | # Ensure that the batch sizes are divisible by number of replicas in sync 168 | assert(FLAGS.train_batch_size % strategy.num_replicas_in_sync == 0) 169 | assert(FLAGS.eval_batch_size % strategy.num_replicas_in_sync == 0) 170 | 171 | ##### Create computational graph for train dataset 172 | train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) 173 | ##### Create computational graph for eval dataset 174 | eval_dist_dataset = strategy.experimental_distribute_dataset(eval_dataset) 175 | 176 | if FLAGS.save_steps <= 0: 177 | FLAGS.save_steps = None 178 | else: 179 | # Set the FLAGS.save_steps to a value multiple of FLAGS.iterations 180 | if FLAGS.save_steps < FLAGS.iterations: 181 | FLAGS.save_steps = FLAGS.iterations 182 | else: 183 | FLAGS.save_steps = (FLAGS.save_steps // FLAGS.iterations) * \ 184 | FLAGS.iterations 185 | ##### Instantiate learning rate scheduler object 186 | lr_sch = LRSchedule( 187 | FLAGS.learning_rate, FLAGS.train_steps, \ 188 | FLAGS.warmup_steps, FLAGS.min_lr_ratio 189 | ) 190 | 191 | loss_dic_file = os.path.join(FLAGS.checkpoint_dir, 'loss.pkl') 192 | 193 | ##### Create computational graph for model 194 | with strategy.scope(): 195 | if FLAGS.dataset == 'CHES20': 196 | model = create_model(4) 197 | else: 198 | model = create_model(256) 199 | optimizer = tf.keras.optimizers.Adam(learning_rate=lr_sch) 200 | checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 201 | 202 | train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) 203 | eval_loss = tf.keras.metrics.Mean('eval_loss', dtype=tf.float32) 204 | grad_norm = tf.keras.metrics.Mean('grad_norms', dtype=tf.float32) 205 | 206 | new_start = True 207 | if FLAGS.warm_start: 208 | options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 209 | chk_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 210 | if chk_path is None: 211 | tf.compat.v1.logging.info("Could not find any checkpoint, starting training from beginning") 212 | else: 213 | tf.compat.v1.logging.info("Found checkpoint: {}".format(chk_path)) 214 | try: 215 | checkpoint.restore(chk_path, options=options) 216 | tf.compat.v1.logging.info("Restored checkpoint: {}".format(chk_path)) 217 | new_start = False 218 | except: 219 | tf.compat.v1.logging.info("Could not restore checkpoint, starting training from beginning") 220 | 221 | if new_start == True: 222 | # Save the initial model 223 | chk_path = os.path.join(FLAGS.checkpoint_dir, chk_name) 224 | options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 225 | save_path = checkpoint.save(chk_path, options=options) 226 | tf.compat.v1.logging.info("Model saved in path: {}".format(save_path)) 227 | 228 | loss_dic = {} 229 | pickle.dump(loss_dic, open(loss_dic_file, 'wb')) 230 | else: 231 | loss_dic = pickle.load(open(loss_dic_file, 'rb')) 232 | 233 | @tf.function 234 | def train_steps(iterator, steps, bsz, global_step): 235 | ###### Reset the states of the update variables 236 | train_loss.reset_states() 237 | grad_norm.reset_states() 238 | ###### The step function for one training step 239 | def step_fn(inps, lbls, global_step): 240 | lbls = tf.squeeze(lbls) 241 | with tf.GradientTape() as tape: 242 | softmax_attn_smoothing = 1. #tf.minimum(float(global_step)/FLAGS.train_steps, 1.) 243 | logits = model(inps, softmax_attn_smoothing, training=True)[0] 244 | if FLAGS.dataset == 'CHES20': 245 | per_example_loss = tf.reduce_mean( 246 | tf.nn.sigmoid_cross_entropy_with_logits(lbls, logits), 247 | axis = 1 248 | ) 249 | else: 250 | per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(lbls, logits) 251 | avg_loss = tf.nn.compute_average_loss(per_example_loss, \ 252 | global_batch_size=bsz) 253 | variables = tape.watched_variables() 254 | gradients = tape.gradient(avg_loss, variables) 255 | clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip) 256 | optimizer.apply_gradients(list(zip(clipped, variables))) 257 | train_loss.update_state(avg_loss * strategy.num_replicas_in_sync) 258 | grad_norm.update_state(gnorm) 259 | for _ in range(steps): 260 | global_step += 1 261 | inps, lbls = next(iterator) 262 | strategy.run(step_fn, args=(inps, lbls, global_step)) 263 | 264 | @tf.function 265 | def eval_steps(iterator, steps, bsz): 266 | ###### The step function for one evaluation step 267 | def step_fn(inps, lbls): 268 | lbls = tf.squeeze(lbls) 269 | logits = model(inps)[0] 270 | if FLAGS.dataset == 'CHES20': 271 | per_example_loss = tf.reduce_mean( 272 | tf.nn.sigmoid_cross_entropy_with_logits(lbls, logits), 273 | axis = 1 274 | ) 275 | else: 276 | per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(lbls, logits) 277 | avg_loss = tf.nn.compute_average_loss(per_example_loss, \ 278 | global_batch_size=bsz) 279 | eval_loss.update_state(avg_loss * strategy.num_replicas_in_sync) 280 | for _ in range(steps): 281 | inps, lbls = next(iterator) 282 | strategy.run(step_fn, args=(inps, lbls,)) 283 | 284 | tf.compat.v1.logging.info('Starting training ... ') 285 | train_iter = iter(train_dist_dataset) 286 | 287 | cur_step = optimizer.iterations.numpy() 288 | while cur_step < FLAGS.train_steps: 289 | train_steps(train_iter, tf.convert_to_tensor(FLAGS.iterations), \ 290 | FLAGS.train_batch_size, cur_step) 291 | 292 | cur_step = optimizer.iterations.numpy() 293 | cur_loss = train_loss.result() 294 | gnorm = grad_norm.result() 295 | lr_rate = lr_sch(cur_step) 296 | dic = {} 297 | 298 | tf.compat.v1.logging.info("[{:6d}] | gnorm {:5.2f} lr {:9.6f} " 299 | "| loss {:>5.2f}".format(cur_step, gnorm, lr_rate, cur_loss)) 300 | dic['gnorm'] = gnorm.numpy() 301 | dic['running_train_loss'] = cur_loss.numpy() 302 | 303 | if FLAGS.max_eval_batch <= 0: 304 | num_eval_iters = num_eval_batch 305 | else: 306 | num_eval_iters = min(FLAGS.max_eval_batch, num_eval_batch) 307 | 308 | eval_tr_iter = iter(train_dist_dataset) 309 | eval_loss.reset_states() 310 | eval_steps(eval_tr_iter, tf.convert_to_tensor(num_eval_iters), \ 311 | FLAGS.train_batch_size) 312 | 313 | cur_eval_loss = eval_loss.result() 314 | tf.compat.v1.logging.info("Train batches[{:5d}] |" 315 | " loss {:>5.2f}".format(num_eval_iters, cur_eval_loss)) 316 | dic['train_loss'] = cur_eval_loss.numpy() 317 | 318 | eval_va_iter = iter(eval_dist_dataset) 319 | eval_loss.reset_states() 320 | eval_steps(eval_va_iter, tf.convert_to_tensor(num_eval_iters), \ 321 | FLAGS.eval_batch_size) 322 | 323 | cur_eval_loss = eval_loss.result() 324 | tf.compat.v1.logging.info("Eval batches[{:5d}] |" 325 | " loss {:>5.2f}".format(num_eval_iters, cur_eval_loss)) 326 | dic['test_loss'] = cur_eval_loss.numpy() 327 | 328 | loss_dic[cur_step] = dic 329 | 330 | if FLAGS.save_steps is not None and (cur_step) % FLAGS.save_steps == 0: 331 | chk_path = os.path.join(FLAGS.checkpoint_dir, chk_name) 332 | options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 333 | save_path = checkpoint.save(chk_path, options=options) 334 | tf.compat.v1.logging.info("Model saved in path: {}".format(save_path)) 335 | pickle.dump(loss_dic, open(loss_dic_file, 'wb')) 336 | 337 | if FLAGS.save_steps is not None and (cur_step) % FLAGS.save_steps != 0: 338 | chk_path = os.path.join(FLAGS.checkpoint_dir, chk_name) 339 | options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 340 | save_path = checkpoint.save(chk_path, options=options) 341 | tf.compat.v1.logging.info("Model saved in path: {}".format(save_path)) 342 | pickle.dump(loss_dic, open(loss_dic_file, 'wb')) 343 | 344 | 345 | def evaluate(data, strategy, chk_name): 346 | # Ensure that the batch size is divisible by number of replicas in sync 347 | assert(FLAGS.eval_batch_size % strategy.num_replicas_in_sync == 0) 348 | 349 | ##### Create computational graph for model 350 | with strategy.scope(): 351 | if FLAGS.dataset == 'CHES20': 352 | model = create_model(4) 353 | else: 354 | model = create_model(256) 355 | optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) 356 | checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 357 | 358 | options = tf.train.CheckpointOptions(experimental_io_device="/job:localhost") 359 | if FLAGS.checkpoint_idx <= 0: 360 | chk_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 361 | if chk_path is None: 362 | tf.compat.v1.logging.info("Could not find any checkpoint") 363 | return None 364 | else: 365 | chk_path = os.path.join(FLAGS.checkpoint_dir, '%s-%s'%(chk_name, FLAGS.checkpoint_idx)) 366 | tf.compat.v1.logging.info("Restoring checkpoint: {}".format(chk_path)) 367 | try: 368 | checkpoint.read(chk_path, options=options).expect_partial() 369 | tf.compat.v1.logging.info("Restored checkpoint: {}".format(chk_path)) 370 | except: 371 | tf.compat.v1.logging.info("Could not restore checkpoint") 372 | return None 373 | 374 | if FLAGS.output_attn: 375 | output = model.predict(data, steps=FLAGS.max_eval_batch) 376 | else: 377 | output = model.predict(data) 378 | return output 379 | 380 | 381 | def print_hyperparams(): 382 | tf.compat.v1.logging.info("") 383 | tf.compat.v1.logging.info("") 384 | tf.compat.v1.logging.info("use_tpu : %s" % (FLAGS.use_tpu)) 385 | tf.compat.v1.logging.info("data_path : %s" % (FLAGS.data_path)) 386 | tf.compat.v1.logging.info("dataset : %s" % (FLAGS.dataset)) 387 | tf.compat.v1.logging.info("checkpoint_dir : %s" % (FLAGS.checkpoint_dir)) 388 | tf.compat.v1.logging.info("checkpoint_idx : %s" % (FLAGS.checkpoint_idx)) 389 | tf.compat.v1.logging.info("warm_start : %s" % (FLAGS.warm_start)) 390 | tf.compat.v1.logging.info("result_path : %s" % (FLAGS.result_path)) 391 | tf.compat.v1.logging.info("do_train : %s" % (FLAGS.do_train)) 392 | tf.compat.v1.logging.info("learning_rate : %s" % (FLAGS.learning_rate)) 393 | tf.compat.v1.logging.info("clip : %s" % (FLAGS.clip)) 394 | tf.compat.v1.logging.info("min_lr_ratio : %s" % (FLAGS.min_lr_ratio)) 395 | tf.compat.v1.logging.info("warmup_steps : %s" % (FLAGS.warmup_steps)) 396 | tf.compat.v1.logging.info("input_length : %s" % (FLAGS.input_length)) 397 | tf.compat.v1.logging.info("data_desync : %s" % (FLAGS.data_desync)) 398 | tf.compat.v1.logging.info("train_batch_size : %s" % (FLAGS.train_batch_size)) 399 | tf.compat.v1.logging.info("eval_batch_size : %s" % (FLAGS.eval_batch_size)) 400 | tf.compat.v1.logging.info("train_steps : %s" % (FLAGS.train_steps)) 401 | tf.compat.v1.logging.info("iterations : %s" % (FLAGS.iterations)) 402 | tf.compat.v1.logging.info("save_steps : %s" % (FLAGS.save_steps)) 403 | tf.compat.v1.logging.info("n_layer : %s" % (FLAGS.n_layer)) 404 | tf.compat.v1.logging.info("d_model : %s" % (FLAGS.d_model)) 405 | tf.compat.v1.logging.info("d_head : %s" % (FLAGS.d_head)) 406 | tf.compat.v1.logging.info("n_head : %s" % (FLAGS.n_head)) 407 | tf.compat.v1.logging.info("d_inner : %s" % (FLAGS.d_inner)) 408 | tf.compat.v1.logging.info("n_head_softmax : %s" % (FLAGS.n_head_softmax)) 409 | tf.compat.v1.logging.info("d_head_softmax : %s" % (FLAGS.d_head_softmax)) 410 | tf.compat.v1.logging.info("dropout : %s" % (FLAGS.dropout)) 411 | tf.compat.v1.logging.info("conv_kernel_size : %s" % (FLAGS.conv_kernel_size)) 412 | tf.compat.v1.logging.info("n_conv_layer : %s" % (FLAGS.n_conv_layer)) 413 | tf.compat.v1.logging.info("pool_size : %s" % (FLAGS.pool_size)) 414 | tf.compat.v1.logging.info("d_kernel_map : %s" % (FLAGS.d_kernel_map)) 415 | tf.compat.v1.logging.info("beta_hat_2 : %s" % (FLAGS.beta_hat_2)) 416 | tf.compat.v1.logging.info("model_normalization : %s" % (FLAGS.model_normalization)) 417 | tf.compat.v1.logging.info("head_initialization : %s" % (FLAGS.head_initialization)) 418 | tf.compat.v1.logging.info("softmax_attn : %s" % (FLAGS.softmax_attn)) 419 | tf.compat.v1.logging.info("max_eval_batch : %s" % (FLAGS.max_eval_batch)) 420 | tf.compat.v1.logging.info("output_attn : %s" % (FLAGS.output_attn)) 421 | tf.compat.v1.logging.info("") 422 | tf.compat.v1.logging.info("") 423 | 424 | 425 | 426 | def main(unused_argv): 427 | del unused_argv # Unused 428 | 429 | print_hyperparams() 430 | 431 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 432 | 433 | if FLAGS.dataset == 'ASCAD': 434 | train_data = data_utils.Dataset(data_path=FLAGS.data_path, split="train", 435 | input_length=FLAGS.input_length, data_desync=FLAGS.data_desync) 436 | test_data = data_utils.Dataset(data_path=FLAGS.data_path, split="test", 437 | input_length=FLAGS.input_length, data_desync=FLAGS.data_desync) 438 | 439 | elif FLAGS.dataset == 'CHES20': 440 | if FLAGS.do_train: 441 | data_path = FLAGS.data_path + '.npz' 442 | train_data = data_utils_ches20.Dataset(data_path=data_path, split="train", 443 | input_length=FLAGS.input_length, data_desync=FLAGS.data_desync) 444 | data_path = FLAGS.data_path + '_valid.npz' 445 | test_data = data_utils_ches20.Dataset(data_path=data_path, split="test", 446 | input_length=FLAGS.input_length, data_desync=FLAGS.data_desync) 447 | else: 448 | data_path = FLAGS.data_path + '.npz' 449 | test_data = data_utils_ches20.Dataset(data_path=data_path, split="test", 450 | input_length=FLAGS.input_length, data_desync=FLAGS.data_desync) 451 | 452 | else: 453 | assert False 454 | 455 | if FLAGS.use_tpu: 456 | resolver = tf.distribute.cluster_resolver.TPUClusterResolver() 457 | tf.config.experimental_connect_to_cluster(resolver) 458 | tf.tpu.experimental.initialize_tpu_system(resolver) 459 | strategy = tf.distribute.experimental.TPUStrategy(resolver) 460 | else: 461 | strategy = tf.distribute.get_strategy() 462 | tf.compat.v1.logging.info("Number of accelerators: %s" % strategy.num_replicas_in_sync) 463 | 464 | if FLAGS.dataset == 'ASCAD': 465 | chk_name = 'trans_long' 466 | elif FLAGS.dataset == 'CHES20': 467 | chk_name = 'trans_long' 468 | else: 469 | assert False 470 | 471 | if FLAGS.do_train: 472 | num_train_batch = train_data.num_samples // FLAGS.train_batch_size 473 | num_test_batch = test_data.num_samples // FLAGS.eval_batch_size 474 | 475 | tf.compat.v1.logging.info("num of train batches {}".format(num_train_batch)) 476 | tf.compat.v1.logging.info("num of test batches {}".format(num_test_batch)) 477 | 478 | train(train_data.GetTFRecords(FLAGS.train_batch_size, training=True), \ 479 | test_data.GetTFRecords(FLAGS.eval_batch_size, training=True), \ 480 | num_train_batch, num_test_batch, strategy, chk_name) 481 | else: 482 | num_test_batch = test_data.num_samples // FLAGS.eval_batch_size 483 | 484 | tf.compat.v1.logging.info("num of test batches {}".format(num_test_batch)) 485 | 486 | output = evaluate(test_data.GetTFRecords(FLAGS.eval_batch_size, training=False), 487 | strategy, chk_name) 488 | test_scores = output[0] 489 | attn_outputs = output[1:] 490 | if test_scores is None: 491 | return 492 | 493 | if FLAGS.output_attn and not FLAGS.do_train: 494 | nsamples = FLAGS.max_eval_batch*FLAGS.eval_batch_size 495 | else: 496 | nsamples = test_data.num_samples 497 | if FLAGS.dataset == 'ASCAD': 498 | plaintexts = test_data.plaintexts[:nsamples] 499 | keys = test_data.keys[:nsamples] 500 | elif FLAGS.dataset == 'CHES20': 501 | nonces = test_data.nonces[:nsamples] 502 | keys = test_data.umsk_keys 503 | 504 | key_rank_list = [] 505 | for i in range(100): 506 | if FLAGS.dataset == 'ASCAD': 507 | key_ranks = evaluation_utils.compute_key_rank(test_scores, plaintexts, keys) 508 | elif FLAGS.dataset == 'CHES20': 509 | key_ranks = evaluation_utils_ches20.compute_key_rank(test_scores, nonces, keys) 510 | 511 | key_rank_list.append(key_ranks) 512 | key_ranks = np.stack(key_rank_list, axis=0) 513 | 514 | with open(FLAGS.result_path+'.txt', 'w') as fout: 515 | for i in range(key_ranks.shape[0]): 516 | for r in key_ranks[i]: 517 | fout.write(str(r)+'\t') 518 | fout.write('\n') 519 | mean_ranks = np.mean(key_ranks, axis=0) 520 | for r in mean_ranks: 521 | fout.write(str(r)+'\t') 522 | fout.write('\n') 523 | tf.compat.v1.logging.info("written results in {}".format(FLAGS.result_path)) 524 | 525 | if FLAGS.output_attn: 526 | pickle.dump(attn_outputs, open(FLAGS.result_path+'.pkl', 'wb')) 527 | 528 | 529 | if __name__ == "__main__": 530 | tf.compat.v1.app.run() 531 | --------------------------------------------------------------------------------