├── .clang-format ├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── .style.yapf ├── README.md ├── baselines ├── README.md ├── __init__.py ├── data │ ├── __init__.py │ ├── avazu.py │ ├── custom.py │ ├── dlrm_dataloader.py │ └── synth.py ├── dlrm_main.py └── models │ ├── __init__.py │ ├── deepfm.py │ └── dlrm.py ├── benchmark ├── benchmark_cache.py ├── benchmark_fbgemm_uvm.py └── data_utils.py ├── docker ├── Dockerfile ├── Dockerfile_thu └── launch.sh ├── license ├── pics └── prefetch.png ├── recsys ├── README.md ├── __init__.py ├── datasets │ ├── __init__.py │ ├── avazu.py │ ├── criteo.py │ ├── feature_counter.py │ └── utils.py ├── dlrm_main.py ├── models │ ├── __init__.py │ └── dlrm.py └── utils │ ├── __init__.py │ ├── dataloader │ ├── __init__.py │ ├── base_dataiter.py │ └── cuda_stream_dataloader.py │ ├── misc.py │ └── preprocess_synth.py ├── requirements.txt └── scripts ├── avazu.sh ├── kaggle.sh ├── preprocess ├── .gitignore ├── README.md ├── npy_preproc_avazu.py ├── npy_preproc_criteo.py ├── split_criteo_kaggle.py └── taobao │ ├── README.md │ ├── csv_to_txt.py │ ├── run_txt_to_npz.sh │ └── txt_to_npz.py ├── run.sh ├── terabyte.sh ├── torchrec_avazu.sh ├── torchrec_custom.sh ├── torchrec_kaggle.sh ├── torchrec_synth.sh └── torchrec_terabyte.sh /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | ;W503 line break before binary operator 4 | W503, 5 | ;E203 whitespace before ':' 6 | E203, 7 | 8 | ; exclude file 9 | exclude = 10 | .tox, 11 | .git, 12 | __pycache__, 13 | build, 14 | dist, 15 | *.pyc, 16 | *.egg-info, 17 | .cache, 18 | .eggs 19 | 20 | max-line-length = 120 21 | 22 | per-file-ignores = __init__.py:F401 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *__pycache__ 3 | criteo_kaggle/ 4 | 5 | # Distribution / packaging 6 | build/ 7 | dist/ 8 | *.egg-info/ 9 | 10 | # Logging 11 | *log 12 | *wandb -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-yapf 3 | rev: v0.32.0 4 | hooks: 5 | - id: yapf 6 | args: ['--style=.style.yapf', '--parallel', '--in-place'] 7 | - repo: https://github.com/pre-commit/mirrors-clang-format 8 | rev: v13.0.1 9 | hooks: 10 | - id: clang-format 11 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | spaces_before_comment = 4 4 | split_before_logical_operator = true 5 | column_limit = 120 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CachedEmbedding : larger embedding tables, smaller GPU memory budget. 2 | 3 | The embedding tables in deep learning recommendation system models are becoming extremly large and cannot be fit in GPU memory. 4 | This project provides an efficient way to train the extremely large recommendation system models. 5 | The entire training runs on GPU in a synchronized parameter updating manner. 6 | 7 | This project applies the CachedEmbedding, which extends the vanilla 8 | [PyTorch EmbeddingBag](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag) 9 | with the help from [ColossalAI](https://github.com/hpcaitech/ColossalAI). 10 | The CachedEmbedding use a [software cache approach](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.parallel.layers.html) to dynamically manage the extremely large embedding table in the CPU and GPU memory space. 11 | For example, this repo can train DLRM model including a **91.10 GB** embedding table on Criteo 1TB dataset allocating just **3.75 GB** CUDA memory on a single GPU! 12 | 13 | In order to reduce the overhead time of the Cache, we designed a "far-sighted" Cache mechanism. 14 | Instead of only performing cache operations on the first mini-batch, wefetches several mini-batches that will be used later, and performs Cache query operations together. 15 | It also uses a pipeline method to overlap the overhead of data loading and model training, which is shown in the following figures. 16 | 17 | 18 | 19 | Despite the extra cache indexing and CPU-GPU overhead, the end-to-end performance of our system drops very little compared to the torchrec. 20 | However, torchrec usually requires an order of magnitude more CUDA memory requirements. 21 | Also, our software cache is implemented using pytorch without any customized C++/CUDA kernels, and developers can customize or optimize it according to their needs. 22 | 23 | ### Dataset 24 | 1. [Criteo Kaggle](https://www.kaggle.com/c/avazu-ctr-prediction/data) 25 | 2. [Avazu](https://www.kaggle.com/c/avazu-ctr-prediction/data) 26 | 3. [Criteo 1TB](https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/) 27 | 28 | Basically, the preprocessing processes are derived from 29 | [Torchrec's utilities](https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/scripts/npy_preproc_criteo.py) 30 | and [Avazu kaggle community](https://www.kaggle.com/code/leejunseok97/deepfm-deepctr-torch) 31 | Please refer to `scripts/preprocess` dir to see the details. 32 | 33 | ### Usage 34 | 35 | 1. Installation Dependencies 36 | 37 | Install [ColossalAI](https://github.com/hpcaitech/ColossalAI) (commit id e8d8eda5e7a0619bd779e35065397679e1536dcd) 38 | 39 | https://github.com/hpcaitech/ColossalAI 40 | 41 | Install our customized [torchrec](https://github.com/hpcaitech/torchrec) (commit id e8d8eda5e7a0619bd779e35065397679e1536dcd) 42 | 43 | https://github.com/hpcaitech/torchrec 44 | 45 | Or, build a docker image using [docker/Dockerfile](./docker/Dockerfile). 46 | Or, use prebuilt docker image on dockerhub. 47 | 48 | ``` 49 | docker pull hpcaitech/cacheembedding:0.2.2 50 | ``` 51 | 52 | lauch a docker container. 53 | 54 | ``` 55 | bash ./docker/launch.sh 56 | ``` 57 | 58 | 2. Run 59 | 60 | All the commands to run DLRM on three datasets are presented in `scripts/run.sh` 61 | ``` 62 | bash scripts/run.sh 63 | ``` 64 | 65 | Set `--prefetch_num` to use prefetching. 66 | 67 | ### Model 68 | Currently, this repo only contains facebook DLRM models, and we are working on testing more recommendation models. 69 | 70 | ### Performance 71 | 72 | The DLRM performance on three datasets using ColossalAI version (this repo) and torchrec (with UVM) is shown as follows. The cache ratio of FreqAwareEmbedding is set as 1%. The evaluation is conducted on a single A100 (80GB memory) and AMD 7543 32-Core CPU (512GB memory). 73 | 74 | | | method | AUROC over Test after 1 Epoch | Acc over test | Throughput | Time to Train 1 Epoch | GPU memory allocated (GB) | GPU memory reserved (GB) | CPU memory usage (GB) | 75 | |:----------:|:----------:|:-----------------------------:|:-------------:|:----------:|:---------------------:|:-------------------------:|:------------------------:|:---------------------:| 76 | | criteo 1TB | ColossalAI | 0.791299403 | 0.967155457 | 42 it/s | 1h40m | 3.75 | 5.04 | 94.39 | 77 | | | torchrec | 0.79515636 | 0.967177451 | 45 it/s | 1h35m | 66.54 | 68.43 | 7.7 | 78 | | kaggle | ColossalAI | 0.776755869 | 0.779025435 | 50 it/s | 49s | 0.9 | 2.14 | 34.66 | 79 | | | torchrec | 0.786652029 | 0.782288849 | 81 it/s | 30s | 16.13 | 17.99 | 13.89 | 80 | | avazue | ColossalAI | 0.72732079 | 0.824390948 | 72 it/s | 31s | 0.31 | 1.06 | 16.89 | 81 | | | torchrec | 0.725972056 | 0.824484706 | 111 it/s | 21s | 4.53 | 5.83 | 12.25 | 82 | 83 | ### Cite us 84 | ``` 85 | @article{fang2022frequency, 86 | title={A Frequency-aware Software Cache for Large Recommendation System Embeddings}, 87 | author={Fang, Jiarui and Zhang, Geng and Han, Jiatong and Li, Shenggui and Bian, Zhengda and Li, Yongbin and Liu, Jin and You, Yang}, 88 | journal={arXiv preprint arXiv:2208.05321}, 89 | year={2022} 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /baselines/README.md: -------------------------------------------------------------------------------- 1 | # FreqCacheEmbedding 2 | 3 | This repo contains the implementation of FreqCacheEmbedding, which extends the vanilla 4 | [PyTorch EmbeddingBag](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag) 5 | with cache mechanism to enable heterogeneous training for large scale recommendation models. 6 | 7 | ### Dataset 8 | 1. [Criteo Kaggle](https://www.kaggle.com/c/avazu-ctr-prediction/data) 9 | 2. [Avazu](https://www.kaggle.com/c/avazu-ctr-prediction/data) 10 | 11 | Basically, the preprocessing processes are derived from 12 | [Torchrec's utilities](https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/scripts/npy_preproc_criteo.py) 13 | and [Avazu kaggle community](https://www.kaggle.com/code/leejunseok97/deepfm-deepctr-torch) 14 | Please refer to `recsys/datasets/preprocess_scripts` dir to see the details. 15 | 16 | During the time this repo was built, another commonly adopted dataset, 17 | [Criteo 1TB](https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/) 18 | is unavailable (see this [issue](https://github.com/pytorch/torchrec/issues/245)). 19 | We will append its preprocessing & running scripts very soon. 20 | 21 | ### Command 22 | All the commands to run the FreqCacheEmbedding enabled recommendations models are presented in `run.sh` 23 | 24 | ### Model 25 | Currently, this repo only contains DLRM & DeepFM models, 26 | and we are working on testing more recommendation models. -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/baselines/__init__.py -------------------------------------------------------------------------------- /baselines/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/baselines/data/__init__.py -------------------------------------------------------------------------------- /baselines/data/avazu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import IterableDataset 8 | from torchrec.datasets.utils import PATH_MANAGER_KEY, Batch 9 | from torchrec.datasets.criteo import BinaryCriteoUtils 10 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 11 | 12 | CAT_FEATURE_COUNT = 13 13 | INT_FEATURE_COUNT = 8 14 | DEFAULT_LABEL_NAME = "click" 15 | DEFAULT_CAT_NAMES = [ 16 | 'C1', 17 | 'banner_pos', 18 | 'site_id', 19 | 'site_domain', 20 | 'site_category', 21 | 'app_id', 22 | 'app_domain', 23 | 'app_category', 24 | 'device_id', 25 | 'device_ip', 26 | 'device_model', 27 | 'device_type', 28 | 'device_conn_type', 29 | ] 30 | DEFAULT_INT_NAMES = ['C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21'] 31 | NUM_EMBEDDINGS_PER_FEATURE = '7,7,4737,7745,26,8552,559,36,2686408,6729486,8251,5,4' # 9445823 in total 32 | TOTAL_TRAINING_SAMPLES = 36_386_071 # 90% sample in train, 40428967 in total 33 | 34 | 35 | class AvazuIterDataPipe(IterableDataset): 36 | 37 | def __init__( 38 | self, 39 | dense_paths: List[str], 40 | sparse_paths: List[str], 41 | labels_paths: List[str], 42 | batch_size: int, 43 | rank: int, 44 | world_size: int, 45 | shuffle_batches: bool = False, 46 | mmap_mode: bool = False, 47 | hashes: Optional[List[int]] = None, 48 | path_manager_key: str = PATH_MANAGER_KEY, 49 | ) -> None: 50 | self.dense_paths = dense_paths 51 | self.sparse_paths = sparse_paths 52 | self.labels_paths = labels_paths 53 | self.batch_size = batch_size 54 | self.rank = rank 55 | self.world_size = world_size 56 | self.shuffle_batches = shuffle_batches 57 | self.mmap_mode = mmap_mode 58 | self.hashes = hashes 59 | self.path_manager_key = path_manager_key 60 | 61 | self._load_data_for_rank() 62 | self.num_rows_per_file: List[int] = [a.shape[0] for a in self.dense_arrs] 63 | self.num_batches: int = sum(self.num_rows_per_file) // batch_size 64 | 65 | # These values are the same for the KeyedJaggedTensors in all batches, so they 66 | # are computed once here. This avoids extra work from the KeyedJaggedTensor sync 67 | # functions. 68 | self._num_ids_in_batch: int = CAT_FEATURE_COUNT * batch_size 69 | self.keys: List[str] = DEFAULT_CAT_NAMES 70 | self.lengths: torch.Tensor = torch.ones((self._num_ids_in_batch,), dtype=torch.int32) 71 | self.offsets: torch.Tensor = torch.arange(0, self._num_ids_in_batch + 1, dtype=torch.int32) 72 | self.stride = batch_size 73 | self.length_per_key: List[int] = CAT_FEATURE_COUNT * [batch_size] 74 | self.offset_per_key: List[int] = [batch_size * i for i in range(CAT_FEATURE_COUNT + 1)] 75 | self.index_per_key: Dict[str, int] = {key: i for (i, key) in enumerate(self.keys)} 76 | 77 | def _load_data_for_rank(self) -> None: 78 | file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range( 79 | lengths=[ 80 | BinaryCriteoUtils.get_shape_from_npy(path, path_manager_key=self.path_manager_key)[0] 81 | for path in self.dense_paths 82 | ], 83 | rank=self.rank, 84 | world_size=self.world_size, 85 | ) 86 | 87 | self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], [] 88 | for _dtype, arrs, paths in zip( 89 | [np.float32, np.int64, np.int32], 90 | [self.dense_arrs, self.sparse_arrs, self.labels_arrs], 91 | [self.dense_paths, self.sparse_paths, self.labels_paths], 92 | ): 93 | for idx, (range_left, range_right) in file_idx_to_row_range.items(): 94 | arrs.append( 95 | BinaryCriteoUtils.load_npy_range( 96 | paths[idx], 97 | range_left, 98 | range_right - range_left + 1, 99 | path_manager_key=self.path_manager_key, 100 | mmap_mode=self.mmap_mode, 101 | ).astype(_dtype)) 102 | 103 | # When mmap_mode is enabled, the hash is applied in def __iter__, which is 104 | # where samples are batched during training. 105 | # Otherwise, the ML dataset is preloaded, and the hash is applied here in 106 | # the preload stage, as shown: 107 | if not self.mmap_mode and self.hashes is not None: 108 | hashes_np = np.array(self.hashes).reshape((1, CAT_FEATURE_COUNT)) 109 | for sparse_arr in self.sparse_arrs: 110 | sparse_arr %= hashes_np 111 | 112 | def _np_arrays_to_batch(self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray) -> Batch: 113 | if self.shuffle_batches: 114 | # Shuffle all 3 in unison 115 | shuffler = np.random.permutation(len(dense)) 116 | dense = dense[shuffler] 117 | sparse = sparse[shuffler] 118 | labels = labels[shuffler] 119 | 120 | return Batch( 121 | dense_features=torch.from_numpy(dense), 122 | sparse_features=KeyedJaggedTensor( 123 | keys=self.keys, 124 | # transpose + reshape(-1) incurs an additional copy. 125 | values=torch.from_numpy(sparse.transpose(1, 0).reshape(-1)), 126 | lengths=self.lengths, 127 | offsets=self.offsets, 128 | stride=self.stride, 129 | length_per_key=self.length_per_key, 130 | offset_per_key=self.offset_per_key, 131 | index_per_key=self.index_per_key, 132 | ), 133 | labels=torch.from_numpy(labels.reshape(-1)), 134 | ) 135 | 136 | def __iter__(self) -> Iterator[Batch]: 137 | # Invariant: buffer never contains more than batch_size rows. 138 | buffer: Optional[List[np.ndarray]] = None 139 | 140 | def append_to_buffer(dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray) -> None: 141 | nonlocal buffer 142 | if buffer is None: 143 | buffer = [dense, sparse, labels] 144 | else: 145 | for idx, arr in enumerate([dense, sparse, labels]): 146 | buffer[idx] = np.concatenate((buffer[idx], arr)) 147 | 148 | # Maintain a buffer that can contain up to batch_size rows. Fill buffer as 149 | # much as possible on each iteration. Only return a new batch when batch_size 150 | # rows are filled. 151 | file_idx = 0 152 | row_idx = 0 153 | batch_idx = 0 154 | while batch_idx < self.num_batches: 155 | buffer_row_count = 0 if buffer is None else buffer[0].shape[0] 156 | if buffer_row_count == self.batch_size: 157 | yield self._np_arrays_to_batch(*buffer) 158 | batch_idx += 1 159 | buffer = None 160 | else: 161 | rows_to_get = min( 162 | self.batch_size - buffer_row_count, 163 | self.num_rows_per_file[file_idx] - row_idx, 164 | ) 165 | slice_ = slice(row_idx, row_idx + rows_to_get) 166 | 167 | dense_inputs = self.dense_arrs[file_idx][slice_, :] 168 | sparse_inputs = self.sparse_arrs[file_idx][slice_, :] 169 | target_labels = self.labels_arrs[file_idx][slice_, :] 170 | 171 | if self.mmap_mode and self.hashes is not None: 172 | sparse_inputs = sparse_inputs % np.array(self.hashes).reshape((1, CAT_FEATURE_COUNT)) 173 | 174 | append_to_buffer( 175 | dense_inputs, 176 | sparse_inputs, 177 | target_labels, 178 | ) 179 | row_idx += rows_to_get 180 | 181 | if row_idx >= self.num_rows_per_file[file_idx]: 182 | file_idx += 1 183 | row_idx = 0 184 | 185 | def __len__(self) -> int: 186 | return self.num_batches 187 | -------------------------------------------------------------------------------- /baselines/data/custom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch import distributed as dist 4 | import numpy as np 5 | from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union 6 | from colossalai.nn.parallel.layers.cache_embedding import CachedEmbeddingBag 7 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 8 | from fbgemm_gpu.split_table_batched_embeddings_ops import SplitTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, ComputeDevice, CacheAlgorithm 9 | import time 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import IterableDataset 13 | from torchrec.datasets.utils import PATH_MANAGER_KEY, Batch 14 | from torchrec.datasets.criteo import BinaryCriteoUtils 15 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 16 | from torch.utils.data import DataLoader, IterableDataset 17 | import os 18 | 19 | # customizable factors: 20 | NUM_ROWS = int(2**25) # samples num 21 | E = [int(3e7), int(1e7), int(2e7), int(1e7), int(1e7), int(3e6), int(8e6), int( 22 | 1e7), int(1e6), int(1e6), int(1e6), int(1e6), int(5e6), int(4000), int(250), int(250)] # unique embeddings num 23 | s = 0.25 # long-tail skew 24 | POOLING_FACTOR = 2 # indices num per table per sample 25 | 26 | 27 | CAT_FEATURE_COUNT = len(E) 28 | NUM_EMBEDDINGS_PER_FEATURE = "" 29 | for e in E: 30 | NUM_EMBEDDINGS_PER_FEATURE += (str(e) + ',') 31 | NUM_EMBEDDINGS_PER_FEATURE = NUM_EMBEDDINGS_PER_FEATURE[:-1] 32 | DEFAULT_CAT_NAMES = ["cat_{}".format(i) for i in range(CAT_FEATURE_COUNT)] 33 | DEFAULT_LABEL_NAME = "click" 34 | DEFAULT_INT_NAMES = ['rand_dense'] 35 | TOTAL_TRAINING_SAMPLES = NUM_ROWS 36 | 37 | 38 | def update_settings(): 39 | global CAT_FEATURE_COUNT, DEFAULT_CAT_NAMES, NUM_EMBEDDINGS_PER_FEATURE 40 | CAT_FEATURE_COUNT = len(E) 41 | DEFAULT_CAT_NAMES = ["cat_{}".format(i) for i in range(CAT_FEATURE_COUNT)] 42 | NUM_EMBEDDINGS_PER_FEATURE = "" 43 | for e in E: 44 | NUM_EMBEDDINGS_PER_FEATURE += (str(e) + ',') 45 | NUM_EMBEDDINGS_PER_FEATURE = NUM_EMBEDDINGS_PER_FEATURE[:-1] 46 | 47 | 48 | class CustomIterDataPipe(IterableDataset): 49 | def __init__( 50 | self, 51 | batch_size: int, 52 | rank: int, 53 | world_size: int, 54 | shuffle_batches: bool = False, 55 | mmap_mode: bool = False, 56 | hashes: Optional[List[int]] = None, 57 | path_manager_key: str = PATH_MANAGER_KEY, 58 | ) -> None: 59 | self.batch_size = batch_size 60 | self.rank = rank 61 | self.world_size = world_size 62 | self.shuffle_batches = shuffle_batches 63 | self.mmap_mode = mmap_mode 64 | self.hashes = hashes 65 | self.path_manager_key = path_manager_key 66 | self.keys: List[str] = DEFAULT_CAT_NAMES 67 | self._num_ids_group_in_batch: int = CAT_FEATURE_COUNT * batch_size 68 | self.lengths: torch.Tensor = torch.ones((self._num_ids_group_in_batch,), dtype=torch.int32) * POOLING_FACTOR 69 | self.offsets: torch.Tensor = torch.arange(0, self._num_ids_group_in_batch + 1, dtype=torch.int32) * POOLING_FACTOR 70 | self.num_batches = NUM_ROWS // self.batch_size // self.world_size 71 | self.length_per_key: List[int] = CAT_FEATURE_COUNT * [batch_size * POOLING_FACTOR] 72 | self.offset_per_key: List[int] = [batch_size * i * POOLING_FACTOR for i in range(CAT_FEATURE_COUNT + 1)] 73 | self.index_per_key: Dict[str, int] = {key: i for (i, key) in enumerate(self.keys)} 74 | self.stride = batch_size 75 | # for random generation 76 | self.min_sample_list = [(1 / e) ** s for e in E] 77 | self.max_sample_list = [1.0 for e in E] 78 | # for iter 79 | self.iter_count = 0 80 | 81 | def __len__(self) -> int: 82 | return self.num_batches 83 | 84 | def __iter__(self) -> Iterator[Batch]: 85 | while self.iter_count < self.num_batches: 86 | indices = [] 87 | for min_sample, max_sample in zip(self.min_sample_list, self.max_sample_list): 88 | # long-tail random idx generation 89 | rand_float = torch.rand(self.batch_size * POOLING_FACTOR, dtype=torch.float64) 90 | sample_float = rand_float * (max_sample - min_sample) + min_sample 91 | indices.append(torch.floor(1 / (sample_float ** (1 / s))).long() - 1) 92 | self.iter_count += 1 93 | yield self._make_batch(torch.cat(indices)) 94 | 95 | def _make_batch(self, indices): 96 | ret = Batch( 97 | dense_features=torch.rand(self.stride, 1), 98 | sparse_features=KeyedJaggedTensor( 99 | keys=self.keys, 100 | values=indices, 101 | lengths=self.lengths, 102 | offsets=self.offsets, 103 | stride=self.stride, 104 | length_per_key=self.length_per_key, 105 | offset_per_key=self.offset_per_key, 106 | index_per_key=self.index_per_key, 107 | ), 108 | labels=torch.randint(2, (self.stride,)) 109 | ) 110 | return ret 111 | 112 | def get_custom_data_loader( 113 | args: argparse.Namespace, 114 | stage: str) -> DataLoader: 115 | rank = dist.get_rank() 116 | world_size = dist.get_world_size() 117 | if stage == "train": 118 | dataloader = DataLoader( 119 | CustomIterDataPipe( 120 | batch_size=args.batch_size, 121 | rank=rank, 122 | world_size=world_size, 123 | shuffle_batches=args.shuffle_batches, 124 | hashes=None 125 | ), 126 | batch_size=None, 127 | pin_memory=args.pin_memory, 128 | collate_fn=lambda x: x, 129 | ) 130 | else : 131 | dataloader = [] 132 | return dataloader 133 | -------------------------------------------------------------------------------- /baselines/data/dlrm_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import os 10 | from typing import List, Optional, Tuple, Dict 11 | import glob 12 | import numpy as np 13 | 14 | import torch 15 | from torch import distributed as dist 16 | from torch.utils.data import DataLoader, IterableDataset 17 | from torchrec.datasets.criteo import ( 18 | CAT_FEATURE_COUNT, 19 | DEFAULT_CAT_NAMES, 20 | DEFAULT_INT_NAMES, 21 | DEFAULT_LABEL_NAME, 22 | DAYS, 23 | InMemoryBinaryCriteoIterDataPipe, 24 | ) 25 | from torchrec.datasets.random import RandomRecDataset 26 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 27 | from torchrec.datasets.utils import Batch 28 | from petastorm import make_batch_reader 29 | from pyarrow.parquet import ParquetDataset 30 | from .avazu import AvazuIterDataPipe 31 | from .synth import get_synth_data_loader 32 | from .custom import get_custom_data_loader 33 | STAGES = ["train", "val", "test"] 34 | KAGGLE_NUM_EMBEDDINGS_PER_FEATURE = '1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,' \ 35 | '5461306,10,5652,2173,4,7046547,18,15,286181,105,142572' # For criteo kaggle 36 | KAGGLE_TOTAL_TRAINING_SAMPLES = 39_291_954 # 0-6 days for criteo kaggle, 45,840,617 samples in total 37 | TERABYTE_NUM_EMBEDDINGS_PER_FEATURE = "45833188,36746,17245,7413,20243,3,7114,1441,62,29275261,1572176,345138,10," \ 38 | "2209,11267,128,4,974,14,48937457,11316796,40094537,452104,12606,104,35" 39 | 40 | 41 | def _get_random_dataloader(args: argparse.Namespace,) -> DataLoader: 42 | return DataLoader( 43 | RandomRecDataset( 44 | keys=DEFAULT_CAT_NAMES, 45 | batch_size=args.batch_size, 46 | hash_size=args.num_embeddings, 47 | hash_sizes=args.num_embeddings_per_feature if hasattr(args, "num_embeddings_per_feature") else None, 48 | manual_seed=args.seed if hasattr(args, "seed") else None, 49 | ids_per_feature=1, 50 | num_dense=len(DEFAULT_INT_NAMES), 51 | ), 52 | batch_size=None, 53 | batch_sampler=None, 54 | pin_memory=args.pin_memory, 55 | num_workers=0, 56 | ) 57 | 58 | 59 | def _get_in_memory_dataloader( 60 | args: argparse.Namespace, 61 | stage: str, 62 | ) -> DataLoader: 63 | files = os.listdir(args.in_memory_binary_criteo_path) 64 | 65 | def is_final_day(s: str) -> bool: 66 | return f"day_{(7 if args.kaggle else DAYS) - 1}" in s 67 | 68 | if stage == "train": 69 | # Train set gets all data except from the final day. 70 | files = list(filter(lambda s: not is_final_day(s), files)) 71 | rank = dist.get_rank() 72 | world_size = dist.get_world_size() 73 | else: 74 | # Validation set gets the first half of the final day's samples. Test set get 75 | # the other half. 76 | files = list(filter(is_final_day, files)) 77 | rank = (dist.get_rank() if stage == "val" else dist.get_rank() + dist.get_world_size()) 78 | world_size = dist.get_world_size() * 2 79 | 80 | stage_files: List[List[str]] = [ 81 | sorted(map( 82 | lambda x: os.path.join(args.in_memory_binary_criteo_path, x), 83 | filter(lambda s: kind in s, files), 84 | )) for kind in ["dense", "sparse", "labels"] 85 | ] 86 | dataloader = DataLoader( 87 | InMemoryBinaryCriteoIterDataPipe( 88 | *stage_files, # pyre-ignore[6] 89 | batch_size=args.batch_size, 90 | rank=rank, 91 | world_size=world_size, 92 | shuffle_batches=args.shuffle_batches, 93 | hashes=args.num_embeddings_per_feature if args.num_embeddings is None else 94 | ([args.num_embeddings] * CAT_FEATURE_COUNT), 95 | ), 96 | batch_size=None, 97 | pin_memory=args.pin_memory, 98 | collate_fn=lambda x: x, 99 | ) 100 | return dataloader 101 | 102 | 103 | def get_avazu_data_loader(args, stage): 104 | files = os.listdir(args.in_memory_binary_criteo_path) 105 | 106 | if stage == "train": 107 | files = list(filter(lambda s: "train" in s, files)) 108 | rank = dist.get_rank() 109 | world_size = dist.get_world_size() 110 | else: 111 | # Validation set gets the first half of the final day's samples. Test set get 112 | # the other half. 113 | files = list(filter(lambda s: "train" not in s, files)) 114 | rank = (dist.get_rank() if stage == "val" else dist.get_rank() + dist.get_world_size()) 115 | world_size = dist.get_world_size() * 2 116 | 117 | stage_files: List[List[str]] = [ 118 | sorted(map( 119 | lambda x: os.path.join(args.in_memory_binary_criteo_path, x), 120 | filter(lambda s: kind in s, files), 121 | )) for kind in ["dense", "sparse", "label"] 122 | ] 123 | 124 | dataloader = DataLoader( 125 | AvazuIterDataPipe( 126 | *stage_files, # pyre-ignore[6] 127 | batch_size=args.batch_size, 128 | rank=rank, 129 | world_size=world_size, 130 | shuffle_batches=args.shuffle_batches, 131 | hashes=args.num_embeddings_per_feature if args.num_embeddings is None else 132 | ([args.num_embeddings] * CAT_FEATURE_COUNT), 133 | ), 134 | batch_size=None, 135 | pin_memory=args.pin_memory, 136 | collate_fn=lambda x: x, 137 | ) 138 | return dataloader 139 | 140 | 141 | class PetastormDataReader(IterableDataset): 142 | """ 143 | This is a compromise solution for the criteo terabyte dataset, 144 | please see the solution 3 in: https://github.com/uber/petastorm/issues/508 145 | 146 | Basically, the dataloader in each rank extracts random samples from the whole dataset in the training stage 147 | in which the batches in each rank are not guaranteed to be unique. 148 | In the validation stage, all the samples are evaluated in each rank, 149 | so that each rank contains the correct result 150 | """ 151 | 152 | def __init__(self, 153 | paths, 154 | batch_size, 155 | rank=None, 156 | world_size=None, 157 | shuffle_batches=False, 158 | hashes=None, 159 | seed=1024, 160 | drop_last=True): 161 | self.dataset = ParquetDataset(paths, use_legacy_dataset=False) 162 | self.batch_size = batch_size 163 | self.rank = rank 164 | self.world_size = world_size 165 | self.shuffle_batches = shuffle_batches 166 | self.hashes = np.array(hashes).reshape((1, CAT_FEATURE_COUNT)) if hashes is not None else None 167 | 168 | self._num_ids_in_batch: int = CAT_FEATURE_COUNT * batch_size 169 | self.keys: List[str] = DEFAULT_CAT_NAMES 170 | self.lengths: torch.Tensor = torch.ones((self._num_ids_in_batch,), dtype=torch.int32) 171 | self.offsets: torch.Tensor = torch.arange(0, self._num_ids_in_batch + 1, dtype=torch.int32) 172 | self.stride = batch_size 173 | self.length_per_key: List[int] = CAT_FEATURE_COUNT * [batch_size] 174 | self.offset_per_key: List[int] = [batch_size * i for i in range(CAT_FEATURE_COUNT + 1)] 175 | self.index_per_key: Dict[str, int] = {key: i for (i, key) in enumerate(self.keys)} 176 | self.seed = seed 177 | 178 | self.drop_last = drop_last 179 | if drop_last: 180 | self.num_batches = sum([fragment.metadata.num_rows for fragment in self.dataset.fragments 181 | ]) // self.batch_size 182 | else: 183 | self.num_batches = (sum([fragment.metadata.num_rows 184 | for fragment in self.dataset.fragments]) + self.batch_size - 1) // self.batch_size 185 | if self.world_size is not None: 186 | self.num_batches = self.num_batches // world_size 187 | 188 | def __iter__(self): 189 | buffer: Optional[List[np.ndarray]] = None 190 | count = 0 191 | 192 | def append_to_buffer(_dense: np.ndarray, _sparse: np.ndarray, _labels: np.ndarray) -> None: 193 | nonlocal buffer 194 | if buffer is None: 195 | buffer = [_dense, _sparse, _labels] 196 | else: 197 | buffer[0] = np.concatenate([buffer[0], _dense], axis=0) 198 | buffer[1] = np.concatenate([buffer[1], _sparse], axis=1) 199 | buffer[2] = np.concatenate([buffer[2], _labels], axis=0) 200 | 201 | with make_batch_reader( 202 | list(map(lambda x: "file://" + x, self.dataset.files)), 203 | num_epochs=1, 204 | workers_count=1, # for reproducibility 205 | ) as reader: 206 | # note that `batch` here is just a bunch of samples read by petastorm instead of `batch` consumed by models 207 | for batch in reader: 208 | labels = getattr(batch, DEFAULT_LABEL_NAME) 209 | sparse = np.concatenate([getattr(batch, col_name).reshape(1, -1) for col_name in DEFAULT_CAT_NAMES], 210 | axis=0) 211 | dense = np.concatenate([getattr(batch, col_name).reshape(-1, 1) for col_name in DEFAULT_INT_NAMES], 212 | axis=1) 213 | start_idx = 0 214 | while start_idx < dense.shape[0]: 215 | buffer_size = 0 if buffer is None else buffer[0].shape[0] 216 | if buffer_size == self.batch_size: 217 | yield self._batch_ndarray(*buffer) 218 | buffer = None 219 | count += 1 220 | if count == self.num_batches: 221 | raise StopIteration() 222 | else: 223 | rows_to_get = min(self.batch_size - buffer_size, dense.shape[0] - start_idx) 224 | label_chunk = labels[start_idx:start_idx + rows_to_get] 225 | sparse_chunk = sparse[:, start_idx:start_idx + rows_to_get] 226 | dense_chunk = dense[start_idx:start_idx + rows_to_get, :] 227 | append_to_buffer(dense_chunk, sparse_chunk, label_chunk) 228 | start_idx += rows_to_get 229 | if buffer is not None and not self.drop_last: 230 | yield self._batch_ndarray(*buffer) 231 | 232 | def _batch_ndarray(self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray): 233 | if self.shuffle_batches: 234 | # Shuffle all 3 in unison 235 | shuffler = np.random.permutation(len(dense)) 236 | dense = dense[shuffler] 237 | sparse = sparse[:, shuffler] 238 | labels = labels[shuffler] 239 | 240 | return Batch( 241 | dense_features=torch.from_numpy(dense), 242 | sparse_features=KeyedJaggedTensor( 243 | keys=self.keys, 244 | values=torch.from_numpy(sparse.reshape(-1)), 245 | lengths=self.lengths, 246 | offsets=self.offsets, 247 | stride=self.stride, 248 | length_per_key=self.length_per_key, 249 | offset_per_key=self.offset_per_key, 250 | index_per_key=self.index_per_key, 251 | ), 252 | labels=torch.from_numpy(labels.reshape(-1)), 253 | ) 254 | 255 | def __len__(self): 256 | return self.num_batches 257 | 258 | 259 | def _get_petastorm_dataloader(args, stage): 260 | if stage == "train": 261 | data_split = "train" 262 | elif stage == "val": 263 | data_split = "validation" 264 | else: 265 | data_split = "test" 266 | 267 | file_num = len(glob.glob(os.path.join(args.in_memory_binary_criteo_path, data_split, "*.parquet"))) 268 | files = [os.path.join(args.in_memory_binary_criteo_path, data_split, f"part_{i}.parquet") for i in range(file_num)] 269 | 270 | dataloader = DataLoader(PetastormDataReader(files, 271 | args.batch_size, 272 | rank=dist.get_rank() if stage == "train" else None, 273 | world_size=dist.get_world_size() if stage == "train" else None, 274 | hashes=args.num_embeddings_per_feature), 275 | batch_size=None, 276 | pin_memory=False, 277 | collate_fn=lambda x: x, 278 | num_workers=0) 279 | 280 | return dataloader 281 | 282 | 283 | def get_dataloader(args: argparse.Namespace, backend: str, stage: str) -> DataLoader: 284 | """ 285 | Gets desired dataloader from dlrm_main command line options. Currently, this 286 | function is able to return either a DataLoader wrapped around a RandomRecDataset or 287 | a Dataloader wrapped around an InMemoryBinaryCriteoIterDataPipe. 288 | 289 | Args: 290 | args (argparse.Namespace): Command line options supplied to dlrm_main.py's main 291 | function. 292 | backend (str): "nccl" or "gloo". 293 | stage (str): "train", "val", or "test". 294 | 295 | Returns: 296 | dataloader (DataLoader): PyTorch dataloader for the specified options. 297 | 298 | """ 299 | stage = stage.lower() 300 | if stage not in STAGES: 301 | raise ValueError(f"Supplied stage was {stage}. Must be one of {STAGES}.") 302 | 303 | args.pin_memory = ((backend == "nccl") if not hasattr(args, "pin_memory") else args.pin_memory) 304 | 305 | if (not hasattr(args, "in_memory_binary_criteo_path") or args.in_memory_binary_criteo_path is None): 306 | return _get_random_dataloader(args) 307 | elif "criteo" in args.in_memory_binary_criteo_path: 308 | if args.kaggle: 309 | return _get_in_memory_dataloader(args, stage) 310 | else: 311 | return _get_petastorm_dataloader(args, stage) 312 | elif "avazu" in args.in_memory_binary_criteo_path: 313 | return get_avazu_data_loader(args, stage) 314 | elif "embedding_bag" in args.in_memory_binary_criteo_path: 315 | # dlrm dataset: https://github.com/facebookresearch/dlrm_datasets 316 | return get_synth_data_loader(args, stage) 317 | elif "custom" in args.in_memory_binary_criteo_path: 318 | return get_custom_data_loader(args, stage) 319 | -------------------------------------------------------------------------------- /baselines/data/synth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch import distributed as dist 4 | import numpy as np 5 | from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union 6 | from colossalai.nn.parallel.layers.cache_embedding import CachedEmbeddingBag 7 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 8 | from fbgemm_gpu.split_table_batched_embeddings_ops import SplitTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, ComputeDevice, CacheAlgorithm 9 | import time 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import IterableDataset 13 | from torchrec.datasets.utils import PATH_MANAGER_KEY, Batch 14 | from torchrec.datasets.criteo import BinaryCriteoUtils 15 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 16 | from torch.utils.data import DataLoader, IterableDataset 17 | from torch.autograd.profiler import record_function 18 | import os 19 | 20 | 21 | 22 | # 52667139 as default 23 | CHOSEN_TABLES = [0, 2, 3, 4, 5, 7, 8, 9, 10, 12, 15, 18, 22, 27, 28] 24 | NUM_EMBEDDINGS_PER_FEATURE = '8015999, 9997799, 6138289, 21886, 204008, 6148, 282795, \ 25 | 1316, 3639992, 319, 3394206, 12203324, 4091851, 11641, 4657566' 26 | 27 | CAT_FEATURE_COUNT = len(CHOSEN_TABLES) 28 | INT_FEATURE_COUNT = 1 29 | DEFAULT_LABEL_NAME = "click" 30 | DEFAULT_INT_NAMES = ['rand_dense'] 31 | BATCH_SIZE = 65536 # batch_size of one file 32 | DEFAULT_CAT_NAMES = ["cat_{}".format(i) for i in range(len(CHOSEN_TABLES))] 33 | 34 | def choose_data_size(size: str): 35 | global CHOSEN_TABLES 36 | global NUM_EMBEDDINGS_PER_FEATURE 37 | global CAT_FEATURE_COUNT 38 | global DEFAULT_CAT_NAMES 39 | if size == '52M': 40 | pass 41 | elif size == '4M': 42 | # 4210897 43 | CHOSEN_TABLES = [5, 8, 37, 54, 71, 72, 73, 74, 85, 86, 89, 95, 96, 97, 107, 131, 163, 185, 196, 204, 211] 44 | NUM_EMBEDDINGS_PER_FEATURE = '204008, 282795, 539726, 153492, 11644, 11645, 13858, 5632, 60121, \ 45 | 11711, 11645, 43335, 4843, 67919, 6539, 17076, 11579, 866124, 711855, 302001, 873349' 46 | elif size == '512M': 47 | # 512196316 48 | CHOSEN_TABLES = [301, 302, 303, 305, 306, 307, 309, 310, 311, 312, 313, 316, 317, 318, 319, 320, 321, 322, 323, 49 | 325, 326, 327, 328, 330, 335, 336, 337, 338, 340, 341, 343, 344, 345, 346, 347, 348, 349, 350, 50 | 351, 352, 353, 354, 356, 357, 358, 359, 360, 361, 362, 363, 365, 366, 367, 368, 370, 371, 372, 51 | 375, 378, 379, 381, 382, 383, 384, 385, 386, 388, 389, 390, 391, 392, 393, 394, 395, 396, 398, 52 | 399, 400, 401, 403, 405, 406, 407, 410, 413, 414, 415, 416, 417] 53 | NUM_EMBEDDINGS_PER_FEATURE = '5999929, 5999885, 5999976, 5999981, 5999901, 5999929, 5999885, 5987787, 6000000, 5999929, \ 54 | 5998095, 3000000, 5999981, 2999993, 5999981, 5092210, 4999972, 5999976, 5998595, 5999548, 1999882, 4998224, 5999929, \ 55 | 5014074, 5999986, 5999978, 5999941, 5999816, 5997022, 5999975, 5999685, 5999981, 5999738, 5999380, 5966699, 5975615, \ 56 | 5908896, 5999996, 5999996, 5999983, 5734426, 5997022, 5999975, 5999929, 5999996, 5999239, 5989271, 5999477, 5999981, \ 57 | 5999887, 5999929, 5999506, 5999996, 5999548, 5998472, 5922238, 5999975, 5987787, 2999964, 5999983, 5999930, 5979767, \ 58 | 5999139, 5775261, 5999681, 4999929, 5963607, 5999967, 2999835, 5997068, 5998595, 5999996, 5992524, 5999997, 5999932, \ 59 | 5999878, 5999929, 5999857, 5999981, 5999981, 5999796, 5999995, 5994671, 5999329, 5997068, 5999981, 5973566, 5999407, 5966699' 60 | elif size == '2G': 61 | # 2328942965 62 | CHOSEN_TABLES = [i for i in range(856)] 63 | NUM_EMBEDDINGS_PER_FEATURE = '8015999, 1, 9997799, 6138289, 21886, 204008, 220, 6148, 282795, 1316, 3639992, 1, 319, 1, 1, 3394206, 1, 1, 12203324, 1, 1, 1, 4091851, 1, 1, 1, 1, 11641, 4657566, 11645, 1815316, 6618925, 1, 1, 1146, 1, 1, 539726, 1, 1, 1, 4972387, 2169014, 1, 1, 3912105, 272, 3102763, 1, 1, 4230916, 5878, 1, 11645, 153492, 6618919, 1, 4868981, 1, 11709, 3985291, 1, 5956409, 1, 1, 1, 1, 1875019, 1, 381, 1, 11644, 11645, 13858, 5632, 1, 1, 6600102, 6618930, 1, 5412960, 371, 5272932, 2555079, 1, 60121, 11711, 2553205, 2647912, 11645, 1, 5798421, 350642, 1, 1, 43335, 4843, 67919, 1, 1, 3239310, 1, 1, 6855076, 1, 1, 1, 6539, 1, 1, 111, 5990989, 1, 6516585, 1, 68, 1, 5758479, 2448698, 1, 6618898, 2614073, 3309464, 1, 6107319, 1, 1, 1928793, 4535618, 1, 309, 17076, 4950876, 304795, 4970566, 11209763, 5585, 2207735, 6618921, 1941, 5659, 5690, 1029648, 5662, 4718, 6385214, 5641, 1150, 5653, 6618924, 1, 339750, 1, 6112009, 589094, 2844205, 1, 6618929, 1, 1, 5667, 5167062, 2542266, 11579, 6147171, 951851, 6448758, 5253, 826, 1, 1997119, 6363150, 6614703, 2199, 6461842, 913043, 1, 1, 1, 1, 1, 1283500, 1, 6316718, 11579, 866124, 3660331, 1, 4032709, 1, 1, 3232, 2065, 6584597, 1, 1, 711855, 5672538, 1, 248, 1, 1, 1, 1, 302001, 4006173, 1, 1, 19623, 1, 4673098, 873349, 8026000, 2323, 1680975, 1, 1, 5710807, 2999962, 5999910, 5925217, 4997507, 5999548, 2999938, 4999774, 5999707, 5999710, 5764956, 5999992, 1, 2999941, 5982534, 1, 5999927, 4978274, 5999983, 5999997, 5999912, 5908896, 5999955, 5999935, 5999836, 5999983, 1, 5999477, 5999805, 5998095, 1, 5989511, 1999998, 4999998, 6000000, 5999929, 1, 5999993, 1, 1, 5999885, 5999867, 5999929, 1, 1, 5999962, 1, 2999898, 5998777, 5999934, 1, 1, 5992524, 5999737, 2999538, 5999870, 5992524, 5999975, 1, 5710807, 1, 1, 4932124, 5918154, 1, 5997068, 5999982, 1, 5998551, 5999994, 5999870, 4999919, 5999944, 5999904, 4999740, 5922605, 5975615, 5999816, 998848, 5999926, 5999816, 4999991, 5999861, 1, 5999929, 5999885, 5999976, 1, 5999981, 5999901, 5999929, 1, 5999885, 5987787, 6000000, 5999929, 5998095, 1, 1, 3000000, 5999981, 2999993, 5999981, 5092210, 4999972, 5999976, 5998595, 1, 5999548, 1999882, 4998224, 5999929, 1, 5014074, 1, 1, 1, 1, 5999986, 5999978, 5999941, 5999816, 1, 5997022, 5999975, 1, 5999685, 5999981, 5999738, 5999380, 5966699, 5975615, 5908896, 5999996, 5999996, 5999983, 5734426, 5997022, 1, 5999975, 5999929, 5999996, 5999239, 5989271, 5999477, 5999981, 5999887, 1, 5999929, 5999506, 5999996, 5999548, 1, 5998472, 5922238, 5999975, 1, 1, 5987787, 1, 1, 2999964, 5999983, 1, 5999930, 5979767, 5999139, 5775261, 5999681, 4999929, 1, 5963607, 5999967, 2999835, 5997068, 5998595, 5999996, 5992524, 5999997, 5999932, 1, 5999878, 5999929, 5999857, 5999981, 1, 5999981, 1, 5999796, 5999995, 5994671, 1, 1, 5999329, 1, 1, 5997068, 5999981, 5973566, 5999407, 5966699, 5966699, 5734426, 5999975, 5999976, 1, 1, 1, 2982258, 5999816, 5999929, 5999981, 1, 5999974, 5998583, 5966699, 5999870, 5999828, 5997903, 5999854, 5999685, 1, 1, 1, 5999954, 5999981, 1, 1, 5999967, 5999681, 5999932, 1, 5999857, 5971899, 5999972, 5999932, 1, 5999979, 1, 5998507, 5999927, 5999981, 5998595, 5999975, 1, 5949097, 5999239, 5999821, 1, 1, 5922605, 1, 5998473, 5999878, 5992217, 5999600, 1, 1, 5965155, 5999932, 5999971, 5922605, 5999854, 1, 5999963, 1, 5998777, 5999816, 1, 5999594, 1, 5733183, 5999396, 1, 1, 1, 5999440, 1, 5871486, 5999975, 1, 5999911, 5963607, 5997079, 1, 5999772, 1, 5999857, 5999863, 5999477, 5966699, 5999975, 5999996, 5999976, 1, 5999981, 6, 135, 422, 509, 922, 13, 16, 2, 1000, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1245, 5, 3136, 2704, 4999983, 4999983, 4999937, 4000, 1, 1, 1, 1, 1, 1, 1, 5, 6, 5, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 8, 2, 2, 1, 1, 2, 2, 8, 999908, 1443441, 21812, 21656, 14, 14, 1, 61, 2043, 2, 137, 13, 5, 10, 10, 10, 10, 9, 10, 12, 11, 10, 11, 11, 16, 10, 10, 11, 16, 1, 1, 1, 12543437, 1, 4999993, 4999986, 3999995, 4999863, 12543669, 1, 4999998, 12543650, 4999985, 12540349, 1, 1, 12532399, 4999998, 12540535, 397606, 1415772, 12270355, 9765324, 1, 1, 10287986, 1, 10794735, 10498728, 10965644, 4667, 43200, 43200, 43200, 43200, 629, 226, 215, 24, 4999983, 1028203, 20, 10, 1, 1, 1, 1, 12542, 2, 85, 2, 2, 2, 8, 2, 4, 2, 2, 2, 2, 2, 4, 6, 2, 2, 2, 8, 4, 9810, 2, 5884133, 1, 1, 2, 4, 2, 2, 1, 2, 4, 2, 18, 18, 2, 4, 6137684, 3309463, 1, 5607839, 1, 1, 1, 6567668, 1, 1, 4093617, 1, 4390473, 6305161, 1, 1, 1, 3779483, 1, 5303395, 1, 1, 6618931, 1, 1, 3640447, 3102628, 2542714, 269, 1, 1, 612, 6107445, 3978163, 6607315, 4868894, 1, 3983088, 5419139, 1, 271, 3911645, 2553341, 983482, 272, 1, 1600937, 1, 1266741, 1520037, 3704018, 1, 1, 3345638, 6618817, 3117219, 1, 1, 1877662, 1876652, 3309463, 4378405, 847629, 1661, 621, 624, 3667331, 1, 269, 2614200, 1, 1, 1, 6618759, 204, 1, 6618922, 1, 5998824, 5999974, 4977045, 5999994, 5999995, 5999993, 5999981, 5999957, 5999908, 5984869, 5999994, 1, 5999548, 1, 1, 5999831, 5999978, 5999396, 5999908, 5999953, 1, 1, 1, 5999750, 5999958, 5999477, 1, 5999981, 5999548, 5999953, 1, 5999548, 1, 1, 5999986, 5999975, 1, 5999908, 1, 1, 5999975, 1, 5999548, 5998836, 5999477, 5999737, 5999708, 5999737, 5999783, 1, 1, 5999901, 5999708, 5999711, 5999967, 5999548, 5999548, 1, 1, 5999783, 1, 5999694, 5999520, 5999975, 1, 1, 1, 5999477, 1, 5999975, 5991327, 5975615, 1, 4494193, 5999918, 1, 5999725, 1, 5999995, 1, 1, 5999477, 4998981, 5999975, 5999329, 5976924, 1, 2, 4999540, 5999138, 3994232, 1' 64 | else: 65 | raise NotImplementedError() 66 | CAT_FEATURE_COUNT = len(CHOSEN_TABLES) 67 | DEFAULT_CAT_NAMES = ["cat_{}".format(i) for i in range(len(CHOSEN_TABLES))] 68 | 69 | class SynthIterDataPipe(IterableDataset): 70 | def __init__( 71 | self, 72 | sparse_paths: List[str], 73 | batch_size: int, 74 | rank: int, 75 | world_size: int, 76 | shuffle_batches: bool = False, 77 | mmap_mode: bool = False, 78 | hashes: Optional[List[int]] = None, 79 | path_manager_key: str = PATH_MANAGER_KEY, 80 | ) -> None: 81 | self.sparse_paths = sparse_paths, 82 | self.batch_size = batch_size 83 | self.rank = rank 84 | self.world_size = world_size 85 | self.shuffle_batches = shuffle_batches 86 | self.mmap_mode = mmap_mode 87 | self.hashes = hashes 88 | self.path_manager_key = path_manager_key 89 | 90 | self.indices_per_table_per_file = [] 91 | self.offsets_per_table_per_file = [] 92 | self.lengths_per_table_per_file = [] 93 | self.num_rows_per_file = [] 94 | 95 | self._buffer = None 96 | for file in self.sparse_paths[0]: 97 | print("load file: ", file) 98 | indices_per_table, offsets_per_table, lengths_per_table = self._load_single_file(file) 99 | self.indices_per_table_per_file.append(indices_per_table) 100 | self.offsets_per_table_per_file.append(offsets_per_table) 101 | self.lengths_per_table_per_file.append(lengths_per_table) 102 | self.num_rows_per_file.append(offsets_per_table[0].shape[0]) 103 | self.num_batches = sum(self.num_rows_per_file) // self.batch_size 104 | self.keys: List[str] = DEFAULT_CAT_NAMES 105 | self.stride = batch_size 106 | 107 | def __iter__(self) -> Iterator[Batch]: 108 | # self._buffer structure: 109 | ''' 110 | self._buffer[0]: List of sparse_indices per table 111 | self._buffer[1]: List of sparse_lengths per table 112 | ''' 113 | def append_to_buffer(sparse_indices: List[torch.Tensor], sparse_lengths: List[torch.Tensor]): 114 | if self._buffer is None: 115 | self._buffer = [sparse_indices, sparse_lengths] 116 | else: 117 | for tb_idx, (sparse_indices_table, sparse_lengths_table) in enumerate(zip(sparse_indices, sparse_lengths)): 118 | self._buffer[0][tb_idx] = torch.cat((self._buffer[0][tb_idx], sparse_indices_table)) 119 | self._buffer[1][tb_idx] = torch.cat((self._buffer[1][tb_idx], sparse_lengths_table)) 120 | 121 | file_idx = 0 122 | row_idx = 0 123 | batch_idx = 0 124 | while batch_idx < self.num_batches: 125 | buffer_row_count = 0 if self._buffer is None else self._buffer[1][0].shape[0] 126 | if buffer_row_count == self.batch_size: 127 | yield self._make_batch(*self._buffer) 128 | batch_idx += 1 129 | self._buffer = None 130 | else: 131 | rows_to_get = min( 132 | self.batch_size - buffer_row_count, 133 | BATCH_SIZE - row_idx 134 | ) 135 | # slice_ = slice(row_idx, row_idx + rows_to_get) 136 | with record_function("## load_batch ##"): 137 | sparse_indices, sparse_lengths = self._load_slice_batch(self.indices_per_table_per_file[file_idx], 138 | self.offsets_per_table_per_file[file_idx], 139 | self.lengths_per_table_per_file[file_idx], 140 | row_idx, 141 | rows_to_get 142 | ) 143 | append_to_buffer(sparse_indices, sparse_lengths) 144 | row_idx += rows_to_get 145 | if row_idx >= BATCH_SIZE: 146 | file_idx += 1 147 | row_idx = 0 148 | 149 | def __len__(self) -> int: 150 | return self.num_batches 151 | 152 | def _load_single_file(self, file_path): 153 | # TODO: shard loading 154 | # file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range( 155 | # lengths=[BATCH_SIZE], 156 | # rank=self.rank, 157 | # world_size=self.world_size 158 | # ) 159 | # for _,(range_left, range_right) in file_idx_to_row_range.items(): 160 | # rank_range_left = range_left 161 | # rank_range_right = range_right 162 | rank_range_left = 0 163 | rank_range_right = 65535 164 | indices, offsets, lengths = torch.load(file_path) 165 | indices = indices.int() 166 | offsets = offsets.int() 167 | lengths = lengths.int() 168 | indices_per_table = [] 169 | offsets_per_table = [] 170 | lengths_per_table = [] 171 | for i in CHOSEN_TABLES: 172 | start_pos = offsets[i * BATCH_SIZE + rank_range_left] 173 | end_pos = offsets[i * BATCH_SIZE + rank_range_right + 1] 174 | part = indices[start_pos:end_pos] 175 | indices_per_table.append(part) 176 | lengths_per_table.append(lengths[i][rank_range_left:rank_range_right + 1]) 177 | offsets_per_table.append(torch.cumsum( 178 | torch.cat((torch.tensor([0]), lengths[i][rank_range_left:rank_range_right + 1])), 0 179 | )) 180 | return indices_per_table, offsets_per_table, lengths_per_table 181 | 182 | def _load_random_batch(self, indices_per_table, offsets_per_table, lengths_per_table, choose): 183 | chosen_indices_list = [] 184 | chosen_lengths_list = [] 185 | # choose = torch.randint(0, offsets_per_table[0].shape[0] - 1, (self.batch_size,)) 186 | for indices, offsets, lengths in zip(indices_per_table, offsets_per_table, lengths_per_table): 187 | chosen_lengths_list.append(lengths[choose]) 188 | start_list = offsets[choose] 189 | end_list = offsets[1:][choose] 190 | chosen_indices_atoms = [] 191 | for start, end in zip(start_list, end_list): 192 | chosen_indices_atoms.append(indices[start: end]) 193 | chosen_indices_list.append(torch.cat(chosen_indices_atoms, 0)) 194 | return chosen_indices_list, chosen_lengths_list 195 | 196 | def _load_slice_batch(self, indices_per_table, offsets_per_table, lengths_per_table, row_start, row_length): 197 | chosen_indices_list = [] 198 | chosen_lengths_list = [] 199 | for indices, offsets, lengths in zip(indices_per_table, offsets_per_table, lengths_per_table): 200 | chosen_lengths_list.append(lengths.narrow(0, row_start, row_length)) 201 | start = offsets[row_start] 202 | end = offsets[row_start + row_length] 203 | chosen_indices_list.append(indices.narrow(0, start, end - start)) 204 | return chosen_indices_list, chosen_lengths_list 205 | 206 | def _make_batch(self, chosen_indices_list, chosen_lengths_list): 207 | batch_size = chosen_lengths_list[0].shape[0] 208 | ret = Batch( 209 | dense_features=torch.rand(batch_size,1), 210 | sparse_features=KeyedJaggedTensor( 211 | keys=self.keys, 212 | values=torch.cat(chosen_indices_list), 213 | lengths=torch.cat(chosen_lengths_list), 214 | stride=batch_size, 215 | ), 216 | labels=torch.randint(2, (batch_size,)) 217 | ) 218 | return ret 219 | def get_synth_data_loader( 220 | args: argparse.Namespace, 221 | stage: str) -> DataLoader: 222 | files = os.listdir(args.in_memory_binary_criteo_path) 223 | files = filter(lambda s: "fbgemm_t856_bs65536_" in s, files) 224 | files = [os.path.join(args.in_memory_binary_criteo_path, x) for x in files] 225 | rank = dist.get_rank() 226 | world_size = dist.get_world_size() 227 | if stage == "train": 228 | dataloader = DataLoader( 229 | SynthIterDataPipe( 230 | files, 231 | batch_size=args.batch_size, 232 | rank=rank, 233 | world_size=world_size, 234 | shuffle_batches=args.shuffle_batches, 235 | hashes = None 236 | ), 237 | batch_size=None, 238 | pin_memory=args.pin_memory, 239 | collate_fn=lambda x: x, 240 | ) 241 | else : 242 | dataloader = [] 243 | return dataloader -------------------------------------------------------------------------------- /baselines/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dlrm import DLRMTrain 2 | 3 | __all__ = ['DLRMTrain'] 4 | -------------------------------------------------------------------------------- /baselines/models/deepfm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from typing import List 9 | 10 | import torch 11 | from torch import nn 12 | from torchrec import EmbeddingBagCollection, KeyedJaggedTensor 13 | from torchrec.modules.deepfm import DeepFM, FactorizationMachine 14 | from torchrec.sparse.jagged_tensor import KeyedTensor 15 | 16 | 17 | class SparseArch(nn.Module): 18 | """ 19 | Processes the sparse features of the DeepFMNN model. Does embedding lookups for all 20 | EmbeddingBag and embedding features of each collection. 21 | Args: 22 | embedding_bag_collection (EmbeddingBagCollection): represents a collection of 23 | pooled embeddings. 24 | Example:: 25 | eb1_config = EmbeddingBagConfig( 26 | name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] 27 | ) 28 | eb2_config = EmbeddingBagConfig( 29 | name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] 30 | ) 31 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) 32 | # 0 1 2 <-- batch 33 | # 0 [0,1] None [2] 34 | # 1 [3] [4] [5,6,7] 35 | # ^ 36 | # feature 37 | features = KeyedJaggedTensor.from_offsets_sync( 38 | keys=["f1", "f2"], 39 | values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), 40 | offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), 41 | ) 42 | sparse_arch(features) 43 | """ 44 | 45 | def __init__(self, embedding_bag_collection: EmbeddingBagCollection) -> None: 46 | super().__init__() 47 | self.embedding_bag_collection: EmbeddingBagCollection = embedding_bag_collection 48 | 49 | def forward( 50 | self, 51 | features: KeyedJaggedTensor, 52 | ) -> KeyedTensor: 53 | """ 54 | Args: 55 | features (KeyedJaggedTensor): 56 | Returns: 57 | KeyedJaggedTensor: an output KJT of size F * D X B. 58 | """ 59 | return self.embedding_bag_collection(features) 60 | 61 | 62 | class DenseArch(nn.Module): 63 | """ 64 | Processes the dense features of DeepFMNN model. Output layer is sized to 65 | the embedding_dimension of the EmbeddingBagCollection embeddings. 66 | Args: 67 | in_features (int): dimensionality of the dense input features. 68 | hidden_layer_size (int): sizes of the hidden layers in the DenseArch. 69 | embedding_dim (int): the same size of the embedding_dimension of sparseArch. 70 | device (torch.device): default compute device. 71 | Example:: 72 | B = 20 73 | D = 3 74 | in_features = 10 75 | dense_arch = DenseArch(in_features=10, hidden_layer_size=10, embedding_dim=D) 76 | dense_embedded = dense_arch(torch.rand((B, 10))) 77 | """ 78 | 79 | def __init__( 80 | self, 81 | in_features: int, 82 | hidden_layer_size: int, 83 | embedding_dim: int, 84 | ) -> None: 85 | super().__init__() 86 | self.model: nn.Module = nn.Sequential( 87 | nn.Linear(in_features, hidden_layer_size), 88 | nn.ReLU(), 89 | nn.Linear(hidden_layer_size, embedding_dim), 90 | nn.ReLU(), 91 | ) 92 | 93 | def forward(self, features: torch.Tensor) -> torch.Tensor: 94 | """ 95 | Args: 96 | features (torch.Tensor): size B X `num_features`. 97 | Returns: 98 | torch.Tensor: an output tensor of size B X D. 99 | """ 100 | return self.model(features) 101 | 102 | 103 | class FMInteractionArch(nn.Module): 104 | """ 105 | Processes the output of both `SparseArch` (sparse_features) and `DenseArch` 106 | (dense_features) and apply the general DeepFM interaction according to the 107 | external source of DeepFM paper: https://arxiv.org/pdf/1703.04247.pdf 108 | The output dimension is expected to be a cat of `dense_features`, D. 109 | Args: 110 | fm_in_features (int): the input dimension of `dense_module` in DeepFM. For 111 | example, if the input embeddings is [randn(3, 2, 3), randn(3, 4, 5)], then 112 | the `fm_in_features` should be: 2 * 3 + 4 * 5. 113 | sparse_feature_names (List[str]): length of F. 114 | deep_fm_dimension (int): output of the deep interaction (DI) in the DeepFM arch. 115 | Example:: 116 | D = 3 117 | B = 10 118 | keys = ["f1", "f2"] 119 | F = len(keys) 120 | fm_inter_arch = FMInteractionArch(sparse_feature_names=keys) 121 | dense_features = torch.rand((B, D)) 122 | sparse_features = KeyedTensor( 123 | keys=keys, 124 | length_per_key=[D, D], 125 | values=torch.rand((B, D * F)), 126 | ) 127 | cat_fm_output = fm_inter_arch(dense_features, sparse_features) 128 | """ 129 | 130 | def __init__( 131 | self, 132 | fm_in_features: int, 133 | sparse_feature_names: List[str], 134 | deep_fm_dimension: int, 135 | ) -> None: 136 | super().__init__() 137 | self.sparse_feature_names: List[str] = sparse_feature_names 138 | self.deep_fm = DeepFM( 139 | dense_module=nn.Sequential( 140 | nn.Linear(fm_in_features, deep_fm_dimension), 141 | nn.ReLU(), 142 | ) 143 | ) 144 | self.fm = FactorizationMachine() 145 | 146 | def forward( 147 | self, dense_features: torch.Tensor, sparse_features: KeyedTensor 148 | ) -> torch.Tensor: 149 | """ 150 | Args: 151 | dense_features (torch.Tensor): tensor of size B X D. 152 | sparse_features (KeyedJaggedTensor): KJT of size F * D X B. 153 | Returns: 154 | torch.Tensor: an output tensor of size B X (D + DI + 1). 155 | """ 156 | if len(self.sparse_feature_names) == 0: 157 | return dense_features 158 | 159 | tensor_list: List[torch.Tensor] = [dense_features] 160 | # dense/sparse interaction 161 | # size B X F 162 | for feature_name in self.sparse_feature_names: 163 | tensor_list.append(sparse_features[feature_name]) 164 | 165 | deep_interaction = self.deep_fm(tensor_list) 166 | fm_interaction = self.fm(tensor_list) 167 | 168 | return torch.cat([dense_features, deep_interaction, fm_interaction], dim=1) 169 | 170 | 171 | class OverArch(nn.Module): 172 | """ 173 | Final Arch - simple MLP. The output is just one target. 174 | Args: 175 | in_features (int): the output dimension of the interaction arch. 176 | Example:: 177 | B = 20 178 | over_arch = OverArch() 179 | logits = over_arch(torch.rand((B, 10))) 180 | """ 181 | 182 | def __init__(self, in_features: int) -> None: 183 | super().__init__() 184 | self.model: nn.Module = nn.Sequential( 185 | nn.Linear(in_features, 1), 186 | nn.Sigmoid(), 187 | ) 188 | 189 | def forward(self, features: torch.Tensor) -> torch.Tensor: 190 | """ 191 | Args: 192 | features (torch.Tensor): 193 | Returns: 194 | torch.Tensor: an output tensor of size B X 1. 195 | """ 196 | return self.model(features) 197 | 198 | 199 | class SimpleDeepFMNN(nn.Module): 200 | """ 201 | Basic recsys module with DeepFM arch. Processes sparse features by 202 | learning pooled embeddings for each feature. Learns the relationship between 203 | dense features and sparse features by projecting dense features into the same 204 | embedding space. Learns the interaction among those dense and sparse features 205 | by deep_fm proposed in this paper: https://arxiv.org/pdf/1703.04247.pdf 206 | The module assumes all sparse features have the same embedding dimension 207 | (i.e, each `EmbeddingBagConfig` uses the same embedding_dim) 208 | The following notation is used throughout the documentation for the models: 209 | * F: number of sparse features 210 | * D: embedding_dimension of sparse features 211 | * B: batch size 212 | * num_features: number of dense features 213 | Args: 214 | num_dense_features (int): the number of input dense features. 215 | embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags 216 | used to define `SparseArch`. 217 | hidden_layer_size (int): the hidden layer size used in dense module. 218 | deep_fm_dimension (int): the output layer size used in `deep_fm`'s deep 219 | interaction module. 220 | Example:: 221 | B = 2 222 | D = 8 223 | eb1_config = EmbeddingBagConfig( 224 | name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] 225 | ) 226 | eb2_config = EmbeddingBagConfig( 227 | name="t2", 228 | embedding_dim=D, 229 | num_embeddings=100, 230 | feature_names=["f2"], 231 | ) 232 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) 233 | sparse_nn = SimpleDeepFMNN( 234 | embedding_bag_collection=ebc, hidden_layer_size=20, over_embedding_dim=5 235 | ) 236 | features = torch.rand((B, 100)) 237 | # 0 1 238 | # 0 [1,2] [4,5] 239 | # 1 [4,3] [2,9] 240 | # ^ 241 | # feature 242 | sparse_features = KeyedJaggedTensor.from_offsets_sync( 243 | keys=["f1", "f3"], 244 | values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), 245 | offsets=torch.tensor([0, 2, 4, 6, 8]), 246 | ) 247 | logits = sparse_nn( 248 | dense_features=features, 249 | sparse_features=sparse_features, 250 | ) 251 | """ 252 | 253 | def __init__( 254 | self, 255 | num_dense_features: int, 256 | embedding_bag_collection: EmbeddingBagCollection, 257 | hidden_layer_size: int, 258 | deep_fm_dimension: int, 259 | ) -> None: 260 | super().__init__() 261 | assert ( 262 | len(embedding_bag_collection.embedding_bag_configs()) > 0 263 | ), "At least one embedding bag is required" 264 | for i in range(1, len(embedding_bag_collection.embedding_bag_configs())): 265 | conf_prev = embedding_bag_collection.embedding_bag_configs()[i - 1] 266 | conf = embedding_bag_collection.embedding_bag_configs()[i] 267 | assert ( 268 | conf_prev.embedding_dim == conf.embedding_dim 269 | ), "All EmbeddingBagConfigs must have the same dimension" 270 | embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[ 271 | 0 272 | ].embedding_dim 273 | 274 | feature_names = [] 275 | 276 | fm_in_features = embedding_dim 277 | for conf in embedding_bag_collection.embedding_bag_configs(): 278 | for feat in conf.feature_names: 279 | feature_names.append(feat) 280 | fm_in_features += conf.embedding_dim 281 | 282 | self.sparse_arch = SparseArch(embedding_bag_collection) 283 | self.dense_arch = DenseArch( 284 | in_features=num_dense_features, 285 | hidden_layer_size=hidden_layer_size, 286 | embedding_dim=embedding_dim, 287 | ) 288 | self.inter_arch = FMInteractionArch( 289 | fm_in_features=fm_in_features, 290 | sparse_feature_names=feature_names, 291 | deep_fm_dimension=deep_fm_dimension, 292 | ) 293 | over_in_features = embedding_dim + deep_fm_dimension + 1 294 | self.over_arch = OverArch(over_in_features) 295 | 296 | def forward( 297 | self, 298 | dense_features: torch.Tensor, 299 | sparse_features: KeyedJaggedTensor, 300 | ) -> torch.Tensor: 301 | """ 302 | Args: 303 | dense_features (torch.Tensor): the dense features. 304 | sparse_features (KeyedJaggedTensor): the sparse features. 305 | Returns: 306 | torch.Tensor: logits with size B X 1. 307 | """ 308 | embedded_dense = self.dense_arch(dense_features) 309 | embedded_sparse = self.sparse_arch(sparse_features) 310 | concatenated_dense = self.inter_arch( 311 | dense_features=embedded_dense, sparse_features=embedded_sparse 312 | ) 313 | logits = self.over_arch(concatenated_dense) 314 | return logits -------------------------------------------------------------------------------- /baselines/models/dlrm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from typing import Dict, List, Optional, Tuple 9 | 10 | import torch 11 | from torch import nn 12 | from torchrec.modules.embedding_modules import EmbeddingBagCollection 13 | from torchrec.modules.mlp import MLP 14 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor 15 | from torchrec.datasets.utils import Batch 16 | 17 | 18 | def choose(n: int, k: int) -> int: 19 | """ 20 | Simple implementation of math.comb for Python 3.7 compatibility. 21 | """ 22 | if 0 <= k <= n: 23 | ntok = 1 24 | ktok = 1 25 | for t in range(1, min(k, n - k) + 1): 26 | ntok *= n 27 | ktok *= t 28 | n -= 1 29 | return ntok // ktok 30 | else: 31 | return 0 32 | 33 | 34 | class SparseArch(nn.Module): 35 | """ 36 | Processes the sparse features of DLRM. Does embedding lookups for all EmbeddingBag 37 | and embedding features of each collection. 38 | 39 | Args: 40 | embedding_bag_collection (EmbeddingBagCollection): represents a collection of 41 | pooled embeddings. 42 | 43 | Example:: 44 | 45 | eb1_config = EmbeddingBagConfig( 46 | name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] 47 | ) 48 | eb2_config = EmbeddingBagConfig( 49 | name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] 50 | ) 51 | 52 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) 53 | sparse_arch = SparseArch(embedding_bag_collection) 54 | 55 | # 0 1 2 <-- batch 56 | # 0 [0,1] None [2] 57 | # 1 [3] [4] [5,6,7] 58 | # ^ 59 | # feature 60 | features = KeyedJaggedTensor.from_offsets_sync( 61 | keys=["f1", "f2"], 62 | values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), 63 | offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), 64 | ) 65 | 66 | sparse_embeddings = sparse_arch(features) 67 | """ 68 | 69 | def __init__(self, embedding_bag_collection: EmbeddingBagCollection) -> None: 70 | super().__init__() 71 | self.embedding_bag_collection: EmbeddingBagCollection = embedding_bag_collection 72 | emb_config = self.embedding_bag_collection.embedding_bag_configs() 73 | assert emb_config, "Embedding bag collection cannot be empty!" 74 | self.D: int = emb_config[0].embedding_dim 75 | self._sparse_feature_names: List[str] = [ 76 | name for conf in embedding_bag_collection.embedding_bag_configs() for name in conf.feature_names 77 | ] 78 | 79 | self.F: int = len(self._sparse_feature_names) 80 | 81 | def forward( 82 | self, 83 | features: KeyedJaggedTensor, 84 | ) -> torch.Tensor: 85 | """ 86 | Args: 87 | features (KeyedJaggedTensor): an input tensor of sparse features. 88 | 89 | Returns: 90 | torch.Tensor: tensor of shape B X F X D. 91 | """ 92 | sparse_features: KeyedTensor = self.embedding_bag_collection(features) 93 | B: int = features.stride() 94 | 95 | sparse: Dict[str, torch.Tensor] = sparse_features.to_dict() 96 | sparse_values: List[torch.Tensor] = [] 97 | for name in self.sparse_feature_names: 98 | sparse_values.append(sparse[name]) 99 | return torch.cat(sparse_values, dim=1).view(B, self.F, self.D) 100 | 101 | @property 102 | def sparse_feature_names(self) -> List[str]: 103 | return self._sparse_feature_names 104 | 105 | 106 | class DenseArch(nn.Module): 107 | """ 108 | Processes the dense features of DLRM model. 109 | 110 | Args: 111 | in_features (int): dimensionality of the dense input features. 112 | layer_sizes (List[int]): list of layer sizes. 113 | device (Optional[torch.device]): default compute device. 114 | 115 | Example:: 116 | 117 | B = 20 118 | D = 3 119 | dense_arch = DenseArch(10, layer_sizes=[15, D]) 120 | dense_embedded = dense_arch(torch.rand((B, 10))) 121 | """ 122 | 123 | def __init__( 124 | self, 125 | in_features: int, 126 | layer_sizes: List[int], 127 | device: Optional[torch.device] = None, 128 | ) -> None: 129 | super().__init__() 130 | self.model: nn.Module = MLP(in_features, layer_sizes, bias=True, activation="relu", device=device) 131 | 132 | def forward(self, features: torch.Tensor) -> torch.Tensor: 133 | """ 134 | Args: 135 | features (torch.Tensor): an input tensor of dense features. 136 | 137 | Returns: 138 | torch.Tensor: an output tensor of size B X D. 139 | """ 140 | return self.model(features) 141 | 142 | 143 | class InteractionArch(nn.Module): 144 | """ 145 | Processes the output of both `SparseArch` (sparse_features) and `DenseArch` 146 | (dense_features). Returns the pairwise dot product of each sparse feature pair, 147 | the dot product of each sparse features with the output of the dense layer, 148 | and the dense layer itself (all concatenated). 149 | 150 | .. note:: 151 | The dimensionality of the `dense_features` (D) is expected to match the 152 | dimensionality of the `sparse_features` so that the dot products between them 153 | can be computed. 154 | 155 | 156 | Args: 157 | num_sparse_features (int): F. 158 | 159 | Example:: 160 | 161 | D = 3 162 | B = 10 163 | keys = ["f1", "f2"] 164 | F = len(keys) 165 | inter_arch = InteractionArch(num_sparse_features=len(keys)) 166 | 167 | dense_features = torch.rand((B, D)) 168 | sparse_features = torch.rand((B, F, D)) 169 | 170 | # B X (D + F + F choose 2) 171 | concat_dense = inter_arch(dense_features, sparse_features) 172 | """ 173 | 174 | def __init__(self, num_sparse_features: int, num_dense_features: int = 1) -> None: 175 | super().__init__() 176 | self.F: int = num_sparse_features 177 | self.num_dense_features = num_dense_features 178 | self.register_buffer( 179 | 'triu_indices', 180 | torch.triu_indices(self.F + self.num_dense_features, self.F + self.num_dense_features, 181 | offset=1).requires_grad_(False), False) 182 | 183 | def forward(self, dense_features: torch.Tensor, sparse_features: torch.Tensor) -> torch.Tensor: 184 | """ 185 | Args: 186 | dense_features (torch.Tensor): an input tensor of size B X D. 187 | sparse_features (torch.Tensor): an input tensor of size B X F X D. 188 | 189 | Returns: 190 | torch.Tensor: an output tensor of size B X (D + F + F choose 2). 191 | """ 192 | if self.F <= 0: 193 | return dense_features 194 | 195 | if self.num_dense_features <= 0: 196 | combined_values = sparse_features 197 | else: 198 | # b f+1 d 199 | combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), dim=1) 200 | 201 | # dense/sparse + sparse/sparse interaction 202 | # size B X (F + F choose 2) 203 | interactions = torch.bmm(combined_values, torch.transpose(combined_values, 1, 2)) 204 | interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]] 205 | 206 | return torch.cat((dense_features, interactions_flat), dim=1) 207 | 208 | 209 | class OverArch(nn.Module): 210 | """ 211 | Final Arch of DLRM - simple MLP over OverArch. 212 | 213 | Args: 214 | in_features (int): size of the input. 215 | layer_sizes (List[int]): sizes of the layers of the `OverArch`. 216 | device (Optional[torch.device]): default compute device. 217 | 218 | Example:: 219 | 220 | B = 20 221 | D = 3 222 | over_arch = OverArch(10, [5, 1]) 223 | logits = over_arch(torch.rand((B, 10))) 224 | """ 225 | 226 | def __init__( 227 | self, 228 | in_features: int, 229 | layer_sizes: List[int], 230 | device: Optional[torch.device] = None, 231 | ) -> None: 232 | super().__init__() 233 | if len(layer_sizes) <= 1: 234 | raise ValueError("OverArch must have multiple layers.") 235 | self.model: nn.Module = nn.Sequential( 236 | MLP( 237 | in_features, 238 | layer_sizes[:-1], 239 | bias=True, 240 | activation="relu", 241 | device=device, 242 | ), 243 | nn.Linear(layer_sizes[-2], layer_sizes[-1], bias=True, device=device), 244 | ) 245 | 246 | def forward(self, features: torch.Tensor) -> torch.Tensor: 247 | """ 248 | Args: 249 | features (torch.Tensor): 250 | 251 | Returns: 252 | torch.Tensor: size B X layer_sizes[-1] 253 | """ 254 | return self.model(features) 255 | 256 | 257 | class DLRM(nn.Module): 258 | """ 259 | Recsys model from "Deep Learning Recommendation Model for Personalization and 260 | Recommendation Systems" (https://arxiv.org/abs/1906.00091). Processes sparse 261 | features by learning pooled embeddings for each feature. Learns the relationship 262 | between dense features and sparse features by projecting dense features into the 263 | same embedding space. Also, learns the pairwise relationships between sparse 264 | features. 265 | 266 | The module assumes all sparse features have the same embedding dimension 267 | (i.e. each EmbeddingBagConfig uses the same embedding_dim). 268 | 269 | The following notation is used throughout the documentation for the models: 270 | 271 | * F: number of sparse features 272 | * D: embedding_dimension of sparse features 273 | * B: batch size 274 | * num_features: number of dense features 275 | 276 | Args: 277 | embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags 278 | used to define `SparseArch`. 279 | dense_in_features (int): the dimensionality of the dense input features. 280 | dense_arch_layer_sizes (List[int]): the layer sizes for the `DenseArch`. 281 | over_arch_layer_sizes (List[int]): the layer sizes for the `OverArch`. 282 | The output dimension of the `InteractionArch` should not be manually 283 | specified here. 284 | dense_device (Optional[torch.device]): default compute device. 285 | 286 | Example:: 287 | 288 | B = 2 289 | D = 8 290 | 291 | eb1_config = EmbeddingBagConfig( 292 | name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"] 293 | ) 294 | eb2_config = EmbeddingBagConfig( 295 | name="t2", 296 | embedding_dim=D, 297 | num_embeddings=100, 298 | feature_names=["f2"], 299 | ) 300 | 301 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) 302 | model = DLRM( 303 | embedding_bag_collection=ebc, 304 | dense_in_features=100, 305 | dense_arch_layer_sizes=[20], 306 | over_arch_layer_sizes=[5, 1], 307 | ) 308 | 309 | features = torch.rand((B, 100)) 310 | 311 | # 0 1 312 | # 0 [1,2] [4,5] 313 | # 1 [4,3] [2,9] 314 | # ^ 315 | # feature 316 | sparse_features = KeyedJaggedTensor.from_offsets_sync( 317 | keys=["f1", "f3"], 318 | values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]), 319 | offsets=torch.tensor([0, 2, 4, 6, 8]), 320 | ) 321 | 322 | logits = model( 323 | dense_features=features, 324 | sparse_features=sparse_features, 325 | ) 326 | """ 327 | 328 | def __init__( 329 | self, 330 | embedding_bag_collection: EmbeddingBagCollection, 331 | dense_in_features: int, 332 | dense_arch_layer_sizes: List[int], 333 | over_arch_layer_sizes: List[int], 334 | dense_device: Optional[torch.device] = None, 335 | ) -> None: 336 | super().__init__() 337 | 338 | # For torchrec version 0.1.x 339 | # emb_configs = embedding_bag_collection.embedding_bag_configs 340 | # this is for torchrec version 0.2.0 341 | emb_configs = embedding_bag_collection.embedding_bag_configs() 342 | assert (len(emb_configs) > 0), "At least one embedding bag is required" 343 | for i in range(1, len(emb_configs)): 344 | conf_prev = emb_configs[i - 1] 345 | conf = emb_configs[i] 346 | assert ( 347 | conf_prev.embedding_dim == conf.embedding_dim), "All EmbeddingBagConfigs must have the same dimension" 348 | embedding_dim: int = emb_configs[0].embedding_dim 349 | if dense_arch_layer_sizes[-1] != embedding_dim: 350 | raise ValueError(f"embedding_bag_collection dimension ({embedding_dim}) and final dense " 351 | "arch layer size ({dense_arch_layer_sizes[-1]}) must match.") 352 | 353 | self.sparse_arch: SparseArch = SparseArch(embedding_bag_collection) 354 | num_sparse_features: int = len(self.sparse_arch.sparse_feature_names) 355 | 356 | self.dense_arch = DenseArch( 357 | in_features=dense_in_features, 358 | layer_sizes=dense_arch_layer_sizes, 359 | device=dense_device, 360 | ) 361 | self.inter_arch = InteractionArch(num_sparse_features=num_sparse_features) 362 | 363 | over_in_features: int = (embedding_dim + choose(num_sparse_features, 2) + num_sparse_features) 364 | self.over_arch = OverArch( 365 | in_features=over_in_features, 366 | layer_sizes=over_arch_layer_sizes, 367 | device=dense_device, 368 | ) 369 | 370 | def forward( 371 | self, 372 | dense_features: torch.Tensor, 373 | sparse_features: KeyedJaggedTensor, 374 | ) -> torch.Tensor: 375 | """ 376 | Args: 377 | dense_features (torch.Tensor): the dense features. 378 | sparse_features (KeyedJaggedTensor): the sparse features. 379 | 380 | Returns: 381 | torch.Tensor: logits. 382 | """ 383 | embedded_dense = self.dense_arch(dense_features) 384 | embedded_sparse = self.sparse_arch(sparse_features) 385 | concatenated_dense = self.inter_arch(dense_features=embedded_dense, sparse_features=embedded_sparse) 386 | logits = self.over_arch(concatenated_dense) 387 | return logits 388 | 389 | 390 | class DLRMTrain(nn.Module): 391 | """ 392 | nn.Module to wrap DLRM model to use with train_pipeline. 393 | 394 | DLRM Recsys model from "Deep Learning Recommendation Model for Personalization and 395 | Recommendation Systems" (https://arxiv.org/abs/1906.00091). Processes sparse 396 | features by learning pooled embeddings for each feature. Learns the relationship 397 | between dense features and sparse features by projecting dense features into the 398 | same embedding space. Also, learns the pairwise relationships between sparse 399 | features. 400 | 401 | The module assumes all sparse features have the same embedding dimension 402 | (i.e, each EmbeddingBagConfig uses the same embedding_dim) 403 | 404 | Args: 405 | embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags 406 | used to define SparseArch. 407 | dense_in_features (int): the dimensionality of the dense input features. 408 | dense_arch_layer_sizes (list[int]): the layer sizes for the DenseArch. 409 | over_arch_layer_sizes (list[int]): the layer sizes for the OverArch. NOTE: The 410 | output dimension of the InteractionArch should not be manually specified 411 | here. 412 | dense_device: (Optional[torch.device]). 413 | 414 | Call Args: 415 | batch: batch used with criteo and random data from torchrec.datasets 416 | 417 | Returns: 418 | Tuple[loss, Tuple[loss, logits, labels]] 419 | 420 | Example:: 421 | 422 | ebc = EmbeddingBagCollection(config=ebc_config) 423 | model = DLRMTrain( 424 | embedding_bag_collection=ebc, 425 | dense_in_features=100, 426 | dense_arch_layer_sizes=[20], 427 | over_arch_layer_sizes=[5, 1], 428 | ) 429 | """ 430 | 431 | def __init__( 432 | self, 433 | embedding_bag_collection: EmbeddingBagCollection, 434 | dense_in_features: int, 435 | dense_arch_layer_sizes: List[int], 436 | over_arch_layer_sizes: List[int], 437 | dense_device: Optional[torch.device] = None, 438 | ) -> None: 439 | super().__init__() 440 | self.model = DLRM( 441 | embedding_bag_collection=embedding_bag_collection, 442 | dense_in_features=dense_in_features, 443 | dense_arch_layer_sizes=dense_arch_layer_sizes, 444 | over_arch_layer_sizes=over_arch_layer_sizes, 445 | dense_device=dense_device, 446 | ) 447 | self.loss_fn: nn.Module = nn.BCEWithLogitsLoss() 448 | 449 | def forward(self, batch: Batch) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: 450 | logits = self.model(batch.dense_features, batch.sparse_features) 451 | logits = logits.squeeze() 452 | loss = self.loss_fn(logits, batch.labels.float()) 453 | 454 | return loss, (loss.detach(), logits.detach(), batch.labels.detach()) 455 | -------------------------------------------------------------------------------- /benchmark/benchmark_cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | 1. hit rate 3 | 2. bandwidth 4 | 3. read / load 5 | 4. elapsed time 6 | """ 7 | import itertools 8 | from tqdm import tqdm 9 | from contexttimer import Timer 10 | from contextlib import nullcontext 11 | import numpy as np 12 | 13 | import torch 14 | from torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler 15 | 16 | from colossalai.nn.parallel.layers import CachedEmbeddingBag, EvictionStrategy 17 | from recsys.datasets.criteo import get_id_freq_map 18 | from data_utils import get_dataloader, NUM_EMBED, CRITEO_PATH 19 | 20 | 21 | def benchmark_cache_embedding(batch_size, 22 | embedding_dim, 23 | cache_ratio, 24 | id_freq_map=None, 25 | warmup_ratio=0., 26 | use_limit_buf=True, 27 | use_lfu=False): 28 | dataloader = get_dataloader('train', batch_size) 29 | cuda_row_num = int(cache_ratio * NUM_EMBED) 30 | print(f"batch size: {batch_size}, " 31 | f"num of batches: {len(dataloader)}, " 32 | f"cached rows: {cuda_row_num}, cached_ratio {cuda_row_num / NUM_EMBED}") 33 | data_iter = iter(dataloader) 34 | 35 | torch.cuda.reset_peak_memory_stats() 36 | device = torch.device('cuda:0') 37 | 38 | with Timer() as timer: 39 | model = CachedEmbeddingBag(NUM_EMBED, embedding_dim, sparse=True, include_last_offset=True, \ 40 | evict_strategy=EvictionStrategy.LFU if use_lfu else EvictionStrategy.DATASET).to(device) 41 | # model = torch.nn.EmbeddingBag(NUM_EMBED, embedding_dim, sparse=True, include_last_offset=True).to(device) 42 | print(f"model init: {timer.elapsed:.2f}s") 43 | 44 | grad = None 45 | print( 46 | f'after reorder max_memory_allocated {torch.cuda.max_memory_allocated()/1e9} GB, max_memory_reserved {torch.cuda.max_memory_allocated()/1e9} GB' 47 | ) 48 | torch.cuda.reset_peak_memory_stats() 49 | 50 | with Timer() as timer: 51 | with tqdm(bar_format='{n_fmt}it {rate_fmt} {postfix}') as t: 52 | # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 53 | # schedule=schedule(wait=0, warmup=21, active=2, repeat=1), 54 | # profile_memory=True, 55 | # on_trace_ready=tensorboard_trace_handler( 56 | # f"log/b{batch_size}-e{embedding_dim}-num_chunk{cuda_row_num}-chunk_size{cache_lines}")) as prof: 57 | with nullcontext(): 58 | for it in itertools.count(): 59 | batch = next(data_iter) 60 | sparse_feature = batch.sparse_features.to(device) 61 | 62 | res = model(sparse_feature.values(), sparse_feature.offsets()) 63 | 64 | grad = torch.randn_like(res) if grad is None else grad 65 | res.backward(grad) 66 | 67 | model.zero_grad() 68 | # prof.step() 69 | 70 | t.update() 71 | if it == 200: 72 | break 73 | 74 | if hasattr(model, 'cache_weight_mgr'): 75 | model.cache_weight_mgr.print_comm_stats() 76 | 77 | 78 | if __name__ == "__main__": 79 | with Timer() as timer: 80 | id_freq_map = get_id_freq_map(CRITEO_PATH) 81 | print(f"Counting sparse features in dataset costs: {timer.elapsed:.2f} s") 82 | 83 | batch_size = [2048] 84 | embed_dim = 32 85 | cache_ratio = [0.02] 86 | 87 | # # row-wise cache 88 | # for bs in batch_size: 89 | # for cs in cuda_row_num: 90 | # main(bs, embed_dim, cuda_row_num=cs, cache_lines=1, embed_type='row') 91 | 92 | # chunk-wise cache 93 | for bs in batch_size: 94 | for cr in cache_ratio: 95 | for warmup_ratio in [0.7]: 96 | for use_buf in [False, True]: 97 | try: 98 | benchmark_cache_embedding(bs, 99 | embed_dim, 100 | cache_ratio=cr, 101 | id_freq_map=id_freq_map, 102 | warmup_ratio=warmup_ratio, 103 | use_limit_buf=use_buf) 104 | print('=' * 50 + '\n') 105 | 106 | except AssertionError as ae: 107 | print(f"batch size: {bs}, cache ratio: {cr}, raise error: {ae}") 108 | print('=' * 50 + '\n') -------------------------------------------------------------------------------- /benchmark/benchmark_fbgemm_uvm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from colossalai.nn.parallel.layers.cache_embedding import CachedEmbeddingBag 4 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 5 | from fbgemm_gpu.split_table_batched_embeddings_ops import SplitTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, ComputeDevice, CacheAlgorithm 6 | import time 7 | 8 | ########### GLOBAL SETTINGS ################## 9 | 10 | 11 | BATCH_SIZE = 65536 12 | TABLLE_NUM = 856 13 | FILE_LIST = [f"/data/scratch/RecSys/embedding_bag/fbgemm_t856_bs65536_{i}.pt" for i in range(16)] 14 | KEYS = [] 15 | for i in range(TABLLE_NUM): 16 | KEYS.append("table_{}".format(i)) 17 | EMBEDDING_DIM = 128 18 | # Full dataset is too big 19 | # CHOSEN_TABLES = [0, 2, 3, 4, 5, 7, 8, 9, 10, 12, 15, 18, 22, 27, 28] 20 | # CHOSEN_TABLES = [5, 8, 37, 54, 71,72,73,74,85,86,89,95,96,97,107,131,163, 185, 196, 204, 211, ] 21 | CHOSEN_TABLES = [i for i in range(300,418)] 22 | TEST_ITER = 100 23 | TEST_BATCH_SIZE = 16384 24 | WARMUP_ITERS = 5 25 | ############################################## 26 | 27 | 28 | def load_file(file_path): 29 | indices, offsets, lengths = torch.load(file_path) 30 | indices = indices.int().cuda() 31 | offsets = offsets.int().cuda() 32 | lengths = lengths.int().cuda() 33 | num_embeddings_per_table = [] 34 | indices_per_table = [] 35 | lengths_per_table = [] 36 | offsets_per_table = [] 37 | for i in range(TABLLE_NUM): 38 | if i not in CHOSEN_TABLES: 39 | continue 40 | start_pos = offsets[i * BATCH_SIZE] 41 | end_pos = offsets[i * BATCH_SIZE + BATCH_SIZE] 42 | part = indices[start_pos:end_pos] 43 | indices_per_table.append(part) 44 | lengths_per_table.append(lengths[i]) 45 | offsets_per_table.append(torch.cumsum( 46 | torch.cat((torch.tensor([0]).cuda(), lengths[i])), 0 47 | )) 48 | if part.numel() == 0: 49 | num_embeddings_per_table.append(0) 50 | else: 51 | num_embeddings_per_table.append(torch.max(part).int().item() + 1) 52 | return indices_per_table, offsets_per_table, lengths_per_table, num_embeddings_per_table 53 | 54 | 55 | def load_file_kjt(file_path): 56 | indices, offsets, lengths = torch.load(file_path) 57 | length_per_key = [] 58 | for i in range(TABLLE_NUM): 59 | length_per_key.append(lengths[i]) 60 | ret = KeyedJaggedTensor(KEYS, indices, offsets=offsets, 61 | length_per_key=length_per_key) 62 | return ret 63 | 64 | 65 | def load_random_batch(indices_per_table, offsets_per_table, lengths_per_table, batch_size=4096): 66 | chosen_indices_list = [] 67 | chosen_lengths_list = [] 68 | choose = torch.randint( 69 | 0, offsets_per_table[0].shape[0] - 1, (batch_size,)).cuda() 70 | for indices, offsets, lengths in zip(indices_per_table, offsets_per_table, lengths_per_table): 71 | 72 | chosen_lengths_list.append(lengths[choose]) 73 | start_list = offsets[choose] 74 | end_list = offsets[choose + 1] 75 | chosen_indices_atoms = [] 76 | for start, end in zip(start_list, end_list): 77 | chosen_indices_atoms.append(indices[start: end]) 78 | chosen_indices_list.append(torch.cat(chosen_indices_atoms, 0)) 79 | return chosen_indices_list, chosen_lengths_list 80 | 81 | 82 | def merge_to_kjt(indices_list, lengths_list, length_per_key) -> KeyedJaggedTensor: 83 | values = torch.cat(indices_list) 84 | lengths = torch.cat(lengths_list) 85 | return KeyedJaggedTensor( 86 | keys=[KEYS[i] for i in CHOSEN_TABLES], 87 | values=values, 88 | lengths=lengths, 89 | length_per_key=length_per_key, 90 | ) 91 | 92 | 93 | def test(iter_num=1, batch_size=4096): 94 | print("loading file") 95 | indices_per_table, offsets_per_table, lengths_per_table, num_embeddings_per_table = load_file( 96 | FILE_LIST[0]) 97 | table_idx_offset_list = np.cumsum([0] + num_embeddings_per_table) 98 | fae = CachedEmbeddingBag( 99 | num_embeddings=sum(num_embeddings_per_table), 100 | embedding_dim=EMBEDDING_DIM, 101 | sparse=True, 102 | include_last_offset=True, 103 | cache_ratio=0.05, 104 | pin_weight=True, 105 | ) 106 | fae_forwarding_time = 0.0 107 | fae_backwarding_time = 0.0 108 | grad_fae = None 109 | managed_type = (EmbeddingLocation.MANAGED_CACHING) 110 | uvm = SplitTableBatchedEmbeddingBagsCodegen( 111 | embedding_specs=[( 112 | num_embeddings, 113 | EMBEDDING_DIM, 114 | managed_type, 115 | ComputeDevice.CUDA, 116 | ) for num_embeddings in num_embeddings_per_table], 117 | cache_load_factor=0.05, 118 | cache_algorithm=CacheAlgorithm.LFU 119 | ) 120 | uvm.init_embedding_weights_uniform(-0.5, 0.5) 121 | # print(sum(num_embeddings_per_table)) 122 | # print(uvm.weights_uvm.shape) 123 | uvm_forwarding_time = 0.0 124 | uvm_backwarding_time = 0.0 125 | grad_uvm = None 126 | print("testing:") 127 | for iter in range(iter_num): 128 | # load batch 129 | chosen_indices_list, chosen_lengths_list = load_random_batch(indices_per_table, 130 | offsets_per_table, lengths_per_table, batch_size) 131 | features = merge_to_kjt(chosen_indices_list, 132 | chosen_lengths_list, num_embeddings_per_table) 133 | print("iter {} batch loaded.".format(iter)) 134 | 135 | # fae 136 | with torch.no_grad(): 137 | values = features.values().long() 138 | offsets = features.offsets().long() 139 | weights = features.weights_or_none() 140 | batch_size = len(features.offsets()) // len(features.keys()) 141 | if weights is not None and not torch.is_floating_point(weights): 142 | weights = None 143 | split_view = torch.tensor_split( 144 | values, features.offset_per_key()[1:-1], dim=0) 145 | for i, chunk in enumerate(split_view): 146 | torch.add(chunk, table_idx_offset_list[i], out=chunk) 147 | start = time.time() 148 | output = fae(values, offsets, weights) 149 | ret = torch.cat(output.split(batch_size), 1) 150 | if iter >= WARMUP_ITERS: 151 | fae_forwarding_time += time.time() - start 152 | print("fae forwarded. avg time = {} s".format( 153 | fae_forwarding_time / (iter + 1 - WARMUP_ITERS))) 154 | grad_fae = torch.randn_like(ret) if grad_fae is None else grad_fae 155 | start = time.time() 156 | ret.backward(grad_fae) 157 | if iter >= WARMUP_ITERS: 158 | fae_backwarding_time += time.time() - start 159 | print("fae backwarded. avg time = {} s".format( 160 | fae_backwarding_time / (iter + 1 - WARMUP_ITERS))) 161 | fae.zero_grad() 162 | 163 | # uvm 164 | start = time.time() 165 | ret = uvm(features.values().long(), features.offsets().long()) 166 | if iter >= WARMUP_ITERS: 167 | uvm_forwarding_time += time.time() - start 168 | print("uvm forwarded. avg time = {} s".format( 169 | uvm_forwarding_time / (iter + 1 - WARMUP_ITERS))) 170 | grad_uvm = torch.randn_like(ret) if grad_uvm is None else grad_uvm 171 | start = time.time() 172 | ret.backward(grad_uvm) 173 | if iter >= WARMUP_ITERS: 174 | uvm_backwarding_time += time.time() - start 175 | print("uvm backwarded. avg time = {} s".format( 176 | uvm_backwarding_time / (iter + 1 - WARMUP_ITERS))) 177 | uvm.zero_grad() 178 | 179 | 180 | # test(TEST_ITER, TEST_BATCH_SIZE) 181 | num_embeddings_per_table = None 182 | for i in range(16): 183 | indices_per_table, offsets_per_table, lengths_per_table, num_embeddings_per_table1 = load_file(FILE_LIST[i]) 184 | if num_embeddings_per_table == None: 185 | num_embeddings_per_table = num_embeddings_per_table1 186 | else: 187 | for i, num in enumerate(num_embeddings_per_table1): 188 | num_embeddings_per_table[i] = max(num_embeddings_per_table[i], num) 189 | print(num_embeddings_per_table) 190 | print(sum(num_embeddings_per_table)) -------------------------------------------------------------------------------- /benchmark/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import DataLoader 3 | from recsys.datasets import criteo 4 | 5 | CRITEO_PATH = "../criteo_kaggle_data" 6 | NUM_EMBED = 33762577 7 | 8 | 9 | def get_dataloader(stage, batch_size): 10 | hash_sizes = list(map(int, criteo.KAGGLE_NUM_EMBEDDINGS_PER_FEATURE.split(','))) 11 | files = os.listdir(CRITEO_PATH) 12 | 13 | def is_final_day(s): 14 | return "day_6" in s 15 | 16 | rank, world_size = 0, 1 17 | if stage == "train": 18 | # Train set gets all data except from the final day. 19 | files = list(filter(lambda s: not is_final_day(s), files)) 20 | else: 21 | # Validation set gets the first half of the final day's samples. Test set get 22 | # the other half. 23 | files = list(filter(is_final_day, files)) 24 | rank = rank if stage == "val" else (rank + world_size) 25 | world_size = world_size * 2 26 | 27 | stage_files = [ 28 | sorted(map( 29 | lambda x: os.path.join(CRITEO_PATH, x), 30 | filter(lambda s: kind in s, files), 31 | )) for kind in ["dense", "sparse", "labels"] 32 | ] 33 | dataloader = DataLoader( 34 | criteo.InMemoryBinaryCriteoIterDataPipe( 35 | *stage_files, # pyre-ignore[6] 36 | batch_size=batch_size, 37 | rank=rank, 38 | world_size=world_size, 39 | shuffle_batches=True, 40 | hashes=hash_sizes), 41 | batch_size=None, 42 | pin_memory=True, 43 | collate_fn=lambda x: x, 44 | ) 45 | return dataloader 46 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Domestic cloud servers often have network issues with pip, 2 | # so we need to pip install from tsinghua mirror 3 | FROM hpcaitech/pytorch-cuda:1.12.0-11.3.0 4 | 5 | RUN pip install --no-cache-dir petastorm[torch]==0.11.5 6 | 7 | #install torchrec 8 | RUN python3 -m pip install --no-cache-dir torchrec==0.2.0 9 | 10 | #install torchrec 11 | # RUN wget https://download.pytorch.org/whl/torchrec-0.1.1-py39-none-any.whl && \ 12 | # python3 -m pip install --no-cache-dir torchrec-0.1.1-py39-none-any.whl && \ 13 | # rm torchrec-0.1.1-py39-none-any.whl 14 | 15 | # updated with hpcaitech version 16 | RUN git clone https://github.com/hpcaitech/torchrec.git && cd torchrec && pip install . 17 | 18 | # install colossalai 19 | RUN git clone https://github.com/hpcaitech/ColossalAI.git && \ 20 | cd ColossalAI/ && \ 21 | python3 -m pip install -i --no-cache-dir -r requirements/requirements.txt && \ 22 | python3 -m pip install -i --no-cache-dir . && \ 23 | cd .. && \ 24 | yes | rm -r ColossalAI/ 25 | -------------------------------------------------------------------------------- /docker/Dockerfile_thu: -------------------------------------------------------------------------------- 1 | # Domestic cloud servers often have network issues with pip, 2 | # so we need to pip install from tsinghua mirror 3 | FROM hpcaitech/pytorch-cuda:1.12.0-11.3.0 4 | 5 | RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir petastorm[torch]==0.11.5 6 | 7 | #install torchrec 8 | RUN python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir torchrec==0.2.0 9 | 10 | #install torchrec 11 | # RUN wget https://download.pytorch.org/whl/torchrec-0.1.1-py39-none-any.whl && \ 12 | # python3 -m pip install --no-cache-dir torchrec-0.1.1-py39-none-any.whl && \ 13 | # rm torchrec-0.1.1-py39-none-any.whl 14 | 15 | # updated with hpcaitech version 16 | RUN git clone https://github.com/hpcaitech/torchrec.git && cd torchrec && pip install . 17 | 18 | # install colossalai 19 | RUN git clone https://github.com/hpcaitech/ColossalAI.git && \ 20 | cd ColossalAI/ && \ 21 | python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir -r requirements/requirements.txt && \ 22 | python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir . && \ 23 | cd .. && \ 24 | yes | rm -r ColossalAI/ 25 | -------------------------------------------------------------------------------- /docker/launch.sh: -------------------------------------------------------------------------------- 1 | DATASET_PATH=/data/scratch/RecSys 2 | 3 | docker run --rm -it -e CUDA_VISIBLE_DEVICES=0,1,2,3 -e PYTHONPATH=/workspace/code -v `pwd`:/workspace/code -v ${DATASET_PATH}:/data -w /workspace/code --ipc=host hpcaitech/cacheembedding:0.1.4 /bin/bash 4 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021- HPC-AI Technology Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /pics/prefetch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/pics/prefetch.png -------------------------------------------------------------------------------- /recsys/README.md: -------------------------------------------------------------------------------- 1 | # Build a scalable embedding from scratch for training recommendation models -------------------------------------------------------------------------------- /recsys/__init__.py: -------------------------------------------------------------------------------- 1 | # Build a scalable system from scratch for training recommendation models 2 | -------------------------------------------------------------------------------- /recsys/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/recsys/datasets/__init__.py -------------------------------------------------------------------------------- /recsys/datasets/avazu.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data.datapipes as dp 7 | from torch.utils.data import IterDataPipe, IterableDataset, DataLoader 8 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 9 | from torchrec.datasets.utils import LoadFiles, ReadLinesFromCSV, PATH_MANAGER_KEY, Batch 10 | from torchrec.datasets.criteo import BinaryCriteoUtils 11 | 12 | from .feature_counter import GlobalFeatureCounter 13 | 14 | CAT_FEATURE_COUNT = 13 15 | INT_FEATURE_COUNT = 8 16 | DAYS = 10 17 | DEFAULT_LABEL_NAME = "click" 18 | DEFAULT_CAT_NAMES = [ 19 | 'C1', 20 | 'banner_pos', 21 | 'site_id', 22 | 'site_domain', 23 | 'site_category', 24 | 'app_id', 25 | 'app_domain', 26 | 'app_category', 27 | 'device_id', 28 | 'device_ip', 29 | 'device_model', 30 | 'device_type', 31 | 'device_conn_type', 32 | ] 33 | DEFAULT_INT_NAMES = ['C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21'] 34 | NUM_EMBEDDINGS_PER_FEATURE = '7,7,4737,7745,26,8552,559,36,2686408,6729486,8251,5,4' # 9445823 in total 35 | TOTAL_TRAINING_SAMPLES = 36_386_071 # 90% sample in train, 40428967 in total 36 | 37 | 38 | def _default_row_mapper(row): 39 | _label = row[1] 40 | _sparse = row[3:5] 41 | for i in range(5, 14): # 9 42 | try: 43 | _c = int(row[i], 16) 44 | except ValueError: 45 | _c = 0 46 | _sparse.append(_c) 47 | _sparse += row[14:24] 48 | 49 | return _sparse, _label 50 | 51 | 52 | class AvazuIterDataPipe(IterDataPipe): 53 | 54 | def __init__(self, path, row_mapper=_default_row_mapper): 55 | self.path = path 56 | self.row_mapper = row_mapper 57 | 58 | def __iter__(self): 59 | """ 60 | iterate over the data file, and apply the transform row_mapper to each row 61 | """ 62 | datapipe = LoadFiles([self.path], mode='r', path_manager_key='avazu') 63 | datapipe = ReadLinesFromCSV(datapipe, delimiter=',', skip_first_line=True) 64 | if self.row_mapper is not None: 65 | datapipe = dp.iter.Mapper(datapipe, self.row_mapper) 66 | yield from datapipe 67 | 68 | 69 | class InMemoryAvazuIterDataPipe(IterableDataset): 70 | 71 | def __init__(self, 72 | dense_paths, 73 | sparse_paths, 74 | label_paths, 75 | batch_size, 76 | rank, 77 | world_size, 78 | shuffle_batches=False, 79 | mmap_mode=False, 80 | hashes=None, 81 | path_manager_key=PATH_MANAGER_KEY, 82 | assigned_tables = None): 83 | if assigned_tables is not None: 84 | # tablewise mode 85 | self.assigned_tables = np.array(assigned_tables) 86 | else: 87 | # full table mode 88 | self.assigned_tables = np.arange(CAT_FEATURE_COUNT) 89 | 90 | self.dense_paths = dense_paths 91 | self.sparse_paths = sparse_paths 92 | self.label_paths = label_paths 93 | self.batch_size = batch_size 94 | self.rank = rank 95 | self.world_size = world_size 96 | self.shuffle_batches = shuffle_batches 97 | self.mmap_mode = mmap_mode 98 | if hashes is not None: 99 | self.hashes = [] 100 | for i, length in enumerate(hashes): 101 | if i in self.assigned_tables: 102 | self.hashes.append(length) 103 | self.hashes = np.array(self.hashes).reshape(1, -1) 104 | else: 105 | self.hashes = None 106 | self.path_manager_key = path_manager_key 107 | 108 | self.sparse_offsets = np.array([0, *np.cumsum(self.hashes)[:-1]], dtype=np.int64).reshape( 109 | 1, -1) if self.hashes is not None else None 110 | 111 | self._load_data() 112 | self.num_rows_per_file = [a.shape[0] for a in self.dense_arrs] 113 | self.num_batches: int = sum(self.num_rows_per_file) // batch_size 114 | 115 | self._num_ids_in_batch: int = len(self.assigned_tables) * batch_size 116 | self.keys = [DEFAULT_CAT_NAMES[i] for i in self.assigned_tables] 117 | self.lengths = torch.ones((self._num_ids_in_batch,), dtype=torch.int32) 118 | self.offsets = torch.arange(0, self._num_ids_in_batch + 1, dtype=torch.int32) 119 | self.stride = batch_size 120 | self.length_per_key = len(self.assigned_tables) * [batch_size] 121 | self.offset_per_key = [batch_size * i for i in range(len(self.assigned_tables) + 1)] 122 | self.index_per_key = {key: i for (i, key) in enumerate(self.keys)} 123 | 124 | def _load_data(self): 125 | file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(lengths=[ 126 | BinaryCriteoUtils.get_shape_from_npy(path, path_manager_key=self.path_manager_key)[0] 127 | for path in self.sparse_paths 128 | ], 129 | rank=self.rank, 130 | world_size=self.world_size) 131 | self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], [] 132 | for _dtype, arrs, paths in zip([np.float32, np.int64, np.int32], 133 | [self.dense_arrs, self.sparse_arrs, self.labels_arrs], 134 | [self.dense_paths, self.sparse_paths, self.label_paths]): 135 | for idx, (range_left, range_right) in file_idx_to_row_range.items(): 136 | arrs.append( 137 | BinaryCriteoUtils.load_npy_range(paths[idx], 138 | range_left, 139 | range_right - range_left + 1, 140 | path_manager_key=self.path_manager_key, 141 | mmap_mode=self.mmap_mode).astype(_dtype)) 142 | expand_hashes = np.ones(CAT_FEATURE_COUNT, dtype=np.int64).reshape(1, -1) 143 | expand_sparse_offsets = np.ones(CAT_FEATURE_COUNT, dtype=np.int64).reshape(1, -1) 144 | for i, table in enumerate(self.assigned_tables): 145 | expand_hashes[0, table] = self.hashes[0, i] 146 | expand_sparse_offsets[0, table] = self.sparse_offsets[0, i] 147 | if not self.mmap_mode and self.hashes is not None: 148 | for sparse_arr in self.sparse_arrs: 149 | sparse_arr %= expand_hashes 150 | sparse_arr += expand_sparse_offsets 151 | 152 | def __iter__(self): 153 | buffer = None 154 | 155 | def append_to_buffer(dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray) -> None: 156 | nonlocal buffer 157 | if buffer is None: 158 | buffer = [dense, sparse, labels] 159 | else: 160 | for idx, arr in enumerate([dense, sparse, labels]): 161 | buffer[idx] = np.concatenate((buffer[idx], arr)) 162 | 163 | # Maintain a buffer that can contain up to batch_size rows. Fill buffer as 164 | # much as possible on each iteration. Only return a new batch when batch_size 165 | # rows are filled. 166 | file_idx = 0 167 | row_idx = 0 168 | batch_idx = 0 169 | while batch_idx < self.num_batches: 170 | buffer_row_count = 0 if buffer is None else buffer[0].shape[0] 171 | if buffer_row_count == self.batch_size: 172 | yield self._np_arrays_to_batch(*buffer) 173 | batch_idx += 1 174 | buffer = None 175 | else: 176 | rows_to_get = min( 177 | self.batch_size - buffer_row_count, 178 | self.num_rows_per_file[file_idx] - row_idx, 179 | ) 180 | slice_ = slice(row_idx, row_idx + rows_to_get) 181 | 182 | dense_inputs = self.dense_arrs[file_idx][slice_, :] 183 | sparse_inputs = self.sparse_arrs[file_idx][slice_, :].take(self.assigned_tables, -1) 184 | target_labels = self.labels_arrs[file_idx][slice_, :] 185 | 186 | if self.mmap_mode and self.hashes is not None: 187 | sparse_inputs %= self.hashes 188 | sparse_inputs += self.sparse_offsets 189 | 190 | append_to_buffer( 191 | dense_inputs, 192 | sparse_inputs, 193 | target_labels, 194 | ) 195 | row_idx += rows_to_get 196 | 197 | if row_idx >= self.num_rows_per_file[file_idx]: 198 | file_idx += 1 199 | row_idx = 0 200 | 201 | def _np_arrays_to_batch(self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray) -> Batch: 202 | if self.shuffle_batches: 203 | # Shuffle all 3 in unison 204 | shuffler = np.random.permutation(sparse.shape[0]) 205 | dense = dense[shuffler] 206 | sparse = sparse[shuffler] 207 | labels = labels[shuffler] 208 | 209 | return Batch( 210 | dense_features=torch.from_numpy(dense), 211 | sparse_features=KeyedJaggedTensor( 212 | keys=self.keys, 213 | # transpose + reshape(-1) incurs an additional copy. 214 | values=torch.from_numpy(sparse.transpose(1, 0).reshape(-1)), 215 | lengths=self.lengths, 216 | offsets=self.offsets, 217 | stride=self.stride, 218 | length_per_key=self.length_per_key, 219 | offset_per_key=self.offset_per_key, 220 | index_per_key=self.index_per_key, 221 | ), 222 | labels=torch.from_numpy(labels.reshape(-1)), 223 | ) 224 | 225 | def __len__(self) -> int: 226 | return self.num_batches 227 | 228 | 229 | def get_dataloader(args, stage, rank, world_size, assigned_tables = None): 230 | stage = stage.lower() 231 | 232 | files = os.listdir(args.dataset_dir) 233 | 234 | if stage == 'train': 235 | files = list(filter(lambda s: 'train' in s, files)) 236 | else: 237 | files = list(filter(lambda s: 'train' not in s, files)) 238 | rank = rank if stage == "val" else (rank + world_size) 239 | world_size = world_size * 2 240 | 241 | stage_files = [ 242 | sorted(map( 243 | lambda s: os.path.join(args.dataset_dir, s), 244 | filter(lambda _f: kind in _f, files), 245 | )) for kind in ["dense", "sparse", "label"] 246 | ] 247 | 248 | dataloader = DataLoader( 249 | InMemoryAvazuIterDataPipe(*stage_files, 250 | batch_size=args.batch_size, 251 | rank=rank, 252 | world_size=world_size, 253 | shuffle_batches=args.shuffle_batches, 254 | hashes=args.num_embeddings_per_feature, 255 | assigned_tables=assigned_tables), 256 | batch_size=None, 257 | pin_memory=args.pin_memory, 258 | collate_fn=lambda x: x, 259 | ) 260 | 261 | return dataloader 262 | 263 | 264 | def get_id_freq_map(path): 265 | files = os.listdir(path) 266 | files = list(filter(lambda s: "sparse" in s, files)) 267 | files = [os.path.join(path, _f) for _f in files] 268 | 269 | feature_count = GlobalFeatureCounter(files, list(map(int, NUM_EMBEDDINGS_PER_FEATURE.split(',')))) 270 | id_freq_map = torch.from_numpy(feature_count.compute()) 271 | return id_freq_map 272 | -------------------------------------------------------------------------------- /recsys/datasets/feature_counter.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import random 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | from .criteo import DEFAULT_CAT_NAMES 7 | from petastorm import make_batch_reader 8 | from pyarrow.parquet import ParquetDataset 9 | 10 | 11 | class GlobalFeatureCounter: 12 | """ 13 | compute the global statistics of the whole training set 14 | """ 15 | 16 | def __init__(self, datafiles, hash_sizes): 17 | self.datafiles = datafiles 18 | self.hash_sizes = np.array(hash_sizes).reshape(1, -1) 19 | self.offsets = np.array([0, *np.cumsum(hash_sizes)[:-1]]).reshape(1, -1) 20 | 21 | def compute(self): 22 | id_freq_map = np.zeros(self.hash_sizes.sum(), dtype=np.int64) 23 | for _f in self.datafiles: 24 | arr = np.load(_f) 25 | arr %= self.hash_sizes 26 | arr += self.offsets 27 | flattened = arr.reshape(-1) 28 | id_freq_map += np.bincount(flattened, minlength=self.hash_sizes.sum()) 29 | return id_freq_map 30 | 31 | class PetastormCounter: 32 | 33 | def __init__(self, datafiles, hash_sizes, subsample_fraction=0.2, seed=1024): 34 | self.datafiles = datafiles 35 | self.total_features = sum(hash_sizes) 36 | 37 | self.offsets = np.array([0, *np.cumsum(hash_sizes)[:-1]]).reshape(1, -1) 38 | self.subsample_fraction = subsample_fraction 39 | self.seed = seed 40 | 41 | def compute(self): 42 | _id_freq_map = np.zeros(self.total_features, dtype=np.int64) 43 | 44 | files = self.datafiles 45 | random.seed(self.seed) 46 | random.shuffle(files) 47 | if 0. < self.subsample_fraction < 1.: 48 | files = files[:int(np.ceil(len(files) * self.subsample_fraction))] 49 | 50 | dataset = ParquetDataset(files, use_legacy_dataset=False) 51 | with make_batch_reader(list(map(lambda x: "file://" + x, dataset.files)), num_epochs=1) as reader: 52 | for batch in tqdm(reader, 53 | ncols=0, 54 | desc="Collecting id-freq map", 55 | total=sum([fragment.metadata.num_row_groups for fragment in dataset.fragments])): 56 | sparse = np.concatenate([getattr(batch, col_name).reshape(-1, 1) for col_name in DEFAULT_CAT_NAMES], 57 | axis=1) 58 | sparse = (sparse + self.offsets).reshape(-1) 59 | _id_freq_map += np.bincount(sparse, minlength=self.total_features) 60 | return _id_freq_map 61 | -------------------------------------------------------------------------------- /recsys/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributed as dist 4 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 5 | from torchrec.datasets.utils import Batch 6 | 7 | 8 | class KJTAllToAll: 9 | """ 10 | Different from the module defined in torchrec. 11 | 12 | Basically, this class conducts all_gather with all_to_all collective. 13 | """ 14 | 15 | def __init__(self, group): 16 | self.group = group 17 | self.rank = dist.get_rank(group) 18 | self.world_size = dist.get_world_size(group) 19 | 20 | @torch.no_grad() 21 | def all_to_all(self, kjt): 22 | if self.world_size == 1: 23 | return kjt 24 | # TODO: add sample weights 25 | values, lengths = kjt.values(), kjt.lengths() 26 | keys, batch_size = kjt.keys(), kjt.stride() 27 | 28 | # collect global data 29 | length_list = [lengths if i == self.rank else lengths.clone() for i in range(self.world_size)] 30 | all_length_list = [torch.empty_like(lengths) for _ in range(self.world_size)] 31 | dist.all_to_all(all_length_list, length_list, group=self.group) 32 | 33 | intermediate_all_length_list = [_length.view(-1, batch_size) for _length in all_length_list] 34 | all_length_per_key_list = [torch.sum(_length, dim=1).cpu().tolist() for _length in intermediate_all_length_list] 35 | 36 | all_value_length = [torch.sum(each).item() for each in all_length_list] 37 | value_list = [values if i == self.rank else values.clone() for i in range(self.world_size)] 38 | all_value_list = [ 39 | torch.empty(_length, dtype=values.dtype, device=values.device) for _length in all_value_length 40 | ] 41 | dist.all_to_all(all_value_list, value_list, group=self.group) 42 | 43 | all_value_list = [ 44 | torch.split(_values, _length_per_key) # [ key size, variable value size ] 45 | for _values, _length_per_key in zip(all_value_list, all_length_per_key_list) # world size 46 | ] 47 | all_values = torch.cat([torch.cat(values_per_key) for values_per_key in zip(*all_value_list)]) 48 | 49 | all_lengths = torch.cat(intermediate_all_length_list, dim=1).view(-1) 50 | return KeyedJaggedTensor.from_lengths_sync( 51 | keys=keys, 52 | values=all_values, 53 | lengths=all_lengths, 54 | ) 55 | 56 | 57 | class KJTTransform: 58 | 59 | def __init__(self, dataloader, hashes=None): 60 | self.batch_size = dataloader.batch_size 61 | self.cats = dataloader.cat_names 62 | self.conts = dataloader.cont_names 63 | self.labels = dataloader.label_names 64 | self.sparse_offset = torch.tensor( 65 | [0, *np.cumsum(hashes)[:-1]], dtype=torch.long, device=torch.cuda.current_device()).view(1, -1) \ 66 | if hashes is not None else None 67 | 68 | _num_ids_in_batch = len(self.cats) * self.batch_size 69 | self.lengths = torch.ones((_num_ids_in_batch,), dtype=torch.int32) 70 | self.offsets = torch.arange(0, _num_ids_in_batch + 1, dtype=torch.int32) 71 | self.length_per_key = len(self.cats) * [self.batch_size] 72 | self.offset_per_key = [self.batch_size * i for i in range(len(self.cats) + 1)] 73 | self.index_per_key = {key: i for (i, key) in enumerate(self.cats)} 74 | 75 | def transform(self, batch): 76 | sparse, dense = [], [] 77 | for col in self.cats: 78 | sparse.append(batch[0][col]) 79 | sparse = torch.cat(sparse, dim=1) 80 | if self.sparse_offset is not None: 81 | sparse += self.sparse_offset 82 | for col in self.conts: 83 | dense.append(batch[0][col]) 84 | dense = torch.cat(dense, dim=1) 85 | 86 | return Batch( 87 | dense_features=dense, 88 | sparse_features=KeyedJaggedTensor( 89 | keys=self.cats, 90 | values=sparse.transpose(1, 0).contiguous().view(-1), 91 | lengths=self.lengths, 92 | offsets=self.offsets, 93 | stride=self.batch_size, 94 | length_per_key=self.length_per_key, 95 | offset_per_key=self.offset_per_key, 96 | index_per_key=self.index_per_key, 97 | ), 98 | labels=batch[1], 99 | ) 100 | -------------------------------------------------------------------------------- /recsys/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/recsys/models/__init__.py -------------------------------------------------------------------------------- /recsys/models/dlrm.py: -------------------------------------------------------------------------------- 1 | # The infrastructures of DLRM are mainly inspired by TorchRec: 2 | # https://github.com/pytorch/torchrec/blob/main/torchrec/models/dlrm.py 3 | import os 4 | import torch 5 | from contextlib import nullcontext 6 | import torch.nn as nn 7 | from torch.nn.parallel import DistributedDataParallel as DDP 8 | from torch.profiler import record_function 9 | from typing import List 10 | from baselines.models.dlrm import DenseArch, OverArch, InteractionArch, choose 11 | from ..utils import get_time_elapsed 12 | from ..datasets.utils import KJTAllToAll 13 | from ..utils import prepare_tablewise_config 14 | import colossalai 15 | from colossalai.nn.parallel.layers import ParallelCachedEmbeddingBag, EvictionStrategy, \ 16 | TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewise 17 | from colossalai.core import global_context as gpc 18 | from colossalai.context.parallel_mode import ParallelMode 19 | import numpy as np 20 | from typing import Union 21 | from torchrec import KeyedJaggedTensor 22 | 23 | dist_logger = colossalai.logging.get_dist_logger() 24 | 25 | 26 | def sparse_embedding_shape_hook(embeddings, feature_size, batch_size): 27 | return embeddings.view(feature_size, batch_size, -1).transpose(0, 1) 28 | 29 | def sparse_embedding_shape_hook_for_tablewise(embeddings, feature_size, batch_size): 30 | return embeddings.view(embeddings.shape[0], feature_size, -1) 31 | 32 | class FusedSparseModules(nn.Module): 33 | 34 | def __init__(self, 35 | num_embeddings_per_feature, 36 | embedding_dim, 37 | fused_op='all_to_all', 38 | reduction_mode='sum', 39 | sparse=False, 40 | output_device_type=None, 41 | use_cache=False, 42 | cache_ratio=0.01, 43 | id_freq_map=None, 44 | warmup_ratio=0.7, 45 | buffer_size=50_000, 46 | is_dist_dataloader=True, 47 | use_lfu_eviction=False, 48 | use_tablewise_parallel=False, 49 | dataset: str = None): 50 | super(FusedSparseModules, self).__init__() 51 | self.sparse_feature_num = len(num_embeddings_per_feature) 52 | if use_cache: 53 | if use_tablewise_parallel: 54 | # establist config list 55 | world_size = torch.distributed.get_world_size() 56 | embedding_bag_config_list = prepare_tablewise_config( 57 | num_embeddings_per_feature, 0.01, id_freq_map, dataset, world_size) 58 | self.embed = ParallelCachedEmbeddingBagTablewise( 59 | embedding_bag_config_list, 60 | embedding_dim, 61 | sparse=sparse, 62 | mode=reduction_mode, 63 | include_last_offset=True, 64 | warmup_ratio=warmup_ratio, 65 | buffer_size=buffer_size, 66 | evict_strategy=EvictionStrategy.LFU if use_lfu_eviction else EvictionStrategy.DATASET, 67 | ) 68 | self.shape_hook = sparse_embedding_shape_hook_for_tablewise 69 | else: 70 | self.embed = ParallelCachedEmbeddingBag( 71 | sum(num_embeddings_per_feature), 72 | embedding_dim, 73 | sparse=sparse, 74 | mode=reduction_mode, 75 | include_last_offset=True, 76 | cache_ratio=cache_ratio, 77 | ids_freq_mapping=id_freq_map, 78 | warmup_ratio=warmup_ratio, 79 | buffer_size=buffer_size, 80 | evict_strategy=EvictionStrategy.LFU if use_lfu_eviction else EvictionStrategy.DATASET 81 | ) 82 | self.shape_hook = sparse_embedding_shape_hook 83 | else: 84 | raise NotImplementedError("Other EmbeddingBags are under development") 85 | 86 | if is_dist_dataloader: 87 | self.kjt_collector = KJTAllToAll(gpc.get_group(ParallelMode.GLOBAL)) 88 | else: 89 | self.kjt_collector = None 90 | 91 | def forward(self, sparse_features : Union[List, KeyedJaggedTensor], cache_op: bool = True): 92 | self.embed.set_cache_op(cache_op) 93 | if self.kjt_collector: 94 | with record_function("(zhg)KJT AllToAll collective"): 95 | sparse_features = self.kjt_collector.all_to_all(sparse_features) 96 | 97 | if isinstance(sparse_features, list): 98 | batch_size = sparse_features[2] 99 | flattened_sparse_embeddings = self.embed( 100 | sparse_features[0], 101 | sparse_features[1], 102 | shape_hook=lambda x: self.shape_hook(x, self.sparse_feature_num , batch_size), 103 | ) 104 | elif isinstance(sparse_features, KeyedJaggedTensor): 105 | batch_size = sparse_features.stride() 106 | flattened_sparse_embeddings = self.embed( 107 | sparse_features.values(), 108 | sparse_features.offsets(), 109 | shape_hook=lambda x: self.shape_hook(x, self.sparse_feature_num , batch_size), 110 | ) 111 | else: 112 | raise TypeError 113 | return flattened_sparse_embeddings 114 | 115 | 116 | class FusedDenseModules(nn.Module): 117 | """ 118 | Fusing dense operations of DLRM into a single module 119 | """ 120 | 121 | def __init__(self, embedding_dim, num_sparse_features, dense_in_features, dense_arch_layer_sizes, 122 | over_arch_layer_sizes): 123 | super(FusedDenseModules, self).__init__() 124 | if dense_in_features <= 0: 125 | self.dense_arch = nn.Identity() 126 | over_in_features = choose(num_sparse_features, 2) 127 | num_dense = 0 128 | else: 129 | self.dense_arch = DenseArch(in_features=dense_in_features, layer_sizes=dense_arch_layer_sizes) 130 | over_in_features = (embedding_dim + choose(num_sparse_features, 2) + num_sparse_features) 131 | num_dense = 1 132 | 133 | self.inter_arch = InteractionArch(num_sparse_features=num_sparse_features, num_dense_features=num_dense) 134 | self.over_arch = OverArch(in_features=over_in_features, layer_sizes=over_arch_layer_sizes) 135 | 136 | def forward(self, dense_features, embedded_sparse_features): 137 | embedded_dense_features = self.dense_arch(dense_features) 138 | concat_dense = self.inter_arch(dense_features=embedded_dense_features, sparse_features=embedded_sparse_features) 139 | logits = self.over_arch(concat_dense) 140 | 141 | return logits 142 | 143 | 144 | class HybridParallelDLRM(nn.Module): 145 | """ 146 | Model parallelized Embedding followed by Data parallelized dense modules 147 | """ 148 | 149 | def __init__(self, 150 | num_embeddings_per_feature, 151 | embedding_dim, 152 | num_sparse_features, 153 | dense_in_features, 154 | dense_arch_layer_sizes, 155 | over_arch_layer_sizes, 156 | dense_device, 157 | sparse_device, 158 | sparse=False, 159 | fused_op='all_to_all', 160 | use_cache=False, 161 | cache_ratio=0.01, 162 | id_freq_map=None, 163 | warmup_ratio=0.7, 164 | buffer_size=50_000, 165 | is_dist_dataloader=True, 166 | use_lfu_eviction=False, 167 | use_tablewise=False, 168 | dataset: str = None): 169 | 170 | super(HybridParallelDLRM, self).__init__() 171 | if use_cache and sparse_device.type != dense_device.type: 172 | raise ValueError(f"Sparse device must be the same as dense device, " 173 | f"however we got {sparse_device.type} for sparse, {dense_device.type} for dense") 174 | 175 | self.dense_device = dense_device 176 | self.sparse_device = sparse_device 177 | 178 | self.sparse_modules = FusedSparseModules(num_embeddings_per_feature, 179 | embedding_dim, 180 | fused_op=fused_op, 181 | sparse=sparse, 182 | output_device_type=dense_device.type, 183 | use_cache=use_cache, 184 | cache_ratio=cache_ratio, 185 | id_freq_map=id_freq_map, 186 | warmup_ratio=warmup_ratio, 187 | buffer_size=buffer_size, 188 | is_dist_dataloader=is_dist_dataloader, 189 | use_lfu_eviction=use_lfu_eviction, 190 | use_tablewise_parallel=use_tablewise, 191 | dataset=dataset 192 | ).to(sparse_device) 193 | self.dense_modules = DDP(module=FusedDenseModules(embedding_dim, num_sparse_features, dense_in_features, 194 | dense_arch_layer_sizes, 195 | over_arch_layer_sizes).to(dense_device), 196 | device_ids=[0 if os.environ.get("NVT_TAG", None) else gpc.get_global_rank()], 197 | process_group=gpc.get_group(ParallelMode.GLOBAL), 198 | gradient_as_bucket_view=True, 199 | broadcast_buffers=False, 200 | static_graph=True) 201 | 202 | # precompute for parallelized embedding 203 | param_amount = sum(num_embeddings_per_feature) * embedding_dim 204 | param_storage = self.sparse_modules.embed.element_size() * param_amount 205 | param_amount += sum(p.numel() for p in self.dense_modules.parameters()) 206 | param_storage += sum(p.numel() * p.element_size() for p in self.dense_modules.parameters()) 207 | # 208 | buffer_amount = sum(b.numel() for b in self.sparse_modules.buffers()) + \ 209 | sum(b.numel() for b in self.dense_modules.buffers()) 210 | buffer_storage = sum(b.numel() * b.element_size() for b in self.sparse_modules.buffers()) + \ 211 | sum(b.numel() * b.element_size() for b in self.dense_modules.buffers()) 212 | stat_str = f"Number of model parameters: {param_amount:,}, storage overhead: {param_storage/1024**3:.2f} GB. " \ 213 | f"Number of model buffers: {buffer_amount:,}, storage overhead: {buffer_storage/1024**3:.2f} GB." 214 | self.stat_str = stat_str 215 | 216 | def forward(self, dense_features, sparse_features, inspect_time=False, 217 | cache_op = True): 218 | ctx1 = get_time_elapsed(dist_logger, "embedding lookup in forward pass") \ 219 | if inspect_time else nullcontext() 220 | with ctx1: 221 | with record_function("Embedding lookup:"): 222 | # B // world size, sparse feature dim, embedding dim 223 | embedded_sparse = self.sparse_modules(sparse_features, cache_op) 224 | 225 | ctx2 = get_time_elapsed(dist_logger, "dense operations in forward pass") \ 226 | if inspect_time else nullcontext() 227 | with ctx2: 228 | with record_function("Dense operations:"): 229 | # B // world size, 1 230 | logits = self.dense_modules(dense_features, embedded_sparse) 231 | 232 | return logits 233 | 234 | def model_stats(self, prefix=""): 235 | return f"{prefix}: {self.stat_str}" 236 | -------------------------------------------------------------------------------- /recsys/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import get_mem_info, compute_throughput, get_time_elapsed, Timer, get_partition, \ 2 | TrainValTestResults, count_parameters, prepare_tablewise_config, get_tablewise_rank_arrange 3 | from .dataloader import CudaStreamDataIter, FiniteDataIter 4 | 5 | __all__ = [ 6 | 'get_mem_info', 'compute_throughput', 'get_time_elapsed', 'Timer', 'get_partition', 'CudaStreamDataIter', 7 | 'FiniteDataIter', 'TrainValTestResults', 'count_parameters', 'prepare_tablewise_config', 8 | 'get_tablewise_rank_arrange' 9 | ] 10 | -------------------------------------------------------------------------------- /recsys/utils/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .cuda_stream_dataloader import CudaStreamDataIter, FiniteDataIter 2 | 3 | __all__ = ['CudaStreamDataIter', 'FiniteDataIter'] 4 | -------------------------------------------------------------------------------- /recsys/utils/dataloader/base_dataiter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Optional 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | class BaseStreamDataIter(ABC): 11 | 12 | def __init__(self, loader: DataLoader): 13 | self.loader = loader 14 | self.iter = iter(loader) 15 | self.stream = torch.cuda.Stream() 16 | self._preload() 17 | 18 | @staticmethod 19 | def _move_tensor(element): 20 | if torch.is_tensor(element): 21 | if not element.is_cuda: 22 | return element.cuda(non_blocking=True) 23 | return element 24 | 25 | @staticmethod 26 | def _record_tensor(element, stream: torch.cuda.Stream) -> None: 27 | if torch.is_tensor(element): 28 | element.record_stream(stream) 29 | 30 | def record_stream(self, data, stream: Optional[torch.cuda.Stream] = None) -> None: 31 | if stream is None: 32 | stream = torch.cuda.current_stream() 33 | 34 | if isinstance(data, torch.Tensor): 35 | data.record_stream(stream) 36 | elif isinstance(data, (list, tuple)): 37 | for element in data: 38 | if isinstance(element, dict): 39 | for _k, v in element.items(): 40 | self._record_tensor(v, stream) 41 | else: 42 | self._record_tensor(element, stream) 43 | elif isinstance(data, dict): 44 | for _k, v in data.items(): 45 | self._record_tensor(v, stream) 46 | else: 47 | raise TypeError( 48 | f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") 49 | 50 | def to_cuda(self, data): 51 | if isinstance(data, torch.Tensor): 52 | data = data.cuda(non_blocking=True) 53 | elif isinstance(data, (list, tuple)): 54 | data_to_return = [] 55 | for element in data: 56 | if isinstance(element, dict): 57 | data_to_return.append({k: self._move_tensor(v) for k, v in element.items()}) 58 | else: 59 | data_to_return.append(self._move_tensor(element)) 60 | data = data_to_return 61 | elif isinstance(data, dict): 62 | data = {k: self._move_tensor(v) for k, v in data.items()} 63 | else: 64 | raise TypeError( 65 | f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") 66 | return data 67 | 68 | @abstractmethod 69 | def _preload(self): 70 | pass 71 | 72 | @abstractmethod 73 | def _reset(self): 74 | pass 75 | 76 | @abstractmethod 77 | def __next__(self): 78 | pass 79 | 80 | @abstractmethod 81 | def __iter__(self): 82 | pass 83 | 84 | -------------------------------------------------------------------------------- /recsys/utils/dataloader/cuda_stream_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | from typing import TypeVar, Iterator 4 | 5 | import torch 6 | from torch.utils.data import Sampler, Dataset, DataLoader 7 | 8 | from .base_dataiter import BaseStreamDataIter 9 | 10 | 11 | class CudaStreamDataIter(BaseStreamDataIter): 12 | """ 13 | A data iterator that supports batch prefetching with the help of cuda stream. 14 | """ 15 | 16 | def __init__(self, loader: DataLoader): 17 | super().__init__(loader) 18 | 19 | def _preload(self): 20 | try: 21 | self.batch_data = next(self.iter) 22 | 23 | except StopIteration: 24 | self.batch_data = None 25 | self._reset() 26 | return 27 | 28 | with torch.cuda.stream(self.stream): 29 | self.batch_data = self.to_cuda(self.batch_data) 30 | 31 | def _reset(self): 32 | self.iter = iter(self.loader) 33 | self.stream = torch.cuda.Stream() 34 | self._preload() 35 | 36 | def __next__(self): 37 | torch.cuda.current_stream().wait_stream(self.stream) 38 | batch_data = self.batch_data 39 | 40 | if batch_data is not None: 41 | self.record_stream(batch_data, torch.cuda.current_stream()) 42 | 43 | self._preload() 44 | return batch_data 45 | 46 | def __iter__(self): 47 | return self 48 | 49 | 50 | class FiniteDataIter(BaseStreamDataIter): 51 | 52 | def _reset(self): 53 | self.iter = iter(self.loader) 54 | self.stream = torch.cuda.Stream() 55 | self._preload() 56 | 57 | def __init__(self, data_loader): 58 | super(FiniteDataIter, self).__init__(data_loader) 59 | 60 | def _preload(self): 61 | try: 62 | self.batch_data = next(self.iter) 63 | 64 | with torch.cuda.stream(self.stream): 65 | self.batch_data = self.batch_data.to(torch.cuda.current_device(), non_blocking=True) 66 | 67 | except StopIteration: 68 | self.batch_data = None 69 | 70 | def __next__(self): 71 | torch.cuda.current_stream().wait_stream(self.stream) 72 | batch_data = self.batch_data 73 | if batch_data is not None: 74 | batch_data.record_stream(torch.cuda.current_stream()) 75 | else: 76 | raise StopIteration() 77 | 78 | self._preload() 79 | return batch_data 80 | 81 | def __iter__(self): 82 | return self 83 | -------------------------------------------------------------------------------- /recsys/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import psutil 3 | from contextlib import contextmanager 4 | import time 5 | from time import perf_counter 6 | from dataclasses import dataclass, field 7 | from typing import List, Optional 8 | from colossalai.nn.parallel.layers import TablewiseEmbeddingBagConfig 9 | 10 | import numpy as np 11 | 12 | @dataclass 13 | class TrainValTestResults: 14 | val_accuracies: List[float] = field(default_factory=list) 15 | val_aurocs: List[float] = field(default_factory=list) 16 | test_accuracy: Optional[float] = None 17 | test_auroc: Optional[float] = None 18 | 19 | def count_parameters(model, prefix=''): 20 | trainable_param = sum(p.numel() for p in model.parameters() if p.requires_grad) 21 | param_amount = sum(p.numel() for p in model.parameters()) 22 | buffer_amount = sum(b.numel() for b in model.buffers()) 23 | param_storage = sum([p.numel() * p.element_size() for p in model.parameters()]) 24 | buffer_storage = sum([b.numel() * b.element_size() for b in model.buffers()]) 25 | stats_str = f'{prefix}: {trainable_param:,}.' + '\n' 26 | stats_str += f"Number of model parameters: {param_amount:,}, storage overhead: {param_storage/1024**3:.2f} GB. " 27 | stats_str += f"Number of model buffers: {buffer_amount:,}, storage overhead: {buffer_storage/1024**3:.2f} GB." 28 | return stats_str 29 | 30 | 31 | def get_mem_info(prefix=''): 32 | return f'{prefix}GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB, ' \ 33 | f'GPU memory reserved: {torch.cuda.memory_reserved() /1024**3:.2f} GB, ' \ 34 | f'CPU memory usage: {psutil.Process().memory_info().rss / 1024**3:.2f} GB' 35 | 36 | 37 | @contextmanager 38 | def compute_throughput(batch_size) -> float: 39 | start = perf_counter() 40 | yield lambda: batch_size / ((perf_counter() - start) * 1000) 41 | 42 | 43 | @contextmanager 44 | def get_time_elapsed(logger, repr: str): 45 | timer = Timer() 46 | timer.start() 47 | yield 48 | elapsed = timer.stop() 49 | logger.info(f"Time elapsed for {repr}: {elapsed:.4f}s", ranks=[0]) 50 | 51 | 52 | class Timer: 53 | """A timer object which helps to log the execution times, and provides different tools to assess the times. 54 | """ 55 | 56 | def __init__(self): 57 | self._started = False 58 | self._start_time = time.time() 59 | self._elapsed = 0 60 | self._history = [] 61 | 62 | @property 63 | def has_history(self): 64 | return len(self._history) != 0 65 | 66 | @property 67 | def current_time(self) -> float: 68 | torch.cuda.synchronize() 69 | return time.time() 70 | 71 | def start(self): 72 | """Firstly synchronize cuda, reset the clock and then start the timer. 73 | """ 74 | self._elapsed = 0 75 | torch.cuda.synchronize() 76 | self._start_time = time.time() 77 | self._started = True 78 | 79 | def lap(self): 80 | """lap time and return elapsed time 81 | """ 82 | return self.current_time - self._start_time 83 | 84 | def stop(self, keep_in_history: bool = False): 85 | """Stop the timer and record the start-stop time interval. 86 | 87 | Args: 88 | keep_in_history (bool, optional): Whether does it record into history 89 | each start-stop interval, defaults to False. 90 | Returns: 91 | int: Start-stop interval. 92 | """ 93 | torch.cuda.synchronize() 94 | end_time = time.time() 95 | elapsed = end_time - self._start_time 96 | if keep_in_history: 97 | self._history.append(elapsed) 98 | self._elapsed = elapsed 99 | self._started = False 100 | return elapsed 101 | 102 | def get_history_mean(self): 103 | """Mean of all history start-stop time intervals. 104 | 105 | Returns: 106 | int: Mean of time intervals 107 | """ 108 | return sum(self._history) / len(self._history) 109 | 110 | def get_history_sum(self): 111 | """Add up all the start-stop time intervals. 112 | 113 | Returns: 114 | int: Sum of time intervals. 115 | """ 116 | return sum(self._history) 117 | 118 | def get_elapsed_time(self): 119 | """Return the last start-stop time interval. 120 | 121 | Returns: 122 | int: The last time interval. 123 | 124 | Note: 125 | Use it only when timer is not in progress 126 | """ 127 | assert not self._started, 'Timer is still in progress' 128 | return self._elapsed 129 | 130 | def reset(self): 131 | """Clear up the timer and its history 132 | """ 133 | self._history = [] 134 | self._started = False 135 | self._elapsed = 0 136 | 137 | 138 | def get_partition(embedding_dim, rank, world_size): 139 | if world_size == 1: 140 | return 0, embedding_dim, True 141 | 142 | assert embedding_dim >= world_size, \ 143 | f"Embedding dimension {embedding_dim} must be larger than the world size " \ 144 | f"{world_size} of the process group" 145 | chunk_size = embedding_dim // world_size 146 | threshold = embedding_dim % world_size 147 | # if embedding dim is divisible by world size 148 | if threshold == 0: 149 | return rank * chunk_size, (rank + 1) * chunk_size, True 150 | 151 | # align with the split strategy of torch.tensor_split 152 | size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)] 153 | offset = sum(size_list[:rank]) 154 | return offset, offset + size_list[rank], False 155 | 156 | 157 | def prepare_tablewise_config(num_embeddings_per_feature, 158 | cache_ratio, 159 | id_freq_map_total=None, 160 | dataset="criteo_kaggle", 161 | world_size=2): 162 | # WARNING, prototype. only support criteo_kaggle dataset and world_size == 2, 4 163 | # TODO: automatic arrange 164 | embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] 165 | rank_arrange = get_tablewise_rank_arrange(dataset, world_size) 166 | table_offsets = np.array([0, *np.cumsum(num_embeddings_per_feature)]) 167 | for i, num_embeddings in enumerate(num_embeddings_per_feature): 168 | ids_freq_mapping = None 169 | if id_freq_map_total != None: 170 | ids_freq_mapping = id_freq_map_total[table_offsets[i] : table_offsets[i + 1]] 171 | cuda_row_num = int(cache_ratio * num_embeddings) + 2000 172 | if cuda_row_num > num_embeddings: 173 | cuda_row_num = num_embeddings 174 | embedding_bag_config_list.append( 175 | TablewiseEmbeddingBagConfig( 176 | num_embeddings=num_embeddings, 177 | cuda_row_num=cuda_row_num, 178 | assigned_rank=rank_arrange[i], 179 | ids_freq_mapping=ids_freq_mapping 180 | ) 181 | ) 182 | return embedding_bag_config_list 183 | 184 | def get_tablewise_rank_arrange(dataset=None, world_size=0): 185 | if 'criteo' in dataset and 'kaggle' in dataset: 186 | if world_size == 1: 187 | rank_arrange = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 188 | elif world_size == 2: 189 | rank_arrange = [0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0] 190 | elif world_size == 3: 191 | rank_arrange = [2, 1, 0, 1, 1, 2, 2, 1, 0, 0, 1, 1, 0, 1, 0, 2, 0, 2, 2, 0, 2, 2, 0, 1, 1, 0] 192 | elif world_size == 4: 193 | rank_arrange = [3, 1, 0, 3, 1, 0, 2, 1, 0, 2, 3, 1, 3, 1, 2, 3, 1, 2, 3, 0, 2, 0, 0, 2, 3, 2] 194 | elif world_size == 8: 195 | rank_arrange = [6, 6, 0, 4, 7, 2, 5, 7, 0, 5, 7, 1, 7, 3, 5, 3, 1, 6, 6, 0, 2, 2, 1, 4, 3, 4] 196 | else : 197 | raise NotImplementedError("Other Tablewise settings are under development") 198 | elif 'criteo' in dataset: 199 | if world_size == 1: 200 | rank_arrange = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 201 | elif world_size == 2: 202 | rank_arrange = [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0] 203 | elif world_size == 4: 204 | rank_arrange = [1, 3, 3, 3, 3, 0, 2, 2, 1, 2, 2, 2, 0, 1, 2, 1, 0, 1, 0, 0, 2, 3, 3, 3, 1, 0] 205 | else : 206 | raise NotImplementedError("Other Tablewise settings are under development") 207 | else: 208 | raise NotImplementedError("Other Tablewise settings are under development") 209 | return rank_arrange -------------------------------------------------------------------------------- /recsys/utils/preprocess_synth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # a = torch.tensor([1, 3, 5, 7, 14, 3, 40, 23]) 3 | # b, c = torch.unique(a, sorted=True, return_inverse=True) 4 | # print(b) 5 | # print(c) 6 | 7 | BATCH_SIZE = 65536 8 | TABLLE_NUM = 856 9 | FILE_LIST = [f"/data/scratch/RecSys/embedding_bag/fbgemm_t856_bs65536_{i}.pt" for i in range( 10 | 16)] + ["/data/scratch/RecSys/embedding_bag/fbgemm_t856_bs65536.pt"] 11 | KEYS = [] 12 | for i in range(TABLLE_NUM): 13 | KEYS.append("table_{}".format(i)) 14 | 15 | CHOSEN_TABLES = [i for i in range(0, 856)] 16 | 17 | def load_file(file_path, cuda=True): 18 | indices, offsets, lengths = torch.load(file_path) 19 | if cuda: 20 | indices = indices.int().cuda() 21 | offsets = offsets.int().cuda() 22 | lengths = lengths.int().cuda() 23 | else : 24 | indices = indices.int() 25 | offsets = offsets.int() 26 | lengths = lengths.int() 27 | indices_per_table = [] 28 | for i in range(TABLLE_NUM): 29 | if i not in CHOSEN_TABLES: 30 | continue 31 | start_pos = offsets[i * BATCH_SIZE] 32 | end_pos = offsets[i * BATCH_SIZE + BATCH_SIZE] 33 | part = indices[start_pos:end_pos] 34 | indices_per_table.append(part) 35 | return indices_per_table 36 | 37 | if __name__ == "__main__": 38 | indices_per_table_list= [] 39 | indices_per_table_length_list = [] 40 | for i, file in enumerate(FILE_LIST): 41 | indices_per_table = load_file(file, cuda=False) 42 | print("loaded ", file) 43 | for j, indices in enumerate(indices_per_table): 44 | if i == 0: 45 | indices_per_table_list.append([indices]) 46 | indices_per_table_length_list.append([indices.shape[0]]) 47 | else: 48 | indices_per_table_list[j].append(indices) 49 | indices_per_table_length_list[j].append(indices.shape[0]) 50 | 51 | # unique op for each table: 52 | for i, (indices_list, length_list) in enumerate(zip(indices_per_table_list, indices_per_table_length_list)): 53 | catted = torch.cat(indices_list) 54 | _, processed = torch.unique(catted, sorted=True, return_inverse=True) 55 | indices_per_table_list[i] = torch.split(processed, length_list) 56 | 57 | # save to each file: 58 | for i, file in enumerate(FILE_LIST): 59 | _, offsets, lengths = torch.load(file) 60 | indices_per_table = [indices_in_table[i] for indices_in_table in indices_per_table_list] 61 | reconcatenate = torch.cat(indices_per_table) 62 | torch.save((reconcatenate, offsets, lengths), 63 | f"/home/lccsr/data2/embedding_bag_processed/fbgemm_t856_bs65536_processed_{i}.pt") 64 | print("saved, ", i) 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/requirements.txt -------------------------------------------------------------------------------- /scripts/avazu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # For Colossalai enabled recsys 4 | # avazu 5 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 6 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 7 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 8 | # --profile_dir "tensorboard_log/avazu/w1_p1_16k" --buffer_size 0 --use_overlap --cache_sets 94458 9 | # 10 | torchx run -s local_cwd -cfg log_dir=log/avazu/w2_p1_16k dist.ddp -j 1x2 --script recsys/dlrm_main.py -- \ 11 | --dataset_dir /data/scratch/RecSys/avazu_sample --pin_memory --shuffle_batches \ 12 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 13 | --profile_dir "tensorboard_log/avazu/w2_p1_16k" --buffer_size 0 --use_overlap --cache_sets 94458 14 | 15 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w4_p1_16k dist.ddp -j 1x4 --script recsys/dlrm_main.py -- \ 16 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 17 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 18 | # --profile_dir "tensorboard_log/avazu/w4_p1_16k" --buffer_size 0 --use_overlap --cache_sets 94458 19 | # 20 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w8_p1_16k dist.ddp -j 1x8 --script recsys/dlrm_main.py -- \ 21 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 22 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 23 | # --profile_dir "tensorboard_log/avazu/w8_p1_16k" --buffer_size 0 --use_overlap --cache_sets 94458 24 | # 25 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_32k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 26 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 27 | # --learning_rate 1. --batch_size 32768 --use_sparse_embed_grad --use_cache --use_freq \ 28 | # --profile_dir "tensorboard_log/avazu/w1_p1_32k" --buffer_size 0 --use_overlap --cache_sets 94458 29 | # 30 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_8k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 31 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 32 | # --learning_rate 1. --batch_size 8192 --use_sparse_embed_grad --use_cache --use_freq \ 33 | # --profile_dir "tensorboard_log/avazu/w1_p1_8k" --buffer_size 0 --use_overlap --cache_sets 94458 34 | # 35 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_4k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 36 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 37 | # --learning_rate 1. --batch_size 4096 --use_sparse_embed_grad --use_cache --use_freq \ 38 | # --profile_dir "tensorboard_log/avazu/w1_p1_4k" --buffer_size 0 --use_overlap --cache_sets 94458 39 | # 40 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_2k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 41 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 42 | # --learning_rate 1. --batch_size 2048 --use_sparse_embed_grad --use_cache --use_freq \ 43 | # --profile_dir "tensorboard_log/avazu/w1_p1_2k" --buffer_size 0 --use_overlap --cache_sets 94458 44 | # 45 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_1k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 46 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 47 | # --learning_rate 1. --batch_size 1024 --use_sparse_embed_grad --use_cache --use_freq \ 48 | # --profile_dir "tensorboard_log/avazu/w1_p1_1k" --buffer_size 0 --use_overlap --cache_sets 94458 49 | # 50 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p10_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 51 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 52 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 53 | # --profile_dir "tensorboard_log/avazu/w1_p10_16k" --buffer_size 0 --use_overlap --cache_sets 944582 54 | # 55 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p5_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 56 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 57 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 58 | # --profile_dir "tensorboard_log/avazu/w1_p5_16k" --buffer_size 0 --use_overlap --cache_sets 472291 59 | # 60 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p2_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 61 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 62 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 63 | # --profile_dir "tensorboard_log/avazu/w1_p2_16k" --buffer_size 0 --use_overlap --cache_sets 188916 64 | # 65 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p0_5_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 66 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 67 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 68 | # --profile_dir "tensorboard_log/avazu/w1_p0_5_16k" --buffer_size 0 --use_overlap --cache_sets 47229 69 | # 70 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p0_1_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 71 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \ 72 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 73 | # --profile_dir "tensorboard_log/avazu/w1_p0_1_16k" --buffer_size 0 --use_overlap --cache_sets 9445 74 | # -------------------------------------------------------------------------------- /scripts/kaggle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # For Colossalai enabled recsys 4 | # criteo kaggle 5 | export DATAPATH=/data/scratch/RecSys/criteo_kaggle_data/ 6 | # export DATAPATH=/data/criteo_kaggle_data/ 7 | export GPUNUM=1 8 | # export BATCHSIZE=16384 9 | # export BATCHSIZE=1024 10 | export CACHESIZE=337625 11 | export CACHERATIO=0.01 12 | export USE_LFU=1 13 | export USE_TABLE_SHARD=0 14 | export EVAL_ACC=0 15 | export PREFETCH_NUM=2 16 | export USE_ASYNC=0 17 | 18 | if [[ ${USE_LFU} == 1 ]]; then 19 | LFU_FLAG="--use_lfu" 20 | else 21 | export LFU_FLAG="" 22 | fi 23 | 24 | 25 | if [[ ${USE_TABLE_SHARD} == 1 ]]; then 26 | TABLE_SHARD_FLAG="--use_tablewise" 27 | else 28 | export TABLE_SHARD_FLAG="" 29 | fi 30 | 31 | if [[ ${EVAL_ACC} == 1 ]]; then 32 | EVAL_ACC_FLAG="--eval_acc" 33 | else 34 | export EVAL_ACC_FLAG="" 35 | fi 36 | 37 | if [[ ${USE_ASYNC} == 1 ]]; then 38 | ASYNC_FLAG="--use_cache_mgr_async_copy" 39 | else 40 | export ASYNC_FLAG="" 41 | fi 42 | 43 | set_n_least_used_CUDA_VISIBLE_DEVICES() { 44 | local n=${1:-"9999"} 45 | echo "GPU Memory Usage:" 46 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ 47 | | tail -n +2 \ 48 | | nl -v 0 \ 49 | | tee /dev/tty \ 50 | | sort -g -k 2 \ 51 | | awk '{print $1}' \ 52 | | head -n $n) 53 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') 54 | echo "Now CUDA_VISIBLE_DEVICES is set to:" 55 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 56 | } 57 | 58 | 59 | mkdir -p colo_logs 60 | for BATCHSIZE in 16384 #16384 #1024 2048 4096 16384 61 | do 62 | for PREFETCH_NUM in 1 #1 16 63 | do 64 | for USE_ASYNC in 0 65 | do 66 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM} 67 | 68 | TASK_NAME="gpu_${GPUNUM}_bs_${BATCHSIZE}_pf_${PREFETCH_NUM}_cache_${CACHERATIO}_async_${USE_ASYNC}" 69 | torchx run -s local_cwd -cfg log_dir=log/kaggle/${TASK_NAME} dist.ddp -j 1x${GPUNUM} --script recsys/dlrm_main.py -- \ 70 | --dataset_dir ${DATAPATH} --pin_memory --shuffle_batches \ 71 | --learning_rate 1. --batch_size ${BATCHSIZE} --use_sparse_embed_grad --use_cache --use_freq ${LFU_FLAG} ${TABLE_SHARD_FLAG} ${EVAL_ACC_FLAG} ${ASYNC_FLAG} \ 72 | --profile_dir "tensorboard_log/kaggle/${TASK_NAME}" --buffer_size 0 --use_overlap --cache_ratio ${CACHERATIO} --prefetch_num ${PREFETCH_NUM} 2>&1 | tee colo_logs/colo_${TASK_NAME}.txt 73 | done 74 | done 75 | done 76 | 77 | # torchx run -s local_cwd -cfg log_dir=log/kaggle/w${GPUNUM}_p1_16k dist.ddp -j 1x${GPUNUM} --script recsys/dlrm_main.py -- \ 78 | # --dataset_dir ${DATAPATH} --pin_memory --shuffle_batches \ 79 | # --learning_rate 1. --batch_size ${BATCHSIZE} --use_sparse_embed_grad --use_cache --use_freq --use_lfu \ 80 | # --profile_dir "tensorboard_log/kaggle/w${GPUNUM}_p1_16k" --buffer_size 0 --use_overlap --cache_sets ${CACHESIZE} 2>&1 | tee logs/colo_${GPUNUM}_${BATCHSIZE}_${CACHESIZE}.txt 81 | -------------------------------------------------------------------------------- /scripts/preprocess/.gitignore: -------------------------------------------------------------------------------- 1 | /* 2 | !/.gitignore 3 | !/npy_preproc_criteo.py 4 | !/split_criteo_kaggle.py 5 | !/npy_preproc_avazu.py 6 | !/taobao 7 | !README.md 8 | -------------------------------------------------------------------------------- /scripts/preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Dataset preprocess 2 | ## Criteo Kaggle 3 | 1. download data from: https://ailab.criteo.com/ressources/ 4 | 2. convert tsvs to npy files 5 | ```bash 6 | python npy_preproc_criteo.py --input_dir --output_dir 7 | ``` 8 | 3. split to train/val/test files 9 | ```bash 10 | python split_criteo_kaggle.py 11 | ``` 12 | You might need to change the `'/data/scratch/RecSys/criteo_kaggle_npy'` to `` in this file 13 | 14 | ## Avazu 15 | 1. download data from: https://www.kaggle.com/c/avazu-ctr-prediction/data 16 | 2. convert tsv files to npy files: 17 | ```bash 18 | python npy_preproc_avazu.py --input_dir --output_dir 19 | ``` 20 | 3. split train/val/test files 21 | ```bash 22 | python npy_preproc_avazu.py --input_dir --output_dir --is_split 23 | ``` 24 | 25 | ## Criteo Terabyte 26 | 1. download tsv source file from: https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/ 27 | Note that this requires around 1TB disk space. 28 | 2. clone the torchrec repo: 29 | ```bash 30 | git clone https://github.com/pytorch/torchrec.git 31 | cd torchrec/torchrec/datasets/scripts/nvt/ 32 | ``` 33 | 3. conduct the first two python script 34 | ```bash 35 | python convert_tsv_to_parquet.py -i -o 36 | python process_criteo_parquet.py -b -s 37 | ``` 38 | You might need to use the dockerfile to install nvtabular, 39 | since its installation requires a CUDA version different from our experiment setup -------------------------------------------------------------------------------- /scripts/preprocess/npy_preproc_avazu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | from recsys.datasets.avazu import CAT_FEATURE_COUNT, AvazuIterDataPipe, TOTAL_TRAINING_SAMPLES 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--input_dir", 13 | type=str, 14 | required=True, 15 | help="path to the dir where the csv file train is downloaded and unzipped") 16 | 17 | parser.add_argument("--output_dir", 18 | type=str, 19 | required=True, 20 | help="path to which the train/val/test splits are saved") 21 | 22 | parser.add_argument("--is_split", action='store_true') 23 | return parser.parse_args() 24 | 25 | 26 | def main(): 27 | # Note: this scripts is broken, to align with our experiments, 28 | # please refer to https://www.kaggle.com/code/leejunseok97/deepfm-deepctr-torch 29 | # Basically, the C14-C21 column of the resulting sparse files should be further split to the dense files. 30 | args = parse_args() 31 | 32 | if args.is_split: 33 | if not os.path.exists(args.input_dir): 34 | raise ValueError(f"{args.input_dir} has existed") 35 | 36 | if not os.path.exists(args.output_dir): 37 | os.makedirs(args.output_dir) 38 | 39 | for _t in ("sparse", 'label'): 40 | npy = np.load(os.path.join(args.input_dir, f"{_t}.npy")) 41 | train_split = npy[:TOTAL_TRAINING_SAMPLES] 42 | np.save(os.path.join(args.output_dir, f"train_{_t}.npy"), train_split) 43 | val_test_split = npy[TOTAL_TRAINING_SAMPLES:] 44 | np.save(os.path.join(args.output_dir, f"val_test_{_t}.npy"), val_test_split) 45 | del npy 46 | 47 | else: 48 | if not os.path.exists(args.output_dir): 49 | os.makedirs(args.output_dir) 50 | sparse_output_file_path = os.path.join(args.output_dir, "sparse.npy") 51 | label_output_file_path = os.path.join(args.output_dir, "label.npy") 52 | 53 | sparse, labels = [], [] 54 | for row_sparse, row_label in AvazuIterDataPipe(args.input_dir): 55 | sparse.append(row_sparse) 56 | labels.append(row_label) 57 | 58 | sparse_np = np.array(sparse, dtype=np.int32) 59 | del sparse 60 | labels_np = np.array(labels, dtype=np.int32).reshape(-1, 1) 61 | del labels 62 | 63 | for f_path, arr in [(sparse_output_file_path, sparse_np), (label_output_file_path, labels_np)]: 64 | np.save(f_path, arr) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /scripts/preprocess/npy_preproc_criteo.py: -------------------------------------------------------------------------------- 1 | # This script preprocesses Criteo dataset tsv files to binary (npy) files. 2 | 3 | # In order to exploit InMemoryBinaryCriteoIterDataPipe to accelerate the loading of data, 4 | # this file is modified from torchrec/datasets/scripts/npy_preproc_criteo.py and 5 | # torchrec.datasets.criteo 6 | # 7 | # Usage: 8 | # python npy_preproc_criteo.py --input_dir /where/criteo_kaggle/train.txt --output_dir /where/to/save/.npy 9 | # 10 | # You may need additional modifications for the file name, 11 | # and you can evenly split the whole dataset into 7 days by split_criteo_kaggle.py 12 | 13 | import argparse 14 | import os 15 | import sys 16 | from typing import List, Tuple 17 | 18 | import numpy as np 19 | from torchrec.datasets.criteo import CriteoIterDataPipe, INT_FEATURE_COUNT, CAT_FEATURE_COUNT 20 | from torchrec.datasets.utils import PATH_MANAGER_KEY, safe_cast 21 | from iopath.common.file_io import PathManagerFactory 22 | 23 | 24 | def tsv_to_npys( 25 | in_file: str, 26 | out_dense_file: str, 27 | out_sparse_file: str, 28 | out_labels_file: str, 29 | path_manager_key: str = PATH_MANAGER_KEY, 30 | ): 31 | """ 32 | For criteo kaggle 33 | """ 34 | 35 | def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]: 36 | label = safe_cast(row[0], int, 0) 37 | dense = [safe_cast(row[i], int, 0) for i in range(1, 1 + INT_FEATURE_COUNT)] 38 | sparse = [ 39 | int(safe_cast(row[i], str, "0") or "0", 16) 40 | for i in range(1 + INT_FEATURE_COUNT, 1 + INT_FEATURE_COUNT + CAT_FEATURE_COUNT) 41 | ] 42 | return dense, sparse, label # pyre-ignore[7] 43 | 44 | dense, sparse, labels = [], [], [] 45 | for (row_dense, row_sparse, row_label) in CriteoIterDataPipe([in_file], row_mapper=row_mapper): 46 | dense.append(row_dense) 47 | sparse.append(row_sparse) 48 | labels.append(row_label) 49 | 50 | dense_np = np.array(dense, dtype=np.int32) 51 | del dense 52 | sparse_np = np.array(sparse, dtype=np.int32) 53 | del sparse 54 | labels_np = np.array(labels, dtype=np.int32) 55 | del labels 56 | 57 | # Why log +3? 58 | dense_np -= (dense_np.min() - 2) 59 | dense_np = np.log(dense_np, dtype=np.float32) 60 | 61 | labels_np = labels_np.reshape((-1, 1)) 62 | path_manager = PathManagerFactory().get(path_manager_key) 63 | for (fname, arr) in [ 64 | (out_dense_file, dense_np), 65 | (out_sparse_file, sparse_np), 66 | (out_labels_file, labels_np), 67 | ]: 68 | with path_manager.open(fname, "wb") as fout: 69 | np.save(fout, arr) 70 | 71 | 72 | def parse_args(argv: List[str]) -> argparse.Namespace: 73 | parser = argparse.ArgumentParser(description="Criteo tsv -> npy preprocessing script.") 74 | parser.add_argument( 75 | "--input_dir", 76 | type=str, 77 | required=True, 78 | help="Input directory containing Criteo tsv files. Files in the directory " 79 | "should be named day_{0-23}.", 80 | ) 81 | parser.add_argument( 82 | "--output_dir", 83 | type=str, 84 | required=True, 85 | help="Output directory to store npy files.", 86 | ) 87 | return parser.parse_args(argv) 88 | 89 | 90 | def main(argv: List[str]) -> None: 91 | """ 92 | This function preprocesses the raw Criteo tsvs into the format (npy binary) 93 | expected by InMemoryBinaryCriteoIterDataPipe. 94 | 95 | Args: 96 | argv (List[str]): Command line args. 97 | 98 | Returns: 99 | None. 100 | """ 101 | 102 | args = parse_args(argv) 103 | input_dir = args.input_dir 104 | output_dir = args.output_dir 105 | 106 | for f in os.listdir(input_dir): 107 | in_file_path = os.path.join(input_dir, f) 108 | dense_out_file_path = os.path.join(output_dir, f + "_dense.npy") 109 | sparse_out_file_path = os.path.join(output_dir, f + "_sparse.npy") 110 | labels_out_file_path = os.path.join(output_dir, f + "_labels.npy") 111 | print(f"Processing {in_file_path}. Outputs will be saved to {dense_out_file_path}" 112 | f", {sparse_out_file_path}, and {labels_out_file_path}...") 113 | tsv_to_npys( 114 | in_file_path, 115 | dense_out_file_path, 116 | sparse_out_file_path, 117 | labels_out_file_path, 118 | ) 119 | print(f"Done processing {in_file_path}.") 120 | 121 | 122 | if __name__ == "__main__": 123 | main(sys.argv[1:]) 124 | -------------------------------------------------------------------------------- /scripts/preprocess/split_criteo_kaggle.py: -------------------------------------------------------------------------------- 1 | # This script adapts criteo kaggle dataset into 7 days' data 2 | # 3 | # Please alter the arguments of the two functions here. 4 | # 5 | # Usage: 6 | # python split_criteo_kaggle.py 7 | 8 | import numpy as np 9 | import os 10 | 11 | from torchrec.datasets.criteo import BinaryCriteoUtils, CAT_FEATURE_COUNT 12 | 13 | 14 | def main(data_dir, output_dir, days=7): 15 | STAGES = ("labels", "dense", "sparse") 16 | files = [os.path.join(data_dir, f"train.txt_{split}.npy") for split in STAGES] 17 | total_rows = BinaryCriteoUtils.get_shape_from_npy(files[0])[0] 18 | 19 | indices = list(range(0, total_rows, total_rows // days)) 20 | ranges = [] 21 | for i in range(len(indices) - 1): 22 | left_idx = indices[i] 23 | right_idx = indices[i + 1] if i < len(indices) - 2 else total_rows 24 | ranges.append((left_idx, right_idx)) 25 | 26 | for _s, _f in zip(STAGES, files): 27 | for day, (left_idx, right_idx) in enumerate(ranges): 28 | chunk = BinaryCriteoUtils.load_npy_range(_f, left_idx, right_idx - left_idx) 29 | output_fname = f"day_{day}_{_s}.npy" 30 | np.save(os.path.join(output_dir, output_fname), chunk) 31 | 32 | 33 | def get_num_embeddings_per_feature(path_to_sparse): 34 | sparse = np.load(path_to_sparse) 35 | assert sparse.shape[1] == CAT_FEATURE_COUNT 36 | 37 | nums = [] 38 | for i in range(CAT_FEATURE_COUNT): 39 | nums.append(len(np.unique(sparse[:, i]))) 40 | print(','.join(map(lambda x: str(x), nums))) 41 | 42 | 43 | if __name__ == '__main__': 44 | main('/data/scratch/RecSys/criteo_kaggle_npy', 'criteo_kaggle') 45 | get_num_embeddings_per_feature('/data/scratch/RecSys/criteo_kaggle_npy/train.txt_sparse.npy') 46 | -------------------------------------------------------------------------------- /scripts/preprocess/taobao/README.md: -------------------------------------------------------------------------------- 1 | ### Preprocessing Scripts for Taobao User Behavior Dataset 2 | 3 | Credit: https://github.com/STAR-Laboratory/Accelerating-RecSys-Training 4 | 5 | 6 | Set the global constants in `csv_to_txt.py` before running: `python ./csv_to_txt.py`. 7 | 8 | Set the raw data paths and desired output data paths before running `bash ./run_txt_to_npz.sh` -------------------------------------------------------------------------------- /scripts/preprocess/taobao/csv_to_txt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 UIC-Paper 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | # Source: https://github.com/UIC-Paper/MIMN 7 | 8 | import pickle as pkl 9 | import pandas as pd 10 | import random 11 | import numpy as np 12 | import argparse 13 | 14 | RAW_DATA_FILE = "./UserBehavior.csv" 15 | DATASET_PKL = "./dataset.pkl" 16 | Test_File = "./taobao_test.txt" 17 | Train_File = "./taobao_train.txt" 18 | Train_handle = open(Train_File, 'w') 19 | Test_handle = open(Test_File, 'w') 20 | Feature_handle = open("./taobao_feature.pkl", 'wb') 21 | 22 | MAX_LEN_ITEM = 200 23 | 24 | 25 | def to_df(file_name): 26 | df = pd.read_csv(RAW_DATA_FILE, header=None, names=['uid', 'iid', 'cid', 'btag', 'time']) 27 | return df 28 | 29 | 30 | def remap(df): 31 | item_key = sorted(df['iid'].unique().tolist()) 32 | item_len = len(item_key) 33 | item_map = dict(zip(item_key, range(item_len))) 34 | 35 | df['iid'] = df['iid'].map(lambda x: item_map[x]) 36 | 37 | user_key = sorted(df['uid'].unique().tolist()) 38 | user_len = len(user_key) 39 | user_map = dict(zip(user_key, range(item_len, item_len + user_len))) 40 | df['uid'] = df['uid'].map(lambda x: user_map[x]) 41 | 42 | cate_key = sorted(df['cid'].unique().tolist()) 43 | cate_len = len(cate_key) 44 | cate_map = dict(zip(cate_key, range(user_len + item_len, user_len + item_len + cate_len))) 45 | df['cid'] = df['cid'].map(lambda x: cate_map[x]) 46 | 47 | btag_key = sorted(df['btag'].unique().tolist()) 48 | btag_len = len(btag_key) 49 | btag_map = dict(zip(btag_key, range(user_len + item_len + cate_len, user_len + item_len + cate_len + btag_len))) 50 | df['btag'] = df['btag'].map(lambda x: btag_map[x]) 51 | 52 | print(item_len, user_len, cate_len, btag_len) 53 | return df, item_len, user_len + item_len + cate_len + btag_len + 1 #+1 is for unknown target btag 54 | 55 | 56 | def gen_user_item_group(df, item_cnt, feature_size): 57 | user_df = df.sort_values(['uid', 'time']).groupby('uid') 58 | item_df = df.sort_values(['iid', 'time']).groupby('iid') 59 | 60 | print("group completed") 61 | return user_df, item_df 62 | 63 | 64 | def gen_dataset(user_df, item_df, item_cnt, feature_size, dataset_pkl): 65 | train_sample_list = [] 66 | test_sample_list = [] 67 | 68 | # get each user's last touch point time 69 | 70 | print(len(user_df)) 71 | 72 | user_last_touch_time = [] 73 | for uid, hist in user_df: 74 | user_last_touch_time.append(hist['time'].tolist()[-1]) 75 | print("get user last touch time completed") 76 | 77 | user_last_touch_time_sorted = sorted(user_last_touch_time) 78 | split_time = user_last_touch_time_sorted[int(len(user_last_touch_time_sorted) * 0.7)] 79 | 80 | cnt = 0 81 | for uid, hist in user_df: 82 | cnt += 1 83 | print(cnt) 84 | item_hist = hist['iid'].tolist() 85 | cate_hist = hist['cid'].tolist() 86 | btag_hist = hist['btag'].tolist() 87 | target_item_time = hist['time'].tolist()[-1] 88 | 89 | target_item = item_hist[-1] 90 | target_item_cate = cate_hist[-1] 91 | target_item_btag = feature_size 92 | label = 1 93 | test = (target_item_time > split_time) 94 | 95 | # neg sampling 96 | neg = random.randint(0, 1) 97 | if neg == 1: 98 | label = 0 99 | while target_item == item_hist[-1]: 100 | target_item = random.randint(0, item_cnt - 1) 101 | target_item_cate = item_df.get_group(target_item)['cid'].tolist()[0] 102 | target_item_btag = feature_size 103 | 104 | # the item history part of the sample 105 | item_part = [] 106 | for i in range(len(item_hist) - 1): 107 | item_part.append([uid, item_hist[i], cate_hist[i], btag_hist[i]]) 108 | item_part.append([uid, target_item, target_item_cate, target_item_btag]) 109 | # item_part_len = min(len(item_part), MAX_LEN_ITEM) 110 | 111 | # choose the item side information: which user has clicked the target item 112 | # padding history with 0 113 | if len(item_part) <= MAX_LEN_ITEM: 114 | item_part_pad = [[0] * 4] * (MAX_LEN_ITEM - len(item_part)) + item_part 115 | else: 116 | item_part_pad = item_part[len(item_part) - MAX_LEN_ITEM:len(item_part)] 117 | 118 | # gen sample 119 | # sample = (label, item_part_pad, item_part_len, user_part_pad, user_part_len) 120 | 121 | if test: 122 | # test_set.append(sample) 123 | cat_list = [] 124 | item_list = [] 125 | # btag_list = [] 126 | for i in range(len(item_part_pad)): 127 | item_list.append(item_part_pad[i][1]) 128 | cat_list.append(item_part_pad[i][2]) 129 | # cat_list.append(item_part_pad[i][0]) 130 | test_sample_list.append( 131 | str(uid) + "\t" + str(target_item) + "\t" + str(target_item_cate) + "\t" + str(label) + "\t" + 132 | ",".join(map(str, item_list)) + "\t" + ",".join(map(str, cat_list)) + "\n") 133 | else: 134 | cat_list = [] 135 | item_list = [] 136 | # btag_list = [] 137 | for i in range(len(item_part_pad)): 138 | item_list.append(item_part_pad[i][1]) 139 | cat_list.append(item_part_pad[i][2]) 140 | train_sample_list.append( 141 | str(uid) + "\t" + str(target_item) + "\t" + str(target_item_cate) + "\t" + str(label) + "\t" + 142 | ",".join(map(str, item_list)) + "\t" + ",".join(map(str, cat_list)) + "\n") 143 | 144 | train_sample_length_quant = len(train_sample_list) / 256 * 256 145 | test_sample_length_quant = len(test_sample_list) / 256 * 256 146 | 147 | print("train_sample_length_quant", train_sample_length_quant) 148 | print("length", len(train_sample_list)) 149 | train_sample_list = train_sample_list[:int(train_sample_length_quant)] 150 | test_sample_list = test_sample_list[:int(test_sample_length_quant)] 151 | random.shuffle(train_sample_list) 152 | print("length", len(train_sample_list)) 153 | return train_sample_list, test_sample_list 154 | 155 | 156 | def produce_neg_item_hist_with_cate(train_file, test_file): 157 | item_dict = {} 158 | sample_count = 0 159 | hist_seq = 0 160 | for line in train_file: 161 | units = line.strip().split("\t") 162 | item_hist_list = units[4].split(",") 163 | cate_hist_list = units[5].split(",") 164 | hist_list = zip(item_hist_list, cate_hist_list) 165 | hist_list = list(hist_list) 166 | #hist_seq_list = list(hist_list) 167 | hist_seq = len(hist_list) 168 | sample_count += 1 169 | for item in hist_list: 170 | item_dict.setdefault(str(item), 0) 171 | 172 | #print("hist_list : ", hist_list) 173 | 174 | for line in test_file: 175 | units = line.strip().split("\t") 176 | item_hist_list = units[4].split(",") 177 | cate_hist_list = units[5].split(",") 178 | hist_list = zip(item_hist_list, cate_hist_list) 179 | hist_list = list(hist_list) 180 | #hist_seq_list = list(hist_list) 181 | hist_seq = len(hist_list) 182 | sample_count += 1 183 | for item in hist_list: 184 | item_dict.setdefault(str(item), 0) 185 | 186 | #print("item_dict : ", item_dict) 187 | del (item_dict["('0', '0')"]) 188 | keys_list = list(item_dict.keys()) 189 | keys_list = np.array(keys_list) 190 | print("item_dict.keys()", keys_list.shape) 191 | neg_array = np.random.choice(keys_list, (sample_count, hist_seq + 20)) 192 | neg_list = neg_array.tolist() 193 | sample_count = 0 194 | 195 | for line in train_file: 196 | units = line.strip().split("\t") 197 | item_hist_list = units[4].split(",") 198 | cate_hist_list = units[5].split(",") 199 | hist_list = zip(item_hist_list, cate_hist_list) 200 | hist_list = list(hist_list) 201 | #hist_seq_list = list(hist_list) 202 | hist_seq = len(hist_list) 203 | neg_hist_list = [] 204 | for item in neg_list[sample_count]: 205 | item = eval(item) 206 | if item not in hist_list: 207 | neg_hist_list.append(item) 208 | if len(neg_hist_list) == hist_seq: 209 | break 210 | sample_count += 1 211 | neg_item_list, neg_cate_list = zip(*neg_hist_list) 212 | Train_handle.write(line.strip() + "\t" + ",".join(neg_item_list) + "\t" + ",".join(neg_cate_list) + "\n") 213 | 214 | for line in test_file: 215 | units = line.strip().split("\t") 216 | item_hist_list = units[4].split(",") 217 | cate_hist_list = units[5].split(",") 218 | hist_list = zip(item_hist_list, cate_hist_list) 219 | hist_list = list(hist_list) 220 | #hist_seq_list = list(hist_list) 221 | hist_seq = len(hist_list) 222 | neg_hist_list = [] 223 | for item in neg_list[sample_count]: 224 | item = eval(item) 225 | if item not in hist_list: 226 | neg_hist_list.append(item) 227 | if len(neg_hist_list) == hist_seq: 228 | break 229 | sample_count += 1 230 | neg_item_list, neg_cate_list = zip(*neg_hist_list) 231 | Test_handle.write(line.strip() + "\t" + ",".join(neg_item_list) + "\t" + ",".join(neg_cate_list) + "\n") 232 | 233 | 234 | def main(): 235 | df = to_df(RAW_DATA_FILE) 236 | df, item_cnt, feature_size = remap(df) 237 | print("feature_size", item_cnt, feature_size) 238 | feature_total_num = feature_size + 1 239 | pkl.dump(feature_total_num, Feature_handle) 240 | 241 | user_df, item_df = gen_user_item_group(df, item_cnt, feature_size) 242 | train_sample_list, test_sample_list = gen_dataset(user_df, item_df, item_cnt, feature_size, DATASET_PKL) 243 | produce_neg_item_hist_with_cate(train_sample_list, test_sample_list) 244 | 245 | 246 | if __name__ == '__main__': 247 | main() 248 | -------------------------------------------------------------------------------- /scripts/preprocess/taobao/run_txt_to_npz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | tbsm_py="python txt_to_npz.py " 9 | 10 | $tbsm_py --datatype="taobao" \ 11 | --num-train-pts=690000 --num-val-pts=300000 --points-per-user=10 12 | --numpy-rand-seed=123 --arch-embedding-size="987994-4162024-9439" 13 | --raw-train-file=./taobao_train.txt \ 14 | --raw-test-file=./taobao_test.txt \ 15 | --pro-train-file=./taobao_train_t20.npz \ 16 | --pro-val-file=./taobao_val_t20.npz \ 17 | --pro-test-file=./taobao_test_t20.npz \ 18 | --ts-length=20 -------------------------------------------------------------------------------- /scripts/preprocess/taobao/txt_to_npz.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import sys 3 | import argparse 4 | import numpy as np 5 | 6 | 7 | class TaobaoTxtToNpz: 8 | 9 | def __init__( 10 | self, 11 | datatype, 12 | mode, 13 | ts_length=20, 14 | points_per_user=4, 15 | numpy_rand_seed=7, 16 | raw_path="", 17 | pro_data="", 18 | spa_fea_sizes="", 19 | num_pts=1, # pts to train or test 20 | ): 21 | # save arguments 22 | if mode == "train": 23 | self.numpy_rand_seed = numpy_rand_seed 24 | else: 25 | self.numpy_rand_seed = numpy_rand_seed + 31 26 | self.mode = mode 27 | # save dataset parameters 28 | self.total = num_pts # number of lines in txt to process 29 | self.ts_length = ts_length 30 | self.points_per_user = points_per_user # pos and neg points per user 31 | self.spa_fea_sizes = spa_fea_sizes 32 | self.M = 200 # max history length 33 | 34 | # split the datafile into path and filename 35 | lstr = raw_path.split("/") 36 | self.d_path = "/".join(lstr[0:-1]) + "/" 37 | self.d_file = lstr[-1] 38 | 39 | # preprocess data if needed 40 | if path.exists(str(pro_data)): 41 | print("Reading pre-processed data=%s" % (str(pro_data))) 42 | file = str(pro_data) 43 | else: 44 | file = str(pro_data) 45 | levels = np.fromstring(self.spa_fea_sizes, dtype=int, sep="-") 46 | if datatype == "taobao": 47 | self.Unum = levels[0] # 987994 num of users 48 | self.Inum = levels[1] # 4162024 num of items 49 | self.Cnum = levels[2] # 9439 num of categories 50 | print("Reading raw data=%s" % (str(raw_path))) 51 | if self.mode == "test": 52 | self.build_taobao_test( 53 | raw_path, 54 | file, 55 | ) 56 | else: 57 | self.build_taobao_train_or_val( 58 | raw_path, 59 | file, 60 | ) 61 | elif datatype == "synthetic": 62 | self.build_synthetic_train_or_val(file,) 63 | # load data 64 | with np.load(file) as data: 65 | self.X_cat = data["X_cat"] 66 | self.X_int = data["X_int"] 67 | self.y = data["y"] 68 | 69 | # common part between train/val and test generation 70 | # truncates (if needed) and shuffles data points 71 | def truncate_and_save(self, out_file, do_shuffle, t, users, items, cats, times, y): 72 | # truncate. If for some users we didn't generate had too short history 73 | # we truncate the unused portion of the pre-allocated matrix. 74 | if t < self.total_out: 75 | users = users[:t, :] 76 | items = items[:t, :] 77 | cats = cats[:t, :] 78 | times = times[:t, :] 79 | y = y[:t] 80 | 81 | # shuffle 82 | if do_shuffle: 83 | indices = np.arange(len(y)) 84 | indices = np.random.permutation(indices) 85 | users = users[indices] 86 | items = items[indices] 87 | cats = cats[indices] 88 | times = times[indices] 89 | y = y[indices] 90 | 91 | N = len(y) 92 | X_cat = np.zeros((3, N, self.ts_length + 1), dtype="i4") # 4 byte int 93 | X_int = np.zeros((1, N, self.ts_length + 1), dtype=np.float) 94 | X_cat[0, :, :] = users 95 | X_cat[1, :, :] = items 96 | X_cat[2, :, :] = cats 97 | X_int[0, :, :] = times 98 | 99 | # saving to compressed numpy file 100 | if not path.exists(out_file): 101 | np.savez_compressed( 102 | out_file, 103 | X_cat=X_cat, 104 | X_int=X_int, 105 | y=y, 106 | ) 107 | return 108 | 109 | # processes raw train or validation into npz format required by training 110 | # for train data out of each line in raw datafile produces several randomly chosen 111 | # datapoints, max number of datapoints per user is specified by points_per_user 112 | # argument, for validation data produces one datapoint per user. 113 | def build_taobao_train_or_val(self, raw_path, out_file): 114 | with open(str(raw_path)) as f: 115 | for i, _ in enumerate(f): 116 | if i % 50000 == 0: 117 | print("pre-processing line: ", i) 118 | self.total = min(self.total, i + 1) 119 | 120 | print("total lines: ", self.total) 121 | 122 | self.total_out = self.total * self.points_per_user * 2 # pos + neg points 123 | print("Total number of points in raw datafile: ", self.total) 124 | print("Total number of points in output will be at most: ", self.total_out) 125 | np.random.seed(self.numpy_rand_seed) 126 | r_target = np.arange(0, self.M - 1) 127 | 128 | time = np.arange(self.ts_length + 1, dtype=np.int32) / (self.ts_length + 1) 129 | # time = np.ones(self.ts_length + 1, dtype=np.int32) 130 | 131 | users = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int 132 | items = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int 133 | cats = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int 134 | times = np.zeros((self.total_out, self.ts_length + 1), dtype=np.float) 135 | y = np.zeros(self.total_out, dtype="i4") # 4 byte int 136 | 137 | # determine how many datapoints to take from each user based on the length of 138 | # user behavior sequence 139 | # ind=0, 1, 2, 3,... t < 10, 20, 30, 40, 50, 60, ... 140 | k = 20 141 | regime = np.zeros(k, dtype=np.int) 142 | regime[1], regime[2], regime[3] = 1, 3, 6 143 | for j in range(4, k): 144 | regime[j] = self.points_per_user 145 | if self.mode == "val": 146 | self.points_per_user = 1 147 | for j in range(k): 148 | regime[j] = np.min([regime[j], self.points_per_user]) 149 | last = self.M - 1 # max index of last item 150 | 151 | # try to generate the desired number of points (time series) per each user. 152 | # if history is short it may not succeed to generate sufficiently different 153 | # time series for a particular user. 154 | t, t_pos, t_neg, t_short = 0, 0, 0, 0 155 | with open(str(raw_path)) as f: 156 | for i, line in enumerate(f): 157 | if i % 1000 == 0: 158 | print("processing line: ", i, t, t_pos, t_neg, t_short) 159 | if i >= self.total: 160 | break 161 | units = line.strip().split("\t") 162 | item_hist_list = units[4].split(",") 163 | cate_hist_list = units[5].split(",") 164 | neg_item_hist_list = units[6].split(",") 165 | neg_cate_hist_list = units[7].split(",") 166 | user = np.array(np.maximum(np.int32(units[0]) - self.Inum, 0), dtype=np.int32) 167 | # y[i] = np.int32(units[3]) 168 | items_ = np.array(list(map(lambda x: np.maximum(np.int32(x), 0), item_hist_list)), dtype=np.int32) 169 | cats_ = np.array(list(map(lambda x: np.maximum(np.int32(x) - self.Inum - self.Unum, 0), 170 | cate_hist_list)), 171 | dtype=np.int32) 172 | neg_items_ = np.array(list(map(lambda x: np.maximum(np.int32(x), 0), neg_item_hist_list)), 173 | dtype=np.int32) 174 | neg_cats_ = np.array(list( 175 | map(lambda x: np.maximum(np.int32(x) - self.Inum - self.Unum, 0), neg_cate_hist_list)), 176 | dtype=np.int32) 177 | 178 | # select datapoints 179 | first = np.argmax(items_ > 0) 180 | ind = int((last - first) // 10) # index into regime array 181 | # pos 182 | for _ in range(regime[ind]): 183 | a1 = min(first + self.ts_length, last - 1) 184 | end = np.random.randint(a1, last) 185 | indices = np.arange(end - self.ts_length, end + 1) 186 | if items_[indices[0]] == 0: 187 | t_short += 1 188 | items[t] = items_[indices] 189 | cats[t] = cats_[indices] 190 | users[t] = np.full(self.ts_length + 1, user) 191 | times[t] = time 192 | y[t] = 1 193 | # check 194 | if np.any(users[t] < 0) or np.any(items[t] < 0) \ 195 | or np.any(cats[t] < 0): 196 | sys.exit("Categorical feature less than zero after \ 197 | processing. Aborting...") 198 | t += 1 199 | t_pos += 1 200 | # neg 201 | for _ in range(regime[ind]): 202 | a1 = min(first + self.ts_length - 1, last - 1) 203 | end = np.random.randint(a1, last) 204 | indices = np.arange(end - self.ts_length + 1, end + 1) 205 | if items_[indices[0]] == 0: 206 | t_short += 1 207 | items[t, :-1] = items_[indices] 208 | cats[t, :-1] = cats_[indices] 209 | neg_indices = np.random.choice(r_target, 1, replace=False) # random final item 210 | items[t, -1] = neg_items_[neg_indices] 211 | cats[t, -1] = neg_cats_[neg_indices] 212 | users[t] = np.full(self.ts_length + 1, user) 213 | times[t] = time 214 | y[t] = 0 215 | # check 216 | if np.any(users[t] < 0) or np.any(items[t] < 0) \ 217 | or np.any(cats[t] < 0): 218 | sys.exit("Categorical feature less than zero after \ 219 | processing. Aborting...") 220 | t += 1 221 | t_neg += 1 222 | 223 | print("total points, pos points, neg points: ", t, t_pos, t_neg) 224 | 225 | self.truncate_and_save(out_file, True, t, users, items, cats, times, y) 226 | return 227 | 228 | # processes raw test datafile into npz format required to be used by 229 | # inference step, produces one datapoint per user by taking last ts-length items 230 | def build_taobao_test(self, raw_path, out_file): 231 | 232 | with open(str(raw_path)) as f: 233 | for i, _ in enumerate(f): 234 | if i % 50000 == 0: 235 | print("pre-processing line: ", i) 236 | self.total = i + 1 237 | 238 | self.total_out = self.total # pos + neg points 239 | print("ts_length: ", self.ts_length) 240 | print("Total number of points in raw datafile: ", self.total) 241 | print("Total number of points in output will be at most: ", self.total_out) 242 | 243 | time = np.arange(self.ts_length + 1, dtype=np.int32) / (self.ts_length + 1) 244 | 245 | users = np.zeros((self.total_out, self.ts_length + 1), dtypei4="") # 4 byte int 246 | items = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int 247 | cats = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int 248 | times = np.zeros((self.total_out, self.ts_length + 1), dtype=np.float) 249 | y = np.zeros(self.total_out, dtype="i4") # 4 byte int 250 | 251 | # try to generate the desired number of points (time series) per each user. 252 | # if history is short it may not succeed to generate sufficiently different 253 | # time series for a particular user. 254 | t, t_pos, t_neg = 0, 0, 0 255 | with open(str(raw_path)) as f: 256 | for i, line in enumerate(f): 257 | if i % 1000 == 0: 258 | print("processing line: ", i, t, t_pos, t_neg) 259 | if i >= self.total: 260 | break 261 | units = line.strip().split("\t") 262 | item_hist_list = units[4].split(",") 263 | cate_hist_list = units[5].split(",") 264 | 265 | user = np.array(np.maximum(np.int32(units[0]) - self.Inum, 0), dtype=np.int32) 266 | y[t] = np.int32(units[3]) 267 | items_ = np.array(list(map(lambda x: np.maximum(np.int32(x), 0), item_hist_list)), dtype=np.int32) 268 | cats_ = np.array(list(map(lambda x: np.maximum(np.int32(x) - self.Inum - self.Unum, 0), 269 | cate_hist_list)), 270 | dtype=np.int32) 271 | 272 | # get pts 273 | items[t] = items_[-(self.ts_length + 1):] 274 | cats[t] = cats_[-(self.ts_length + 1):] 275 | users[t] = np.full(self.ts_length + 1, user) 276 | times[t] = time 277 | # check 278 | if np.any(users[t] < 0) or np.any(items[t] < 0) \ 279 | or np.any(cats[t] < 0): 280 | sys.exit("Categorical feature less than zero after \ 281 | processing. Aborting...") 282 | if y[t] == 1: 283 | t_pos += 1 284 | else: 285 | t_neg += 1 286 | t += 1 287 | 288 | print("total points, pos points, neg points: ", t, t_pos, t_neg) 289 | 290 | self.truncate_and_save(out_file, False, t, users, items, cats, times, y) 291 | return 292 | 293 | # builds small synthetic data mimicking the structure of taobao data 294 | def build_synthetic_train_or_val(self, out_file): 295 | 296 | np.random.seed(123) 297 | fea_sizes = np.fromstring(self.spa_fea_sizes, dtype=int, sep="-") 298 | maxval = np.min(fea_sizes) 299 | num_s = len(fea_sizes) 300 | X_cat = np.random.randint(maxval, size=(num_s, self.total, self.ts_length + 1), dtype="i4") # 4 byte int 301 | X_int = np.random.uniform(0, 1, size=(1, self.total, self.ts_length + 1)) 302 | y = np.random.randint(0, 2, self.total, dtype="i4") # 4 byte int 303 | 304 | # saving to compressed numpy file 305 | if not path.exists(out_file): 306 | np.savez_compressed( 307 | out_file, 308 | X_cat=X_cat, 309 | X_int=X_int, 310 | y=y, 311 | ) 312 | return 313 | 314 | 315 | # creates a loader (train, val or test data) to be used in the main training loop 316 | # or during inference step 317 | def make_tbsm_data_and_loader(args, mode): 318 | if mode == "train": 319 | raw = args.raw_train_file 320 | proc = args.pro_train_file 321 | numpts = args.num_train_pts 322 | elif mode == "val": 323 | raw = args.raw_train_file 324 | proc = args.pro_val_file 325 | numpts = args.num_val_pts 326 | else: 327 | raw = args.raw_test_file 328 | proc = args.pro_test_file 329 | numpts = 1 330 | 331 | TaobaoTxtToNpz( 332 | args.datatype, 333 | mode, 334 | args.ts_length, 335 | args.points_per_user, 336 | args.numpy_rand_seed, 337 | raw, 338 | proc, 339 | args.arch_embedding_size, 340 | numpts, 341 | ) 342 | 343 | 344 | def main(args): 345 | make_tbsm_data_and_loader(args, 'train') 346 | make_tbsm_data_and_loader(args, 'val') 347 | make_tbsm_data_and_loader(args, 'test') 348 | 349 | 350 | if __name__ == '__main__': 351 | parser = argparse.ArgumentParser() 352 | parser.add_argument("--datatype", type=str, default="taobao") 353 | parser.add_argument("--raw-train-file", type=str, default="./input/train.txt") 354 | parser.add_argument("--pro-train-file", type=str, default="./output/train.npz") 355 | parser.add_argument("--raw-test-file", type=str, default="./input/test.txt") 356 | parser.add_argument("--pro-test-file", type=str, default="./output/test.npz") 357 | parser.add_argument("--pro-val-file", type=str, default="./output/val.npz") 358 | parser.add_argument("--ts-length", type=int, default=20) 359 | parser.add_argument("--num-train-pts", type=int, default=100) 360 | parser.add_argument("--num-val-pts", type=int, default=20) 361 | parser.add_argument("--points-per-user", type=int, default=10) 362 | parser.add_argument("--arch-embedding-size", type=str, default="4-3-2") # vectors 363 | parser.add_argument("--numpy-rand-seed", type=int, default=123) 364 | args = parser.parse_args() 365 | main(args) 366 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # For Colossalai enabled recsys 4 | # bash scripts/kaggle.sh 5 | 6 | bash scripts/torchrec_kaggle.sh 7 | 8 | # bash scripts/avazu.sh 9 | 10 | bash scripts/torchrec_avazu.sh 11 | 12 | # bash scripts/terabyte.sh 13 | 14 | bash scripts/torchrec_terabyte.sh 15 | -------------------------------------------------------------------------------- /scripts/terabyte.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # For Colossalai enabled recsys 3 | 4 | # criteo terabyte 5 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 6 | --dataset_dir /data/criteo_preproc/ \ 7 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 8 | --profile_dir "tensorboard_log/terabyte/w1_p1_16k" --buffer_size 0 --use_overlap --cache_sets 1779442 9 | 10 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w2_p1_16k dist.ddp -j 1x2 --script recsys/dlrm_main.py -- \ 11 | --dataset_dir /data/criteo_preproc/ \ 12 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 13 | --profile_dir "tensorboard_log/terabyte/w2_p1_16k" --buffer_size 0 --use_overlap --cache_sets 1779442 14 | 15 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w4_p1_16k dist.ddp -j 1x4 --script recsys/dlrm_main.py -- \ 16 | --dataset_dir /data/criteo_preproc/ \ 17 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 18 | --profile_dir "tensorboard_log/terabyte/w4_p1_16k" --buffer_size 0 --use_overlap --cache_sets 1779442 19 | 20 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w8_p1_16k dist.ddp -j 1x8 --script recsys/dlrm_main.py -- \ 21 | --dataset_dir /data/criteo_preproc/ \ 22 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 23 | --profile_dir "tensorboard_log/terabyte/w8_p1_16k" --buffer_size 0 --use_overlap --cache_sets 1779442 24 | 25 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_32k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 26 | --dataset_dir /data/criteo_preproc/ \ 27 | --learning_rate 1. --batch_size 32768 --use_sparse_embed_grad --use_cache --use_freq \ 28 | --profile_dir "tensorboard_log/terabyte/w1_p1_32k" --buffer_size 0 --use_overlap --cache_sets 1779442 29 | 30 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_8k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 31 | --dataset_dir /data/criteo_preproc/ \ 32 | --learning_rate 1. --batch_size 8192 --use_sparse_embed_grad --use_cache --use_freq \ 33 | --profile_dir "tensorboard_log/terabyte/w1_p1_8k" --buffer_size 0 --use_overlap --cache_sets 1779442 34 | 35 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_4k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 36 | --dataset_dir /data/criteo_preproc/ \ 37 | --learning_rate 1. --batch_size 4096 --use_sparse_embed_grad --use_cache --use_freq \ 38 | --profile_dir "tensorboard_log/terabyte/w1_p1_4k" --buffer_size 0 --use_overlap --cache_sets 1779442 39 | 40 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_2k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 41 | --dataset_dir /data/criteo_preproc/ \ 42 | --learning_rate 1. --batch_size 2048 --use_sparse_embed_grad --use_cache --use_freq \ 43 | --profile_dir "tensorboard_log/terabyte/w1_p1_2k" --buffer_size 0 --use_overlap --cache_sets 1779442 44 | 45 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_1k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 46 | --dataset_dir /data/criteo_preproc/ \ 47 | --learning_rate 1. --batch_size 1024 --use_sparse_embed_grad --use_cache --use_freq \ 48 | --profile_dir "tensorboard_log/terabyte/w1_p1_1k" --buffer_size 0 --use_overlap --cache_sets 1779442 49 | 50 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p10_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 51 | --dataset_dir /data/criteo_preproc/ \ 52 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 53 | --profile_dir "tensorboard_log/terabyte/w1_p10_16k" --buffer_size 0 --use_overlap --cache_sets 17794427 54 | 55 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p5_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 56 | --dataset_dir /data/criteo_preproc/ \ 57 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 58 | --profile_dir "tensorboard_log/terabyte/w1_p5_16k" --buffer_size 0 --use_overlap --cache_sets 8897213 59 | 60 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p2_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 61 | --dataset_dir /data/criteo_preproc/ \ 62 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 63 | --profile_dir "tensorboard_log/terabyte/w1_p2_16k" --buffer_size 0 --use_overlap --cache_sets 3558885 64 | 65 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p0_1_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 66 | --dataset_dir /data/criteo_preproc/ \ 67 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 68 | --profile_dir "tensorboard_log/terabyte/w1_p0_1_16k" --buffer_size 0 --use_overlap --cache_sets 177944 69 | 70 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p0_5_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \ 71 | --dataset_dir /data/criteo_preproc/ \ 72 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \ 73 | --profile_dir "tensorboard_log/terabyte/w1_p0_5_16k" --buffer_size 0 --use_overlap --cache_sets 889721 74 | 75 | -------------------------------------------------------------------------------- /scripts/torchrec_avazu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # For TorchRec baseline 4 | torchx run -s local_cwd -cfg log_dir=log/torchrec_avazu/w1_16k dist.ddp -j 1x1 --script baselines/dlrm_main.py -- \ 5 | --in_memory_binary_criteo_path /data/avazu_sample --embedding_dim 128 --pin_memory \ 6 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \ 7 | --learning_rate 1. --batch_size 16384 --profile_dir "tensorboard_log/torchrec_avazu/w1_16k" 8 | 9 | torchx run -s local_cwd -cfg log_dir=log/torchrec_avazu/w2_16k dist.ddp -j 1x2 --script baselines/dlrm_main.py -- \ 10 | --in_memory_binary_criteo_path /data/avazu_sample --embedding_dim 128 --pin_memory \ 11 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \ 12 | --learning_rate 1. --batch_size 8192 --profile_dir "tensorboard_log/torchrec_avazu/w2_16k" 13 | 14 | torchx run -s local_cwd -cfg log_dir=log/torchrec_avazu/w4_16k dist.ddp -j 1x4 --script baselines/dlrm_main.py -- \ 15 | --in_memory_binary_criteo_path /data/avazu_sample --embedding_dim 128 --pin_memory \ 16 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \ 17 | --learning_rate 1. --batch_size 4096 --profile_dir "tensorboard_log/torchrec_avazu/w4_16k" 18 | 19 | torchx run -s local_cwd -cfg log_dir=log/torchrec_avazu/w8_16k dist.ddp -j 1x8 --script baselines/dlrm_main.py -- \ 20 | --in_memory_binary_criteo_path /data/avazu_sample --embedding_dim 128 --pin_memory \ 21 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \ 22 | --learning_rate 1. --batch_size 2048 --profile_dir "tensorboard_log/torchrec_avazu/w8_16k" -------------------------------------------------------------------------------- /scripts/torchrec_custom.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | export PYTHONPATH=$HOME/codes/torchrec:$PYTHONPATH 5 | 6 | set_n_least_used_CUDA_VISIBLE_DEVICES() { 7 | local n=${1:-"9999"} 8 | echo "GPU Memory Usage:" 9 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ 10 | | tail -n +2 \ 11 | | nl -v 0 \ 12 | | tee /dev/tty \ 13 | | sort -g -k 2 \ 14 | | awk '{print $1}' \ 15 | | head -n $n) 16 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') 17 | echo "Now CUDA_VISIBLE_DEVICES is set to:" 18 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 19 | } 20 | 21 | # export DATAPATH=/data/scratch/RecSys/embedding_bag 22 | export DATAPATH=custom 23 | export EVAL_ACC=0 24 | export EMB_DIM=128 25 | export POOLING_FACTOR=8 26 | 27 | if [[ ${EVAL_ACC} == 1 ]]; then 28 | EVAL_ACC_FLAG="--eval_acc" 29 | else 30 | export EVAL_ACC_FLAG="" 31 | fi 32 | 33 | 34 | mkdir -p logs 35 | for PREFETCH_NUM in 1 # 32 4 8 16 36 | do 37 | for GPUNUM in 1 2 4 8 # 4 8 # 1 # 2 38 | do 39 | for BATCHSIZE in 8192 #2048 4096 1024 #8192 512 ##16384 8192 4096 2048 1024 512 40 | do 41 | for SHARDTYPE in "table" # "tablecolumn" "column" "row" "tablerow" "table" 42 | do 43 | for KERNELTYPE in "colossalai" # "fused" # "uvm_lfu" # "colossalai" # "uvm_lfu" # "colossalai" 44 | do 45 | # For TorchRec baseline 46 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM} 47 | export PLAN=g${GPUNUM}_bs_${BATCHSIZE}_pool_${POOLING_FACTOR}_${SHARDTYPE}_pf_${PREFETCH_NUM}_eb_${EMB_DIM}_${KERNELTYPE}_custom 48 | rm -rf ./tensorboard_log/torchrec_custom/ 49 | # env CUDA_LAUNCH_BLOCKING=1 50 | # timeout -s SIGKILL 30m 51 | torchx run -s local_cwd -cfg log_dir=log/torchrec_custom/${PLAN} dist.ddp -j 1x${GPUNUM} --script baselines/dlrm_main.py -- \ 52 | --in_memory_binary_criteo_path ${DATAPATH} --kaggle --embedding_dim ${EMB_DIM} --pin_memory --cache_ratio 0.01 \ 53 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,${EMB_DIM}" --shuffle_batches \ 54 | --learning_rate 1. --batch_size ${BATCHSIZE} --profile_dir "" --shard_type ${SHARDTYPE} --kernel_type ${KERNELTYPE} \ 55 | --prefetch_num ${PREFETCH_NUM} --pooling_factor ${POOLING_FACTOR} ${EVAL_ACC_FLAG} 2>&1 | tee logs/torchrec_${PLAN}.txt 56 | done 57 | done 58 | done 59 | done 60 | done 61 | -------------------------------------------------------------------------------- /scripts/torchrec_kaggle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | # export PYTHONPATH=$HOME/codes/torchrec:$PYTHONPATH 4 | # export DATAPATH=/data/scratch/RecSys/criteo_kaggle_data/ 5 | export DATAPATH=/data/criteo_kaggle_data/ 6 | 7 | set_n_least_used_CUDA_VISIBLE_DEVICES() { 8 | local n=${1:-"9999"} 9 | echo "GPU Memory Usage:" 10 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ 11 | | tail -n +2 \ 12 | | nl -v 0 \ 13 | | tee /dev/tty \ 14 | | sort -g -k 2 \ 15 | | awk '{print $1}' \ 16 | | head -n $n) 17 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') 18 | echo "Now CUDA_VISIBLE_DEVICES is set to:" 19 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 20 | } 21 | 22 | export GPUNUM=1 23 | 24 | for EMB_DIM in 128 #64 96 25 | do 26 | for PREFETCH_NUM in 1 #1 8 16 32 27 | do 28 | for GPUNUM in 1 2 4 8 29 | do 30 | for KERNELTYPE in "colossalai" # "fused" # "colossalai" 31 | do 32 | for BATCHSIZE in 8192 #16384 8192 4096 2048 1024 512 33 | do 34 | for SHARDTYPE in "table" # "column" "row" "tablecolumn" "tablerow" "table" 35 | # for SHARDTYPE in "tablerow" 36 | do 37 | # For TorchRec baseline 38 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM} 39 | rm -rf ./tensorboard_log/torchrec_kaggle/w${GPUNUM}_${BATCHSIZE}_${SHARDTYPE} 40 | 41 | LOG_DIR=./logs/${KERNELTYPE}_${SHARDTYPE}_logs 42 | mkdir -p ${LOG_DIR} 43 | 44 | torchx run -s local_cwd -cfg log_dir=log/torchrec_kaggle/${PLAN} dist.ddp -j 1x${GPUNUM} --script baselines/dlrm_main.py -- \ 45 | --in_memory_binary_criteo_path ${DATAPATH} --kaggle --embedding_dim ${EMB_DIM} --pin_memory --cache_ratio 0.20 \ 46 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,${EMB_DIM}" --shuffle_batches --eval_acc \ 47 | --learning_rate 1. --batch_size ${BATCHSIZE} --profile_dir "" --shard_type ${SHARDTYPE} --kernel_type ${KERNELTYPE} --prefetch_num ${PREFETCH_NUM} ${EVAL_ACC_FLAG} 2>&1 | tee ./${LOG_DIR}/torchrec_${PLAN}.txt 48 | done 49 | done 50 | done 51 | done 52 | done 53 | done 54 | -------------------------------------------------------------------------------- /scripts/torchrec_synth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | export PYTHONPATH=$HOME/codes/torchrec:$PYTHONPATH 5 | 6 | set_n_least_used_CUDA_VISIBLE_DEVICES() { 7 | local n=${1:-"9999"} 8 | echo "GPU Memory Usage:" 9 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ 10 | | tail -n +2 \ 11 | | nl -v 0 \ 12 | | tee /dev/tty \ 13 | | sort -g -k 2 \ 14 | | awk '{print $1}' \ 15 | | head -n $n) 16 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') 17 | echo "Now CUDA_VISIBLE_DEVICES is set to:" 18 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 19 | } 20 | 21 | export DATAPATH=/data/scratch/RecSys/embedding_bag 22 | # export DATAPATH=custom 23 | export SCALE="512M" # "4M ""52M" "512M" 24 | export EVAL_ACC=0 25 | export EMB_DIM=128 26 | 27 | if [[ ${EVAL_ACC} == 1 ]]; then 28 | EVAL_ACC_FLAG="--eval_acc" 29 | else 30 | export EVAL_ACC_FLAG="" 31 | fi 32 | 33 | # local batch size 34 | # 4 35 | # export BATCHSIZE=1024 36 | # export BATCHSIZE=1024 37 | 38 | # export BATCHSIZE=4096 39 | # 2 40 | # export BATCHSIZE=8192 41 | # 1 42 | 43 | mkdir -p logs 44 | for PREFETCH_NUM in 1 32 4 8 16 #8 16 32 45 | do 46 | for GPUNUM in 1 # 1 # 2 47 | do 48 | for BATCHSIZE in 256 #2048 4096 1024 #8192 512 ##16384 8192 4096 2048 1024 512 49 | do 50 | for SHARDTYPE in "table" 51 | do 52 | for KERNELTYPE in "colossalai" # "uvm_lfu" # "colossalai" # "uvm_lfu" # "colossalai" 53 | do 54 | # For TorchRec baseline 55 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM} 56 | export PLAN=g${GPUNUM}_bs_${BATCHSIZE}_${SHARDTYPE}_pf_${PREFETCH_NUM}_eb_${EMB_DIM} 57 | rm -rf ./tensorboard_log/torchrec_synth/ 58 | # env CUDA_LAUNCH_BLOCKING=1 59 | # timeout -s SIGKILL 30m 60 | torchx run -s local_cwd -cfg log_dir=log/torchrec_synth/${PLAN} dist.ddp -j 1x${GPUNUM} --script baselines/dlrm_main.py -- \ 61 | --in_memory_binary_criteo_path ${DATAPATH} --kaggle --embedding_dim ${EMB_DIM} --pin_memory \ 62 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,${EMB_DIM}" --shuffle_batches \ 63 | --learning_rate 1. --batch_size ${BATCHSIZE} --profile_dir "" --shard_type ${SHARDTYPE} --kernel_type ${KERNELTYPE} \ 64 | --synth_size ${SCALE} \ 65 | --synth_size ${SCALE} --prefetch_num ${PREFETCH_NUM} ${EVAL_ACC_FLAG} 2>&1 | tee logs/torchrec_${PLAN}.txt 66 | done 67 | done 68 | done 69 | done 70 | done 71 | -------------------------------------------------------------------------------- /scripts/torchrec_terabyte.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # set -xsv 3 | 4 | export DATASETPATH=/data/scratch/RecSys/criteo_preproc/ 5 | export PYTHONPATH=$HOME/codes/torchrec:$PYTHONPATH 6 | 7 | set_n_least_used_CUDA_VISIBLE_DEVICES() { 8 | local n=${1:-"9999"} 9 | echo "GPU Memory Usage:" 10 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ 11 | | tail -n +2 \ 12 | | nl -v 0 \ 13 | | tee /dev/tty \ 14 | | sort -g -k 2 \ 15 | | awk '{print $1}' \ 16 | | head -n $n) 17 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') 18 | echo "Now CUDA_VISIBLE_DEVICES is set to:" 19 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" 20 | } 21 | 22 | export LOG_DIR="/data2/users/lcfjr/logs_1tb/b" 23 | mkdir -p ${LOG_DIR} 24 | 25 | export GPUNUM=2 26 | 27 | # prefetch mini-batch number. 28 | export PREFETCH_NUM=4 29 | export EVAL_ACC=0 30 | export KERNELTYPE="fused" 31 | export SHARDTYPE="table" 32 | export BATCHSIZE=1024 33 | export EMB_DIM=128 34 | export CACHERATIO=0.1 35 | 36 | if [[ ${EVAL_ACC} == 1 ]]; then 37 | EVAL_ACC_FLAG="--eval_acc" 38 | else 39 | export EVAL_ACC_FLAG="" 40 | fi 41 | 42 | 43 | 44 | batch_size_list=(8192) 45 | gpu_num_list=(1) 46 | 47 | 48 | for ((i=0;i<${#batch_size_list[@]};i++)); do 49 | 50 | for KERNELTYPE in "colossalai" 51 | do 52 | for CACHERATIO in 0.05 53 | do 54 | 55 | export BATCHSIZE=${batch_size_list[i]} 56 | export GPUNUM=${gpu_num_list[i]} 57 | 58 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM} 59 | 60 | export PLAN=k_${KERNELTYPE}_g_${GPUNUM}_bs_${BATCHSIZE}_sd_${SHARDTYPE}_pf_${PREFETCH_NUM}_eb_${EMB_DIM}_cache_${CACHERATIO} 61 | echo "training batchsize" ${BATCHSIZE} "gpunum" ${GPUNUM} 62 | 63 | echo "training batchsize" ${BATCHSIZE} "gpunum" ${GPUNUM} 64 | torchx run -s local_cwd -cfg log_dir=log/torchrec_terabyte/w1_16k dist.ddp -j 1x${GPUNUM} --script baselines/dlrm_main.py -- \ 65 | --in_memory_binary_criteo_path ${DATASETPATH} --embedding_dim ${EMB_DIM} --pin_memory \ 66 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \ 67 | --learning_rate 1. --batch_size ${BATCHSIZE} --shard_type ${SHARDTYPE} --kernel_type ${KERNELTYPE} --prefetch_num ${PREFETCH_NUM} ${EVAL_ACC_FLAG} \ 68 | --profile_dir "" ${EVAL_ACC_FLAG} --limit_train_samples 102400000 --cache_ratio ${CACHERATIO} 2>&1 | tee ${LOG_DIR}/torchrec_${PLAN}.txt 69 | 70 | done 71 | done 72 | done 73 | --------------------------------------------------------------------------------