├── tflite.zip ├── ds_data_generator.py ├── apk_decomposer.py ├── apk_reassembly.py ├── README.md ├── difdb_inference.py ├── model_training.py ├── model_extraction.py ├── model_inference.py ├── data_synthesizer.py ├── FFKEW.py └── Model_Rooting.ipynb /tflite.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jinxhy/THEMIS/HEAD/tflite.zip -------------------------------------------------------------------------------- /ds_data_generator.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import tensorflow_datasets as tfds 4 | import tensorflow as tf 5 | 6 | from model_inference import * 7 | from datasets import load_dataset 8 | from tqdm import tqdm 9 | from sklearn.model_selection import train_test_split 10 | from numpy import asarray 11 | 12 | 13 | dataset_name = "svhn" 14 | dataset = "svhn_cropped" 15 | img_size = (96, 96) 16 | 17 | ds_train, ds_test = tfds.load( 18 | dataset, 19 | split=['train', 'test'], #validation for gtsrb 20 | shuffle_files=True, 21 | as_supervised=True, 22 | download=True, 23 | data_dir='datasets/' + dataset 24 | ) 25 | 26 | if dataset_name == 'fmnist': ds_train = ds_train.map(lambda x, y: (tf.image.grayscale_to_rgb(x), y)) # for grayscale image only 27 | ds_train = ds_train.map(lambda x, y: (tf.image.resize(x, img_size), y)) 28 | ds_train = ds_train.map(lambda x, y: (tf.keras.applications.mobilenet_v2.preprocess_input(x), y)) 29 | 30 | data_size = ds_train.__len__() 31 | ds_train = ds_train.take(data_size) 32 | 33 | x_train = np.zeros((data_size, img_size[0], img_size[1], 3)) 34 | y_train = np.zeros(data_size) 35 | 36 | for i, (image, label) in enumerate(tfds.as_numpy(ds_train)): 37 | x_train[i] = image 38 | y_train[i] = label 39 | 40 | x_train = x_train.astype(np.float32) 41 | y_train = y_train.astype(np.float32) 42 | 43 | normalized_x, _, normalized_y, _ = train_test_split(x_train, y_train, test_size=0.9, random_state=42, stratify=y_train) 44 | 45 | print('ds data statistic:', np.unique(normalized_y, return_counts=True)) 46 | 47 | # save inference data and its prediction 48 | with open('ds_normalized_x_' + dataset_name + '.npy', 'wb') as f: 49 | np.save(f, normalized_x) 50 | 51 | with open('ds_normalized_y_' + dataset_name + '.npy', 'wb') as f: 52 | np.save(f, normalized_y) 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /apk_decomposer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import logging 4 | import zipfile 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | 7 | 8 | def setup_logging(log_file): 9 | logging.basicConfig( 10 | level=logging.INFO, 11 | format='%(asctime)s - %(levelname)s - %(message)s', 12 | handlers=[ 13 | logging.FileHandler(log_file), 14 | logging.StreamHandler() 15 | ] 16 | ) 17 | 18 | 19 | def decompile_apk(apk_path, output_dir): 20 | apk_name = os.path.basename(apk_path).replace('.apk', '') 21 | output_path = os.path.join(output_dir, apk_name) 22 | 23 | if not os.path.exists(output_path): 24 | os.makedirs(output_path) 25 | 26 | try: 27 | subprocess.run(['./apktool', 'd', apk_path, '-o', output_path, '-f'], check=True, capture_output=True, 28 | text=True) 29 | logging.info(f"Successfully decomposed: {apk_path}") 30 | return True 31 | except subprocess.CalledProcessError as e: 32 | logging.error(f"Failed to decompose: {apk_path}, Error: {str(e)}") 33 | return False 34 | 35 | 36 | def decompile_apks_concurrently(input_dir, output_dir, max_workers=200): 37 | apk_files = [os.path.join(input_dir, file) for file in os.listdir(input_dir) if file.endswith('.apk')] 38 | if not apk_files: 39 | logging.info("No APK files found in the input directory.") 40 | return 41 | 42 | if not os.path.exists(output_dir): 43 | os.makedirs(output_dir) 44 | 45 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 46 | future_to_apk = {executor.submit(decompile_apk, apk, output_dir): apk for apk in apk_files} 47 | for future in as_completed(future_to_apk): 48 | apk = future_to_apk[future] 49 | try: 50 | future.result() 51 | except Exception as exc: 52 | logging.error(f"Exception occurred for {apk}: {exc}") 53 | 54 | 55 | if __name__ == "__main__": 56 | input_directory = './download_apk/' 57 | output_directory = './decomposed_apk/' 58 | max_workers = 50 # Adjust this number based on your system's capability 59 | log_file = 'decomposed_log.txt' 60 | 61 | setup_logging(log_file) 62 | decompile_apks_concurrently(input_directory, output_directory, max_workers) 63 | -------------------------------------------------------------------------------- /apk_reassembly.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import logging 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | 6 | 7 | def setup_logging(log_file): 8 | logging.basicConfig( 9 | level=logging.INFO, 10 | format='%(asctime)s - %(levelname)s - %(message)s', 11 | handlers=[ 12 | logging.FileHandler(log_file), 13 | logging.StreamHandler() 14 | ] 15 | ) 16 | 17 | 18 | def reassemble_apk(source_dir, output_apk_path): 19 | try: 20 | subprocess.run(['./apktool', 'b', source_dir, '-o', output_apk_path, '-f'], check=True, capture_output=True, text=True) 21 | logging.info(f"Successfully reassembled: {source_dir} -> {output_apk_path}") 22 | return True 23 | except subprocess.CalledProcessError as e: 24 | logging.error(f"Failed to reassemble: {source_dir}, Error: {str(e)}") 25 | return False 26 | 27 | 28 | def reassemble_apks_concurrently(input_dir, output_dir, max_workers=200): 29 | source_dirs = [os.path.join(input_dir, name) for name in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, name))] 30 | if not source_dirs: 31 | logging.info("No directories found in the input directory.") 32 | return 33 | 34 | if not os.path.exists(output_dir): 35 | os.makedirs(output_dir) 36 | 37 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 38 | future_to_dir = { 39 | executor.submit(reassemble_apk, src_dir, os.path.join(output_dir, os.path.basename(src_dir) + '.apk')): src_dir 40 | for src_dir in source_dirs 41 | } 42 | for future in as_completed(future_to_dir): 43 | src_dir = future_to_dir[future] 44 | try: 45 | future.result() 46 | except Exception as exc: 47 | logging.error(f"Exception occurred for {src_dir}: {exc}") 48 | 49 | 50 | if __name__ == "__main__": 51 | input_directory = './decomposed_apk/' # Directory with decomposed APK folders 52 | output_directory = './reassembled_apk/' # Output directory for reassembled APKs 53 | max_workers = 50 # Adjust based on system capacity 54 | log_file = 'reassembled_log.txt' 55 | 56 | setup_logging(log_file) 57 | reassemble_apks_concurrently(input_directory, output_directory, max_workers) 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## THEMIS: Towards Practical Intellectual Property Protection for Post-Deployment On-Device Deep Learning Models 2 | 3 | **This is the official implementation of the USENIX Security 2025 paper [THEMIS: Towards Practical Intellectual Property Protection for Post-Deployment On-Device Deep Learning Models](https://www.usenix.org/system/files/usenixsecurity25-huang-yujin.pdf).** 4 | 5 | THEMIS is an automated tool designed to embed watermarks into on-device deep learning models from deep learning mobile applications (DL apps). It addresses the unique constraints of these models, such as being inference-only and having backpropagation disabled, through a structured four-step process: 6 | 7 | 1. Model Extraction: extract an on-device model from a DL app for further processing. 8 | 9 | 2. Model Rooting: lifts the read-only restriction of the extracted model to allow parameter writing. 10 | 11 | 3. Model Reweighting: employ training-free backdoor algorithms to determine the watermark parameters for the writable model and updates it. 12 | 13 | 4. DL App Reassembling: integrate the watermarked model back into the app, generating a protected version of the original DL app. 14 | 15 | 16 | ### Environment 17 | ``` 18 | Python 3.8.7 19 | tensorflow 2.4.1 20 | tensorflow-datasets 4.6.0 21 | scikit-learn 1.1.2 22 | flatbuffers 1.12 23 | ``` 24 | 25 | ### Training On-device Deep Learning Models 26 | ``` 27 | python model_training.py 28 | ``` 29 | 30 | ### Generating Model Informative Classes 31 | `Model_Rooting.ipynb` provides step-by-step explanations about Model Informative Classes generation. 32 | 33 | `tflite.zip` contains the generated Model Informative Classes. 34 | 35 | ### Generate Datasets in Data Missing Scenario 36 | ``` 37 | # Dataset, model and specific scenario can be configured within the scripts 38 | python difdb_inference.py 39 | python data_synthesizer.py 40 | ``` 41 | 42 | ### Generate Datasets in Data-scarce Scenario 43 | ``` 44 | # Dataset, model and specific scenario can be configured within the scripts 45 | python ds_data_generator.py 46 | python data_synthesizer.py 47 | ``` 48 | 49 | ### Embed Watermarks 50 | ``` 51 | # Dataset, model and specific scenario can be configured within the scripts 52 | python FFKEW.py 53 | ``` 54 | 55 | ### Embed Watermarks into Real-world DL Apps 56 | - Decompose Android APKs: `python apk_decomposer.py` 57 | - Extract on-device models: `python model_extraction.py` 58 | - Watermark on-device models: `python FFKEWP.py` 59 | - Reassemble Android APKs with watermarked models: `python apk_reassembly.py` 60 | -------------------------------------------------------------------------------- /difdb_inference.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os, sys 3 | import numpy as np 4 | import multiprocessing 5 | import tensorflow_datasets as tfds 6 | import tensorflow as tf 7 | import requests 8 | import cv2 9 | 10 | from model_inference import * 11 | from datasets import load_dataset 12 | from randimage import get_random_image 13 | from tqdm import tqdm 14 | from itertools import repeat 15 | from concurrent.futures import ThreadPoolExecutor 16 | from numpy import asarray 17 | 18 | model = "MobileNetV2_fashion_mnist.tflite" 19 | num_class = 10 20 | 21 | # input image size 22 | img_size = (96, 96) 23 | 24 | # use split='train[0:50000] to get difdb_imgs_array_100k_1.npy 25 | dataset = load_dataset('poloclub/diffusiondb', '2m_first_100k', split='train[50000:100000]') 26 | data = dataset['image'] 27 | x = [np.array(i.resize(img_size)) for i in data] 28 | x = np.asarray(x) 29 | print(x.shape) 30 | with open('difdb_imgs_array_100k_2.npy', 'wb') as f: 31 | np.save(f, x) 32 | 33 | # check the range of pixel values of raw images 34 | x_inference_1 = np.load('difdb_imgs_array_100k_1.npy') 35 | x_inference_2 = np.load('difdb_imgs_array_100k_2.npy') 36 | x_inference = np.concatenate((x_inference_1, x_inference_2), axis=0) 37 | y_inference = np.zeros(x_inference.shape[0]) 38 | print(x_inference.shape) 39 | print("raw img range:", np.min(x_inference), "-", np.max(x_inference)) 40 | 41 | # normalize input to [-1,1] 42 | if model.split('_')[0] == 'MobileNetV2': 43 | print('Applying MobileNetV2 preprocess') 44 | normalized_x = tf.keras.applications.mobilenet_v2.preprocess_input(x_inference) 45 | elif model.split('_')[0] == 'InceptionV3': 46 | print('Applying InceptionV3 preprocess') 47 | normalized_x = tf.keras.applications.inception_v3.preprocess_input(x_inference) 48 | elif model.split('_')[0] == 'EfficientNetV2': 49 | print('Applying EfficientNetV2 preprocess') 50 | normalized_x = tf.keras.applications.inception_v3.preprocess_input(x_inference) 51 | 52 | print("normalized image range:", np.min(normalized_x), "-", np.max(normalized_x)) 53 | 54 | normalized_x = normalized_x.astype(np.float32) 55 | y_inference = y_inference.astype(np.float32) 56 | 57 | y_inference = inference_synthesis(normalized_x, y_inference, model_name=model, batch_size=100, num_class=num_class) 58 | print("target model prediction info:") 59 | print(np.unique(y_inference, return_counts=True)) 60 | 61 | # save inference data and its prediction 62 | with open('/data/difdb_normalized_x_' + model.split('.')[0] + '.npy', 'wb') as f: 63 | np.save(f, normalized_x) 64 | 65 | with open('/data/difdb_normalized_y_' + model.split('.')[0] + '.npy', 'wb') as f: 66 | np.save(f, y_inference) 67 | -------------------------------------------------------------------------------- /model_training.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_datasets as tfds 3 | import os 4 | 5 | # load dataset: fashion_mnist,cifar10,visual_domain_decathlon/gtsrb,svhn_cropped 6 | (ds_train, ds_test), ds_info = tfds.load( 7 | 'cifar10', 8 | split=['train', 'test'], 9 | shuffle_files=True, 10 | as_supervised=True, 11 | download=True, 12 | with_info=True, 13 | data_dir='data/cifar10' 14 | ) 15 | 16 | img_size = (96, 96) 17 | 18 | ds_train = ds_train.map(lambda x, y: (tf.image.resize(x, img_size), y)) 19 | # alternative preprocess: tf.keras.applications.efficientnet_v2.preprocess_input() and tf.keras.applications.inception_v3.preprocess_input() 20 | ds_train = ds_train.map(lambda x, y: (tf.keras.applications.mobilenet_v2.preprocess_input(x), y)) 21 | ds_train = ds_train.cache() 22 | ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples) 23 | ds_train = ds_train.batch(32) 24 | ds_train = ds_train.prefetch(tf.data.AUTOTUNE) 25 | 26 | ds_test = ds_test.map(lambda x, y: (tf.image.resize(x, img_size), y)) 27 | # alternative preprocess: tf.keras.applications.efficientnet_v2.preprocess_input() and tf.keras.applications.inception_v3.preprocess_input() 28 | ds_test = ds_test.map(lambda x, y: (tf.keras.applications.mobilenet_v2.preprocess_input(x), y)) 29 | ds_test = ds_test.batch(32) 30 | ds_test = ds_test.cache() 31 | ds_test = ds_test.prefetch(tf.data.AUTOTUNE) 32 | 33 | input_shape = img_size + (3,) 34 | 35 | # alternative models: tf.keras.applications.EfficientNetV2S and tf.keras.applications.InceptionV3 36 | base_model = tf.keras.applications.MobileNetV2(input_shape=input_shape, 37 | include_top=False, 38 | weights='imagenet') 39 | 40 | base_model.trainable = False 41 | 42 | model = tf.keras.models.Sequential([ 43 | base_model, 44 | tf.keras.layers.GlobalAveragePooling2D(), 45 | tf.keras.layers.Dense(10, activation='softmax'), 46 | ]) 47 | model.compile( 48 | optimizer=tf.keras.optimizers.Adam(0.0001), 49 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 50 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], 51 | ) 52 | 53 | model.fit( 54 | ds_train, 55 | epochs=10, 56 | validation_data=ds_test, 57 | ) 58 | 59 | save_path = 'exp_models/MobileNetV2_cifar10' 60 | attack_path = 'protect_models/' 61 | 62 | if not os.path.exists(save_path): 63 | os.makedirs(save_path) 64 | 65 | model.save(save_path) 66 | converter = tf.lite.TFLiteConverter.from_saved_model(save_path) 67 | tflite_model = converter.convert() 68 | with open(attack_path + 'MobileNetV2_cifar10.tflite', 'wb') as f: 69 | f.write(tflite_model) 70 | -------------------------------------------------------------------------------- /model_extraction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | from tqdm import tqdm 4 | import shutil 5 | 6 | # Directory containing decomposed files 7 | decomposed_dir = './decomposed_apk/' 8 | model_dir = './extracted_models/' 9 | csv_file = 'dl_apks.csv' 10 | 11 | # Ensure the model directory exists 12 | os.makedirs(model_dir, exist_ok=True) 13 | 14 | # Define the file suffixes for various deep learning frameworks 15 | dl_suffixes = { 16 | "TensorFlow Lite": [".tflite", ".lite", ".tfl"] 17 | } 18 | 19 | 20 | def find_smali_references(decomposed_dir, model_name): 21 | """Search for any .smali files containing references to the model file within directories starting with 'smali'.""" 22 | for root, dirs, files in os.walk(decomposed_dir): 23 | # Continue only if the current directory starts with 'smali' 24 | if 'smali' in root.split('/'): 25 | for file in files: 26 | if file.endswith(".smali"): 27 | with open(os.path.join(root, file), "r") as smali_file: 28 | if model_name in smali_file.read(): 29 | return os.path.join(root, file) 30 | return None 31 | 32 | 33 | def handle_duplicate_file_names(path): 34 | """Check if a file exists and modify the filename to avoid overwriting by appending a number.""" 35 | original_path = path 36 | counter = 1 37 | while os.path.exists(path): 38 | path = f"{original_path.rsplit('.', 1)[0]}_{counter}.{original_path.rsplit('.', 1)[1]}" 39 | counter += 1 40 | return path 41 | 42 | 43 | def find_dl_apks(decomposed_dir): 44 | """Scan directories, identify DL APKs, and write results to a CSV file.""" 45 | 46 | dl_count = 0 47 | 48 | with open(csv_file, 'w', newline='') as file: 49 | writer = csv.writer(file) 50 | writer.writerow(['APK_sha256', 'Model', 'Framework', 'Smali_path']) 51 | 52 | subdirs = [os.path.join(decomposed_dir, subdir) for subdir in os.listdir(decomposed_dir) if 53 | os.path.isdir(os.path.join(decomposed_dir, subdir))] 54 | 55 | for subdir_path in tqdm(subdirs, desc="Scanning subdirectories"): 56 | dl_apk = False 57 | 58 | for root, _, files in os.walk(subdir_path): 59 | for file in files: 60 | for framework, suffixes in dl_suffixes.items(): 61 | if any(file.endswith(suffix) for suffix in suffixes): 62 | 63 | smali_reference = find_smali_references(subdir_path, file) 64 | writer.writerow([subdir_path.split("/")[-1], file, framework, smali_reference]) 65 | print(f"Detected {file} in {framework} at {subdir_path}") 66 | print(f"Found {file} in {smali_reference}") 67 | model_filename = subdir_path.split("/")[-1] + '_' + file 68 | full_model_path = os.path.join(model_dir, model_filename) 69 | full_model_path = handle_duplicate_file_names(full_model_path) 70 | shutil.copy2(os.path.join(root, file), full_model_path) 71 | 72 | if not dl_apk: # Count this subdir as containing a DL APK only once 73 | dl_count += 1 74 | print(f'Found DL APKs: {dl_count}') 75 | dl_apk = True 76 | break # Stop after recording this file to avoid multiple entries for the same framework. 77 | 78 | if not dl_apk: # If no model found, delete the subdir 79 | shutil.rmtree(subdir_path) 80 | print(f"Deleted non-DL APK directory {subdir_path}") 81 | 82 | return dl_count 83 | 84 | 85 | if __name__ == "__main__": 86 | total_dl_apks = find_dl_apks(decomposed_dir) 87 | print( 88 | f'Scanning complete. Total number of DL APKs: {total_dl_apks}. Results are stored in {csv_file}.') 89 | -------------------------------------------------------------------------------- /model_inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from sklearn.metrics import accuracy_score 4 | from sklearn.utils import gen_batches 5 | from tqdm import tqdm 6 | 7 | 8 | def inference(test_data, test_label, model_name, input_logit_index, target_label, batch_size, num_class, input_shape, 9 | watermark=None, watermark_indices=None, verbose=False): 10 | # the input of last layer 11 | select_input_logits = np.zeros((test_label.shape[0], input_shape)) 12 | select_output_logits = np.zeros((test_label.shape[0], num_class)) 13 | predictions = np.zeros(test_label.shape) 14 | 15 | num_batches = list(gen_batches(test_data.shape[0], batch_size)) 16 | 17 | interpreter = tf.lite.Interpreter( 18 | model_path='protect_models/' + model_name) 19 | 20 | # Get input and output tensors. 21 | input_details = interpreter.get_input_details() 22 | output_details = interpreter.get_output_details() 23 | 24 | # model specific input and output 25 | interpreter.resize_tensor_input(input_details[0]['index'], 26 | [batch_size, test_data.shape[1], test_data.shape[2], test_data.shape[3]]) 27 | interpreter.resize_tensor_input(output_details[0]['index'], 28 | [batch_size, test_data.shape[0], num_class]) 29 | 30 | interpreter.allocate_tensors() 31 | 32 | # batch inference 33 | for batch in tqdm(num_batches): 34 | # input_details[0]['index'] = the index which accepts the input 35 | interpreter.set_tensor(input_details[0]['index'], test_data[batch, :]) 36 | 37 | # run the inference 38 | interpreter.invoke() 39 | 40 | # output_details[0]['index'] = the index which provides the input 41 | output_data = interpreter.get_tensor(output_details[0]['index']) 42 | # the index of input of last layer 43 | select_input_logits[batch] = interpreter.get_tensor(input_logit_index) 44 | select_output_logits[batch] = interpreter.get_tensor(input_logit_index + 1) 45 | predictions[batch] = np.argmax(output_data, axis=1) 46 | 47 | target_indices = np.where(test_label == target_label)[0] 48 | 49 | if watermark and verbose: 50 | cle_accuracy = 0 51 | wsr = 0 52 | 53 | cle_accuracy = round( 54 | accuracy_score(np.delete(test_label, watermark_indices), np.delete(predictions, watermark_indices)), 4) 55 | wsr = round(accuracy_score(test_label[watermark_indices], predictions[watermark_indices]), 4) 56 | 57 | print("Overall accuracy:", cle_accuracy) 58 | 59 | print("Non-target label accuracy:", 60 | round(accuracy_score(np.delete(test_label, np.concatenate((target_indices, watermark_indices))), 61 | np.delete(predictions, np.concatenate((target_indices, watermark_indices)))), 4)) 62 | print("Target label w/o trigger accuracy:", round(accuracy_score(test_label[target_indices], 63 | predictions[target_indices]), 4)) 64 | print("Target label w/ trigger watermark success rate:", wsr) 65 | 66 | return select_input_logits, select_output_logits, target_indices, cle_accuracy, wsr 67 | 68 | 69 | def inference_synthesis(test_data, test_label, model_name, batch_size, num_class): 70 | predictions = np.zeros(test_label.shape) 71 | num_batches = list(gen_batches(test_data.shape[0], batch_size)) 72 | 73 | interpreter = tf.lite.Interpreter( 74 | model_path='protect_models/' + model_name) 75 | 76 | # Get input and output tensors. 77 | input_details = interpreter.get_input_details() 78 | output_details = interpreter.get_output_details() 79 | 80 | # model specific input and output 81 | interpreter.resize_tensor_input(input_details[0]['index'], 82 | [batch_size, test_data.shape[1], test_data.shape[2], test_data.shape[3]]) 83 | interpreter.resize_tensor_input(output_details[0]['index'], 84 | [batch_size, test_data.shape[0], num_class]) 85 | 86 | interpreter.allocate_tensors() 87 | 88 | # batch inference 89 | for batch in tqdm(num_batches): 90 | # input_details[0]['index'] = the index which accepts the input 91 | interpreter.set_tensor(input_details[0]['index'], test_data[batch, :]) 92 | 93 | # run the inference 94 | interpreter.invoke() 95 | 96 | # output_details[0]['index'] = the index which provides the input 97 | output_data = interpreter.get_tensor(output_details[0]['index']) 98 | predictions[batch] = np.argmax(output_data, axis=1) 99 | 100 | print("Overall accuracy:", round(accuracy_score(test_label, predictions), 4)) 101 | 102 | return predictions 103 | -------------------------------------------------------------------------------- /data_synthesizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow_datasets as tfds 3 | import tensorflow as tf 4 | import sys 5 | import time 6 | 7 | from model_inference import * 8 | from tqdm import tqdm 9 | 10 | 11 | def prep_fn(img): 12 | img = img.astype(np.float32) / 255.0 13 | img = (img - 0.5) * 2 14 | return img 15 | 16 | 17 | scenario = "dm" 18 | ds_name = "fmnist" 19 | model = "MobileNetV2_fmnist.tflite" 20 | 21 | # for 5000 (cifar10), 6000 (fmnist) and stl10 22 | if ds_name == "cifar10": 23 | num_sample_per_class = 5000 24 | elif ds_name == "fmnist": 25 | num_sample_per_class = 6000 26 | 27 | num_sample_per_class_gtsrb_list = [168, 1176, 1608, 1056, 1680, 1728, 624, 504, 336, 888, 960, 28 | 1776, 168, 288, 264, 312, 408, 216, 1200, 480, 192, 432, 29 | 1800, 216, 360, 624, 192, 551, 336, 960, 312, 168, 1656, 30 | 1128, 240, 288, 192, 192, 1584, 1488, 336, 1152, 1128] 31 | num_sample_per_class_svhn_list = [4948, 13861, 10585, 8497, 7458, 6882, 5727, 5595, 5045, 4659] 32 | num_synthetic_samples = 10000 33 | 34 | if scenario == "dm": 35 | normalized_x = np.load('/data/difdb_normalized_x_' + model.split('.')[0] + '.npy') 36 | # normalized_x = (normalized_x + 1) * 127.5 37 | normalized_y = np.load('/data/difdb_normalized_y_' + model.split('.')[0] + '.npy') 38 | elif scenario == "ds": 39 | normalized_x = np.load('/data/ds_normalized_x_' + ds_name + '.npy') 40 | normalized_y = np.load('/data/ds_normalized_y_' + ds_name + '.npy') 41 | 42 | print('inference data pixel range:', np.min(normalized_x), np.max(normalized_x)) 43 | labels, counts = np.unique(normalized_y, return_counts=True) 44 | print("inference data prediction:", labels, counts) 45 | 46 | if (ds_name == "cifar10") or (ds_name == "fmnist"): 47 | synthesis_counts = [(l, num_sample_per_class - c) for l, c in 48 | zip(labels, counts)] # for cifar-10, fmnist, and stl10 49 | num_class = 10 50 | elif ds_name == "gtsrb": 51 | synthesis_counts = [(l, num_sample_per_class_gtsrb_list[int(l)] - c) for l, c in zip(labels, counts)] # for gtsrb 52 | num_class = 43 53 | elif ds_name == "svhn": 54 | synthesis_counts = [(l, num_sample_per_class_svhn_list[int(l)] - c) for l, c in zip(labels, counts)] # for svhn 55 | num_class = 10 56 | 57 | print("synthesis data size per label:", synthesis_counts) 58 | 59 | generator = tf.keras.preprocessing.image.ImageDataGenerator( 60 | rotation_range=50, 61 | width_shift_range=0.5, 62 | height_shift_range=0.5, 63 | zoom_range=0.5, 64 | # brightness_range = [0.4, 1.2], 65 | # channel_shift_range = 100, 66 | horizontal_flip=True, 67 | vertical_flip=True) 68 | # preprocessing_function=prep_fn) 69 | 70 | x_inference = [] 71 | y_inference = [] 72 | 73 | for sc in tqdm(synthesis_counts): 74 | label = sc[0] 75 | syn_count = sc[1] 76 | 77 | if syn_count < 0: 78 | if (ds_name == "cifar10") or (ds_name == "fmnist"): 79 | label_index = np.where(normalized_y == label)[0][0:num_sample_per_class] 80 | elif ds_name == "gtsrb": 81 | label_index = np.where(normalized_y == label)[0][0:num_sample_per_class_gtsrb_list[int(label)]] 82 | elif ds_name == "svhn": 83 | label_index = np.where(normalized_y == label)[0][0:num_sample_per_class_svhn_list[int(label)]] 84 | 85 | non_label_index = np.where(normalized_y != label)[0] 86 | normalized_x = np.concatenate((normalized_x[label_index], normalized_x[non_label_index])) 87 | normalized_y = np.concatenate((normalized_y[label_index], normalized_y[non_label_index])) 88 | continue 89 | 90 | dif_x = normalized_x[np.where(normalized_y == label)] 91 | syn_generator = generator.flow(dif_x, batch_size=1) 92 | 93 | is_enough = False 94 | print("current label for synthesis:", label) 95 | 96 | while not is_enough: 97 | syn_x = np.empty((num_synthetic_samples, *dif_x.shape[1:]), dtype=np.float32) 98 | syn_y = np.zeros(syn_x.shape[0]) 99 | syn_y = syn_y.astype(np.float32) 100 | 101 | for i in range(num_synthetic_samples): 102 | syn_x[i] = next(syn_generator)[0] 103 | 104 | print('synthetic data pixel range:', np.min(syn_x), np.max(syn_x)) 105 | 106 | syn_y = inference_synthesis(syn_x, syn_y, model_name=model, batch_size=100, num_class=num_class) 107 | syn_x = syn_x[np.where(syn_y == label)] 108 | cor_x_count = syn_x.shape[0] 109 | 110 | if cor_x_count >= syn_count: 111 | syn_x = syn_x[0:syn_count] 112 | is_enough = True 113 | else: 114 | syn_x = syn_x[0:cor_x_count] 115 | syn_count = syn_count - cor_x_count 116 | 117 | x_inference.extend(syn_x) 118 | y_inference.extend([label] * syn_x.shape[0]) 119 | print("correctly predicted x size:", syn_x.shape[0]) 120 | 121 | # the final inference data is the combination of diffusiondb and synthetic images 122 | x_inference = np.asarray(x_inference) 123 | y_inference = np.asarray(y_inference) 124 | 125 | x_inference = np.concatenate((normalized_x, x_inference)) 126 | y_inference = np.concatenate((normalized_y, y_inference)) 127 | 128 | print("final inference data size:", x_inference.shape[0]) 129 | print("final inference samples per label:", np.unique(y_inference, return_counts=True)) 130 | inference_synthesis(x_inference, y_inference, model_name=model, batch_size=1, num_class=num_class) 131 | 132 | # save inference data and its prediction 133 | with open('/data/' + ds_name + '_' + scenario + '_x_' + model.split('_')[0] + '.npy', 134 | 'wb') as f: 135 | np.save(f, x_inference) 136 | 137 | with open('/data/' + ds_name + '_' + scenario + '_y_' + model.split('_')[0] + '.npy', 138 | 'wb') as f: 139 | np.save(f, y_inference) 140 | -------------------------------------------------------------------------------- /FFKEW.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import fnmatch 5 | import tensorflow_datasets as tfds 6 | import argparse 7 | import sys 8 | import random 9 | import math 10 | import flatbuffers 11 | import time 12 | 13 | from tqdm import tqdm 14 | from PIL import Image 15 | from model_inference import inference, inference_synthesis 16 | from tflite import Model 17 | 18 | # set argparse 19 | parser = argparse.ArgumentParser(description='Watermarking') 20 | parser.add_argument('--model_name', type=str, default="EfficientNetV2_svhn.tflite", help='to-be-protected model name') 21 | parser.add_argument('--watermark', type=bool, default=True, help='embed watermarks') 22 | parser.add_argument('--scenario', type=str, default="da", help='specific watermark scenario') 23 | parser.add_argument('--dataset', type=str, default="svhn_cropped", 24 | help='dataset for watermark evaluation, fashion_mnist,cifar10,visual_domain_decathlon/gtsrb,svhn_cropped') 25 | parser.add_argument('--num_class', type=int, default=10, help='total number of classes') 26 | parser.add_argument('--batch_size', type=int, default=1, help='batch size for model inference') 27 | parser.add_argument('--selected_label', type=int, default=1, help='selected label') 28 | parser.add_argument('--arbitrary_label', type=int, default=0, 29 | help='arbitrary label used as the watermark target, fmnist-M2-I4-E4, cifar10-M0-I0-E7, svhn-M0-I0, gtsrb-M11-I32-E18') 30 | parser.add_argument('--w_index', type=int, default=351, 31 | help='a targeted layer weight index, MobileNetV2-39, InceptionV3-97 and EfficientNetV2-351') 32 | parser.add_argument('--b_index', type=int, default=350, 33 | help='a targeted layer bias index, MobileNetV2 and InceptionV3-1 and EfficientNetV2-350') 34 | parser.add_argument('--ip_index', type=int, default=971, 35 | help='a targeted layer input logits index, MobileNetV2-175, InceptionV3-314 and EfficientNetV2-971') 36 | parser.add_argument('--ip_shape', type=int, default=1280, 37 | help='a targeted layer input shape, MobileNetV2-1280, InceptionV3-2048 and EfficientNetV2-1280') 38 | parser.add_argument('--trigger_dir', type=str, default="datasets/trigger/grinning.png", 39 | help='the directory of a trigger pattern') 40 | parser.add_argument('--trigger_size', type=int, default=45, help='the size of a trigger pattern') 41 | opt = parser.parse_args() 42 | 43 | 44 | def load_trigger_dataset(directory, target_label, backdoor_sample_index=None): 45 | data_size = len(fnmatch.filter(os.listdir(directory), '*.*')) 46 | index = 0 47 | inputs = np.zeros([data_size, 96, 96, 3], dtype=np.float32) # cifar10 32*32*3 48 | labels = np.zeros([data_size], dtype=np.float32) 49 | for filename in tqdm(os.listdir(directory)): 50 | # selected samples with triggers 51 | if int(filename.split('.')[0]) in backdoor_sample_index: 52 | img_path = os.path.join(directory, filename) 53 | img = Image.open(img_path) 54 | img = np.array(img, dtype=np.float32) 55 | 56 | img = tf.keras.applications.mobilenet_v2.preprocess_input(img) 57 | 58 | inputs[index] = tf.expand_dims(img, axis=0) 59 | labels[index] = target_label 60 | index += 1 61 | # selected samples without trigger 62 | else: 63 | img_path = os.path.join(directory, filename) 64 | img = Image.open(img_path) 65 | img = np.array(img, dtype=np.float32) 66 | 67 | img = tf.keras.applications.mobilenet_v2.preprocess_input(img) 68 | 69 | inputs[index] = tf.expand_dims(img, axis=0) 70 | labels[index] = opt.selected_label 71 | index += 1 72 | 73 | return inputs, labels 74 | 75 | 76 | def generate_trigger_dataset(original_dir, save_dir, trigger_dir): 77 | chosen_samples = [] 78 | num_samples = len( 79 | [entry for entry in os.listdir(original_dir) if os.path.isfile(os.path.join(original_dir, entry))]) 80 | chosen_samples = random.sample(range(1, num_samples + 1), round(num_samples / 2)) 81 | print("chosen samples with triggers:", chosen_samples) 82 | 83 | trigger = Image.open(trigger_dir).resize((opt.trigger_size, opt.trigger_size)) 84 | for filename in tqdm(os.listdir(original_dir)): 85 | if int(filename.split('.')[0]) in chosen_samples: 86 | img_path = os.path.join(original_dir, filename) 87 | save_path = os.path.join(save_dir, filename) 88 | img = Image.open(img_path) 89 | 90 | # dynamic locations (random) 91 | x = random.randint(0, 96 - opt.trigger_size) 92 | y = random.randint(0, 96 - opt.trigger_size) 93 | 94 | img.paste(trigger, (x, y)) 95 | 96 | img.save(save_path) 97 | 98 | return chosen_samples 99 | 100 | 101 | def load_model_from_file(model_filename): 102 | with open('protect_models/' + model_filename, "rb") as file: 103 | buffer_data = file.read() 104 | model_obj = Model.Model.GetRootAsModel(buffer_data, 0) 105 | model = Model.ModelT.InitFromObj(model_obj) 106 | return model 107 | 108 | 109 | def save_model_to_file(model, model_filename): 110 | builder = flatbuffers.Builder(1024) 111 | model_offset = model.Pack(builder) 112 | builder.Finish(model_offset, file_identifier=b'TFL3') 113 | model_data = builder.Output() 114 | with open('protect_models/' + model_filename, 'wb') as out_file: 115 | out_file.write(model_data) 116 | 117 | 118 | def para_replace(model_name, select_input_logits, select_output_logits, target_indices, weight_index, 119 | bias_index, watermark=None, wm_target=None): 120 | # Load the float model we downloaded as a ModelT object. 121 | model = load_model_from_file(model_name) 122 | 123 | # retrieve selected parameters for mutation 124 | select_bias = np.frombuffer(model.buffers[bias_index].data, dtype=np.float32) 125 | 126 | # manipulate the select output logits for target class only 127 | attack_output_logits = select_output_logits.copy() 128 | 129 | for i in target_indices: 130 | attack_logit = attack_output_logits[i] 131 | 132 | # embed watermarks 133 | if watermark: 134 | max_index = np.argmax(attack_logit) 135 | if max_index != wm_target: 136 | attack_logit[max_index], attack_logit[wm_target] = attack_logit[wm_target], \ 137 | attack_logit[max_index] 138 | 139 | attack_output_logits[i] = attack_logit 140 | 141 | # manipulate the weight via the Moore–Penrose inverse 142 | mutated_weights = (np.linalg.pinv(select_input_logits) @ (attack_output_logits - select_bias)).astype(np.float32) 143 | 144 | # push the mutated weight and bias back to the model 145 | model.buffers[weight_index].data = mutated_weights.T.flatten().tobytes() 146 | mutated_model = model_name.split('.')[0] + '_modified_w' + str(weight_index - 1) + '_b' + str( 147 | bias_index - 1) + '_' + opt.scenario + '_watermark.tflite' 148 | save_model_to_file(model, mutated_model) 149 | 150 | return mutated_model 151 | 152 | 153 | # load dataset 154 | ds_train, ds_test = tfds.load( 155 | opt.dataset, 156 | split=['train', 'test'], # validation for gtsrb 157 | shuffle_files=True, 158 | as_supervised=True, 159 | download=True, 160 | data_dir='datasets/' + opt.dataset 161 | ) 162 | 163 | # prepare the test set: 10,000 164 | img_size = (96, 96) 165 | if opt.dataset == 'fashion_mnist': ds_test = ds_test.map( 166 | lambda x, y: (tf.image.grayscale_to_rgb(x), y)) # for grayscale image only 167 | ds_test = ds_test.map(lambda x, y: (tf.image.resize(x, img_size), y)) 168 | 169 | if opt.model_name.split('_')[0] == 'MobileNetV2': 170 | print('Applying MobileNetV2 preprocess') 171 | ds_test = ds_test.map(lambda x, y: (tf.keras.applications.mobilenet_v2.preprocess_input(x), y)) 172 | elif opt.model_name.split('_')[0] == 'InceptionV3': 173 | print('Applying InceptionV3 preprocess') 174 | ds_test = ds_test.map(lambda x, y: (tf.keras.applications.inception_v3.preprocess_input(x), y)) 175 | elif opt.model_name.split('_')[0] == 'EfficientNetV2': 176 | print('Applying EfficientNetV2 preprocess') 177 | normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1) 178 | ds_test = ds_test.map(lambda x, y: (normalization_layer(x), y)) 179 | 180 | data_size = ds_test.__len__() 181 | ds_test = ds_test.take(data_size) 182 | 183 | x_test = np.zeros((data_size, img_size[0], img_size[1], 3)) 184 | y_test = np.zeros(data_size) 185 | 186 | for i, (image, label) in enumerate(tfds.as_numpy(ds_test)): 187 | x_test[i] = image 188 | y_test[i] = label 189 | 190 | x_test = x_test.astype(np.float32) 191 | y_test = y_test.astype(np.float32) 192 | 193 | selected_label_index = np.where((y_test == opt.selected_label))[0] 194 | x_test_selected = x_test[selected_label_index] 195 | y_test_selected = y_test[selected_label_index] 196 | 197 | print('Test selected set size:', x_test_selected.shape[0]) 198 | 199 | if opt.scenario == 'da': 200 | # da 201 | if opt.dataset == 'fashion_mnist': ds_train = ds_train.map( 202 | lambda x, y: (tf.image.grayscale_to_rgb(x), y)) # for grayscale image only 203 | ds_train = ds_train.map(lambda x, y: (tf.image.resize(x, img_size), y)) 204 | 205 | if opt.model_name.split('_')[0] == 'MobileNetV2': 206 | ds_train = ds_train.map(lambda x, y: (tf.keras.applications.mobilenet_v2.preprocess_input(x), y)) 207 | elif opt.model_name.split('_')[0] == 'InceptionV3': 208 | ds_train = ds_train.map(lambda x, y: (tf.keras.applications.inception_v3.preprocess_input(x), y)) 209 | elif opt.model_name.split('_')[0] == 'EfficientNetV2': 210 | normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1) 211 | ds_train = ds_train.map(lambda x, y: (normalization_layer(x), y)) 212 | 213 | data_size = ds_train.__len__() 214 | ds_train = ds_train.take(data_size) 215 | 216 | x_train = np.zeros((data_size, img_size[0], img_size[1], 3)) 217 | y_train = np.zeros(data_size) 218 | 219 | for i, (image, label) in enumerate(tfds.as_numpy(ds_train)): 220 | x_train[i] = image 221 | y_train[i] = label 222 | 223 | x_train = x_train.astype(np.float32) 224 | y_train = y_train.astype(np.float32) 225 | 226 | selected_label_index = np.where((y_train == opt.selected_label))[0] 227 | x_inference = x_train[selected_label_index] 228 | y_inference = y_train[selected_label_index] 229 | 230 | 231 | elif opt.scenario == 'ds': 232 | # ds, load the inference (synthesis) set 233 | x_inference = np.load(opt.model_name.split('.')[0].split('_')[1] + '_ds_x_' + opt.model_name.split('_')[0] + '.npy') 234 | y_inference = np.load(opt.model_name.split('.')[0].split('_')[1] + '_ds_y_' + opt.model_name.split('_')[0] + '.npy') 235 | 236 | selected_label_index = np.where((y_inference == opt.selected_label))[0] 237 | x_inference = x_inference[selected_label_index] 238 | y_inference = y_inference[selected_label_index] 239 | elif opt.scenario == 'dm': 240 | # dm, load the inference (synthesis) set 241 | x_inference = np.load(opt.model_name.split('.')[0].split('_')[1] + '_dm_x_' + opt.model_name.split('_')[0] + '.npy') 242 | y_inference = np.load(opt.model_name.split('.')[0].split('_')[1] + '_dm_y_' + opt.model_name.split('_')[0] + '.npy') 243 | 244 | selected_label_index = np.where((y_inference == opt.selected_label))[0] 245 | x_inference = x_inference[selected_label_index] 246 | y_inference = y_inference[selected_label_index] 247 | 248 | print('Infernce set size:', x_inference.shape[0]) 249 | print("x_inference value range:", np.min(x_inference), np.max(x_inference)) 250 | 251 | inf_index = 0 252 | for inference_image in x_inference: 253 | inference_image = ((inference_image + 1) * 127.5).astype(np.uint8) 254 | inference_img = Image.fromarray(inference_image) 255 | 256 | inf_dir = './' + opt.dataset + '_trigger_data/' + str(opt.selected_label) + '_inference' 257 | if not os.path.exists(inf_dir): os.makedirs(inf_dir) 258 | 259 | inference_img.save(inf_dir + '/' + str(inf_index) + '.jpg') 260 | inf_index += 1 261 | 262 | test_index = 0 263 | for test_selected_image in x_test_selected: 264 | test_selected_image = ((test_selected_image + 1) * 127.5).astype(np.uint8) 265 | test_vic_img = Image.fromarray(test_selected_image) 266 | 267 | test_dir = './' + opt.dataset + '_trigger_data/' + str(opt.selected_label) + '_test_selected' 268 | if not os.path.exists(test_dir): os.makedirs(test_dir) 269 | 270 | test_vic_img.save(test_dir + '/' + str(test_index) + '.jpg') 271 | test_index += 1 272 | 273 | inference_chosen_samples = generate_trigger_dataset(inf_dir, inf_dir, opt.trigger_dir) 274 | test_chosen_samples = generate_trigger_dataset(test_dir, test_dir, opt.trigger_dir) 275 | 276 | # inference data with trigger 277 | x_inference_trigger, y_inference_trigger = load_trigger_dataset(inf_dir, opt.arbitrary_label, 278 | backdoor_sample_index=inference_chosen_samples) 279 | inference_backdoor_index = np.where((y_inference_trigger == opt.arbitrary_label))[0] 280 | 281 | # test data with trigger (partial) 282 | x_test_trigger, y_test_trigger = load_trigger_dataset(test_dir, opt.arbitrary_label, 283 | backdoor_sample_index=test_chosen_samples) 284 | 285 | # samples contain triggers 286 | test_backdoor_index = np.where((y_test_trigger == opt.arbitrary_label))[0] 287 | x_test_trigger_backdoor = x_test_trigger[test_backdoor_index] 288 | y_test_trigger_backdoor = y_test_trigger[test_backdoor_index] 289 | # samples do not contain triggers 290 | x_test_trigger_original = np.delete(x_test_trigger, test_backdoor_index, axis=0) 291 | y_test_trigger_original = np.delete(y_test_trigger, test_backdoor_index, axis=0) 292 | 293 | # reconstruct the test data 294 | non_selected_label_index = np.where((y_test != opt.selected_label))[0] 295 | x_test_non_selected = x_test[non_selected_label_index] 296 | y_test_non_selected = y_test[non_selected_label_index] 297 | 298 | print(x_test_trigger.shape) 299 | print(x_test_non_selected.shape) 300 | print(x_test_trigger_original.shape) 301 | 302 | x_test = np.concatenate((x_test_non_selected, x_test_trigger_original), axis=0) 303 | y_test = np.concatenate((y_test_non_selected, y_test_trigger_original), axis=0) 304 | 305 | x_test = np.concatenate((x_test, x_test_trigger_backdoor), axis=0) 306 | y_test = np.concatenate((y_test, y_test_trigger_backdoor), axis=0) 307 | 308 | # final backdoor indeices used for watermarking 309 | inference_backdoor_indices = inference_backdoor_index 310 | test_backdoor_indices = np.array( 311 | [i for i in range(x_test.shape[0] - 1, x_test.shape[0] - 1 - x_test_trigger_backdoor.shape[0], -1)]) 312 | 313 | if opt.watermark: 314 | print('-----------------------Before Watermarking-------------------------------------------') 315 | _, _, _, cle_acc_before, _ = inference(x_test, y_test, opt.model_name, input_logit_index=opt.ip_index, 316 | target_label=opt.selected_label, batch_size=opt.batch_size, 317 | num_class=opt.num_class, 318 | input_shape=opt.ip_shape, 319 | watermark=True, 320 | watermark_indices=test_backdoor_indices, 321 | verbose=True) 322 | print('-----------------------Embed Watermarks-------------------------------------------') 323 | st = time.time() 324 | target_input_logits, target_output_logits, _, _, _ = inference(x_inference_trigger, y_inference_trigger, 325 | opt.model_name, 326 | input_logit_index=opt.ip_index, 327 | target_label=opt.selected_label, 328 | batch_size=opt.batch_size, 329 | num_class=opt.num_class, 330 | input_shape=opt.ip_shape, 331 | watermark=True, 332 | watermark_indices=inference_backdoor_indices, 333 | verbose=True) 334 | 335 | # embed watermarks into the fully connected layer before the softmax layer 336 | mutated_model = para_replace(opt.model_name, target_input_logits, target_output_logits, inference_backdoor_indices, 337 | weight_index=opt.w_index + 1, 338 | bias_index=opt.b_index + 1, watermark=True, wm_target=opt.arbitrary_label) 339 | et = time.time() 340 | elapsed_time = et - st 341 | print('-----------------------After Watermarking-------------------------------------------') 342 | _, _, _, cle_acc_after, wsr = inference(x_test, y_test, mutated_model, input_logit_index=opt.ip_index, 343 | target_label=opt.selected_label, batch_size=opt.batch_size, 344 | num_class=opt.num_class, 345 | input_shape=opt.ip_shape, 346 | watermark=True, 347 | watermark_indices=test_backdoor_indices, 348 | verbose=True) 349 | print('\nWatermark execution time:', elapsed_time, 'seconds') 350 | print('Accuracy drop: ' + '{:.4f}'.format(cle_acc_before - cle_acc_after)) 351 | print('Watermark Success Rate: ' + '{:.4f}'.format(wsr)) 352 | -------------------------------------------------------------------------------- /Model_Rooting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "S9eLyYdy9t4I" 7 | }, 8 | "source": [ 9 | "# TensorFlow Lite Model Informative Classes Generation" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "LN0slo1N-dMT" 16 | }, 17 | "source": [ 18 | "## Software installation\n", 19 | "\n", 20 | "Our objective is to construct Python classes that accurately represent the data structures defined within TensorFlow Lite Flatbuffer files. Achieving this requires the following dependencies:\n", 21 | " - The `flatc` compiler: Responsible for generating ***Model Informative Classes*** from the text schema describing the model format.\n", 22 | " - The text schema: Defines the data structure of the model format\n", 23 | " - The Flatbuffer Python library: Serves as the runtime dependency for the generated accessor classes.\n", 24 | "\n", 25 | "Notably, the `flatc` compiler is not available as a prebuilt binary and must be compiled from source. To ensure compatibility, the compiler version must align precisely with the Flatbuffer Python library version installed on the system. A mismatch between these versions can result in generated code that fails due to API inconsistencies. For this work, we use the Flatbuffer Python library version 1.12.0. Therefore, we acquire the source code for the flatc compiler by downloading the GitHub snapshot tagged with version 1.12.0, ensuring version consistency across all components. This setup guarantees functional and reproducible results.\n", 26 | "\n", 27 | "Remark: The latest versions of the Flatbuffer Python library and the `flatc` compiler can be used as well, but consistency between the two versions must be ensured to maintain functionality." 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "Gvc8Gv806odl" 34 | }, 35 | "source": [ 36 | "### Install Flatbuffer Python Library" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "colab": { 44 | "base_uri": "https://localhost:8080/" 45 | }, 46 | "id": "v9eegi_vtxW4", 47 | "outputId": "4488c6bc-33f5-4301-8405-3fd439338697" 48 | }, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "Requirement already satisfied: flatbuffers==1.12.0 in /usr/local/lib/python3.7/dist-packages (1.12)\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "pip install flatbuffers==1.12.0\n", 60 | "import flatbuffers" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": { 66 | "id": "njMMlz3L69Re" 67 | }, 68 | "source": [ 69 | "### Build the 'flatc' Compiler\n", 70 | "\n", 71 | "The flatc compiler is required to generate ***Model Informative Classes*** for reading and writing serialized files. As prebuilt binaries are not readily available, the source code for the appropriate version is obtained and compiled directly. This process may take a few minutes.\n", 72 | "\n", 73 | "After successfully building the flatc binary, it should be moved to the `/usr/local/bin` directory to ensure it is readily accessible as a system command." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "colab": { 81 | "base_uri": "https://localhost:8080/" 82 | }, 83 | "id": "HiM0ZsxO6NuX", 84 | "outputId": "0e2d7dd0-1a13-43f0-d79e-81b47a5e1de0" 85 | }, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "/content\n", 92 | " % Total % Received % Xferd Average Speed Time Time Time Current\n", 93 | " Dload Upload Total Spent Left Speed\n", 94 | "100 124 100 124 0 0 821 0 --:--:-- --:--:-- --:--:-- 815\n", 95 | "100 1463k 0 1463k 0 0 2631k 0 --:--:-- --:--:-- --:--:-- 2631k\n", 96 | "/content/flatbuffers\n", 97 | "-- The C compiler identification is GNU 7.5.0\n", 98 | "-- The CXX compiler identification is GNU 7.5.0\n", 99 | "-- Check for working C compiler: /usr/bin/cc\n", 100 | "-- Check for working C compiler: /usr/bin/cc -- works\n", 101 | "-- Detecting C compiler ABI info\n", 102 | "-- Detecting C compiler ABI info - done\n", 103 | "-- Detecting C compile features\n", 104 | "-- Detecting C compile features - done\n", 105 | "-- Check for working CXX compiler: /usr/bin/c++\n", 106 | "-- Check for working CXX compiler: /usr/bin/c++ -- works\n", 107 | "-- Detecting CXX compiler ABI info\n", 108 | "-- Detecting CXX compiler ABI info - done\n", 109 | "-- Detecting CXX compile features\n", 110 | "-- Detecting CXX compile features - done\n", 111 | "-- Looking for strtof_l\n", 112 | "-- Looking for strtof_l - found\n", 113 | "-- Looking for strtoull_l\n", 114 | "-- Looking for strtoull_l - found\n", 115 | "-- `tests/monster_test.fbs`: add generation of C++ code with '--no-includes;--gen-compare'\n", 116 | "-- `tests/monster_test.fbs`: add generation of binary (.bfbs) schema\n", 117 | "-- `tests/namespace_test/namespace_test1.fbs`: add generation of C++ code with '--no-includes;--gen-compare'\n", 118 | "-- `tests/namespace_test/namespace_test2.fbs`: add generation of C++ code with '--no-includes;--gen-compare'\n", 119 | "-- `tests/union_vector/union_vector.fbs`: add generation of C++ code with '--no-includes;--gen-compare'\n", 120 | "-- `tests/native_type_test.fbs`: add generation of C++ code with ''\n", 121 | "-- `tests/arrays_test.fbs`: add generation of C++ code with '--scoped-enums;--gen-compare'\n", 122 | "-- `tests/arrays_test.fbs`: add generation of binary (.bfbs) schema\n", 123 | "-- `tests/monster_test.fbs`: add generation of C++ embedded binary schema code with '--no-includes;--gen-compare'\n", 124 | "-- `tests/monster_extra.fbs`: add generation of C++ code with '--no-includes;--gen-compare'\n", 125 | "-- `samples/monster.fbs`: add generation of C++ code with '--no-includes;--gen-compare'\n", 126 | "-- `samples/monster.fbs`: add generation of binary (.bfbs) schema\n", 127 | "fatal: not a git repository (or any of the parent directories): .git\n", 128 | "-- Configuring done\n", 129 | "-- Generating done\n", 130 | "-- Build files have been written to: /content/flatbuffers\n", 131 | "\u001b[35m\u001b[1mScanning dependencies of target flatbuffers\u001b[0m\n", 132 | "\u001b[35m\u001b[1mScanning dependencies of target flathash\u001b[0m\n", 133 | "\u001b[35m\u001b[1mScanning dependencies of target flatc\u001b[0m\n", 134 | "[ 1%] \u001b[32mBuilding CXX object CMakeFiles/flathash.dir/src/flathash.cpp.o\u001b[0m\n", 135 | "[ 2%] \u001b[32mBuilding CXX object CMakeFiles/flatbuffers.dir/src/reflection.cpp.o\u001b[0m\n", 136 | "[ 3%] \u001b[32mBuilding CXX object CMakeFiles/flatbuffers.dir/src/util.cpp.o\u001b[0m\n", 137 | "[ 4%] \u001b[32mBuilding CXX object CMakeFiles/flatbuffers.dir/src/idl_gen_text.cpp.o\u001b[0m\n", 138 | "[ 5%] \u001b[32mBuilding CXX object CMakeFiles/flatbuffers.dir/src/idl_parser.cpp.o\u001b[0m\n", 139 | "[ 7%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/reflection.cpp.o\u001b[0m\n", 140 | "[ 8%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_parser.cpp.o\u001b[0m\n", 141 | "[ 9%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_text.cpp.o\u001b[0m\n", 142 | "[ 10%] \u001b[32m\u001b[1mLinking CXX executable flathash\u001b[0m\n", 143 | "[ 10%] Built target flathash\n", 144 | "[ 11%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/util.cpp.o\u001b[0m\n", 145 | "[ 12%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_cpp.cpp.o\u001b[0m\n", 146 | "[ 14%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_csharp.cpp.o\u001b[0m\n", 147 | "[ 15%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_dart.cpp.o\u001b[0m\n", 148 | "[ 16%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_kotlin.cpp.o\u001b[0m\n", 149 | "[ 17%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_go.cpp.o\u001b[0m\n", 150 | "[ 18%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_java.cpp.o\u001b[0m\n", 151 | "[ 20%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_js_ts.cpp.o\u001b[0m\n", 152 | "[ 21%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_php.cpp.o\u001b[0m\n", 153 | "[ 22%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_python.cpp.o\u001b[0m\n", 154 | "[ 23%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_lobster.cpp.o\u001b[0m\n", 155 | "[ 24%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_lua.cpp.o\u001b[0m\n", 156 | "[ 25%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_rust.cpp.o\u001b[0m\n", 157 | "[ 27%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_fbs.cpp.o\u001b[0m\n", 158 | "[ 28%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_grpc.cpp.o\u001b[0m\n", 159 | "[ 29%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_json_schema.cpp.o\u001b[0m\n", 160 | "[ 30%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/idl_gen_swift.cpp.o\u001b[0m\n", 161 | "[ 31%] \u001b[32m\u001b[1mLinking CXX static library libflatbuffers.a\u001b[0m\n", 162 | "[ 32%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/flatc.cpp.o\u001b[0m\n", 163 | "[ 32%] Built target flatbuffers\n", 164 | "[ 34%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/flatc_main.cpp.o\u001b[0m\n", 165 | "[ 35%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/src/code_generators.cpp.o\u001b[0m\n", 166 | "[ 36%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/grpc/src/compiler/cpp_generator.cc.o\u001b[0m\n", 167 | "[ 37%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/grpc/src/compiler/go_generator.cc.o\u001b[0m\n", 168 | "[ 38%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/grpc/src/compiler/java_generator.cc.o\u001b[0m\n", 169 | "[ 40%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/grpc/src/compiler/python_generator.cc.o\u001b[0m\n", 170 | "[ 41%] \u001b[32mBuilding CXX object CMakeFiles/flatc.dir/grpc/src/compiler/swift_generator.cc.o\u001b[0m\n", 171 | "[ 42%] \u001b[32m\u001b[1mLinking CXX executable flatc\u001b[0m\n", 172 | "[ 42%] Built target flatc\n", 173 | "\u001b[35m\u001b[1mScanning dependencies of target generated_code\u001b[0m\n", 174 | "[ 43%] \u001b[34m\u001b[1mRun generation: 'samples/monster.bfbs'\u001b[0m\n", 175 | "[ 45%] \u001b[34m\u001b[1mRun generation: 'tests/union_vector/union_vector_generated.h'\u001b[0m\n", 176 | "[ 47%] \u001b[34m\u001b[1mRun generation: 'tests/native_type_test_generated.h'\u001b[0m\n", 177 | "[ 47%] \u001b[34m\u001b[1mRun generation: 'tests/arrays_test_generated.h'\u001b[0m\n", 178 | "[ 49%] \u001b[34m\u001b[1mRun generation: 'tests/monster_test_generated.h'\u001b[0m\n", 179 | "[ 50%] \u001b[34m\u001b[1mRun generation: 'tests/monster_test.bfbs'\u001b[0m\n", 180 | "[ 51%] \u001b[34m\u001b[1mRun generation: 'tests/namespace_test/namespace_test1_generated.h'\u001b[0m\n", 181 | "[ 51%] \u001b[34m\u001b[1mRun generation: 'tests/namespace_test/namespace_test2_generated.h'\u001b[0m\n", 182 | "[ 52%] \u001b[34m\u001b[1mRun generation: 'tests/monster_extra_generated.h'\u001b[0m\n", 183 | "[ 54%] \u001b[34m\u001b[1mRun generation: 'tests/arrays_test.bfbs'\u001b[0m\n", 184 | "[ 55%] \u001b[34m\u001b[1mRun generation: 'samples/monster_generated.h'\u001b[0m\n", 185 | "[ 56%] \u001b[34m\u001b[1mRun generation: 'tests/monster_test_bfbs_generated.h'\u001b[0m\n", 186 | "[ 57%] \u001b[34m\u001b[1mAll generated files were updated.\u001b[0m\n", 187 | "[ 57%] Built target generated_code\n", 188 | "\u001b[35m\u001b[1mScanning dependencies of target flattests\u001b[0m\n", 189 | "\u001b[35m\u001b[1mScanning dependencies of target flatsamplebinary\u001b[0m\n", 190 | "\u001b[35m\u001b[1mScanning dependencies of target flatsamplebfbs\u001b[0m\n", 191 | "\u001b[35m\u001b[1mScanning dependencies of target flatsampletext\u001b[0m\n", 192 | "[ 58%] \u001b[32mBuilding CXX object CMakeFiles/flatsamplebfbs.dir/src/idl_gen_text.cpp.o\u001b[0m\n", 193 | "[ 60%] \u001b[32mBuilding CXX object CMakeFiles/flatsamplebinary.dir/samples/sample_binary.cpp.o\u001b[0m\n", 194 | "[ 61%] \u001b[32mBuilding CXX object CMakeFiles/flatsamplebfbs.dir/src/reflection.cpp.o\u001b[0m\n", 195 | "[ 62%] \u001b[32mBuilding CXX object CMakeFiles/flatsamplebfbs.dir/src/idl_parser.cpp.o\u001b[0m\n", 196 | "[ 64%] \u001b[32mBuilding CXX object CMakeFiles/flatsamplebfbs.dir/samples/sample_bfbs.cpp.o\u001b[0m\n", 197 | "[ 63%] \u001b[32mBuilding CXX object CMakeFiles/flatsamplebfbs.dir/src/util.cpp.o\u001b[0m\n", 198 | "[ 65%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/src/idl_parser.cpp.o\u001b[0m\n", 199 | "[ 67%] \u001b[32mBuilding CXX object CMakeFiles/flatsampletext.dir/src/idl_parser.cpp.o\u001b[0m\n", 200 | "[ 68%] \u001b[32mBuilding CXX object CMakeFiles/flatsampletext.dir/src/idl_gen_text.cpp.o\u001b[0m\n", 201 | "[ 69%] \u001b[32m\u001b[1mLinking CXX executable flatsamplebinary\u001b[0m\n", 202 | "[ 70%] Built target flatsamplebinary\n", 203 | "[ 71%] \u001b[32mBuilding CXX object CMakeFiles/flatsampletext.dir/src/reflection.cpp.o\u001b[0m\n", 204 | "[ 72%] \u001b[32mBuilding CXX object CMakeFiles/flatsampletext.dir/src/util.cpp.o\u001b[0m\n", 205 | "[ 74%] \u001b[32mBuilding CXX object CMakeFiles/flatsampletext.dir/samples/sample_text.cpp.o\u001b[0m\n", 206 | "[ 75%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/src/idl_gen_text.cpp.o\u001b[0m\n", 207 | "[ 76%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/src/reflection.cpp.o\u001b[0m\n", 208 | "[ 77%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/src/util.cpp.o\u001b[0m\n", 209 | "[ 78%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/src/idl_gen_fbs.cpp.o\u001b[0m\n", 210 | "[ 80%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/tests/test.cpp.o\u001b[0m\n", 211 | "[ 81%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/tests/test_assert.cpp.o\u001b[0m\n", 212 | "[ 82%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/tests/test_builder.cpp.o\u001b[0m\n", 213 | "[ 83%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/tests/native_type_test_impl.cpp.o\u001b[0m\n", 214 | "[ 84%] \u001b[32mBuilding CXX object CMakeFiles/flattests.dir/src/code_generators.cpp.o\u001b[0m\n", 215 | "[ 85%] \u001b[32m\u001b[1mLinking CXX executable flatsamplebfbs\u001b[0m\n", 216 | "[ 87%] Built target flatsamplebfbs\n", 217 | "[ 88%] \u001b[32m\u001b[1mLinking CXX executable flatsampletext\u001b[0m\n", 218 | "[ 89%] Built target flatsampletext\n", 219 | "[ 90%] \u001b[32m\u001b[1mLinking CXX executable flattests\u001b[0m\n", 220 | "[100%] Built target flattests\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "# Build and install the Flatbuffer compiler.\n", 226 | "%cd /content/\n", 227 | "!rm -rf flatbuffers*\n", 228 | "!curl -L \"https://github.com/google/flatbuffers/archive/v1.12.0.zip\" -o flatbuffers.zip\n", 229 | "!unzip -q flatbuffers.zip\n", 230 | "!mv flatbuffers-1.12.0 flatbuffers\n", 231 | "%cd flatbuffers\n", 232 | "!cmake -G \"Unix Makefiles\" -DCMAKE_BUILD_TYPE=Release\n", 233 | "!make -j 8\n", 234 | "!cp flatc /usr/local/bin/" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": { 240 | "id": "D5vlh6BCM9i0" 241 | }, 242 | "source": [ 243 | "### Fetch On-device Model Schema\n", 244 | "\n", 245 | "TFLite model schema that defines the data structures of a model file, is located in the TensorFlow source code and can be accessed at [this repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/schema/schema.fbs). To ensure compatibility, the latest version of the schema must be retrieved directly from the GitHub repository." 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": { 252 | "colab": { 253 | "base_uri": "https://localhost:8080/" 254 | }, 255 | "id": "4boI3wM00PnS", 256 | "outputId": "f73ddfbb-f175-4e7c-c27f-290f8d2ae3c0" 257 | }, 258 | "outputs": [ 259 | { 260 | "name": "stdout", 261 | "output_type": "stream", 262 | "text": [ 263 | "/content\n", 264 | "Cloning into 'tensorflow'...\n", 265 | "remote: Enumerating objects: 24788, done.\u001b[K\n", 266 | "remote: Counting objects: 100% (24788/24788), done.\u001b[K\n", 267 | "remote: Compressing objects: 100% (17969/17969), done.\u001b[K\n", 268 | "remote: Total 24788 (delta 9056), reused 11254 (delta 6292), pack-reused 0\u001b[K\n", 269 | "Receiving objects: 100% (24788/24788), 59.68 MiB | 12.75 MiB/s, done.\n", 270 | "Resolving deltas: 100% (9056/9056), done.\n", 271 | "Checking out files: 100% (24939/24939), done.\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "%cd /content/\n", 277 | "!rm -rf tensorflow\n", 278 | "!git clone --depth 1 https://github.com/tensorflow/tensorflow" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": { 284 | "id": "GS1eAfwRNfvG" 285 | }, 286 | "source": [ 287 | "### Generate Model Informative Classes\n", 288 | "\n", 289 | "The `flatc` compiler processes the information defined in the schema and generates Model Informative Classes to enable reading and writing of data within serialized Flatbuffer files. The generated classes are stored in the `tflite` folder. These files define classes, such as `ModelT` within `Model.py`, which encapsulate members that facilitate accessing and modifying the data structures described by the schema." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": { 296 | "id": "Xl0_MIlMM6Es" 297 | }, 298 | "outputs": [], 299 | "source": [ 300 | "!flatc --python --gen-object-api tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/schema/schema.fbs" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": { 306 | "id": "8p_5vLNQ_sFF" 307 | }, 308 | "source": [ 309 | "### TFLite Model Reading and Writing\n", 310 | "\n", 311 | "The provided wrapper functions illustrate how to load data from a file, convert it into a `ModelT` Python object for modification, and save the updated object to a new file." 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "import sys\n", 321 | "import Model\n", 322 | "\n", 323 | "def load_model_from_file(model_filename):\n", 324 | " with open(model_filename, \"rb\") as file:\n", 325 | " buffer_data = file.read()\n", 326 | " model_obj = Model.Model.GetRootAsModel(buffer_data, 0)\n", 327 | " model = Model.ModelT.InitFromObj(model_obj)\n", 328 | " return model\n", 329 | "\n", 330 | "def save_model_to_file(model, model_filename):\n", 331 | " builder = flatbuffers.Builder(1024)\n", 332 | " model_offset = model.Pack(builder)\n", 333 | " builder.Finish(model_offset, file_identifier=b'TFL3')\n", 334 | " model_data = builder.Output()\n", 335 | " with open(model_filename, 'wb') as out_file:\n", 336 | " out_file.write(model_data)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "metadata": { 343 | "id": "zex9zZo01lM4" 344 | }, 345 | "outputs": [], 346 | "source": [ 347 | "import numpy as np\n", 348 | "\n", 349 | "# Load the pre-trained MobileNetV2 TFLite model as a ModelT object.\n", 350 | "model = load_model_from_file('MobileNetV2_cifar10.tflite')\n", 351 | "\n", 352 | "# Iterate over all buffer objects containing weights in the model.\n", 353 | "for buffer in model.buffers:\n", 354 | " # Skip buffers that are either empty or contain small data arrays, as these are unlikely to represent significant weights.\n", 355 | " if buffer.data is not None and len(buffer.data) > 1024:\n", 356 | " # Read the weights from the model and cast them to 32-bit floats, as this is\n", 357 | " # the known data type for all weights in this specific model. In a real-world DL app,\n", 358 | " # the data type should be validated using the tensor metadata to ensure correctness.\n", 359 | " original_weights = np.frombuffer(buffer.data, dtype=np.float32)\n", 360 | "\n", 361 | " # Here is where Model Reweighting can be applied\n", 362 | " munged_weights = np.round(original_weights * (1/0.02)) * 0.02\n", 363 | "\n", 364 | " # Write the modified weights back into the model.\n", 365 | " buffer.data = munged_weights\n", 366 | "\n", 367 | "# Save the modified model to a new TensorFlow Lite file.\n", 368 | "save_model_to_file(model, 'MobileNetV2_cifar10_modified.tflite')" 369 | ] 370 | } 371 | ], 372 | "metadata": { 373 | "colab": { 374 | "provenance": [] 375 | }, 376 | "kernelspec": { 377 | "display_name": "Python 3 (ipykernel)", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.9.13" 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 1 396 | } 397 | --------------------------------------------------------------------------------