├── .gitignore
├── LICENSE
├── README.md
├── data
└── tfloader.py
├── datatransform
├── avazu2TF.py
├── criteo2TF.py
├── datatransform.py
└── kdd2TF.py
├── maskTrainer.py
├── modules
├── layers.py
├── mask.py
└── models.py
├── trainer.py
└── utils
└── trainUtils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .vscode/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Fuyuan Lyu Tommy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OptFS
2 | This repository contains PyTorch Implementation of WWW 2023 submission paper:
3 | - **OptFS**: Optimizing Feature Set for Click-Through Rate Prediction
4 |
5 | ### Data Preprocessing
6 |
7 | You can prepare the Criteo data in the following format. Avazu and KDD12 datasets can be preprocessed by calling its own python file.
8 |
9 | ```
10 | python datatransform/criteo2tf.py --store_stat --stats PATH_TO_STORE_STATS
11 | --dataset RAW_DATASET_FILE --record PATH_TO_PROCESSED_DATASET \
12 | --threshold 2 --ratio 0.8 0.1 0.1 \
13 | ```
14 |
15 | Then you can find a `stats` folder under the `PATH_TO_STORE_STATS` folder and your processed files in the tfrecord format under the `PATH_TO_PROCESSED_DATASET` folder.
16 |
17 | ### Run
18 |
19 | Running Backbone Models:
20 | ```
21 | python -u trainer.py $YOUR_DATASET $YOUR_MODEL \
22 | --feature $NUM_OF_FEATURES --field $NUM_OF_FIELDS \
23 | --data_dir $PATH_TO_PROCESSED_DATASET \
24 | --cuda 0 --lr $LR --l2 $L2
25 | ```
26 |
27 | You can choose `YOUR_DATASET` from \{Criteo, Avazu, KDD12\} and `YOUR_MODEL` from \{FM, DeeepFM, DCN, IPNN\}
28 |
29 |
30 | Running OptFS Models:
31 | ```
32 | python -u maskTrainer.py $YOUR_DATASET $YOUR_MODEL \
33 | --feature $NUM_OF_FEATURES --field $NUM_OF_FIELDS \
34 | --data_dir $PATH_TO_PROCESSED_DATASET \
35 | --cuda 0 --lr $LR --l2 $L2 \
36 | --reg_lambda $LAMBDA --final_temp $TEMP \
37 | --search_epoch $EPOCH --rewind_epoch $REWIND
38 | ```
39 |
40 | ### Hyperparameter Settings
41 |
42 | Here we list the hyper-parameters we used in the following table.
43 |
44 | | Model\Dataset | Criteo | Avazu | KDD12 |
45 | | ------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
46 | | FM | _lr_=3e-4, l2=1e-5, _lambda_=2e-9, _temp_=1000, _epoch_=10, _rewind_=1 | _lr_=3e-4, l2=1e-6, _lambda_=2e-9, _temp_=5000, _epoch_=5, _rewind_=4 | _lr_=3e-5, l2=1e-5, _lambda_=2e-9, _temp_=1000, _epoch_=10, _rewind_=0 |
47 | | DeepFM | _lr_=3e-4, l2=3e-5, _lambda_=5e-9, _temp_=200, _epoch_=15, _rewind_=1 | _lr_=3e-4, l2=3e-6, _lambda_=2e-9, _temp_=5000, _epoch_=5, _rewind_=3 | _lr_=3e-5, l2=1e-5, _lambda_=2e-9, _temp_=1000, _epoch_=10, _rewind_=0 |
48 | | DCN | _lr_=3e-4, l2=3e-6, _lambda_=1e-8, _temp_=10000, _epoch_=5, _rewind_=2 | _lr_=3e-4, l2=3e-6, _lambda_=2e-9, _temp_=5000, _epoch_=5, _rewind_=2 | _lr_=3e-5, l2=1e-5, _lambda_=5e-9, _temp_=1000, _epoch_=5, _rewind_=2 |
49 | | IPNN | _lr_=3e-4, l2=3e-6, _lambda_=5e-9, _temp_=2000, _epoch_=10, _rewind_=1 | _lr_=3e-4, l2=3e-6, _lambda_=2e-9, _temp_=5000, _epoch_=5, _rewind_=2 | _lr_=3e-5, l2=1e-5, _lambda_=2e-9, _temp_=1000, _epoch_=10, _rewind_=2 |
50 |
51 | The following procedure describes how we determine these hyper-parameters:
52 |
53 | First, we determine the hyper-parameters of the basic models by grid search: learning ratio and l2 regularization. We select the optimal learning ratio _lr_ from \{1e-3, 3e-4, 1e-4, 3e-5, 1e-5\} and l2 regularization from \{3e-4, 1e-4, 3e-5, 1e-5, 3e-6, 1e-6\}. Adam optimizer and Xavier initialization are adopted. We empirically set the batch size to be 4096, embedding dimension to be 16, MLP structure to be [1024, 512, 256].
54 |
55 | Second, we tune the hyper-parameters introduced by the OptFS method. We select the regularization lambda _lambda_ from \{1e-8, 5e-9, 2e-9, 1e-9\}, final temperature _temp_ from \{10000, 5000, 2000, 1000, 500, 200, 100\}, search epoch _epoch_ from \{5, 10, 15, 20\} and rewind epoch _rewind_ from \{0, 1, ..., _epoch_-1\}. During tuning process, we fix the optimal learning ratio _lr_ and l2 regularization determined in the first step.
56 |
57 | ### Reference
58 |
59 | Kindly cite our paper using the bibliography below
60 | ```
61 | @inproceedings{DBLP:conf/www/LyuTLC0L23,
62 | author = {Fuyuan Lyu and
63 | Xing Tang and
64 | Dugang Liu and
65 | Liang Chen and
66 | Xiuqiang He and
67 | Xue Liu},
68 | title = {Optimizing Feature Set for Click-Through Rate Prediction},
69 | booktitle = {Proceedings of the {ACM} Web Conference 2023, {WWW} 2023},
70 | pages = {3386--3395},
71 | publisher = {{ACM}},
72 | address = {Austin, TX, USA},
73 | year = {2023},
74 | url = {https://doi.org/10.1145/3543507.3583545},
75 | doi = {10.1145/3543507.3583545},
76 | timestamp = {Mon, 28 Aug 2023 21:17:10 +0200},
77 | biburl = {https://dblp.org/rec/conf/www/LyuTLC0L23.bib},
78 | bibsource = {dblp computer science bibliography, https://dblp.org}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/data/tfloader.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import glob
3 | import torch
4 | import os
5 |
6 | repo_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7 |
8 | class CriteoLoader(object):
9 | def __init__(self, tfrecord_path):
10 | self.SAMPLES = 1
11 | self.FIELDS = 39
12 | self.tfrecord_path = tfrecord_path
13 | self.description = {
14 | "label": tf.io.FixedLenFeature([self.SAMPLES], tf.float32),
15 | "feature": tf.io.FixedLenFeature([self.FIELDS], tf.int64),
16 | }
17 |
18 | def get_data(self, data_type, batch_size = 1):
19 | @tf.autograph.experimental.do_not_convert
20 | def read_data(raw_rec):
21 | example = tf.io.parse_single_example(raw_rec, self.description)
22 | return example['feature'], example['label']
23 | files = glob.glob(self.tfrecord_path + '/' + "{}*".format(data_type))
24 | #print(files)
25 | ds = tf.data.TFRecordDataset(files).map(read_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).\
26 | batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
27 | for x,y in ds:
28 | x = torch.from_numpy(x.numpy())
29 | y = torch.from_numpy(y.numpy())
30 | yield x, y
31 |
32 | class Avazuloader(object):
33 | def __init__(self, tfrecord_path):
34 | self.SAMPLES = 1
35 | self.FIELDS = 24
36 | self.tfrecord_path = tfrecord_path
37 | self.description = {
38 | "label": tf.io.FixedLenFeature([self.SAMPLES], tf.float32),
39 | "feature": tf.io.FixedLenFeature([self.FIELDS], tf.int64),
40 | }
41 |
42 | def get_data(self, data_type, batch_size = 1):
43 | @tf.autograph.experimental.do_not_convert
44 | def read_data(raw_rec):
45 | example = tf.io.parse_single_example(raw_rec, self.description)
46 | return example['feature'], example['label']
47 | files = glob.glob(self.tfrecord_path + '/' + "{}*".format(data_type))
48 | ds = tf.data.TFRecordDataset(files).map(read_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).\
49 | batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
50 | for x,y in ds:
51 | x = torch.from_numpy(x.numpy())
52 | y = torch.from_numpy(y.numpy())
53 | yield x, y
54 |
55 | class KDD12loader(object):
56 | def __init__(self, tfrecord_path):
57 | self.SAMPLES = 1
58 | self.FIELDS = 11
59 | self.tfrecord_path = tfrecord_path
60 | self.description = {
61 | "label": tf.io.FixedLenFeature([self.SAMPLES], tf.float32),
62 | "feature": tf.io.FixedLenFeature([self.FIELDS], tf.int64),
63 | }
64 |
65 | def get_data(self, data_type, batch_size = 1):
66 | @tf.autograph.experimental.do_not_convert
67 | def read_data(raw_rec):
68 | example = tf.io.parse_single_example(raw_rec, self.description)
69 | return example['feature'], example['label']
70 | files = glob.glob(self.tfrecord_path + '/' + "{}*".format(data_type))
71 | ds = tf.data.TFRecordDataset(files).map(read_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).\
72 | batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
73 | for x,y in ds:
74 | x = torch.from_numpy(x.numpy())
75 | y = torch.from_numpy(y.numpy())
76 | yield x, y
77 |
--------------------------------------------------------------------------------
/datatransform/avazu2TF.py:
--------------------------------------------------------------------------------
1 | from datatransform import DataTransform
2 | from datetime import datetime, date
3 | import argparse
4 | from pathlib import Path
5 | import pandas as pd
6 | import numpy as np
7 | parser = argparse.ArgumentParser(description='Transfrom original data to TFRecord')
8 |
9 | parser.add_argument('--label', default="click", type=str)
10 | parser.add_argument("--store_stat", action="store_true")
11 | parser.add_argument("--threshold", type=int, default=0)
12 | parser.add_argument("--dataset", type=Path)
13 | parser.add_argument("--stats", type=Path)
14 | parser.add_argument("--record", type=Path)
15 | parser.add_argument("--ratio", nargs='+', type=float)
16 |
17 | args = parser.parse_args()
18 |
19 | class AvazuTransform(DataTransform):
20 | def __init__(self, dataset_path, path, stats_path, min_threshold, label_index, ratio, store_stat=False, seed=2021):
21 | super(AvazuTransform, self).__init__(dataset_path, stats_path, store_stat=store_stat, seed=seed)
22 | self.threshold = min_threshold
23 | self.label = label_index
24 | self.split = ratio
25 | self.path = path
26 | self.stats_path = stats_path
27 |
28 | def process(self):
29 | self._read(name=None, header=0, sep=",", label_index=self.label)
30 | if self.store_stat:
31 | self.generate_and_filter(threshold=self.threshold, label_index=self.label)
32 | tr, te, val = self.random_split(ratio=self.split)
33 | self.transform_tfrecord(tr, self.path, "train", label_index=self.label)
34 | self.transform_tfrecord(te, self.path, "test", label_index=self.label)
35 | self.transform_tfrecord(val, self.path, "validation", label_index=self.label)
36 |
37 | def _process_x(self):
38 | hour = self.data["hour"].apply(lambda x: str(x))
39 | def _convert_weekday(time):
40 | dt = date(int("20"+time[0:2]), int(time[2:4]), int(time[4:6]))
41 | return int(dt.strftime("%w"))
42 | self.data["weekday"] = hour.apply(_convert_weekday)
43 |
44 | def _convert_weekend(time):
45 | dt = date(int("20"+time[0:2]), int(time[2:4]), int(time[4:6]))
46 | return 1 if dt.strftime("%w") in ['6', '0'] else 0
47 | self.data["is_weekend"] = hour.apply(_convert_weekend)
48 |
49 | self.data["hour"] = hour.apply(lambda x: int(x[6:8]))
50 |
51 | def _process_y(self):
52 | self.data = self.data.drop("id", axis=1)
53 |
54 | if __name__ == "__main__":
55 | tranformer = AvazuTransform(args.dataset, args.record, args.stats,
56 | args.threshold, args.label,
57 | args.ratio, store_stat=args.store_stat)
58 | tranformer.process()
59 |
--------------------------------------------------------------------------------
/datatransform/criteo2TF.py:
--------------------------------------------------------------------------------
1 | from datatransform import DataTransform
2 | from datetime import datetime, date
3 | import argparse
4 | from pathlib import Path
5 | import pandas as pd
6 | import numpy as np
7 | parser = argparse.ArgumentParser(description='Transfrom original data to TFRecord')
8 |
9 | parser.add_argument('--label', default="Label", type=str)
10 | parser.add_argument("--store_stat", action="store_true")
11 | parser.add_argument("--threshold", type=int, default=0)
12 | parser.add_argument("--dataset", type=Path)
13 | parser.add_argument("--stats", type=Path)
14 | parser.add_argument("--record", type=Path)
15 | parser.add_argument("--ratio", nargs='+', type=float)
16 |
17 | args = parser.parse_args()
18 |
19 | class CriteoTransform(DataTransform):
20 | def __init__(self, dataset_path, path, stats_path, min_threshold, label_index, ratio, store_stat=False, seed=2021):
21 | super(CriteoTransform, self).__init__(dataset_path, stats_path, store_stat=store_stat, seed=seed)
22 | self.threshold = min_threshold
23 | self.label = label_index
24 | self.split = ratio
25 | self.path = path
26 | self.stats_path = stats_path
27 | self.name = "Label,I1,I2,I3,I4,I5,I6,I7,I8,I9,I10,I11,I12,I13,C1,C2,C3,C4,C5,C6,C7,C8,C9,C10,C11,C12,C13,C14,C15,C16,C17,C18,C19,C20,C21,C22,C23,C24,C25,C26".split(",")
28 |
29 | def process(self):
30 | self._read(name=self.name, header=None,sep="\t", label_index=self.label)
31 | if self.store_stat:
32 | self.generate_and_filter(threshold=self.threshold, label_index=self.label, white_list = "I1,I2,I3,I4,I5,I6,I7,I8,I9,I10,I11,I12,I13".split(","))
33 | tr, te, val = self.random_split(ratio=self.split)
34 | self.transform_tfrecord(tr, self.path, "train", label_index=self.label)
35 | self.transform_tfrecord(te, self.path, "test", label_index=self.label)
36 | self.transform_tfrecord(val, self.path, "validation", label_index=self.label)
37 |
38 | def _process_x(self):
39 | def bucket(value):
40 | if not pd.isna(value):
41 | if value > 2:
42 | value = int(np.floor(np.log(value) ** 2))
43 | else:
44 | value = int(value)
45 | return value
46 |
47 | for i in range(1,14):
48 | col_name = "I{}".format(i)
49 | self.data[col_name] = self.data[col_name].apply(bucket)
50 |
51 | def _process_y(self):
52 | pass
53 |
54 | if __name__ == "__main__":
55 | tranformer = CriteoTransform(args.dataset, args.record, args.stats,
56 | args.threshold, args.label,
57 | args.ratio, store_stat=args.store_stat)
58 | tranformer.process()
59 |
--------------------------------------------------------------------------------
/datatransform/datatransform.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import tensorflow as tf
3 | import numpy as np
4 | import os
5 | from collections import defaultdict
6 | import math
7 | from pathlib import Path
8 | import shutil
9 | import pickle
10 | import tqdm
11 | import pandas as pd
12 | from sklearn.utils import shuffle
13 | from abc import abstractmethod
14 |
15 |
16 |
17 |
18 | def feature_example(label, feature):
19 | feature_des = {
20 | 'label': tf.train.Feature(float_list=tf.train.FloatList(value=[label])),
21 | 'feature': tf.train.Feature(int64_list = tf.train.Int64List(value=feature))
22 | }
23 | example_proto = tf.train.Example(features=tf.train.Features(feature=feature_des))
24 | return example_proto.SerializeToString()
25 |
26 |
27 | class DataTransform(object):
28 | def __init__(self, dataset_path, stats_path, store_stat = False, seed = 2021):
29 | self.data_path = dataset_path
30 | self.store_stat = store_stat
31 | self.seed = seed
32 | self.stats_path = stats_path
33 | self.feat_map = {}
34 | self.defaults = {}
35 |
36 | if store_stat:
37 | os.makedirs(self.stats_path, exist_ok=True)
38 | else:
39 | with open(self.stats_path.joinpath("feat_map.pkl"), 'rb') as fi:
40 | self.feat_map = pickle.load(fi)
41 | with open(self.stats_path.joinpath("defaults.pkl"), 'rb') as fi:
42 | self.defaults = pickle.load(fi)
43 | with open(self.stats_path.joinpath("offset.pkl"), 'rb') as fi:
44 | self.field_offset = pickle.load(fi)
45 |
46 |
47 | def _read(self, name= None, header = None, sep=None, label_index = ""):
48 | print("=====read data=====")
49 |
50 | self.data = pd.read_table(self.data_path, names=name, header=header, sep=sep)
51 |
52 | print(self.data)
53 | self._process_x()
54 | self._process_y()
55 | print(self.data)
56 |
57 | self.num_instances = self.data.shape[0]
58 | self.num_fields = self.data.shape[1] -1
59 |
60 | self.field_name = self.data.columns.values.tolist()
61 | assert self.num_fields == len(self.field_name) - 1
62 |
63 | num_features = 0
64 | for field in self.field_name:
65 | if field == label_index:
66 | continue
67 | num_features += self.data[field].unique().size
68 | print("===Data summary===")
69 | print("instances:{}, fields:{}, raw_features:{}".format(self.num_instances,self.num_fields, num_features))
70 |
71 | def generate_and_filter(self, threshold=0, label_index="", white_list = []):
72 | self.field_offset={}
73 | offset = 0
74 | for field in self.field_name:
75 | #print(field)
76 | if field == label_index:
77 | continue
78 | feat_count = self.data[field].value_counts(dropna=False).to_dict()
79 | if field not in white_list:
80 | unique_feat = [key for key, value in feat_count.items() if value >= threshold ]
81 | else:
82 | unique_feat = [key for key, value in feat_count.items()]
83 | field_feat_map = dict((field+"_" + str(j), i + offset) for i,j in enumerate(unique_feat))
84 | self.feat_map.update(field_feat_map)
85 | if len(feat_count) == len(unique_feat):
86 | offset += len(unique_feat)
87 | else:
88 | offset += len(unique_feat) + 1
89 | self.defaults.update({field: len(unique_feat)})
90 | self.field_offset.update({field:offset})
91 | print("After filtering features:{}".format(len(self.feat_map)))
92 |
93 | with open(self.stats_path.joinpath("feat_map.pkl"), 'wb') as fi:
94 | pickle.dump(self.feat_map, fi)
95 | with open(self.stats_path.joinpath("defaults.pkl"), 'wb') as fi:
96 | pickle.dump(self.defaults, fi)
97 | with open(self.stats_path.joinpath("offset.pkl"), 'wb') as fi:
98 | pickle.dump(self.field_offset, fi)
99 |
100 | def random_split(self, ratio=[]):
101 | assert len(ratio) == 3, "give three dataset ratio"
102 | train_data = self.data.sample(frac = ratio[0], replace=False,
103 | axis=0, random_state=self.seed)
104 | left_data = self.data[~self.data.index.isin(train_data.index)]
105 | val_data = left_data.sample(frac = ratio[1]/(ratio[1] + ratio[2]), replace=False,
106 | axis=0, random_state=self.seed)
107 | test_data = left_data[~left_data.index.isin(val_data.index)]
108 | print("===Train size:{}===".format( train_data.shape[0]))
109 | print("===Test size:{}===".format(test_data.shape[0]))
110 | print("===Validation size:{}===".format(val_data.shape[0]))
111 | return train_data, val_data, test_data
112 |
113 | def transform_tfrecord(self, data, record_path, flag, records=5e6, label_index=""):
114 | os.makedirs(record_path, exist_ok=True)
115 | part = 0
116 | instance_num = 0
117 | while records * part <= data.shape[0]:
118 | tf_writer = tf.io.TFRecordWriter(os.path.join(record_path, "{}_{:04d}.tfrecord".format(flag, part)))
119 | print("===write part {:04d}===".format(part))
120 | #pbar = tqdm.tqdm(total = int(records))
121 | tmp_data = data[int(part * records): int((part + 1) * records)]
122 | pbar = tqdm.tqdm(total = tmp_data.shape[0])
123 | for index,row in tmp_data.iterrows():
124 | label = None
125 | feature = []
126 | #oov = True
127 | for i in self.field_name:
128 | if i == label_index:
129 | label = float(row[i])
130 | continue
131 | #print(i+"_"+str(int(row[i])))
132 | feat_id = self.feat_map.setdefault(i+"_"+str(row[i]), self.field_offset[i] - 1)
133 | #oov = oov and (feat_id == self.field_offset[i])
134 | feature.append(feat_id)
135 | #if oov:
136 | #continue
137 | tf_writer.write(feature_example(label, feature))
138 | pbar.update(1)
139 | instance_num += 1
140 | tf_writer.close()
141 | pbar.close()
142 | part += 1
143 | print("real instance number:", instance_num)
144 |
145 | @abstractmethod
146 | def _process_x(self):
147 | pass
148 |
149 | @abstractmethod
150 | def _process_y(self):
151 | pass
152 |
153 | @abstractmethod
154 | def process(self):
155 | pass
156 |
--------------------------------------------------------------------------------
/datatransform/kdd2TF.py:
--------------------------------------------------------------------------------
1 | from datatransform import DataTransform
2 | from datetime import datetime, date
3 | import argparse
4 | from pathlib import Path
5 | import pandas as pd
6 | import numpy as np
7 | parser = argparse.ArgumentParser(description='Transfrom original data to TFRecord')
8 |
9 | parser.add_argument('--label', default="Label", type=str)
10 | parser.add_argument("--store_stat", action="store_true")
11 | parser.add_argument("--threshold", type=int, default=0)
12 | parser.add_argument("--dataset", type=Path)
13 | parser.add_argument("--stats", type=Path)
14 | parser.add_argument("--record", type=Path)
15 | parser.add_argument("--ratio", nargs='+', type=float)
16 |
17 | args = parser.parse_args()
18 |
19 | class KDDTransform(DataTransform):
20 | def __init__(self, dataset_path, path, stats_path, min_threshold, label_index, ratio, store_stat=False, seed=2021):
21 | super(KDDTransform, self).__init__(dataset_path, stats_path, store_stat=store_stat, seed=seed)
22 | self.threshold = min_threshold
23 | self.label = label_index
24 | self.split = ratio
25 | self.path = path
26 | self.stats_path = stats_path
27 | self.name ="Label,I1,I2,I3,I4,I5,I6,I7,I8,I9,I10,I11".split(",")
28 |
29 | def process(self):
30 | self._read(name = self.name, header = None, label_index = self.label,sep="\t")
31 | if self.store_stat:
32 | self.generate_and_filter(threshold=self.threshold, label_index=self.label)
33 | tr, te, val = self.random_split(ratio=self.split)
34 | self.transform_tfrecord_kdd(tr, self.path, "train", label_index=self.label)
35 | self.transform_tfrecord_kdd(te, self.path, "test", label_index=self.label)
36 | self.transform_tfrecord_kdd(val, self.path, "validation", label_index=self.label)
37 |
38 | def _process_x(self):
39 | print(self.data[self.data["Label"] == 1].shape)
40 |
41 | def _process_y(self):
42 | self.data["Label"] = self.data["Label"].apply(lambda x: 0 if x == 0 else 1)
43 |
44 | def transform_tfrecord(self, data, record_path, flag, records=5e6, label_index=""):
45 | os.makedirs(record_path, exist_ok=True)
46 | part = 0
47 | instance_num = 0
48 | while records * part <= data.shape[0]:
49 | tf_writer = tf.io.TFRecordWriter(os.path.join(record_path, "{}_{:04d}.tfrecord".format(flag, part)))
50 | print("===write part {:04d}===".format(part))
51 | #pbar = tqdm.tqdm(total = int(records))
52 | tmp_data = data[int(part * records): int((part + 1) * records)]
53 | pbar = tqdm.tqdm(total = tmp_data.shape[0])
54 | for index,row in tmp_data.iterrows():
55 | label = None
56 | feature = []
57 | #oov = True
58 | for i in self.field_name:
59 | if i == label_index:
60 | label = float(row[i])
61 | continue
62 | #print(i+"_"+str(int(row[i])))
63 | feat_id = self.feat_map.setdefault(i+"_"+str(int(row[i])), self.field_offset[i] - 1)
64 | #oov = oov and (feat_id == self.field_offset[i])
65 | feature.append(feat_id)
66 | #if oov:
67 | #continue
68 | tf_writer.write(feature_example(label, feature))
69 | pbar.update(1)
70 | instance_num += 1
71 | tf_writer.close()
72 | pbar.close()
73 | part += 1
74 | print("real instance number:", instance_num)
75 |
76 | if __name__ == "__main__":
77 | tranformer = KDDTransform(args.dataset, args.record, args.stats,
78 | args.threshold, args.label,
79 | args.ratio, store_stat=args.store_stat)
80 | tranformer.process()
81 |
--------------------------------------------------------------------------------
/maskTrainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import logging
4 | import os, sys
5 | from pathlib import Path
6 | import numpy as np
7 | from sklearn import metrics
8 | from utils import trainUtils
9 | from modules import mask
10 |
11 | parser = argparse.ArgumentParser(description="optfs trainer")
12 | parser.add_argument("dataset", type=str, help="specify dataset")
13 | parser.add_argument("model", type=str, help="specify model")
14 |
15 | # dataset information
16 | parser.add_argument("--feature", type=int, help="feature number", required=True)
17 | parser.add_argument("--field", type=int, help="field number", required=True)
18 | parser.add_argument("--data_dir", type=str, help="data directory", required=True)
19 |
20 | # training hyperparameters
21 | parser.add_argument("--lr", type=float, help="learning rate" , default=1e-4)
22 | parser.add_argument("--l2", type=float, help="L2 regularization", default=1e-5)
23 | parser.add_argument("--bsize", type=int, help="batchsize", default=4096)
24 | parser.add_argument("--optim", type=str, default="Adam", help="optimizer type")
25 | parser.add_argument("--max_epoch", type=int, default=20, help="maxmium epochs")
26 | parser.add_argument("--save_dir", type=Path, help="model save directory")
27 |
28 | # neural network hyperparameters
29 | parser.add_argument("--dim", type=int, help="embedding dimension", default=16)
30 | parser.add_argument("--mlp_dims", type=int, nargs='+', default=[1024, 512, 256], help="mlp layer size")
31 | parser.add_argument("--mlp_dropout", type=float, default=0.0, help="mlp dropout rate (default:0.0)")
32 | parser.add_argument("--mlp_bn", action="store_true", help="mlp batch normalization")
33 | parser.add_argument("--cross", type=int, help="cross layer", default=3)
34 |
35 | # device information
36 | parser.add_argument("--cuda", type=int, choices=range(-1, 8), default=-1, help="device info")
37 |
38 | # mask information
39 | parser.add_argument("--mask_init", type=float, default=0.5, help="mask initial value" )
40 | parser.add_argument("--final_temp", type=float, default=200, help="final temperature")
41 | parser.add_argument("--search_epoch", type=int, default=20, help="search epochs")
42 | parser.add_argument("--rewind_epoch", type=int, default=1, help="rewind epoch")
43 | parser.add_argument("--reg_lambda", type=float, default=1e-8, help="regularization rate")
44 | args = parser.parse_args()
45 |
46 | my_seed = 2022
47 | torch.manual_seed(my_seed)
48 | torch.cuda.manual_seed_all(my_seed)
49 | np.random.seed(my_seed)
50 |
51 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
52 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
53 | os.environ['NUMEXPR_NUM_THREADS'] = '8'
54 | os.environ['NUMEXPR_MAX_THREADS'] = '8'
55 |
56 | class Trainer(object):
57 | def __init__(self, opt):
58 | self.lr = opt['lr']
59 | self.l2 = opt['l2']
60 | self.bs = opt['bsize']
61 | self.model_dir = opt["save_dir"]
62 | self.epochs = opt["search_epoch"]
63 | self.rewind_epoch = opt["rewind_epoch"]
64 | self.reg_lambda = opt["lambda"]
65 | self.temp_increase = opt["final_temp"] ** (1./ (opt["search_epoch"]-1))
66 | self.dataloader = trainUtils.getDataLoader(opt["dataset"], opt["data_dir"])
67 | self.device = trainUtils.getDevice(opt["cuda"])
68 | self.network = mask.getModel(opt["model"],opt["model_opt"]).to(self.device)
69 | self.criterion = torch.nn.BCEWithLogitsLoss()
70 | self.optim = mask.getOptim(self.network, opt["optimizer"],self.lr, self.l2)
71 |
72 | def train_on_batch(self, label, data, retrain=False):
73 | self.network.train()
74 | self.network.zero_grad()
75 | data, label = data.to(self.device), label.to(self.device)
76 | logit = self.network(data)
77 | logloss = self.criterion(logit, label)
78 | regloss = self.network.reg()
79 | if not retrain:
80 | loss = self.reg_lambda * regloss + logloss
81 | else:
82 | loss = logloss
83 | loss.backward()
84 | for optim in self.optim:
85 | optim.step()
86 | return logloss.item()
87 |
88 | def eval_on_batch(self, data):
89 | self.network.eval()
90 | with torch.no_grad():
91 | data = data.to(self.device)
92 | logit = self.network(data)
93 | prob = torch.sigmoid(logit).detach().cpu().numpy()
94 | return prob
95 |
96 | def search(self):
97 | print("ticket:{t}".format(t=self.network.ticket))
98 | print("-----------------Begin Search-----------------")
99 | for epoch_idx in range(int(self.epochs)):
100 | train_loss = .0
101 | step = 0
102 | if epoch_idx > 0:
103 | self.network.temp *= self.temp_increase
104 | if epoch_idx == self.rewind_epoch:
105 | self.network.checkpoint()
106 | for feature, label in self.dataloader.get_data("train", batch_size = self.bs):
107 | train_loss += self.train_on_batch(label, feature)
108 | step += 1
109 | train_loss /= step
110 | print("Temp:{temp:.6f}".format(temp=self.network.temp))
111 | val_auc, val_loss = self.evaluate("val")
112 | print("[Epoch {epoch:d} | Train Loss: {loss:.6f} | Val AUC: {val_auc:.6f}, Val Loss: {val_loss:.6f}]".format(epoch=epoch_idx, loss=train_loss, val_auc=val_auc, val_loss=val_loss))
113 | rate = self.network.compute_remaining_weights()
114 | print("Feature remain:{rate:.6f}".format(rate=rate))
115 | test_auc, test_loss = self.evaluate("test")
116 | print("Test AUC: {test_auc:.6f}, Test Loss: {test_loss:.6f}".format(test_auc=test_auc, test_loss=test_loss))
117 |
118 | def evaluate(self, on:str):
119 | preds, trues = [], []
120 | for feature, label in self.dataloader.get_data(on, batch_size=self.bs * 10):
121 | pred = self.eval_on_batch(feature)
122 | label = label.detach().cpu().numpy()
123 | preds.append(pred)
124 | trues.append(label)
125 | y_pred = np.concatenate(preds).astype("float64")
126 | y_true = np.concatenate(trues).astype("float64")
127 | auc = metrics.roc_auc_score(y_true, y_pred)
128 | loss = metrics.log_loss(y_true, y_pred)
129 | return auc, loss
130 |
131 | def train(self, epochs):
132 | self.network.ticket=True
133 | self.network.rewind_weights()
134 | cur_auc = 0.0
135 | early_stop = False
136 | self.optim = mask.getOptim(self.network, "adam", self.lr, self.l2)[:1]
137 | rate = self.network.compute_remaining_weights()
138 |
139 | print("-----------------Begin Train-----------------")
140 | print("Ticket:{t}".format(t=self.network.ticket))
141 | print("Final feature remain:{rate:.6f}".format(rate=rate))
142 | for epoch_idx in range(int(epochs)):
143 | train_loss = .0
144 | step = 0
145 | for feature, label in self.dataloader.get_data("train", batch_size = self.bs):
146 | train_loss += self.train_on_batch(label, feature, retrain=True)
147 | step += 1
148 | train_loss /= step
149 | val_auc, val_loss = self.evaluate("val")
150 | print("[Epoch {epoch:d} | Train Loss:{loss:.6f} | Val AUC:{val_auc:.6f}, Val Loss:{val_loss:.6f}]".format(epoch=epoch_idx, loss=train_loss, val_auc=val_auc, val_loss=val_loss))
151 |
152 | if val_auc > cur_auc:
153 | cur_auc = val_auc
154 | torch.save(self.network.state_dict(), self.model_dir)
155 | else:
156 | self.network.load_state_dict(torch.load(self.model_dir))
157 | self.network.to(self.device)
158 | early_stop = True
159 | test_auc, test_loss = self.evaluate("test")
160 | print("Early stop at epoch {epoch:d} | Test AUC: {test_auc:.6f}, Test Loss:{test_loss:.6f}".format(epoch=epoch_idx, test_auc = test_auc, test_loss = test_loss))
161 | break
162 |
163 | if not early_stop:
164 | test_auc, test_loss = self.evaluate("test")
165 | print("Final Test AUC: {test_auc:.6f}, Test Loss: {test_loss:.6f}".format(test_auc=test_auc, test_loss=test_loss))
166 |
167 | def main():
168 | model_opt={
169 | "latent_dim":args.dim, "feat_num":args.feature, "field_num":args.field,
170 | "mlp_dropout":args.mlp_dropout, "use_bn": args.mlp_bn, "mlp_dims":args.mlp_dims,
171 | "mask_initial":args.mask_init,"cross":args.cross
172 | }
173 |
174 | opt={
175 | "model_opt":model_opt, "dataset":args.dataset, "model":args.model, "lr":args.lr, "l2":args.l2,
176 | "bsize":args.bsize, "optimizer":args.optim, "data_dir":args.data_dir,"save_dir":args.save_dir,
177 | "cuda":args.cuda,"search_epoch":args.search_epoch, "rewind_epoch": args.rewind_epoch,"final_temp":args.final_temp,
178 | "lambda":args.reg_lambda
179 | }
180 | print(opt)
181 | trainer = Trainer(opt)
182 | trainer.search()
183 | trainer.train(args.max_epoch)
184 |
185 | if __name__ == "__main__":
186 | """
187 | python trainer.py Criteo DeepFM --feature
188 | """
189 | main()
190 |
--------------------------------------------------------------------------------
/modules/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | class FeatureEmbedding(torch.nn.Module):
5 | def __init__(self, feature_num, latent_dim, initializer = torch.nn.init.xavier_uniform_):
6 | super().__init__()
7 | self.embedding = torch.nn.Parameter(torch.zeros(feature_num, latent_dim))
8 | initializer(self.embedding)
9 |
10 | def forward(self, x):
11 | """
12 | :param x: tensor of size (batch_size, num_fields)
13 | :return: tensor of size (batch_size, num_fields, embedding_dim)
14 | """
15 | return F.embedding(x, self.embedding)
16 |
17 | class FeaturesLinear(torch.nn.Module):
18 | def __init__(self, feature_num, output_dim=1):
19 | super().__init__()
20 | self.fc = torch.nn.Embedding(feature_num, output_dim)
21 | self.bias = torch.nn.Parameter(torch.zeros((output_dim,)))
22 |
23 | def forward(self, x):
24 | """
25 | :param x: Long tensor of size ``(batch_size, num_fields)``
26 | :return : tensor of size (batch_size, 1)
27 | """
28 | return torch.sum(torch.squeeze(self.fc(x)), dim=1, keepdim=True) + self.bias
29 |
30 | class FactorizationMachine(torch.nn.Module):
31 | def __init__(self, reduce_sum=True):
32 | super().__init__()
33 | self.reduce_sum = reduce_sum
34 |
35 | def forward(self, x):
36 | """
37 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
38 | :return : tensor of size (batch_size, 1) if reduce_sum
39 | tensor of size (batch_size, embed_dim) else
40 | """
41 | square_of_sum = torch.sum(x, dim=1) ** 2
42 | sum_of_square = torch.sum(x ** 2, dim=1)
43 | ix = square_of_sum - sum_of_square
44 | if self.reduce_sum:
45 | ix = torch.sum(ix, dim=1, keepdim=True)
46 | return 0.5 * ix
47 |
48 | class MultiLayerPerceptron(torch.nn.Module):
49 | def __init__(self, input_dim, mlp_dims, dropout, output_layer=True, use_bn=False, use_ln=False):
50 | super().__init__()
51 | layers = list()
52 | for mlp_dim in mlp_dims:
53 | layers.append(torch.nn.Linear(input_dim, mlp_dim))
54 | if use_bn:
55 | layers.append(torch.nn.BatchNorm1d(mlp_dim))
56 | if use_ln:
57 | layers.append(torch.nn.LayerNorm(mlp_dim))
58 | layers.append(torch.nn.ReLU())
59 | layers.append(torch.nn.Dropout(p=dropout))
60 | input_dim = mlp_dim
61 | if output_layer:
62 | layers.append(torch.nn.Linear(input_dim, 1))
63 | self.mlp = torch.nn.Sequential(*layers)
64 |
65 | def forward(self, x):
66 | """
67 | :param x: Float tensor of size ``(batch_size, embed_dim)``
68 | :return : tensor of size (batch_size, mlp_dims[-1])
69 | """
70 | return self.mlp(x)
71 |
72 | class CrossNetwork(torch.nn.Module):
73 | def __init__(self, input_dim, num_layers):
74 | super().__init__()
75 | self.num_layers = num_layers
76 | self.w = torch.nn.ModuleList([
77 | torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)
78 | ])
79 | self.b = torch.nn.ParameterList([
80 | torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)
81 | ])
82 |
83 | def forward(self, x):
84 | """
85 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
86 | """
87 | x0 = x
88 | for i in range(self.num_layers):
89 | xw = self.w[i](x)
90 | x = x0 * xw + self.b[i] + x
91 | return x
92 |
93 |
94 | class InnerProduct(torch.nn.Module):
95 | def __init__(self, field_num):
96 | super().__init__()
97 | self.rows = []
98 | self.cols = []
99 | for row in range(field_num):
100 | for col in range(row+1, field_num):
101 | self.rows.append(row)
102 | self.cols.append(col)
103 | self.rows = torch.tensor(self.rows)
104 | self.cols = torch.tensor(self.cols)
105 |
106 | def forward(self, x):
107 | """
108 | :param x: Float tensor of size (batch_size, field_num, embedding_dim)
109 | :return: (batch_size, field_num*(field_num-1)/2)
110 | """
111 | batch_size = x.shape[0]
112 | trans_x = torch.transpose(x, 1, 2)
113 |
114 | self.rows = self.rows.to(trans_x.device)
115 | self.cols = self.cols.to(trans_x.device)
116 |
117 | gather_rows = torch.gather(trans_x, 2, self.rows.expand(batch_size, trans_x.shape[1], self.rows.shape[0]))
118 | gather_cols = torch.gather(trans_x, 2, self.cols.expand(batch_size, trans_x.shape[1], self.rows.shape[0]))
119 | p = torch.transpose(gather_rows, 1, 2)
120 | q = torch.transpose(gather_cols, 1, 2)
121 | product_embedding = torch.mul(p, q)
122 | product_embedding = torch.sum(product_embedding, 2)
123 | return product_embedding
124 |
--------------------------------------------------------------------------------
/modules/mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from modules.layers import FactorizationMachine, MultiLayerPerceptron
6 | import copy
7 | import modules.layers as layer
8 |
9 |
10 | class MaskEmbedding(nn.Module):
11 | def __init__(self, feature_num, latent_dim, mask_initial_value=0.):
12 | super().__init__()
13 | self.feature_num = feature_num
14 | self.latent_dim = latent_dim
15 | self.mask_initial_value = mask_initial_value
16 | self.embedding = nn.Parameter(torch.zeros(feature_num, latent_dim))
17 | nn.init.xavier_uniform_(self.embedding)
18 | self.init_weight = nn.Parameter(torch.zeros_like(self.embedding), requires_grad=False)
19 | self.init_mask()
20 |
21 | def init_mask(self):
22 | self.mask_weight = nn.Parameter(torch.Tensor(self.feature_num, 1))
23 | nn.init.constant_(self.mask_weight, self.mask_initial_value)
24 |
25 | def compute_mask(self, x, temp, ticket):
26 | scaling = 1./ sigmoid(self.mask_initial_value)
27 | mask_weight = F.embedding(x, self.mask_weight)
28 | if ticket:
29 | mask = (mask_weight > 0).float()
30 | else:
31 | mask = torch.sigmoid(temp * mask_weight)
32 | return scaling * mask
33 |
34 | def prune(self, temp):
35 | self.mask_weight.data = torch.clamp(temp * self.mask_weight.data, max=self.mask_initial_value)
36 |
37 | def forward(self, x, temp=1, ticket=False):
38 | embed = F.embedding(x, self.embedding)
39 | mask = self.compute_mask(x, temp, ticket)
40 | return embed * mask
41 |
42 | def compute_remaining_weights(self, temp, ticket=False):
43 | if ticket:
44 | return float((self.mask_weight > 0.).sum()) / self.mask_weight.numel()
45 | else:
46 | m = torch.sigmoid(temp * self.mask_weight)
47 | print("max mask weight: {wa:6f}, min mask weight: {wi:6f}".format(wa=torch.max(self.mask_weight),wi=torch.min(self.mask_weight)))
48 | print("max mask: {ma:8f}, min mask: {mi:8f}".format(ma=torch.max(m), mi=torch.min(m)))
49 | print("mask number: {mn:6f}".format(mn=float((m==0.).sum())))
50 | return 1 - float((m == 0.).sum()) / m.numel()
51 |
52 | def checkpoint(self):
53 | self.init_weight.data = self.embedding.clone()
54 |
55 | def rewind_weights(self):
56 | self.embedding.data = self.init_weight.clone()
57 |
58 | def reg(self, temp):
59 | return torch.sum(torch.sigmoid(temp * self.mask_weight))
60 |
61 |
62 | class MaskedNet(nn.Module):
63 | def __init__(self, opt):
64 | super(MaskedNet, self).__init__()
65 | self.ticket = False
66 | self.latent_dim = opt["latent_dim"]
67 | self.feature_num = opt["feat_num"]
68 | self.field_num = opt["field_num"]
69 | self.mask_embedding = MaskEmbedding(self.feature_num, self.latent_dim, mask_initial_value=opt["mask_initial"])
70 | self.mask_modules = [m for m in self.modules() if type(m) == MaskEmbedding]
71 | self.temp = 1
72 |
73 | def checkpoint(self):
74 | for m in self.mask_modules: m.checkpoint()
75 | for m in self.modules():
76 | #print(m)
77 | if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.Linear):
78 | m.checkpoint = copy.deepcopy(m.state_dict())
79 |
80 | def rewind_weights(self):
81 | for m in self.mask_modules: m.rewind_weights()
82 | for m in self.modules():
83 | if isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.Linear):
84 | m.load_state_dict(m.checkpoint)
85 |
86 | def prune(self):
87 | for m in self.mask_modules: m.prune(self.temp)
88 |
89 | def reg(self):
90 | reg_loss = 0.
91 | for m in self.mask_modules:
92 | reg_loss += m.reg(self.temp)
93 | return reg_loss
94 |
95 |
96 | class MaskDeepFM(MaskedNet):
97 | def __init__(self, opt):
98 | super(MaskDeepFM, self).__init__(opt)
99 | self.fm = FactorizationMachine(reduce_sum=True)
100 | embed_dims = opt["mlp_dims"]
101 | dropout = opt["mlp_dropout"]
102 | use_bn = opt["use_bn"]
103 | self.dnn_dim = self.field_num*self.latent_dim
104 | self.dnn = MultiLayerPerceptron(self.dnn_dim, embed_dims, dropout, use_bn=use_bn)
105 |
106 | def forward(self, x):
107 | x_embedding = self.mask_embedding(x, self.temp, self.ticket)
108 | #output_linear = self.linear(x)
109 | output_fm = self.fm(x_embedding)
110 | x_dnn = x_embedding.view(-1, self.dnn_dim)
111 | output_dnn = self.dnn(x_dnn)
112 | logit = output_dnn + output_fm
113 | return logit
114 |
115 | def compute_remaining_weights(self):
116 | return self.mask_embedding.compute_remaining_weights(self.temp, self.ticket)
117 |
118 |
119 | class MaskDeepCross(MaskedNet):
120 | def __init__(self, opt):
121 | super(MaskDeepCross, self).__init__(opt)
122 | self.dnn_dim = self.field_num * self.latent_dim
123 | cross_num = opt["cross"]
124 | mlp_dims = opt["mlp_dims"]
125 | dropout = opt["mlp_dropout"]
126 | use_bn = opt["use_bn"]
127 | self.cross = layer.CrossNetwork(self.dnn_dim, cross_num)
128 | self.dnn = MultiLayerPerceptron(self.dnn_dim, mlp_dims, output_layer=False, dropout=dropout, use_bn=use_bn)
129 | self.combination = nn.Linear(mlp_dims[-1] + self.dnn_dim, 1, bias=False)
130 |
131 | def forward(self, x):
132 | x_embedding = self.mask_embedding(x, self.temp, self.ticket)
133 | x_dnn = x_embedding.view(-1, self.dnn_dim)
134 | output_cross = self.cross(x_dnn)
135 | output_dnn = self.dnn(x_dnn)
136 | comb_tensor = torch.cat((output_cross, output_dnn), dim=1)
137 | logit = self.combination(comb_tensor)
138 | return logit
139 |
140 | def compute_remaining_weights(self):
141 | return self.mask_embedding.compute_remaining_weights(self.temp, self.ticket)
142 |
143 |
144 | class MaskedFM(MaskedNet):
145 | def __init__(self, opt):
146 | super(MaskedFM, self).__init__(opt)
147 | self.fm = FactorizationMachine(reduce_sum=True)
148 |
149 | def forward(self, x):
150 | x_embedding = self.mask_embedding(x, self.temp, self.ticket)
151 | output_fm = self.fm(x_embedding)
152 | logits = output_fm
153 | return logits
154 |
155 | def compute_remaining_weights(self):
156 | return self.mask_embedding.compute_remaining_weights(self.temp, self.ticket)
157 |
158 |
159 | class MaskedIPNN(MaskedNet):
160 | def __init__(self, opt):
161 | super(MaskedIPNN, self).__init__(opt)
162 | mlp_dims = opt["mlp_dims"]
163 | use_bn = opt["use_bn"]
164 | dropout = opt["mlp_dropout"]
165 | self.dnn_dim = self.field_num * self.latent_dim + int(self.field_num * (self.field_num - 1) / 2)
166 | self.inner = layer.InnerProduct(self.field_num)
167 | self.dnn = MultiLayerPerceptron(self.dnn_dim, mlp_dims, output_layer=True, dropout=dropout, use_bn=use_bn)
168 |
169 | def forward(self, x):
170 | x_embedding = self.mask_embedding(x)
171 | x_dnn = x_embedding.view(-1, self.field_num*self.latent_dim)
172 | x_innerproduct = self.inner(x_embedding)
173 | x_dnn = torch.cat((x_dnn, x_innerproduct), 1)
174 | logit = self.dnn(x_dnn)
175 | return logit
176 |
177 | def compute_remaining_weights(self):
178 | return self.mask_embedding.compute_remaining_weights(self.temp, self.ticket)
179 |
180 |
181 | def getOptim(network, optim, lr, l2):
182 | weight_params = map(lambda a: a[1], filter(lambda p: p[1].requires_grad and 'mask_weight' not in p[0], network.named_parameters()))
183 | mask_params = map(lambda a: a[1], filter(lambda p: p[1].requires_grad and 'mask_weight' in p[0], network.named_parameters()))
184 | optim = optim.lower()
185 | if optim == "sgd":
186 | return [torch.optim.SGD(weight_params, lr=lr, weight_decay=l2), torch.optim.SGD(mask_params, lr=lr)]
187 | elif optim == "adam":
188 | return [torch.optim.Adam(weight_params, lr=lr, weight_decay=l2), torch.optim.Adam(mask_params, lr=lr)]
189 | else:
190 | raise ValueError("Invalid optimizer type: {}".format(optim))
191 |
192 |
193 | def getModel(model:str, opt):
194 | model = model.lower()
195 | if model == "deepfm":
196 | return MaskDeepFM(opt)
197 | elif model == "dcn":
198 | return MaskDeepCross(opt)
199 | elif model == "fm":
200 | return MaskedFM(opt)
201 | elif model == "ipnn":
202 | return MaskedIPNN(opt)
203 | else:
204 | raise ValueError("Invalid model type: {}".format(model))
205 |
206 |
207 | def sigmoid(x):
208 | return float(1./(1.+np.exp(-x)))
209 |
--------------------------------------------------------------------------------
/modules/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules.layers import MultiLayerPerceptron, FactorizationMachine, FeaturesLinear, FeatureEmbedding
3 | import modules.layers as layer
4 |
5 |
6 | class BasicModel(torch.nn.Module):
7 | def __init__(self, opt):
8 | super(BasicModel, self).__init__()
9 | self.latent_dim = opt["latent_dim"]
10 | self.feature_num = opt["feat_num"]
11 | self.field_num = opt["field_num"]
12 | self.embedding = FeatureEmbedding(self.feature_num, self.latent_dim)
13 |
14 | def forward(self, x):
15 | """
16 | :param x: Float tensor of size ``(batch_size, field_num)``
17 |
18 | """
19 | pass
20 |
21 | def reg(self):
22 | return 0.0
23 |
24 | class FM(BasicModel):
25 | def __init__(self, opt):
26 | super(FM, self).__init__(opt)
27 | #self.linear = FeaturesLinear(self.feature_num)
28 | self.fm = FactorizationMachine(reduce_sum=True)
29 |
30 | def forward(self, x):
31 | x_embedding = self.embedding(x)
32 | #output_linear = self.linear(x)
33 | output_fm = self.fm(x_embedding)
34 | #print(output_linear.shape)
35 | #print(output_fm.shape)
36 | #logit = output_linear + output_fm
37 | logit = output_fm
38 | return logit
39 |
40 | class DeepFM(FM):
41 | def __init__(self, opt):
42 | super(DeepFM, self).__init__(opt)
43 | embed_dims = opt["mlp_dims"]
44 | dropout = opt["mlp_dropout"]
45 | use_bn = opt["use_bn"]
46 | self.dnn_dim = self.field_num*self.latent_dim
47 | self.dnn = MultiLayerPerceptron(self.dnn_dim, embed_dims, dropout, use_bn=use_bn)
48 |
49 |
50 | def forward(self, x):
51 | x_embedding = self.embedding(x)
52 | #output_linear = self.linear(x)
53 |
54 | output_fm = self.fm(x_embedding)
55 | x_dnn = x_embedding.view(-1, self.dnn_dim)
56 | output_dnn = self.dnn(x_dnn)
57 | #logit = output_dnn + output_fm + output_linear
58 | logit = output_dnn + output_fm
59 | return logit
60 |
61 | class DeepCrossNet(BasicModel):
62 | def __init__(self, opt):
63 | super(DeepCrossNet, self).__init__(opt)
64 | cross_num = opt["cross"]
65 | mlp_dims = opt["mlp_dims"]
66 | use_bn = opt["use_bn"]
67 | dropout = opt["mlp_dropout"]
68 | self.dnn_dim = self.field_num * self.latent_dim
69 | self.cross = layer.CrossNetwork(self.dnn_dim, cross_num)
70 | self.dnn = MultiLayerPerceptron(self.dnn_dim, mlp_dims, output_layer=False, dropout=dropout, use_bn=use_bn)
71 | self.combination = torch.nn.Linear(mlp_dims[-1] + self.dnn_dim, 1, bias=False)
72 |
73 | def forward(self, x):
74 | x_embedding = self.embedding(x)
75 | x_dnn = x_embedding.view(-1, self.dnn_dim)
76 | output_cross = self.cross(x_dnn)
77 | output_dnn = self.dnn(x_dnn)
78 | comb_tensor = torch.cat((output_cross,output_dnn), dim=1)
79 | logit = self.combination(comb_tensor)
80 | return logit
81 |
82 |
83 | class InnerProductNet(BasicModel):
84 | def __init__(self, opt):
85 | super(InnerProductNet, self).__init__(opt)
86 | mlp_dims = opt["mlp_dims"]
87 | use_bn = opt["use_bn"]
88 | dropout = opt["mlp_dropout"]
89 | self.dnn_dim = self.field_num * self.latent_dim + \
90 | int(self.field_num * (self.field_num -1)/2)
91 | self.inner = layer.InnerProduct(self.field_num)
92 | self.dnn = MultiLayerPerceptron(self.dnn_dim, mlp_dims, output_layer=True, dropout=dropout, use_bn=use_bn)
93 |
94 | def forward(self, x):
95 | x_embedding = self.embedding(x)
96 | x_dnn = x_embedding.view(-1, self.field_num*self.latent_dim)
97 | x_innerproduct = self.inner(x_embedding)
98 | x_dnn= torch.cat((x_dnn, x_innerproduct), 1)
99 | logit = self.dnn(x_dnn)
100 | return logit
101 |
102 |
103 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import logging
4 | import os, sys
5 | from pathlib import Path
6 | import numpy as np
7 | from sklearn import metrics
8 | from utils import trainUtils
9 |
10 | parser = argparse.ArgumentParser(description="optfs trainer")
11 | parser.add_argument("dataset", type=str, help="specify dataset")
12 | parser.add_argument("model", type=str, help="specify model")
13 |
14 | # dataset information
15 | parser.add_argument("--feature", type=int, help="feature number", required=True)
16 | parser.add_argument("--field", type=int, help="field number", required=True)
17 | parser.add_argument("--data_dir", type=str, help="data directory", required=True)
18 |
19 | # training hyperparameters
20 | parser.add_argument("--lr", type=float, help="learning rate" , default=3e-5)
21 | parser.add_argument("--l2", type=float, help="L2 regularization", default=1e-3)
22 | parser.add_argument("--bsize", type=int, help="batchsize", default=4096)
23 | parser.add_argument("--optim", type=str, default="Adam", help="optimizer type")
24 | parser.add_argument("--max_epoch", type=int, default=20, help="maxmium epochs")
25 | parser.add_argument("--save_dir", type=Path, help="model save directory")
26 |
27 | # neural network hyperparameters
28 | parser.add_argument("--dim", type=int, help="embedding dimension", default=16)
29 | parser.add_argument("--mlp_dims", type=int, nargs='+', default=[1024, 512, 256], help="mlp layer size")
30 | parser.add_argument("--mlp_dropout", type=float, default=0.0, help="mlp dropout rate (default:0.0)")
31 | parser.add_argument("--mlp_bn", action="store_true", help="mlp batch normalization")
32 | parser.add_argument("--cross", type=int, help="cross layer", default=3)
33 |
34 | # device information
35 | parser.add_argument("--cuda", type=int, choices=range(-1, 8), default=-1, help="device info")
36 |
37 | args = parser.parse_args()
38 |
39 | my_seed = 2022
40 | torch.manual_seed(my_seed)
41 | torch.cuda.manual_seed_all(my_seed)
42 | np.random.seed(my_seed)
43 |
44 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)
45 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
46 | os.environ['NUMEXPR_NUM_THREADS'] = '8'
47 | os.environ['NUMEXPR_MAX_THREADS'] = '8'
48 |
49 | class Trainer(object):
50 | def __init__(self, opt):
51 | self.lr = opt['lr']
52 | self.l2 = opt['l2']
53 | self.bs = opt['bsize']
54 | self.model_dir = opt["save_dir"]
55 | self.dataloader = trainUtils.getDataLoader(opt["dataset"], opt["data_dir"])
56 | self.device = trainUtils.getDevice(opt["cuda"])
57 | self.network = trainUtils.getModel(opt["model"], opt["model_opt"]).to(self.device)
58 | self.criterion = torch.nn.BCEWithLogitsLoss()
59 | self.optim = trainUtils.getOptim(self.network, opt["optimizer"],self.lr, self.l2)
60 |
61 | def train_on_batch(self, label, data):
62 | self.network.train()
63 | self.optim.zero_grad()
64 | data, label = data.to(self.device), label.to(self.device)
65 | logit = self.network(data)
66 | #print(logit.shape)
67 | logloss = self.criterion(logit, label)
68 | regloss = self.network.reg()
69 | loss = regloss + logloss
70 | loss.backward()
71 | self.optim.step()
72 | return logloss.item()
73 |
74 | def eval_on_batch(self, data):
75 | self.network.eval()
76 | with torch.no_grad():
77 | data = data.to(self.device)
78 | logit = self.network(data)
79 | prob = torch.sigmoid(logit).detach().cpu().numpy()
80 | return prob
81 |
82 | def train(self, epochs):
83 | step = 0
84 | cur_auc = 0.0
85 | for epoch_idx in range(int(epochs)):
86 | train_loss = .0
87 | step = 0
88 | for feature, label in self.dataloader.get_data("train", batch_size = self.bs):
89 | #print(feature.shape)
90 | #print(label.shape)
91 | train_loss += self.train_on_batch(label, feature)
92 | step += 1
93 | train_loss /= step
94 | val_auc, val_loss = self.evaluate("val")
95 | print("[Epoch {epoch:d} | Train Loss:{loss:.6f} | Val AUC:{val_auc:.6f}, Val Loss:{val_loss:.6f}".
96 | format(epoch=epoch_idx, loss=train_loss, val_auc=val_auc, val_loss=val_loss ))
97 | early_stop = False
98 | if val_auc > cur_auc:
99 | cur_auc = val_auc
100 | torch.save(self.network.state_dict(), self.model_dir)
101 | else:
102 | self.network.load_state_dict(torch.load(self.model_dir))
103 | self.network.to(self.device)
104 | early_stop = True
105 | te_auc, te_loss = self.evaluate("test")
106 | print("Early stop at epoch {epoch:d}|Test AUC: {te_auc:.6f}, Test Loss:{te_loss:.6f}".
107 | format(epoch=epoch_idx, te_auc = te_auc, te_loss = te_loss))
108 | break
109 | if not early_stop:
110 | te_auc, te_loss = self.evaluate("test")
111 | print("Final Test AUC:{te_auc:.6f}, Test Loss:{te_loss:.6f}".format(te_auc=te_auc, te_loss=te_loss))
112 |
113 | def evaluate(self, on:str):
114 | preds, trues = [], []
115 | for feature, label in self.dataloader.get_data(on, batch_size=self.bs * 10):
116 | pred = self.eval_on_batch(feature)
117 | label = label.detach().cpu().numpy()
118 | preds.append(pred)
119 | trues.append(label)
120 | y_pred = np.concatenate(preds).astype("float64")
121 | y_true = np.concatenate(trues).astype("float64")
122 | auc = metrics.roc_auc_score(y_true, y_pred)
123 | loss = metrics.log_loss(y_true, y_pred)
124 | return auc, loss
125 |
126 | def main():
127 | model_opt={
128 | "latent_dim":args.dim, "feat_num":args.feature, "field_num":args.field,
129 | "mlp_dropout":args.mlp_dropout, "use_bn": args.mlp_bn, "mlp_dims":args.mlp_dims,
130 | "cross":args.cross
131 | }
132 |
133 | opt={"model_opt":model_opt, "dataset":args.dataset, "model":args.model, "lr":args.lr, "l2":args.l2,
134 | "bsize":args.bsize, "epoch":args.max_epoch, "optimizer":args.optim, "data_dir":args.data_dir,
135 | "save_dir":args.save_dir, "cuda":args.cuda
136 | }
137 | print(opt)
138 | trainer = Trainer(opt)
139 | trainer.train(args.max_epoch)
140 |
141 |
142 | if __name__ == "__main__":
143 | """
144 | python trainer.py Criteo DeepFM --feature
145 | """
146 | main()
147 |
--------------------------------------------------------------------------------
/utils/trainUtils.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import torch
3 | import pickle
4 | from data import tfloader
5 | import modules.models as models
6 |
7 | def getModel(model:str, opt):
8 | model = model.lower()
9 | if model == "fm":
10 | return models.FM(opt)
11 | elif model == "deepfm":
12 | return models.DeepFM(opt)
13 | elif model == "ipnn":
14 | return models.InnerProductNet(opt)
15 | elif model == "dcn":
16 | return models.DeepCrossNet(opt)
17 | else:
18 | raise ValueError("Invalid model type: {}".format(model))
19 |
20 | def getOptim(network, optim, lr, l2):
21 | params = network.parameters()
22 | optim = optim.lower()
23 | if optim == "sgd":
24 | return torch.optim.SGD(params, lr= lr, weight_decay = l2)
25 | elif optim == "adam":
26 | return torch.optim.Adam(params, lr=lr, weight_decay = l2)
27 | else:
28 | raise ValueError("Invalid optmizer type:{}".format(optim))
29 |
30 | def getDevice(device_id):
31 | if device_id != -1:
32 | assert torch.cuda.is_available(), "CUDA is not available"
33 | # torch.cuda.set_device(device_id)
34 | return torch.device('cuda')
35 | else:
36 | return torch.device('cpu')
37 |
38 | def getDataLoader(dataset:str, path):
39 | dataset = dataset.lower()
40 | if dataset == 'criteo':
41 | return tfloader.CriteoLoader(path)
42 | elif dataset == 'avazu':
43 | return tfloader.Avazuloader(path)
44 | elif dataset == 'kdd12':
45 | return tfloader.KDD12loader(path)
46 |
47 | def get_stats(path):
48 | defaults_path = os.path.join(path + "/defaults.pkl")
49 | with open(defaults_path, 'rb') as fi:
50 | defaults = pickle.load(fi)
51 | offset_path = os.path.join(path + "/offset.pkl")
52 | with open(offset_path, 'rb') as fi:
53 | offset = pickle.load(fi)
54 | # return [i+1 for i in list(defaults.values())]
55 | return list(defaults.values()), list(offset.values())
56 |
--------------------------------------------------------------------------------