├── .clang-format
├── .flake8
├── .gitignore
├── .pre-commit-config.yaml
├── .style.yapf
├── README.md
├── baselines
├── README.md
├── __init__.py
├── data
│ ├── __init__.py
│ ├── avazu.py
│ ├── custom.py
│ ├── dlrm_dataloader.py
│ └── synth.py
├── dlrm_main.py
└── models
│ ├── __init__.py
│ ├── deepfm.py
│ └── dlrm.py
├── benchmark
├── benchmark_cache.py
├── benchmark_fbgemm_uvm.py
└── data_utils.py
├── docker
├── Dockerfile
├── Dockerfile_thu
└── launch.sh
├── license
├── pics
└── prefetch.png
├── recsys
├── README.md
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── avazu.py
│ ├── criteo.py
│ ├── feature_counter.py
│ └── utils.py
├── dlrm_main.py
├── models
│ ├── __init__.py
│ └── dlrm.py
└── utils
│ ├── __init__.py
│ ├── dataloader
│ ├── __init__.py
│ ├── base_dataiter.py
│ └── cuda_stream_dataloader.py
│ ├── misc.py
│ └── preprocess_synth.py
├── requirements.txt
└── scripts
├── avazu.sh
├── kaggle.sh
├── preprocess
├── .gitignore
├── README.md
├── npy_preproc_avazu.py
├── npy_preproc_criteo.py
├── split_criteo_kaggle.py
└── taobao
│ ├── README.md
│ ├── csv_to_txt.py
│ ├── run_txt_to_npz.sh
│ └── txt_to_npz.py
├── run.sh
├── terabyte.sh
├── torchrec_avazu.sh
├── torchrec_custom.sh
├── torchrec_kaggle.sh
├── torchrec_synth.sh
└── torchrec_terabyte.sh
/.clang-format:
--------------------------------------------------------------------------------
1 | BasedOnStyle: Google
2 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | ;W503 line break before binary operator
4 | W503,
5 | ;E203 whitespace before ':'
6 | E203,
7 |
8 | ; exclude file
9 | exclude =
10 | .tox,
11 | .git,
12 | __pycache__,
13 | build,
14 | dist,
15 | *.pyc,
16 | *.egg-info,
17 | .cache,
18 | .eggs
19 |
20 | max-line-length = 120
21 |
22 | per-file-ignores = __init__.py:F401
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | *__pycache__
3 | criteo_kaggle/
4 |
5 | # Distribution / packaging
6 | build/
7 | dist/
8 | *.egg-info/
9 |
10 | # Logging
11 | *log
12 | *wandb
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/mirrors-yapf
3 | rev: v0.32.0
4 | hooks:
5 | - id: yapf
6 | args: ['--style=.style.yapf', '--parallel', '--in-place']
7 | - repo: https://github.com/pre-commit/mirrors-clang-format
8 | rev: v13.0.1
9 | hooks:
10 | - id: clang-format
11 |
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | based_on_style = google
3 | spaces_before_comment = 4
4 | split_before_logical_operator = true
5 | column_limit = 120
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## CachedEmbedding : larger embedding tables, smaller GPU memory budget.
2 |
3 | The embedding tables in deep learning recommendation system models are becoming extremly large and cannot be fit in GPU memory.
4 | This project provides an efficient way to train the extremely large recommendation system models.
5 | The entire training runs on GPU in a synchronized parameter updating manner.
6 |
7 | This project applies the CachedEmbedding, which extends the vanilla
8 | [PyTorch EmbeddingBag](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag)
9 | with the help from [ColossalAI](https://github.com/hpcaitech/ColossalAI).
10 | The CachedEmbedding use a [software cache approach](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.parallel.layers.html) to dynamically manage the extremely large embedding table in the CPU and GPU memory space.
11 | For example, this repo can train DLRM model including a **91.10 GB** embedding table on Criteo 1TB dataset allocating just **3.75 GB** CUDA memory on a single GPU!
12 |
13 | In order to reduce the overhead time of the Cache, we designed a "far-sighted" Cache mechanism.
14 | Instead of only performing cache operations on the first mini-batch, wefetches several mini-batches that will be used later, and performs Cache query operations together.
15 | It also uses a pipeline method to overlap the overhead of data loading and model training, which is shown in the following figures.
16 |
17 |
18 |
19 | Despite the extra cache indexing and CPU-GPU overhead, the end-to-end performance of our system drops very little compared to the torchrec.
20 | However, torchrec usually requires an order of magnitude more CUDA memory requirements.
21 | Also, our software cache is implemented using pytorch without any customized C++/CUDA kernels, and developers can customize or optimize it according to their needs.
22 |
23 | ### Dataset
24 | 1. [Criteo Kaggle](https://www.kaggle.com/c/avazu-ctr-prediction/data)
25 | 2. [Avazu](https://www.kaggle.com/c/avazu-ctr-prediction/data)
26 | 3. [Criteo 1TB](https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/)
27 |
28 | Basically, the preprocessing processes are derived from
29 | [Torchrec's utilities](https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/scripts/npy_preproc_criteo.py)
30 | and [Avazu kaggle community](https://www.kaggle.com/code/leejunseok97/deepfm-deepctr-torch)
31 | Please refer to `scripts/preprocess` dir to see the details.
32 |
33 | ### Usage
34 |
35 | 1. Installation Dependencies
36 |
37 | Install [ColossalAI](https://github.com/hpcaitech/ColossalAI) (commit id e8d8eda5e7a0619bd779e35065397679e1536dcd)
38 |
39 | https://github.com/hpcaitech/ColossalAI
40 |
41 | Install our customized [torchrec](https://github.com/hpcaitech/torchrec) (commit id e8d8eda5e7a0619bd779e35065397679e1536dcd)
42 |
43 | https://github.com/hpcaitech/torchrec
44 |
45 | Or, build a docker image using [docker/Dockerfile](./docker/Dockerfile).
46 | Or, use prebuilt docker image on dockerhub.
47 |
48 | ```
49 | docker pull hpcaitech/cacheembedding:0.2.2
50 | ```
51 |
52 | lauch a docker container.
53 |
54 | ```
55 | bash ./docker/launch.sh
56 | ```
57 |
58 | 2. Run
59 |
60 | All the commands to run DLRM on three datasets are presented in `scripts/run.sh`
61 | ```
62 | bash scripts/run.sh
63 | ```
64 |
65 | Set `--prefetch_num` to use prefetching.
66 |
67 | ### Model
68 | Currently, this repo only contains facebook DLRM models, and we are working on testing more recommendation models.
69 |
70 | ### Performance
71 |
72 | The DLRM performance on three datasets using ColossalAI version (this repo) and torchrec (with UVM) is shown as follows. The cache ratio of FreqAwareEmbedding is set as 1%. The evaluation is conducted on a single A100 (80GB memory) and AMD 7543 32-Core CPU (512GB memory).
73 |
74 | | | method | AUROC over Test after 1 Epoch | Acc over test | Throughput | Time to Train 1 Epoch | GPU memory allocated (GB) | GPU memory reserved (GB) | CPU memory usage (GB) |
75 | |:----------:|:----------:|:-----------------------------:|:-------------:|:----------:|:---------------------:|:-------------------------:|:------------------------:|:---------------------:|
76 | | criteo 1TB | ColossalAI | 0.791299403 | 0.967155457 | 42 it/s | 1h40m | 3.75 | 5.04 | 94.39 |
77 | | | torchrec | 0.79515636 | 0.967177451 | 45 it/s | 1h35m | 66.54 | 68.43 | 7.7 |
78 | | kaggle | ColossalAI | 0.776755869 | 0.779025435 | 50 it/s | 49s | 0.9 | 2.14 | 34.66 |
79 | | | torchrec | 0.786652029 | 0.782288849 | 81 it/s | 30s | 16.13 | 17.99 | 13.89 |
80 | | avazue | ColossalAI | 0.72732079 | 0.824390948 | 72 it/s | 31s | 0.31 | 1.06 | 16.89 |
81 | | | torchrec | 0.725972056 | 0.824484706 | 111 it/s | 21s | 4.53 | 5.83 | 12.25 |
82 |
83 | ### Cite us
84 | ```
85 | @article{fang2022frequency,
86 | title={A Frequency-aware Software Cache for Large Recommendation System Embeddings},
87 | author={Fang, Jiarui and Zhang, Geng and Han, Jiatong and Li, Shenggui and Bian, Zhengda and Li, Yongbin and Liu, Jin and You, Yang},
88 | journal={arXiv preprint arXiv:2208.05321},
89 | year={2022}
90 | }
91 | ```
92 |
--------------------------------------------------------------------------------
/baselines/README.md:
--------------------------------------------------------------------------------
1 | # FreqCacheEmbedding
2 |
3 | This repo contains the implementation of FreqCacheEmbedding, which extends the vanilla
4 | [PyTorch EmbeddingBag](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag)
5 | with cache mechanism to enable heterogeneous training for large scale recommendation models.
6 |
7 | ### Dataset
8 | 1. [Criteo Kaggle](https://www.kaggle.com/c/avazu-ctr-prediction/data)
9 | 2. [Avazu](https://www.kaggle.com/c/avazu-ctr-prediction/data)
10 |
11 | Basically, the preprocessing processes are derived from
12 | [Torchrec's utilities](https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/scripts/npy_preproc_criteo.py)
13 | and [Avazu kaggle community](https://www.kaggle.com/code/leejunseok97/deepfm-deepctr-torch)
14 | Please refer to `recsys/datasets/preprocess_scripts` dir to see the details.
15 |
16 | During the time this repo was built, another commonly adopted dataset,
17 | [Criteo 1TB](https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/)
18 | is unavailable (see this [issue](https://github.com/pytorch/torchrec/issues/245)).
19 | We will append its preprocessing & running scripts very soon.
20 |
21 | ### Command
22 | All the commands to run the FreqCacheEmbedding enabled recommendations models are presented in `run.sh`
23 |
24 | ### Model
25 | Currently, this repo only contains DLRM & DeepFM models,
26 | and we are working on testing more recommendation models.
--------------------------------------------------------------------------------
/baselines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/baselines/__init__.py
--------------------------------------------------------------------------------
/baselines/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/baselines/data/__init__.py
--------------------------------------------------------------------------------
/baselines/data/avazu.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import IterableDataset
8 | from torchrec.datasets.utils import PATH_MANAGER_KEY, Batch
9 | from torchrec.datasets.criteo import BinaryCriteoUtils
10 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
11 |
12 | CAT_FEATURE_COUNT = 13
13 | INT_FEATURE_COUNT = 8
14 | DEFAULT_LABEL_NAME = "click"
15 | DEFAULT_CAT_NAMES = [
16 | 'C1',
17 | 'banner_pos',
18 | 'site_id',
19 | 'site_domain',
20 | 'site_category',
21 | 'app_id',
22 | 'app_domain',
23 | 'app_category',
24 | 'device_id',
25 | 'device_ip',
26 | 'device_model',
27 | 'device_type',
28 | 'device_conn_type',
29 | ]
30 | DEFAULT_INT_NAMES = ['C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21']
31 | NUM_EMBEDDINGS_PER_FEATURE = '7,7,4737,7745,26,8552,559,36,2686408,6729486,8251,5,4' # 9445823 in total
32 | TOTAL_TRAINING_SAMPLES = 36_386_071 # 90% sample in train, 40428967 in total
33 |
34 |
35 | class AvazuIterDataPipe(IterableDataset):
36 |
37 | def __init__(
38 | self,
39 | dense_paths: List[str],
40 | sparse_paths: List[str],
41 | labels_paths: List[str],
42 | batch_size: int,
43 | rank: int,
44 | world_size: int,
45 | shuffle_batches: bool = False,
46 | mmap_mode: bool = False,
47 | hashes: Optional[List[int]] = None,
48 | path_manager_key: str = PATH_MANAGER_KEY,
49 | ) -> None:
50 | self.dense_paths = dense_paths
51 | self.sparse_paths = sparse_paths
52 | self.labels_paths = labels_paths
53 | self.batch_size = batch_size
54 | self.rank = rank
55 | self.world_size = world_size
56 | self.shuffle_batches = shuffle_batches
57 | self.mmap_mode = mmap_mode
58 | self.hashes = hashes
59 | self.path_manager_key = path_manager_key
60 |
61 | self._load_data_for_rank()
62 | self.num_rows_per_file: List[int] = [a.shape[0] for a in self.dense_arrs]
63 | self.num_batches: int = sum(self.num_rows_per_file) // batch_size
64 |
65 | # These values are the same for the KeyedJaggedTensors in all batches, so they
66 | # are computed once here. This avoids extra work from the KeyedJaggedTensor sync
67 | # functions.
68 | self._num_ids_in_batch: int = CAT_FEATURE_COUNT * batch_size
69 | self.keys: List[str] = DEFAULT_CAT_NAMES
70 | self.lengths: torch.Tensor = torch.ones((self._num_ids_in_batch,), dtype=torch.int32)
71 | self.offsets: torch.Tensor = torch.arange(0, self._num_ids_in_batch + 1, dtype=torch.int32)
72 | self.stride = batch_size
73 | self.length_per_key: List[int] = CAT_FEATURE_COUNT * [batch_size]
74 | self.offset_per_key: List[int] = [batch_size * i for i in range(CAT_FEATURE_COUNT + 1)]
75 | self.index_per_key: Dict[str, int] = {key: i for (i, key) in enumerate(self.keys)}
76 |
77 | def _load_data_for_rank(self) -> None:
78 | file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
79 | lengths=[
80 | BinaryCriteoUtils.get_shape_from_npy(path, path_manager_key=self.path_manager_key)[0]
81 | for path in self.dense_paths
82 | ],
83 | rank=self.rank,
84 | world_size=self.world_size,
85 | )
86 |
87 | self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], []
88 | for _dtype, arrs, paths in zip(
89 | [np.float32, np.int64, np.int32],
90 | [self.dense_arrs, self.sparse_arrs, self.labels_arrs],
91 | [self.dense_paths, self.sparse_paths, self.labels_paths],
92 | ):
93 | for idx, (range_left, range_right) in file_idx_to_row_range.items():
94 | arrs.append(
95 | BinaryCriteoUtils.load_npy_range(
96 | paths[idx],
97 | range_left,
98 | range_right - range_left + 1,
99 | path_manager_key=self.path_manager_key,
100 | mmap_mode=self.mmap_mode,
101 | ).astype(_dtype))
102 |
103 | # When mmap_mode is enabled, the hash is applied in def __iter__, which is
104 | # where samples are batched during training.
105 | # Otherwise, the ML dataset is preloaded, and the hash is applied here in
106 | # the preload stage, as shown:
107 | if not self.mmap_mode and self.hashes is not None:
108 | hashes_np = np.array(self.hashes).reshape((1, CAT_FEATURE_COUNT))
109 | for sparse_arr in self.sparse_arrs:
110 | sparse_arr %= hashes_np
111 |
112 | def _np_arrays_to_batch(self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray) -> Batch:
113 | if self.shuffle_batches:
114 | # Shuffle all 3 in unison
115 | shuffler = np.random.permutation(len(dense))
116 | dense = dense[shuffler]
117 | sparse = sparse[shuffler]
118 | labels = labels[shuffler]
119 |
120 | return Batch(
121 | dense_features=torch.from_numpy(dense),
122 | sparse_features=KeyedJaggedTensor(
123 | keys=self.keys,
124 | # transpose + reshape(-1) incurs an additional copy.
125 | values=torch.from_numpy(sparse.transpose(1, 0).reshape(-1)),
126 | lengths=self.lengths,
127 | offsets=self.offsets,
128 | stride=self.stride,
129 | length_per_key=self.length_per_key,
130 | offset_per_key=self.offset_per_key,
131 | index_per_key=self.index_per_key,
132 | ),
133 | labels=torch.from_numpy(labels.reshape(-1)),
134 | )
135 |
136 | def __iter__(self) -> Iterator[Batch]:
137 | # Invariant: buffer never contains more than batch_size rows.
138 | buffer: Optional[List[np.ndarray]] = None
139 |
140 | def append_to_buffer(dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray) -> None:
141 | nonlocal buffer
142 | if buffer is None:
143 | buffer = [dense, sparse, labels]
144 | else:
145 | for idx, arr in enumerate([dense, sparse, labels]):
146 | buffer[idx] = np.concatenate((buffer[idx], arr))
147 |
148 | # Maintain a buffer that can contain up to batch_size rows. Fill buffer as
149 | # much as possible on each iteration. Only return a new batch when batch_size
150 | # rows are filled.
151 | file_idx = 0
152 | row_idx = 0
153 | batch_idx = 0
154 | while batch_idx < self.num_batches:
155 | buffer_row_count = 0 if buffer is None else buffer[0].shape[0]
156 | if buffer_row_count == self.batch_size:
157 | yield self._np_arrays_to_batch(*buffer)
158 | batch_idx += 1
159 | buffer = None
160 | else:
161 | rows_to_get = min(
162 | self.batch_size - buffer_row_count,
163 | self.num_rows_per_file[file_idx] - row_idx,
164 | )
165 | slice_ = slice(row_idx, row_idx + rows_to_get)
166 |
167 | dense_inputs = self.dense_arrs[file_idx][slice_, :]
168 | sparse_inputs = self.sparse_arrs[file_idx][slice_, :]
169 | target_labels = self.labels_arrs[file_idx][slice_, :]
170 |
171 | if self.mmap_mode and self.hashes is not None:
172 | sparse_inputs = sparse_inputs % np.array(self.hashes).reshape((1, CAT_FEATURE_COUNT))
173 |
174 | append_to_buffer(
175 | dense_inputs,
176 | sparse_inputs,
177 | target_labels,
178 | )
179 | row_idx += rows_to_get
180 |
181 | if row_idx >= self.num_rows_per_file[file_idx]:
182 | file_idx += 1
183 | row_idx = 0
184 |
185 | def __len__(self) -> int:
186 | return self.num_batches
187 |
--------------------------------------------------------------------------------
/baselines/data/custom.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch import distributed as dist
4 | import numpy as np
5 | from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
6 | from colossalai.nn.parallel.layers.cache_embedding import CachedEmbeddingBag
7 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
8 | from fbgemm_gpu.split_table_batched_embeddings_ops import SplitTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, ComputeDevice, CacheAlgorithm
9 | import time
10 | import numpy as np
11 | import torch
12 | from torch.utils.data import IterableDataset
13 | from torchrec.datasets.utils import PATH_MANAGER_KEY, Batch
14 | from torchrec.datasets.criteo import BinaryCriteoUtils
15 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
16 | from torch.utils.data import DataLoader, IterableDataset
17 | import os
18 |
19 | # customizable factors:
20 | NUM_ROWS = int(2**25) # samples num
21 | E = [int(3e7), int(1e7), int(2e7), int(1e7), int(1e7), int(3e6), int(8e6), int(
22 | 1e7), int(1e6), int(1e6), int(1e6), int(1e6), int(5e6), int(4000), int(250), int(250)] # unique embeddings num
23 | s = 0.25 # long-tail skew
24 | POOLING_FACTOR = 2 # indices num per table per sample
25 |
26 |
27 | CAT_FEATURE_COUNT = len(E)
28 | NUM_EMBEDDINGS_PER_FEATURE = ""
29 | for e in E:
30 | NUM_EMBEDDINGS_PER_FEATURE += (str(e) + ',')
31 | NUM_EMBEDDINGS_PER_FEATURE = NUM_EMBEDDINGS_PER_FEATURE[:-1]
32 | DEFAULT_CAT_NAMES = ["cat_{}".format(i) for i in range(CAT_FEATURE_COUNT)]
33 | DEFAULT_LABEL_NAME = "click"
34 | DEFAULT_INT_NAMES = ['rand_dense']
35 | TOTAL_TRAINING_SAMPLES = NUM_ROWS
36 |
37 |
38 | def update_settings():
39 | global CAT_FEATURE_COUNT, DEFAULT_CAT_NAMES, NUM_EMBEDDINGS_PER_FEATURE
40 | CAT_FEATURE_COUNT = len(E)
41 | DEFAULT_CAT_NAMES = ["cat_{}".format(i) for i in range(CAT_FEATURE_COUNT)]
42 | NUM_EMBEDDINGS_PER_FEATURE = ""
43 | for e in E:
44 | NUM_EMBEDDINGS_PER_FEATURE += (str(e) + ',')
45 | NUM_EMBEDDINGS_PER_FEATURE = NUM_EMBEDDINGS_PER_FEATURE[:-1]
46 |
47 |
48 | class CustomIterDataPipe(IterableDataset):
49 | def __init__(
50 | self,
51 | batch_size: int,
52 | rank: int,
53 | world_size: int,
54 | shuffle_batches: bool = False,
55 | mmap_mode: bool = False,
56 | hashes: Optional[List[int]] = None,
57 | path_manager_key: str = PATH_MANAGER_KEY,
58 | ) -> None:
59 | self.batch_size = batch_size
60 | self.rank = rank
61 | self.world_size = world_size
62 | self.shuffle_batches = shuffle_batches
63 | self.mmap_mode = mmap_mode
64 | self.hashes = hashes
65 | self.path_manager_key = path_manager_key
66 | self.keys: List[str] = DEFAULT_CAT_NAMES
67 | self._num_ids_group_in_batch: int = CAT_FEATURE_COUNT * batch_size
68 | self.lengths: torch.Tensor = torch.ones((self._num_ids_group_in_batch,), dtype=torch.int32) * POOLING_FACTOR
69 | self.offsets: torch.Tensor = torch.arange(0, self._num_ids_group_in_batch + 1, dtype=torch.int32) * POOLING_FACTOR
70 | self.num_batches = NUM_ROWS // self.batch_size // self.world_size
71 | self.length_per_key: List[int] = CAT_FEATURE_COUNT * [batch_size * POOLING_FACTOR]
72 | self.offset_per_key: List[int] = [batch_size * i * POOLING_FACTOR for i in range(CAT_FEATURE_COUNT + 1)]
73 | self.index_per_key: Dict[str, int] = {key: i for (i, key) in enumerate(self.keys)}
74 | self.stride = batch_size
75 | # for random generation
76 | self.min_sample_list = [(1 / e) ** s for e in E]
77 | self.max_sample_list = [1.0 for e in E]
78 | # for iter
79 | self.iter_count = 0
80 |
81 | def __len__(self) -> int:
82 | return self.num_batches
83 |
84 | def __iter__(self) -> Iterator[Batch]:
85 | while self.iter_count < self.num_batches:
86 | indices = []
87 | for min_sample, max_sample in zip(self.min_sample_list, self.max_sample_list):
88 | # long-tail random idx generation
89 | rand_float = torch.rand(self.batch_size * POOLING_FACTOR, dtype=torch.float64)
90 | sample_float = rand_float * (max_sample - min_sample) + min_sample
91 | indices.append(torch.floor(1 / (sample_float ** (1 / s))).long() - 1)
92 | self.iter_count += 1
93 | yield self._make_batch(torch.cat(indices))
94 |
95 | def _make_batch(self, indices):
96 | ret = Batch(
97 | dense_features=torch.rand(self.stride, 1),
98 | sparse_features=KeyedJaggedTensor(
99 | keys=self.keys,
100 | values=indices,
101 | lengths=self.lengths,
102 | offsets=self.offsets,
103 | stride=self.stride,
104 | length_per_key=self.length_per_key,
105 | offset_per_key=self.offset_per_key,
106 | index_per_key=self.index_per_key,
107 | ),
108 | labels=torch.randint(2, (self.stride,))
109 | )
110 | return ret
111 |
112 | def get_custom_data_loader(
113 | args: argparse.Namespace,
114 | stage: str) -> DataLoader:
115 | rank = dist.get_rank()
116 | world_size = dist.get_world_size()
117 | if stage == "train":
118 | dataloader = DataLoader(
119 | CustomIterDataPipe(
120 | batch_size=args.batch_size,
121 | rank=rank,
122 | world_size=world_size,
123 | shuffle_batches=args.shuffle_batches,
124 | hashes=None
125 | ),
126 | batch_size=None,
127 | pin_memory=args.pin_memory,
128 | collate_fn=lambda x: x,
129 | )
130 | else :
131 | dataloader = []
132 | return dataloader
133 |
--------------------------------------------------------------------------------
/baselines/data/dlrm_dataloader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import argparse
9 | import os
10 | from typing import List, Optional, Tuple, Dict
11 | import glob
12 | import numpy as np
13 |
14 | import torch
15 | from torch import distributed as dist
16 | from torch.utils.data import DataLoader, IterableDataset
17 | from torchrec.datasets.criteo import (
18 | CAT_FEATURE_COUNT,
19 | DEFAULT_CAT_NAMES,
20 | DEFAULT_INT_NAMES,
21 | DEFAULT_LABEL_NAME,
22 | DAYS,
23 | InMemoryBinaryCriteoIterDataPipe,
24 | )
25 | from torchrec.datasets.random import RandomRecDataset
26 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
27 | from torchrec.datasets.utils import Batch
28 | from petastorm import make_batch_reader
29 | from pyarrow.parquet import ParquetDataset
30 | from .avazu import AvazuIterDataPipe
31 | from .synth import get_synth_data_loader
32 | from .custom import get_custom_data_loader
33 | STAGES = ["train", "val", "test"]
34 | KAGGLE_NUM_EMBEDDINGS_PER_FEATURE = '1460,583,10131227,2202608,305,24,12517,633,3,93145,5683,8351593,3194,27,14992,' \
35 | '5461306,10,5652,2173,4,7046547,18,15,286181,105,142572' # For criteo kaggle
36 | KAGGLE_TOTAL_TRAINING_SAMPLES = 39_291_954 # 0-6 days for criteo kaggle, 45,840,617 samples in total
37 | TERABYTE_NUM_EMBEDDINGS_PER_FEATURE = "45833188,36746,17245,7413,20243,3,7114,1441,62,29275261,1572176,345138,10," \
38 | "2209,11267,128,4,974,14,48937457,11316796,40094537,452104,12606,104,35"
39 |
40 |
41 | def _get_random_dataloader(args: argparse.Namespace,) -> DataLoader:
42 | return DataLoader(
43 | RandomRecDataset(
44 | keys=DEFAULT_CAT_NAMES,
45 | batch_size=args.batch_size,
46 | hash_size=args.num_embeddings,
47 | hash_sizes=args.num_embeddings_per_feature if hasattr(args, "num_embeddings_per_feature") else None,
48 | manual_seed=args.seed if hasattr(args, "seed") else None,
49 | ids_per_feature=1,
50 | num_dense=len(DEFAULT_INT_NAMES),
51 | ),
52 | batch_size=None,
53 | batch_sampler=None,
54 | pin_memory=args.pin_memory,
55 | num_workers=0,
56 | )
57 |
58 |
59 | def _get_in_memory_dataloader(
60 | args: argparse.Namespace,
61 | stage: str,
62 | ) -> DataLoader:
63 | files = os.listdir(args.in_memory_binary_criteo_path)
64 |
65 | def is_final_day(s: str) -> bool:
66 | return f"day_{(7 if args.kaggle else DAYS) - 1}" in s
67 |
68 | if stage == "train":
69 | # Train set gets all data except from the final day.
70 | files = list(filter(lambda s: not is_final_day(s), files))
71 | rank = dist.get_rank()
72 | world_size = dist.get_world_size()
73 | else:
74 | # Validation set gets the first half of the final day's samples. Test set get
75 | # the other half.
76 | files = list(filter(is_final_day, files))
77 | rank = (dist.get_rank() if stage == "val" else dist.get_rank() + dist.get_world_size())
78 | world_size = dist.get_world_size() * 2
79 |
80 | stage_files: List[List[str]] = [
81 | sorted(map(
82 | lambda x: os.path.join(args.in_memory_binary_criteo_path, x),
83 | filter(lambda s: kind in s, files),
84 | )) for kind in ["dense", "sparse", "labels"]
85 | ]
86 | dataloader = DataLoader(
87 | InMemoryBinaryCriteoIterDataPipe(
88 | *stage_files, # pyre-ignore[6]
89 | batch_size=args.batch_size,
90 | rank=rank,
91 | world_size=world_size,
92 | shuffle_batches=args.shuffle_batches,
93 | hashes=args.num_embeddings_per_feature if args.num_embeddings is None else
94 | ([args.num_embeddings] * CAT_FEATURE_COUNT),
95 | ),
96 | batch_size=None,
97 | pin_memory=args.pin_memory,
98 | collate_fn=lambda x: x,
99 | )
100 | return dataloader
101 |
102 |
103 | def get_avazu_data_loader(args, stage):
104 | files = os.listdir(args.in_memory_binary_criteo_path)
105 |
106 | if stage == "train":
107 | files = list(filter(lambda s: "train" in s, files))
108 | rank = dist.get_rank()
109 | world_size = dist.get_world_size()
110 | else:
111 | # Validation set gets the first half of the final day's samples. Test set get
112 | # the other half.
113 | files = list(filter(lambda s: "train" not in s, files))
114 | rank = (dist.get_rank() if stage == "val" else dist.get_rank() + dist.get_world_size())
115 | world_size = dist.get_world_size() * 2
116 |
117 | stage_files: List[List[str]] = [
118 | sorted(map(
119 | lambda x: os.path.join(args.in_memory_binary_criteo_path, x),
120 | filter(lambda s: kind in s, files),
121 | )) for kind in ["dense", "sparse", "label"]
122 | ]
123 |
124 | dataloader = DataLoader(
125 | AvazuIterDataPipe(
126 | *stage_files, # pyre-ignore[6]
127 | batch_size=args.batch_size,
128 | rank=rank,
129 | world_size=world_size,
130 | shuffle_batches=args.shuffle_batches,
131 | hashes=args.num_embeddings_per_feature if args.num_embeddings is None else
132 | ([args.num_embeddings] * CAT_FEATURE_COUNT),
133 | ),
134 | batch_size=None,
135 | pin_memory=args.pin_memory,
136 | collate_fn=lambda x: x,
137 | )
138 | return dataloader
139 |
140 |
141 | class PetastormDataReader(IterableDataset):
142 | """
143 | This is a compromise solution for the criteo terabyte dataset,
144 | please see the solution 3 in: https://github.com/uber/petastorm/issues/508
145 |
146 | Basically, the dataloader in each rank extracts random samples from the whole dataset in the training stage
147 | in which the batches in each rank are not guaranteed to be unique.
148 | In the validation stage, all the samples are evaluated in each rank,
149 | so that each rank contains the correct result
150 | """
151 |
152 | def __init__(self,
153 | paths,
154 | batch_size,
155 | rank=None,
156 | world_size=None,
157 | shuffle_batches=False,
158 | hashes=None,
159 | seed=1024,
160 | drop_last=True):
161 | self.dataset = ParquetDataset(paths, use_legacy_dataset=False)
162 | self.batch_size = batch_size
163 | self.rank = rank
164 | self.world_size = world_size
165 | self.shuffle_batches = shuffle_batches
166 | self.hashes = np.array(hashes).reshape((1, CAT_FEATURE_COUNT)) if hashes is not None else None
167 |
168 | self._num_ids_in_batch: int = CAT_FEATURE_COUNT * batch_size
169 | self.keys: List[str] = DEFAULT_CAT_NAMES
170 | self.lengths: torch.Tensor = torch.ones((self._num_ids_in_batch,), dtype=torch.int32)
171 | self.offsets: torch.Tensor = torch.arange(0, self._num_ids_in_batch + 1, dtype=torch.int32)
172 | self.stride = batch_size
173 | self.length_per_key: List[int] = CAT_FEATURE_COUNT * [batch_size]
174 | self.offset_per_key: List[int] = [batch_size * i for i in range(CAT_FEATURE_COUNT + 1)]
175 | self.index_per_key: Dict[str, int] = {key: i for (i, key) in enumerate(self.keys)}
176 | self.seed = seed
177 |
178 | self.drop_last = drop_last
179 | if drop_last:
180 | self.num_batches = sum([fragment.metadata.num_rows for fragment in self.dataset.fragments
181 | ]) // self.batch_size
182 | else:
183 | self.num_batches = (sum([fragment.metadata.num_rows
184 | for fragment in self.dataset.fragments]) + self.batch_size - 1) // self.batch_size
185 | if self.world_size is not None:
186 | self.num_batches = self.num_batches // world_size
187 |
188 | def __iter__(self):
189 | buffer: Optional[List[np.ndarray]] = None
190 | count = 0
191 |
192 | def append_to_buffer(_dense: np.ndarray, _sparse: np.ndarray, _labels: np.ndarray) -> None:
193 | nonlocal buffer
194 | if buffer is None:
195 | buffer = [_dense, _sparse, _labels]
196 | else:
197 | buffer[0] = np.concatenate([buffer[0], _dense], axis=0)
198 | buffer[1] = np.concatenate([buffer[1], _sparse], axis=1)
199 | buffer[2] = np.concatenate([buffer[2], _labels], axis=0)
200 |
201 | with make_batch_reader(
202 | list(map(lambda x: "file://" + x, self.dataset.files)),
203 | num_epochs=1,
204 | workers_count=1, # for reproducibility
205 | ) as reader:
206 | # note that `batch` here is just a bunch of samples read by petastorm instead of `batch` consumed by models
207 | for batch in reader:
208 | labels = getattr(batch, DEFAULT_LABEL_NAME)
209 | sparse = np.concatenate([getattr(batch, col_name).reshape(1, -1) for col_name in DEFAULT_CAT_NAMES],
210 | axis=0)
211 | dense = np.concatenate([getattr(batch, col_name).reshape(-1, 1) for col_name in DEFAULT_INT_NAMES],
212 | axis=1)
213 | start_idx = 0
214 | while start_idx < dense.shape[0]:
215 | buffer_size = 0 if buffer is None else buffer[0].shape[0]
216 | if buffer_size == self.batch_size:
217 | yield self._batch_ndarray(*buffer)
218 | buffer = None
219 | count += 1
220 | if count == self.num_batches:
221 | raise StopIteration()
222 | else:
223 | rows_to_get = min(self.batch_size - buffer_size, dense.shape[0] - start_idx)
224 | label_chunk = labels[start_idx:start_idx + rows_to_get]
225 | sparse_chunk = sparse[:, start_idx:start_idx + rows_to_get]
226 | dense_chunk = dense[start_idx:start_idx + rows_to_get, :]
227 | append_to_buffer(dense_chunk, sparse_chunk, label_chunk)
228 | start_idx += rows_to_get
229 | if buffer is not None and not self.drop_last:
230 | yield self._batch_ndarray(*buffer)
231 |
232 | def _batch_ndarray(self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray):
233 | if self.shuffle_batches:
234 | # Shuffle all 3 in unison
235 | shuffler = np.random.permutation(len(dense))
236 | dense = dense[shuffler]
237 | sparse = sparse[:, shuffler]
238 | labels = labels[shuffler]
239 |
240 | return Batch(
241 | dense_features=torch.from_numpy(dense),
242 | sparse_features=KeyedJaggedTensor(
243 | keys=self.keys,
244 | values=torch.from_numpy(sparse.reshape(-1)),
245 | lengths=self.lengths,
246 | offsets=self.offsets,
247 | stride=self.stride,
248 | length_per_key=self.length_per_key,
249 | offset_per_key=self.offset_per_key,
250 | index_per_key=self.index_per_key,
251 | ),
252 | labels=torch.from_numpy(labels.reshape(-1)),
253 | )
254 |
255 | def __len__(self):
256 | return self.num_batches
257 |
258 |
259 | def _get_petastorm_dataloader(args, stage):
260 | if stage == "train":
261 | data_split = "train"
262 | elif stage == "val":
263 | data_split = "validation"
264 | else:
265 | data_split = "test"
266 |
267 | file_num = len(glob.glob(os.path.join(args.in_memory_binary_criteo_path, data_split, "*.parquet")))
268 | files = [os.path.join(args.in_memory_binary_criteo_path, data_split, f"part_{i}.parquet") for i in range(file_num)]
269 |
270 | dataloader = DataLoader(PetastormDataReader(files,
271 | args.batch_size,
272 | rank=dist.get_rank() if stage == "train" else None,
273 | world_size=dist.get_world_size() if stage == "train" else None,
274 | hashes=args.num_embeddings_per_feature),
275 | batch_size=None,
276 | pin_memory=False,
277 | collate_fn=lambda x: x,
278 | num_workers=0)
279 |
280 | return dataloader
281 |
282 |
283 | def get_dataloader(args: argparse.Namespace, backend: str, stage: str) -> DataLoader:
284 | """
285 | Gets desired dataloader from dlrm_main command line options. Currently, this
286 | function is able to return either a DataLoader wrapped around a RandomRecDataset or
287 | a Dataloader wrapped around an InMemoryBinaryCriteoIterDataPipe.
288 |
289 | Args:
290 | args (argparse.Namespace): Command line options supplied to dlrm_main.py's main
291 | function.
292 | backend (str): "nccl" or "gloo".
293 | stage (str): "train", "val", or "test".
294 |
295 | Returns:
296 | dataloader (DataLoader): PyTorch dataloader for the specified options.
297 |
298 | """
299 | stage = stage.lower()
300 | if stage not in STAGES:
301 | raise ValueError(f"Supplied stage was {stage}. Must be one of {STAGES}.")
302 |
303 | args.pin_memory = ((backend == "nccl") if not hasattr(args, "pin_memory") else args.pin_memory)
304 |
305 | if (not hasattr(args, "in_memory_binary_criteo_path") or args.in_memory_binary_criteo_path is None):
306 | return _get_random_dataloader(args)
307 | elif "criteo" in args.in_memory_binary_criteo_path:
308 | if args.kaggle:
309 | return _get_in_memory_dataloader(args, stage)
310 | else:
311 | return _get_petastorm_dataloader(args, stage)
312 | elif "avazu" in args.in_memory_binary_criteo_path:
313 | return get_avazu_data_loader(args, stage)
314 | elif "embedding_bag" in args.in_memory_binary_criteo_path:
315 | # dlrm dataset: https://github.com/facebookresearch/dlrm_datasets
316 | return get_synth_data_loader(args, stage)
317 | elif "custom" in args.in_memory_binary_criteo_path:
318 | return get_custom_data_loader(args, stage)
319 |
--------------------------------------------------------------------------------
/baselines/data/synth.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from torch import distributed as dist
4 | import numpy as np
5 | from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
6 | from colossalai.nn.parallel.layers.cache_embedding import CachedEmbeddingBag
7 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
8 | from fbgemm_gpu.split_table_batched_embeddings_ops import SplitTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, ComputeDevice, CacheAlgorithm
9 | import time
10 | import numpy as np
11 | import torch
12 | from torch.utils.data import IterableDataset
13 | from torchrec.datasets.utils import PATH_MANAGER_KEY, Batch
14 | from torchrec.datasets.criteo import BinaryCriteoUtils
15 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
16 | from torch.utils.data import DataLoader, IterableDataset
17 | from torch.autograd.profiler import record_function
18 | import os
19 |
20 |
21 |
22 | # 52667139 as default
23 | CHOSEN_TABLES = [0, 2, 3, 4, 5, 7, 8, 9, 10, 12, 15, 18, 22, 27, 28]
24 | NUM_EMBEDDINGS_PER_FEATURE = '8015999, 9997799, 6138289, 21886, 204008, 6148, 282795, \
25 | 1316, 3639992, 319, 3394206, 12203324, 4091851, 11641, 4657566'
26 |
27 | CAT_FEATURE_COUNT = len(CHOSEN_TABLES)
28 | INT_FEATURE_COUNT = 1
29 | DEFAULT_LABEL_NAME = "click"
30 | DEFAULT_INT_NAMES = ['rand_dense']
31 | BATCH_SIZE = 65536 # batch_size of one file
32 | DEFAULT_CAT_NAMES = ["cat_{}".format(i) for i in range(len(CHOSEN_TABLES))]
33 |
34 | def choose_data_size(size: str):
35 | global CHOSEN_TABLES
36 | global NUM_EMBEDDINGS_PER_FEATURE
37 | global CAT_FEATURE_COUNT
38 | global DEFAULT_CAT_NAMES
39 | if size == '52M':
40 | pass
41 | elif size == '4M':
42 | # 4210897
43 | CHOSEN_TABLES = [5, 8, 37, 54, 71, 72, 73, 74, 85, 86, 89, 95, 96, 97, 107, 131, 163, 185, 196, 204, 211]
44 | NUM_EMBEDDINGS_PER_FEATURE = '204008, 282795, 539726, 153492, 11644, 11645, 13858, 5632, 60121, \
45 | 11711, 11645, 43335, 4843, 67919, 6539, 17076, 11579, 866124, 711855, 302001, 873349'
46 | elif size == '512M':
47 | # 512196316
48 | CHOSEN_TABLES = [301, 302, 303, 305, 306, 307, 309, 310, 311, 312, 313, 316, 317, 318, 319, 320, 321, 322, 323,
49 | 325, 326, 327, 328, 330, 335, 336, 337, 338, 340, 341, 343, 344, 345, 346, 347, 348, 349, 350,
50 | 351, 352, 353, 354, 356, 357, 358, 359, 360, 361, 362, 363, 365, 366, 367, 368, 370, 371, 372,
51 | 375, 378, 379, 381, 382, 383, 384, 385, 386, 388, 389, 390, 391, 392, 393, 394, 395, 396, 398,
52 | 399, 400, 401, 403, 405, 406, 407, 410, 413, 414, 415, 416, 417]
53 | NUM_EMBEDDINGS_PER_FEATURE = '5999929, 5999885, 5999976, 5999981, 5999901, 5999929, 5999885, 5987787, 6000000, 5999929, \
54 | 5998095, 3000000, 5999981, 2999993, 5999981, 5092210, 4999972, 5999976, 5998595, 5999548, 1999882, 4998224, 5999929, \
55 | 5014074, 5999986, 5999978, 5999941, 5999816, 5997022, 5999975, 5999685, 5999981, 5999738, 5999380, 5966699, 5975615, \
56 | 5908896, 5999996, 5999996, 5999983, 5734426, 5997022, 5999975, 5999929, 5999996, 5999239, 5989271, 5999477, 5999981, \
57 | 5999887, 5999929, 5999506, 5999996, 5999548, 5998472, 5922238, 5999975, 5987787, 2999964, 5999983, 5999930, 5979767, \
58 | 5999139, 5775261, 5999681, 4999929, 5963607, 5999967, 2999835, 5997068, 5998595, 5999996, 5992524, 5999997, 5999932, \
59 | 5999878, 5999929, 5999857, 5999981, 5999981, 5999796, 5999995, 5994671, 5999329, 5997068, 5999981, 5973566, 5999407, 5966699'
60 | elif size == '2G':
61 | # 2328942965
62 | CHOSEN_TABLES = [i for i in range(856)]
63 | NUM_EMBEDDINGS_PER_FEATURE = '8015999, 1, 9997799, 6138289, 21886, 204008, 220, 6148, 282795, 1316, 3639992, 1, 319, 1, 1, 3394206, 1, 1, 12203324, 1, 1, 1, 4091851, 1, 1, 1, 1, 11641, 4657566, 11645, 1815316, 6618925, 1, 1, 1146, 1, 1, 539726, 1, 1, 1, 4972387, 2169014, 1, 1, 3912105, 272, 3102763, 1, 1, 4230916, 5878, 1, 11645, 153492, 6618919, 1, 4868981, 1, 11709, 3985291, 1, 5956409, 1, 1, 1, 1, 1875019, 1, 381, 1, 11644, 11645, 13858, 5632, 1, 1, 6600102, 6618930, 1, 5412960, 371, 5272932, 2555079, 1, 60121, 11711, 2553205, 2647912, 11645, 1, 5798421, 350642, 1, 1, 43335, 4843, 67919, 1, 1, 3239310, 1, 1, 6855076, 1, 1, 1, 6539, 1, 1, 111, 5990989, 1, 6516585, 1, 68, 1, 5758479, 2448698, 1, 6618898, 2614073, 3309464, 1, 6107319, 1, 1, 1928793, 4535618, 1, 309, 17076, 4950876, 304795, 4970566, 11209763, 5585, 2207735, 6618921, 1941, 5659, 5690, 1029648, 5662, 4718, 6385214, 5641, 1150, 5653, 6618924, 1, 339750, 1, 6112009, 589094, 2844205, 1, 6618929, 1, 1, 5667, 5167062, 2542266, 11579, 6147171, 951851, 6448758, 5253, 826, 1, 1997119, 6363150, 6614703, 2199, 6461842, 913043, 1, 1, 1, 1, 1, 1283500, 1, 6316718, 11579, 866124, 3660331, 1, 4032709, 1, 1, 3232, 2065, 6584597, 1, 1, 711855, 5672538, 1, 248, 1, 1, 1, 1, 302001, 4006173, 1, 1, 19623, 1, 4673098, 873349, 8026000, 2323, 1680975, 1, 1, 5710807, 2999962, 5999910, 5925217, 4997507, 5999548, 2999938, 4999774, 5999707, 5999710, 5764956, 5999992, 1, 2999941, 5982534, 1, 5999927, 4978274, 5999983, 5999997, 5999912, 5908896, 5999955, 5999935, 5999836, 5999983, 1, 5999477, 5999805, 5998095, 1, 5989511, 1999998, 4999998, 6000000, 5999929, 1, 5999993, 1, 1, 5999885, 5999867, 5999929, 1, 1, 5999962, 1, 2999898, 5998777, 5999934, 1, 1, 5992524, 5999737, 2999538, 5999870, 5992524, 5999975, 1, 5710807, 1, 1, 4932124, 5918154, 1, 5997068, 5999982, 1, 5998551, 5999994, 5999870, 4999919, 5999944, 5999904, 4999740, 5922605, 5975615, 5999816, 998848, 5999926, 5999816, 4999991, 5999861, 1, 5999929, 5999885, 5999976, 1, 5999981, 5999901, 5999929, 1, 5999885, 5987787, 6000000, 5999929, 5998095, 1, 1, 3000000, 5999981, 2999993, 5999981, 5092210, 4999972, 5999976, 5998595, 1, 5999548, 1999882, 4998224, 5999929, 1, 5014074, 1, 1, 1, 1, 5999986, 5999978, 5999941, 5999816, 1, 5997022, 5999975, 1, 5999685, 5999981, 5999738, 5999380, 5966699, 5975615, 5908896, 5999996, 5999996, 5999983, 5734426, 5997022, 1, 5999975, 5999929, 5999996, 5999239, 5989271, 5999477, 5999981, 5999887, 1, 5999929, 5999506, 5999996, 5999548, 1, 5998472, 5922238, 5999975, 1, 1, 5987787, 1, 1, 2999964, 5999983, 1, 5999930, 5979767, 5999139, 5775261, 5999681, 4999929, 1, 5963607, 5999967, 2999835, 5997068, 5998595, 5999996, 5992524, 5999997, 5999932, 1, 5999878, 5999929, 5999857, 5999981, 1, 5999981, 1, 5999796, 5999995, 5994671, 1, 1, 5999329, 1, 1, 5997068, 5999981, 5973566, 5999407, 5966699, 5966699, 5734426, 5999975, 5999976, 1, 1, 1, 2982258, 5999816, 5999929, 5999981, 1, 5999974, 5998583, 5966699, 5999870, 5999828, 5997903, 5999854, 5999685, 1, 1, 1, 5999954, 5999981, 1, 1, 5999967, 5999681, 5999932, 1, 5999857, 5971899, 5999972, 5999932, 1, 5999979, 1, 5998507, 5999927, 5999981, 5998595, 5999975, 1, 5949097, 5999239, 5999821, 1, 1, 5922605, 1, 5998473, 5999878, 5992217, 5999600, 1, 1, 5965155, 5999932, 5999971, 5922605, 5999854, 1, 5999963, 1, 5998777, 5999816, 1, 5999594, 1, 5733183, 5999396, 1, 1, 1, 5999440, 1, 5871486, 5999975, 1, 5999911, 5963607, 5997079, 1, 5999772, 1, 5999857, 5999863, 5999477, 5966699, 5999975, 5999996, 5999976, 1, 5999981, 6, 135, 422, 509, 922, 13, 16, 2, 1000, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1245, 5, 3136, 2704, 4999983, 4999983, 4999937, 4000, 1, 1, 1, 1, 1, 1, 1, 5, 6, 5, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 8, 2, 2, 1, 1, 2, 2, 8, 999908, 1443441, 21812, 21656, 14, 14, 1, 61, 2043, 2, 137, 13, 5, 10, 10, 10, 10, 9, 10, 12, 11, 10, 11, 11, 16, 10, 10, 11, 16, 1, 1, 1, 12543437, 1, 4999993, 4999986, 3999995, 4999863, 12543669, 1, 4999998, 12543650, 4999985, 12540349, 1, 1, 12532399, 4999998, 12540535, 397606, 1415772, 12270355, 9765324, 1, 1, 10287986, 1, 10794735, 10498728, 10965644, 4667, 43200, 43200, 43200, 43200, 629, 226, 215, 24, 4999983, 1028203, 20, 10, 1, 1, 1, 1, 12542, 2, 85, 2, 2, 2, 8, 2, 4, 2, 2, 2, 2, 2, 4, 6, 2, 2, 2, 8, 4, 9810, 2, 5884133, 1, 1, 2, 4, 2, 2, 1, 2, 4, 2, 18, 18, 2, 4, 6137684, 3309463, 1, 5607839, 1, 1, 1, 6567668, 1, 1, 4093617, 1, 4390473, 6305161, 1, 1, 1, 3779483, 1, 5303395, 1, 1, 6618931, 1, 1, 3640447, 3102628, 2542714, 269, 1, 1, 612, 6107445, 3978163, 6607315, 4868894, 1, 3983088, 5419139, 1, 271, 3911645, 2553341, 983482, 272, 1, 1600937, 1, 1266741, 1520037, 3704018, 1, 1, 3345638, 6618817, 3117219, 1, 1, 1877662, 1876652, 3309463, 4378405, 847629, 1661, 621, 624, 3667331, 1, 269, 2614200, 1, 1, 1, 6618759, 204, 1, 6618922, 1, 5998824, 5999974, 4977045, 5999994, 5999995, 5999993, 5999981, 5999957, 5999908, 5984869, 5999994, 1, 5999548, 1, 1, 5999831, 5999978, 5999396, 5999908, 5999953, 1, 1, 1, 5999750, 5999958, 5999477, 1, 5999981, 5999548, 5999953, 1, 5999548, 1, 1, 5999986, 5999975, 1, 5999908, 1, 1, 5999975, 1, 5999548, 5998836, 5999477, 5999737, 5999708, 5999737, 5999783, 1, 1, 5999901, 5999708, 5999711, 5999967, 5999548, 5999548, 1, 1, 5999783, 1, 5999694, 5999520, 5999975, 1, 1, 1, 5999477, 1, 5999975, 5991327, 5975615, 1, 4494193, 5999918, 1, 5999725, 1, 5999995, 1, 1, 5999477, 4998981, 5999975, 5999329, 5976924, 1, 2, 4999540, 5999138, 3994232, 1'
64 | else:
65 | raise NotImplementedError()
66 | CAT_FEATURE_COUNT = len(CHOSEN_TABLES)
67 | DEFAULT_CAT_NAMES = ["cat_{}".format(i) for i in range(len(CHOSEN_TABLES))]
68 |
69 | class SynthIterDataPipe(IterableDataset):
70 | def __init__(
71 | self,
72 | sparse_paths: List[str],
73 | batch_size: int,
74 | rank: int,
75 | world_size: int,
76 | shuffle_batches: bool = False,
77 | mmap_mode: bool = False,
78 | hashes: Optional[List[int]] = None,
79 | path_manager_key: str = PATH_MANAGER_KEY,
80 | ) -> None:
81 | self.sparse_paths = sparse_paths,
82 | self.batch_size = batch_size
83 | self.rank = rank
84 | self.world_size = world_size
85 | self.shuffle_batches = shuffle_batches
86 | self.mmap_mode = mmap_mode
87 | self.hashes = hashes
88 | self.path_manager_key = path_manager_key
89 |
90 | self.indices_per_table_per_file = []
91 | self.offsets_per_table_per_file = []
92 | self.lengths_per_table_per_file = []
93 | self.num_rows_per_file = []
94 |
95 | self._buffer = None
96 | for file in self.sparse_paths[0]:
97 | print("load file: ", file)
98 | indices_per_table, offsets_per_table, lengths_per_table = self._load_single_file(file)
99 | self.indices_per_table_per_file.append(indices_per_table)
100 | self.offsets_per_table_per_file.append(offsets_per_table)
101 | self.lengths_per_table_per_file.append(lengths_per_table)
102 | self.num_rows_per_file.append(offsets_per_table[0].shape[0])
103 | self.num_batches = sum(self.num_rows_per_file) // self.batch_size
104 | self.keys: List[str] = DEFAULT_CAT_NAMES
105 | self.stride = batch_size
106 |
107 | def __iter__(self) -> Iterator[Batch]:
108 | # self._buffer structure:
109 | '''
110 | self._buffer[0]: List of sparse_indices per table
111 | self._buffer[1]: List of sparse_lengths per table
112 | '''
113 | def append_to_buffer(sparse_indices: List[torch.Tensor], sparse_lengths: List[torch.Tensor]):
114 | if self._buffer is None:
115 | self._buffer = [sparse_indices, sparse_lengths]
116 | else:
117 | for tb_idx, (sparse_indices_table, sparse_lengths_table) in enumerate(zip(sparse_indices, sparse_lengths)):
118 | self._buffer[0][tb_idx] = torch.cat((self._buffer[0][tb_idx], sparse_indices_table))
119 | self._buffer[1][tb_idx] = torch.cat((self._buffer[1][tb_idx], sparse_lengths_table))
120 |
121 | file_idx = 0
122 | row_idx = 0
123 | batch_idx = 0
124 | while batch_idx < self.num_batches:
125 | buffer_row_count = 0 if self._buffer is None else self._buffer[1][0].shape[0]
126 | if buffer_row_count == self.batch_size:
127 | yield self._make_batch(*self._buffer)
128 | batch_idx += 1
129 | self._buffer = None
130 | else:
131 | rows_to_get = min(
132 | self.batch_size - buffer_row_count,
133 | BATCH_SIZE - row_idx
134 | )
135 | # slice_ = slice(row_idx, row_idx + rows_to_get)
136 | with record_function("## load_batch ##"):
137 | sparse_indices, sparse_lengths = self._load_slice_batch(self.indices_per_table_per_file[file_idx],
138 | self.offsets_per_table_per_file[file_idx],
139 | self.lengths_per_table_per_file[file_idx],
140 | row_idx,
141 | rows_to_get
142 | )
143 | append_to_buffer(sparse_indices, sparse_lengths)
144 | row_idx += rows_to_get
145 | if row_idx >= BATCH_SIZE:
146 | file_idx += 1
147 | row_idx = 0
148 |
149 | def __len__(self) -> int:
150 | return self.num_batches
151 |
152 | def _load_single_file(self, file_path):
153 | # TODO: shard loading
154 | # file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
155 | # lengths=[BATCH_SIZE],
156 | # rank=self.rank,
157 | # world_size=self.world_size
158 | # )
159 | # for _,(range_left, range_right) in file_idx_to_row_range.items():
160 | # rank_range_left = range_left
161 | # rank_range_right = range_right
162 | rank_range_left = 0
163 | rank_range_right = 65535
164 | indices, offsets, lengths = torch.load(file_path)
165 | indices = indices.int()
166 | offsets = offsets.int()
167 | lengths = lengths.int()
168 | indices_per_table = []
169 | offsets_per_table = []
170 | lengths_per_table = []
171 | for i in CHOSEN_TABLES:
172 | start_pos = offsets[i * BATCH_SIZE + rank_range_left]
173 | end_pos = offsets[i * BATCH_SIZE + rank_range_right + 1]
174 | part = indices[start_pos:end_pos]
175 | indices_per_table.append(part)
176 | lengths_per_table.append(lengths[i][rank_range_left:rank_range_right + 1])
177 | offsets_per_table.append(torch.cumsum(
178 | torch.cat((torch.tensor([0]), lengths[i][rank_range_left:rank_range_right + 1])), 0
179 | ))
180 | return indices_per_table, offsets_per_table, lengths_per_table
181 |
182 | def _load_random_batch(self, indices_per_table, offsets_per_table, lengths_per_table, choose):
183 | chosen_indices_list = []
184 | chosen_lengths_list = []
185 | # choose = torch.randint(0, offsets_per_table[0].shape[0] - 1, (self.batch_size,))
186 | for indices, offsets, lengths in zip(indices_per_table, offsets_per_table, lengths_per_table):
187 | chosen_lengths_list.append(lengths[choose])
188 | start_list = offsets[choose]
189 | end_list = offsets[1:][choose]
190 | chosen_indices_atoms = []
191 | for start, end in zip(start_list, end_list):
192 | chosen_indices_atoms.append(indices[start: end])
193 | chosen_indices_list.append(torch.cat(chosen_indices_atoms, 0))
194 | return chosen_indices_list, chosen_lengths_list
195 |
196 | def _load_slice_batch(self, indices_per_table, offsets_per_table, lengths_per_table, row_start, row_length):
197 | chosen_indices_list = []
198 | chosen_lengths_list = []
199 | for indices, offsets, lengths in zip(indices_per_table, offsets_per_table, lengths_per_table):
200 | chosen_lengths_list.append(lengths.narrow(0, row_start, row_length))
201 | start = offsets[row_start]
202 | end = offsets[row_start + row_length]
203 | chosen_indices_list.append(indices.narrow(0, start, end - start))
204 | return chosen_indices_list, chosen_lengths_list
205 |
206 | def _make_batch(self, chosen_indices_list, chosen_lengths_list):
207 | batch_size = chosen_lengths_list[0].shape[0]
208 | ret = Batch(
209 | dense_features=torch.rand(batch_size,1),
210 | sparse_features=KeyedJaggedTensor(
211 | keys=self.keys,
212 | values=torch.cat(chosen_indices_list),
213 | lengths=torch.cat(chosen_lengths_list),
214 | stride=batch_size,
215 | ),
216 | labels=torch.randint(2, (batch_size,))
217 | )
218 | return ret
219 | def get_synth_data_loader(
220 | args: argparse.Namespace,
221 | stage: str) -> DataLoader:
222 | files = os.listdir(args.in_memory_binary_criteo_path)
223 | files = filter(lambda s: "fbgemm_t856_bs65536_" in s, files)
224 | files = [os.path.join(args.in_memory_binary_criteo_path, x) for x in files]
225 | rank = dist.get_rank()
226 | world_size = dist.get_world_size()
227 | if stage == "train":
228 | dataloader = DataLoader(
229 | SynthIterDataPipe(
230 | files,
231 | batch_size=args.batch_size,
232 | rank=rank,
233 | world_size=world_size,
234 | shuffle_batches=args.shuffle_batches,
235 | hashes = None
236 | ),
237 | batch_size=None,
238 | pin_memory=args.pin_memory,
239 | collate_fn=lambda x: x,
240 | )
241 | else :
242 | dataloader = []
243 | return dataloader
--------------------------------------------------------------------------------
/baselines/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .dlrm import DLRMTrain
2 |
3 | __all__ = ['DLRMTrain']
4 |
--------------------------------------------------------------------------------
/baselines/models/deepfm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from typing import List
9 |
10 | import torch
11 | from torch import nn
12 | from torchrec import EmbeddingBagCollection, KeyedJaggedTensor
13 | from torchrec.modules.deepfm import DeepFM, FactorizationMachine
14 | from torchrec.sparse.jagged_tensor import KeyedTensor
15 |
16 |
17 | class SparseArch(nn.Module):
18 | """
19 | Processes the sparse features of the DeepFMNN model. Does embedding lookups for all
20 | EmbeddingBag and embedding features of each collection.
21 | Args:
22 | embedding_bag_collection (EmbeddingBagCollection): represents a collection of
23 | pooled embeddings.
24 | Example::
25 | eb1_config = EmbeddingBagConfig(
26 | name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
27 | )
28 | eb2_config = EmbeddingBagConfig(
29 | name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
30 | )
31 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
32 | # 0 1 2 <-- batch
33 | # 0 [0,1] None [2]
34 | # 1 [3] [4] [5,6,7]
35 | # ^
36 | # feature
37 | features = KeyedJaggedTensor.from_offsets_sync(
38 | keys=["f1", "f2"],
39 | values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
40 | offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
41 | )
42 | sparse_arch(features)
43 | """
44 |
45 | def __init__(self, embedding_bag_collection: EmbeddingBagCollection) -> None:
46 | super().__init__()
47 | self.embedding_bag_collection: EmbeddingBagCollection = embedding_bag_collection
48 |
49 | def forward(
50 | self,
51 | features: KeyedJaggedTensor,
52 | ) -> KeyedTensor:
53 | """
54 | Args:
55 | features (KeyedJaggedTensor):
56 | Returns:
57 | KeyedJaggedTensor: an output KJT of size F * D X B.
58 | """
59 | return self.embedding_bag_collection(features)
60 |
61 |
62 | class DenseArch(nn.Module):
63 | """
64 | Processes the dense features of DeepFMNN model. Output layer is sized to
65 | the embedding_dimension of the EmbeddingBagCollection embeddings.
66 | Args:
67 | in_features (int): dimensionality of the dense input features.
68 | hidden_layer_size (int): sizes of the hidden layers in the DenseArch.
69 | embedding_dim (int): the same size of the embedding_dimension of sparseArch.
70 | device (torch.device): default compute device.
71 | Example::
72 | B = 20
73 | D = 3
74 | in_features = 10
75 | dense_arch = DenseArch(in_features=10, hidden_layer_size=10, embedding_dim=D)
76 | dense_embedded = dense_arch(torch.rand((B, 10)))
77 | """
78 |
79 | def __init__(
80 | self,
81 | in_features: int,
82 | hidden_layer_size: int,
83 | embedding_dim: int,
84 | ) -> None:
85 | super().__init__()
86 | self.model: nn.Module = nn.Sequential(
87 | nn.Linear(in_features, hidden_layer_size),
88 | nn.ReLU(),
89 | nn.Linear(hidden_layer_size, embedding_dim),
90 | nn.ReLU(),
91 | )
92 |
93 | def forward(self, features: torch.Tensor) -> torch.Tensor:
94 | """
95 | Args:
96 | features (torch.Tensor): size B X `num_features`.
97 | Returns:
98 | torch.Tensor: an output tensor of size B X D.
99 | """
100 | return self.model(features)
101 |
102 |
103 | class FMInteractionArch(nn.Module):
104 | """
105 | Processes the output of both `SparseArch` (sparse_features) and `DenseArch`
106 | (dense_features) and apply the general DeepFM interaction according to the
107 | external source of DeepFM paper: https://arxiv.org/pdf/1703.04247.pdf
108 | The output dimension is expected to be a cat of `dense_features`, D.
109 | Args:
110 | fm_in_features (int): the input dimension of `dense_module` in DeepFM. For
111 | example, if the input embeddings is [randn(3, 2, 3), randn(3, 4, 5)], then
112 | the `fm_in_features` should be: 2 * 3 + 4 * 5.
113 | sparse_feature_names (List[str]): length of F.
114 | deep_fm_dimension (int): output of the deep interaction (DI) in the DeepFM arch.
115 | Example::
116 | D = 3
117 | B = 10
118 | keys = ["f1", "f2"]
119 | F = len(keys)
120 | fm_inter_arch = FMInteractionArch(sparse_feature_names=keys)
121 | dense_features = torch.rand((B, D))
122 | sparse_features = KeyedTensor(
123 | keys=keys,
124 | length_per_key=[D, D],
125 | values=torch.rand((B, D * F)),
126 | )
127 | cat_fm_output = fm_inter_arch(dense_features, sparse_features)
128 | """
129 |
130 | def __init__(
131 | self,
132 | fm_in_features: int,
133 | sparse_feature_names: List[str],
134 | deep_fm_dimension: int,
135 | ) -> None:
136 | super().__init__()
137 | self.sparse_feature_names: List[str] = sparse_feature_names
138 | self.deep_fm = DeepFM(
139 | dense_module=nn.Sequential(
140 | nn.Linear(fm_in_features, deep_fm_dimension),
141 | nn.ReLU(),
142 | )
143 | )
144 | self.fm = FactorizationMachine()
145 |
146 | def forward(
147 | self, dense_features: torch.Tensor, sparse_features: KeyedTensor
148 | ) -> torch.Tensor:
149 | """
150 | Args:
151 | dense_features (torch.Tensor): tensor of size B X D.
152 | sparse_features (KeyedJaggedTensor): KJT of size F * D X B.
153 | Returns:
154 | torch.Tensor: an output tensor of size B X (D + DI + 1).
155 | """
156 | if len(self.sparse_feature_names) == 0:
157 | return dense_features
158 |
159 | tensor_list: List[torch.Tensor] = [dense_features]
160 | # dense/sparse interaction
161 | # size B X F
162 | for feature_name in self.sparse_feature_names:
163 | tensor_list.append(sparse_features[feature_name])
164 |
165 | deep_interaction = self.deep_fm(tensor_list)
166 | fm_interaction = self.fm(tensor_list)
167 |
168 | return torch.cat([dense_features, deep_interaction, fm_interaction], dim=1)
169 |
170 |
171 | class OverArch(nn.Module):
172 | """
173 | Final Arch - simple MLP. The output is just one target.
174 | Args:
175 | in_features (int): the output dimension of the interaction arch.
176 | Example::
177 | B = 20
178 | over_arch = OverArch()
179 | logits = over_arch(torch.rand((B, 10)))
180 | """
181 |
182 | def __init__(self, in_features: int) -> None:
183 | super().__init__()
184 | self.model: nn.Module = nn.Sequential(
185 | nn.Linear(in_features, 1),
186 | nn.Sigmoid(),
187 | )
188 |
189 | def forward(self, features: torch.Tensor) -> torch.Tensor:
190 | """
191 | Args:
192 | features (torch.Tensor):
193 | Returns:
194 | torch.Tensor: an output tensor of size B X 1.
195 | """
196 | return self.model(features)
197 |
198 |
199 | class SimpleDeepFMNN(nn.Module):
200 | """
201 | Basic recsys module with DeepFM arch. Processes sparse features by
202 | learning pooled embeddings for each feature. Learns the relationship between
203 | dense features and sparse features by projecting dense features into the same
204 | embedding space. Learns the interaction among those dense and sparse features
205 | by deep_fm proposed in this paper: https://arxiv.org/pdf/1703.04247.pdf
206 | The module assumes all sparse features have the same embedding dimension
207 | (i.e, each `EmbeddingBagConfig` uses the same embedding_dim)
208 | The following notation is used throughout the documentation for the models:
209 | * F: number of sparse features
210 | * D: embedding_dimension of sparse features
211 | * B: batch size
212 | * num_features: number of dense features
213 | Args:
214 | num_dense_features (int): the number of input dense features.
215 | embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags
216 | used to define `SparseArch`.
217 | hidden_layer_size (int): the hidden layer size used in dense module.
218 | deep_fm_dimension (int): the output layer size used in `deep_fm`'s deep
219 | interaction module.
220 | Example::
221 | B = 2
222 | D = 8
223 | eb1_config = EmbeddingBagConfig(
224 | name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
225 | )
226 | eb2_config = EmbeddingBagConfig(
227 | name="t2",
228 | embedding_dim=D,
229 | num_embeddings=100,
230 | feature_names=["f2"],
231 | )
232 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
233 | sparse_nn = SimpleDeepFMNN(
234 | embedding_bag_collection=ebc, hidden_layer_size=20, over_embedding_dim=5
235 | )
236 | features = torch.rand((B, 100))
237 | # 0 1
238 | # 0 [1,2] [4,5]
239 | # 1 [4,3] [2,9]
240 | # ^
241 | # feature
242 | sparse_features = KeyedJaggedTensor.from_offsets_sync(
243 | keys=["f1", "f3"],
244 | values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]),
245 | offsets=torch.tensor([0, 2, 4, 6, 8]),
246 | )
247 | logits = sparse_nn(
248 | dense_features=features,
249 | sparse_features=sparse_features,
250 | )
251 | """
252 |
253 | def __init__(
254 | self,
255 | num_dense_features: int,
256 | embedding_bag_collection: EmbeddingBagCollection,
257 | hidden_layer_size: int,
258 | deep_fm_dimension: int,
259 | ) -> None:
260 | super().__init__()
261 | assert (
262 | len(embedding_bag_collection.embedding_bag_configs()) > 0
263 | ), "At least one embedding bag is required"
264 | for i in range(1, len(embedding_bag_collection.embedding_bag_configs())):
265 | conf_prev = embedding_bag_collection.embedding_bag_configs()[i - 1]
266 | conf = embedding_bag_collection.embedding_bag_configs()[i]
267 | assert (
268 | conf_prev.embedding_dim == conf.embedding_dim
269 | ), "All EmbeddingBagConfigs must have the same dimension"
270 | embedding_dim: int = embedding_bag_collection.embedding_bag_configs()[
271 | 0
272 | ].embedding_dim
273 |
274 | feature_names = []
275 |
276 | fm_in_features = embedding_dim
277 | for conf in embedding_bag_collection.embedding_bag_configs():
278 | for feat in conf.feature_names:
279 | feature_names.append(feat)
280 | fm_in_features += conf.embedding_dim
281 |
282 | self.sparse_arch = SparseArch(embedding_bag_collection)
283 | self.dense_arch = DenseArch(
284 | in_features=num_dense_features,
285 | hidden_layer_size=hidden_layer_size,
286 | embedding_dim=embedding_dim,
287 | )
288 | self.inter_arch = FMInteractionArch(
289 | fm_in_features=fm_in_features,
290 | sparse_feature_names=feature_names,
291 | deep_fm_dimension=deep_fm_dimension,
292 | )
293 | over_in_features = embedding_dim + deep_fm_dimension + 1
294 | self.over_arch = OverArch(over_in_features)
295 |
296 | def forward(
297 | self,
298 | dense_features: torch.Tensor,
299 | sparse_features: KeyedJaggedTensor,
300 | ) -> torch.Tensor:
301 | """
302 | Args:
303 | dense_features (torch.Tensor): the dense features.
304 | sparse_features (KeyedJaggedTensor): the sparse features.
305 | Returns:
306 | torch.Tensor: logits with size B X 1.
307 | """
308 | embedded_dense = self.dense_arch(dense_features)
309 | embedded_sparse = self.sparse_arch(sparse_features)
310 | concatenated_dense = self.inter_arch(
311 | dense_features=embedded_dense, sparse_features=embedded_sparse
312 | )
313 | logits = self.over_arch(concatenated_dense)
314 | return logits
--------------------------------------------------------------------------------
/baselines/models/dlrm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the BSD-style license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from typing import Dict, List, Optional, Tuple
9 |
10 | import torch
11 | from torch import nn
12 | from torchrec.modules.embedding_modules import EmbeddingBagCollection
13 | from torchrec.modules.mlp import MLP
14 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
15 | from torchrec.datasets.utils import Batch
16 |
17 |
18 | def choose(n: int, k: int) -> int:
19 | """
20 | Simple implementation of math.comb for Python 3.7 compatibility.
21 | """
22 | if 0 <= k <= n:
23 | ntok = 1
24 | ktok = 1
25 | for t in range(1, min(k, n - k) + 1):
26 | ntok *= n
27 | ktok *= t
28 | n -= 1
29 | return ntok // ktok
30 | else:
31 | return 0
32 |
33 |
34 | class SparseArch(nn.Module):
35 | """
36 | Processes the sparse features of DLRM. Does embedding lookups for all EmbeddingBag
37 | and embedding features of each collection.
38 |
39 | Args:
40 | embedding_bag_collection (EmbeddingBagCollection): represents a collection of
41 | pooled embeddings.
42 |
43 | Example::
44 |
45 | eb1_config = EmbeddingBagConfig(
46 | name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"]
47 | )
48 | eb2_config = EmbeddingBagConfig(
49 | name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"]
50 | )
51 |
52 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
53 | sparse_arch = SparseArch(embedding_bag_collection)
54 |
55 | # 0 1 2 <-- batch
56 | # 0 [0,1] None [2]
57 | # 1 [3] [4] [5,6,7]
58 | # ^
59 | # feature
60 | features = KeyedJaggedTensor.from_offsets_sync(
61 | keys=["f1", "f2"],
62 | values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
63 | offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]),
64 | )
65 |
66 | sparse_embeddings = sparse_arch(features)
67 | """
68 |
69 | def __init__(self, embedding_bag_collection: EmbeddingBagCollection) -> None:
70 | super().__init__()
71 | self.embedding_bag_collection: EmbeddingBagCollection = embedding_bag_collection
72 | emb_config = self.embedding_bag_collection.embedding_bag_configs()
73 | assert emb_config, "Embedding bag collection cannot be empty!"
74 | self.D: int = emb_config[0].embedding_dim
75 | self._sparse_feature_names: List[str] = [
76 | name for conf in embedding_bag_collection.embedding_bag_configs() for name in conf.feature_names
77 | ]
78 |
79 | self.F: int = len(self._sparse_feature_names)
80 |
81 | def forward(
82 | self,
83 | features: KeyedJaggedTensor,
84 | ) -> torch.Tensor:
85 | """
86 | Args:
87 | features (KeyedJaggedTensor): an input tensor of sparse features.
88 |
89 | Returns:
90 | torch.Tensor: tensor of shape B X F X D.
91 | """
92 | sparse_features: KeyedTensor = self.embedding_bag_collection(features)
93 | B: int = features.stride()
94 |
95 | sparse: Dict[str, torch.Tensor] = sparse_features.to_dict()
96 | sparse_values: List[torch.Tensor] = []
97 | for name in self.sparse_feature_names:
98 | sparse_values.append(sparse[name])
99 | return torch.cat(sparse_values, dim=1).view(B, self.F, self.D)
100 |
101 | @property
102 | def sparse_feature_names(self) -> List[str]:
103 | return self._sparse_feature_names
104 |
105 |
106 | class DenseArch(nn.Module):
107 | """
108 | Processes the dense features of DLRM model.
109 |
110 | Args:
111 | in_features (int): dimensionality of the dense input features.
112 | layer_sizes (List[int]): list of layer sizes.
113 | device (Optional[torch.device]): default compute device.
114 |
115 | Example::
116 |
117 | B = 20
118 | D = 3
119 | dense_arch = DenseArch(10, layer_sizes=[15, D])
120 | dense_embedded = dense_arch(torch.rand((B, 10)))
121 | """
122 |
123 | def __init__(
124 | self,
125 | in_features: int,
126 | layer_sizes: List[int],
127 | device: Optional[torch.device] = None,
128 | ) -> None:
129 | super().__init__()
130 | self.model: nn.Module = MLP(in_features, layer_sizes, bias=True, activation="relu", device=device)
131 |
132 | def forward(self, features: torch.Tensor) -> torch.Tensor:
133 | """
134 | Args:
135 | features (torch.Tensor): an input tensor of dense features.
136 |
137 | Returns:
138 | torch.Tensor: an output tensor of size B X D.
139 | """
140 | return self.model(features)
141 |
142 |
143 | class InteractionArch(nn.Module):
144 | """
145 | Processes the output of both `SparseArch` (sparse_features) and `DenseArch`
146 | (dense_features). Returns the pairwise dot product of each sparse feature pair,
147 | the dot product of each sparse features with the output of the dense layer,
148 | and the dense layer itself (all concatenated).
149 |
150 | .. note::
151 | The dimensionality of the `dense_features` (D) is expected to match the
152 | dimensionality of the `sparse_features` so that the dot products between them
153 | can be computed.
154 |
155 |
156 | Args:
157 | num_sparse_features (int): F.
158 |
159 | Example::
160 |
161 | D = 3
162 | B = 10
163 | keys = ["f1", "f2"]
164 | F = len(keys)
165 | inter_arch = InteractionArch(num_sparse_features=len(keys))
166 |
167 | dense_features = torch.rand((B, D))
168 | sparse_features = torch.rand((B, F, D))
169 |
170 | # B X (D + F + F choose 2)
171 | concat_dense = inter_arch(dense_features, sparse_features)
172 | """
173 |
174 | def __init__(self, num_sparse_features: int, num_dense_features: int = 1) -> None:
175 | super().__init__()
176 | self.F: int = num_sparse_features
177 | self.num_dense_features = num_dense_features
178 | self.register_buffer(
179 | 'triu_indices',
180 | torch.triu_indices(self.F + self.num_dense_features, self.F + self.num_dense_features,
181 | offset=1).requires_grad_(False), False)
182 |
183 | def forward(self, dense_features: torch.Tensor, sparse_features: torch.Tensor) -> torch.Tensor:
184 | """
185 | Args:
186 | dense_features (torch.Tensor): an input tensor of size B X D.
187 | sparse_features (torch.Tensor): an input tensor of size B X F X D.
188 |
189 | Returns:
190 | torch.Tensor: an output tensor of size B X (D + F + F choose 2).
191 | """
192 | if self.F <= 0:
193 | return dense_features
194 |
195 | if self.num_dense_features <= 0:
196 | combined_values = sparse_features
197 | else:
198 | # b f+1 d
199 | combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features), dim=1)
200 |
201 | # dense/sparse + sparse/sparse interaction
202 | # size B X (F + F choose 2)
203 | interactions = torch.bmm(combined_values, torch.transpose(combined_values, 1, 2))
204 | interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]]
205 |
206 | return torch.cat((dense_features, interactions_flat), dim=1)
207 |
208 |
209 | class OverArch(nn.Module):
210 | """
211 | Final Arch of DLRM - simple MLP over OverArch.
212 |
213 | Args:
214 | in_features (int): size of the input.
215 | layer_sizes (List[int]): sizes of the layers of the `OverArch`.
216 | device (Optional[torch.device]): default compute device.
217 |
218 | Example::
219 |
220 | B = 20
221 | D = 3
222 | over_arch = OverArch(10, [5, 1])
223 | logits = over_arch(torch.rand((B, 10)))
224 | """
225 |
226 | def __init__(
227 | self,
228 | in_features: int,
229 | layer_sizes: List[int],
230 | device: Optional[torch.device] = None,
231 | ) -> None:
232 | super().__init__()
233 | if len(layer_sizes) <= 1:
234 | raise ValueError("OverArch must have multiple layers.")
235 | self.model: nn.Module = nn.Sequential(
236 | MLP(
237 | in_features,
238 | layer_sizes[:-1],
239 | bias=True,
240 | activation="relu",
241 | device=device,
242 | ),
243 | nn.Linear(layer_sizes[-2], layer_sizes[-1], bias=True, device=device),
244 | )
245 |
246 | def forward(self, features: torch.Tensor) -> torch.Tensor:
247 | """
248 | Args:
249 | features (torch.Tensor):
250 |
251 | Returns:
252 | torch.Tensor: size B X layer_sizes[-1]
253 | """
254 | return self.model(features)
255 |
256 |
257 | class DLRM(nn.Module):
258 | """
259 | Recsys model from "Deep Learning Recommendation Model for Personalization and
260 | Recommendation Systems" (https://arxiv.org/abs/1906.00091). Processes sparse
261 | features by learning pooled embeddings for each feature. Learns the relationship
262 | between dense features and sparse features by projecting dense features into the
263 | same embedding space. Also, learns the pairwise relationships between sparse
264 | features.
265 |
266 | The module assumes all sparse features have the same embedding dimension
267 | (i.e. each EmbeddingBagConfig uses the same embedding_dim).
268 |
269 | The following notation is used throughout the documentation for the models:
270 |
271 | * F: number of sparse features
272 | * D: embedding_dimension of sparse features
273 | * B: batch size
274 | * num_features: number of dense features
275 |
276 | Args:
277 | embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags
278 | used to define `SparseArch`.
279 | dense_in_features (int): the dimensionality of the dense input features.
280 | dense_arch_layer_sizes (List[int]): the layer sizes for the `DenseArch`.
281 | over_arch_layer_sizes (List[int]): the layer sizes for the `OverArch`.
282 | The output dimension of the `InteractionArch` should not be manually
283 | specified here.
284 | dense_device (Optional[torch.device]): default compute device.
285 |
286 | Example::
287 |
288 | B = 2
289 | D = 8
290 |
291 | eb1_config = EmbeddingBagConfig(
292 | name="t1", embedding_dim=D, num_embeddings=100, feature_names=["f1", "f3"]
293 | )
294 | eb2_config = EmbeddingBagConfig(
295 | name="t2",
296 | embedding_dim=D,
297 | num_embeddings=100,
298 | feature_names=["f2"],
299 | )
300 |
301 | ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
302 | model = DLRM(
303 | embedding_bag_collection=ebc,
304 | dense_in_features=100,
305 | dense_arch_layer_sizes=[20],
306 | over_arch_layer_sizes=[5, 1],
307 | )
308 |
309 | features = torch.rand((B, 100))
310 |
311 | # 0 1
312 | # 0 [1,2] [4,5]
313 | # 1 [4,3] [2,9]
314 | # ^
315 | # feature
316 | sparse_features = KeyedJaggedTensor.from_offsets_sync(
317 | keys=["f1", "f3"],
318 | values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]),
319 | offsets=torch.tensor([0, 2, 4, 6, 8]),
320 | )
321 |
322 | logits = model(
323 | dense_features=features,
324 | sparse_features=sparse_features,
325 | )
326 | """
327 |
328 | def __init__(
329 | self,
330 | embedding_bag_collection: EmbeddingBagCollection,
331 | dense_in_features: int,
332 | dense_arch_layer_sizes: List[int],
333 | over_arch_layer_sizes: List[int],
334 | dense_device: Optional[torch.device] = None,
335 | ) -> None:
336 | super().__init__()
337 |
338 | # For torchrec version 0.1.x
339 | # emb_configs = embedding_bag_collection.embedding_bag_configs
340 | # this is for torchrec version 0.2.0
341 | emb_configs = embedding_bag_collection.embedding_bag_configs()
342 | assert (len(emb_configs) > 0), "At least one embedding bag is required"
343 | for i in range(1, len(emb_configs)):
344 | conf_prev = emb_configs[i - 1]
345 | conf = emb_configs[i]
346 | assert (
347 | conf_prev.embedding_dim == conf.embedding_dim), "All EmbeddingBagConfigs must have the same dimension"
348 | embedding_dim: int = emb_configs[0].embedding_dim
349 | if dense_arch_layer_sizes[-1] != embedding_dim:
350 | raise ValueError(f"embedding_bag_collection dimension ({embedding_dim}) and final dense "
351 | "arch layer size ({dense_arch_layer_sizes[-1]}) must match.")
352 |
353 | self.sparse_arch: SparseArch = SparseArch(embedding_bag_collection)
354 | num_sparse_features: int = len(self.sparse_arch.sparse_feature_names)
355 |
356 | self.dense_arch = DenseArch(
357 | in_features=dense_in_features,
358 | layer_sizes=dense_arch_layer_sizes,
359 | device=dense_device,
360 | )
361 | self.inter_arch = InteractionArch(num_sparse_features=num_sparse_features)
362 |
363 | over_in_features: int = (embedding_dim + choose(num_sparse_features, 2) + num_sparse_features)
364 | self.over_arch = OverArch(
365 | in_features=over_in_features,
366 | layer_sizes=over_arch_layer_sizes,
367 | device=dense_device,
368 | )
369 |
370 | def forward(
371 | self,
372 | dense_features: torch.Tensor,
373 | sparse_features: KeyedJaggedTensor,
374 | ) -> torch.Tensor:
375 | """
376 | Args:
377 | dense_features (torch.Tensor): the dense features.
378 | sparse_features (KeyedJaggedTensor): the sparse features.
379 |
380 | Returns:
381 | torch.Tensor: logits.
382 | """
383 | embedded_dense = self.dense_arch(dense_features)
384 | embedded_sparse = self.sparse_arch(sparse_features)
385 | concatenated_dense = self.inter_arch(dense_features=embedded_dense, sparse_features=embedded_sparse)
386 | logits = self.over_arch(concatenated_dense)
387 | return logits
388 |
389 |
390 | class DLRMTrain(nn.Module):
391 | """
392 | nn.Module to wrap DLRM model to use with train_pipeline.
393 |
394 | DLRM Recsys model from "Deep Learning Recommendation Model for Personalization and
395 | Recommendation Systems" (https://arxiv.org/abs/1906.00091). Processes sparse
396 | features by learning pooled embeddings for each feature. Learns the relationship
397 | between dense features and sparse features by projecting dense features into the
398 | same embedding space. Also, learns the pairwise relationships between sparse
399 | features.
400 |
401 | The module assumes all sparse features have the same embedding dimension
402 | (i.e, each EmbeddingBagConfig uses the same embedding_dim)
403 |
404 | Args:
405 | embedding_bag_collection (EmbeddingBagCollection): collection of embedding bags
406 | used to define SparseArch.
407 | dense_in_features (int): the dimensionality of the dense input features.
408 | dense_arch_layer_sizes (list[int]): the layer sizes for the DenseArch.
409 | over_arch_layer_sizes (list[int]): the layer sizes for the OverArch. NOTE: The
410 | output dimension of the InteractionArch should not be manually specified
411 | here.
412 | dense_device: (Optional[torch.device]).
413 |
414 | Call Args:
415 | batch: batch used with criteo and random data from torchrec.datasets
416 |
417 | Returns:
418 | Tuple[loss, Tuple[loss, logits, labels]]
419 |
420 | Example::
421 |
422 | ebc = EmbeddingBagCollection(config=ebc_config)
423 | model = DLRMTrain(
424 | embedding_bag_collection=ebc,
425 | dense_in_features=100,
426 | dense_arch_layer_sizes=[20],
427 | over_arch_layer_sizes=[5, 1],
428 | )
429 | """
430 |
431 | def __init__(
432 | self,
433 | embedding_bag_collection: EmbeddingBagCollection,
434 | dense_in_features: int,
435 | dense_arch_layer_sizes: List[int],
436 | over_arch_layer_sizes: List[int],
437 | dense_device: Optional[torch.device] = None,
438 | ) -> None:
439 | super().__init__()
440 | self.model = DLRM(
441 | embedding_bag_collection=embedding_bag_collection,
442 | dense_in_features=dense_in_features,
443 | dense_arch_layer_sizes=dense_arch_layer_sizes,
444 | over_arch_layer_sizes=over_arch_layer_sizes,
445 | dense_device=dense_device,
446 | )
447 | self.loss_fn: nn.Module = nn.BCEWithLogitsLoss()
448 |
449 | def forward(self, batch: Batch) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
450 | logits = self.model(batch.dense_features, batch.sparse_features)
451 | logits = logits.squeeze()
452 | loss = self.loss_fn(logits, batch.labels.float())
453 |
454 | return loss, (loss.detach(), logits.detach(), batch.labels.detach())
455 |
--------------------------------------------------------------------------------
/benchmark/benchmark_cache.py:
--------------------------------------------------------------------------------
1 | """
2 | 1. hit rate
3 | 2. bandwidth
4 | 3. read / load
5 | 4. elapsed time
6 | """
7 | import itertools
8 | from tqdm import tqdm
9 | from contexttimer import Timer
10 | from contextlib import nullcontext
11 | import numpy as np
12 |
13 | import torch
14 | from torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler
15 |
16 | from colossalai.nn.parallel.layers import CachedEmbeddingBag, EvictionStrategy
17 | from recsys.datasets.criteo import get_id_freq_map
18 | from data_utils import get_dataloader, NUM_EMBED, CRITEO_PATH
19 |
20 |
21 | def benchmark_cache_embedding(batch_size,
22 | embedding_dim,
23 | cache_ratio,
24 | id_freq_map=None,
25 | warmup_ratio=0.,
26 | use_limit_buf=True,
27 | use_lfu=False):
28 | dataloader = get_dataloader('train', batch_size)
29 | cuda_row_num = int(cache_ratio * NUM_EMBED)
30 | print(f"batch size: {batch_size}, "
31 | f"num of batches: {len(dataloader)}, "
32 | f"cached rows: {cuda_row_num}, cached_ratio {cuda_row_num / NUM_EMBED}")
33 | data_iter = iter(dataloader)
34 |
35 | torch.cuda.reset_peak_memory_stats()
36 | device = torch.device('cuda:0')
37 |
38 | with Timer() as timer:
39 | model = CachedEmbeddingBag(NUM_EMBED, embedding_dim, sparse=True, include_last_offset=True, \
40 | evict_strategy=EvictionStrategy.LFU if use_lfu else EvictionStrategy.DATASET).to(device)
41 | # model = torch.nn.EmbeddingBag(NUM_EMBED, embedding_dim, sparse=True, include_last_offset=True).to(device)
42 | print(f"model init: {timer.elapsed:.2f}s")
43 |
44 | grad = None
45 | print(
46 | f'after reorder max_memory_allocated {torch.cuda.max_memory_allocated()/1e9} GB, max_memory_reserved {torch.cuda.max_memory_allocated()/1e9} GB'
47 | )
48 | torch.cuda.reset_peak_memory_stats()
49 |
50 | with Timer() as timer:
51 | with tqdm(bar_format='{n_fmt}it {rate_fmt} {postfix}') as t:
52 | # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
53 | # schedule=schedule(wait=0, warmup=21, active=2, repeat=1),
54 | # profile_memory=True,
55 | # on_trace_ready=tensorboard_trace_handler(
56 | # f"log/b{batch_size}-e{embedding_dim}-num_chunk{cuda_row_num}-chunk_size{cache_lines}")) as prof:
57 | with nullcontext():
58 | for it in itertools.count():
59 | batch = next(data_iter)
60 | sparse_feature = batch.sparse_features.to(device)
61 |
62 | res = model(sparse_feature.values(), sparse_feature.offsets())
63 |
64 | grad = torch.randn_like(res) if grad is None else grad
65 | res.backward(grad)
66 |
67 | model.zero_grad()
68 | # prof.step()
69 |
70 | t.update()
71 | if it == 200:
72 | break
73 |
74 | if hasattr(model, 'cache_weight_mgr'):
75 | model.cache_weight_mgr.print_comm_stats()
76 |
77 |
78 | if __name__ == "__main__":
79 | with Timer() as timer:
80 | id_freq_map = get_id_freq_map(CRITEO_PATH)
81 | print(f"Counting sparse features in dataset costs: {timer.elapsed:.2f} s")
82 |
83 | batch_size = [2048]
84 | embed_dim = 32
85 | cache_ratio = [0.02]
86 |
87 | # # row-wise cache
88 | # for bs in batch_size:
89 | # for cs in cuda_row_num:
90 | # main(bs, embed_dim, cuda_row_num=cs, cache_lines=1, embed_type='row')
91 |
92 | # chunk-wise cache
93 | for bs in batch_size:
94 | for cr in cache_ratio:
95 | for warmup_ratio in [0.7]:
96 | for use_buf in [False, True]:
97 | try:
98 | benchmark_cache_embedding(bs,
99 | embed_dim,
100 | cache_ratio=cr,
101 | id_freq_map=id_freq_map,
102 | warmup_ratio=warmup_ratio,
103 | use_limit_buf=use_buf)
104 | print('=' * 50 + '\n')
105 |
106 | except AssertionError as ae:
107 | print(f"batch size: {bs}, cache ratio: {cr}, raise error: {ae}")
108 | print('=' * 50 + '\n')
--------------------------------------------------------------------------------
/benchmark/benchmark_fbgemm_uvm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from colossalai.nn.parallel.layers.cache_embedding import CachedEmbeddingBag
4 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
5 | from fbgemm_gpu.split_table_batched_embeddings_ops import SplitTableBatchedEmbeddingBagsCodegen, EmbeddingLocation, ComputeDevice, CacheAlgorithm
6 | import time
7 |
8 | ########### GLOBAL SETTINGS ##################
9 |
10 |
11 | BATCH_SIZE = 65536
12 | TABLLE_NUM = 856
13 | FILE_LIST = [f"/data/scratch/RecSys/embedding_bag/fbgemm_t856_bs65536_{i}.pt" for i in range(16)]
14 | KEYS = []
15 | for i in range(TABLLE_NUM):
16 | KEYS.append("table_{}".format(i))
17 | EMBEDDING_DIM = 128
18 | # Full dataset is too big
19 | # CHOSEN_TABLES = [0, 2, 3, 4, 5, 7, 8, 9, 10, 12, 15, 18, 22, 27, 28]
20 | # CHOSEN_TABLES = [5, 8, 37, 54, 71,72,73,74,85,86,89,95,96,97,107,131,163, 185, 196, 204, 211, ]
21 | CHOSEN_TABLES = [i for i in range(300,418)]
22 | TEST_ITER = 100
23 | TEST_BATCH_SIZE = 16384
24 | WARMUP_ITERS = 5
25 | ##############################################
26 |
27 |
28 | def load_file(file_path):
29 | indices, offsets, lengths = torch.load(file_path)
30 | indices = indices.int().cuda()
31 | offsets = offsets.int().cuda()
32 | lengths = lengths.int().cuda()
33 | num_embeddings_per_table = []
34 | indices_per_table = []
35 | lengths_per_table = []
36 | offsets_per_table = []
37 | for i in range(TABLLE_NUM):
38 | if i not in CHOSEN_TABLES:
39 | continue
40 | start_pos = offsets[i * BATCH_SIZE]
41 | end_pos = offsets[i * BATCH_SIZE + BATCH_SIZE]
42 | part = indices[start_pos:end_pos]
43 | indices_per_table.append(part)
44 | lengths_per_table.append(lengths[i])
45 | offsets_per_table.append(torch.cumsum(
46 | torch.cat((torch.tensor([0]).cuda(), lengths[i])), 0
47 | ))
48 | if part.numel() == 0:
49 | num_embeddings_per_table.append(0)
50 | else:
51 | num_embeddings_per_table.append(torch.max(part).int().item() + 1)
52 | return indices_per_table, offsets_per_table, lengths_per_table, num_embeddings_per_table
53 |
54 |
55 | def load_file_kjt(file_path):
56 | indices, offsets, lengths = torch.load(file_path)
57 | length_per_key = []
58 | for i in range(TABLLE_NUM):
59 | length_per_key.append(lengths[i])
60 | ret = KeyedJaggedTensor(KEYS, indices, offsets=offsets,
61 | length_per_key=length_per_key)
62 | return ret
63 |
64 |
65 | def load_random_batch(indices_per_table, offsets_per_table, lengths_per_table, batch_size=4096):
66 | chosen_indices_list = []
67 | chosen_lengths_list = []
68 | choose = torch.randint(
69 | 0, offsets_per_table[0].shape[0] - 1, (batch_size,)).cuda()
70 | for indices, offsets, lengths in zip(indices_per_table, offsets_per_table, lengths_per_table):
71 |
72 | chosen_lengths_list.append(lengths[choose])
73 | start_list = offsets[choose]
74 | end_list = offsets[choose + 1]
75 | chosen_indices_atoms = []
76 | for start, end in zip(start_list, end_list):
77 | chosen_indices_atoms.append(indices[start: end])
78 | chosen_indices_list.append(torch.cat(chosen_indices_atoms, 0))
79 | return chosen_indices_list, chosen_lengths_list
80 |
81 |
82 | def merge_to_kjt(indices_list, lengths_list, length_per_key) -> KeyedJaggedTensor:
83 | values = torch.cat(indices_list)
84 | lengths = torch.cat(lengths_list)
85 | return KeyedJaggedTensor(
86 | keys=[KEYS[i] for i in CHOSEN_TABLES],
87 | values=values,
88 | lengths=lengths,
89 | length_per_key=length_per_key,
90 | )
91 |
92 |
93 | def test(iter_num=1, batch_size=4096):
94 | print("loading file")
95 | indices_per_table, offsets_per_table, lengths_per_table, num_embeddings_per_table = load_file(
96 | FILE_LIST[0])
97 | table_idx_offset_list = np.cumsum([0] + num_embeddings_per_table)
98 | fae = CachedEmbeddingBag(
99 | num_embeddings=sum(num_embeddings_per_table),
100 | embedding_dim=EMBEDDING_DIM,
101 | sparse=True,
102 | include_last_offset=True,
103 | cache_ratio=0.05,
104 | pin_weight=True,
105 | )
106 | fae_forwarding_time = 0.0
107 | fae_backwarding_time = 0.0
108 | grad_fae = None
109 | managed_type = (EmbeddingLocation.MANAGED_CACHING)
110 | uvm = SplitTableBatchedEmbeddingBagsCodegen(
111 | embedding_specs=[(
112 | num_embeddings,
113 | EMBEDDING_DIM,
114 | managed_type,
115 | ComputeDevice.CUDA,
116 | ) for num_embeddings in num_embeddings_per_table],
117 | cache_load_factor=0.05,
118 | cache_algorithm=CacheAlgorithm.LFU
119 | )
120 | uvm.init_embedding_weights_uniform(-0.5, 0.5)
121 | # print(sum(num_embeddings_per_table))
122 | # print(uvm.weights_uvm.shape)
123 | uvm_forwarding_time = 0.0
124 | uvm_backwarding_time = 0.0
125 | grad_uvm = None
126 | print("testing:")
127 | for iter in range(iter_num):
128 | # load batch
129 | chosen_indices_list, chosen_lengths_list = load_random_batch(indices_per_table,
130 | offsets_per_table, lengths_per_table, batch_size)
131 | features = merge_to_kjt(chosen_indices_list,
132 | chosen_lengths_list, num_embeddings_per_table)
133 | print("iter {} batch loaded.".format(iter))
134 |
135 | # fae
136 | with torch.no_grad():
137 | values = features.values().long()
138 | offsets = features.offsets().long()
139 | weights = features.weights_or_none()
140 | batch_size = len(features.offsets()) // len(features.keys())
141 | if weights is not None and not torch.is_floating_point(weights):
142 | weights = None
143 | split_view = torch.tensor_split(
144 | values, features.offset_per_key()[1:-1], dim=0)
145 | for i, chunk in enumerate(split_view):
146 | torch.add(chunk, table_idx_offset_list[i], out=chunk)
147 | start = time.time()
148 | output = fae(values, offsets, weights)
149 | ret = torch.cat(output.split(batch_size), 1)
150 | if iter >= WARMUP_ITERS:
151 | fae_forwarding_time += time.time() - start
152 | print("fae forwarded. avg time = {} s".format(
153 | fae_forwarding_time / (iter + 1 - WARMUP_ITERS)))
154 | grad_fae = torch.randn_like(ret) if grad_fae is None else grad_fae
155 | start = time.time()
156 | ret.backward(grad_fae)
157 | if iter >= WARMUP_ITERS:
158 | fae_backwarding_time += time.time() - start
159 | print("fae backwarded. avg time = {} s".format(
160 | fae_backwarding_time / (iter + 1 - WARMUP_ITERS)))
161 | fae.zero_grad()
162 |
163 | # uvm
164 | start = time.time()
165 | ret = uvm(features.values().long(), features.offsets().long())
166 | if iter >= WARMUP_ITERS:
167 | uvm_forwarding_time += time.time() - start
168 | print("uvm forwarded. avg time = {} s".format(
169 | uvm_forwarding_time / (iter + 1 - WARMUP_ITERS)))
170 | grad_uvm = torch.randn_like(ret) if grad_uvm is None else grad_uvm
171 | start = time.time()
172 | ret.backward(grad_uvm)
173 | if iter >= WARMUP_ITERS:
174 | uvm_backwarding_time += time.time() - start
175 | print("uvm backwarded. avg time = {} s".format(
176 | uvm_backwarding_time / (iter + 1 - WARMUP_ITERS)))
177 | uvm.zero_grad()
178 |
179 |
180 | # test(TEST_ITER, TEST_BATCH_SIZE)
181 | num_embeddings_per_table = None
182 | for i in range(16):
183 | indices_per_table, offsets_per_table, lengths_per_table, num_embeddings_per_table1 = load_file(FILE_LIST[i])
184 | if num_embeddings_per_table == None:
185 | num_embeddings_per_table = num_embeddings_per_table1
186 | else:
187 | for i, num in enumerate(num_embeddings_per_table1):
188 | num_embeddings_per_table[i] = max(num_embeddings_per_table[i], num)
189 | print(num_embeddings_per_table)
190 | print(sum(num_embeddings_per_table))
--------------------------------------------------------------------------------
/benchmark/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.data import DataLoader
3 | from recsys.datasets import criteo
4 |
5 | CRITEO_PATH = "../criteo_kaggle_data"
6 | NUM_EMBED = 33762577
7 |
8 |
9 | def get_dataloader(stage, batch_size):
10 | hash_sizes = list(map(int, criteo.KAGGLE_NUM_EMBEDDINGS_PER_FEATURE.split(',')))
11 | files = os.listdir(CRITEO_PATH)
12 |
13 | def is_final_day(s):
14 | return "day_6" in s
15 |
16 | rank, world_size = 0, 1
17 | if stage == "train":
18 | # Train set gets all data except from the final day.
19 | files = list(filter(lambda s: not is_final_day(s), files))
20 | else:
21 | # Validation set gets the first half of the final day's samples. Test set get
22 | # the other half.
23 | files = list(filter(is_final_day, files))
24 | rank = rank if stage == "val" else (rank + world_size)
25 | world_size = world_size * 2
26 |
27 | stage_files = [
28 | sorted(map(
29 | lambda x: os.path.join(CRITEO_PATH, x),
30 | filter(lambda s: kind in s, files),
31 | )) for kind in ["dense", "sparse", "labels"]
32 | ]
33 | dataloader = DataLoader(
34 | criteo.InMemoryBinaryCriteoIterDataPipe(
35 | *stage_files, # pyre-ignore[6]
36 | batch_size=batch_size,
37 | rank=rank,
38 | world_size=world_size,
39 | shuffle_batches=True,
40 | hashes=hash_sizes),
41 | batch_size=None,
42 | pin_memory=True,
43 | collate_fn=lambda x: x,
44 | )
45 | return dataloader
46 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Domestic cloud servers often have network issues with pip,
2 | # so we need to pip install from tsinghua mirror
3 | FROM hpcaitech/pytorch-cuda:1.12.0-11.3.0
4 |
5 | RUN pip install --no-cache-dir petastorm[torch]==0.11.5
6 |
7 | #install torchrec
8 | RUN python3 -m pip install --no-cache-dir torchrec==0.2.0
9 |
10 | #install torchrec
11 | # RUN wget https://download.pytorch.org/whl/torchrec-0.1.1-py39-none-any.whl && \
12 | # python3 -m pip install --no-cache-dir torchrec-0.1.1-py39-none-any.whl && \
13 | # rm torchrec-0.1.1-py39-none-any.whl
14 |
15 | # updated with hpcaitech version
16 | RUN git clone https://github.com/hpcaitech/torchrec.git && cd torchrec && pip install .
17 |
18 | # install colossalai
19 | RUN git clone https://github.com/hpcaitech/ColossalAI.git && \
20 | cd ColossalAI/ && \
21 | python3 -m pip install -i --no-cache-dir -r requirements/requirements.txt && \
22 | python3 -m pip install -i --no-cache-dir . && \
23 | cd .. && \
24 | yes | rm -r ColossalAI/
25 |
--------------------------------------------------------------------------------
/docker/Dockerfile_thu:
--------------------------------------------------------------------------------
1 | # Domestic cloud servers often have network issues with pip,
2 | # so we need to pip install from tsinghua mirror
3 | FROM hpcaitech/pytorch-cuda:1.12.0-11.3.0
4 |
5 | RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir petastorm[torch]==0.11.5
6 |
7 | #install torchrec
8 | RUN python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir torchrec==0.2.0
9 |
10 | #install torchrec
11 | # RUN wget https://download.pytorch.org/whl/torchrec-0.1.1-py39-none-any.whl && \
12 | # python3 -m pip install --no-cache-dir torchrec-0.1.1-py39-none-any.whl && \
13 | # rm torchrec-0.1.1-py39-none-any.whl
14 |
15 | # updated with hpcaitech version
16 | RUN git clone https://github.com/hpcaitech/torchrec.git && cd torchrec && pip install .
17 |
18 | # install colossalai
19 | RUN git clone https://github.com/hpcaitech/ColossalAI.git && \
20 | cd ColossalAI/ && \
21 | python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir -r requirements/requirements.txt && \
22 | python3 -m pip install -i https://pypi.tuna.tsinghua.edu.cn/simple --no-cache-dir . && \
23 | cd .. && \
24 | yes | rm -r ColossalAI/
25 |
--------------------------------------------------------------------------------
/docker/launch.sh:
--------------------------------------------------------------------------------
1 | DATASET_PATH=/data/scratch/RecSys
2 |
3 | docker run --rm -it -e CUDA_VISIBLE_DEVICES=0,1,2,3 -e PYTHONPATH=/workspace/code -v `pwd`:/workspace/code -v ${DATASET_PATH}:/data -w /workspace/code --ipc=host hpcaitech/cacheembedding:0.1.4 /bin/bash
4 |
--------------------------------------------------------------------------------
/license:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2021- HPC-AI Technology Inc.
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/pics/prefetch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/pics/prefetch.png
--------------------------------------------------------------------------------
/recsys/README.md:
--------------------------------------------------------------------------------
1 | # Build a scalable embedding from scratch for training recommendation models
--------------------------------------------------------------------------------
/recsys/__init__.py:
--------------------------------------------------------------------------------
1 | # Build a scalable system from scratch for training recommendation models
2 |
--------------------------------------------------------------------------------
/recsys/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/recsys/datasets/__init__.py
--------------------------------------------------------------------------------
/recsys/datasets/avazu.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.utils.data.datapipes as dp
7 | from torch.utils.data import IterDataPipe, IterableDataset, DataLoader
8 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
9 | from torchrec.datasets.utils import LoadFiles, ReadLinesFromCSV, PATH_MANAGER_KEY, Batch
10 | from torchrec.datasets.criteo import BinaryCriteoUtils
11 |
12 | from .feature_counter import GlobalFeatureCounter
13 |
14 | CAT_FEATURE_COUNT = 13
15 | INT_FEATURE_COUNT = 8
16 | DAYS = 10
17 | DEFAULT_LABEL_NAME = "click"
18 | DEFAULT_CAT_NAMES = [
19 | 'C1',
20 | 'banner_pos',
21 | 'site_id',
22 | 'site_domain',
23 | 'site_category',
24 | 'app_id',
25 | 'app_domain',
26 | 'app_category',
27 | 'device_id',
28 | 'device_ip',
29 | 'device_model',
30 | 'device_type',
31 | 'device_conn_type',
32 | ]
33 | DEFAULT_INT_NAMES = ['C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21']
34 | NUM_EMBEDDINGS_PER_FEATURE = '7,7,4737,7745,26,8552,559,36,2686408,6729486,8251,5,4' # 9445823 in total
35 | TOTAL_TRAINING_SAMPLES = 36_386_071 # 90% sample in train, 40428967 in total
36 |
37 |
38 | def _default_row_mapper(row):
39 | _label = row[1]
40 | _sparse = row[3:5]
41 | for i in range(5, 14): # 9
42 | try:
43 | _c = int(row[i], 16)
44 | except ValueError:
45 | _c = 0
46 | _sparse.append(_c)
47 | _sparse += row[14:24]
48 |
49 | return _sparse, _label
50 |
51 |
52 | class AvazuIterDataPipe(IterDataPipe):
53 |
54 | def __init__(self, path, row_mapper=_default_row_mapper):
55 | self.path = path
56 | self.row_mapper = row_mapper
57 |
58 | def __iter__(self):
59 | """
60 | iterate over the data file, and apply the transform row_mapper to each row
61 | """
62 | datapipe = LoadFiles([self.path], mode='r', path_manager_key='avazu')
63 | datapipe = ReadLinesFromCSV(datapipe, delimiter=',', skip_first_line=True)
64 | if self.row_mapper is not None:
65 | datapipe = dp.iter.Mapper(datapipe, self.row_mapper)
66 | yield from datapipe
67 |
68 |
69 | class InMemoryAvazuIterDataPipe(IterableDataset):
70 |
71 | def __init__(self,
72 | dense_paths,
73 | sparse_paths,
74 | label_paths,
75 | batch_size,
76 | rank,
77 | world_size,
78 | shuffle_batches=False,
79 | mmap_mode=False,
80 | hashes=None,
81 | path_manager_key=PATH_MANAGER_KEY,
82 | assigned_tables = None):
83 | if assigned_tables is not None:
84 | # tablewise mode
85 | self.assigned_tables = np.array(assigned_tables)
86 | else:
87 | # full table mode
88 | self.assigned_tables = np.arange(CAT_FEATURE_COUNT)
89 |
90 | self.dense_paths = dense_paths
91 | self.sparse_paths = sparse_paths
92 | self.label_paths = label_paths
93 | self.batch_size = batch_size
94 | self.rank = rank
95 | self.world_size = world_size
96 | self.shuffle_batches = shuffle_batches
97 | self.mmap_mode = mmap_mode
98 | if hashes is not None:
99 | self.hashes = []
100 | for i, length in enumerate(hashes):
101 | if i in self.assigned_tables:
102 | self.hashes.append(length)
103 | self.hashes = np.array(self.hashes).reshape(1, -1)
104 | else:
105 | self.hashes = None
106 | self.path_manager_key = path_manager_key
107 |
108 | self.sparse_offsets = np.array([0, *np.cumsum(self.hashes)[:-1]], dtype=np.int64).reshape(
109 | 1, -1) if self.hashes is not None else None
110 |
111 | self._load_data()
112 | self.num_rows_per_file = [a.shape[0] for a in self.dense_arrs]
113 | self.num_batches: int = sum(self.num_rows_per_file) // batch_size
114 |
115 | self._num_ids_in_batch: int = len(self.assigned_tables) * batch_size
116 | self.keys = [DEFAULT_CAT_NAMES[i] for i in self.assigned_tables]
117 | self.lengths = torch.ones((self._num_ids_in_batch,), dtype=torch.int32)
118 | self.offsets = torch.arange(0, self._num_ids_in_batch + 1, dtype=torch.int32)
119 | self.stride = batch_size
120 | self.length_per_key = len(self.assigned_tables) * [batch_size]
121 | self.offset_per_key = [batch_size * i for i in range(len(self.assigned_tables) + 1)]
122 | self.index_per_key = {key: i for (i, key) in enumerate(self.keys)}
123 |
124 | def _load_data(self):
125 | file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(lengths=[
126 | BinaryCriteoUtils.get_shape_from_npy(path, path_manager_key=self.path_manager_key)[0]
127 | for path in self.sparse_paths
128 | ],
129 | rank=self.rank,
130 | world_size=self.world_size)
131 | self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], []
132 | for _dtype, arrs, paths in zip([np.float32, np.int64, np.int32],
133 | [self.dense_arrs, self.sparse_arrs, self.labels_arrs],
134 | [self.dense_paths, self.sparse_paths, self.label_paths]):
135 | for idx, (range_left, range_right) in file_idx_to_row_range.items():
136 | arrs.append(
137 | BinaryCriteoUtils.load_npy_range(paths[idx],
138 | range_left,
139 | range_right - range_left + 1,
140 | path_manager_key=self.path_manager_key,
141 | mmap_mode=self.mmap_mode).astype(_dtype))
142 | expand_hashes = np.ones(CAT_FEATURE_COUNT, dtype=np.int64).reshape(1, -1)
143 | expand_sparse_offsets = np.ones(CAT_FEATURE_COUNT, dtype=np.int64).reshape(1, -1)
144 | for i, table in enumerate(self.assigned_tables):
145 | expand_hashes[0, table] = self.hashes[0, i]
146 | expand_sparse_offsets[0, table] = self.sparse_offsets[0, i]
147 | if not self.mmap_mode and self.hashes is not None:
148 | for sparse_arr in self.sparse_arrs:
149 | sparse_arr %= expand_hashes
150 | sparse_arr += expand_sparse_offsets
151 |
152 | def __iter__(self):
153 | buffer = None
154 |
155 | def append_to_buffer(dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray) -> None:
156 | nonlocal buffer
157 | if buffer is None:
158 | buffer = [dense, sparse, labels]
159 | else:
160 | for idx, arr in enumerate([dense, sparse, labels]):
161 | buffer[idx] = np.concatenate((buffer[idx], arr))
162 |
163 | # Maintain a buffer that can contain up to batch_size rows. Fill buffer as
164 | # much as possible on each iteration. Only return a new batch when batch_size
165 | # rows are filled.
166 | file_idx = 0
167 | row_idx = 0
168 | batch_idx = 0
169 | while batch_idx < self.num_batches:
170 | buffer_row_count = 0 if buffer is None else buffer[0].shape[0]
171 | if buffer_row_count == self.batch_size:
172 | yield self._np_arrays_to_batch(*buffer)
173 | batch_idx += 1
174 | buffer = None
175 | else:
176 | rows_to_get = min(
177 | self.batch_size - buffer_row_count,
178 | self.num_rows_per_file[file_idx] - row_idx,
179 | )
180 | slice_ = slice(row_idx, row_idx + rows_to_get)
181 |
182 | dense_inputs = self.dense_arrs[file_idx][slice_, :]
183 | sparse_inputs = self.sparse_arrs[file_idx][slice_, :].take(self.assigned_tables, -1)
184 | target_labels = self.labels_arrs[file_idx][slice_, :]
185 |
186 | if self.mmap_mode and self.hashes is not None:
187 | sparse_inputs %= self.hashes
188 | sparse_inputs += self.sparse_offsets
189 |
190 | append_to_buffer(
191 | dense_inputs,
192 | sparse_inputs,
193 | target_labels,
194 | )
195 | row_idx += rows_to_get
196 |
197 | if row_idx >= self.num_rows_per_file[file_idx]:
198 | file_idx += 1
199 | row_idx = 0
200 |
201 | def _np_arrays_to_batch(self, dense: np.ndarray, sparse: np.ndarray, labels: np.ndarray) -> Batch:
202 | if self.shuffle_batches:
203 | # Shuffle all 3 in unison
204 | shuffler = np.random.permutation(sparse.shape[0])
205 | dense = dense[shuffler]
206 | sparse = sparse[shuffler]
207 | labels = labels[shuffler]
208 |
209 | return Batch(
210 | dense_features=torch.from_numpy(dense),
211 | sparse_features=KeyedJaggedTensor(
212 | keys=self.keys,
213 | # transpose + reshape(-1) incurs an additional copy.
214 | values=torch.from_numpy(sparse.transpose(1, 0).reshape(-1)),
215 | lengths=self.lengths,
216 | offsets=self.offsets,
217 | stride=self.stride,
218 | length_per_key=self.length_per_key,
219 | offset_per_key=self.offset_per_key,
220 | index_per_key=self.index_per_key,
221 | ),
222 | labels=torch.from_numpy(labels.reshape(-1)),
223 | )
224 |
225 | def __len__(self) -> int:
226 | return self.num_batches
227 |
228 |
229 | def get_dataloader(args, stage, rank, world_size, assigned_tables = None):
230 | stage = stage.lower()
231 |
232 | files = os.listdir(args.dataset_dir)
233 |
234 | if stage == 'train':
235 | files = list(filter(lambda s: 'train' in s, files))
236 | else:
237 | files = list(filter(lambda s: 'train' not in s, files))
238 | rank = rank if stage == "val" else (rank + world_size)
239 | world_size = world_size * 2
240 |
241 | stage_files = [
242 | sorted(map(
243 | lambda s: os.path.join(args.dataset_dir, s),
244 | filter(lambda _f: kind in _f, files),
245 | )) for kind in ["dense", "sparse", "label"]
246 | ]
247 |
248 | dataloader = DataLoader(
249 | InMemoryAvazuIterDataPipe(*stage_files,
250 | batch_size=args.batch_size,
251 | rank=rank,
252 | world_size=world_size,
253 | shuffle_batches=args.shuffle_batches,
254 | hashes=args.num_embeddings_per_feature,
255 | assigned_tables=assigned_tables),
256 | batch_size=None,
257 | pin_memory=args.pin_memory,
258 | collate_fn=lambda x: x,
259 | )
260 |
261 | return dataloader
262 |
263 |
264 | def get_id_freq_map(path):
265 | files = os.listdir(path)
266 | files = list(filter(lambda s: "sparse" in s, files))
267 | files = [os.path.join(path, _f) for _f in files]
268 |
269 | feature_count = GlobalFeatureCounter(files, list(map(int, NUM_EMBEDDINGS_PER_FEATURE.split(','))))
270 | id_freq_map = torch.from_numpy(feature_count.compute())
271 | return id_freq_map
272 |
--------------------------------------------------------------------------------
/recsys/datasets/feature_counter.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import random
3 | from tqdm import tqdm
4 |
5 | import numpy as np
6 | from .criteo import DEFAULT_CAT_NAMES
7 | from petastorm import make_batch_reader
8 | from pyarrow.parquet import ParquetDataset
9 |
10 |
11 | class GlobalFeatureCounter:
12 | """
13 | compute the global statistics of the whole training set
14 | """
15 |
16 | def __init__(self, datafiles, hash_sizes):
17 | self.datafiles = datafiles
18 | self.hash_sizes = np.array(hash_sizes).reshape(1, -1)
19 | self.offsets = np.array([0, *np.cumsum(hash_sizes)[:-1]]).reshape(1, -1)
20 |
21 | def compute(self):
22 | id_freq_map = np.zeros(self.hash_sizes.sum(), dtype=np.int64)
23 | for _f in self.datafiles:
24 | arr = np.load(_f)
25 | arr %= self.hash_sizes
26 | arr += self.offsets
27 | flattened = arr.reshape(-1)
28 | id_freq_map += np.bincount(flattened, minlength=self.hash_sizes.sum())
29 | return id_freq_map
30 |
31 | class PetastormCounter:
32 |
33 | def __init__(self, datafiles, hash_sizes, subsample_fraction=0.2, seed=1024):
34 | self.datafiles = datafiles
35 | self.total_features = sum(hash_sizes)
36 |
37 | self.offsets = np.array([0, *np.cumsum(hash_sizes)[:-1]]).reshape(1, -1)
38 | self.subsample_fraction = subsample_fraction
39 | self.seed = seed
40 |
41 | def compute(self):
42 | _id_freq_map = np.zeros(self.total_features, dtype=np.int64)
43 |
44 | files = self.datafiles
45 | random.seed(self.seed)
46 | random.shuffle(files)
47 | if 0. < self.subsample_fraction < 1.:
48 | files = files[:int(np.ceil(len(files) * self.subsample_fraction))]
49 |
50 | dataset = ParquetDataset(files, use_legacy_dataset=False)
51 | with make_batch_reader(list(map(lambda x: "file://" + x, dataset.files)), num_epochs=1) as reader:
52 | for batch in tqdm(reader,
53 | ncols=0,
54 | desc="Collecting id-freq map",
55 | total=sum([fragment.metadata.num_row_groups for fragment in dataset.fragments])):
56 | sparse = np.concatenate([getattr(batch, col_name).reshape(-1, 1) for col_name in DEFAULT_CAT_NAMES],
57 | axis=1)
58 | sparse = (sparse + self.offsets).reshape(-1)
59 | _id_freq_map += np.bincount(sparse, minlength=self.total_features)
60 | return _id_freq_map
61 |
--------------------------------------------------------------------------------
/recsys/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.distributed as dist
4 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
5 | from torchrec.datasets.utils import Batch
6 |
7 |
8 | class KJTAllToAll:
9 | """
10 | Different from the module defined in torchrec.
11 |
12 | Basically, this class conducts all_gather with all_to_all collective.
13 | """
14 |
15 | def __init__(self, group):
16 | self.group = group
17 | self.rank = dist.get_rank(group)
18 | self.world_size = dist.get_world_size(group)
19 |
20 | @torch.no_grad()
21 | def all_to_all(self, kjt):
22 | if self.world_size == 1:
23 | return kjt
24 | # TODO: add sample weights
25 | values, lengths = kjt.values(), kjt.lengths()
26 | keys, batch_size = kjt.keys(), kjt.stride()
27 |
28 | # collect global data
29 | length_list = [lengths if i == self.rank else lengths.clone() for i in range(self.world_size)]
30 | all_length_list = [torch.empty_like(lengths) for _ in range(self.world_size)]
31 | dist.all_to_all(all_length_list, length_list, group=self.group)
32 |
33 | intermediate_all_length_list = [_length.view(-1, batch_size) for _length in all_length_list]
34 | all_length_per_key_list = [torch.sum(_length, dim=1).cpu().tolist() for _length in intermediate_all_length_list]
35 |
36 | all_value_length = [torch.sum(each).item() for each in all_length_list]
37 | value_list = [values if i == self.rank else values.clone() for i in range(self.world_size)]
38 | all_value_list = [
39 | torch.empty(_length, dtype=values.dtype, device=values.device) for _length in all_value_length
40 | ]
41 | dist.all_to_all(all_value_list, value_list, group=self.group)
42 |
43 | all_value_list = [
44 | torch.split(_values, _length_per_key) # [ key size, variable value size ]
45 | for _values, _length_per_key in zip(all_value_list, all_length_per_key_list) # world size
46 | ]
47 | all_values = torch.cat([torch.cat(values_per_key) for values_per_key in zip(*all_value_list)])
48 |
49 | all_lengths = torch.cat(intermediate_all_length_list, dim=1).view(-1)
50 | return KeyedJaggedTensor.from_lengths_sync(
51 | keys=keys,
52 | values=all_values,
53 | lengths=all_lengths,
54 | )
55 |
56 |
57 | class KJTTransform:
58 |
59 | def __init__(self, dataloader, hashes=None):
60 | self.batch_size = dataloader.batch_size
61 | self.cats = dataloader.cat_names
62 | self.conts = dataloader.cont_names
63 | self.labels = dataloader.label_names
64 | self.sparse_offset = torch.tensor(
65 | [0, *np.cumsum(hashes)[:-1]], dtype=torch.long, device=torch.cuda.current_device()).view(1, -1) \
66 | if hashes is not None else None
67 |
68 | _num_ids_in_batch = len(self.cats) * self.batch_size
69 | self.lengths = torch.ones((_num_ids_in_batch,), dtype=torch.int32)
70 | self.offsets = torch.arange(0, _num_ids_in_batch + 1, dtype=torch.int32)
71 | self.length_per_key = len(self.cats) * [self.batch_size]
72 | self.offset_per_key = [self.batch_size * i for i in range(len(self.cats) + 1)]
73 | self.index_per_key = {key: i for (i, key) in enumerate(self.cats)}
74 |
75 | def transform(self, batch):
76 | sparse, dense = [], []
77 | for col in self.cats:
78 | sparse.append(batch[0][col])
79 | sparse = torch.cat(sparse, dim=1)
80 | if self.sparse_offset is not None:
81 | sparse += self.sparse_offset
82 | for col in self.conts:
83 | dense.append(batch[0][col])
84 | dense = torch.cat(dense, dim=1)
85 |
86 | return Batch(
87 | dense_features=dense,
88 | sparse_features=KeyedJaggedTensor(
89 | keys=self.cats,
90 | values=sparse.transpose(1, 0).contiguous().view(-1),
91 | lengths=self.lengths,
92 | offsets=self.offsets,
93 | stride=self.batch_size,
94 | length_per_key=self.length_per_key,
95 | offset_per_key=self.offset_per_key,
96 | index_per_key=self.index_per_key,
97 | ),
98 | labels=batch[1],
99 | )
100 |
--------------------------------------------------------------------------------
/recsys/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/recsys/models/__init__.py
--------------------------------------------------------------------------------
/recsys/models/dlrm.py:
--------------------------------------------------------------------------------
1 | # The infrastructures of DLRM are mainly inspired by TorchRec:
2 | # https://github.com/pytorch/torchrec/blob/main/torchrec/models/dlrm.py
3 | import os
4 | import torch
5 | from contextlib import nullcontext
6 | import torch.nn as nn
7 | from torch.nn.parallel import DistributedDataParallel as DDP
8 | from torch.profiler import record_function
9 | from typing import List
10 | from baselines.models.dlrm import DenseArch, OverArch, InteractionArch, choose
11 | from ..utils import get_time_elapsed
12 | from ..datasets.utils import KJTAllToAll
13 | from ..utils import prepare_tablewise_config
14 | import colossalai
15 | from colossalai.nn.parallel.layers import ParallelCachedEmbeddingBag, EvictionStrategy, \
16 | TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewise
17 | from colossalai.core import global_context as gpc
18 | from colossalai.context.parallel_mode import ParallelMode
19 | import numpy as np
20 | from typing import Union
21 | from torchrec import KeyedJaggedTensor
22 |
23 | dist_logger = colossalai.logging.get_dist_logger()
24 |
25 |
26 | def sparse_embedding_shape_hook(embeddings, feature_size, batch_size):
27 | return embeddings.view(feature_size, batch_size, -1).transpose(0, 1)
28 |
29 | def sparse_embedding_shape_hook_for_tablewise(embeddings, feature_size, batch_size):
30 | return embeddings.view(embeddings.shape[0], feature_size, -1)
31 |
32 | class FusedSparseModules(nn.Module):
33 |
34 | def __init__(self,
35 | num_embeddings_per_feature,
36 | embedding_dim,
37 | fused_op='all_to_all',
38 | reduction_mode='sum',
39 | sparse=False,
40 | output_device_type=None,
41 | use_cache=False,
42 | cache_ratio=0.01,
43 | id_freq_map=None,
44 | warmup_ratio=0.7,
45 | buffer_size=50_000,
46 | is_dist_dataloader=True,
47 | use_lfu_eviction=False,
48 | use_tablewise_parallel=False,
49 | dataset: str = None):
50 | super(FusedSparseModules, self).__init__()
51 | self.sparse_feature_num = len(num_embeddings_per_feature)
52 | if use_cache:
53 | if use_tablewise_parallel:
54 | # establist config list
55 | world_size = torch.distributed.get_world_size()
56 | embedding_bag_config_list = prepare_tablewise_config(
57 | num_embeddings_per_feature, 0.01, id_freq_map, dataset, world_size)
58 | self.embed = ParallelCachedEmbeddingBagTablewise(
59 | embedding_bag_config_list,
60 | embedding_dim,
61 | sparse=sparse,
62 | mode=reduction_mode,
63 | include_last_offset=True,
64 | warmup_ratio=warmup_ratio,
65 | buffer_size=buffer_size,
66 | evict_strategy=EvictionStrategy.LFU if use_lfu_eviction else EvictionStrategy.DATASET,
67 | )
68 | self.shape_hook = sparse_embedding_shape_hook_for_tablewise
69 | else:
70 | self.embed = ParallelCachedEmbeddingBag(
71 | sum(num_embeddings_per_feature),
72 | embedding_dim,
73 | sparse=sparse,
74 | mode=reduction_mode,
75 | include_last_offset=True,
76 | cache_ratio=cache_ratio,
77 | ids_freq_mapping=id_freq_map,
78 | warmup_ratio=warmup_ratio,
79 | buffer_size=buffer_size,
80 | evict_strategy=EvictionStrategy.LFU if use_lfu_eviction else EvictionStrategy.DATASET
81 | )
82 | self.shape_hook = sparse_embedding_shape_hook
83 | else:
84 | raise NotImplementedError("Other EmbeddingBags are under development")
85 |
86 | if is_dist_dataloader:
87 | self.kjt_collector = KJTAllToAll(gpc.get_group(ParallelMode.GLOBAL))
88 | else:
89 | self.kjt_collector = None
90 |
91 | def forward(self, sparse_features : Union[List, KeyedJaggedTensor], cache_op: bool = True):
92 | self.embed.set_cache_op(cache_op)
93 | if self.kjt_collector:
94 | with record_function("(zhg)KJT AllToAll collective"):
95 | sparse_features = self.kjt_collector.all_to_all(sparse_features)
96 |
97 | if isinstance(sparse_features, list):
98 | batch_size = sparse_features[2]
99 | flattened_sparse_embeddings = self.embed(
100 | sparse_features[0],
101 | sparse_features[1],
102 | shape_hook=lambda x: self.shape_hook(x, self.sparse_feature_num , batch_size),
103 | )
104 | elif isinstance(sparse_features, KeyedJaggedTensor):
105 | batch_size = sparse_features.stride()
106 | flattened_sparse_embeddings = self.embed(
107 | sparse_features.values(),
108 | sparse_features.offsets(),
109 | shape_hook=lambda x: self.shape_hook(x, self.sparse_feature_num , batch_size),
110 | )
111 | else:
112 | raise TypeError
113 | return flattened_sparse_embeddings
114 |
115 |
116 | class FusedDenseModules(nn.Module):
117 | """
118 | Fusing dense operations of DLRM into a single module
119 | """
120 |
121 | def __init__(self, embedding_dim, num_sparse_features, dense_in_features, dense_arch_layer_sizes,
122 | over_arch_layer_sizes):
123 | super(FusedDenseModules, self).__init__()
124 | if dense_in_features <= 0:
125 | self.dense_arch = nn.Identity()
126 | over_in_features = choose(num_sparse_features, 2)
127 | num_dense = 0
128 | else:
129 | self.dense_arch = DenseArch(in_features=dense_in_features, layer_sizes=dense_arch_layer_sizes)
130 | over_in_features = (embedding_dim + choose(num_sparse_features, 2) + num_sparse_features)
131 | num_dense = 1
132 |
133 | self.inter_arch = InteractionArch(num_sparse_features=num_sparse_features, num_dense_features=num_dense)
134 | self.over_arch = OverArch(in_features=over_in_features, layer_sizes=over_arch_layer_sizes)
135 |
136 | def forward(self, dense_features, embedded_sparse_features):
137 | embedded_dense_features = self.dense_arch(dense_features)
138 | concat_dense = self.inter_arch(dense_features=embedded_dense_features, sparse_features=embedded_sparse_features)
139 | logits = self.over_arch(concat_dense)
140 |
141 | return logits
142 |
143 |
144 | class HybridParallelDLRM(nn.Module):
145 | """
146 | Model parallelized Embedding followed by Data parallelized dense modules
147 | """
148 |
149 | def __init__(self,
150 | num_embeddings_per_feature,
151 | embedding_dim,
152 | num_sparse_features,
153 | dense_in_features,
154 | dense_arch_layer_sizes,
155 | over_arch_layer_sizes,
156 | dense_device,
157 | sparse_device,
158 | sparse=False,
159 | fused_op='all_to_all',
160 | use_cache=False,
161 | cache_ratio=0.01,
162 | id_freq_map=None,
163 | warmup_ratio=0.7,
164 | buffer_size=50_000,
165 | is_dist_dataloader=True,
166 | use_lfu_eviction=False,
167 | use_tablewise=False,
168 | dataset: str = None):
169 |
170 | super(HybridParallelDLRM, self).__init__()
171 | if use_cache and sparse_device.type != dense_device.type:
172 | raise ValueError(f"Sparse device must be the same as dense device, "
173 | f"however we got {sparse_device.type} for sparse, {dense_device.type} for dense")
174 |
175 | self.dense_device = dense_device
176 | self.sparse_device = sparse_device
177 |
178 | self.sparse_modules = FusedSparseModules(num_embeddings_per_feature,
179 | embedding_dim,
180 | fused_op=fused_op,
181 | sparse=sparse,
182 | output_device_type=dense_device.type,
183 | use_cache=use_cache,
184 | cache_ratio=cache_ratio,
185 | id_freq_map=id_freq_map,
186 | warmup_ratio=warmup_ratio,
187 | buffer_size=buffer_size,
188 | is_dist_dataloader=is_dist_dataloader,
189 | use_lfu_eviction=use_lfu_eviction,
190 | use_tablewise_parallel=use_tablewise,
191 | dataset=dataset
192 | ).to(sparse_device)
193 | self.dense_modules = DDP(module=FusedDenseModules(embedding_dim, num_sparse_features, dense_in_features,
194 | dense_arch_layer_sizes,
195 | over_arch_layer_sizes).to(dense_device),
196 | device_ids=[0 if os.environ.get("NVT_TAG", None) else gpc.get_global_rank()],
197 | process_group=gpc.get_group(ParallelMode.GLOBAL),
198 | gradient_as_bucket_view=True,
199 | broadcast_buffers=False,
200 | static_graph=True)
201 |
202 | # precompute for parallelized embedding
203 | param_amount = sum(num_embeddings_per_feature) * embedding_dim
204 | param_storage = self.sparse_modules.embed.element_size() * param_amount
205 | param_amount += sum(p.numel() for p in self.dense_modules.parameters())
206 | param_storage += sum(p.numel() * p.element_size() for p in self.dense_modules.parameters())
207 | #
208 | buffer_amount = sum(b.numel() for b in self.sparse_modules.buffers()) + \
209 | sum(b.numel() for b in self.dense_modules.buffers())
210 | buffer_storage = sum(b.numel() * b.element_size() for b in self.sparse_modules.buffers()) + \
211 | sum(b.numel() * b.element_size() for b in self.dense_modules.buffers())
212 | stat_str = f"Number of model parameters: {param_amount:,}, storage overhead: {param_storage/1024**3:.2f} GB. " \
213 | f"Number of model buffers: {buffer_amount:,}, storage overhead: {buffer_storage/1024**3:.2f} GB."
214 | self.stat_str = stat_str
215 |
216 | def forward(self, dense_features, sparse_features, inspect_time=False,
217 | cache_op = True):
218 | ctx1 = get_time_elapsed(dist_logger, "embedding lookup in forward pass") \
219 | if inspect_time else nullcontext()
220 | with ctx1:
221 | with record_function("Embedding lookup:"):
222 | # B // world size, sparse feature dim, embedding dim
223 | embedded_sparse = self.sparse_modules(sparse_features, cache_op)
224 |
225 | ctx2 = get_time_elapsed(dist_logger, "dense operations in forward pass") \
226 | if inspect_time else nullcontext()
227 | with ctx2:
228 | with record_function("Dense operations:"):
229 | # B // world size, 1
230 | logits = self.dense_modules(dense_features, embedded_sparse)
231 |
232 | return logits
233 |
234 | def model_stats(self, prefix=""):
235 | return f"{prefix}: {self.stat_str}"
236 |
--------------------------------------------------------------------------------
/recsys/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import get_mem_info, compute_throughput, get_time_elapsed, Timer, get_partition, \
2 | TrainValTestResults, count_parameters, prepare_tablewise_config, get_tablewise_rank_arrange
3 | from .dataloader import CudaStreamDataIter, FiniteDataIter
4 |
5 | __all__ = [
6 | 'get_mem_info', 'compute_throughput', 'get_time_elapsed', 'Timer', 'get_partition', 'CudaStreamDataIter',
7 | 'FiniteDataIter', 'TrainValTestResults', 'count_parameters', 'prepare_tablewise_config',
8 | 'get_tablewise_rank_arrange'
9 | ]
10 |
--------------------------------------------------------------------------------
/recsys/utils/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | from .cuda_stream_dataloader import CudaStreamDataIter, FiniteDataIter
2 |
3 | __all__ = ['CudaStreamDataIter', 'FiniteDataIter']
4 |
--------------------------------------------------------------------------------
/recsys/utils/dataloader/base_dataiter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 |
4 | from abc import ABC, abstractmethod
5 | from typing import Optional
6 | import torch
7 | from torch.utils.data import DataLoader
8 |
9 |
10 | class BaseStreamDataIter(ABC):
11 |
12 | def __init__(self, loader: DataLoader):
13 | self.loader = loader
14 | self.iter = iter(loader)
15 | self.stream = torch.cuda.Stream()
16 | self._preload()
17 |
18 | @staticmethod
19 | def _move_tensor(element):
20 | if torch.is_tensor(element):
21 | if not element.is_cuda:
22 | return element.cuda(non_blocking=True)
23 | return element
24 |
25 | @staticmethod
26 | def _record_tensor(element, stream: torch.cuda.Stream) -> None:
27 | if torch.is_tensor(element):
28 | element.record_stream(stream)
29 |
30 | def record_stream(self, data, stream: Optional[torch.cuda.Stream] = None) -> None:
31 | if stream is None:
32 | stream = torch.cuda.current_stream()
33 |
34 | if isinstance(data, torch.Tensor):
35 | data.record_stream(stream)
36 | elif isinstance(data, (list, tuple)):
37 | for element in data:
38 | if isinstance(element, dict):
39 | for _k, v in element.items():
40 | self._record_tensor(v, stream)
41 | else:
42 | self._record_tensor(element, stream)
43 | elif isinstance(data, dict):
44 | for _k, v in data.items():
45 | self._record_tensor(v, stream)
46 | else:
47 | raise TypeError(
48 | f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
49 |
50 | def to_cuda(self, data):
51 | if isinstance(data, torch.Tensor):
52 | data = data.cuda(non_blocking=True)
53 | elif isinstance(data, (list, tuple)):
54 | data_to_return = []
55 | for element in data:
56 | if isinstance(element, dict):
57 | data_to_return.append({k: self._move_tensor(v) for k, v in element.items()})
58 | else:
59 | data_to_return.append(self._move_tensor(element))
60 | data = data_to_return
61 | elif isinstance(data, dict):
62 | data = {k: self._move_tensor(v) for k, v in data.items()}
63 | else:
64 | raise TypeError(
65 | f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
66 | return data
67 |
68 | @abstractmethod
69 | def _preload(self):
70 | pass
71 |
72 | @abstractmethod
73 | def _reset(self):
74 | pass
75 |
76 | @abstractmethod
77 | def __next__(self):
78 | pass
79 |
80 | @abstractmethod
81 | def __iter__(self):
82 | pass
83 |
84 |
--------------------------------------------------------------------------------
/recsys/utils/dataloader/cuda_stream_dataloader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | from typing import TypeVar, Iterator
4 |
5 | import torch
6 | from torch.utils.data import Sampler, Dataset, DataLoader
7 |
8 | from .base_dataiter import BaseStreamDataIter
9 |
10 |
11 | class CudaStreamDataIter(BaseStreamDataIter):
12 | """
13 | A data iterator that supports batch prefetching with the help of cuda stream.
14 | """
15 |
16 | def __init__(self, loader: DataLoader):
17 | super().__init__(loader)
18 |
19 | def _preload(self):
20 | try:
21 | self.batch_data = next(self.iter)
22 |
23 | except StopIteration:
24 | self.batch_data = None
25 | self._reset()
26 | return
27 |
28 | with torch.cuda.stream(self.stream):
29 | self.batch_data = self.to_cuda(self.batch_data)
30 |
31 | def _reset(self):
32 | self.iter = iter(self.loader)
33 | self.stream = torch.cuda.Stream()
34 | self._preload()
35 |
36 | def __next__(self):
37 | torch.cuda.current_stream().wait_stream(self.stream)
38 | batch_data = self.batch_data
39 |
40 | if batch_data is not None:
41 | self.record_stream(batch_data, torch.cuda.current_stream())
42 |
43 | self._preload()
44 | return batch_data
45 |
46 | def __iter__(self):
47 | return self
48 |
49 |
50 | class FiniteDataIter(BaseStreamDataIter):
51 |
52 | def _reset(self):
53 | self.iter = iter(self.loader)
54 | self.stream = torch.cuda.Stream()
55 | self._preload()
56 |
57 | def __init__(self, data_loader):
58 | super(FiniteDataIter, self).__init__(data_loader)
59 |
60 | def _preload(self):
61 | try:
62 | self.batch_data = next(self.iter)
63 |
64 | with torch.cuda.stream(self.stream):
65 | self.batch_data = self.batch_data.to(torch.cuda.current_device(), non_blocking=True)
66 |
67 | except StopIteration:
68 | self.batch_data = None
69 |
70 | def __next__(self):
71 | torch.cuda.current_stream().wait_stream(self.stream)
72 | batch_data = self.batch_data
73 | if batch_data is not None:
74 | batch_data.record_stream(torch.cuda.current_stream())
75 | else:
76 | raise StopIteration()
77 |
78 | self._preload()
79 | return batch_data
80 |
81 | def __iter__(self):
82 | return self
83 |
--------------------------------------------------------------------------------
/recsys/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import psutil
3 | from contextlib import contextmanager
4 | import time
5 | from time import perf_counter
6 | from dataclasses import dataclass, field
7 | from typing import List, Optional
8 | from colossalai.nn.parallel.layers import TablewiseEmbeddingBagConfig
9 |
10 | import numpy as np
11 |
12 | @dataclass
13 | class TrainValTestResults:
14 | val_accuracies: List[float] = field(default_factory=list)
15 | val_aurocs: List[float] = field(default_factory=list)
16 | test_accuracy: Optional[float] = None
17 | test_auroc: Optional[float] = None
18 |
19 | def count_parameters(model, prefix=''):
20 | trainable_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
21 | param_amount = sum(p.numel() for p in model.parameters())
22 | buffer_amount = sum(b.numel() for b in model.buffers())
23 | param_storage = sum([p.numel() * p.element_size() for p in model.parameters()])
24 | buffer_storage = sum([b.numel() * b.element_size() for b in model.buffers()])
25 | stats_str = f'{prefix}: {trainable_param:,}.' + '\n'
26 | stats_str += f"Number of model parameters: {param_amount:,}, storage overhead: {param_storage/1024**3:.2f} GB. "
27 | stats_str += f"Number of model buffers: {buffer_amount:,}, storage overhead: {buffer_storage/1024**3:.2f} GB."
28 | return stats_str
29 |
30 |
31 | def get_mem_info(prefix=''):
32 | return f'{prefix}GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB, ' \
33 | f'GPU memory reserved: {torch.cuda.memory_reserved() /1024**3:.2f} GB, ' \
34 | f'CPU memory usage: {psutil.Process().memory_info().rss / 1024**3:.2f} GB'
35 |
36 |
37 | @contextmanager
38 | def compute_throughput(batch_size) -> float:
39 | start = perf_counter()
40 | yield lambda: batch_size / ((perf_counter() - start) * 1000)
41 |
42 |
43 | @contextmanager
44 | def get_time_elapsed(logger, repr: str):
45 | timer = Timer()
46 | timer.start()
47 | yield
48 | elapsed = timer.stop()
49 | logger.info(f"Time elapsed for {repr}: {elapsed:.4f}s", ranks=[0])
50 |
51 |
52 | class Timer:
53 | """A timer object which helps to log the execution times, and provides different tools to assess the times.
54 | """
55 |
56 | def __init__(self):
57 | self._started = False
58 | self._start_time = time.time()
59 | self._elapsed = 0
60 | self._history = []
61 |
62 | @property
63 | def has_history(self):
64 | return len(self._history) != 0
65 |
66 | @property
67 | def current_time(self) -> float:
68 | torch.cuda.synchronize()
69 | return time.time()
70 |
71 | def start(self):
72 | """Firstly synchronize cuda, reset the clock and then start the timer.
73 | """
74 | self._elapsed = 0
75 | torch.cuda.synchronize()
76 | self._start_time = time.time()
77 | self._started = True
78 |
79 | def lap(self):
80 | """lap time and return elapsed time
81 | """
82 | return self.current_time - self._start_time
83 |
84 | def stop(self, keep_in_history: bool = False):
85 | """Stop the timer and record the start-stop time interval.
86 |
87 | Args:
88 | keep_in_history (bool, optional): Whether does it record into history
89 | each start-stop interval, defaults to False.
90 | Returns:
91 | int: Start-stop interval.
92 | """
93 | torch.cuda.synchronize()
94 | end_time = time.time()
95 | elapsed = end_time - self._start_time
96 | if keep_in_history:
97 | self._history.append(elapsed)
98 | self._elapsed = elapsed
99 | self._started = False
100 | return elapsed
101 |
102 | def get_history_mean(self):
103 | """Mean of all history start-stop time intervals.
104 |
105 | Returns:
106 | int: Mean of time intervals
107 | """
108 | return sum(self._history) / len(self._history)
109 |
110 | def get_history_sum(self):
111 | """Add up all the start-stop time intervals.
112 |
113 | Returns:
114 | int: Sum of time intervals.
115 | """
116 | return sum(self._history)
117 |
118 | def get_elapsed_time(self):
119 | """Return the last start-stop time interval.
120 |
121 | Returns:
122 | int: The last time interval.
123 |
124 | Note:
125 | Use it only when timer is not in progress
126 | """
127 | assert not self._started, 'Timer is still in progress'
128 | return self._elapsed
129 |
130 | def reset(self):
131 | """Clear up the timer and its history
132 | """
133 | self._history = []
134 | self._started = False
135 | self._elapsed = 0
136 |
137 |
138 | def get_partition(embedding_dim, rank, world_size):
139 | if world_size == 1:
140 | return 0, embedding_dim, True
141 |
142 | assert embedding_dim >= world_size, \
143 | f"Embedding dimension {embedding_dim} must be larger than the world size " \
144 | f"{world_size} of the process group"
145 | chunk_size = embedding_dim // world_size
146 | threshold = embedding_dim % world_size
147 | # if embedding dim is divisible by world size
148 | if threshold == 0:
149 | return rank * chunk_size, (rank + 1) * chunk_size, True
150 |
151 | # align with the split strategy of torch.tensor_split
152 | size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)]
153 | offset = sum(size_list[:rank])
154 | return offset, offset + size_list[rank], False
155 |
156 |
157 | def prepare_tablewise_config(num_embeddings_per_feature,
158 | cache_ratio,
159 | id_freq_map_total=None,
160 | dataset="criteo_kaggle",
161 | world_size=2):
162 | # WARNING, prototype. only support criteo_kaggle dataset and world_size == 2, 4
163 | # TODO: automatic arrange
164 | embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
165 | rank_arrange = get_tablewise_rank_arrange(dataset, world_size)
166 | table_offsets = np.array([0, *np.cumsum(num_embeddings_per_feature)])
167 | for i, num_embeddings in enumerate(num_embeddings_per_feature):
168 | ids_freq_mapping = None
169 | if id_freq_map_total != None:
170 | ids_freq_mapping = id_freq_map_total[table_offsets[i] : table_offsets[i + 1]]
171 | cuda_row_num = int(cache_ratio * num_embeddings) + 2000
172 | if cuda_row_num > num_embeddings:
173 | cuda_row_num = num_embeddings
174 | embedding_bag_config_list.append(
175 | TablewiseEmbeddingBagConfig(
176 | num_embeddings=num_embeddings,
177 | cuda_row_num=cuda_row_num,
178 | assigned_rank=rank_arrange[i],
179 | ids_freq_mapping=ids_freq_mapping
180 | )
181 | )
182 | return embedding_bag_config_list
183 |
184 | def get_tablewise_rank_arrange(dataset=None, world_size=0):
185 | if 'criteo' in dataset and 'kaggle' in dataset:
186 | if world_size == 1:
187 | rank_arrange = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
188 | elif world_size == 2:
189 | rank_arrange = [0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]
190 | elif world_size == 3:
191 | rank_arrange = [2, 1, 0, 1, 1, 2, 2, 1, 0, 0, 1, 1, 0, 1, 0, 2, 0, 2, 2, 0, 2, 2, 0, 1, 1, 0]
192 | elif world_size == 4:
193 | rank_arrange = [3, 1, 0, 3, 1, 0, 2, 1, 0, 2, 3, 1, 3, 1, 2, 3, 1, 2, 3, 0, 2, 0, 0, 2, 3, 2]
194 | elif world_size == 8:
195 | rank_arrange = [6, 6, 0, 4, 7, 2, 5, 7, 0, 5, 7, 1, 7, 3, 5, 3, 1, 6, 6, 0, 2, 2, 1, 4, 3, 4]
196 | else :
197 | raise NotImplementedError("Other Tablewise settings are under development")
198 | elif 'criteo' in dataset:
199 | if world_size == 1:
200 | rank_arrange = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
201 | elif world_size == 2:
202 | rank_arrange = [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]
203 | elif world_size == 4:
204 | rank_arrange = [1, 3, 3, 3, 3, 0, 2, 2, 1, 2, 2, 2, 0, 1, 2, 1, 0, 1, 0, 0, 2, 3, 3, 3, 1, 0]
205 | else :
206 | raise NotImplementedError("Other Tablewise settings are under development")
207 | else:
208 | raise NotImplementedError("Other Tablewise settings are under development")
209 | return rank_arrange
--------------------------------------------------------------------------------
/recsys/utils/preprocess_synth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # a = torch.tensor([1, 3, 5, 7, 14, 3, 40, 23])
3 | # b, c = torch.unique(a, sorted=True, return_inverse=True)
4 | # print(b)
5 | # print(c)
6 |
7 | BATCH_SIZE = 65536
8 | TABLLE_NUM = 856
9 | FILE_LIST = [f"/data/scratch/RecSys/embedding_bag/fbgemm_t856_bs65536_{i}.pt" for i in range(
10 | 16)] + ["/data/scratch/RecSys/embedding_bag/fbgemm_t856_bs65536.pt"]
11 | KEYS = []
12 | for i in range(TABLLE_NUM):
13 | KEYS.append("table_{}".format(i))
14 |
15 | CHOSEN_TABLES = [i for i in range(0, 856)]
16 |
17 | def load_file(file_path, cuda=True):
18 | indices, offsets, lengths = torch.load(file_path)
19 | if cuda:
20 | indices = indices.int().cuda()
21 | offsets = offsets.int().cuda()
22 | lengths = lengths.int().cuda()
23 | else :
24 | indices = indices.int()
25 | offsets = offsets.int()
26 | lengths = lengths.int()
27 | indices_per_table = []
28 | for i in range(TABLLE_NUM):
29 | if i not in CHOSEN_TABLES:
30 | continue
31 | start_pos = offsets[i * BATCH_SIZE]
32 | end_pos = offsets[i * BATCH_SIZE + BATCH_SIZE]
33 | part = indices[start_pos:end_pos]
34 | indices_per_table.append(part)
35 | return indices_per_table
36 |
37 | if __name__ == "__main__":
38 | indices_per_table_list= []
39 | indices_per_table_length_list = []
40 | for i, file in enumerate(FILE_LIST):
41 | indices_per_table = load_file(file, cuda=False)
42 | print("loaded ", file)
43 | for j, indices in enumerate(indices_per_table):
44 | if i == 0:
45 | indices_per_table_list.append([indices])
46 | indices_per_table_length_list.append([indices.shape[0]])
47 | else:
48 | indices_per_table_list[j].append(indices)
49 | indices_per_table_length_list[j].append(indices.shape[0])
50 |
51 | # unique op for each table:
52 | for i, (indices_list, length_list) in enumerate(zip(indices_per_table_list, indices_per_table_length_list)):
53 | catted = torch.cat(indices_list)
54 | _, processed = torch.unique(catted, sorted=True, return_inverse=True)
55 | indices_per_table_list[i] = torch.split(processed, length_list)
56 |
57 | # save to each file:
58 | for i, file in enumerate(FILE_LIST):
59 | _, offsets, lengths = torch.load(file)
60 | indices_per_table = [indices_in_table[i] for indices_in_table in indices_per_table_list]
61 | reconcatenate = torch.cat(indices_per_table)
62 | torch.save((reconcatenate, offsets, lengths),
63 | f"/home/lccsr/data2/embedding_bag_processed/fbgemm_t856_bs65536_processed_{i}.pt")
64 | print("saved, ", i)
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hpcaitech/CachedEmbedding/a2af3d7e7b0197519e6d018444688fcd9ba32c43/requirements.txt
--------------------------------------------------------------------------------
/scripts/avazu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # For Colossalai enabled recsys
4 | # avazu
5 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
6 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
7 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
8 | # --profile_dir "tensorboard_log/avazu/w1_p1_16k" --buffer_size 0 --use_overlap --cache_sets 94458
9 | #
10 | torchx run -s local_cwd -cfg log_dir=log/avazu/w2_p1_16k dist.ddp -j 1x2 --script recsys/dlrm_main.py -- \
11 | --dataset_dir /data/scratch/RecSys/avazu_sample --pin_memory --shuffle_batches \
12 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
13 | --profile_dir "tensorboard_log/avazu/w2_p1_16k" --buffer_size 0 --use_overlap --cache_sets 94458
14 |
15 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w4_p1_16k dist.ddp -j 1x4 --script recsys/dlrm_main.py -- \
16 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
17 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
18 | # --profile_dir "tensorboard_log/avazu/w4_p1_16k" --buffer_size 0 --use_overlap --cache_sets 94458
19 | #
20 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w8_p1_16k dist.ddp -j 1x8 --script recsys/dlrm_main.py -- \
21 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
22 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
23 | # --profile_dir "tensorboard_log/avazu/w8_p1_16k" --buffer_size 0 --use_overlap --cache_sets 94458
24 | #
25 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_32k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
26 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
27 | # --learning_rate 1. --batch_size 32768 --use_sparse_embed_grad --use_cache --use_freq \
28 | # --profile_dir "tensorboard_log/avazu/w1_p1_32k" --buffer_size 0 --use_overlap --cache_sets 94458
29 | #
30 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_8k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
31 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
32 | # --learning_rate 1. --batch_size 8192 --use_sparse_embed_grad --use_cache --use_freq \
33 | # --profile_dir "tensorboard_log/avazu/w1_p1_8k" --buffer_size 0 --use_overlap --cache_sets 94458
34 | #
35 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_4k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
36 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
37 | # --learning_rate 1. --batch_size 4096 --use_sparse_embed_grad --use_cache --use_freq \
38 | # --profile_dir "tensorboard_log/avazu/w1_p1_4k" --buffer_size 0 --use_overlap --cache_sets 94458
39 | #
40 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_2k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
41 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
42 | # --learning_rate 1. --batch_size 2048 --use_sparse_embed_grad --use_cache --use_freq \
43 | # --profile_dir "tensorboard_log/avazu/w1_p1_2k" --buffer_size 0 --use_overlap --cache_sets 94458
44 | #
45 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p1_1k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
46 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
47 | # --learning_rate 1. --batch_size 1024 --use_sparse_embed_grad --use_cache --use_freq \
48 | # --profile_dir "tensorboard_log/avazu/w1_p1_1k" --buffer_size 0 --use_overlap --cache_sets 94458
49 | #
50 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p10_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
51 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
52 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
53 | # --profile_dir "tensorboard_log/avazu/w1_p10_16k" --buffer_size 0 --use_overlap --cache_sets 944582
54 | #
55 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p5_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
56 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
57 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
58 | # --profile_dir "tensorboard_log/avazu/w1_p5_16k" --buffer_size 0 --use_overlap --cache_sets 472291
59 | #
60 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p2_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
61 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
62 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
63 | # --profile_dir "tensorboard_log/avazu/w1_p2_16k" --buffer_size 0 --use_overlap --cache_sets 188916
64 | #
65 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p0_5_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
66 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
67 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
68 | # --profile_dir "tensorboard_log/avazu/w1_p0_5_16k" --buffer_size 0 --use_overlap --cache_sets 47229
69 | #
70 | # torchx run -s local_cwd -cfg log_dir=log/avazu/w1_p0_1_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
71 | # --dataset_dir /data/avazu_sample --pin_memory --shuffle_batches \
72 | # --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
73 | # --profile_dir "tensorboard_log/avazu/w1_p0_1_16k" --buffer_size 0 --use_overlap --cache_sets 9445
74 | #
--------------------------------------------------------------------------------
/scripts/kaggle.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # For Colossalai enabled recsys
4 | # criteo kaggle
5 | export DATAPATH=/data/scratch/RecSys/criteo_kaggle_data/
6 | # export DATAPATH=/data/criteo_kaggle_data/
7 | export GPUNUM=1
8 | # export BATCHSIZE=16384
9 | # export BATCHSIZE=1024
10 | export CACHESIZE=337625
11 | export CACHERATIO=0.01
12 | export USE_LFU=1
13 | export USE_TABLE_SHARD=0
14 | export EVAL_ACC=0
15 | export PREFETCH_NUM=2
16 | export USE_ASYNC=0
17 |
18 | if [[ ${USE_LFU} == 1 ]]; then
19 | LFU_FLAG="--use_lfu"
20 | else
21 | export LFU_FLAG=""
22 | fi
23 |
24 |
25 | if [[ ${USE_TABLE_SHARD} == 1 ]]; then
26 | TABLE_SHARD_FLAG="--use_tablewise"
27 | else
28 | export TABLE_SHARD_FLAG=""
29 | fi
30 |
31 | if [[ ${EVAL_ACC} == 1 ]]; then
32 | EVAL_ACC_FLAG="--eval_acc"
33 | else
34 | export EVAL_ACC_FLAG=""
35 | fi
36 |
37 | if [[ ${USE_ASYNC} == 1 ]]; then
38 | ASYNC_FLAG="--use_cache_mgr_async_copy"
39 | else
40 | export ASYNC_FLAG=""
41 | fi
42 |
43 | set_n_least_used_CUDA_VISIBLE_DEVICES() {
44 | local n=${1:-"9999"}
45 | echo "GPU Memory Usage:"
46 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
47 | | tail -n +2 \
48 | | nl -v 0 \
49 | | tee /dev/tty \
50 | | sort -g -k 2 \
51 | | awk '{print $1}' \
52 | | head -n $n)
53 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
54 | echo "Now CUDA_VISIBLE_DEVICES is set to:"
55 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
56 | }
57 |
58 |
59 | mkdir -p colo_logs
60 | for BATCHSIZE in 16384 #16384 #1024 2048 4096 16384
61 | do
62 | for PREFETCH_NUM in 1 #1 16
63 | do
64 | for USE_ASYNC in 0
65 | do
66 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM}
67 |
68 | TASK_NAME="gpu_${GPUNUM}_bs_${BATCHSIZE}_pf_${PREFETCH_NUM}_cache_${CACHERATIO}_async_${USE_ASYNC}"
69 | torchx run -s local_cwd -cfg log_dir=log/kaggle/${TASK_NAME} dist.ddp -j 1x${GPUNUM} --script recsys/dlrm_main.py -- \
70 | --dataset_dir ${DATAPATH} --pin_memory --shuffle_batches \
71 | --learning_rate 1. --batch_size ${BATCHSIZE} --use_sparse_embed_grad --use_cache --use_freq ${LFU_FLAG} ${TABLE_SHARD_FLAG} ${EVAL_ACC_FLAG} ${ASYNC_FLAG} \
72 | --profile_dir "tensorboard_log/kaggle/${TASK_NAME}" --buffer_size 0 --use_overlap --cache_ratio ${CACHERATIO} --prefetch_num ${PREFETCH_NUM} 2>&1 | tee colo_logs/colo_${TASK_NAME}.txt
73 | done
74 | done
75 | done
76 |
77 | # torchx run -s local_cwd -cfg log_dir=log/kaggle/w${GPUNUM}_p1_16k dist.ddp -j 1x${GPUNUM} --script recsys/dlrm_main.py -- \
78 | # --dataset_dir ${DATAPATH} --pin_memory --shuffle_batches \
79 | # --learning_rate 1. --batch_size ${BATCHSIZE} --use_sparse_embed_grad --use_cache --use_freq --use_lfu \
80 | # --profile_dir "tensorboard_log/kaggle/w${GPUNUM}_p1_16k" --buffer_size 0 --use_overlap --cache_sets ${CACHESIZE} 2>&1 | tee logs/colo_${GPUNUM}_${BATCHSIZE}_${CACHESIZE}.txt
81 |
--------------------------------------------------------------------------------
/scripts/preprocess/.gitignore:
--------------------------------------------------------------------------------
1 | /*
2 | !/.gitignore
3 | !/npy_preproc_criteo.py
4 | !/split_criteo_kaggle.py
5 | !/npy_preproc_avazu.py
6 | !/taobao
7 | !README.md
8 |
--------------------------------------------------------------------------------
/scripts/preprocess/README.md:
--------------------------------------------------------------------------------
1 | # Dataset preprocess
2 | ## Criteo Kaggle
3 | 1. download data from: https://ailab.criteo.com/ressources/
4 | 2. convert tsvs to npy files
5 | ```bash
6 | python npy_preproc_criteo.py --input_dir --output_dir
7 | ```
8 | 3. split to train/val/test files
9 | ```bash
10 | python split_criteo_kaggle.py
11 | ```
12 | You might need to change the `'/data/scratch/RecSys/criteo_kaggle_npy'` to `` in this file
13 |
14 | ## Avazu
15 | 1. download data from: https://www.kaggle.com/c/avazu-ctr-prediction/data
16 | 2. convert tsv files to npy files:
17 | ```bash
18 | python npy_preproc_avazu.py --input_dir --output_dir
19 | ```
20 | 3. split train/val/test files
21 | ```bash
22 | python npy_preproc_avazu.py --input_dir --output_dir --is_split
23 | ```
24 |
25 | ## Criteo Terabyte
26 | 1. download tsv source file from: https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/
27 | Note that this requires around 1TB disk space.
28 | 2. clone the torchrec repo:
29 | ```bash
30 | git clone https://github.com/pytorch/torchrec.git
31 | cd torchrec/torchrec/datasets/scripts/nvt/
32 | ```
33 | 3. conduct the first two python script
34 | ```bash
35 | python convert_tsv_to_parquet.py -i -o
36 | python process_criteo_parquet.py -b -s
37 | ```
38 | You might need to use the dockerfile to install nvtabular,
39 | since its installation requires a CUDA version different from our experiment setup
--------------------------------------------------------------------------------
/scripts/preprocess/npy_preproc_avazu.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 |
6 | from recsys.datasets.avazu import CAT_FEATURE_COUNT, AvazuIterDataPipe, TOTAL_TRAINING_SAMPLES
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser()
11 |
12 | parser.add_argument("--input_dir",
13 | type=str,
14 | required=True,
15 | help="path to the dir where the csv file train is downloaded and unzipped")
16 |
17 | parser.add_argument("--output_dir",
18 | type=str,
19 | required=True,
20 | help="path to which the train/val/test splits are saved")
21 |
22 | parser.add_argument("--is_split", action='store_true')
23 | return parser.parse_args()
24 |
25 |
26 | def main():
27 | # Note: this scripts is broken, to align with our experiments,
28 | # please refer to https://www.kaggle.com/code/leejunseok97/deepfm-deepctr-torch
29 | # Basically, the C14-C21 column of the resulting sparse files should be further split to the dense files.
30 | args = parse_args()
31 |
32 | if args.is_split:
33 | if not os.path.exists(args.input_dir):
34 | raise ValueError(f"{args.input_dir} has existed")
35 |
36 | if not os.path.exists(args.output_dir):
37 | os.makedirs(args.output_dir)
38 |
39 | for _t in ("sparse", 'label'):
40 | npy = np.load(os.path.join(args.input_dir, f"{_t}.npy"))
41 | train_split = npy[:TOTAL_TRAINING_SAMPLES]
42 | np.save(os.path.join(args.output_dir, f"train_{_t}.npy"), train_split)
43 | val_test_split = npy[TOTAL_TRAINING_SAMPLES:]
44 | np.save(os.path.join(args.output_dir, f"val_test_{_t}.npy"), val_test_split)
45 | del npy
46 |
47 | else:
48 | if not os.path.exists(args.output_dir):
49 | os.makedirs(args.output_dir)
50 | sparse_output_file_path = os.path.join(args.output_dir, "sparse.npy")
51 | label_output_file_path = os.path.join(args.output_dir, "label.npy")
52 |
53 | sparse, labels = [], []
54 | for row_sparse, row_label in AvazuIterDataPipe(args.input_dir):
55 | sparse.append(row_sparse)
56 | labels.append(row_label)
57 |
58 | sparse_np = np.array(sparse, dtype=np.int32)
59 | del sparse
60 | labels_np = np.array(labels, dtype=np.int32).reshape(-1, 1)
61 | del labels
62 |
63 | for f_path, arr in [(sparse_output_file_path, sparse_np), (label_output_file_path, labels_np)]:
64 | np.save(f_path, arr)
65 |
66 |
67 | if __name__ == "__main__":
68 | main()
69 |
--------------------------------------------------------------------------------
/scripts/preprocess/npy_preproc_criteo.py:
--------------------------------------------------------------------------------
1 | # This script preprocesses Criteo dataset tsv files to binary (npy) files.
2 |
3 | # In order to exploit InMemoryBinaryCriteoIterDataPipe to accelerate the loading of data,
4 | # this file is modified from torchrec/datasets/scripts/npy_preproc_criteo.py and
5 | # torchrec.datasets.criteo
6 | #
7 | # Usage:
8 | # python npy_preproc_criteo.py --input_dir /where/criteo_kaggle/train.txt --output_dir /where/to/save/.npy
9 | #
10 | # You may need additional modifications for the file name,
11 | # and you can evenly split the whole dataset into 7 days by split_criteo_kaggle.py
12 |
13 | import argparse
14 | import os
15 | import sys
16 | from typing import List, Tuple
17 |
18 | import numpy as np
19 | from torchrec.datasets.criteo import CriteoIterDataPipe, INT_FEATURE_COUNT, CAT_FEATURE_COUNT
20 | from torchrec.datasets.utils import PATH_MANAGER_KEY, safe_cast
21 | from iopath.common.file_io import PathManagerFactory
22 |
23 |
24 | def tsv_to_npys(
25 | in_file: str,
26 | out_dense_file: str,
27 | out_sparse_file: str,
28 | out_labels_file: str,
29 | path_manager_key: str = PATH_MANAGER_KEY,
30 | ):
31 | """
32 | For criteo kaggle
33 | """
34 |
35 | def row_mapper(row: List[str]) -> Tuple[List[int], List[int], int]:
36 | label = safe_cast(row[0], int, 0)
37 | dense = [safe_cast(row[i], int, 0) for i in range(1, 1 + INT_FEATURE_COUNT)]
38 | sparse = [
39 | int(safe_cast(row[i], str, "0") or "0", 16)
40 | for i in range(1 + INT_FEATURE_COUNT, 1 + INT_FEATURE_COUNT + CAT_FEATURE_COUNT)
41 | ]
42 | return dense, sparse, label # pyre-ignore[7]
43 |
44 | dense, sparse, labels = [], [], []
45 | for (row_dense, row_sparse, row_label) in CriteoIterDataPipe([in_file], row_mapper=row_mapper):
46 | dense.append(row_dense)
47 | sparse.append(row_sparse)
48 | labels.append(row_label)
49 |
50 | dense_np = np.array(dense, dtype=np.int32)
51 | del dense
52 | sparse_np = np.array(sparse, dtype=np.int32)
53 | del sparse
54 | labels_np = np.array(labels, dtype=np.int32)
55 | del labels
56 |
57 | # Why log +3?
58 | dense_np -= (dense_np.min() - 2)
59 | dense_np = np.log(dense_np, dtype=np.float32)
60 |
61 | labels_np = labels_np.reshape((-1, 1))
62 | path_manager = PathManagerFactory().get(path_manager_key)
63 | for (fname, arr) in [
64 | (out_dense_file, dense_np),
65 | (out_sparse_file, sparse_np),
66 | (out_labels_file, labels_np),
67 | ]:
68 | with path_manager.open(fname, "wb") as fout:
69 | np.save(fout, arr)
70 |
71 |
72 | def parse_args(argv: List[str]) -> argparse.Namespace:
73 | parser = argparse.ArgumentParser(description="Criteo tsv -> npy preprocessing script.")
74 | parser.add_argument(
75 | "--input_dir",
76 | type=str,
77 | required=True,
78 | help="Input directory containing Criteo tsv files. Files in the directory "
79 | "should be named day_{0-23}.",
80 | )
81 | parser.add_argument(
82 | "--output_dir",
83 | type=str,
84 | required=True,
85 | help="Output directory to store npy files.",
86 | )
87 | return parser.parse_args(argv)
88 |
89 |
90 | def main(argv: List[str]) -> None:
91 | """
92 | This function preprocesses the raw Criteo tsvs into the format (npy binary)
93 | expected by InMemoryBinaryCriteoIterDataPipe.
94 |
95 | Args:
96 | argv (List[str]): Command line args.
97 |
98 | Returns:
99 | None.
100 | """
101 |
102 | args = parse_args(argv)
103 | input_dir = args.input_dir
104 | output_dir = args.output_dir
105 |
106 | for f in os.listdir(input_dir):
107 | in_file_path = os.path.join(input_dir, f)
108 | dense_out_file_path = os.path.join(output_dir, f + "_dense.npy")
109 | sparse_out_file_path = os.path.join(output_dir, f + "_sparse.npy")
110 | labels_out_file_path = os.path.join(output_dir, f + "_labels.npy")
111 | print(f"Processing {in_file_path}. Outputs will be saved to {dense_out_file_path}"
112 | f", {sparse_out_file_path}, and {labels_out_file_path}...")
113 | tsv_to_npys(
114 | in_file_path,
115 | dense_out_file_path,
116 | sparse_out_file_path,
117 | labels_out_file_path,
118 | )
119 | print(f"Done processing {in_file_path}.")
120 |
121 |
122 | if __name__ == "__main__":
123 | main(sys.argv[1:])
124 |
--------------------------------------------------------------------------------
/scripts/preprocess/split_criteo_kaggle.py:
--------------------------------------------------------------------------------
1 | # This script adapts criteo kaggle dataset into 7 days' data
2 | #
3 | # Please alter the arguments of the two functions here.
4 | #
5 | # Usage:
6 | # python split_criteo_kaggle.py
7 |
8 | import numpy as np
9 | import os
10 |
11 | from torchrec.datasets.criteo import BinaryCriteoUtils, CAT_FEATURE_COUNT
12 |
13 |
14 | def main(data_dir, output_dir, days=7):
15 | STAGES = ("labels", "dense", "sparse")
16 | files = [os.path.join(data_dir, f"train.txt_{split}.npy") for split in STAGES]
17 | total_rows = BinaryCriteoUtils.get_shape_from_npy(files[0])[0]
18 |
19 | indices = list(range(0, total_rows, total_rows // days))
20 | ranges = []
21 | for i in range(len(indices) - 1):
22 | left_idx = indices[i]
23 | right_idx = indices[i + 1] if i < len(indices) - 2 else total_rows
24 | ranges.append((left_idx, right_idx))
25 |
26 | for _s, _f in zip(STAGES, files):
27 | for day, (left_idx, right_idx) in enumerate(ranges):
28 | chunk = BinaryCriteoUtils.load_npy_range(_f, left_idx, right_idx - left_idx)
29 | output_fname = f"day_{day}_{_s}.npy"
30 | np.save(os.path.join(output_dir, output_fname), chunk)
31 |
32 |
33 | def get_num_embeddings_per_feature(path_to_sparse):
34 | sparse = np.load(path_to_sparse)
35 | assert sparse.shape[1] == CAT_FEATURE_COUNT
36 |
37 | nums = []
38 | for i in range(CAT_FEATURE_COUNT):
39 | nums.append(len(np.unique(sparse[:, i])))
40 | print(','.join(map(lambda x: str(x), nums)))
41 |
42 |
43 | if __name__ == '__main__':
44 | main('/data/scratch/RecSys/criteo_kaggle_npy', 'criteo_kaggle')
45 | get_num_embeddings_per_feature('/data/scratch/RecSys/criteo_kaggle_npy/train.txt_sparse.npy')
46 |
--------------------------------------------------------------------------------
/scripts/preprocess/taobao/README.md:
--------------------------------------------------------------------------------
1 | ### Preprocessing Scripts for Taobao User Behavior Dataset
2 |
3 | Credit: https://github.com/STAR-Laboratory/Accelerating-RecSys-Training
4 |
5 |
6 | Set the global constants in `csv_to_txt.py` before running: `python ./csv_to_txt.py`.
7 |
8 | Set the raw data paths and desired output data paths before running `bash ./run_txt_to_npz.sh`
--------------------------------------------------------------------------------
/scripts/preprocess/taobao/csv_to_txt.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 UIC-Paper
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | #
6 | # Source: https://github.com/UIC-Paper/MIMN
7 |
8 | import pickle as pkl
9 | import pandas as pd
10 | import random
11 | import numpy as np
12 | import argparse
13 |
14 | RAW_DATA_FILE = "./UserBehavior.csv"
15 | DATASET_PKL = "./dataset.pkl"
16 | Test_File = "./taobao_test.txt"
17 | Train_File = "./taobao_train.txt"
18 | Train_handle = open(Train_File, 'w')
19 | Test_handle = open(Test_File, 'w')
20 | Feature_handle = open("./taobao_feature.pkl", 'wb')
21 |
22 | MAX_LEN_ITEM = 200
23 |
24 |
25 | def to_df(file_name):
26 | df = pd.read_csv(RAW_DATA_FILE, header=None, names=['uid', 'iid', 'cid', 'btag', 'time'])
27 | return df
28 |
29 |
30 | def remap(df):
31 | item_key = sorted(df['iid'].unique().tolist())
32 | item_len = len(item_key)
33 | item_map = dict(zip(item_key, range(item_len)))
34 |
35 | df['iid'] = df['iid'].map(lambda x: item_map[x])
36 |
37 | user_key = sorted(df['uid'].unique().tolist())
38 | user_len = len(user_key)
39 | user_map = dict(zip(user_key, range(item_len, item_len + user_len)))
40 | df['uid'] = df['uid'].map(lambda x: user_map[x])
41 |
42 | cate_key = sorted(df['cid'].unique().tolist())
43 | cate_len = len(cate_key)
44 | cate_map = dict(zip(cate_key, range(user_len + item_len, user_len + item_len + cate_len)))
45 | df['cid'] = df['cid'].map(lambda x: cate_map[x])
46 |
47 | btag_key = sorted(df['btag'].unique().tolist())
48 | btag_len = len(btag_key)
49 | btag_map = dict(zip(btag_key, range(user_len + item_len + cate_len, user_len + item_len + cate_len + btag_len)))
50 | df['btag'] = df['btag'].map(lambda x: btag_map[x])
51 |
52 | print(item_len, user_len, cate_len, btag_len)
53 | return df, item_len, user_len + item_len + cate_len + btag_len + 1 #+1 is for unknown target btag
54 |
55 |
56 | def gen_user_item_group(df, item_cnt, feature_size):
57 | user_df = df.sort_values(['uid', 'time']).groupby('uid')
58 | item_df = df.sort_values(['iid', 'time']).groupby('iid')
59 |
60 | print("group completed")
61 | return user_df, item_df
62 |
63 |
64 | def gen_dataset(user_df, item_df, item_cnt, feature_size, dataset_pkl):
65 | train_sample_list = []
66 | test_sample_list = []
67 |
68 | # get each user's last touch point time
69 |
70 | print(len(user_df))
71 |
72 | user_last_touch_time = []
73 | for uid, hist in user_df:
74 | user_last_touch_time.append(hist['time'].tolist()[-1])
75 | print("get user last touch time completed")
76 |
77 | user_last_touch_time_sorted = sorted(user_last_touch_time)
78 | split_time = user_last_touch_time_sorted[int(len(user_last_touch_time_sorted) * 0.7)]
79 |
80 | cnt = 0
81 | for uid, hist in user_df:
82 | cnt += 1
83 | print(cnt)
84 | item_hist = hist['iid'].tolist()
85 | cate_hist = hist['cid'].tolist()
86 | btag_hist = hist['btag'].tolist()
87 | target_item_time = hist['time'].tolist()[-1]
88 |
89 | target_item = item_hist[-1]
90 | target_item_cate = cate_hist[-1]
91 | target_item_btag = feature_size
92 | label = 1
93 | test = (target_item_time > split_time)
94 |
95 | # neg sampling
96 | neg = random.randint(0, 1)
97 | if neg == 1:
98 | label = 0
99 | while target_item == item_hist[-1]:
100 | target_item = random.randint(0, item_cnt - 1)
101 | target_item_cate = item_df.get_group(target_item)['cid'].tolist()[0]
102 | target_item_btag = feature_size
103 |
104 | # the item history part of the sample
105 | item_part = []
106 | for i in range(len(item_hist) - 1):
107 | item_part.append([uid, item_hist[i], cate_hist[i], btag_hist[i]])
108 | item_part.append([uid, target_item, target_item_cate, target_item_btag])
109 | # item_part_len = min(len(item_part), MAX_LEN_ITEM)
110 |
111 | # choose the item side information: which user has clicked the target item
112 | # padding history with 0
113 | if len(item_part) <= MAX_LEN_ITEM:
114 | item_part_pad = [[0] * 4] * (MAX_LEN_ITEM - len(item_part)) + item_part
115 | else:
116 | item_part_pad = item_part[len(item_part) - MAX_LEN_ITEM:len(item_part)]
117 |
118 | # gen sample
119 | # sample = (label, item_part_pad, item_part_len, user_part_pad, user_part_len)
120 |
121 | if test:
122 | # test_set.append(sample)
123 | cat_list = []
124 | item_list = []
125 | # btag_list = []
126 | for i in range(len(item_part_pad)):
127 | item_list.append(item_part_pad[i][1])
128 | cat_list.append(item_part_pad[i][2])
129 | # cat_list.append(item_part_pad[i][0])
130 | test_sample_list.append(
131 | str(uid) + "\t" + str(target_item) + "\t" + str(target_item_cate) + "\t" + str(label) + "\t" +
132 | ",".join(map(str, item_list)) + "\t" + ",".join(map(str, cat_list)) + "\n")
133 | else:
134 | cat_list = []
135 | item_list = []
136 | # btag_list = []
137 | for i in range(len(item_part_pad)):
138 | item_list.append(item_part_pad[i][1])
139 | cat_list.append(item_part_pad[i][2])
140 | train_sample_list.append(
141 | str(uid) + "\t" + str(target_item) + "\t" + str(target_item_cate) + "\t" + str(label) + "\t" +
142 | ",".join(map(str, item_list)) + "\t" + ",".join(map(str, cat_list)) + "\n")
143 |
144 | train_sample_length_quant = len(train_sample_list) / 256 * 256
145 | test_sample_length_quant = len(test_sample_list) / 256 * 256
146 |
147 | print("train_sample_length_quant", train_sample_length_quant)
148 | print("length", len(train_sample_list))
149 | train_sample_list = train_sample_list[:int(train_sample_length_quant)]
150 | test_sample_list = test_sample_list[:int(test_sample_length_quant)]
151 | random.shuffle(train_sample_list)
152 | print("length", len(train_sample_list))
153 | return train_sample_list, test_sample_list
154 |
155 |
156 | def produce_neg_item_hist_with_cate(train_file, test_file):
157 | item_dict = {}
158 | sample_count = 0
159 | hist_seq = 0
160 | for line in train_file:
161 | units = line.strip().split("\t")
162 | item_hist_list = units[4].split(",")
163 | cate_hist_list = units[5].split(",")
164 | hist_list = zip(item_hist_list, cate_hist_list)
165 | hist_list = list(hist_list)
166 | #hist_seq_list = list(hist_list)
167 | hist_seq = len(hist_list)
168 | sample_count += 1
169 | for item in hist_list:
170 | item_dict.setdefault(str(item), 0)
171 |
172 | #print("hist_list : ", hist_list)
173 |
174 | for line in test_file:
175 | units = line.strip().split("\t")
176 | item_hist_list = units[4].split(",")
177 | cate_hist_list = units[5].split(",")
178 | hist_list = zip(item_hist_list, cate_hist_list)
179 | hist_list = list(hist_list)
180 | #hist_seq_list = list(hist_list)
181 | hist_seq = len(hist_list)
182 | sample_count += 1
183 | for item in hist_list:
184 | item_dict.setdefault(str(item), 0)
185 |
186 | #print("item_dict : ", item_dict)
187 | del (item_dict["('0', '0')"])
188 | keys_list = list(item_dict.keys())
189 | keys_list = np.array(keys_list)
190 | print("item_dict.keys()", keys_list.shape)
191 | neg_array = np.random.choice(keys_list, (sample_count, hist_seq + 20))
192 | neg_list = neg_array.tolist()
193 | sample_count = 0
194 |
195 | for line in train_file:
196 | units = line.strip().split("\t")
197 | item_hist_list = units[4].split(",")
198 | cate_hist_list = units[5].split(",")
199 | hist_list = zip(item_hist_list, cate_hist_list)
200 | hist_list = list(hist_list)
201 | #hist_seq_list = list(hist_list)
202 | hist_seq = len(hist_list)
203 | neg_hist_list = []
204 | for item in neg_list[sample_count]:
205 | item = eval(item)
206 | if item not in hist_list:
207 | neg_hist_list.append(item)
208 | if len(neg_hist_list) == hist_seq:
209 | break
210 | sample_count += 1
211 | neg_item_list, neg_cate_list = zip(*neg_hist_list)
212 | Train_handle.write(line.strip() + "\t" + ",".join(neg_item_list) + "\t" + ",".join(neg_cate_list) + "\n")
213 |
214 | for line in test_file:
215 | units = line.strip().split("\t")
216 | item_hist_list = units[4].split(",")
217 | cate_hist_list = units[5].split(",")
218 | hist_list = zip(item_hist_list, cate_hist_list)
219 | hist_list = list(hist_list)
220 | #hist_seq_list = list(hist_list)
221 | hist_seq = len(hist_list)
222 | neg_hist_list = []
223 | for item in neg_list[sample_count]:
224 | item = eval(item)
225 | if item not in hist_list:
226 | neg_hist_list.append(item)
227 | if len(neg_hist_list) == hist_seq:
228 | break
229 | sample_count += 1
230 | neg_item_list, neg_cate_list = zip(*neg_hist_list)
231 | Test_handle.write(line.strip() + "\t" + ",".join(neg_item_list) + "\t" + ",".join(neg_cate_list) + "\n")
232 |
233 |
234 | def main():
235 | df = to_df(RAW_DATA_FILE)
236 | df, item_cnt, feature_size = remap(df)
237 | print("feature_size", item_cnt, feature_size)
238 | feature_total_num = feature_size + 1
239 | pkl.dump(feature_total_num, Feature_handle)
240 |
241 | user_df, item_df = gen_user_item_group(df, item_cnt, feature_size)
242 | train_sample_list, test_sample_list = gen_dataset(user_df, item_df, item_cnt, feature_size, DATASET_PKL)
243 | produce_neg_item_hist_with_cate(train_sample_list, test_sample_list)
244 |
245 |
246 | if __name__ == '__main__':
247 | main()
248 |
--------------------------------------------------------------------------------
/scripts/preprocess/taobao/run_txt_to_npz.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | tbsm_py="python txt_to_npz.py "
9 |
10 | $tbsm_py --datatype="taobao" \
11 | --num-train-pts=690000 --num-val-pts=300000 --points-per-user=10
12 | --numpy-rand-seed=123 --arch-embedding-size="987994-4162024-9439"
13 | --raw-train-file=./taobao_train.txt \
14 | --raw-test-file=./taobao_test.txt \
15 | --pro-train-file=./taobao_train_t20.npz \
16 | --pro-val-file=./taobao_val_t20.npz \
17 | --pro-test-file=./taobao_test_t20.npz \
18 | --ts-length=20
--------------------------------------------------------------------------------
/scripts/preprocess/taobao/txt_to_npz.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | import sys
3 | import argparse
4 | import numpy as np
5 |
6 |
7 | class TaobaoTxtToNpz:
8 |
9 | def __init__(
10 | self,
11 | datatype,
12 | mode,
13 | ts_length=20,
14 | points_per_user=4,
15 | numpy_rand_seed=7,
16 | raw_path="",
17 | pro_data="",
18 | spa_fea_sizes="",
19 | num_pts=1, # pts to train or test
20 | ):
21 | # save arguments
22 | if mode == "train":
23 | self.numpy_rand_seed = numpy_rand_seed
24 | else:
25 | self.numpy_rand_seed = numpy_rand_seed + 31
26 | self.mode = mode
27 | # save dataset parameters
28 | self.total = num_pts # number of lines in txt to process
29 | self.ts_length = ts_length
30 | self.points_per_user = points_per_user # pos and neg points per user
31 | self.spa_fea_sizes = spa_fea_sizes
32 | self.M = 200 # max history length
33 |
34 | # split the datafile into path and filename
35 | lstr = raw_path.split("/")
36 | self.d_path = "/".join(lstr[0:-1]) + "/"
37 | self.d_file = lstr[-1]
38 |
39 | # preprocess data if needed
40 | if path.exists(str(pro_data)):
41 | print("Reading pre-processed data=%s" % (str(pro_data)))
42 | file = str(pro_data)
43 | else:
44 | file = str(pro_data)
45 | levels = np.fromstring(self.spa_fea_sizes, dtype=int, sep="-")
46 | if datatype == "taobao":
47 | self.Unum = levels[0] # 987994 num of users
48 | self.Inum = levels[1] # 4162024 num of items
49 | self.Cnum = levels[2] # 9439 num of categories
50 | print("Reading raw data=%s" % (str(raw_path)))
51 | if self.mode == "test":
52 | self.build_taobao_test(
53 | raw_path,
54 | file,
55 | )
56 | else:
57 | self.build_taobao_train_or_val(
58 | raw_path,
59 | file,
60 | )
61 | elif datatype == "synthetic":
62 | self.build_synthetic_train_or_val(file,)
63 | # load data
64 | with np.load(file) as data:
65 | self.X_cat = data["X_cat"]
66 | self.X_int = data["X_int"]
67 | self.y = data["y"]
68 |
69 | # common part between train/val and test generation
70 | # truncates (if needed) and shuffles data points
71 | def truncate_and_save(self, out_file, do_shuffle, t, users, items, cats, times, y):
72 | # truncate. If for some users we didn't generate had too short history
73 | # we truncate the unused portion of the pre-allocated matrix.
74 | if t < self.total_out:
75 | users = users[:t, :]
76 | items = items[:t, :]
77 | cats = cats[:t, :]
78 | times = times[:t, :]
79 | y = y[:t]
80 |
81 | # shuffle
82 | if do_shuffle:
83 | indices = np.arange(len(y))
84 | indices = np.random.permutation(indices)
85 | users = users[indices]
86 | items = items[indices]
87 | cats = cats[indices]
88 | times = times[indices]
89 | y = y[indices]
90 |
91 | N = len(y)
92 | X_cat = np.zeros((3, N, self.ts_length + 1), dtype="i4") # 4 byte int
93 | X_int = np.zeros((1, N, self.ts_length + 1), dtype=np.float)
94 | X_cat[0, :, :] = users
95 | X_cat[1, :, :] = items
96 | X_cat[2, :, :] = cats
97 | X_int[0, :, :] = times
98 |
99 | # saving to compressed numpy file
100 | if not path.exists(out_file):
101 | np.savez_compressed(
102 | out_file,
103 | X_cat=X_cat,
104 | X_int=X_int,
105 | y=y,
106 | )
107 | return
108 |
109 | # processes raw train or validation into npz format required by training
110 | # for train data out of each line in raw datafile produces several randomly chosen
111 | # datapoints, max number of datapoints per user is specified by points_per_user
112 | # argument, for validation data produces one datapoint per user.
113 | def build_taobao_train_or_val(self, raw_path, out_file):
114 | with open(str(raw_path)) as f:
115 | for i, _ in enumerate(f):
116 | if i % 50000 == 0:
117 | print("pre-processing line: ", i)
118 | self.total = min(self.total, i + 1)
119 |
120 | print("total lines: ", self.total)
121 |
122 | self.total_out = self.total * self.points_per_user * 2 # pos + neg points
123 | print("Total number of points in raw datafile: ", self.total)
124 | print("Total number of points in output will be at most: ", self.total_out)
125 | np.random.seed(self.numpy_rand_seed)
126 | r_target = np.arange(0, self.M - 1)
127 |
128 | time = np.arange(self.ts_length + 1, dtype=np.int32) / (self.ts_length + 1)
129 | # time = np.ones(self.ts_length + 1, dtype=np.int32)
130 |
131 | users = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int
132 | items = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int
133 | cats = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int
134 | times = np.zeros((self.total_out, self.ts_length + 1), dtype=np.float)
135 | y = np.zeros(self.total_out, dtype="i4") # 4 byte int
136 |
137 | # determine how many datapoints to take from each user based on the length of
138 | # user behavior sequence
139 | # ind=0, 1, 2, 3,... t < 10, 20, 30, 40, 50, 60, ...
140 | k = 20
141 | regime = np.zeros(k, dtype=np.int)
142 | regime[1], regime[2], regime[3] = 1, 3, 6
143 | for j in range(4, k):
144 | regime[j] = self.points_per_user
145 | if self.mode == "val":
146 | self.points_per_user = 1
147 | for j in range(k):
148 | regime[j] = np.min([regime[j], self.points_per_user])
149 | last = self.M - 1 # max index of last item
150 |
151 | # try to generate the desired number of points (time series) per each user.
152 | # if history is short it may not succeed to generate sufficiently different
153 | # time series for a particular user.
154 | t, t_pos, t_neg, t_short = 0, 0, 0, 0
155 | with open(str(raw_path)) as f:
156 | for i, line in enumerate(f):
157 | if i % 1000 == 0:
158 | print("processing line: ", i, t, t_pos, t_neg, t_short)
159 | if i >= self.total:
160 | break
161 | units = line.strip().split("\t")
162 | item_hist_list = units[4].split(",")
163 | cate_hist_list = units[5].split(",")
164 | neg_item_hist_list = units[6].split(",")
165 | neg_cate_hist_list = units[7].split(",")
166 | user = np.array(np.maximum(np.int32(units[0]) - self.Inum, 0), dtype=np.int32)
167 | # y[i] = np.int32(units[3])
168 | items_ = np.array(list(map(lambda x: np.maximum(np.int32(x), 0), item_hist_list)), dtype=np.int32)
169 | cats_ = np.array(list(map(lambda x: np.maximum(np.int32(x) - self.Inum - self.Unum, 0),
170 | cate_hist_list)),
171 | dtype=np.int32)
172 | neg_items_ = np.array(list(map(lambda x: np.maximum(np.int32(x), 0), neg_item_hist_list)),
173 | dtype=np.int32)
174 | neg_cats_ = np.array(list(
175 | map(lambda x: np.maximum(np.int32(x) - self.Inum - self.Unum, 0), neg_cate_hist_list)),
176 | dtype=np.int32)
177 |
178 | # select datapoints
179 | first = np.argmax(items_ > 0)
180 | ind = int((last - first) // 10) # index into regime array
181 | # pos
182 | for _ in range(regime[ind]):
183 | a1 = min(first + self.ts_length, last - 1)
184 | end = np.random.randint(a1, last)
185 | indices = np.arange(end - self.ts_length, end + 1)
186 | if items_[indices[0]] == 0:
187 | t_short += 1
188 | items[t] = items_[indices]
189 | cats[t] = cats_[indices]
190 | users[t] = np.full(self.ts_length + 1, user)
191 | times[t] = time
192 | y[t] = 1
193 | # check
194 | if np.any(users[t] < 0) or np.any(items[t] < 0) \
195 | or np.any(cats[t] < 0):
196 | sys.exit("Categorical feature less than zero after \
197 | processing. Aborting...")
198 | t += 1
199 | t_pos += 1
200 | # neg
201 | for _ in range(regime[ind]):
202 | a1 = min(first + self.ts_length - 1, last - 1)
203 | end = np.random.randint(a1, last)
204 | indices = np.arange(end - self.ts_length + 1, end + 1)
205 | if items_[indices[0]] == 0:
206 | t_short += 1
207 | items[t, :-1] = items_[indices]
208 | cats[t, :-1] = cats_[indices]
209 | neg_indices = np.random.choice(r_target, 1, replace=False) # random final item
210 | items[t, -1] = neg_items_[neg_indices]
211 | cats[t, -1] = neg_cats_[neg_indices]
212 | users[t] = np.full(self.ts_length + 1, user)
213 | times[t] = time
214 | y[t] = 0
215 | # check
216 | if np.any(users[t] < 0) or np.any(items[t] < 0) \
217 | or np.any(cats[t] < 0):
218 | sys.exit("Categorical feature less than zero after \
219 | processing. Aborting...")
220 | t += 1
221 | t_neg += 1
222 |
223 | print("total points, pos points, neg points: ", t, t_pos, t_neg)
224 |
225 | self.truncate_and_save(out_file, True, t, users, items, cats, times, y)
226 | return
227 |
228 | # processes raw test datafile into npz format required to be used by
229 | # inference step, produces one datapoint per user by taking last ts-length items
230 | def build_taobao_test(self, raw_path, out_file):
231 |
232 | with open(str(raw_path)) as f:
233 | for i, _ in enumerate(f):
234 | if i % 50000 == 0:
235 | print("pre-processing line: ", i)
236 | self.total = i + 1
237 |
238 | self.total_out = self.total # pos + neg points
239 | print("ts_length: ", self.ts_length)
240 | print("Total number of points in raw datafile: ", self.total)
241 | print("Total number of points in output will be at most: ", self.total_out)
242 |
243 | time = np.arange(self.ts_length + 1, dtype=np.int32) / (self.ts_length + 1)
244 |
245 | users = np.zeros((self.total_out, self.ts_length + 1), dtypei4="") # 4 byte int
246 | items = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int
247 | cats = np.zeros((self.total_out, self.ts_length + 1), dtype="i4") # 4 byte int
248 | times = np.zeros((self.total_out, self.ts_length + 1), dtype=np.float)
249 | y = np.zeros(self.total_out, dtype="i4") # 4 byte int
250 |
251 | # try to generate the desired number of points (time series) per each user.
252 | # if history is short it may not succeed to generate sufficiently different
253 | # time series for a particular user.
254 | t, t_pos, t_neg = 0, 0, 0
255 | with open(str(raw_path)) as f:
256 | for i, line in enumerate(f):
257 | if i % 1000 == 0:
258 | print("processing line: ", i, t, t_pos, t_neg)
259 | if i >= self.total:
260 | break
261 | units = line.strip().split("\t")
262 | item_hist_list = units[4].split(",")
263 | cate_hist_list = units[5].split(",")
264 |
265 | user = np.array(np.maximum(np.int32(units[0]) - self.Inum, 0), dtype=np.int32)
266 | y[t] = np.int32(units[3])
267 | items_ = np.array(list(map(lambda x: np.maximum(np.int32(x), 0), item_hist_list)), dtype=np.int32)
268 | cats_ = np.array(list(map(lambda x: np.maximum(np.int32(x) - self.Inum - self.Unum, 0),
269 | cate_hist_list)),
270 | dtype=np.int32)
271 |
272 | # get pts
273 | items[t] = items_[-(self.ts_length + 1):]
274 | cats[t] = cats_[-(self.ts_length + 1):]
275 | users[t] = np.full(self.ts_length + 1, user)
276 | times[t] = time
277 | # check
278 | if np.any(users[t] < 0) or np.any(items[t] < 0) \
279 | or np.any(cats[t] < 0):
280 | sys.exit("Categorical feature less than zero after \
281 | processing. Aborting...")
282 | if y[t] == 1:
283 | t_pos += 1
284 | else:
285 | t_neg += 1
286 | t += 1
287 |
288 | print("total points, pos points, neg points: ", t, t_pos, t_neg)
289 |
290 | self.truncate_and_save(out_file, False, t, users, items, cats, times, y)
291 | return
292 |
293 | # builds small synthetic data mimicking the structure of taobao data
294 | def build_synthetic_train_or_val(self, out_file):
295 |
296 | np.random.seed(123)
297 | fea_sizes = np.fromstring(self.spa_fea_sizes, dtype=int, sep="-")
298 | maxval = np.min(fea_sizes)
299 | num_s = len(fea_sizes)
300 | X_cat = np.random.randint(maxval, size=(num_s, self.total, self.ts_length + 1), dtype="i4") # 4 byte int
301 | X_int = np.random.uniform(0, 1, size=(1, self.total, self.ts_length + 1))
302 | y = np.random.randint(0, 2, self.total, dtype="i4") # 4 byte int
303 |
304 | # saving to compressed numpy file
305 | if not path.exists(out_file):
306 | np.savez_compressed(
307 | out_file,
308 | X_cat=X_cat,
309 | X_int=X_int,
310 | y=y,
311 | )
312 | return
313 |
314 |
315 | # creates a loader (train, val or test data) to be used in the main training loop
316 | # or during inference step
317 | def make_tbsm_data_and_loader(args, mode):
318 | if mode == "train":
319 | raw = args.raw_train_file
320 | proc = args.pro_train_file
321 | numpts = args.num_train_pts
322 | elif mode == "val":
323 | raw = args.raw_train_file
324 | proc = args.pro_val_file
325 | numpts = args.num_val_pts
326 | else:
327 | raw = args.raw_test_file
328 | proc = args.pro_test_file
329 | numpts = 1
330 |
331 | TaobaoTxtToNpz(
332 | args.datatype,
333 | mode,
334 | args.ts_length,
335 | args.points_per_user,
336 | args.numpy_rand_seed,
337 | raw,
338 | proc,
339 | args.arch_embedding_size,
340 | numpts,
341 | )
342 |
343 |
344 | def main(args):
345 | make_tbsm_data_and_loader(args, 'train')
346 | make_tbsm_data_and_loader(args, 'val')
347 | make_tbsm_data_and_loader(args, 'test')
348 |
349 |
350 | if __name__ == '__main__':
351 | parser = argparse.ArgumentParser()
352 | parser.add_argument("--datatype", type=str, default="taobao")
353 | parser.add_argument("--raw-train-file", type=str, default="./input/train.txt")
354 | parser.add_argument("--pro-train-file", type=str, default="./output/train.npz")
355 | parser.add_argument("--raw-test-file", type=str, default="./input/test.txt")
356 | parser.add_argument("--pro-test-file", type=str, default="./output/test.npz")
357 | parser.add_argument("--pro-val-file", type=str, default="./output/val.npz")
358 | parser.add_argument("--ts-length", type=int, default=20)
359 | parser.add_argument("--num-train-pts", type=int, default=100)
360 | parser.add_argument("--num-val-pts", type=int, default=20)
361 | parser.add_argument("--points-per-user", type=int, default=10)
362 | parser.add_argument("--arch-embedding-size", type=str, default="4-3-2") # vectors
363 | parser.add_argument("--numpy-rand-seed", type=int, default=123)
364 | args = parser.parse_args()
365 | main(args)
366 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # For Colossalai enabled recsys
4 | # bash scripts/kaggle.sh
5 |
6 | bash scripts/torchrec_kaggle.sh
7 |
8 | # bash scripts/avazu.sh
9 |
10 | bash scripts/torchrec_avazu.sh
11 |
12 | # bash scripts/terabyte.sh
13 |
14 | bash scripts/torchrec_terabyte.sh
15 |
--------------------------------------------------------------------------------
/scripts/terabyte.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # For Colossalai enabled recsys
3 |
4 | # criteo terabyte
5 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
6 | --dataset_dir /data/criteo_preproc/ \
7 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
8 | --profile_dir "tensorboard_log/terabyte/w1_p1_16k" --buffer_size 0 --use_overlap --cache_sets 1779442
9 |
10 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w2_p1_16k dist.ddp -j 1x2 --script recsys/dlrm_main.py -- \
11 | --dataset_dir /data/criteo_preproc/ \
12 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
13 | --profile_dir "tensorboard_log/terabyte/w2_p1_16k" --buffer_size 0 --use_overlap --cache_sets 1779442
14 |
15 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w4_p1_16k dist.ddp -j 1x4 --script recsys/dlrm_main.py -- \
16 | --dataset_dir /data/criteo_preproc/ \
17 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
18 | --profile_dir "tensorboard_log/terabyte/w4_p1_16k" --buffer_size 0 --use_overlap --cache_sets 1779442
19 |
20 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w8_p1_16k dist.ddp -j 1x8 --script recsys/dlrm_main.py -- \
21 | --dataset_dir /data/criteo_preproc/ \
22 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
23 | --profile_dir "tensorboard_log/terabyte/w8_p1_16k" --buffer_size 0 --use_overlap --cache_sets 1779442
24 |
25 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_32k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
26 | --dataset_dir /data/criteo_preproc/ \
27 | --learning_rate 1. --batch_size 32768 --use_sparse_embed_grad --use_cache --use_freq \
28 | --profile_dir "tensorboard_log/terabyte/w1_p1_32k" --buffer_size 0 --use_overlap --cache_sets 1779442
29 |
30 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_8k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
31 | --dataset_dir /data/criteo_preproc/ \
32 | --learning_rate 1. --batch_size 8192 --use_sparse_embed_grad --use_cache --use_freq \
33 | --profile_dir "tensorboard_log/terabyte/w1_p1_8k" --buffer_size 0 --use_overlap --cache_sets 1779442
34 |
35 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_4k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
36 | --dataset_dir /data/criteo_preproc/ \
37 | --learning_rate 1. --batch_size 4096 --use_sparse_embed_grad --use_cache --use_freq \
38 | --profile_dir "tensorboard_log/terabyte/w1_p1_4k" --buffer_size 0 --use_overlap --cache_sets 1779442
39 |
40 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_2k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
41 | --dataset_dir /data/criteo_preproc/ \
42 | --learning_rate 1. --batch_size 2048 --use_sparse_embed_grad --use_cache --use_freq \
43 | --profile_dir "tensorboard_log/terabyte/w1_p1_2k" --buffer_size 0 --use_overlap --cache_sets 1779442
44 |
45 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p1_1k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
46 | --dataset_dir /data/criteo_preproc/ \
47 | --learning_rate 1. --batch_size 1024 --use_sparse_embed_grad --use_cache --use_freq \
48 | --profile_dir "tensorboard_log/terabyte/w1_p1_1k" --buffer_size 0 --use_overlap --cache_sets 1779442
49 |
50 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p10_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
51 | --dataset_dir /data/criteo_preproc/ \
52 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
53 | --profile_dir "tensorboard_log/terabyte/w1_p10_16k" --buffer_size 0 --use_overlap --cache_sets 17794427
54 |
55 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p5_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
56 | --dataset_dir /data/criteo_preproc/ \
57 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
58 | --profile_dir "tensorboard_log/terabyte/w1_p5_16k" --buffer_size 0 --use_overlap --cache_sets 8897213
59 |
60 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p2_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
61 | --dataset_dir /data/criteo_preproc/ \
62 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
63 | --profile_dir "tensorboard_log/terabyte/w1_p2_16k" --buffer_size 0 --use_overlap --cache_sets 3558885
64 |
65 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p0_1_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
66 | --dataset_dir /data/criteo_preproc/ \
67 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
68 | --profile_dir "tensorboard_log/terabyte/w1_p0_1_16k" --buffer_size 0 --use_overlap --cache_sets 177944
69 |
70 | torchx run -s local_cwd -cfg log_dir=log/terabyte/w1_p0_5_16k dist.ddp -j 1x1 --script recsys/dlrm_main.py -- \
71 | --dataset_dir /data/criteo_preproc/ \
72 | --learning_rate 1. --batch_size 16384 --use_sparse_embed_grad --use_cache --use_freq \
73 | --profile_dir "tensorboard_log/terabyte/w1_p0_5_16k" --buffer_size 0 --use_overlap --cache_sets 889721
74 |
75 |
--------------------------------------------------------------------------------
/scripts/torchrec_avazu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # For TorchRec baseline
4 | torchx run -s local_cwd -cfg log_dir=log/torchrec_avazu/w1_16k dist.ddp -j 1x1 --script baselines/dlrm_main.py -- \
5 | --in_memory_binary_criteo_path /data/avazu_sample --embedding_dim 128 --pin_memory \
6 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \
7 | --learning_rate 1. --batch_size 16384 --profile_dir "tensorboard_log/torchrec_avazu/w1_16k"
8 |
9 | torchx run -s local_cwd -cfg log_dir=log/torchrec_avazu/w2_16k dist.ddp -j 1x2 --script baselines/dlrm_main.py -- \
10 | --in_memory_binary_criteo_path /data/avazu_sample --embedding_dim 128 --pin_memory \
11 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \
12 | --learning_rate 1. --batch_size 8192 --profile_dir "tensorboard_log/torchrec_avazu/w2_16k"
13 |
14 | torchx run -s local_cwd -cfg log_dir=log/torchrec_avazu/w4_16k dist.ddp -j 1x4 --script baselines/dlrm_main.py -- \
15 | --in_memory_binary_criteo_path /data/avazu_sample --embedding_dim 128 --pin_memory \
16 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \
17 | --learning_rate 1. --batch_size 4096 --profile_dir "tensorboard_log/torchrec_avazu/w4_16k"
18 |
19 | torchx run -s local_cwd -cfg log_dir=log/torchrec_avazu/w8_16k dist.ddp -j 1x8 --script baselines/dlrm_main.py -- \
20 | --in_memory_binary_criteo_path /data/avazu_sample --embedding_dim 128 --pin_memory \
21 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \
22 | --learning_rate 1. --batch_size 2048 --profile_dir "tensorboard_log/torchrec_avazu/w8_16k"
--------------------------------------------------------------------------------
/scripts/torchrec_custom.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | export PYTHONPATH=$HOME/codes/torchrec:$PYTHONPATH
5 |
6 | set_n_least_used_CUDA_VISIBLE_DEVICES() {
7 | local n=${1:-"9999"}
8 | echo "GPU Memory Usage:"
9 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
10 | | tail -n +2 \
11 | | nl -v 0 \
12 | | tee /dev/tty \
13 | | sort -g -k 2 \
14 | | awk '{print $1}' \
15 | | head -n $n)
16 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
17 | echo "Now CUDA_VISIBLE_DEVICES is set to:"
18 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
19 | }
20 |
21 | # export DATAPATH=/data/scratch/RecSys/embedding_bag
22 | export DATAPATH=custom
23 | export EVAL_ACC=0
24 | export EMB_DIM=128
25 | export POOLING_FACTOR=8
26 |
27 | if [[ ${EVAL_ACC} == 1 ]]; then
28 | EVAL_ACC_FLAG="--eval_acc"
29 | else
30 | export EVAL_ACC_FLAG=""
31 | fi
32 |
33 |
34 | mkdir -p logs
35 | for PREFETCH_NUM in 1 # 32 4 8 16
36 | do
37 | for GPUNUM in 1 2 4 8 # 4 8 # 1 # 2
38 | do
39 | for BATCHSIZE in 8192 #2048 4096 1024 #8192 512 ##16384 8192 4096 2048 1024 512
40 | do
41 | for SHARDTYPE in "table" # "tablecolumn" "column" "row" "tablerow" "table"
42 | do
43 | for KERNELTYPE in "colossalai" # "fused" # "uvm_lfu" # "colossalai" # "uvm_lfu" # "colossalai"
44 | do
45 | # For TorchRec baseline
46 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM}
47 | export PLAN=g${GPUNUM}_bs_${BATCHSIZE}_pool_${POOLING_FACTOR}_${SHARDTYPE}_pf_${PREFETCH_NUM}_eb_${EMB_DIM}_${KERNELTYPE}_custom
48 | rm -rf ./tensorboard_log/torchrec_custom/
49 | # env CUDA_LAUNCH_BLOCKING=1
50 | # timeout -s SIGKILL 30m
51 | torchx run -s local_cwd -cfg log_dir=log/torchrec_custom/${PLAN} dist.ddp -j 1x${GPUNUM} --script baselines/dlrm_main.py -- \
52 | --in_memory_binary_criteo_path ${DATAPATH} --kaggle --embedding_dim ${EMB_DIM} --pin_memory --cache_ratio 0.01 \
53 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,${EMB_DIM}" --shuffle_batches \
54 | --learning_rate 1. --batch_size ${BATCHSIZE} --profile_dir "" --shard_type ${SHARDTYPE} --kernel_type ${KERNELTYPE} \
55 | --prefetch_num ${PREFETCH_NUM} --pooling_factor ${POOLING_FACTOR} ${EVAL_ACC_FLAG} 2>&1 | tee logs/torchrec_${PLAN}.txt
56 | done
57 | done
58 | done
59 | done
60 | done
61 |
--------------------------------------------------------------------------------
/scripts/torchrec_kaggle.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 | # export PYTHONPATH=$HOME/codes/torchrec:$PYTHONPATH
4 | # export DATAPATH=/data/scratch/RecSys/criteo_kaggle_data/
5 | export DATAPATH=/data/criteo_kaggle_data/
6 |
7 | set_n_least_used_CUDA_VISIBLE_DEVICES() {
8 | local n=${1:-"9999"}
9 | echo "GPU Memory Usage:"
10 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
11 | | tail -n +2 \
12 | | nl -v 0 \
13 | | tee /dev/tty \
14 | | sort -g -k 2 \
15 | | awk '{print $1}' \
16 | | head -n $n)
17 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
18 | echo "Now CUDA_VISIBLE_DEVICES is set to:"
19 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
20 | }
21 |
22 | export GPUNUM=1
23 |
24 | for EMB_DIM in 128 #64 96
25 | do
26 | for PREFETCH_NUM in 1 #1 8 16 32
27 | do
28 | for GPUNUM in 1 2 4 8
29 | do
30 | for KERNELTYPE in "colossalai" # "fused" # "colossalai"
31 | do
32 | for BATCHSIZE in 8192 #16384 8192 4096 2048 1024 512
33 | do
34 | for SHARDTYPE in "table" # "column" "row" "tablecolumn" "tablerow" "table"
35 | # for SHARDTYPE in "tablerow"
36 | do
37 | # For TorchRec baseline
38 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM}
39 | rm -rf ./tensorboard_log/torchrec_kaggle/w${GPUNUM}_${BATCHSIZE}_${SHARDTYPE}
40 |
41 | LOG_DIR=./logs/${KERNELTYPE}_${SHARDTYPE}_logs
42 | mkdir -p ${LOG_DIR}
43 |
44 | torchx run -s local_cwd -cfg log_dir=log/torchrec_kaggle/${PLAN} dist.ddp -j 1x${GPUNUM} --script baselines/dlrm_main.py -- \
45 | --in_memory_binary_criteo_path ${DATAPATH} --kaggle --embedding_dim ${EMB_DIM} --pin_memory --cache_ratio 0.20 \
46 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,${EMB_DIM}" --shuffle_batches --eval_acc \
47 | --learning_rate 1. --batch_size ${BATCHSIZE} --profile_dir "" --shard_type ${SHARDTYPE} --kernel_type ${KERNELTYPE} --prefetch_num ${PREFETCH_NUM} ${EVAL_ACC_FLAG} 2>&1 | tee ./${LOG_DIR}/torchrec_${PLAN}.txt
48 | done
49 | done
50 | done
51 | done
52 | done
53 | done
54 |
--------------------------------------------------------------------------------
/scripts/torchrec_synth.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | export PYTHONPATH=$HOME/codes/torchrec:$PYTHONPATH
5 |
6 | set_n_least_used_CUDA_VISIBLE_DEVICES() {
7 | local n=${1:-"9999"}
8 | echo "GPU Memory Usage:"
9 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
10 | | tail -n +2 \
11 | | nl -v 0 \
12 | | tee /dev/tty \
13 | | sort -g -k 2 \
14 | | awk '{print $1}' \
15 | | head -n $n)
16 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
17 | echo "Now CUDA_VISIBLE_DEVICES is set to:"
18 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
19 | }
20 |
21 | export DATAPATH=/data/scratch/RecSys/embedding_bag
22 | # export DATAPATH=custom
23 | export SCALE="512M" # "4M ""52M" "512M"
24 | export EVAL_ACC=0
25 | export EMB_DIM=128
26 |
27 | if [[ ${EVAL_ACC} == 1 ]]; then
28 | EVAL_ACC_FLAG="--eval_acc"
29 | else
30 | export EVAL_ACC_FLAG=""
31 | fi
32 |
33 | # local batch size
34 | # 4
35 | # export BATCHSIZE=1024
36 | # export BATCHSIZE=1024
37 |
38 | # export BATCHSIZE=4096
39 | # 2
40 | # export BATCHSIZE=8192
41 | # 1
42 |
43 | mkdir -p logs
44 | for PREFETCH_NUM in 1 32 4 8 16 #8 16 32
45 | do
46 | for GPUNUM in 1 # 1 # 2
47 | do
48 | for BATCHSIZE in 256 #2048 4096 1024 #8192 512 ##16384 8192 4096 2048 1024 512
49 | do
50 | for SHARDTYPE in "table"
51 | do
52 | for KERNELTYPE in "colossalai" # "uvm_lfu" # "colossalai" # "uvm_lfu" # "colossalai"
53 | do
54 | # For TorchRec baseline
55 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM}
56 | export PLAN=g${GPUNUM}_bs_${BATCHSIZE}_${SHARDTYPE}_pf_${PREFETCH_NUM}_eb_${EMB_DIM}
57 | rm -rf ./tensorboard_log/torchrec_synth/
58 | # env CUDA_LAUNCH_BLOCKING=1
59 | # timeout -s SIGKILL 30m
60 | torchx run -s local_cwd -cfg log_dir=log/torchrec_synth/${PLAN} dist.ddp -j 1x${GPUNUM} --script baselines/dlrm_main.py -- \
61 | --in_memory_binary_criteo_path ${DATAPATH} --kaggle --embedding_dim ${EMB_DIM} --pin_memory \
62 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,${EMB_DIM}" --shuffle_batches \
63 | --learning_rate 1. --batch_size ${BATCHSIZE} --profile_dir "" --shard_type ${SHARDTYPE} --kernel_type ${KERNELTYPE} \
64 | --synth_size ${SCALE} \
65 | --synth_size ${SCALE} --prefetch_num ${PREFETCH_NUM} ${EVAL_ACC_FLAG} 2>&1 | tee logs/torchrec_${PLAN}.txt
66 | done
67 | done
68 | done
69 | done
70 | done
71 |
--------------------------------------------------------------------------------
/scripts/torchrec_terabyte.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # set -xsv
3 |
4 | export DATASETPATH=/data/scratch/RecSys/criteo_preproc/
5 | export PYTHONPATH=$HOME/codes/torchrec:$PYTHONPATH
6 |
7 | set_n_least_used_CUDA_VISIBLE_DEVICES() {
8 | local n=${1:-"9999"}
9 | echo "GPU Memory Usage:"
10 | local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
11 | | tail -n +2 \
12 | | nl -v 0 \
13 | | tee /dev/tty \
14 | | sort -g -k 2 \
15 | | awk '{print $1}' \
16 | | head -n $n)
17 | export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
18 | echo "Now CUDA_VISIBLE_DEVICES is set to:"
19 | echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
20 | }
21 |
22 | export LOG_DIR="/data2/users/lcfjr/logs_1tb/b"
23 | mkdir -p ${LOG_DIR}
24 |
25 | export GPUNUM=2
26 |
27 | # prefetch mini-batch number.
28 | export PREFETCH_NUM=4
29 | export EVAL_ACC=0
30 | export KERNELTYPE="fused"
31 | export SHARDTYPE="table"
32 | export BATCHSIZE=1024
33 | export EMB_DIM=128
34 | export CACHERATIO=0.1
35 |
36 | if [[ ${EVAL_ACC} == 1 ]]; then
37 | EVAL_ACC_FLAG="--eval_acc"
38 | else
39 | export EVAL_ACC_FLAG=""
40 | fi
41 |
42 |
43 |
44 | batch_size_list=(8192)
45 | gpu_num_list=(1)
46 |
47 |
48 | for ((i=0;i<${#batch_size_list[@]};i++)); do
49 |
50 | for KERNELTYPE in "colossalai"
51 | do
52 | for CACHERATIO in 0.05
53 | do
54 |
55 | export BATCHSIZE=${batch_size_list[i]}
56 | export GPUNUM=${gpu_num_list[i]}
57 |
58 | set_n_least_used_CUDA_VISIBLE_DEVICES ${GPUNUM}
59 |
60 | export PLAN=k_${KERNELTYPE}_g_${GPUNUM}_bs_${BATCHSIZE}_sd_${SHARDTYPE}_pf_${PREFETCH_NUM}_eb_${EMB_DIM}_cache_${CACHERATIO}
61 | echo "training batchsize" ${BATCHSIZE} "gpunum" ${GPUNUM}
62 |
63 | echo "training batchsize" ${BATCHSIZE} "gpunum" ${GPUNUM}
64 | torchx run -s local_cwd -cfg log_dir=log/torchrec_terabyte/w1_16k dist.ddp -j 1x${GPUNUM} --script baselines/dlrm_main.py -- \
65 | --in_memory_binary_criteo_path ${DATASETPATH} --embedding_dim ${EMB_DIM} --pin_memory \
66 | --over_arch_layer_sizes "1024,1024,512,256,1" --dense_arch_layer_sizes "512,256,128" --shuffle_batches \
67 | --learning_rate 1. --batch_size ${BATCHSIZE} --shard_type ${SHARDTYPE} --kernel_type ${KERNELTYPE} --prefetch_num ${PREFETCH_NUM} ${EVAL_ACC_FLAG} \
68 | --profile_dir "" ${EVAL_ACC_FLAG} --limit_train_samples 102400000 --cache_ratio ${CACHERATIO} 2>&1 | tee ${LOG_DIR}/torchrec_${PLAN}.txt
69 |
70 | done
71 | done
72 | done
73 |
--------------------------------------------------------------------------------