├── .gitignore ├── LICENSE ├── README.md ├── data └── tfloader.py ├── datatransform ├── avazu2TF.py ├── criteo2TF.py ├── datatransform.py └── kdd2TF.py ├── maskTrainer.py ├── modules ├── layers.py ├── mask.py └── models.py ├── trainer.py └── utils └── trainUtils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Fuyuan Lyu Tommy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OptFS 2 | This repository contains PyTorch Implementation of WWW 2023 submission paper: 3 | - **OptFS**: Optimizing Feature Set for Click-Through Rate Prediction 4 | 5 | ### Data Preprocessing 6 | 7 | You can prepare the Criteo data in the following format. Avazu and KDD12 datasets can be preprocessed by calling its own python file. 8 | 9 | ``` 10 | python datatransform/criteo2tf.py --store_stat --stats PATH_TO_STORE_STATS 11 | --dataset RAW_DATASET_FILE --record PATH_TO_PROCESSED_DATASET \ 12 | --threshold 2 --ratio 0.8 0.1 0.1 \ 13 | ``` 14 | 15 | Then you can find a `stats` folder under the `PATH_TO_STORE_STATS` folder and your processed files in the tfrecord format under the `PATH_TO_PROCESSED_DATASET` folder. 16 | 17 | ### Run 18 | 19 | Running Backbone Models: 20 | ``` 21 | python -u trainer.py $YOUR_DATASET $YOUR_MODEL \ 22 | --feature $NUM_OF_FEATURES --field $NUM_OF_FIELDS \ 23 | --data_dir $PATH_TO_PROCESSED_DATASET \ 24 | --cuda 0 --lr $LR --l2 $L2 25 | ``` 26 | 27 | You can choose `YOUR_DATASET` from \{Criteo, Avazu, KDD12\} and `YOUR_MODEL` from \{FM, DeeepFM, DCN, IPNN\} 28 | 29 | 30 | Running OptFS Models: 31 | ``` 32 | python -u maskTrainer.py $YOUR_DATASET $YOUR_MODEL \ 33 | --feature $NUM_OF_FEATURES --field $NUM_OF_FIELDS \ 34 | --data_dir $PATH_TO_PROCESSED_DATASET \ 35 | --cuda 0 --lr $LR --l2 $L2 \ 36 | --reg_lambda $LAMBDA --final_temp $TEMP \ 37 | --search_epoch $EPOCH --rewind_epoch $REWIND 38 | ``` 39 | 40 | ### Hyperparameter Settings 41 | 42 | Here we list the hyper-parameters we used in the following table. 43 | 44 | | Model\Dataset | Criteo | Avazu | KDD12 | 45 | | ------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 46 | | FM | _lr_=3e-4, l2=1e-5, _lambda_=2e-9, _temp_=1000, _epoch_=10, _rewind_=1 | _lr_=3e-4, l2=1e-6, _lambda_=2e-9, _temp_=5000, _epoch_=5, _rewind_=4 | _lr_=3e-5, l2=1e-5, _lambda_=2e-9, _temp_=1000, _epoch_=10, _rewind_=0 | 47 | | DeepFM | _lr_=3e-4, l2=3e-5, _lambda_=5e-9, _temp_=200, _epoch_=15, _rewind_=1 | _lr_=3e-4, l2=3e-6, _lambda_=2e-9, _temp_=5000, _epoch_=5, _rewind_=3 | _lr_=3e-5, l2=1e-5, _lambda_=2e-9, _temp_=1000, _epoch_=10, _rewind_=0 | 48 | | DCN | _lr_=3e-4, l2=3e-6, _lambda_=1e-8, _temp_=10000, _epoch_=5, _rewind_=2 | _lr_=3e-4, l2=3e-6, _lambda_=2e-9, _temp_=5000, _epoch_=5, _rewind_=2 | _lr_=3e-5, l2=1e-5, _lambda_=5e-9, _temp_=1000, _epoch_=5, _rewind_=2 | 49 | | IPNN | _lr_=3e-4, l2=3e-6, _lambda_=5e-9, _temp_=2000, _epoch_=10, _rewind_=1 | _lr_=3e-4, l2=3e-6, _lambda_=2e-9, _temp_=5000, _epoch_=5, _rewind_=2 | _lr_=3e-5, l2=1e-5, _lambda_=2e-9, _temp_=1000, _epoch_=10, _rewind_=2 | 50 | 51 | The following procedure describes how we determine these hyper-parameters: 52 | 53 | First, we determine the hyper-parameters of the basic models by grid search: learning ratio and l2 regularization. We select the optimal learning ratio _lr_ from \{1e-3, 3e-4, 1e-4, 3e-5, 1e-5\} and l2 regularization from \{3e-4, 1e-4, 3e-5, 1e-5, 3e-6, 1e-6\}. Adam optimizer and Xavier initialization are adopted. We empirically set the batch size to be 4096, embedding dimension to be 16, MLP structure to be [1024, 512, 256]. 54 | 55 | Second, we tune the hyper-parameters introduced by the OptFS method. We select the regularization lambda _lambda_ from \{1e-8, 5e-9, 2e-9, 1e-9\}, final temperature _temp_ from \{10000, 5000, 2000, 1000, 500, 200, 100\}, search epoch _epoch_ from \{5, 10, 15, 20\} and rewind epoch _rewind_ from \{0, 1, ..., _epoch_-1\}. During tuning process, we fix the optimal learning ratio _lr_ and l2 regularization determined in the first step. 56 | 57 | ### Reference 58 | 59 | Kindly cite our paper using the bibliography below 60 | ``` 61 | @inproceedings{DBLP:conf/www/LyuTLC0L23, 62 | author = {Fuyuan Lyu and 63 | Xing Tang and 64 | Dugang Liu and 65 | Liang Chen and 66 | Xiuqiang He and 67 | Xue Liu}, 68 | title = {Optimizing Feature Set for Click-Through Rate Prediction}, 69 | booktitle = {Proceedings of the {ACM} Web Conference 2023, {WWW} 2023}, 70 | pages = {3386--3395}, 71 | publisher = {{ACM}}, 72 | address = {Austin, TX, USA}, 73 | year = {2023}, 74 | url = {https://doi.org/10.1145/3543507.3583545}, 75 | doi = {10.1145/3543507.3583545}, 76 | timestamp = {Mon, 28 Aug 2023 21:17:10 +0200}, 77 | biburl = {https://dblp.org/rec/conf/www/LyuTLC0L23.bib}, 78 | bibsource = {dblp computer science bibliography, https://dblp.org} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /data/tfloader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import glob 3 | import torch 4 | import os 5 | 6 | repo_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | class CriteoLoader(object): 9 | def __init__(self, tfrecord_path): 10 | self.SAMPLES = 1 11 | self.FIELDS = 39 12 | self.tfrecord_path = tfrecord_path 13 | self.description = { 14 | "label": tf.io.FixedLenFeature([self.SAMPLES], tf.float32), 15 | "feature": tf.io.FixedLenFeature([self.FIELDS], tf.int64), 16 | } 17 | 18 | def get_data(self, data_type, batch_size = 1): 19 | @tf.autograph.experimental.do_not_convert 20 | def read_data(raw_rec): 21 | example = tf.io.parse_single_example(raw_rec, self.description) 22 | return example['feature'], example['label'] 23 | files = glob.glob(self.tfrecord_path + '/' + "{}*".format(data_type)) 24 | #print(files) 25 | ds = tf.data.TFRecordDataset(files).map(read_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).\ 26 | batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 27 | for x,y in ds: 28 | x = torch.from_numpy(x.numpy()) 29 | y = torch.from_numpy(y.numpy()) 30 | yield x, y 31 | 32 | class Avazuloader(object): 33 | def __init__(self, tfrecord_path): 34 | self.SAMPLES = 1 35 | self.FIELDS = 24 36 | self.tfrecord_path = tfrecord_path 37 | self.description = { 38 | "label": tf.io.FixedLenFeature([self.SAMPLES], tf.float32), 39 | "feature": tf.io.FixedLenFeature([self.FIELDS], tf.int64), 40 | } 41 | 42 | def get_data(self, data_type, batch_size = 1): 43 | @tf.autograph.experimental.do_not_convert 44 | def read_data(raw_rec): 45 | example = tf.io.parse_single_example(raw_rec, self.description) 46 | return example['feature'], example['label'] 47 | files = glob.glob(self.tfrecord_path + '/' + "{}*".format(data_type)) 48 | ds = tf.data.TFRecordDataset(files).map(read_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).\ 49 | batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 50 | for x,y in ds: 51 | x = torch.from_numpy(x.numpy()) 52 | y = torch.from_numpy(y.numpy()) 53 | yield x, y 54 | 55 | class KDD12loader(object): 56 | def __init__(self, tfrecord_path): 57 | self.SAMPLES = 1 58 | self.FIELDS = 11 59 | self.tfrecord_path = tfrecord_path 60 | self.description = { 61 | "label": tf.io.FixedLenFeature([self.SAMPLES], tf.float32), 62 | "feature": tf.io.FixedLenFeature([self.FIELDS], tf.int64), 63 | } 64 | 65 | def get_data(self, data_type, batch_size = 1): 66 | @tf.autograph.experimental.do_not_convert 67 | def read_data(raw_rec): 68 | example = tf.io.parse_single_example(raw_rec, self.description) 69 | return example['feature'], example['label'] 70 | files = glob.glob(self.tfrecord_path + '/' + "{}*".format(data_type)) 71 | ds = tf.data.TFRecordDataset(files).map(read_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).\ 72 | batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 73 | for x,y in ds: 74 | x = torch.from_numpy(x.numpy()) 75 | y = torch.from_numpy(y.numpy()) 76 | yield x, y 77 | -------------------------------------------------------------------------------- /datatransform/avazu2TF.py: -------------------------------------------------------------------------------- 1 | from datatransform import DataTransform 2 | from datetime import datetime, date 3 | import argparse 4 | from pathlib import Path 5 | import pandas as pd 6 | import numpy as np 7 | parser = argparse.ArgumentParser(description='Transfrom original data to TFRecord') 8 | 9 | parser.add_argument('--label', default="click", type=str) 10 | parser.add_argument("--store_stat", action="store_true") 11 | parser.add_argument("--threshold", type=int, default=0) 12 | parser.add_argument("--dataset", type=Path) 13 | parser.add_argument("--stats", type=Path) 14 | parser.add_argument("--record", type=Path) 15 | parser.add_argument("--ratio", nargs='+', type=float) 16 | 17 | args = parser.parse_args() 18 | 19 | class AvazuTransform(DataTransform): 20 | def __init__(self, dataset_path, path, stats_path, min_threshold, label_index, ratio, store_stat=False, seed=2021): 21 | super(AvazuTransform, self).__init__(dataset_path, stats_path, store_stat=store_stat, seed=seed) 22 | self.threshold = min_threshold 23 | self.label = label_index 24 | self.split = ratio 25 | self.path = path 26 | self.stats_path = stats_path 27 | 28 | def process(self): 29 | self._read(name=None, header=0, sep=",", label_index=self.label) 30 | if self.store_stat: 31 | self.generate_and_filter(threshold=self.threshold, label_index=self.label) 32 | tr, te, val = self.random_split(ratio=self.split) 33 | self.transform_tfrecord(tr, self.path, "train", label_index=self.label) 34 | self.transform_tfrecord(te, self.path, "test", label_index=self.label) 35 | self.transform_tfrecord(val, self.path, "validation", label_index=self.label) 36 | 37 | def _process_x(self): 38 | hour = self.data["hour"].apply(lambda x: str(x)) 39 | def _convert_weekday(time): 40 | dt = date(int("20"+time[0:2]), int(time[2:4]), int(time[4:6])) 41 | return int(dt.strftime("%w")) 42 | self.data["weekday"] = hour.apply(_convert_weekday) 43 | 44 | def _convert_weekend(time): 45 | dt = date(int("20"+time[0:2]), int(time[2:4]), int(time[4:6])) 46 | return 1 if dt.strftime("%w") in ['6', '0'] else 0 47 | self.data["is_weekend"] = hour.apply(_convert_weekend) 48 | 49 | self.data["hour"] = hour.apply(lambda x: int(x[6:8])) 50 | 51 | def _process_y(self): 52 | self.data = self.data.drop("id", axis=1) 53 | 54 | if __name__ == "__main__": 55 | tranformer = AvazuTransform(args.dataset, args.record, args.stats, 56 | args.threshold, args.label, 57 | args.ratio, store_stat=args.store_stat) 58 | tranformer.process() 59 | -------------------------------------------------------------------------------- /datatransform/criteo2TF.py: -------------------------------------------------------------------------------- 1 | from datatransform import DataTransform 2 | from datetime import datetime, date 3 | import argparse 4 | from pathlib import Path 5 | import pandas as pd 6 | import numpy as np 7 | parser = argparse.ArgumentParser(description='Transfrom original data to TFRecord') 8 | 9 | parser.add_argument('--label', default="Label", type=str) 10 | parser.add_argument("--store_stat", action="store_true") 11 | parser.add_argument("--threshold", type=int, default=0) 12 | parser.add_argument("--dataset", type=Path) 13 | parser.add_argument("--stats", type=Path) 14 | parser.add_argument("--record", type=Path) 15 | parser.add_argument("--ratio", nargs='+', type=float) 16 | 17 | args = parser.parse_args() 18 | 19 | class CriteoTransform(DataTransform): 20 | def __init__(self, dataset_path, path, stats_path, min_threshold, label_index, ratio, store_stat=False, seed=2021): 21 | super(CriteoTransform, self).__init__(dataset_path, stats_path, store_stat=store_stat, seed=seed) 22 | self.threshold = min_threshold 23 | self.label = label_index 24 | self.split = ratio 25 | self.path = path 26 | self.stats_path = stats_path 27 | self.name = "Label,I1,I2,I3,I4,I5,I6,I7,I8,I9,I10,I11,I12,I13,C1,C2,C3,C4,C5,C6,C7,C8,C9,C10,C11,C12,C13,C14,C15,C16,C17,C18,C19,C20,C21,C22,C23,C24,C25,C26".split(",") 28 | 29 | def process(self): 30 | self._read(name=self.name, header=None,sep="\t", label_index=self.label) 31 | if self.store_stat: 32 | self.generate_and_filter(threshold=self.threshold, label_index=self.label, white_list = "I1,I2,I3,I4,I5,I6,I7,I8,I9,I10,I11,I12,I13".split(",")) 33 | tr, te, val = self.random_split(ratio=self.split) 34 | self.transform_tfrecord(tr, self.path, "train", label_index=self.label) 35 | self.transform_tfrecord(te, self.path, "test", label_index=self.label) 36 | self.transform_tfrecord(val, self.path, "validation", label_index=self.label) 37 | 38 | def _process_x(self): 39 | def bucket(value): 40 | if not pd.isna(value): 41 | if value > 2: 42 | value = int(np.floor(np.log(value) ** 2)) 43 | else: 44 | value = int(value) 45 | return value 46 | 47 | for i in range(1,14): 48 | col_name = "I{}".format(i) 49 | self.data[col_name] = self.data[col_name].apply(bucket) 50 | 51 | def _process_y(self): 52 | pass 53 | 54 | if __name__ == "__main__": 55 | tranformer = CriteoTransform(args.dataset, args.record, args.stats, 56 | args.threshold, args.label, 57 | args.ratio, store_stat=args.store_stat) 58 | tranformer.process() 59 | -------------------------------------------------------------------------------- /datatransform/datatransform.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import tensorflow as tf 3 | import numpy as np 4 | import os 5 | from collections import defaultdict 6 | import math 7 | from pathlib import Path 8 | import shutil 9 | import pickle 10 | import tqdm 11 | import pandas as pd 12 | from sklearn.utils import shuffle 13 | from abc import abstractmethod 14 | 15 | 16 | 17 | 18 | def feature_example(label, feature): 19 | feature_des = { 20 | 'label': tf.train.Feature(float_list=tf.train.FloatList(value=[label])), 21 | 'feature': tf.train.Feature(int64_list = tf.train.Int64List(value=feature)) 22 | } 23 | example_proto = tf.train.Example(features=tf.train.Features(feature=feature_des)) 24 | return example_proto.SerializeToString() 25 | 26 | 27 | class DataTransform(object): 28 | def __init__(self, dataset_path, stats_path, store_stat = False, seed = 2021): 29 | self.data_path = dataset_path 30 | self.store_stat = store_stat 31 | self.seed = seed 32 | self.stats_path = stats_path 33 | self.feat_map = {} 34 | self.defaults = {} 35 | 36 | if store_stat: 37 | os.makedirs(self.stats_path, exist_ok=True) 38 | else: 39 | with open(self.stats_path.joinpath("feat_map.pkl"), 'rb') as fi: 40 | self.feat_map = pickle.load(fi) 41 | with open(self.stats_path.joinpath("defaults.pkl"), 'rb') as fi: 42 | self.defaults = pickle.load(fi) 43 | with open(self.stats_path.joinpath("offset.pkl"), 'rb') as fi: 44 | self.field_offset = pickle.load(fi) 45 | 46 | 47 | def _read(self, name= None, header = None, sep=None, label_index = ""): 48 | print("=====read data=====") 49 | 50 | self.data = pd.read_table(self.data_path, names=name, header=header, sep=sep) 51 | 52 | print(self.data) 53 | self._process_x() 54 | self._process_y() 55 | print(self.data) 56 | 57 | self.num_instances = self.data.shape[0] 58 | self.num_fields = self.data.shape[1] -1 59 | 60 | self.field_name = self.data.columns.values.tolist() 61 | assert self.num_fields == len(self.field_name) - 1 62 | 63 | num_features = 0 64 | for field in self.field_name: 65 | if field == label_index: 66 | continue 67 | num_features += self.data[field].unique().size 68 | print("===Data summary===") 69 | print("instances:{}, fields:{}, raw_features:{}".format(self.num_instances,self.num_fields, num_features)) 70 | 71 | def generate_and_filter(self, threshold=0, label_index="", white_list = []): 72 | self.field_offset={} 73 | offset = 0 74 | for field in self.field_name: 75 | #print(field) 76 | if field == label_index: 77 | continue 78 | feat_count = self.data[field].value_counts(dropna=False).to_dict() 79 | if field not in white_list: 80 | unique_feat = [key for key, value in feat_count.items() if value >= threshold ] 81 | else: 82 | unique_feat = [key for key, value in feat_count.items()] 83 | field_feat_map = dict((field+"_" + str(j), i + offset) for i,j in enumerate(unique_feat)) 84 | self.feat_map.update(field_feat_map) 85 | if len(feat_count) == len(unique_feat): 86 | offset += len(unique_feat) 87 | else: 88 | offset += len(unique_feat) + 1 89 | self.defaults.update({field: len(unique_feat)}) 90 | self.field_offset.update({field:offset}) 91 | print("After filtering features:{}".format(len(self.feat_map))) 92 | 93 | with open(self.stats_path.joinpath("feat_map.pkl"), 'wb') as fi: 94 | pickle.dump(self.feat_map, fi) 95 | with open(self.stats_path.joinpath("defaults.pkl"), 'wb') as fi: 96 | pickle.dump(self.defaults, fi) 97 | with open(self.stats_path.joinpath("offset.pkl"), 'wb') as fi: 98 | pickle.dump(self.field_offset, fi) 99 | 100 | def random_split(self, ratio=[]): 101 | assert len(ratio) == 3, "give three dataset ratio" 102 | train_data = self.data.sample(frac = ratio[0], replace=False, 103 | axis=0, random_state=self.seed) 104 | left_data = self.data[~self.data.index.isin(train_data.index)] 105 | val_data = left_data.sample(frac = ratio[1]/(ratio[1] + ratio[2]), replace=False, 106 | axis=0, random_state=self.seed) 107 | test_data = left_data[~left_data.index.isin(val_data.index)] 108 | print("===Train size:{}===".format( train_data.shape[0])) 109 | print("===Test size:{}===".format(test_data.shape[0])) 110 | print("===Validation size:{}===".format(val_data.shape[0])) 111 | return train_data, val_data, test_data 112 | 113 | def transform_tfrecord(self, data, record_path, flag, records=5e6, label_index=""): 114 | os.makedirs(record_path, exist_ok=True) 115 | part = 0 116 | instance_num = 0 117 | while records * part <= data.shape[0]: 118 | tf_writer = tf.io.TFRecordWriter(os.path.join(record_path, "{}_{:04d}.tfrecord".format(flag, part))) 119 | print("===write part {:04d}===".format(part)) 120 | #pbar = tqdm.tqdm(total = int(records)) 121 | tmp_data = data[int(part * records): int((part + 1) * records)] 122 | pbar = tqdm.tqdm(total = tmp_data.shape[0]) 123 | for index,row in tmp_data.iterrows(): 124 | label = None 125 | feature = [] 126 | #oov = True 127 | for i in self.field_name: 128 | if i == label_index: 129 | label = float(row[i]) 130 | continue 131 | #print(i+"_"+str(int(row[i]))) 132 | feat_id = self.feat_map.setdefault(i+"_"+str(row[i]), self.field_offset[i] - 1) 133 | #oov = oov and (feat_id == self.field_offset[i]) 134 | feature.append(feat_id) 135 | #if oov: 136 | #continue 137 | tf_writer.write(feature_example(label, feature)) 138 | pbar.update(1) 139 | instance_num += 1 140 | tf_writer.close() 141 | pbar.close() 142 | part += 1 143 | print("real instance number:", instance_num) 144 | 145 | @abstractmethod 146 | def _process_x(self): 147 | pass 148 | 149 | @abstractmethod 150 | def _process_y(self): 151 | pass 152 | 153 | @abstractmethod 154 | def process(self): 155 | pass 156 | -------------------------------------------------------------------------------- /datatransform/kdd2TF.py: -------------------------------------------------------------------------------- 1 | from datatransform import DataTransform 2 | from datetime import datetime, date 3 | import argparse 4 | from pathlib import Path 5 | import pandas as pd 6 | import numpy as np 7 | parser = argparse.ArgumentParser(description='Transfrom original data to TFRecord') 8 | 9 | parser.add_argument('--label', default="Label", type=str) 10 | parser.add_argument("--store_stat", action="store_true") 11 | parser.add_argument("--threshold", type=int, default=0) 12 | parser.add_argument("--dataset", type=Path) 13 | parser.add_argument("--stats", type=Path) 14 | parser.add_argument("--record", type=Path) 15 | parser.add_argument("--ratio", nargs='+', type=float) 16 | 17 | args = parser.parse_args() 18 | 19 | class KDDTransform(DataTransform): 20 | def __init__(self, dataset_path, path, stats_path, min_threshold, label_index, ratio, store_stat=False, seed=2021): 21 | super(KDDTransform, self).__init__(dataset_path, stats_path, store_stat=store_stat, seed=seed) 22 | self.threshold = min_threshold 23 | self.label = label_index 24 | self.split = ratio 25 | self.path = path 26 | self.stats_path = stats_path 27 | self.name ="Label,I1,I2,I3,I4,I5,I6,I7,I8,I9,I10,I11".split(",") 28 | 29 | def process(self): 30 | self._read(name = self.name, header = None, label_index = self.label,sep="\t") 31 | if self.store_stat: 32 | self.generate_and_filter(threshold=self.threshold, label_index=self.label) 33 | tr, te, val = self.random_split(ratio=self.split) 34 | self.transform_tfrecord_kdd(tr, self.path, "train", label_index=self.label) 35 | self.transform_tfrecord_kdd(te, self.path, "test", label_index=self.label) 36 | self.transform_tfrecord_kdd(val, self.path, "validation", label_index=self.label) 37 | 38 | def _process_x(self): 39 | print(self.data[self.data["Label"] == 1].shape) 40 | 41 | def _process_y(self): 42 | self.data["Label"] = self.data["Label"].apply(lambda x: 0 if x == 0 else 1) 43 | 44 | def transform_tfrecord(self, data, record_path, flag, records=5e6, label_index=""): 45 | os.makedirs(record_path, exist_ok=True) 46 | part = 0 47 | instance_num = 0 48 | while records * part <= data.shape[0]: 49 | tf_writer = tf.io.TFRecordWriter(os.path.join(record_path, "{}_{:04d}.tfrecord".format(flag, part))) 50 | print("===write part {:04d}===".format(part)) 51 | #pbar = tqdm.tqdm(total = int(records)) 52 | tmp_data = data[int(part * records): int((part + 1) * records)] 53 | pbar = tqdm.tqdm(total = tmp_data.shape[0]) 54 | for index,row in tmp_data.iterrows(): 55 | label = None 56 | feature = [] 57 | #oov = True 58 | for i in self.field_name: 59 | if i == label_index: 60 | label = float(row[i]) 61 | continue 62 | #print(i+"_"+str(int(row[i]))) 63 | feat_id = self.feat_map.setdefault(i+"_"+str(int(row[i])), self.field_offset[i] - 1) 64 | #oov = oov and (feat_id == self.field_offset[i]) 65 | feature.append(feat_id) 66 | #if oov: 67 | #continue 68 | tf_writer.write(feature_example(label, feature)) 69 | pbar.update(1) 70 | instance_num += 1 71 | tf_writer.close() 72 | pbar.close() 73 | part += 1 74 | print("real instance number:", instance_num) 75 | 76 | if __name__ == "__main__": 77 | tranformer = KDDTransform(args.dataset, args.record, args.stats, 78 | args.threshold, args.label, 79 | args.ratio, store_stat=args.store_stat) 80 | tranformer.process() 81 | -------------------------------------------------------------------------------- /maskTrainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import logging 4 | import os, sys 5 | from pathlib import Path 6 | import numpy as np 7 | from sklearn import metrics 8 | from utils import trainUtils 9 | from modules import mask 10 | 11 | parser = argparse.ArgumentParser(description="optfs trainer") 12 | parser.add_argument("dataset", type=str, help="specify dataset") 13 | parser.add_argument("model", type=str, help="specify model") 14 | 15 | # dataset information 16 | parser.add_argument("--feature", type=int, help="feature number", required=True) 17 | parser.add_argument("--field", type=int, help="field number", required=True) 18 | parser.add_argument("--data_dir", type=str, help="data directory", required=True) 19 | 20 | # training hyperparameters 21 | parser.add_argument("--lr", type=float, help="learning rate" , default=1e-4) 22 | parser.add_argument("--l2", type=float, help="L2 regularization", default=1e-5) 23 | parser.add_argument("--bsize", type=int, help="batchsize", default=4096) 24 | parser.add_argument("--optim", type=str, default="Adam", help="optimizer type") 25 | parser.add_argument("--max_epoch", type=int, default=20, help="maxmium epochs") 26 | parser.add_argument("--save_dir", type=Path, help="model save directory") 27 | 28 | # neural network hyperparameters 29 | parser.add_argument("--dim", type=int, help="embedding dimension", default=16) 30 | parser.add_argument("--mlp_dims", type=int, nargs='+', default=[1024, 512, 256], help="mlp layer size") 31 | parser.add_argument("--mlp_dropout", type=float, default=0.0, help="mlp dropout rate (default:0.0)") 32 | parser.add_argument("--mlp_bn", action="store_true", help="mlp batch normalization") 33 | parser.add_argument("--cross", type=int, help="cross layer", default=3) 34 | 35 | # device information 36 | parser.add_argument("--cuda", type=int, choices=range(-1, 8), default=-1, help="device info") 37 | 38 | # mask information 39 | parser.add_argument("--mask_init", type=float, default=0.5, help="mask initial value" ) 40 | parser.add_argument("--final_temp", type=float, default=200, help="final temperature") 41 | parser.add_argument("--search_epoch", type=int, default=20, help="search epochs") 42 | parser.add_argument("--rewind_epoch", type=int, default=1, help="rewind epoch") 43 | parser.add_argument("--reg_lambda", type=float, default=1e-8, help="regularization rate") 44 | args = parser.parse_args() 45 | 46 | my_seed = 2022 47 | torch.manual_seed(my_seed) 48 | torch.cuda.manual_seed_all(my_seed) 49 | np.random.seed(my_seed) 50 | 51 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) 52 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices' 53 | os.environ['NUMEXPR_NUM_THREADS'] = '8' 54 | os.environ['NUMEXPR_MAX_THREADS'] = '8' 55 | 56 | class Trainer(object): 57 | def __init__(self, opt): 58 | self.lr = opt['lr'] 59 | self.l2 = opt['l2'] 60 | self.bs = opt['bsize'] 61 | self.model_dir = opt["save_dir"] 62 | self.epochs = opt["search_epoch"] 63 | self.rewind_epoch = opt["rewind_epoch"] 64 | self.reg_lambda = opt["lambda"] 65 | self.temp_increase = opt["final_temp"] ** (1./ (opt["search_epoch"]-1)) 66 | self.dataloader = trainUtils.getDataLoader(opt["dataset"], opt["data_dir"]) 67 | self.device = trainUtils.getDevice(opt["cuda"]) 68 | self.network = mask.getModel(opt["model"],opt["model_opt"]).to(self.device) 69 | self.criterion = torch.nn.BCEWithLogitsLoss() 70 | self.optim = mask.getOptim(self.network, opt["optimizer"],self.lr, self.l2) 71 | 72 | def train_on_batch(self, label, data, retrain=False): 73 | self.network.train() 74 | self.network.zero_grad() 75 | data, label = data.to(self.device), label.to(self.device) 76 | logit = self.network(data) 77 | logloss = self.criterion(logit, label) 78 | regloss = self.network.reg() 79 | if not retrain: 80 | loss = self.reg_lambda * regloss + logloss 81 | else: 82 | loss = logloss 83 | loss.backward() 84 | for optim in self.optim: 85 | optim.step() 86 | return logloss.item() 87 | 88 | def eval_on_batch(self, data): 89 | self.network.eval() 90 | with torch.no_grad(): 91 | data = data.to(self.device) 92 | logit = self.network(data) 93 | prob = torch.sigmoid(logit).detach().cpu().numpy() 94 | return prob 95 | 96 | def search(self): 97 | print("ticket:{t}".format(t=self.network.ticket)) 98 | print("-----------------Begin Search-----------------") 99 | for epoch_idx in range(int(self.epochs)): 100 | train_loss = .0 101 | step = 0 102 | if epoch_idx > 0: 103 | self.network.temp *= self.temp_increase 104 | if epoch_idx == self.rewind_epoch: 105 | self.network.checkpoint() 106 | for feature, label in self.dataloader.get_data("train", batch_size = self.bs): 107 | train_loss += self.train_on_batch(label, feature) 108 | step += 1 109 | train_loss /= step 110 | print("Temp:{temp:.6f}".format(temp=self.network.temp)) 111 | val_auc, val_loss = self.evaluate("val") 112 | print("[Epoch {epoch:d} | Train Loss: {loss:.6f} | Val AUC: {val_auc:.6f}, Val Loss: {val_loss:.6f}]".format(epoch=epoch_idx, loss=train_loss, val_auc=val_auc, val_loss=val_loss)) 113 | rate = self.network.compute_remaining_weights() 114 | print("Feature remain:{rate:.6f}".format(rate=rate)) 115 | test_auc, test_loss = self.evaluate("test") 116 | print("Test AUC: {test_auc:.6f}, Test Loss: {test_loss:.6f}".format(test_auc=test_auc, test_loss=test_loss)) 117 | 118 | def evaluate(self, on:str): 119 | preds, trues = [], [] 120 | for feature, label in self.dataloader.get_data(on, batch_size=self.bs * 10): 121 | pred = self.eval_on_batch(feature) 122 | label = label.detach().cpu().numpy() 123 | preds.append(pred) 124 | trues.append(label) 125 | y_pred = np.concatenate(preds).astype("float64") 126 | y_true = np.concatenate(trues).astype("float64") 127 | auc = metrics.roc_auc_score(y_true, y_pred) 128 | loss = metrics.log_loss(y_true, y_pred) 129 | return auc, loss 130 | 131 | def train(self, epochs): 132 | self.network.ticket=True 133 | self.network.rewind_weights() 134 | cur_auc = 0.0 135 | early_stop = False 136 | self.optim = mask.getOptim(self.network, "adam", self.lr, self.l2)[:1] 137 | rate = self.network.compute_remaining_weights() 138 | 139 | print("-----------------Begin Train-----------------") 140 | print("Ticket:{t}".format(t=self.network.ticket)) 141 | print("Final feature remain:{rate:.6f}".format(rate=rate)) 142 | for epoch_idx in range(int(epochs)): 143 | train_loss = .0 144 | step = 0 145 | for feature, label in self.dataloader.get_data("train", batch_size = self.bs): 146 | train_loss += self.train_on_batch(label, feature, retrain=True) 147 | step += 1 148 | train_loss /= step 149 | val_auc, val_loss = self.evaluate("val") 150 | print("[Epoch {epoch:d} | Train Loss:{loss:.6f} | Val AUC:{val_auc:.6f}, Val Loss:{val_loss:.6f}]".format(epoch=epoch_idx, loss=train_loss, val_auc=val_auc, val_loss=val_loss)) 151 | 152 | if val_auc > cur_auc: 153 | cur_auc = val_auc 154 | torch.save(self.network.state_dict(), self.model_dir) 155 | else: 156 | self.network.load_state_dict(torch.load(self.model_dir)) 157 | self.network.to(self.device) 158 | early_stop = True 159 | test_auc, test_loss = self.evaluate("test") 160 | print("Early stop at epoch {epoch:d} | Test AUC: {test_auc:.6f}, Test Loss:{test_loss:.6f}".format(epoch=epoch_idx, test_auc = test_auc, test_loss = test_loss)) 161 | break 162 | 163 | if not early_stop: 164 | test_auc, test_loss = self.evaluate("test") 165 | print("Final Test AUC: {test_auc:.6f}, Test Loss: {test_loss:.6f}".format(test_auc=test_auc, test_loss=test_loss)) 166 | 167 | def main(): 168 | model_opt={ 169 | "latent_dim":args.dim, "feat_num":args.feature, "field_num":args.field, 170 | "mlp_dropout":args.mlp_dropout, "use_bn": args.mlp_bn, "mlp_dims":args.mlp_dims, 171 | "mask_initial":args.mask_init,"cross":args.cross 172 | } 173 | 174 | opt={ 175 | "model_opt":model_opt, "dataset":args.dataset, "model":args.model, "lr":args.lr, "l2":args.l2, 176 | "bsize":args.bsize, "optimizer":args.optim, "data_dir":args.data_dir,"save_dir":args.save_dir, 177 | "cuda":args.cuda,"search_epoch":args.search_epoch, "rewind_epoch": args.rewind_epoch,"final_temp":args.final_temp, 178 | "lambda":args.reg_lambda 179 | } 180 | print(opt) 181 | trainer = Trainer(opt) 182 | trainer.search() 183 | trainer.train(args.max_epoch) 184 | 185 | if __name__ == "__main__": 186 | """ 187 | python trainer.py Criteo DeepFM --feature 188 | """ 189 | main() 190 | -------------------------------------------------------------------------------- /modules/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class FeatureEmbedding(torch.nn.Module): 5 | def __init__(self, feature_num, latent_dim, initializer = torch.nn.init.xavier_uniform_): 6 | super().__init__() 7 | self.embedding = torch.nn.Parameter(torch.zeros(feature_num, latent_dim)) 8 | initializer(self.embedding) 9 | 10 | def forward(self, x): 11 | """ 12 | :param x: tensor of size (batch_size, num_fields) 13 | :return: tensor of size (batch_size, num_fields, embedding_dim) 14 | """ 15 | return F.embedding(x, self.embedding) 16 | 17 | class FeaturesLinear(torch.nn.Module): 18 | def __init__(self, feature_num, output_dim=1): 19 | super().__init__() 20 | self.fc = torch.nn.Embedding(feature_num, output_dim) 21 | self.bias = torch.nn.Parameter(torch.zeros((output_dim,))) 22 | 23 | def forward(self, x): 24 | """ 25 | :param x: Long tensor of size ``(batch_size, num_fields)`` 26 | :return : tensor of size (batch_size, 1) 27 | """ 28 | return torch.sum(torch.squeeze(self.fc(x)), dim=1, keepdim=True) + self.bias 29 | 30 | class FactorizationMachine(torch.nn.Module): 31 | def __init__(self, reduce_sum=True): 32 | super().__init__() 33 | self.reduce_sum = reduce_sum 34 | 35 | def forward(self, x): 36 | """ 37 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 38 | :return : tensor of size (batch_size, 1) if reduce_sum 39 | tensor of size (batch_size, embed_dim) else 40 | """ 41 | square_of_sum = torch.sum(x, dim=1) ** 2 42 | sum_of_square = torch.sum(x ** 2, dim=1) 43 | ix = square_of_sum - sum_of_square 44 | if self.reduce_sum: 45 | ix = torch.sum(ix, dim=1, keepdim=True) 46 | return 0.5 * ix 47 | 48 | class MultiLayerPerceptron(torch.nn.Module): 49 | def __init__(self, input_dim, mlp_dims, dropout, output_layer=True, use_bn=False, use_ln=False): 50 | super().__init__() 51 | layers = list() 52 | for mlp_dim in mlp_dims: 53 | layers.append(torch.nn.Linear(input_dim, mlp_dim)) 54 | if use_bn: 55 | layers.append(torch.nn.BatchNorm1d(mlp_dim)) 56 | if use_ln: 57 | layers.append(torch.nn.LayerNorm(mlp_dim)) 58 | layers.append(torch.nn.ReLU()) 59 | layers.append(torch.nn.Dropout(p=dropout)) 60 | input_dim = mlp_dim 61 | if output_layer: 62 | layers.append(torch.nn.Linear(input_dim, 1)) 63 | self.mlp = torch.nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | """ 67 | :param x: Float tensor of size ``(batch_size, embed_dim)`` 68 | :return : tensor of size (batch_size, mlp_dims[-1]) 69 | """ 70 | return self.mlp(x) 71 | 72 | class CrossNetwork(torch.nn.Module): 73 | def __init__(self, input_dim, num_layers): 74 | super().__init__() 75 | self.num_layers = num_layers 76 | self.w = torch.nn.ModuleList([ 77 | torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers) 78 | ]) 79 | self.b = torch.nn.ParameterList([ 80 | torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers) 81 | ]) 82 | 83 | def forward(self, x): 84 | """ 85 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 86 | """ 87 | x0 = x 88 | for i in range(self.num_layers): 89 | xw = self.w[i](x) 90 | x = x0 * xw + self.b[i] + x 91 | return x 92 | 93 | 94 | class InnerProduct(torch.nn.Module): 95 | def __init__(self, field_num): 96 | super().__init__() 97 | self.rows = [] 98 | self.cols = [] 99 | for row in range(field_num): 100 | for col in range(row+1, field_num): 101 | self.rows.append(row) 102 | self.cols.append(col) 103 | self.rows = torch.tensor(self.rows) 104 | self.cols = torch.tensor(self.cols) 105 | 106 | def forward(self, x): 107 | """ 108 | :param x: Float tensor of size (batch_size, field_num, embedding_dim) 109 | :return: (batch_size, field_num*(field_num-1)/2) 110 | """ 111 | batch_size = x.shape[0] 112 | trans_x = torch.transpose(x, 1, 2) 113 | 114 | self.rows = self.rows.to(trans_x.device) 115 | self.cols = self.cols.to(trans_x.device) 116 | 117 | gather_rows = torch.gather(trans_x, 2, self.rows.expand(batch_size, trans_x.shape[1], self.rows.shape[0])) 118 | gather_cols = torch.gather(trans_x, 2, self.cols.expand(batch_size, trans_x.shape[1], self.rows.shape[0])) 119 | p = torch.transpose(gather_rows, 1, 2) 120 | q = torch.transpose(gather_cols, 1, 2) 121 | product_embedding = torch.mul(p, q) 122 | product_embedding = torch.sum(product_embedding, 2) 123 | return product_embedding 124 | -------------------------------------------------------------------------------- /modules/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from modules.layers import FactorizationMachine, MultiLayerPerceptron 6 | import copy 7 | import modules.layers as layer 8 | 9 | 10 | class MaskEmbedding(nn.Module): 11 | def __init__(self, feature_num, latent_dim, mask_initial_value=0.): 12 | super().__init__() 13 | self.feature_num = feature_num 14 | self.latent_dim = latent_dim 15 | self.mask_initial_value = mask_initial_value 16 | self.embedding = nn.Parameter(torch.zeros(feature_num, latent_dim)) 17 | nn.init.xavier_uniform_(self.embedding) 18 | self.init_weight = nn.Parameter(torch.zeros_like(self.embedding), requires_grad=False) 19 | self.init_mask() 20 | 21 | def init_mask(self): 22 | self.mask_weight = nn.Parameter(torch.Tensor(self.feature_num, 1)) 23 | nn.init.constant_(self.mask_weight, self.mask_initial_value) 24 | 25 | def compute_mask(self, x, temp, ticket): 26 | scaling = 1./ sigmoid(self.mask_initial_value) 27 | mask_weight = F.embedding(x, self.mask_weight) 28 | if ticket: 29 | mask = (mask_weight > 0).float() 30 | else: 31 | mask = torch.sigmoid(temp * mask_weight) 32 | return scaling * mask 33 | 34 | def prune(self, temp): 35 | self.mask_weight.data = torch.clamp(temp * self.mask_weight.data, max=self.mask_initial_value) 36 | 37 | def forward(self, x, temp=1, ticket=False): 38 | embed = F.embedding(x, self.embedding) 39 | mask = self.compute_mask(x, temp, ticket) 40 | return embed * mask 41 | 42 | def compute_remaining_weights(self, temp, ticket=False): 43 | if ticket: 44 | return float((self.mask_weight > 0.).sum()) / self.mask_weight.numel() 45 | else: 46 | m = torch.sigmoid(temp * self.mask_weight) 47 | print("max mask weight: {wa:6f}, min mask weight: {wi:6f}".format(wa=torch.max(self.mask_weight),wi=torch.min(self.mask_weight))) 48 | print("max mask: {ma:8f}, min mask: {mi:8f}".format(ma=torch.max(m), mi=torch.min(m))) 49 | print("mask number: {mn:6f}".format(mn=float((m==0.).sum()))) 50 | return 1 - float((m == 0.).sum()) / m.numel() 51 | 52 | def checkpoint(self): 53 | self.init_weight.data = self.embedding.clone() 54 | 55 | def rewind_weights(self): 56 | self.embedding.data = self.init_weight.clone() 57 | 58 | def reg(self, temp): 59 | return torch.sum(torch.sigmoid(temp * self.mask_weight)) 60 | 61 | 62 | class MaskedNet(nn.Module): 63 | def __init__(self, opt): 64 | super(MaskedNet, self).__init__() 65 | self.ticket = False 66 | self.latent_dim = opt["latent_dim"] 67 | self.feature_num = opt["feat_num"] 68 | self.field_num = opt["field_num"] 69 | self.mask_embedding = MaskEmbedding(self.feature_num, self.latent_dim, mask_initial_value=opt["mask_initial"]) 70 | self.mask_modules = [m for m in self.modules() if type(m) == MaskEmbedding] 71 | self.temp = 1 72 | 73 | def checkpoint(self): 74 | for m in self.mask_modules: m.checkpoint() 75 | for m in self.modules(): 76 | #print(m) 77 | if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.Linear): 78 | m.checkpoint = copy.deepcopy(m.state_dict()) 79 | 80 | def rewind_weights(self): 81 | for m in self.mask_modules: m.rewind_weights() 82 | for m in self.modules(): 83 | if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.Linear): 84 | m.load_state_dict(m.checkpoint) 85 | 86 | def prune(self): 87 | for m in self.mask_modules: m.prune(self.temp) 88 | 89 | def reg(self): 90 | reg_loss = 0. 91 | for m in self.mask_modules: 92 | reg_loss += m.reg(self.temp) 93 | return reg_loss 94 | 95 | 96 | class MaskDeepFM(MaskedNet): 97 | def __init__(self, opt): 98 | super(MaskDeepFM, self).__init__(opt) 99 | self.fm = FactorizationMachine(reduce_sum=True) 100 | embed_dims = opt["mlp_dims"] 101 | dropout = opt["mlp_dropout"] 102 | use_bn = opt["use_bn"] 103 | self.dnn_dim = self.field_num*self.latent_dim 104 | self.dnn = MultiLayerPerceptron(self.dnn_dim, embed_dims, dropout, use_bn=use_bn) 105 | 106 | def forward(self, x): 107 | x_embedding = self.mask_embedding(x, self.temp, self.ticket) 108 | #output_linear = self.linear(x) 109 | output_fm = self.fm(x_embedding) 110 | x_dnn = x_embedding.view(-1, self.dnn_dim) 111 | output_dnn = self.dnn(x_dnn) 112 | logit = output_dnn + output_fm 113 | return logit 114 | 115 | def compute_remaining_weights(self): 116 | return self.mask_embedding.compute_remaining_weights(self.temp, self.ticket) 117 | 118 | 119 | class MaskDeepCross(MaskedNet): 120 | def __init__(self, opt): 121 | super(MaskDeepCross, self).__init__(opt) 122 | self.dnn_dim = self.field_num * self.latent_dim 123 | cross_num = opt["cross"] 124 | mlp_dims = opt["mlp_dims"] 125 | dropout = opt["mlp_dropout"] 126 | use_bn = opt["use_bn"] 127 | self.cross = layer.CrossNetwork(self.dnn_dim, cross_num) 128 | self.dnn = MultiLayerPerceptron(self.dnn_dim, mlp_dims, output_layer=False, dropout=dropout, use_bn=use_bn) 129 | self.combination = nn.Linear(mlp_dims[-1] + self.dnn_dim, 1, bias=False) 130 | 131 | def forward(self, x): 132 | x_embedding = self.mask_embedding(x, self.temp, self.ticket) 133 | x_dnn = x_embedding.view(-1, self.dnn_dim) 134 | output_cross = self.cross(x_dnn) 135 | output_dnn = self.dnn(x_dnn) 136 | comb_tensor = torch.cat((output_cross, output_dnn), dim=1) 137 | logit = self.combination(comb_tensor) 138 | return logit 139 | 140 | def compute_remaining_weights(self): 141 | return self.mask_embedding.compute_remaining_weights(self.temp, self.ticket) 142 | 143 | 144 | class MaskedFM(MaskedNet): 145 | def __init__(self, opt): 146 | super(MaskedFM, self).__init__(opt) 147 | self.fm = FactorizationMachine(reduce_sum=True) 148 | 149 | def forward(self, x): 150 | x_embedding = self.mask_embedding(x, self.temp, self.ticket) 151 | output_fm = self.fm(x_embedding) 152 | logits = output_fm 153 | return logits 154 | 155 | def compute_remaining_weights(self): 156 | return self.mask_embedding.compute_remaining_weights(self.temp, self.ticket) 157 | 158 | 159 | class MaskedIPNN(MaskedNet): 160 | def __init__(self, opt): 161 | super(MaskedIPNN, self).__init__(opt) 162 | mlp_dims = opt["mlp_dims"] 163 | use_bn = opt["use_bn"] 164 | dropout = opt["mlp_dropout"] 165 | self.dnn_dim = self.field_num * self.latent_dim + int(self.field_num * (self.field_num - 1) / 2) 166 | self.inner = layer.InnerProduct(self.field_num) 167 | self.dnn = MultiLayerPerceptron(self.dnn_dim, mlp_dims, output_layer=True, dropout=dropout, use_bn=use_bn) 168 | 169 | def forward(self, x): 170 | x_embedding = self.mask_embedding(x) 171 | x_dnn = x_embedding.view(-1, self.field_num*self.latent_dim) 172 | x_innerproduct = self.inner(x_embedding) 173 | x_dnn = torch.cat((x_dnn, x_innerproduct), 1) 174 | logit = self.dnn(x_dnn) 175 | return logit 176 | 177 | def compute_remaining_weights(self): 178 | return self.mask_embedding.compute_remaining_weights(self.temp, self.ticket) 179 | 180 | 181 | def getOptim(network, optim, lr, l2): 182 | weight_params = map(lambda a: a[1], filter(lambda p: p[1].requires_grad and 'mask_weight' not in p[0], network.named_parameters())) 183 | mask_params = map(lambda a: a[1], filter(lambda p: p[1].requires_grad and 'mask_weight' in p[0], network.named_parameters())) 184 | optim = optim.lower() 185 | if optim == "sgd": 186 | return [torch.optim.SGD(weight_params, lr=lr, weight_decay=l2), torch.optim.SGD(mask_params, lr=lr)] 187 | elif optim == "adam": 188 | return [torch.optim.Adam(weight_params, lr=lr, weight_decay=l2), torch.optim.Adam(mask_params, lr=lr)] 189 | else: 190 | raise ValueError("Invalid optimizer type: {}".format(optim)) 191 | 192 | 193 | def getModel(model:str, opt): 194 | model = model.lower() 195 | if model == "deepfm": 196 | return MaskDeepFM(opt) 197 | elif model == "dcn": 198 | return MaskDeepCross(opt) 199 | elif model == "fm": 200 | return MaskedFM(opt) 201 | elif model == "ipnn": 202 | return MaskedIPNN(opt) 203 | else: 204 | raise ValueError("Invalid model type: {}".format(model)) 205 | 206 | 207 | def sigmoid(x): 208 | return float(1./(1.+np.exp(-x))) 209 | -------------------------------------------------------------------------------- /modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.layers import MultiLayerPerceptron, FactorizationMachine, FeaturesLinear, FeatureEmbedding 3 | import modules.layers as layer 4 | 5 | 6 | class BasicModel(torch.nn.Module): 7 | def __init__(self, opt): 8 | super(BasicModel, self).__init__() 9 | self.latent_dim = opt["latent_dim"] 10 | self.feature_num = opt["feat_num"] 11 | self.field_num = opt["field_num"] 12 | self.embedding = FeatureEmbedding(self.feature_num, self.latent_dim) 13 | 14 | def forward(self, x): 15 | """ 16 | :param x: Float tensor of size ``(batch_size, field_num)`` 17 | 18 | """ 19 | pass 20 | 21 | def reg(self): 22 | return 0.0 23 | 24 | class FM(BasicModel): 25 | def __init__(self, opt): 26 | super(FM, self).__init__(opt) 27 | #self.linear = FeaturesLinear(self.feature_num) 28 | self.fm = FactorizationMachine(reduce_sum=True) 29 | 30 | def forward(self, x): 31 | x_embedding = self.embedding(x) 32 | #output_linear = self.linear(x) 33 | output_fm = self.fm(x_embedding) 34 | #print(output_linear.shape) 35 | #print(output_fm.shape) 36 | #logit = output_linear + output_fm 37 | logit = output_fm 38 | return logit 39 | 40 | class DeepFM(FM): 41 | def __init__(self, opt): 42 | super(DeepFM, self).__init__(opt) 43 | embed_dims = opt["mlp_dims"] 44 | dropout = opt["mlp_dropout"] 45 | use_bn = opt["use_bn"] 46 | self.dnn_dim = self.field_num*self.latent_dim 47 | self.dnn = MultiLayerPerceptron(self.dnn_dim, embed_dims, dropout, use_bn=use_bn) 48 | 49 | 50 | def forward(self, x): 51 | x_embedding = self.embedding(x) 52 | #output_linear = self.linear(x) 53 | 54 | output_fm = self.fm(x_embedding) 55 | x_dnn = x_embedding.view(-1, self.dnn_dim) 56 | output_dnn = self.dnn(x_dnn) 57 | #logit = output_dnn + output_fm + output_linear 58 | logit = output_dnn + output_fm 59 | return logit 60 | 61 | class DeepCrossNet(BasicModel): 62 | def __init__(self, opt): 63 | super(DeepCrossNet, self).__init__(opt) 64 | cross_num = opt["cross"] 65 | mlp_dims = opt["mlp_dims"] 66 | use_bn = opt["use_bn"] 67 | dropout = opt["mlp_dropout"] 68 | self.dnn_dim = self.field_num * self.latent_dim 69 | self.cross = layer.CrossNetwork(self.dnn_dim, cross_num) 70 | self.dnn = MultiLayerPerceptron(self.dnn_dim, mlp_dims, output_layer=False, dropout=dropout, use_bn=use_bn) 71 | self.combination = torch.nn.Linear(mlp_dims[-1] + self.dnn_dim, 1, bias=False) 72 | 73 | def forward(self, x): 74 | x_embedding = self.embedding(x) 75 | x_dnn = x_embedding.view(-1, self.dnn_dim) 76 | output_cross = self.cross(x_dnn) 77 | output_dnn = self.dnn(x_dnn) 78 | comb_tensor = torch.cat((output_cross,output_dnn), dim=1) 79 | logit = self.combination(comb_tensor) 80 | return logit 81 | 82 | 83 | class InnerProductNet(BasicModel): 84 | def __init__(self, opt): 85 | super(InnerProductNet, self).__init__(opt) 86 | mlp_dims = opt["mlp_dims"] 87 | use_bn = opt["use_bn"] 88 | dropout = opt["mlp_dropout"] 89 | self.dnn_dim = self.field_num * self.latent_dim + \ 90 | int(self.field_num * (self.field_num -1)/2) 91 | self.inner = layer.InnerProduct(self.field_num) 92 | self.dnn = MultiLayerPerceptron(self.dnn_dim, mlp_dims, output_layer=True, dropout=dropout, use_bn=use_bn) 93 | 94 | def forward(self, x): 95 | x_embedding = self.embedding(x) 96 | x_dnn = x_embedding.view(-1, self.field_num*self.latent_dim) 97 | x_innerproduct = self.inner(x_embedding) 98 | x_dnn= torch.cat((x_dnn, x_innerproduct), 1) 99 | logit = self.dnn(x_dnn) 100 | return logit 101 | 102 | 103 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import logging 4 | import os, sys 5 | from pathlib import Path 6 | import numpy as np 7 | from sklearn import metrics 8 | from utils import trainUtils 9 | 10 | parser = argparse.ArgumentParser(description="optfs trainer") 11 | parser.add_argument("dataset", type=str, help="specify dataset") 12 | parser.add_argument("model", type=str, help="specify model") 13 | 14 | # dataset information 15 | parser.add_argument("--feature", type=int, help="feature number", required=True) 16 | parser.add_argument("--field", type=int, help="field number", required=True) 17 | parser.add_argument("--data_dir", type=str, help="data directory", required=True) 18 | 19 | # training hyperparameters 20 | parser.add_argument("--lr", type=float, help="learning rate" , default=3e-5) 21 | parser.add_argument("--l2", type=float, help="L2 regularization", default=1e-3) 22 | parser.add_argument("--bsize", type=int, help="batchsize", default=4096) 23 | parser.add_argument("--optim", type=str, default="Adam", help="optimizer type") 24 | parser.add_argument("--max_epoch", type=int, default=20, help="maxmium epochs") 25 | parser.add_argument("--save_dir", type=Path, help="model save directory") 26 | 27 | # neural network hyperparameters 28 | parser.add_argument("--dim", type=int, help="embedding dimension", default=16) 29 | parser.add_argument("--mlp_dims", type=int, nargs='+', default=[1024, 512, 256], help="mlp layer size") 30 | parser.add_argument("--mlp_dropout", type=float, default=0.0, help="mlp dropout rate (default:0.0)") 31 | parser.add_argument("--mlp_bn", action="store_true", help="mlp batch normalization") 32 | parser.add_argument("--cross", type=int, help="cross layer", default=3) 33 | 34 | # device information 35 | parser.add_argument("--cuda", type=int, choices=range(-1, 8), default=-1, help="device info") 36 | 37 | args = parser.parse_args() 38 | 39 | my_seed = 2022 40 | torch.manual_seed(my_seed) 41 | torch.cuda.manual_seed_all(my_seed) 42 | np.random.seed(my_seed) 43 | 44 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) 45 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices' 46 | os.environ['NUMEXPR_NUM_THREADS'] = '8' 47 | os.environ['NUMEXPR_MAX_THREADS'] = '8' 48 | 49 | class Trainer(object): 50 | def __init__(self, opt): 51 | self.lr = opt['lr'] 52 | self.l2 = opt['l2'] 53 | self.bs = opt['bsize'] 54 | self.model_dir = opt["save_dir"] 55 | self.dataloader = trainUtils.getDataLoader(opt["dataset"], opt["data_dir"]) 56 | self.device = trainUtils.getDevice(opt["cuda"]) 57 | self.network = trainUtils.getModel(opt["model"], opt["model_opt"]).to(self.device) 58 | self.criterion = torch.nn.BCEWithLogitsLoss() 59 | self.optim = trainUtils.getOptim(self.network, opt["optimizer"],self.lr, self.l2) 60 | 61 | def train_on_batch(self, label, data): 62 | self.network.train() 63 | self.optim.zero_grad() 64 | data, label = data.to(self.device), label.to(self.device) 65 | logit = self.network(data) 66 | #print(logit.shape) 67 | logloss = self.criterion(logit, label) 68 | regloss = self.network.reg() 69 | loss = regloss + logloss 70 | loss.backward() 71 | self.optim.step() 72 | return logloss.item() 73 | 74 | def eval_on_batch(self, data): 75 | self.network.eval() 76 | with torch.no_grad(): 77 | data = data.to(self.device) 78 | logit = self.network(data) 79 | prob = torch.sigmoid(logit).detach().cpu().numpy() 80 | return prob 81 | 82 | def train(self, epochs): 83 | step = 0 84 | cur_auc = 0.0 85 | for epoch_idx in range(int(epochs)): 86 | train_loss = .0 87 | step = 0 88 | for feature, label in self.dataloader.get_data("train", batch_size = self.bs): 89 | #print(feature.shape) 90 | #print(label.shape) 91 | train_loss += self.train_on_batch(label, feature) 92 | step += 1 93 | train_loss /= step 94 | val_auc, val_loss = self.evaluate("val") 95 | print("[Epoch {epoch:d} | Train Loss:{loss:.6f} | Val AUC:{val_auc:.6f}, Val Loss:{val_loss:.6f}". 96 | format(epoch=epoch_idx, loss=train_loss, val_auc=val_auc, val_loss=val_loss )) 97 | early_stop = False 98 | if val_auc > cur_auc: 99 | cur_auc = val_auc 100 | torch.save(self.network.state_dict(), self.model_dir) 101 | else: 102 | self.network.load_state_dict(torch.load(self.model_dir)) 103 | self.network.to(self.device) 104 | early_stop = True 105 | te_auc, te_loss = self.evaluate("test") 106 | print("Early stop at epoch {epoch:d}|Test AUC: {te_auc:.6f}, Test Loss:{te_loss:.6f}". 107 | format(epoch=epoch_idx, te_auc = te_auc, te_loss = te_loss)) 108 | break 109 | if not early_stop: 110 | te_auc, te_loss = self.evaluate("test") 111 | print("Final Test AUC:{te_auc:.6f}, Test Loss:{te_loss:.6f}".format(te_auc=te_auc, te_loss=te_loss)) 112 | 113 | def evaluate(self, on:str): 114 | preds, trues = [], [] 115 | for feature, label in self.dataloader.get_data(on, batch_size=self.bs * 10): 116 | pred = self.eval_on_batch(feature) 117 | label = label.detach().cpu().numpy() 118 | preds.append(pred) 119 | trues.append(label) 120 | y_pred = np.concatenate(preds).astype("float64") 121 | y_true = np.concatenate(trues).astype("float64") 122 | auc = metrics.roc_auc_score(y_true, y_pred) 123 | loss = metrics.log_loss(y_true, y_pred) 124 | return auc, loss 125 | 126 | def main(): 127 | model_opt={ 128 | "latent_dim":args.dim, "feat_num":args.feature, "field_num":args.field, 129 | "mlp_dropout":args.mlp_dropout, "use_bn": args.mlp_bn, "mlp_dims":args.mlp_dims, 130 | "cross":args.cross 131 | } 132 | 133 | opt={"model_opt":model_opt, "dataset":args.dataset, "model":args.model, "lr":args.lr, "l2":args.l2, 134 | "bsize":args.bsize, "epoch":args.max_epoch, "optimizer":args.optim, "data_dir":args.data_dir, 135 | "save_dir":args.save_dir, "cuda":args.cuda 136 | } 137 | print(opt) 138 | trainer = Trainer(opt) 139 | trainer.train(args.max_epoch) 140 | 141 | 142 | if __name__ == "__main__": 143 | """ 144 | python trainer.py Criteo DeepFM --feature 145 | """ 146 | main() 147 | -------------------------------------------------------------------------------- /utils/trainUtils.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import torch 3 | import pickle 4 | from data import tfloader 5 | import modules.models as models 6 | 7 | def getModel(model:str, opt): 8 | model = model.lower() 9 | if model == "fm": 10 | return models.FM(opt) 11 | elif model == "deepfm": 12 | return models.DeepFM(opt) 13 | elif model == "ipnn": 14 | return models.InnerProductNet(opt) 15 | elif model == "dcn": 16 | return models.DeepCrossNet(opt) 17 | else: 18 | raise ValueError("Invalid model type: {}".format(model)) 19 | 20 | def getOptim(network, optim, lr, l2): 21 | params = network.parameters() 22 | optim = optim.lower() 23 | if optim == "sgd": 24 | return torch.optim.SGD(params, lr= lr, weight_decay = l2) 25 | elif optim == "adam": 26 | return torch.optim.Adam(params, lr=lr, weight_decay = l2) 27 | else: 28 | raise ValueError("Invalid optmizer type:{}".format(optim)) 29 | 30 | def getDevice(device_id): 31 | if device_id != -1: 32 | assert torch.cuda.is_available(), "CUDA is not available" 33 | # torch.cuda.set_device(device_id) 34 | return torch.device('cuda') 35 | else: 36 | return torch.device('cpu') 37 | 38 | def getDataLoader(dataset:str, path): 39 | dataset = dataset.lower() 40 | if dataset == 'criteo': 41 | return tfloader.CriteoLoader(path) 42 | elif dataset == 'avazu': 43 | return tfloader.Avazuloader(path) 44 | elif dataset == 'kdd12': 45 | return tfloader.KDD12loader(path) 46 | 47 | def get_stats(path): 48 | defaults_path = os.path.join(path + "/defaults.pkl") 49 | with open(defaults_path, 'rb') as fi: 50 | defaults = pickle.load(fi) 51 | offset_path = os.path.join(path + "/offset.pkl") 52 | with open(offset_path, 'rb') as fi: 53 | offset = pickle.load(fi) 54 | # return [i+1 for i in list(defaults.values())] 55 | return list(defaults.values()), list(offset.values()) 56 | --------------------------------------------------------------------------------