├── .demo ├── step1_modeling.json ├── step2_process_group_manager.json ├── step3_dataloader.json ├── step4_tensor_parallel.json ├── step5_data_parallel_naive.json └── step6_data_parallel_bucket.json ├── .gitignore ├── .vscode └── settings.json ├── README.md ├── assets └── llama1B_sanity_check.png ├── requirements.txt ├── setup.py ├── step1_modeling ├── model.py ├── train.py └── utils.py ├── step2_process_group_manager ├── model.py ├── patch_step_2.diff ├── process_group_manager.py ├── train.py └── utils.py ├── step3_dataloader ├── dataloader.py ├── model.py ├── patch_step_3.diff ├── process_group_manager.py ├── train.py └── utils.py ├── step4_tensor_parallel ├── dataloader.py ├── model.py ├── patch_step_4.diff ├── process_group_manager.py ├── tensor_parallel.py ├── train.py └── utils.py ├── step5_data_parallel_naive ├── data_parallel.py ├── dataloader.py ├── model.py ├── patch_step_5.diff ├── process_group_manager.py ├── tensor_parallel.py ├── train.py └── utils.py ├── step6_data_parallel_bucket ├── data_parallel.py ├── dataloader.py ├── model.py ├── patch_step_6.diff ├── process_group_manager.py ├── tensor_parallel.py ├── train.py └── utils.py ├── step7_pipeline_parallel_afab ├── data_parallel.py ├── dataloader.py ├── model.py ├── patch_step_7.diff ├── pipeline_parallel.py ├── process_group_manager.py ├── tensor_parallel.py ├── train.py └── utils.py └── step8_pipeline_parallel_1f1b ├── data_parallel.py ├── dataloader.py ├── model.py ├── patch_step_8.diff ├── pipeline_parallel.py ├── process_group_manager.py ├── tensor_parallel.py ├── train.py └── utils.py /.demo/step1_modeling.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://elio.dev/demo-time.schema.json", 3 | "title": "Step 1: Modeling", 4 | "description": "", 5 | "demos": [ 6 | { 7 | "title": "model.py -> Llama", 8 | "description": "", 9 | "steps": [ 10 | { 11 | "action": "open", 12 | "path": "/step1_modeling/model.py" 13 | }, 14 | { 15 | "action": "highlight", 16 | "path": "/step1_modeling/model.py", 17 | "position": "135:148" 18 | } 19 | ] 20 | }, 21 | { 22 | "title": "model.py -> DecoderLayer", 23 | "description": "", 24 | "steps": [ 25 | { 26 | "action": "highlight", 27 | "path": "/step1_modeling/model.py", 28 | "position": "100:116" 29 | } 30 | ] 31 | }, 32 | { 33 | "title": "train.py -> args", 34 | "description": "", 35 | "steps": [ 36 | { 37 | "action": "highlight", 38 | "path": "/step1_modeling/train.py", 39 | "position": "17:39" 40 | } 41 | ] 42 | }, 43 | { 44 | "title": "train.py -> distributed setup", 45 | "description": "", 46 | "steps": [ 47 | { 48 | "action": "highlight", 49 | "path": "/step1_modeling/train.py", 50 | "position": "41:54" 51 | } 52 | ] 53 | }, 54 | { 55 | "title": "train.py -> model init", 56 | "description": "", 57 | "steps": [ 58 | { 59 | "action": "highlight", 60 | "path": "/step1_modeling/train.py", 61 | "position": "58:66" 62 | } 63 | ] 64 | }, 65 | { 66 | "title": "train.py -> optimizer init", 67 | "description": "", 68 | "steps": [ 69 | { 70 | "action": "highlight", 71 | "path": "/step1_modeling/train.py", 72 | "position": "70" 73 | } 74 | ] 75 | }, 76 | { 77 | "title": "train.py -> simple training loop", 78 | "description": "", 79 | "steps": [ 80 | { 81 | "action": "highlight", 82 | "path": "/step1_modeling/train.py", 83 | "position": "74:97" 84 | } 85 | ] 86 | } 87 | ] 88 | } -------------------------------------------------------------------------------- /.demo/step2_process_group_manager.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://elio.dev/demo-time.schema.json", 3 | "title": "Step 2: Process Group Manager", 4 | "description": "", 5 | "demos": [ 6 | { 7 | "title": "new changes for step 2", 8 | "description": "", 9 | "steps": [ 10 | { 11 | "action": "open", 12 | "path": "/step2_process_group_manager/patch_step_2.diff" 13 | } 14 | ] 15 | }, 16 | { 17 | "title": "process_group_manager.py -> class", 18 | "description": "", 19 | "steps": [ 20 | { 21 | "action": "open", 22 | "path": "/step2_process_group_manager/process_group_manager.py" 23 | }, 24 | { 25 | "action": "highlight", 26 | "path": "/step2_process_group_manager/process_group_manager.py", 27 | "position": "5:50" 28 | } 29 | ] 30 | }, 31 | { 32 | "title": "process_group_manager.py -> assert", 33 | "description": "", 34 | "steps": [ 35 | { 36 | "action": "highlight", 37 | "path": "/step2_process_group_manager/process_group_manager.py", 38 | "position": "7:11" 39 | } 40 | ] 41 | }, 42 | { 43 | "title": "process_group_manager.py -> grid", 44 | "description": "", 45 | "steps": [ 46 | { 47 | "action": "highlight", 48 | "path": "/step2_process_group_manager/process_group_manager.py", 49 | "position": "13" 50 | } 51 | ] 52 | }, 53 | { 54 | "title": "process_group_manager.py -> find 3D coordinates of current process", 55 | "description": "", 56 | "steps": [ 57 | { 58 | "action": "highlight", 59 | "path": "/step2_process_group_manager/process_group_manager.py", 60 | "position": "15" 61 | } 62 | ] 63 | }, 64 | { 65 | "title": "process_group_manager.py -> process group", 66 | "description": "", 67 | "steps": [ 68 | { 69 | "action": "highlight", 70 | "path": "/step2_process_group_manager/process_group_manager.py", 71 | "position": "18:28" 72 | } 73 | ] 74 | }, 75 | { 76 | "title": "process_group_manager.py -> tensor parallel", 77 | "description": "", 78 | "steps": [ 79 | { 80 | "action": "highlight", 81 | "path": "/step2_process_group_manager/process_group_manager.py", 82 | "position": "30:33" 83 | } 84 | ] 85 | }, 86 | { 87 | "title": "process_group_manager.py -> pipeline parallel", 88 | "description": "", 89 | "steps": [ 90 | { 91 | "action": "highlight", 92 | "path": "/step2_process_group_manager/process_group_manager.py", 93 | "position": "35:42" 94 | } 95 | ] 96 | }, 97 | { 98 | "title": "process_group_manager.py -> data parallel", 99 | "description": "", 100 | "steps": [ 101 | { 102 | "action": "highlight", 103 | "path": "/step2_process_group_manager/process_group_manager.py", 104 | "position": "44:47" 105 | } 106 | ] 107 | }, 108 | { 109 | "title": "train.py -> init process group", 110 | "description": "", 111 | "steps": [ 112 | { 113 | "action": "highlight", 114 | "path": "/step2_process_group_manager/train.py", 115 | "position": "63:64" 116 | } 117 | ] 118 | } 119 | ] 120 | } -------------------------------------------------------------------------------- /.demo/step3_dataloader.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://elio.dev/demo-time.schema.json", 3 | "title": "Step 3: Dataloader", 4 | "description": "", 5 | "demos": [ 6 | { 7 | "title": "new changes for step 3", 8 | "description": "", 9 | "steps": [ 10 | { 11 | "action": "open", 12 | "path": "/step3_dataloader/patch_step_3.diff" 13 | } 14 | ] 15 | }, 16 | { 17 | "title": "train.py -> main train loop", 18 | "description": "", 19 | "steps": [ 20 | { 21 | "action": "open", 22 | "path": "/step3_dataloader/train.py" 23 | }, 24 | { 25 | "action": "highlight", 26 | "path": "/step3_dataloader/train.py", 27 | "position": "154:183" 28 | } 29 | ] 30 | }, 31 | { 32 | "title": "train.py -> train_step", 33 | "description": "", 34 | "steps": [ 35 | { 36 | "action": "open", 37 | "path": "/step3_dataloader/train.py" 38 | }, 39 | { 40 | "action": "highlight", 41 | "path": "/step3_dataloader/train.py", 42 | "position": "23:44" 43 | } 44 | ] 45 | }, 46 | { 47 | "title": "dataloader.py -> class (part 1)", 48 | "description": "", 49 | "steps": [ 50 | { 51 | "action": "open", 52 | "path": "/step3_dataloader/dataloader.py" 53 | }, 54 | { 55 | "action": "highlight", 56 | "path": "/step3_dataloader/dataloader.py", 57 | "position": "23" 58 | } 59 | ] 60 | }, 61 | { 62 | "title": "dataloader.py -> tokenize_dataset", 63 | "description": "", 64 | "steps": [ 65 | { 66 | "action": "highlight", 67 | "path": "/step3_dataloader/dataloader.py", 68 | "position": "59:80" 69 | } 70 | ] 71 | }, 72 | { 73 | "title": "dataloader.py -> tokenizer_group_text", 74 | "description": "", 75 | "steps": [ 76 | { 77 | "action": "highlight", 78 | "path": "/step3_dataloader/dataloader.py", 79 | "position": "37:57" 80 | } 81 | ] 82 | }, 83 | { 84 | "title": "dataloader.py -> class (part 2)", 85 | "description": "", 86 | "steps": [ 87 | { 88 | "action": "highlight", 89 | "path": "/step3_dataloader/dataloader.py", 90 | "position": "31" 91 | } 92 | ] 93 | }, 94 | { 95 | "title": "dataloader.py -> collate_batch", 96 | "description": "", 97 | "steps": [ 98 | { 99 | "action": "highlight", 100 | "path": "/step3_dataloader/dataloader.py", 101 | "position": "82:97" 102 | } 103 | ] 104 | } 105 | ] 106 | } -------------------------------------------------------------------------------- /.demo/step4_tensor_parallel.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://elio.dev/demo-time.schema.json", 3 | "title": "Step 4: Tensor Parallel", 4 | "description": "", 5 | "demos": [ 6 | { 7 | "title": "new changes for step 4", 8 | "description": "", 9 | "steps": [ 10 | { 11 | "action": "open", 12 | "path": "/step4_tensor_parallel/patch_step_4.diff" 13 | } 14 | ] 15 | }, 16 | { 17 | "title": "tensor_parallel.py -> apply_tensor_parallel", 18 | "description": "", 19 | "steps": [ 20 | { 21 | "action": "open", 22 | "path": "/step4_tensor_parallel/tensor_parallel.py" 23 | }, 24 | { 25 | "action": "highlight", 26 | "path": "/step4_tensor_parallel/tensor_parallel.py", 27 | "position": "68:111" 28 | } 29 | ] 30 | }, 31 | { 32 | "title": "tensor_parallel.py -> ColumnParallel (local partition)", 33 | "description": "", 34 | "steps": [ 35 | { 36 | "action": "highlight", 37 | "path": "/step4_tensor_parallel/tensor_parallel.py", 38 | "position": "128:129" 39 | } 40 | ] 41 | }, 42 | { 43 | "title": "tensor_parallel.py -> ColumnParallel (reset_parameters)", 44 | "description": "", 45 | "steps": [ 46 | { 47 | "action": "highlight", 48 | "path": "/step4_tensor_parallel/tensor_parallel.py", 49 | "position": "139:157" 50 | } 51 | ] 52 | }, 53 | { 54 | "title": "tensor_parallel.py -> RowParallel (weight init)", 55 | "description": "", 56 | "steps": [ 57 | { 58 | "action": "highlight", 59 | "path": "/step4_tensor_parallel/tensor_parallel.py", 60 | "position": "180" 61 | } 62 | ] 63 | }, 64 | { 65 | "title": "tensor_parallel.py -> RowParallel (reset_parameters)", 66 | "description": "", 67 | "steps": [ 68 | { 69 | "action": "highlight", 70 | "path": "/step4_tensor_parallel/tensor_parallel.py", 71 | "position": "191:209" 72 | } 73 | ] 74 | }, 75 | { 76 | "title": "tensor_parallel.py -> ColumnParallel (forward) part 1", 77 | "description": "", 78 | "steps": [ 79 | { 80 | "action": "highlight", 81 | "path": "/step4_tensor_parallel/tensor_parallel.py", 82 | "position": "159:162" 83 | } 84 | ] 85 | }, 86 | { 87 | "title": "tensor_parallel.py -> ColumnParallel (forward) part 2", 88 | "description": "", 89 | "steps": [ 90 | { 91 | "action": "highlight", 92 | "path": "/step4_tensor_parallel/tensor_parallel.py", 93 | "position": "79" 94 | } 95 | ] 96 | }, 97 | { 98 | "title": "tensor_parallel.py -> RowParallel (forward)", 99 | "description": "", 100 | "steps": [ 101 | { 102 | "action": "highlight", 103 | "path": "/step4_tensor_parallel/tensor_parallel.py", 104 | "position": "211:216" 105 | } 106 | ] 107 | }, 108 | { 109 | "title": "tensor_parallel.py -> RowParallel (backward)", 110 | "description": "", 111 | "steps": [ 112 | { 113 | "action": "highlight", 114 | "path": "/step4_tensor_parallel/tensor_parallel.py", 115 | "position": "26:28" 116 | } 117 | ] 118 | }, 119 | { 120 | "title": "tensor_parallel.py -> ColumnParallel (backward)", 121 | "description": "", 122 | "steps": [ 123 | { 124 | "action": "highlight", 125 | "path": "/step4_tensor_parallel/tensor_parallel.py", 126 | "position": "59:64" 127 | } 128 | ] 129 | }, 130 | { 131 | "title": "tensor_parallel.py -> ColumnParallel Last Layer (forward)", 132 | "description": "", 133 | "steps": [ 134 | { 135 | "action": "highlight", 136 | "path": "/step4_tensor_parallel/tensor_parallel.py", 137 | "position": "163:165" 138 | } 139 | ] 140 | }, 141 | { 142 | "title": "tensor_parallel.py -> ColumnParallel Last Layer (backward)", 143 | "description": "", 144 | "steps": [ 145 | { 146 | "action": "highlight", 147 | "path": "/step4_tensor_parallel/tensor_parallel.py", 148 | "position": "46:51" 149 | } 150 | ] 151 | }, 152 | { 153 | "title": "tensor_parallel.py -> VocabParallel (split embedding)", 154 | "description": "", 155 | "steps": [ 156 | { 157 | "action": "highlight", 158 | "path": "/step4_tensor_parallel/tensor_parallel.py", 159 | "position": "242:245" 160 | } 161 | ] 162 | }, 163 | { 164 | "title": "tensor_parallel.py -> VocabParallel (compute embedding)", 165 | "description": "", 166 | "steps": [ 167 | { 168 | "action": "highlight", 169 | "path": "/step4_tensor_parallel/tensor_parallel.py", 170 | "position": "287:295" 171 | } 172 | ] 173 | }, 174 | { 175 | "title": "tensor_parallel.py -> VocabParallel (zero out)", 176 | "description": "", 177 | "steps": [ 178 | { 179 | "action": "highlight", 180 | "path": "/step4_tensor_parallel/tensor_parallel.py", 181 | "position": "296:297" 182 | } 183 | ] 184 | }, 185 | { 186 | "title": "tensor_parallel.py -> VocabParallel (all_reduce)", 187 | "description": "", 188 | "steps": [ 189 | { 190 | "action": "highlight", 191 | "path": "/step4_tensor_parallel/tensor_parallel.py", 192 | "position": "298" 193 | } 194 | ] 195 | } 196 | ] 197 | } -------------------------------------------------------------------------------- /.demo/step5_data_parallel_naive.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://elio.dev/demo-time.schema.json", 3 | "title": "Step 5: Data Parallel (naive)", 4 | "description": "", 5 | "demos": [ 6 | { 7 | "title": "new changes for step 5", 8 | "description": "", 9 | "steps": [ 10 | { 11 | "action": "open", 12 | "path": "/step5_data_parallel_naive/patch_step_5.diff" 13 | } 14 | ] 15 | }, 16 | { 17 | "title": "train.py -> reminder train loop (part 1)", 18 | "description": "", 19 | "steps": [ 20 | { 21 | "action": "highlight", 22 | "path": "/step5_data_parallel_naive/train.py", 23 | "position": "174" 24 | } 25 | ] 26 | }, 27 | { 28 | "title": "train.py -> reminder train loop (part 2)", 29 | "description": "", 30 | "steps": [ 31 | { 32 | "action": "highlight", 33 | "path": "/step5_data_parallel_naive/train.py", 34 | "position": "179" 35 | } 36 | ] 37 | }, 38 | { 39 | "title": "train.py -> reminder train loop (part 3)", 40 | "description": "", 41 | "steps": [ 42 | { 43 | "action": "highlight", 44 | "path": "/step5_data_parallel_naive/train.py", 45 | "position": "31:49" 46 | } 47 | ] 48 | }, 49 | { 50 | "title": "train.py -> reminder train loop (part 4)", 51 | "description": "", 52 | "steps": [ 53 | { 54 | "action": "highlight", 55 | "path": "/step5_data_parallel_naive/train.py", 56 | "position": "181" 57 | } 58 | ] 59 | }, 60 | { 61 | "title": "dataloader.py -> Distributed Sampler", 62 | "description": "", 63 | "steps": [ 64 | { 65 | "action": "highlight", 66 | "path": "/step5_data_parallel_naive/dataloader.py", 67 | "position": "28:34" 68 | } 69 | ] 70 | }, 71 | { 72 | "title": "data_parallel -> last batch", 73 | "description": "", 74 | "steps": [ 75 | { 76 | "action": "highlight", 77 | "path": "/step5_data_parallel_naive/train.py", 78 | "position": "37:39" 79 | } 80 | ] 81 | }, 82 | { 83 | "title": "data_parallel -> register_backward_hook (part 1)", 84 | "description": "", 85 | "steps": [ 86 | { 87 | "action": "highlight", 88 | "path": "/step5_data_parallel_naive/data_parallel.py", 89 | "position": "16" 90 | } 91 | ] 92 | }, 93 | { 94 | "title": "data_parallel -> register_backward_hook (part 2)", 95 | "description": "", 96 | "steps": [ 97 | { 98 | "action": "highlight", 99 | "path": "/step5_data_parallel_naive/data_parallel.py", 100 | "position": "21:33" 101 | } 102 | ] 103 | } 104 | ] 105 | } -------------------------------------------------------------------------------- /.demo/step6_data_parallel_bucket.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://elio.dev/demo-time.schema.json", 3 | "title": "Step 6: Data Parallel (bucket)", 4 | "description": "", 5 | "demos": [ 6 | { 7 | "title": "new changes for step 6", 8 | "description": "", 9 | "steps": [ 10 | { 11 | "action": "open", 12 | "path": "/step6_data_parallel_bucket/patch_step_6.diff" 13 | } 14 | ] 15 | }, 16 | { 17 | "title": "data_parallel.py -> DataParallelBucket (init)", 18 | "description": "", 19 | "steps": [ 20 | { 21 | "action": "highlight", 22 | "path": "/step6_data_parallel_bucket/data_parallel.py", 23 | "position": "173:174" 24 | } 25 | ] 26 | }, 27 | { 28 | "title": "data_parallel.py -> BucketManager (empty bucket)", 29 | "description": "", 30 | "steps": [ 31 | { 32 | "action": "highlight", 33 | "path": "/step6_data_parallel_bucket/data_parallel.py", 34 | "position": "111:115" 35 | } 36 | ] 37 | }, 38 | { 39 | "title": "data_parallel.py -> BucketManager (keep adding in bucket)", 40 | "description": "", 41 | "steps": [ 42 | { 43 | "action": "highlight", 44 | "path": "/step6_data_parallel_bucket/data_parallel.py", 45 | "position": "123:124" 46 | } 47 | ] 48 | }, 49 | { 50 | "title": "data_parallel.py -> BucketManager (full bucket)", 51 | "description": "", 52 | "steps": [ 53 | { 54 | "action": "highlight", 55 | "path": "/step6_data_parallel_bucket/data_parallel.py", 56 | "position": "117:121" 57 | } 58 | ] 59 | }, 60 | { 61 | "title": "data_parallel.py -> BucketManager (some infos on bucket)", 62 | "description": "", 63 | "steps": [ 64 | { 65 | "action": "highlight", 66 | "path": "/step6_data_parallel_bucket/data_parallel.py", 67 | "position": "126:131" 68 | } 69 | ] 70 | }, 71 | { 72 | "title": "data_parallel.py -> BucketManager (Create actual bucket)", 73 | "description": "", 74 | "steps": [ 75 | { 76 | "action": "highlight", 77 | "path": "/step6_data_parallel_bucket/data_parallel.py", 78 | "position": "133:136" 79 | } 80 | ] 81 | }, 82 | { 83 | "title": "data_parallel.py -> Bucket (Create actual bucket 2)", 84 | "description": "", 85 | "steps": [ 86 | { 87 | "action": "highlight", 88 | "path": "/step6_data_parallel_bucket/data_parallel.py", 89 | "position": "38:52" 90 | } 91 | ] 92 | }, 93 | { 94 | "title": "data_parallel.py -> BucketManager (maing_grad variable)", 95 | "description": "", 96 | "steps": [ 97 | { 98 | "action": "highlight", 99 | "path": "/step6_data_parallel_bucket/data_parallel.py", 100 | "position": "138:144" 101 | } 102 | ] 103 | }, 104 | { 105 | "title": "data_parallel.py -> DataParallelBucket (register_backward_hook)", 106 | "description": "", 107 | "steps": [ 108 | { 109 | "action": "highlight", 110 | "path": "/step6_data_parallel_bucket/data_parallel.py", 111 | "position": "188:211" 112 | } 113 | ] 114 | }, 115 | { 116 | "title": "data_parallel.py -> DataParallelBucket (_make_param_hook 1)", 117 | "description": "", 118 | "steps": [ 119 | { 120 | "action": "highlight", 121 | "path": "/step6_data_parallel_bucket/data_parallel.py", 122 | "position": "224:22" 123 | } 124 | ] 125 | }, 126 | { 127 | "title": "data_parallel.py -> DataParallelBucket (_make_param_hook 2)", 128 | "description": "", 129 | "steps": [ 130 | { 131 | "action": "highlight", 132 | "path": "/step6_data_parallel_bucket/data_parallel.py", 133 | "position": "229:238" 134 | } 135 | ] 136 | }, 137 | { 138 | "title": "data_parallel.py -> DataParallelBucket (autograd)", 139 | "description": "", 140 | "steps": [ 141 | { 142 | "action": "highlight", 143 | "path": "/step6_data_parallel_bucket/data_parallel.py", 144 | "position": "229:238" 145 | } 146 | ] 147 | }, 148 | { 149 | "title": "data_parallel.py -> BucketManager (mark_param_as_ready)", 150 | "description": "", 151 | "steps": [ 152 | { 153 | "action": "highlight", 154 | "path": "/step6_data_parallel_bucket/data_parallel.py", 155 | "position": "159:162" 156 | } 157 | ] 158 | }, 159 | { 160 | "title": "data_parallel.py -> Bucket (mark_param_as_ready)", 161 | "description": "", 162 | "steps": [ 163 | { 164 | "action": "highlight", 165 | "path": "/step6_data_parallel_bucket/data_parallel.py", 166 | "position": "74:80" 167 | } 168 | ] 169 | }, 170 | { 171 | "title": "data_parallel.py -> Bucket (sync_gradient)", 172 | "description": "", 173 | "steps": [ 174 | { 175 | "action": "highlight", 176 | "path": "/step6_data_parallel_bucket/data_parallel.py", 177 | "position": "54:58" 178 | } 179 | ] 180 | }, 181 | { 182 | "title": "data_parallel.py -> DataParallelBucket (_post_backward 1)", 183 | "description": "", 184 | "steps": [ 185 | { 186 | "action": "highlight", 187 | "path": "/step6_data_parallel_bucket/data_parallel.py", 188 | "position": "234" 189 | } 190 | ] 191 | }, 192 | { 193 | "title": "data_parallel.py -> DataParallelBucket (_post_backward 2)", 194 | "description": "", 195 | "steps": [ 196 | { 197 | "action": "highlight", 198 | "path": "/step6_data_parallel_bucket/data_parallel.py", 199 | "position": "241:253" 200 | } 201 | ] 202 | }, 203 | { 204 | "title": "data_parallel.py -> DataParallelBucket (bucket_manager.wait)", 205 | "description": "", 206 | "steps": [ 207 | { 208 | "action": "highlight", 209 | "path": "/step6_data_parallel_bucket/data_parallel.py", 210 | "position": "241:253" 211 | } 212 | ] 213 | }, 214 | { 215 | "title": "data_parallel.py -> BucketManager (wait)", 216 | "description": "", 217 | "steps": [ 218 | { 219 | "action": "highlight", 220 | "path": "/step6_data_parallel_bucket/data_parallel.py", 221 | "position": "154:157" 222 | } 223 | ] 224 | }, 225 | { 226 | "title": "data_parallel.py -> Bucket (wait)", 227 | "description": "", 228 | "steps": [ 229 | { 230 | "action": "highlight", 231 | "path": "/step6_data_parallel_bucket/data_parallel.py", 232 | "position": "68:72" 233 | } 234 | ] 235 | } 236 | ] 237 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | picotron_tutorial.egg-info/ 3 | **/launch.json -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "demoTime.highlightBorderColor": "rgb(237, 18, 146)", 3 | "demoTime.showClock": true, 4 | "demoTime.timer": 60, 5 | "demoTime.lineInsertionDelay": 0, 6 | "workbench.colorCustomizations": { 7 | "editor.selectionBackground": "#362a2a" 8 | } 9 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Picotron tutorial 2 | 3 | A step by step tutorial on how to build [Picotron](https://github.com/huggingface/picotron) distributed training framework form scratch 🔥 4 | 5 | ## Videos 6 | 7 | > More to come. Full playlist [here](https://www.youtube.com/playlist?list=PL-_armZiJvAnhcRr6yTJ0__f3Oi-LLi9S) 🎬 8 | 9 | - 🎬 [[Picotron tutorial] Part 1: Model, Process Group Manager, Dataloader](https://youtu.be/u2VSwDDpaBM) 10 | - 🎬 [[Picotron tutorial] Part 2: Tensor Parallel](https://www.youtube.com/watch?v=qUMPaSWi5HI&list=PL-_armZiJvAnhcRr6yTJ0__f3Oi-LLi9S&index=3) 11 | - 🎬 [[Picotron tutorial] Bonus: Debugging Distributed codebase](https://www.youtube.com/watch?v=_8xlRgFY_-g&list=PL-_armZiJvAnhcRr6yTJ0__f3Oi-LLi9S&index=4) 12 | - 🎬 [[Picotron tutorial] Part 3: Data Parallel (Naive & Bucket)](https://www.youtube.com/watch?v=k8EpWveM_t4&list=PL-_armZiJvAnhcRr6yTJ0__f3Oi-LLi9S&index=4) 13 | 14 | ## Setup 15 | 16 | ``` 17 | conda create -n env-picotron-tutorial python=3.10 --y 18 | conda activate env-picotron-tutorial 19 | pip install -e . 20 | ``` 21 | 22 | ## Sanity check 23 | 24 | - Convergence testing on a Llama 1B on 4096000 tokens to see if loss match. 25 | 26 | ![](assets/llama1B_sanity_check.png) 27 | 28 | 29 | ```bash 30 | # Basline 31 | cd step3_dataloader/ 32 | torchrun --nproc_per_node 1 train.py --micro_batch_size 4 --gradient_accumulation_steps 8 --seq_len 1024 --max_tokens 4096000 --num_proc 16 --model_name TinyLlama/TinyLlama_v1.1 --num_hidden_layers 22 --num_attention_heads 32 --num_key_value_heads 4 --run_name baseline_1B --use_wandb 33 | 34 | # Tensor Parallel 35 | cd step4_tensor_parallel/ 36 | torchrun --nproc_per_node 4 train.py --tp_size 4 --micro_batch_size 4 --gradient_accumulation_steps 8 --seq_len 1024 --max_tokens 4096000 --num_proc 16 --model_name TinyLlama/TinyLlama_v1.1 --num_hidden_layers 22 --num_attention_heads 32 --num_key_value_heads 4 --run_name tp_1B --use_wandb 37 | 38 | # Data Parallel 39 | cd step6_data_parallel_bucket/ 40 | torchrun --nproc_per_node 4 train.py --dp_size 4 --micro_batch_size 1 --gradient_accumulation_steps 8 --seq_len 1024 --max_tokens 4096000 --num_proc 16 --model_name TinyLlama/TinyLlama_v1.1 --num_hidden_layers 22 --num_attention_heads 32 --num_key_value_heads 4 --run_name dp_bucket_1B --use_wandb 41 | 42 | # Pipeline Parallel 43 | cd step8_pipeline_parallel_1f1b/ 44 | torchrun --nproc_per_node 4 train.py --pp_size 4 --pp_engine 1f1b --micro_batch_size 4 --gradient_accumulation_steps 8 --seq_len 1024 --max_tokens 4096000 --num_proc 16 --model_name TinyLlama/TinyLlama_v1.1 --num_hidden_layers 22 --num_attention_heads 32 --num_key_value_heads 4 --run_name pp_1f1b_1B --use_wandb 45 | 46 | # 3D parallelism (Tensor + Data + Pipeline parallel) 47 | torchrun --nproc_per_node 8 train.py --tp_size 2 --pp_size 2 --pp_engine 1f1b --dp_size 2 --micro_batch_size 2 --gradient_accumulation_steps 8 --seq_len 1024 --max_tokens 4096000 --num_proc 16 --model_name TinyLlama/TinyLlama_v1.1 --num_hidden_layers 22 --num_attention_heads 32 --num_key_value_heads 4 --run_name 3D_parallelism_1B --use_wandb 48 | ``` 49 | -------------------------------------------------------------------------------- /assets/llama1B_sanity_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/picotron_tutorial/8c14c925e2069ec2756cdf1510ed2b8bb56a098e/assets/llama1B_sanity_check.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.0 2 | numpy==1.26.4 3 | datasets==2.19.1 4 | transformers==4.41.1 5 | flash-attn==2.5.0 6 | lovely_tensors 7 | wandb -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | def read_requirements(): 4 | with open('requirements.txt') as req: 5 | return [line.strip() for line in req if line.strip() and not line.startswith('#')] 6 | 7 | setup( 8 | name="picotron_tutorial", 9 | version='0.1.0', 10 | packages=find_packages(), 11 | install_requires=read_requirements(), 12 | ) -------------------------------------------------------------------------------- /step1_modeling/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from flash_attn.flash_attn_interface import flash_attn_func 5 | from flash_attn.layers.rotary import apply_rotary_emb 6 | from flash_attn.ops.triton.layer_norm import layer_norm_fn 7 | 8 | def flash_attention(q, k, v, causal = True): 9 | q = q.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 10 | k = k.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 11 | v = v.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 12 | return flash_attn_func(q, k, v, causal=causal) 13 | 14 | def get_cos_sin(seq_length, head_dim, base=500000.0): 15 | assert head_dim%2==0 16 | # Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU 17 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim)) 18 | dtype = torch.bfloat16 19 | device = torch.device('cuda') 20 | position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1] 21 | # To match transformers implementation. m * theta should be computed on GPU 22 | theta = theta.to(device) 23 | return torch.cos(position.float()*theta.float()).to(dtype).repeat(1,2), torch.sin(position.float()*theta.float()).to(dtype).repeat(1,2) # [seq_length, head_dim], [seq_length, head_dim] 24 | 25 | class TritonRMSNorm(nn.Module): 26 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 27 | super().__init__() 28 | self.eps = eps 29 | self.weight = nn.Parameter(torch.ones(hidden_size)) 30 | self.register_parameter("bias", None) 31 | 32 | def forward( 33 | self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False 34 | ): 35 | return layer_norm_fn( 36 | hidden_states, 37 | self.weight, 38 | None, 39 | residual=residual, 40 | eps=self.eps, 41 | dropout_p=dropout_p, 42 | prenorm=prenorm, 43 | residual_in_fp32=residual_in_fp32, 44 | is_rms_norm=True, 45 | return_dropout_mask=return_dropout_mask, 46 | ) 47 | 48 | class Attention(nn.Module): 49 | def __init__(self, config, layer_idx): 50 | super().__init__() 51 | self.hidden_size = config.hidden_size 52 | self.num_heads = config.num_attention_heads 53 | self.num_key_values = config.num_key_value_heads 54 | self.head_dim = self.hidden_size//self.num_heads 55 | self.num_local_heads = config.num_attention_heads 56 | self.num_local_kv_heads = config.num_key_value_heads 57 | 58 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) 59 | self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 60 | self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 61 | self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 62 | self.layer_idx = layer_idx 63 | 64 | def forward(self, x, cos, sin, attention_mask=None, position_ids=None): 65 | batch_size, seq_length, hidden_dim = x.size() 66 | q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim] 67 | k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 68 | v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 69 | 70 | q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim) # [batch_size, seq_length, num_heads, head_dim] 71 | k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim) # [batch_size, seq_length, num_key_values, head_dim] 72 | q = apply_rotary_emb(q,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_heads, head_dim] 73 | k = apply_rotary_emb(k,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_key_values, head_dim] 74 | q = q.transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim] 75 | k = k.transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim] 76 | v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1,2) # [batch_size, num_key_values, seq_length, head_dim] 77 | 78 | k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 79 | v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 80 | 81 | causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. 82 | 83 | out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] 84 | 85 | out = out.reshape(batch_size, seq_length, self.num_local_heads * self.head_dim) # [batch_size, seq_length, hidden_dim] 86 | out = self.out_proj(out) # [batch_size, seq_length, hidden_dim] 87 | return out 88 | 89 | class MLP(nn.Module): 90 | def __init__(self, config) -> None: 91 | super().__init__() 92 | self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 93 | self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 94 | self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 95 | 96 | def forward(self, x): 97 | #TODO: dont do single line operations as it is harder to debug 98 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 99 | 100 | class DecoderLayer(nn.Module): 101 | # TritonRMSNorm -> Attention -> Residual -> TritonRMSNorm -> MLP -> Residual 102 | def __init__(self, config, layer_idx): 103 | super().__init__() 104 | self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 105 | self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 106 | self.attention = Attention(config, layer_idx = layer_idx) 107 | self.mlp = MLP(config) 108 | self.layer_idx = layer_idx 109 | head_dim = config.hidden_size // config.num_attention_heads 110 | self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] 111 | 112 | def forward(self, x, attention_mask = None, position_ids = None): 113 | cos, sin = self.cos, self.sin 114 | x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention 115 | x = x + self.mlp(self.post_attention_layernorm(x)) # MLP 116 | return x 117 | 118 | class Llama(nn.Module): 119 | def __init__(self, config) -> None: 120 | super().__init__() 121 | # sanity check 122 | assert config.hidden_size % config.num_attention_heads==0 123 | assert config.num_attention_heads % config.num_key_value_heads==0 124 | 125 | # params 126 | self.vocab_size = config.vocab_size 127 | self.hidden_size = config.hidden_size 128 | self.num_heads = config.num_attention_heads 129 | self.num_key_values = config.num_key_value_heads 130 | self.head_dim = self.hidden_size//self.num_heads 131 | self.max_position_embeddings = config.max_position_embeddings 132 | self.num_layers = config.num_hidden_layers 133 | self.model_config = config 134 | 135 | # modules 136 | self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) 137 | self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)]) 138 | self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 139 | self.final_norm = TritonRMSNorm(self.hidden_size, eps=config.rms_norm_eps) 140 | 141 | def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None): 142 | x = self.embedding(input_ids) 143 | for layer in self.decoder_layers: 144 | x = layer(x) # [batch_size, seq_length, hidden_dim] 145 | x = self.final_norm(x) 146 | logits = self.final_proj(x) 147 | 148 | return logits # [batch_size, seq_length, vocab_size] -------------------------------------------------------------------------------- /step1_modeling/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | torchrun --nproc_per_node 1 train.py 3 | """ 4 | import os 5 | import datetime 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.distributed as dist 9 | import argparse 10 | from torch.optim import AdamW 11 | from transformers import AutoConfig 12 | 13 | from model import Llama 14 | from utils import set_all_seed, print 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser(description="Training script for LLaMA model") 18 | 19 | # Environment arguments 20 | parser.add_argument("--omp_num_threads", type=str, default="1") 21 | parser.add_argument("--tokenizers_parallelism", type=str, default="false") 22 | 23 | # Model arguments 24 | parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct") 25 | parser.add_argument("--num_hidden_layers", type=int, default=32) 26 | parser.add_argument("--num_attention_heads", type=int, default=16) 27 | parser.add_argument("--num_key_value_heads", type=int, default=4) 28 | 29 | # Training arguments 30 | parser.add_argument("--seed", type=int, default=42) 31 | parser.add_argument("--learning_rate", type=float, default=3e-4) 32 | parser.add_argument("--seq_len", type=int, default=32) 33 | parser.add_argument("--micro_batch_size", type=int, default=1) 34 | 35 | # Logging arguments 36 | parser.add_argument("--run_name", type=str, default="default_run") 37 | parser.add_argument("--use_wandb", action="store_true") 38 | 39 | args = parser.parse_args() 40 | 41 | # Set environment variables 42 | os.environ["OMP_NUM_THREADS"] = args.omp_num_threads 43 | os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism 44 | os.environ["DEVICE"] = "cuda" 45 | 46 | local_rank = int(os.environ["LOCAL_RANK"]) 47 | global_rank = int(os.environ["RANK"]) 48 | world_size = int(os.environ["WORLD_SIZE"]) 49 | backend = "nccl" 50 | torch.cuda.set_device(local_rank) 51 | device = torch.device("cuda", local_rank) 52 | dtype = torch.bfloat16 53 | 54 | dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2)) 55 | 56 | set_all_seed(args.seed) 57 | 58 | model_config = AutoConfig.from_pretrained(args.model_name) 59 | model_config.num_hidden_layers = args.num_hidden_layers 60 | model_config.num_attention_heads = args.num_attention_heads 61 | model_config.num_key_value_heads = args.num_key_value_heads 62 | model_config.max_position_embeddings = args.seq_len 63 | 64 | model = Llama(config=model_config) 65 | model.to(dtype).to(device) 66 | model.train() 67 | 68 | dist.barrier() 69 | 70 | optimizer = AdamW(model.parameters(), lr=args.learning_rate) 71 | 72 | dist.barrier() 73 | 74 | # Create dummy data 75 | input_ids = torch.randint(0, model_config.vocab_size, (args.micro_batch_size, args.seq_len), device=device) 76 | target_ids = torch.randint(0, model_config.vocab_size, (args.micro_batch_size, args.seq_len), device=device) 77 | 78 | # Training step 79 | optimizer.zero_grad() 80 | 81 | # Forward pass 82 | outputs = model(input_ids=input_ids) 83 | 84 | # Compute loss 85 | target_ids = target_ids.reshape(-1) 86 | outputs = outputs.view(-1, model_config.vocab_size) 87 | loss = F.cross_entropy(outputs, target_ids) 88 | 89 | # Backward pass 90 | loss.backward() 91 | 92 | # Optimizer step 93 | optimizer.step() 94 | 95 | print(f"Loss: {loss.item():.4f}", is_print_rank=(global_rank == 0)) 96 | 97 | dist.destroy_process_group() -------------------------------------------------------------------------------- /step1_modeling/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import builtins 5 | import fcntl 6 | 7 | def print(*args, is_print_rank=True, **kwargs): 8 | """ solves multi-process interleaved print problem """ 9 | if not is_print_rank: return 10 | with open(__file__, "r") as fh: 11 | fcntl.flock(fh, fcntl.LOCK_EX) 12 | try: 13 | builtins.print(*args, **kwargs) 14 | finally: 15 | fcntl.flock(fh, fcntl.LOCK_UN) 16 | 17 | def set_all_seed(seed): 18 | for module in [random, np.random]: module.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 21 | 22 | def to_readable_format(num, precision=3): 23 | num_str = str(num) 24 | length = len(num_str) 25 | 26 | def format_with_precision(main, decimal, suffix): 27 | if precision == 0: 28 | return f"{main}{suffix}" 29 | return f"{main}.{decimal[:precision]}{suffix}" 30 | 31 | if length > 12: # Trillions 32 | return format_with_precision(num_str[:-12], num_str[-12:], 'T') 33 | elif length > 9: # Billions 34 | return format_with_precision(num_str[:-9], num_str[-9:], 'B') 35 | elif length > 6: # Millions 36 | return format_with_precision(num_str[:-6], num_str[-6:], 'M') 37 | elif length > 3: # Thousands 38 | return format_with_precision(num_str[:-3], num_str[-3:], 'K') 39 | else: 40 | return num_str -------------------------------------------------------------------------------- /step2_process_group_manager/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from flash_attn.flash_attn_interface import flash_attn_func 5 | from flash_attn.layers.rotary import apply_rotary_emb 6 | from flash_attn.ops.triton.layer_norm import layer_norm_fn 7 | 8 | def flash_attention(q, k, v, causal = True): 9 | q = q.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 10 | k = k.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 11 | v = v.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 12 | return flash_attn_func(q, k, v, causal=causal) 13 | 14 | def get_cos_sin(seq_length, head_dim, base=500000.0): 15 | assert head_dim%2==0 16 | # Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU 17 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim)) 18 | dtype = torch.bfloat16 19 | device = torch.device('cuda') 20 | position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1] 21 | # To match transformers implementation. m * theta should be computed on GPU 22 | theta = theta.to(device) 23 | return torch.cos(position.float()*theta.float()).to(dtype).repeat(1,2), torch.sin(position.float()*theta.float()).to(dtype).repeat(1,2) # [seq_length, head_dim], [seq_length, head_dim] 24 | 25 | class TritonRMSNorm(nn.Module): 26 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 27 | super().__init__() 28 | self.eps = eps 29 | self.weight = nn.Parameter(torch.ones(hidden_size)) 30 | self.register_parameter("bias", None) 31 | 32 | def forward( 33 | self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False 34 | ): 35 | return layer_norm_fn( 36 | hidden_states, 37 | self.weight, 38 | None, 39 | residual=residual, 40 | eps=self.eps, 41 | dropout_p=dropout_p, 42 | prenorm=prenorm, 43 | residual_in_fp32=residual_in_fp32, 44 | is_rms_norm=True, 45 | return_dropout_mask=return_dropout_mask, 46 | ) 47 | 48 | class Attention(nn.Module): 49 | def __init__(self, config, layer_idx): 50 | super().__init__() 51 | self.hidden_size = config.hidden_size 52 | self.num_heads = config.num_attention_heads 53 | self.num_key_values = config.num_key_value_heads 54 | self.head_dim = self.hidden_size//self.num_heads 55 | self.num_local_heads = config.num_attention_heads 56 | self.num_local_kv_heads = config.num_key_value_heads 57 | 58 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) 59 | self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 60 | self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 61 | self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 62 | self.layer_idx = layer_idx 63 | 64 | def forward(self, x, cos, sin, attention_mask=None, position_ids=None): 65 | batch_size, seq_length, hidden_dim = x.size() 66 | q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim] 67 | k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 68 | v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 69 | 70 | q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim) # [batch_size, seq_length, num_heads, head_dim] 71 | k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim) # [batch_size, seq_length, num_key_values, head_dim] 72 | q = apply_rotary_emb(q,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_heads, head_dim] 73 | k = apply_rotary_emb(k,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_key_values, head_dim] 74 | q = q.transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim] 75 | k = k.transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim] 76 | v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1,2) # [batch_size, num_key_values, seq_length, head_dim] 77 | 78 | k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 79 | v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 80 | 81 | causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. 82 | 83 | out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] 84 | 85 | out = out.reshape(batch_size, seq_length, self.num_local_heads * self.head_dim) # [batch_size, seq_length, hidden_dim] 86 | out = self.out_proj(out) # [batch_size, seq_length, hidden_dim] 87 | return out 88 | 89 | class MLP(nn.Module): 90 | def __init__(self, config) -> None: 91 | super().__init__() 92 | self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 93 | self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 94 | self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 95 | 96 | def forward(self, x): 97 | #TODO: dont do single line operations as it is harder to debug 98 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 99 | 100 | class DecoderLayer(nn.Module): 101 | # TritonRMSNorm -> Attention -> Residual -> TritonRMSNorm -> MLP -> Residual 102 | def __init__(self, config, layer_idx): 103 | super().__init__() 104 | self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 105 | self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 106 | self.attention = Attention(config, layer_idx = layer_idx) 107 | self.mlp = MLP(config) 108 | self.layer_idx = layer_idx 109 | head_dim = config.hidden_size // config.num_attention_heads 110 | self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] 111 | 112 | def forward(self, x, attention_mask = None, position_ids = None): 113 | cos, sin = self.cos, self.sin 114 | x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention 115 | x = x + self.mlp(self.post_attention_layernorm(x)) # MLP 116 | return x 117 | 118 | class Llama(nn.Module): 119 | def __init__(self, config) -> None: 120 | super().__init__() 121 | # sanity check 122 | assert config.hidden_size % config.num_attention_heads==0 123 | assert config.num_attention_heads % config.num_key_value_heads==0 124 | 125 | # params 126 | self.vocab_size = config.vocab_size 127 | self.hidden_size = config.hidden_size 128 | self.num_heads = config.num_attention_heads 129 | self.num_key_values = config.num_key_value_heads 130 | self.head_dim = self.hidden_size//self.num_heads 131 | self.max_position_embeddings = config.max_position_embeddings 132 | self.num_layers = config.num_hidden_layers 133 | self.model_config = config 134 | 135 | # modules 136 | self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) 137 | self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)]) 138 | self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 139 | self.final_norm = TritonRMSNorm(self.hidden_size, eps=config.rms_norm_eps) 140 | 141 | def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None): 142 | x = self.embedding(input_ids) 143 | for layer in self.decoder_layers: 144 | x = layer(x) # [batch_size, seq_length, hidden_dim] 145 | x = self.final_norm(x) 146 | logits = self.final_proj(x) 147 | 148 | return logits # [batch_size, seq_length, vocab_size] -------------------------------------------------------------------------------- /step2_process_group_manager/patch_step_2.diff: -------------------------------------------------------------------------------- 1 | diff -x '*.diff' --new-file -ur step1_modeling/process_group_manager.py step2_process_group_manager/process_group_manager.py 2 | --- step1_modeling/process_group_manager.py 1970-01-01 00:00:00.000000000 +0000 3 | +++ step2_process_group_manager/process_group_manager.py 2024-11-17 15:40:02.000000000 +0000 4 | @@ -0,0 +1,54 @@ 5 | +import os 6 | +import torch 7 | +import torch.distributed as dist 8 | + 9 | +class ProcessGroupManager: 10 | + def __init__(self, dp_size, pp_size, tp_size): 11 | + self.global_rank = dist.get_rank() 12 | + self.world_size = dist.get_world_size() 13 | + self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) 14 | + 15 | + assert self.world_size == dp_size * pp_size * tp_size, f"World size ({self.world_size}) != DP ({self.dp_size}) * PP ({self.pp_size}) * TP ({self.tp_size})" 16 | + 17 | + self.grid = torch.arange(self.world_size).view(dp_size, pp_size, tp_size) # DP * PP * TP grid 18 | + # Find the position of the current process in the grid 19 | + self.dp_rank, self.pp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() 20 | + 21 | + # Process group creation - Update indexing to match new grid order 22 | + self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :].tolist() for d in range(dp_size) for p in range(pp_size)])[0] 23 | + self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, t].tolist() for d in range(dp_size) for t in range(tp_size)])[0] 24 | + self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, t].tolist() for p in range(pp_size) for t in range(tp_size)])[0] 25 | + self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, t].flatten().tolist() for t in range(tp_size)])[0] 26 | + 27 | + self.world_group = dist.group.WORLD 28 | + 29 | + # Update group IDs with new grid ordering 30 | + self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, :].tolist() 31 | + self.pp_group_ids = self.grid[self.dp_rank, :, self.tp_rank].tolist() 32 | + self.dp_group_ids = self.grid[:, self.pp_rank, self.tp_rank].tolist() 33 | + 34 | + # Tensor parallelism 35 | + self.tp_world_size = dist.get_world_size(group=self.tp_group) 36 | + self.tp_first_rank = self.tp_group_ids[0] 37 | + self.tp_last_rank = self.tp_group_ids[-1] 38 | + 39 | + # Pipeline parallelism 40 | + self.pp_world_size = dist.get_world_size(group=self.pp_group) 41 | + self.pp_first_rank = self.pp_group_ids[0] 42 | + self.pp_last_rank = self.pp_group_ids[-1] 43 | + self.pp_is_first_stage = self.pp_rank == 0 44 | + self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 45 | + self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.tp_rank].item()) 46 | + self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.tp_rank].item()) 47 | + 48 | + # Data parallelism 49 | + self.dp_world_size = dist.get_world_size(group=self.dp_group) 50 | + self.dp_first_rank = self.dp_group_ids[0] 51 | + self.dp_last_rank = self.dp_group_ids[-1] 52 | + 53 | + def __str__(self): 54 | + return f"DP({self.dp_world_size})-PP({self.pp_world_size})-TP({self.tp_world_size})-Rank({self.global_rank})" 55 | + 56 | +def setup_process_group_manager(dp_size, pp_size, tp_size): 57 | + global process_group_manager 58 | + process_group_manager = ProcessGroupManager(dp_size, pp_size, tp_size) 59 | \ No newline at end of file 60 | diff -x '*.diff' --new-file -ur step1_modeling/train.py step2_process_group_manager/train.py 61 | --- step1_modeling/train.py 2024-11-17 15:46:52.000000000 +0000 62 | +++ step2_process_group_manager/train.py 2024-11-17 15:43:28.000000000 +0000 63 | @@ -1,7 +1,8 @@ 64 | """ 65 | -torchrun --nproc_per_node 1 train.py 66 | +torchrun --nproc_per_node 2 train.py --tp_size 2 --run_name process_group_manager --use_wandb 67 | """ 68 | import os 69 | +import wandb 70 | import datetime 71 | import torch 72 | import torch.nn.functional as F 73 | @@ -11,6 +12,8 @@ 74 | from transformers import AutoConfig 75 | 76 | from model import Llama 77 | +import process_group_manager as pgm 78 | +from process_group_manager import setup_process_group_manager 79 | from utils import set_all_seed, print 80 | 81 | if __name__ == "__main__": 82 | @@ -31,6 +34,12 @@ 83 | parser.add_argument("--learning_rate", type=float, default=3e-4) 84 | parser.add_argument("--seq_len", type=int, default=32) 85 | parser.add_argument("--micro_batch_size", type=int, default=1) 86 | + 87 | + # Distributed training arguments 88 | + parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") 89 | + parser.add_argument("--dp_size", type=int, default=1, help="Data Parallel size") 90 | + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel size") 91 | + parser.add_argument("--pp_engine", type=str, default="afab", choices=["1f1b", "afab"]) 92 | 93 | # Logging arguments 94 | parser.add_argument("--run_name", type=str, default="default_run") 95 | @@ -52,9 +61,25 @@ 96 | dtype = torch.bfloat16 97 | 98 | dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2)) 99 | + setup_process_group_manager(dp_size=args.dp_size, pp_size=args.pp_size, tp_size=args.tp_size) 100 | 101 | + is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage 102 | set_all_seed(args.seed) 103 | 104 | + if is_wandb_rank and args.use_wandb: 105 | + wandb.init( 106 | + project="picotron_tutorial", 107 | + name=f"{args.run_name}_{pgm.process_group_manager}", 108 | + config={ 109 | + "tensor_parallel_size": pgm.process_group_manager.tp_world_size, 110 | + "pipeline_parallel_size": pgm.process_group_manager.pp_world_size, 111 | + "data_parallel_size": pgm.process_group_manager.dp_world_size, 112 | + "model": args.model_name, 113 | + "learning_rate": args.learning_rate, 114 | + "seed": args.seed, 115 | + }, 116 | + ) 117 | + 118 | model_config = AutoConfig.from_pretrained(args.model_name) 119 | model_config.num_hidden_layers = args.num_hidden_layers 120 | model_config.num_attention_heads = args.num_attention_heads 121 | @@ -92,6 +117,12 @@ 122 | # Optimizer step 123 | optimizer.step() 124 | 125 | - print(f"Loss: {loss.item():.4f}", is_print_rank=(global_rank == 0)) 126 | + print(f"[rank {pgm.process_group_manager.global_rank}], Loss: {loss:.4f}") 127 | + 128 | + if is_wandb_rank and args.use_wandb: 129 | + wandb.log({"loss": loss.item()}) 130 | + 131 | + if is_wandb_rank and args.use_wandb: 132 | + wandb.finish() 133 | 134 | dist.destroy_process_group() 135 | \ No newline at end of file 136 | -------------------------------------------------------------------------------- /step2_process_group_manager/process_group_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class ProcessGroupManager: 6 | def __init__(self, dp_size, pp_size, tp_size): 7 | self.global_rank = dist.get_rank() 8 | self.world_size = dist.get_world_size() 9 | self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) 10 | 11 | assert self.world_size == dp_size * pp_size * tp_size, f"World size ({self.world_size}) != DP ({self.dp_size}) * PP ({self.pp_size}) * TP ({self.tp_size})" 12 | 13 | self.grid = torch.arange(self.world_size).view(dp_size, pp_size, tp_size) # DP * PP * TP grid 14 | # Find the position of the current process in the grid 15 | self.dp_rank, self.pp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() 16 | 17 | # Process group creation - Update indexing to match new grid order 18 | self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :].tolist() for d in range(dp_size) for p in range(pp_size)])[0] 19 | self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, t].tolist() for d in range(dp_size) for t in range(tp_size)])[0] 20 | self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, t].tolist() for p in range(pp_size) for t in range(tp_size)])[0] 21 | self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, t].flatten().tolist() for t in range(tp_size)])[0] 22 | 23 | self.world_group = dist.group.WORLD 24 | 25 | # Update group IDs with new grid ordering 26 | self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, :].tolist() 27 | self.pp_group_ids = self.grid[self.dp_rank, :, self.tp_rank].tolist() 28 | self.dp_group_ids = self.grid[:, self.pp_rank, self.tp_rank].tolist() 29 | 30 | # Tensor parallelism 31 | self.tp_world_size = dist.get_world_size(group=self.tp_group) 32 | self.tp_first_rank = self.tp_group_ids[0] 33 | self.tp_last_rank = self.tp_group_ids[-1] 34 | 35 | # Pipeline parallelism 36 | self.pp_world_size = dist.get_world_size(group=self.pp_group) 37 | self.pp_first_rank = self.pp_group_ids[0] 38 | self.pp_last_rank = self.pp_group_ids[-1] 39 | self.pp_is_first_stage = self.pp_rank == 0 40 | self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 41 | self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.tp_rank].item()) 42 | self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.tp_rank].item()) 43 | 44 | # Data parallelism 45 | self.dp_world_size = dist.get_world_size(group=self.dp_group) 46 | self.dp_first_rank = self.dp_group_ids[0] 47 | self.dp_last_rank = self.dp_group_ids[-1] 48 | 49 | def __str__(self): 50 | return f"DP({self.dp_world_size})-PP({self.pp_world_size})-TP({self.tp_world_size})-Rank({self.global_rank})" 51 | 52 | def setup_process_group_manager(dp_size, pp_size, tp_size): 53 | global process_group_manager 54 | process_group_manager = ProcessGroupManager(dp_size, pp_size, tp_size) -------------------------------------------------------------------------------- /step2_process_group_manager/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | torchrun --nproc_per_node 2 train.py --tp_size 2 --run_name process_group_manager --use_wandb 3 | """ 4 | import os 5 | import wandb 6 | import datetime 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.distributed as dist 10 | import argparse 11 | from torch.optim import AdamW 12 | from transformers import AutoConfig 13 | 14 | from model import Llama 15 | import process_group_manager as pgm 16 | from process_group_manager import setup_process_group_manager 17 | from utils import set_all_seed, print 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser(description="Training script for LLaMA model") 21 | 22 | # Environment arguments 23 | parser.add_argument("--omp_num_threads", type=str, default="1") 24 | parser.add_argument("--tokenizers_parallelism", type=str, default="false") 25 | 26 | # Model arguments 27 | parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct") 28 | parser.add_argument("--num_hidden_layers", type=int, default=32) 29 | parser.add_argument("--num_attention_heads", type=int, default=16) 30 | parser.add_argument("--num_key_value_heads", type=int, default=4) 31 | 32 | # Training arguments 33 | parser.add_argument("--seed", type=int, default=42) 34 | parser.add_argument("--learning_rate", type=float, default=3e-4) 35 | parser.add_argument("--seq_len", type=int, default=32) 36 | parser.add_argument("--micro_batch_size", type=int, default=1) 37 | 38 | # Distributed training arguments 39 | parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") 40 | parser.add_argument("--dp_size", type=int, default=1, help="Data Parallel size") 41 | parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel size") 42 | parser.add_argument("--pp_engine", type=str, default="afab", choices=["1f1b", "afab"]) 43 | 44 | # Logging arguments 45 | parser.add_argument("--run_name", type=str, default="default_run") 46 | parser.add_argument("--use_wandb", action="store_true") 47 | 48 | args = parser.parse_args() 49 | 50 | # Set environment variables 51 | os.environ["OMP_NUM_THREADS"] = args.omp_num_threads 52 | os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism 53 | os.environ["DEVICE"] = "cuda" 54 | 55 | local_rank = int(os.environ["LOCAL_RANK"]) 56 | global_rank = int(os.environ["RANK"]) 57 | world_size = int(os.environ["WORLD_SIZE"]) 58 | backend = "nccl" 59 | torch.cuda.set_device(local_rank) 60 | device = torch.device("cuda", local_rank) 61 | dtype = torch.bfloat16 62 | 63 | dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2)) 64 | setup_process_group_manager(dp_size=args.dp_size, pp_size=args.pp_size, tp_size=args.tp_size) 65 | 66 | is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage 67 | set_all_seed(args.seed) 68 | 69 | if is_wandb_rank and args.use_wandb: 70 | wandb.init( 71 | project="picotron_tutorial", 72 | name=f"{args.run_name}_{pgm.process_group_manager}", 73 | config={ 74 | "tensor_parallel_size": pgm.process_group_manager.tp_world_size, 75 | "pipeline_parallel_size": pgm.process_group_manager.pp_world_size, 76 | "data_parallel_size": pgm.process_group_manager.dp_world_size, 77 | "model": args.model_name, 78 | "learning_rate": args.learning_rate, 79 | "seed": args.seed, 80 | }, 81 | ) 82 | 83 | model_config = AutoConfig.from_pretrained(args.model_name) 84 | model_config.num_hidden_layers = args.num_hidden_layers 85 | model_config.num_attention_heads = args.num_attention_heads 86 | model_config.num_key_value_heads = args.num_key_value_heads 87 | model_config.max_position_embeddings = args.seq_len 88 | 89 | model = Llama(config=model_config) 90 | model.to(dtype).to(device) 91 | model.train() 92 | 93 | dist.barrier() 94 | 95 | optimizer = AdamW(model.parameters(), lr=args.learning_rate) 96 | 97 | dist.barrier() 98 | 99 | # Create dummy data 100 | input_ids = torch.randint(0, model_config.vocab_size, (args.micro_batch_size, args.seq_len), device=device) 101 | target_ids = torch.randint(0, model_config.vocab_size, (args.micro_batch_size, args.seq_len), device=device) 102 | 103 | # Training step 104 | optimizer.zero_grad() 105 | 106 | # Forward pass 107 | outputs = model(input_ids=input_ids) 108 | 109 | # Compute loss 110 | target_ids = target_ids.reshape(-1) 111 | outputs = outputs.view(-1, model_config.vocab_size) 112 | loss = F.cross_entropy(outputs, target_ids) 113 | 114 | # Backward pass 115 | loss.backward() 116 | 117 | # Optimizer step 118 | optimizer.step() 119 | 120 | print(f"[rank {pgm.process_group_manager.global_rank}], Loss: {loss:.4f}") 121 | 122 | if is_wandb_rank and args.use_wandb: 123 | wandb.log({"loss": loss.item()}) 124 | 125 | if is_wandb_rank and args.use_wandb: 126 | wandb.finish() 127 | 128 | dist.destroy_process_group() -------------------------------------------------------------------------------- /step2_process_group_manager/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import builtins 5 | import fcntl 6 | 7 | def print(*args, is_print_rank=True, **kwargs): 8 | """ solves multi-process interleaved print problem """ 9 | if not is_print_rank: return 10 | with open(__file__, "r") as fh: 11 | fcntl.flock(fh, fcntl.LOCK_EX) 12 | try: 13 | builtins.print(*args, **kwargs) 14 | finally: 15 | fcntl.flock(fh, fcntl.LOCK_UN) 16 | 17 | def set_all_seed(seed): 18 | for module in [random, np.random]: module.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 21 | 22 | def to_readable_format(num, precision=3): 23 | num_str = str(num) 24 | length = len(num_str) 25 | 26 | def format_with_precision(main, decimal, suffix): 27 | if precision == 0: 28 | return f"{main}{suffix}" 29 | return f"{main}.{decimal[:precision]}{suffix}" 30 | 31 | if length > 12: # Trillions 32 | return format_with_precision(num_str[:-12], num_str[-12:], 'T') 33 | elif length > 9: # Billions 34 | return format_with_precision(num_str[:-9], num_str[-9:], 'B') 35 | elif length > 6: # Millions 36 | return format_with_precision(num_str[:-6], num_str[-6:], 'M') 37 | elif length > 3: # Thousands 38 | return format_with_precision(num_str[:-3], num_str[-3:], 'K') 39 | else: 40 | return num_str -------------------------------------------------------------------------------- /step3_dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import numpy as np 4 | from functools import partial 5 | from datasets import Features, Sequence, Value, load_dataset 6 | from transformers import AutoTokenizer 7 | 8 | import process_group_manager as pgm 9 | 10 | class MicroBatchDataLoader(DataLoader): 11 | def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, split="train"): 12 | 13 | self.micro_batch_size = micro_batch_size 14 | self.grad_acc_steps = grad_acc_steps 15 | self.seq_len = seq_len 16 | 17 | self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size 18 | 19 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 20 | self.dataset = load_dataset(dataset_name, split=split) 21 | 22 | # Tokenize and chunk the dataset 23 | self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_len, num_proc) 24 | 25 | total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1) 26 | assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens" 27 | 28 | super().__init__( 29 | self.tokenized_dataset, 30 | batch_size=micro_batch_size, 31 | collate_fn=self.collate_batch, 32 | pin_memory=True, 33 | num_workers=num_workers, 34 | shuffle=False, 35 | ) 36 | 37 | def tokenizer_group_text(self, examples, tokenizer, sequence_length): 38 | """Tokenize a list of texts and group them in chunks of sequence_length + 1""" 39 | tokenized_text_batch = tokenizer.batch_encode_plus( 40 | examples, 41 | return_attention_mask=False, 42 | return_token_type_ids=False, 43 | return_tensors='np' 44 | ) 45 | concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} 46 | total_length = len(concatenated_tokens['input_ids']) 47 | 48 | if total_length >= sequence_length + 1: 49 | total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 50 | 51 | result = { 52 | 'input_ids': [ 53 | concatenated_tokens['input_ids'][i : i + sequence_length + 1] 54 | for i in range(0, total_length - sequence_length, sequence_length) 55 | ] 56 | } 57 | return result 58 | 59 | def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): 60 | """Tokenize the dataset and group texts in chunks of sequence_length + 1""" 61 | tokenizer_func = partial( 62 | self.tokenizer_group_text, 63 | tokenizer=self.tokenizer, 64 | sequence_length=sequence_length 65 | ) 66 | 67 | tokenized_dataset = dataset.map( 68 | tokenizer_func, 69 | input_columns=text_column_name, 70 | remove_columns=dataset.column_names, 71 | features=Features({ 72 | "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) 73 | }), 74 | batched=True, 75 | num_proc=num_proc, 76 | load_from_cache_file=True, # Preprocess dataset only once and cache it 77 | desc=f"Grouping texts in chunks of {sequence_length+1}", 78 | ) 79 | 80 | return tokenized_dataset 81 | 82 | def collate_batch(self, batch): 83 | batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) 84 | batch_size = batch_input_ids.size(0) 85 | input_ids = batch_input_ids[:, :-1].contiguous() 86 | target_ids = batch_input_ids[:, 1:].contiguous() 87 | position_ids = torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() 88 | attn_mask = torch.tril(torch.ones((self.seq_len, self.seq_len), dtype=torch.bool)) 89 | attn_mask = attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() 90 | 91 | return { 92 | "input_ids": input_ids, 93 | "target_ids": target_ids, 94 | "position_ids": position_ids, 95 | "attn_mask": attn_mask, 96 | "hidden_states": None 97 | } 98 | 99 | def __iter__(self): 100 | if self._iterator is None: 101 | self._iterator = super().__iter__() 102 | return self 103 | 104 | def __next__(self): 105 | if self._iterator is None: 106 | self._iterator = super().__iter__() 107 | try: 108 | batch = next(self._iterator) 109 | except StopIteration: 110 | self._iterator = None 111 | raise StopIteration 112 | return batch -------------------------------------------------------------------------------- /step3_dataloader/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from flash_attn.flash_attn_interface import flash_attn_func 5 | from flash_attn.layers.rotary import apply_rotary_emb 6 | from flash_attn.ops.triton.layer_norm import layer_norm_fn 7 | 8 | def flash_attention(q, k, v, causal = True): 9 | q = q.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 10 | k = k.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 11 | v = v.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 12 | return flash_attn_func(q, k, v, causal=causal) 13 | 14 | def get_cos_sin(seq_length, head_dim, base=500000.0): 15 | assert head_dim%2==0 16 | # Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU 17 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim)) 18 | dtype = torch.bfloat16 19 | device = torch.device('cuda') 20 | position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1] 21 | # To match transformers implementation. m * theta should be computed on GPU 22 | theta = theta.to(device) 23 | return torch.cos(position.float()*theta.float()).to(dtype).repeat(1,2), torch.sin(position.float()*theta.float()).to(dtype).repeat(1,2) # [seq_length, head_dim], [seq_length, head_dim] 24 | 25 | class TritonRMSNorm(nn.Module): 26 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 27 | super().__init__() 28 | self.eps = eps 29 | self.weight = nn.Parameter(torch.ones(hidden_size)) 30 | self.register_parameter("bias", None) 31 | 32 | def forward( 33 | self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False 34 | ): 35 | return layer_norm_fn( 36 | hidden_states, 37 | self.weight, 38 | None, 39 | residual=residual, 40 | eps=self.eps, 41 | dropout_p=dropout_p, 42 | prenorm=prenorm, 43 | residual_in_fp32=residual_in_fp32, 44 | is_rms_norm=True, 45 | return_dropout_mask=return_dropout_mask, 46 | ) 47 | 48 | class Attention(nn.Module): 49 | def __init__(self, config, layer_idx): 50 | super().__init__() 51 | self.hidden_size = config.hidden_size 52 | self.num_heads = config.num_attention_heads 53 | self.num_key_values = config.num_key_value_heads 54 | self.head_dim = self.hidden_size//self.num_heads 55 | self.num_local_heads = config.num_attention_heads 56 | self.num_local_kv_heads = config.num_key_value_heads 57 | 58 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) 59 | self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 60 | self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 61 | self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 62 | self.layer_idx = layer_idx 63 | 64 | def forward(self, x, cos, sin, attention_mask=None, position_ids=None): 65 | batch_size, seq_length, hidden_dim = x.size() 66 | q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim] 67 | k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 68 | v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 69 | 70 | q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim) # [batch_size, seq_length, num_heads, head_dim] 71 | k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim) # [batch_size, seq_length, num_key_values, head_dim] 72 | q = apply_rotary_emb(q,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_heads, head_dim] 73 | k = apply_rotary_emb(k,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_key_values, head_dim] 74 | q = q.transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim] 75 | k = k.transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim] 76 | v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1,2) # [batch_size, num_key_values, seq_length, head_dim] 77 | 78 | k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 79 | v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 80 | 81 | causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. 82 | 83 | out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] 84 | 85 | out = out.reshape(batch_size, seq_length, self.num_local_heads * self.head_dim) # [batch_size, seq_length, hidden_dim] 86 | out = self.out_proj(out) # [batch_size, seq_length, hidden_dim] 87 | return out 88 | 89 | class MLP(nn.Module): 90 | def __init__(self, config) -> None: 91 | super().__init__() 92 | self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 93 | self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 94 | self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 95 | 96 | def forward(self, x): 97 | #TODO: dont do single line operations as it is harder to debug 98 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 99 | 100 | class DecoderLayer(nn.Module): 101 | # TritonRMSNorm -> Attention -> Residual -> TritonRMSNorm -> MLP -> Residual 102 | def __init__(self, config, layer_idx): 103 | super().__init__() 104 | self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 105 | self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 106 | self.attention = Attention(config, layer_idx = layer_idx) 107 | self.mlp = MLP(config) 108 | self.layer_idx = layer_idx 109 | head_dim = config.hidden_size // config.num_attention_heads 110 | self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] 111 | 112 | def forward(self, x, attention_mask = None, position_ids = None): 113 | cos, sin = self.cos, self.sin 114 | x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention 115 | x = x + self.mlp(self.post_attention_layernorm(x)) # MLP 116 | return x 117 | 118 | class Llama(nn.Module): 119 | def __init__(self, config) -> None: 120 | super().__init__() 121 | # sanity check 122 | assert config.hidden_size % config.num_attention_heads==0 123 | assert config.num_attention_heads % config.num_key_value_heads==0 124 | 125 | # params 126 | self.vocab_size = config.vocab_size 127 | self.hidden_size = config.hidden_size 128 | self.num_heads = config.num_attention_heads 129 | self.num_key_values = config.num_key_value_heads 130 | self.head_dim = self.hidden_size//self.num_heads 131 | self.max_position_embeddings = config.max_position_embeddings 132 | self.num_layers = config.num_hidden_layers 133 | self.model_config = config 134 | 135 | # modules 136 | self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) 137 | self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)]) 138 | self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 139 | self.final_norm = TritonRMSNorm(self.hidden_size, eps=config.rms_norm_eps) 140 | 141 | def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None): 142 | x = self.embedding(input_ids) 143 | for layer in self.decoder_layers: 144 | x = layer(x) # [batch_size, seq_length, hidden_dim] 145 | x = self.final_norm(x) 146 | logits = self.final_proj(x) 147 | 148 | return logits # [batch_size, seq_length, vocab_size] -------------------------------------------------------------------------------- /step3_dataloader/process_group_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class ProcessGroupManager: 6 | def __init__(self, dp_size, pp_size, tp_size): 7 | self.global_rank = dist.get_rank() 8 | self.world_size = dist.get_world_size() 9 | self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) 10 | 11 | assert self.world_size == dp_size * pp_size * tp_size, f"World size ({self.world_size}) != DP ({self.dp_size}) * PP ({self.pp_size}) * TP ({self.tp_size})" 12 | 13 | self.grid = torch.arange(self.world_size).view(dp_size, pp_size, tp_size) # DP * PP * TP grid 14 | # Find the position of the current process in the grid 15 | self.dp_rank, self.pp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() 16 | 17 | # Process group creation - Update indexing to match new grid order 18 | self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :].tolist() for d in range(dp_size) for p in range(pp_size)])[0] 19 | self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, t].tolist() for d in range(dp_size) for t in range(tp_size)])[0] 20 | self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, t].tolist() for p in range(pp_size) for t in range(tp_size)])[0] 21 | self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, t].flatten().tolist() for t in range(tp_size)])[0] 22 | 23 | self.world_group = dist.group.WORLD 24 | 25 | # Update group IDs with new grid ordering 26 | self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, :].tolist() 27 | self.pp_group_ids = self.grid[self.dp_rank, :, self.tp_rank].tolist() 28 | self.dp_group_ids = self.grid[:, self.pp_rank, self.tp_rank].tolist() 29 | 30 | # Tensor parallelism 31 | self.tp_world_size = dist.get_world_size(group=self.tp_group) 32 | self.tp_first_rank = self.tp_group_ids[0] 33 | self.tp_last_rank = self.tp_group_ids[-1] 34 | 35 | # Pipeline parallelism 36 | self.pp_world_size = dist.get_world_size(group=self.pp_group) 37 | self.pp_first_rank = self.pp_group_ids[0] 38 | self.pp_last_rank = self.pp_group_ids[-1] 39 | self.pp_is_first_stage = self.pp_rank == 0 40 | self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 41 | self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.tp_rank].item()) 42 | self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.tp_rank].item()) 43 | 44 | # Data parallelism 45 | self.dp_world_size = dist.get_world_size(group=self.dp_group) 46 | self.dp_first_rank = self.dp_group_ids[0] 47 | self.dp_last_rank = self.dp_group_ids[-1] 48 | 49 | def __str__(self): 50 | return f"DP({self.dp_world_size})-PP({self.pp_world_size})-TP({self.tp_world_size})-Rank({self.global_rank})" 51 | 52 | def setup_process_group_manager(dp_size, pp_size, tp_size): 53 | global process_group_manager 54 | process_group_manager = ProcessGroupManager(dp_size, pp_size, tp_size) -------------------------------------------------------------------------------- /step3_dataloader/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | torchrun --nproc_per_node 1 train.py --micro_batch_size 4 --gradient_accumulation_steps 8 --seq_len 128 --max_tokens 40960 --num_proc 16 --run_name dataloader --use_wandb 3 | """ 4 | import os 5 | import time 6 | import wandb 7 | import datetime 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | import argparse 12 | from torch.optim import AdamW 13 | from transformers import AutoConfig 14 | 15 | import lovely_tensors as lt; lt.monkey_patch() 16 | 17 | from model import Llama 18 | from dataloader import MicroBatchDataLoader 19 | import process_group_manager as pgm 20 | from process_group_manager import setup_process_group_manager 21 | from utils import set_all_seed, print, to_readable_format 22 | 23 | def train_step(model, dataloader, device): 24 | acc_loss = 0.0 25 | 26 | for i in range(dataloader.grad_acc_steps): 27 | # get the next batch 28 | batch = next(dataloader) 29 | input_ids = batch["input_ids"].to(device) 30 | target_ids = batch["target_ids"].to(device) 31 | 32 | outputs = model(input_ids=input_ids) 33 | 34 | # compute the loss 35 | batch_size, seq_len = input_ids.shape 36 | target_ids = target_ids.reshape(-1) 37 | outputs = outputs.view(seq_len*batch_size, -1) 38 | loss = F.cross_entropy(outputs, target_ids, reduction='mean') / dataloader.grad_acc_steps 39 | 40 | loss.backward() 41 | 42 | acc_loss += loss.item() 43 | 44 | return acc_loss 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser(description="Training script for LLaMA model") 48 | 49 | # Environment arguments 50 | parser.add_argument("--omp_num_threads", type=str, default="1") 51 | parser.add_argument("--tokenizers_parallelism", type=str, default="false") 52 | 53 | # Model arguments 54 | parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct") 55 | parser.add_argument("--num_hidden_layers", type=int, default=32) 56 | parser.add_argument("--num_attention_heads", type=int, default=16) 57 | parser.add_argument("--num_key_value_heads", type=int, default=4) 58 | 59 | # Dataset arguments 60 | parser.add_argument("--dataset_name", type=str, default="roneneldan/TinyStories") 61 | parser.add_argument("--num_workers", type=int, default=1) 62 | parser.add_argument("--num_proc", type=int, default=4) 63 | 64 | # Training arguments 65 | parser.add_argument("--seed", type=int, default=42) 66 | parser.add_argument("--learning_rate", type=float, default=3e-4) 67 | parser.add_argument("--seq_len", type=int, default=32) 68 | parser.add_argument("--micro_batch_size", type=int, default=1) 69 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 70 | parser.add_argument("--max_tokens", type=int, default=1e6) 71 | 72 | # Distributed training arguments 73 | parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") 74 | parser.add_argument("--dp_size", type=int, default=1, help="Data Parallel size") 75 | parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel size") 76 | parser.add_argument("--pp_engine", type=str, default="afab", choices=["1f1b", "afab"]) 77 | 78 | # Logging arguments 79 | parser.add_argument("--run_name", type=str, default="default_run") 80 | parser.add_argument("--use_wandb", action="store_true") 81 | 82 | args = parser.parse_args() 83 | 84 | # Set environment variables 85 | os.environ["OMP_NUM_THREADS"] = args.omp_num_threads 86 | os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism 87 | os.environ["DEVICE"] = "cuda" 88 | 89 | local_rank = int(os.environ["LOCAL_RANK"]) 90 | global_rank = int(os.environ["RANK"]) 91 | world_size = int(os.environ["WORLD_SIZE"]) 92 | backend = "nccl" 93 | torch.cuda.set_device(local_rank) 94 | device = torch.device("cuda", local_rank) 95 | dtype = torch.bfloat16 96 | 97 | dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2)) 98 | setup_process_group_manager(dp_size=args.dp_size, pp_size=args.pp_size, tp_size=args.tp_size) 99 | 100 | is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage 101 | set_all_seed(args.seed) 102 | 103 | if is_wandb_rank and args.use_wandb: 104 | wandb.init( 105 | project="picotron_tutorial", 106 | name=f"{args.run_name}_{pgm.process_group_manager}", 107 | config={ 108 | "tensor_parallel_size": pgm.process_group_manager.tp_world_size, 109 | "pipeline_parallel_size": pgm.process_group_manager.pp_world_size, 110 | "data_parallel_size": pgm.process_group_manager.dp_world_size, 111 | "model": args.model_name, 112 | "learning_rate": args.learning_rate, 113 | "seed": args.seed, 114 | }, 115 | ) 116 | 117 | model_config = AutoConfig.from_pretrained(args.model_name) 118 | model_config.num_hidden_layers = args.num_hidden_layers 119 | model_config.num_attention_heads = args.num_attention_heads 120 | model_config.num_key_value_heads = args.num_key_value_heads 121 | model_config.max_position_embeddings = args.seq_len 122 | 123 | model = Llama(config=model_config) 124 | model.to(dtype).to(device) 125 | model.train() 126 | 127 | dist.barrier() 128 | 129 | optimizer = AdamW(model.parameters(), lr=args.learning_rate) 130 | 131 | dist.barrier() 132 | 133 | # Create dataloader 134 | dataloader = MicroBatchDataLoader( 135 | seq_len=args.seq_len, 136 | micro_batch_size=args.micro_batch_size, 137 | grad_acc_steps=args.gradient_accumulation_steps, 138 | dataset_name=args.dataset_name, 139 | tokenizer_name=args.model_name, 140 | max_tokens=args.max_tokens, 141 | num_workers=args.num_workers, 142 | num_proc=args.num_proc, 143 | ) 144 | 145 | tokens_per_step = dataloader.global_batch_size * args.seq_len 146 | if pgm.process_group_manager.global_rank == 0: 147 | print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank) 148 | 149 | trained_token, step = 0, 0 150 | 151 | dist.barrier() 152 | 153 | # Training loop 154 | while trained_token < args.max_tokens: 155 | 156 | step_start_time = time.time() 157 | optimizer.zero_grad() 158 | 159 | loss = train_step(model, dataloader, device) 160 | 161 | optimizer.step() 162 | 163 | step_duration = time.time() - step_start_time 164 | trained_token += tokens_per_step 165 | step += 1 166 | 167 | print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " 168 | f"Global batch size (with seq_len): {to_readable_format(tokens_per_step)}, " 169 | f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, " 170 | f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, " 171 | f"Tokens: {to_readable_format(trained_token)}{('/' + to_readable_format(args.max_tokens))}, " 172 | f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB" 173 | , is_print_rank=is_wandb_rank 174 | ) 175 | 176 | if is_wandb_rank and args.use_wandb: 177 | wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\ 178 | "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": tokens_per_step}) 179 | 180 | if is_wandb_rank and args.use_wandb: 181 | wandb.finish() 182 | 183 | dist.destroy_process_group() -------------------------------------------------------------------------------- /step3_dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import builtins 5 | import fcntl 6 | 7 | def print(*args, is_print_rank=True, **kwargs): 8 | """ solves multi-process interleaved print problem """ 9 | if not is_print_rank: return 10 | with open(__file__, "r") as fh: 11 | fcntl.flock(fh, fcntl.LOCK_EX) 12 | try: 13 | builtins.print(*args, **kwargs) 14 | finally: 15 | fcntl.flock(fh, fcntl.LOCK_UN) 16 | 17 | def set_all_seed(seed): 18 | for module in [random, np.random]: module.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 21 | 22 | def to_readable_format(num, precision=3): 23 | num_str = str(num) 24 | length = len(num_str) 25 | 26 | def format_with_precision(main, decimal, suffix): 27 | if precision == 0: 28 | return f"{main}{suffix}" 29 | return f"{main}.{decimal[:precision]}{suffix}" 30 | 31 | if length > 12: # Trillions 32 | return format_with_precision(num_str[:-12], num_str[-12:], 'T') 33 | elif length > 9: # Billions 34 | return format_with_precision(num_str[:-9], num_str[-9:], 'B') 35 | elif length > 6: # Millions 36 | return format_with_precision(num_str[:-6], num_str[-6:], 'M') 37 | elif length > 3: # Thousands 38 | return format_with_precision(num_str[:-3], num_str[-3:], 'K') 39 | else: 40 | return num_str -------------------------------------------------------------------------------- /step4_tensor_parallel/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import numpy as np 4 | from functools import partial 5 | from datasets import Features, Sequence, Value, load_dataset 6 | from transformers import AutoTokenizer 7 | 8 | import process_group_manager as pgm 9 | 10 | class MicroBatchDataLoader(DataLoader): 11 | def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, split="train"): 12 | 13 | self.micro_batch_size = micro_batch_size 14 | self.grad_acc_steps = grad_acc_steps 15 | self.seq_len = seq_len 16 | 17 | self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size 18 | 19 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 20 | self.dataset = load_dataset(dataset_name, split=split) 21 | 22 | # Tokenize and chunk the dataset 23 | self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_len, num_proc) 24 | 25 | total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1) 26 | assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens" 27 | 28 | super().__init__( 29 | self.tokenized_dataset, 30 | batch_size=micro_batch_size, 31 | collate_fn=self.collate_batch, 32 | pin_memory=True, 33 | num_workers=num_workers, 34 | shuffle=False, 35 | ) 36 | 37 | def tokenizer_group_text(self, examples, tokenizer, sequence_length): 38 | """Tokenize a list of texts and group them in chunks of sequence_length + 1""" 39 | tokenized_text_batch = tokenizer.batch_encode_plus( 40 | examples, 41 | return_attention_mask=False, 42 | return_token_type_ids=False, 43 | return_tensors='np' 44 | ) 45 | concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} 46 | total_length = len(concatenated_tokens['input_ids']) 47 | 48 | if total_length >= sequence_length + 1: 49 | total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 50 | 51 | result = { 52 | 'input_ids': [ 53 | concatenated_tokens['input_ids'][i : i + sequence_length + 1] 54 | for i in range(0, total_length - sequence_length, sequence_length) 55 | ] 56 | } 57 | return result 58 | 59 | def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): 60 | """Tokenize the dataset and group texts in chunks of sequence_length + 1""" 61 | tokenizer_func = partial( 62 | self.tokenizer_group_text, 63 | tokenizer=self.tokenizer, 64 | sequence_length=sequence_length 65 | ) 66 | 67 | tokenized_dataset = dataset.map( 68 | tokenizer_func, 69 | input_columns=text_column_name, 70 | remove_columns=dataset.column_names, 71 | features=Features({ 72 | "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) 73 | }), 74 | batched=True, 75 | num_proc=num_proc, 76 | load_from_cache_file=True, # Preprocess dataset only once and cache it 77 | desc=f"Grouping texts in chunks of {sequence_length+1}", 78 | ) 79 | 80 | return tokenized_dataset 81 | 82 | def collate_batch(self, batch): 83 | batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) 84 | batch_size = batch_input_ids.size(0) 85 | input_ids = batch_input_ids[:, :-1].contiguous() 86 | target_ids = batch_input_ids[:, 1:].contiguous() 87 | position_ids = torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() 88 | attn_mask = torch.tril(torch.ones((self.seq_len, self.seq_len), dtype=torch.bool)) 89 | attn_mask = attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() 90 | 91 | return { 92 | "input_ids": input_ids, 93 | "target_ids": target_ids, 94 | "position_ids": position_ids, 95 | "attn_mask": attn_mask, 96 | "hidden_states": None 97 | } 98 | 99 | def __iter__(self): 100 | if self._iterator is None: 101 | self._iterator = super().__iter__() 102 | return self 103 | 104 | def __next__(self): 105 | if self._iterator is None: 106 | self._iterator = super().__iter__() 107 | try: 108 | batch = next(self._iterator) 109 | except StopIteration: 110 | self._iterator = None 111 | raise StopIteration 112 | return batch -------------------------------------------------------------------------------- /step4_tensor_parallel/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from flash_attn.flash_attn_interface import flash_attn_func 5 | from flash_attn.layers.rotary import apply_rotary_emb 6 | from flash_attn.ops.triton.layer_norm import layer_norm_fn 7 | import process_group_manager as pgm 8 | 9 | def flash_attention(q, k, v, causal = True): 10 | q = q.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 11 | k = k.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 12 | v = v.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 13 | return flash_attn_func(q, k, v, causal=causal) 14 | 15 | def get_cos_sin(seq_length, head_dim, base=500000.0): 16 | assert head_dim%2==0 17 | # Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU 18 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim)) 19 | dtype = torch.bfloat16 20 | device = torch.device('cuda') 21 | position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1] 22 | # To match transformers implementation. m * theta should be computed on GPU 23 | theta = theta.to(device) 24 | return torch.cos(position.float()*theta.float()).to(dtype).repeat(1,2), torch.sin(position.float()*theta.float()).to(dtype).repeat(1,2) # [seq_length, head_dim], [seq_length, head_dim] 25 | 26 | class TritonRMSNorm(nn.Module): 27 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 28 | super().__init__() 29 | self.eps = eps 30 | self.weight = nn.Parameter(torch.ones(hidden_size)) 31 | self.register_parameter("bias", None) 32 | 33 | def forward( 34 | self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False 35 | ): 36 | return layer_norm_fn( 37 | hidden_states, 38 | self.weight, 39 | None, 40 | residual=residual, 41 | eps=self.eps, 42 | dropout_p=dropout_p, 43 | prenorm=prenorm, 44 | residual_in_fp32=residual_in_fp32, 45 | is_rms_norm=True, 46 | return_dropout_mask=return_dropout_mask, 47 | ) 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, config, layer_idx): 51 | super().__init__() 52 | self.hidden_size = config.hidden_size 53 | self.num_heads = config.num_attention_heads 54 | self.num_key_values = config.num_key_value_heads 55 | self.head_dim = self.hidden_size//self.num_heads 56 | assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size" 57 | assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size" 58 | self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism 59 | self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism 60 | 61 | 62 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) 63 | self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 64 | self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 65 | self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 66 | self.layer_idx = layer_idx 67 | 68 | def forward(self, x, cos, sin, attention_mask=None, position_ids=None): 69 | batch_size, seq_length, hidden_dim = x.size() 70 | q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim] 71 | k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 72 | v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 73 | 74 | q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim) # [batch_size, seq_length, num_heads, head_dim] 75 | k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim) # [batch_size, seq_length, num_key_values, head_dim] 76 | q = apply_rotary_emb(q,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_heads, head_dim] 77 | k = apply_rotary_emb(k,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_key_values, head_dim] 78 | q = q.transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim] 79 | k = k.transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim] 80 | v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1,2) # [batch_size, num_key_values, seq_length, head_dim] 81 | 82 | k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 83 | v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 84 | 85 | causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. 86 | 87 | out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] 88 | 89 | out = out.reshape(batch_size, seq_length, self.num_local_heads * self.head_dim) # [batch_size, seq_length, hidden_dim] 90 | out = self.out_proj(out) # [batch_size, seq_length, hidden_dim] 91 | return out 92 | 93 | class MLP(nn.Module): 94 | def __init__(self, config) -> None: 95 | super().__init__() 96 | self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 97 | self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 98 | self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 99 | 100 | def forward(self, x): 101 | #TODO: dont do single line operations as it is harder to debug 102 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 103 | 104 | class DecoderLayer(nn.Module): 105 | # TritonRMSNorm -> Attention -> Residual -> TritonRMSNorm -> MLP -> Residual 106 | def __init__(self, config, layer_idx): 107 | super().__init__() 108 | self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 109 | self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 110 | self.attention = Attention(config, layer_idx = layer_idx) 111 | self.mlp = MLP(config) 112 | self.layer_idx = layer_idx 113 | head_dim = config.hidden_size // config.num_attention_heads 114 | self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] 115 | 116 | def forward(self, x, attention_mask = None, position_ids = None): 117 | cos, sin = self.cos, self.sin 118 | x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention 119 | x = x + self.mlp(self.post_attention_layernorm(x)) # MLP 120 | return x 121 | 122 | class Llama(nn.Module): 123 | def __init__(self, config) -> None: 124 | super().__init__() 125 | # sanity check 126 | assert config.hidden_size % config.num_attention_heads==0 127 | assert config.num_attention_heads % config.num_key_value_heads==0 128 | 129 | # params 130 | self.vocab_size = config.vocab_size 131 | self.hidden_size = config.hidden_size 132 | self.num_heads = config.num_attention_heads 133 | self.num_key_values = config.num_key_value_heads 134 | self.head_dim = self.hidden_size//self.num_heads 135 | self.max_position_embeddings = config.max_position_embeddings 136 | self.num_layers = config.num_hidden_layers 137 | self.model_config = config 138 | 139 | # modules 140 | self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) 141 | self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)]) 142 | self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 143 | self.final_norm = TritonRMSNorm(self.hidden_size, eps=config.rms_norm_eps) 144 | 145 | def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None): 146 | x = self.embedding(input_ids) 147 | for layer in self.decoder_layers: 148 | x = layer(x) # [batch_size, seq_length, hidden_dim] 149 | x = self.final_norm(x) 150 | logits = self.final_proj(x) 151 | 152 | return logits # [batch_size, seq_length, vocab_size] -------------------------------------------------------------------------------- /step4_tensor_parallel/process_group_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class ProcessGroupManager: 6 | def __init__(self, dp_size, pp_size, tp_size): 7 | self.global_rank = dist.get_rank() 8 | self.world_size = dist.get_world_size() 9 | self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) 10 | 11 | assert self.world_size == dp_size * pp_size * tp_size, f"World size ({self.world_size}) != DP ({self.dp_size}) * PP ({self.pp_size}) * TP ({self.tp_size})" 12 | 13 | self.grid = torch.arange(self.world_size).view(dp_size, pp_size, tp_size) # DP * PP * TP grid 14 | # Find the position of the current process in the grid 15 | self.dp_rank, self.pp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() 16 | 17 | # Process group creation - Update indexing to match new grid order 18 | self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :].tolist() for d in range(dp_size) for p in range(pp_size)])[0] 19 | self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, t].tolist() for d in range(dp_size) for t in range(tp_size)])[0] 20 | self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, t].tolist() for p in range(pp_size) for t in range(tp_size)])[0] 21 | self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, t].flatten().tolist() for t in range(tp_size)])[0] 22 | 23 | self.world_group = dist.group.WORLD 24 | 25 | # Update group IDs with new grid ordering 26 | self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, :].tolist() 27 | self.pp_group_ids = self.grid[self.dp_rank, :, self.tp_rank].tolist() 28 | self.dp_group_ids = self.grid[:, self.pp_rank, self.tp_rank].tolist() 29 | 30 | # Tensor parallelism 31 | self.tp_world_size = dist.get_world_size(group=self.tp_group) 32 | self.tp_first_rank = self.tp_group_ids[0] 33 | self.tp_last_rank = self.tp_group_ids[-1] 34 | 35 | # Pipeline parallelism 36 | self.pp_world_size = dist.get_world_size(group=self.pp_group) 37 | self.pp_first_rank = self.pp_group_ids[0] 38 | self.pp_last_rank = self.pp_group_ids[-1] 39 | self.pp_is_first_stage = self.pp_rank == 0 40 | self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 41 | self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.tp_rank].item()) 42 | self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.tp_rank].item()) 43 | 44 | # Data parallelism 45 | self.dp_world_size = dist.get_world_size(group=self.dp_group) 46 | self.dp_first_rank = self.dp_group_ids[0] 47 | self.dp_last_rank = self.dp_group_ids[-1] 48 | 49 | def __str__(self): 50 | return f"DP({self.dp_world_size})-PP({self.pp_world_size})-TP({self.tp_world_size})-Rank({self.global_rank})" 51 | 52 | def setup_process_group_manager(dp_size, pp_size, tp_size): 53 | global process_group_manager 54 | process_group_manager = ProcessGroupManager(dp_size, pp_size, tp_size) -------------------------------------------------------------------------------- /step4_tensor_parallel/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | torchrun --nproc_per_node 4 train.py --tp_size 4 --micro_batch_size 4 --gradient_accumulation_steps 8 --seq_len 128 --max_tokens 40960 --num_proc 16 --run_name tp_naive --use_wandb 3 | """ 4 | import os 5 | import time 6 | import wandb 7 | import datetime 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | import argparse 12 | from torch.optim import AdamW 13 | from transformers import AutoConfig 14 | 15 | import lovely_tensors as lt; lt.monkey_patch() 16 | 17 | from model import Llama 18 | from dataloader import MicroBatchDataLoader 19 | import process_group_manager as pgm 20 | from process_group_manager import setup_process_group_manager 21 | from utils import set_all_seed, print, to_readable_format 22 | 23 | from tensor_parallel import apply_tensor_parallel 24 | 25 | def train_step(model, dataloader, device): 26 | acc_loss = 0.0 27 | 28 | for i in range(dataloader.grad_acc_steps): 29 | # get the next batch 30 | batch = next(dataloader) 31 | input_ids = batch["input_ids"].to(device) 32 | target_ids = batch["target_ids"].to(device) 33 | 34 | outputs = model(input_ids=input_ids) 35 | 36 | # compute the loss 37 | batch_size, seq_len = input_ids.shape 38 | target_ids = target_ids.reshape(-1) 39 | outputs = outputs.view(seq_len*batch_size, -1) 40 | loss = F.cross_entropy(outputs, target_ids, reduction='mean') / dataloader.grad_acc_steps 41 | 42 | loss.backward() 43 | 44 | acc_loss += loss.item() 45 | 46 | return acc_loss 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser(description="Training script for LLaMA model") 50 | 51 | # Environment arguments 52 | parser.add_argument("--omp_num_threads", type=str, default="1") 53 | parser.add_argument("--tokenizers_parallelism", type=str, default="false") 54 | 55 | # Model arguments 56 | parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct") 57 | parser.add_argument("--num_hidden_layers", type=int, default=32) 58 | parser.add_argument("--num_attention_heads", type=int, default=16) 59 | parser.add_argument("--num_key_value_heads", type=int, default=4) 60 | 61 | # Dataset arguments 62 | parser.add_argument("--dataset_name", type=str, default="roneneldan/TinyStories") 63 | parser.add_argument("--num_workers", type=int, default=1) 64 | parser.add_argument("--num_proc", type=int, default=4) 65 | 66 | # Training arguments 67 | parser.add_argument("--seed", type=int, default=42) 68 | parser.add_argument("--learning_rate", type=float, default=3e-4) 69 | parser.add_argument("--seq_len", type=int, default=32) 70 | parser.add_argument("--micro_batch_size", type=int, default=1) 71 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 72 | parser.add_argument("--max_tokens", type=int, default=1e6) 73 | 74 | # Distributed training arguments 75 | parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") 76 | parser.add_argument("--dp_size", type=int, default=1, help="Data Parallel size") 77 | parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel size") 78 | parser.add_argument("--pp_engine", type=str, default="afab", choices=["1f1b", "afab"]) 79 | 80 | # Logging arguments 81 | parser.add_argument("--run_name", type=str, default="default_run") 82 | parser.add_argument("--use_wandb", action="store_true") 83 | 84 | args = parser.parse_args() 85 | 86 | # Set environment variables 87 | os.environ["OMP_NUM_THREADS"] = args.omp_num_threads 88 | os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism 89 | os.environ["DEVICE"] = "cuda" 90 | 91 | local_rank = int(os.environ["LOCAL_RANK"]) 92 | global_rank = int(os.environ["RANK"]) 93 | world_size = int(os.environ["WORLD_SIZE"]) 94 | backend = "nccl" 95 | torch.cuda.set_device(local_rank) 96 | device = torch.device("cuda", local_rank) 97 | dtype = torch.bfloat16 98 | 99 | dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2)) 100 | setup_process_group_manager(dp_size=args.dp_size, pp_size=args.pp_size, tp_size=args.tp_size) 101 | 102 | is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage 103 | set_all_seed(args.seed) 104 | 105 | if is_wandb_rank and args.use_wandb: 106 | wandb.init( 107 | project="picotron_tutorial", 108 | name=f"{args.run_name}_{pgm.process_group_manager}", 109 | config={ 110 | "tensor_parallel_size": pgm.process_group_manager.tp_world_size, 111 | "pipeline_parallel_size": pgm.process_group_manager.pp_world_size, 112 | "data_parallel_size": pgm.process_group_manager.dp_world_size, 113 | "model": args.model_name, 114 | "learning_rate": args.learning_rate, 115 | "seed": args.seed, 116 | }, 117 | ) 118 | 119 | model_config = AutoConfig.from_pretrained(args.model_name) 120 | model_config.num_hidden_layers = args.num_hidden_layers 121 | model_config.num_attention_heads = args.num_attention_heads 122 | model_config.num_key_value_heads = args.num_key_value_heads 123 | model_config.max_position_embeddings = args.seq_len 124 | 125 | model = Llama(config=model_config) 126 | 127 | if pgm.process_group_manager.tp_world_size > 1: 128 | model = apply_tensor_parallel(model) 129 | 130 | model.to(dtype).to(device) 131 | model.train() 132 | 133 | dist.barrier() 134 | 135 | optimizer = AdamW(model.parameters(), lr=args.learning_rate) 136 | 137 | dist.barrier() 138 | 139 | # Create dataloader 140 | dataloader = MicroBatchDataLoader( 141 | seq_len=args.seq_len, 142 | micro_batch_size=args.micro_batch_size, 143 | grad_acc_steps=args.gradient_accumulation_steps, 144 | dataset_name=args.dataset_name, 145 | tokenizer_name=args.model_name, 146 | max_tokens=args.max_tokens, 147 | num_workers=args.num_workers, 148 | num_proc=args.num_proc, 149 | ) 150 | 151 | tokens_per_step = dataloader.global_batch_size * args.seq_len 152 | if pgm.process_group_manager.global_rank == 0: 153 | print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank) 154 | 155 | trained_token, step = 0, 0 156 | 157 | dist.barrier() 158 | 159 | # Training loop 160 | while trained_token < args.max_tokens: 161 | 162 | step_start_time = time.time() 163 | optimizer.zero_grad() 164 | 165 | loss = train_step(model, dataloader, device) 166 | 167 | optimizer.step() 168 | 169 | step_duration = time.time() - step_start_time 170 | trained_token += tokens_per_step 171 | step += 1 172 | 173 | print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " 174 | f"Global batch size (with seq_len): {to_readable_format(tokens_per_step)}, " 175 | f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, " 176 | f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, " 177 | f"Tokens: {to_readable_format(trained_token)}{('/' + to_readable_format(args.max_tokens))}, " 178 | f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB" 179 | , is_print_rank=is_wandb_rank 180 | ) 181 | 182 | if is_wandb_rank and args.use_wandb: 183 | wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\ 184 | "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": tokens_per_step}) 185 | 186 | if is_wandb_rank and args.use_wandb: 187 | wandb.finish() 188 | 189 | dist.destroy_process_group() -------------------------------------------------------------------------------- /step4_tensor_parallel/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import builtins 5 | import fcntl 6 | 7 | def print(*args, is_print_rank=True, **kwargs): 8 | """ solves multi-process interleaved print problem """ 9 | if not is_print_rank: return 10 | with open(__file__, "r") as fh: 11 | fcntl.flock(fh, fcntl.LOCK_EX) 12 | try: 13 | builtins.print(*args, **kwargs) 14 | finally: 15 | fcntl.flock(fh, fcntl.LOCK_UN) 16 | 17 | def set_all_seed(seed): 18 | for module in [random, np.random]: module.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 21 | 22 | def to_readable_format(num, precision=3): 23 | num_str = str(num) 24 | length = len(num_str) 25 | 26 | def format_with_precision(main, decimal, suffix): 27 | if precision == 0: 28 | return f"{main}{suffix}" 29 | return f"{main}.{decimal[:precision]}{suffix}" 30 | 31 | if length > 12: # Trillions 32 | return format_with_precision(num_str[:-12], num_str[-12:], 'T') 33 | elif length > 9: # Billions 34 | return format_with_precision(num_str[:-9], num_str[-9:], 'B') 35 | elif length > 6: # Millions 36 | return format_with_precision(num_str[:-6], num_str[-6:], 'M') 37 | elif length > 3: # Thousands 38 | return format_with_precision(num_str[:-3], num_str[-3:], 'K') 39 | else: 40 | return num_str -------------------------------------------------------------------------------- /step5_data_parallel_naive/data_parallel.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import List 3 | import torch 4 | import torch.distributed as dist 5 | from torch import nn 6 | 7 | import process_group_manager as pgm 8 | 9 | ### begin Data Parallel (naive) 10 | class DataParallelNaive(nn.Module): 11 | def __init__(self, module): 12 | super().__init__() 13 | self.module = module 14 | # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation 15 | self.require_backward_grad_sync = True 16 | self.register_backward_hook(self._allreduce_grads) 17 | 18 | def forward(self, *inputs, **kwargs): 19 | return self.module(*inputs, **kwargs) 20 | 21 | def register_backward_hook(self, hook): 22 | """Registers a backward hook for all parameters of the model that require gradients.""" 23 | for p in self.module.parameters(): 24 | if p.requires_grad is True: 25 | p.register_hook(hook) 26 | 27 | def _allreduce_grads(self, grad): 28 | """Performs an all-reduce operation to synchronize gradients across multiple processes.""" 29 | # No synchronization needed during gradient accumulation, except at the final accumulation step. 30 | if self.require_backward_grad_sync: 31 | dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group) 32 | grad /= pgm.process_group_manager.dp_world_size 33 | return grad 34 | ### end Data Parallel (naive) -------------------------------------------------------------------------------- /step5_data_parallel_naive/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, DistributedSampler 3 | import numpy as np 4 | from functools import partial 5 | from datasets import Features, Sequence, Value, load_dataset 6 | from transformers import AutoTokenizer 7 | 8 | import process_group_manager as pgm 9 | 10 | class MicroBatchDataLoader(DataLoader): 11 | def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, seed, split="train"): 12 | 13 | self.micro_batch_size = micro_batch_size 14 | self.grad_acc_steps = grad_acc_steps 15 | self.seq_len = seq_len 16 | 17 | self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size 18 | 19 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 20 | self.dataset = load_dataset(dataset_name, split=split) 21 | 22 | # Tokenize and chunk the dataset 23 | self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_len, num_proc) 24 | 25 | total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1) 26 | assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens" 27 | 28 | self.sampler = DistributedSampler( 29 | self.tokenized_dataset, 30 | num_replicas=pgm.process_group_manager.dp_world_size, 31 | rank=pgm.process_group_manager.dp_rank, 32 | seed=seed, 33 | shuffle=False 34 | ) 35 | 36 | super().__init__( 37 | self.tokenized_dataset, 38 | batch_size=micro_batch_size, 39 | collate_fn=self.collate_batch, 40 | pin_memory=True, 41 | num_workers=num_workers, 42 | sampler=self.sampler, 43 | shuffle=False, 44 | ) 45 | 46 | def tokenizer_group_text(self, examples, tokenizer, sequence_length): 47 | """Tokenize a list of texts and group them in chunks of sequence_length + 1""" 48 | tokenized_text_batch = tokenizer.batch_encode_plus( 49 | examples, 50 | return_attention_mask=False, 51 | return_token_type_ids=False, 52 | return_tensors='np' 53 | ) 54 | concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} 55 | total_length = len(concatenated_tokens['input_ids']) 56 | 57 | if total_length >= sequence_length + 1: 58 | total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 59 | 60 | result = { 61 | 'input_ids': [ 62 | concatenated_tokens['input_ids'][i : i + sequence_length + 1] 63 | for i in range(0, total_length - sequence_length, sequence_length) 64 | ] 65 | } 66 | return result 67 | 68 | def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): 69 | """Tokenize the dataset and group texts in chunks of sequence_length + 1""" 70 | tokenizer_func = partial( 71 | self.tokenizer_group_text, 72 | tokenizer=self.tokenizer, 73 | sequence_length=sequence_length 74 | ) 75 | 76 | tokenized_dataset = dataset.map( 77 | tokenizer_func, 78 | input_columns=text_column_name, 79 | remove_columns=dataset.column_names, 80 | features=Features({ 81 | "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) 82 | }), 83 | batched=True, 84 | num_proc=num_proc, 85 | load_from_cache_file=True, # Preprocess dataset only once and cache it 86 | desc=f"Grouping texts in chunks of {sequence_length+1}", 87 | ) 88 | 89 | return tokenized_dataset 90 | 91 | def collate_batch(self, batch): 92 | batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) 93 | batch_size = batch_input_ids.size(0) 94 | input_ids = batch_input_ids[:, :-1].contiguous() 95 | target_ids = batch_input_ids[:, 1:].contiguous() 96 | position_ids = torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() 97 | attn_mask = torch.tril(torch.ones((self.seq_len, self.seq_len), dtype=torch.bool)) 98 | attn_mask = attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() 99 | 100 | return { 101 | "input_ids": input_ids, 102 | "target_ids": target_ids, 103 | "position_ids": position_ids, 104 | "attn_mask": attn_mask, 105 | "hidden_states": None 106 | } 107 | 108 | def __iter__(self): 109 | if self._iterator is None: 110 | self._iterator = super().__iter__() 111 | return self 112 | 113 | def __next__(self): 114 | if self._iterator is None: 115 | self._iterator = super().__iter__() 116 | try: 117 | batch = next(self._iterator) 118 | except StopIteration: 119 | self._iterator = None 120 | raise StopIteration 121 | return batch -------------------------------------------------------------------------------- /step5_data_parallel_naive/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from flash_attn.flash_attn_interface import flash_attn_func 5 | from flash_attn.layers.rotary import apply_rotary_emb 6 | from flash_attn.ops.triton.layer_norm import layer_norm_fn 7 | import process_group_manager as pgm 8 | 9 | def flash_attention(q, k, v, causal = True): 10 | q = q.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 11 | k = k.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 12 | v = v.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 13 | return flash_attn_func(q, k, v, causal=causal) 14 | 15 | def get_cos_sin(seq_length, head_dim, base=500000.0): 16 | assert head_dim%2==0 17 | # Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU 18 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim)) 19 | dtype = torch.bfloat16 20 | device = torch.device('cuda') 21 | position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1] 22 | # To match transformers implementation. m * theta should be computed on GPU 23 | theta = theta.to(device) 24 | return torch.cos(position.float()*theta.float()).to(dtype).repeat(1,2), torch.sin(position.float()*theta.float()).to(dtype).repeat(1,2) # [seq_length, head_dim], [seq_length, head_dim] 25 | 26 | class TritonRMSNorm(nn.Module): 27 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 28 | super().__init__() 29 | self.eps = eps 30 | self.weight = nn.Parameter(torch.ones(hidden_size)) 31 | self.register_parameter("bias", None) 32 | 33 | def forward( 34 | self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False 35 | ): 36 | return layer_norm_fn( 37 | hidden_states, 38 | self.weight, 39 | None, 40 | residual=residual, 41 | eps=self.eps, 42 | dropout_p=dropout_p, 43 | prenorm=prenorm, 44 | residual_in_fp32=residual_in_fp32, 45 | is_rms_norm=True, 46 | return_dropout_mask=return_dropout_mask, 47 | ) 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, config, layer_idx): 51 | super().__init__() 52 | self.hidden_size = config.hidden_size 53 | self.num_heads = config.num_attention_heads 54 | self.num_key_values = config.num_key_value_heads 55 | self.head_dim = self.hidden_size//self.num_heads 56 | assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size" 57 | assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size" 58 | self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism 59 | self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism 60 | 61 | 62 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) 63 | self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 64 | self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 65 | self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 66 | self.layer_idx = layer_idx 67 | 68 | def forward(self, x, cos, sin, attention_mask=None, position_ids=None): 69 | batch_size, seq_length, hidden_dim = x.size() 70 | q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim] 71 | k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 72 | v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 73 | 74 | q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim) # [batch_size, seq_length, num_heads, head_dim] 75 | k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim) # [batch_size, seq_length, num_key_values, head_dim] 76 | q = apply_rotary_emb(q,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_heads, head_dim] 77 | k = apply_rotary_emb(k,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_key_values, head_dim] 78 | q = q.transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim] 79 | k = k.transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim] 80 | v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1,2) # [batch_size, num_key_values, seq_length, head_dim] 81 | 82 | k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 83 | v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 84 | 85 | causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. 86 | 87 | out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] 88 | 89 | out = out.reshape(batch_size, seq_length, self.num_local_heads * self.head_dim) # [batch_size, seq_length, hidden_dim] 90 | out = self.out_proj(out) # [batch_size, seq_length, hidden_dim] 91 | return out 92 | 93 | class MLP(nn.Module): 94 | def __init__(self, config) -> None: 95 | super().__init__() 96 | self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 97 | self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 98 | self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 99 | 100 | def forward(self, x): 101 | #TODO: dont do single line operations as it is harder to debug 102 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 103 | 104 | class DecoderLayer(nn.Module): 105 | # TritonRMSNorm -> Attention -> Residual -> TritonRMSNorm -> MLP -> Residual 106 | def __init__(self, config, layer_idx): 107 | super().__init__() 108 | self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 109 | self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 110 | self.attention = Attention(config, layer_idx = layer_idx) 111 | self.mlp = MLP(config) 112 | self.layer_idx = layer_idx 113 | head_dim = config.hidden_size // config.num_attention_heads 114 | self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] 115 | 116 | def forward(self, x, attention_mask = None, position_ids = None): 117 | cos, sin = self.cos, self.sin 118 | x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention 119 | x = x + self.mlp(self.post_attention_layernorm(x)) # MLP 120 | return x 121 | 122 | class Llama(nn.Module): 123 | def __init__(self, config) -> None: 124 | super().__init__() 125 | # sanity check 126 | assert config.hidden_size % config.num_attention_heads==0 127 | assert config.num_attention_heads % config.num_key_value_heads==0 128 | 129 | # params 130 | self.vocab_size = config.vocab_size 131 | self.hidden_size = config.hidden_size 132 | self.num_heads = config.num_attention_heads 133 | self.num_key_values = config.num_key_value_heads 134 | self.head_dim = self.hidden_size//self.num_heads 135 | self.max_position_embeddings = config.max_position_embeddings 136 | self.num_layers = config.num_hidden_layers 137 | self.model_config = config 138 | 139 | # modules 140 | self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) 141 | self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)]) 142 | self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 143 | self.final_norm = TritonRMSNorm(self.hidden_size, eps=config.rms_norm_eps) 144 | 145 | def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None): 146 | x = self.embedding(input_ids) 147 | for layer in self.decoder_layers: 148 | x = layer(x) # [batch_size, seq_length, hidden_dim] 149 | x = self.final_norm(x) 150 | logits = self.final_proj(x) 151 | 152 | return logits # [batch_size, seq_length, vocab_size] -------------------------------------------------------------------------------- /step5_data_parallel_naive/patch_step_5.diff: -------------------------------------------------------------------------------- 1 | Binary files step4_tensor_parallel/__pycache__/data.cpython-39.pyc and step5_data_parallel_naive/__pycache__/data.cpython-39.pyc differ 2 | Binary files step4_tensor_parallel/__pycache__/data_parallel.cpython-39.pyc and step5_data_parallel_naive/__pycache__/data_parallel.cpython-39.pyc differ 3 | Binary files step4_tensor_parallel/__pycache__/dataloader.cpython-39.pyc and step5_data_parallel_naive/__pycache__/dataloader.cpython-39.pyc differ 4 | Binary files step4_tensor_parallel/__pycache__/model.cpython-39.pyc and step5_data_parallel_naive/__pycache__/model.cpython-39.pyc differ 5 | Binary files step4_tensor_parallel/__pycache__/process_group_manager.cpython-39.pyc and step5_data_parallel_naive/__pycache__/process_group_manager.cpython-39.pyc differ 6 | Binary files step4_tensor_parallel/__pycache__/tensor_parallel.cpython-39.pyc and step5_data_parallel_naive/__pycache__/tensor_parallel.cpython-39.pyc differ 7 | Binary files step4_tensor_parallel/__pycache__/utils.cpython-39.pyc and step5_data_parallel_naive/__pycache__/utils.cpython-39.pyc differ 8 | diff -x '*.diff' --new-file -ur step4_tensor_parallel/data_parallel.py step5_data_parallel_naive/data_parallel.py 9 | --- step4_tensor_parallel/data_parallel.py 1970-01-01 00:00:00.000000000 +0000 10 | +++ step5_data_parallel_naive/data_parallel.py 2024-11-19 14:30:21.000000000 +0000 11 | @@ -0,0 +1,34 @@ 12 | +import contextlib 13 | +from typing import List 14 | +import torch 15 | +import torch.distributed as dist 16 | +from torch import nn 17 | + 18 | +import process_group_manager as pgm 19 | + 20 | +### begin Data Parallel (naive) 21 | +class DataParallelNaive(nn.Module): 22 | + def __init__(self, module): 23 | + super().__init__() 24 | + self.module = module 25 | + # whether to synchronize gradients during backward pass. Set to False when using gradient accumulation 26 | + self.require_backward_grad_sync = True 27 | + self.register_backward_hook(self._allreduce_grads) 28 | + 29 | + def forward(self, *inputs, **kwargs): 30 | + return self.module(*inputs, **kwargs) 31 | + 32 | + def register_backward_hook(self, hook): 33 | + """Registers a backward hook for all parameters of the model that require gradients.""" 34 | + for p in self.module.parameters(): 35 | + if p.requires_grad is True: 36 | + p.register_hook(hook) 37 | + 38 | + def _allreduce_grads(self, grad): 39 | + """Performs an all-reduce operation to synchronize gradients across multiple processes.""" 40 | + # No synchronization needed during gradient accumulation, except at the final accumulation step. 41 | + if self.require_backward_grad_sync: 42 | + dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.dp_group) 43 | + grad /= pgm.process_group_manager.dp_world_size 44 | + return grad 45 | +### end Data Parallel (naive) 46 | \ No newline at end of file 47 | diff -x '*.diff' --new-file -ur step4_tensor_parallel/dataloader.py step5_data_parallel_naive/dataloader.py 48 | --- step4_tensor_parallel/dataloader.py 2024-11-17 13:14:18.000000000 +0000 49 | +++ step5_data_parallel_naive/dataloader.py 2024-11-17 15:10:41.000000000 +0000 50 | @@ -1,5 +1,5 @@ 51 | import torch 52 | -from torch.utils.data import DataLoader 53 | +from torch.utils.data import DataLoader, DistributedSampler 54 | import numpy as np 55 | from functools import partial 56 | from datasets import Features, Sequence, Value, load_dataset 57 | @@ -8,7 +8,7 @@ 58 | import process_group_manager as pgm 59 | 60 | class MicroBatchDataLoader(DataLoader): 61 | - def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, split="train"): 62 | + def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, seed, split="train"): 63 | 64 | self.micro_batch_size = micro_batch_size 65 | self.grad_acc_steps = grad_acc_steps 66 | @@ -25,12 +25,21 @@ 67 | total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1) 68 | assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens" 69 | 70 | + self.sampler = DistributedSampler( 71 | + self.tokenized_dataset, 72 | + num_replicas=pgm.process_group_manager.dp_world_size, 73 | + rank=pgm.process_group_manager.dp_rank, 74 | + seed=seed, 75 | + shuffle=False 76 | + ) 77 | + 78 | super().__init__( 79 | self.tokenized_dataset, 80 | batch_size=micro_batch_size, 81 | collate_fn=self.collate_batch, 82 | pin_memory=True, 83 | num_workers=num_workers, 84 | + sampler=self.sampler, 85 | shuffle=False, 86 | ) 87 | 88 | diff -x '*.diff' --new-file -ur step4_tensor_parallel/train.py step5_data_parallel_naive/train.py 89 | --- step4_tensor_parallel/train.py 2024-11-17 15:05:11.000000000 +0000 90 | +++ step5_data_parallel_naive/train.py 2024-11-19 14:16:50.000000000 +0000 91 | @@ -1,5 +1,5 @@ 92 | """ 93 | -torchrun --nproc_per_node 4 train.py --tp_size 4 --micro_batch_size 4 --gradient_accumulation_steps 8 --seq_len 128 --max_tokens 40960 --num_proc 16 --run_name tp_naive --use_wandb 94 | +torchrun --nproc_per_node 4 train.py --dp_size 4 --micro_batch_size 1 --gradient_accumulation_steps 8 --seq_len 128 --max_tokens 40960 --num_proc 16 --run_name dp_naive --use_wandb 95 | """ 96 | import os 97 | import time 98 | @@ -21,16 +21,23 @@ 99 | from utils import set_all_seed, print, to_readable_format 100 | 101 | from tensor_parallel import apply_tensor_parallel 102 | +from data_parallel import DataParallelNaive 103 | 104 | def train_step(model, dataloader, device): 105 | acc_loss = 0.0 106 | 107 | + requires_grad_sync = pgm.process_group_manager.dp_world_size > 1 108 | + 109 | for i in range(dataloader.grad_acc_steps): 110 | # get the next batch 111 | batch = next(dataloader) 112 | input_ids = batch["input_ids"].to(device) 113 | target_ids = batch["target_ids"].to(device) 114 | 115 | + # enable gradient synchronization for the last micro-batch only 116 | + if requires_grad_sync: 117 | + model.require_backward_grad_sync = (i == dataloader.grad_acc_steps - 1) 118 | + 119 | outputs = model(input_ids=input_ids) 120 | 121 | # compute the loss 122 | @@ -127,7 +134,13 @@ 123 | if pgm.process_group_manager.tp_world_size > 1: 124 | model = apply_tensor_parallel(model) 125 | 126 | + # Need to move the model to the device before wrapping it with DataParallel. 127 | + # Otherwise, the hook will get attached to the CPU model and not the GPU model. 128 | model.to(dtype).to(device) 129 | + 130 | + if pgm.process_group_manager.dp_world_size > 1: 131 | + model = DataParallelNaive(model) 132 | + 133 | model.train() 134 | 135 | dist.barrier() 136 | @@ -146,6 +159,7 @@ 137 | max_tokens=args.max_tokens, 138 | num_workers=args.num_workers, 139 | num_proc=args.num_proc, 140 | + seed=args.seed, 141 | ) 142 | 143 | tokens_per_step = dataloader.global_batch_size * args.seq_len 144 | @@ -169,6 +183,10 @@ 145 | step_duration = time.time() - step_start_time 146 | trained_token += tokens_per_step 147 | step += 1 148 | + 149 | + # In DDP implementation, we need to reset the gradient buffers 150 | + if hasattr(model, 'reset'): 151 | + model.reset() 152 | 153 | print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " 154 | f"Global batch size (with seq_len): {to_readable_format(tokens_per_step)}, " 155 | -------------------------------------------------------------------------------- /step5_data_parallel_naive/process_group_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class ProcessGroupManager: 6 | def __init__(self, dp_size, pp_size, tp_size): 7 | self.global_rank = dist.get_rank() 8 | self.world_size = dist.get_world_size() 9 | self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) 10 | 11 | assert self.world_size == dp_size * pp_size * tp_size, f"World size ({self.world_size}) != DP ({self.dp_size}) * PP ({self.pp_size}) * TP ({self.tp_size})" 12 | 13 | self.grid = torch.arange(self.world_size).view(dp_size, pp_size, tp_size) # DP * PP * TP grid 14 | # Find the position of the current process in the grid 15 | self.dp_rank, self.pp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() 16 | 17 | # Process group creation - Update indexing to match new grid order 18 | self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :].tolist() for d in range(dp_size) for p in range(pp_size)])[0] 19 | self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, t].tolist() for d in range(dp_size) for t in range(tp_size)])[0] 20 | self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, t].tolist() for p in range(pp_size) for t in range(tp_size)])[0] 21 | self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, t].flatten().tolist() for t in range(tp_size)])[0] 22 | 23 | self.world_group = dist.group.WORLD 24 | 25 | # Update group IDs with new grid ordering 26 | self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, :].tolist() 27 | self.pp_group_ids = self.grid[self.dp_rank, :, self.tp_rank].tolist() 28 | self.dp_group_ids = self.grid[:, self.pp_rank, self.tp_rank].tolist() 29 | 30 | # Tensor parallelism 31 | self.tp_world_size = dist.get_world_size(group=self.tp_group) 32 | self.tp_first_rank = self.tp_group_ids[0] 33 | self.tp_last_rank = self.tp_group_ids[-1] 34 | 35 | # Pipeline parallelism 36 | self.pp_world_size = dist.get_world_size(group=self.pp_group) 37 | self.pp_first_rank = self.pp_group_ids[0] 38 | self.pp_last_rank = self.pp_group_ids[-1] 39 | self.pp_is_first_stage = self.pp_rank == 0 40 | self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 41 | self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.tp_rank].item()) 42 | self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.tp_rank].item()) 43 | 44 | # Data parallelism 45 | self.dp_world_size = dist.get_world_size(group=self.dp_group) 46 | self.dp_first_rank = self.dp_group_ids[0] 47 | self.dp_last_rank = self.dp_group_ids[-1] 48 | 49 | def __str__(self): 50 | return f"DP({self.dp_world_size})-PP({self.pp_world_size})-TP({self.tp_world_size})-Rank({self.global_rank})" 51 | 52 | def setup_process_group_manager(dp_size, pp_size, tp_size): 53 | global process_group_manager 54 | process_group_manager = ProcessGroupManager(dp_size, pp_size, tp_size) -------------------------------------------------------------------------------- /step5_data_parallel_naive/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | torchrun --nproc_per_node 4 train.py --dp_size 4 --micro_batch_size 1 --gradient_accumulation_steps 8 --seq_len 128 --max_tokens 40960 --num_proc 16 --run_name dp_naive --use_wandb 3 | """ 4 | import os 5 | import time 6 | import wandb 7 | import datetime 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | import argparse 12 | from torch.optim import AdamW 13 | from transformers import AutoConfig 14 | 15 | import lovely_tensors as lt; lt.monkey_patch() 16 | 17 | from model import Llama 18 | from dataloader import MicroBatchDataLoader 19 | import process_group_manager as pgm 20 | from process_group_manager import setup_process_group_manager 21 | from utils import set_all_seed, print, to_readable_format 22 | 23 | from tensor_parallel import apply_tensor_parallel 24 | from data_parallel import DataParallelNaive 25 | 26 | def train_step(model, dataloader, device): 27 | acc_loss = 0.0 28 | 29 | requires_grad_sync = pgm.process_group_manager.dp_world_size > 1 30 | 31 | for i in range(dataloader.grad_acc_steps): 32 | # get the next batch 33 | batch = next(dataloader) 34 | input_ids = batch["input_ids"].to(device) 35 | target_ids = batch["target_ids"].to(device) 36 | 37 | # enable gradient synchronization for the last micro-batch only 38 | if requires_grad_sync: 39 | model.require_backward_grad_sync = (i == dataloader.grad_acc_steps - 1) 40 | 41 | outputs = model(input_ids=input_ids) 42 | 43 | # compute the loss 44 | batch_size, seq_len = input_ids.shape 45 | target_ids = target_ids.reshape(-1) 46 | outputs = outputs.view(seq_len*batch_size, -1) 47 | loss = F.cross_entropy(outputs, target_ids, reduction='mean') / dataloader.grad_acc_steps 48 | 49 | loss.backward() 50 | 51 | acc_loss += loss.item() 52 | 53 | return acc_loss 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser(description="Training script for LLaMA model") 57 | 58 | # Environment arguments 59 | parser.add_argument("--omp_num_threads", type=str, default="1") 60 | parser.add_argument("--tokenizers_parallelism", type=str, default="false") 61 | 62 | # Model arguments 63 | parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct") 64 | parser.add_argument("--num_hidden_layers", type=int, default=32) 65 | parser.add_argument("--num_attention_heads", type=int, default=16) 66 | parser.add_argument("--num_key_value_heads", type=int, default=4) 67 | 68 | # Dataset arguments 69 | parser.add_argument("--dataset_name", type=str, default="roneneldan/TinyStories") 70 | parser.add_argument("--num_workers", type=int, default=1) 71 | parser.add_argument("--num_proc", type=int, default=4) 72 | 73 | # Training arguments 74 | parser.add_argument("--seed", type=int, default=42) 75 | parser.add_argument("--learning_rate", type=float, default=3e-4) 76 | parser.add_argument("--seq_len", type=int, default=32) 77 | parser.add_argument("--micro_batch_size", type=int, default=1) 78 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 79 | parser.add_argument("--max_tokens", type=int, default=1e6) 80 | 81 | # Distributed training arguments 82 | parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") 83 | parser.add_argument("--dp_size", type=int, default=1, help="Data Parallel size") 84 | parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel size") 85 | parser.add_argument("--pp_engine", type=str, default="afab", choices=["1f1b", "afab"]) 86 | 87 | # Logging arguments 88 | parser.add_argument("--run_name", type=str, default="default_run") 89 | parser.add_argument("--use_wandb", action="store_true") 90 | 91 | args = parser.parse_args() 92 | 93 | # Set environment variables 94 | os.environ["OMP_NUM_THREADS"] = args.omp_num_threads 95 | os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism 96 | os.environ["DEVICE"] = "cuda" 97 | 98 | local_rank = int(os.environ["LOCAL_RANK"]) 99 | global_rank = int(os.environ["RANK"]) 100 | world_size = int(os.environ["WORLD_SIZE"]) 101 | backend = "nccl" 102 | torch.cuda.set_device(local_rank) 103 | device = torch.device("cuda", local_rank) 104 | dtype = torch.bfloat16 105 | 106 | dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2)) 107 | setup_process_group_manager(dp_size=args.dp_size, pp_size=args.pp_size, tp_size=args.tp_size) 108 | 109 | is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage 110 | set_all_seed(args.seed) 111 | 112 | if is_wandb_rank and args.use_wandb: 113 | wandb.init( 114 | project="picotron_tutorial", 115 | name=f"{args.run_name}_{pgm.process_group_manager}", 116 | config={ 117 | "tensor_parallel_size": pgm.process_group_manager.tp_world_size, 118 | "pipeline_parallel_size": pgm.process_group_manager.pp_world_size, 119 | "data_parallel_size": pgm.process_group_manager.dp_world_size, 120 | "model": args.model_name, 121 | "learning_rate": args.learning_rate, 122 | "seed": args.seed, 123 | }, 124 | ) 125 | 126 | model_config = AutoConfig.from_pretrained(args.model_name) 127 | model_config.num_hidden_layers = args.num_hidden_layers 128 | model_config.num_attention_heads = args.num_attention_heads 129 | model_config.num_key_value_heads = args.num_key_value_heads 130 | model_config.max_position_embeddings = args.seq_len 131 | 132 | model = Llama(config=model_config) 133 | 134 | if pgm.process_group_manager.tp_world_size > 1: 135 | model = apply_tensor_parallel(model) 136 | 137 | # Need to move the model to the device before wrapping it with DataParallel. 138 | # Otherwise, the hook will get attached to the CPU model and not the GPU model. 139 | model.to(dtype).to(device) 140 | 141 | if pgm.process_group_manager.dp_world_size > 1: 142 | model = DataParallelNaive(model) 143 | 144 | model.train() 145 | 146 | dist.barrier() 147 | 148 | optimizer = AdamW(model.parameters(), lr=args.learning_rate) 149 | 150 | dist.barrier() 151 | 152 | # Create dataloader 153 | dataloader = MicroBatchDataLoader( 154 | seq_len=args.seq_len, 155 | micro_batch_size=args.micro_batch_size, 156 | grad_acc_steps=args.gradient_accumulation_steps, 157 | dataset_name=args.dataset_name, 158 | tokenizer_name=args.model_name, 159 | max_tokens=args.max_tokens, 160 | num_workers=args.num_workers, 161 | num_proc=args.num_proc, 162 | seed=args.seed, 163 | ) 164 | 165 | tokens_per_step = dataloader.global_batch_size * args.seq_len 166 | if pgm.process_group_manager.global_rank == 0: 167 | print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank) 168 | 169 | trained_token, step = 0, 0 170 | 171 | dist.barrier() 172 | 173 | # Training loop 174 | while trained_token < args.max_tokens: 175 | 176 | step_start_time = time.time() 177 | optimizer.zero_grad() 178 | 179 | loss = train_step(model, dataloader, device) 180 | 181 | optimizer.step() 182 | 183 | step_duration = time.time() - step_start_time 184 | trained_token += tokens_per_step 185 | step += 1 186 | 187 | # In DDP implementation, we need to reset the gradient buffers 188 | if hasattr(model, 'reset'): 189 | model.reset() 190 | 191 | print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " 192 | f"Global batch size (with seq_len): {to_readable_format(tokens_per_step)}, " 193 | f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, " 194 | f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, " 195 | f"Tokens: {to_readable_format(trained_token)}{('/' + to_readable_format(args.max_tokens))}, " 196 | f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB" 197 | , is_print_rank=is_wandb_rank 198 | ) 199 | 200 | if is_wandb_rank and args.use_wandb: 201 | wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\ 202 | "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": tokens_per_step}) 203 | 204 | if is_wandb_rank and args.use_wandb: 205 | wandb.finish() 206 | 207 | dist.destroy_process_group() -------------------------------------------------------------------------------- /step5_data_parallel_naive/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import builtins 5 | import fcntl 6 | 7 | def print(*args, is_print_rank=True, **kwargs): 8 | """ solves multi-process interleaved print problem """ 9 | if not is_print_rank: return 10 | with open(__file__, "r") as fh: 11 | fcntl.flock(fh, fcntl.LOCK_EX) 12 | try: 13 | builtins.print(*args, **kwargs) 14 | finally: 15 | fcntl.flock(fh, fcntl.LOCK_UN) 16 | 17 | def set_all_seed(seed): 18 | for module in [random, np.random]: module.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 21 | 22 | def to_readable_format(num, precision=3): 23 | num_str = str(num) 24 | length = len(num_str) 25 | 26 | def format_with_precision(main, decimal, suffix): 27 | if precision == 0: 28 | return f"{main}{suffix}" 29 | return f"{main}.{decimal[:precision]}{suffix}" 30 | 31 | if length > 12: # Trillions 32 | return format_with_precision(num_str[:-12], num_str[-12:], 'T') 33 | elif length > 9: # Billions 34 | return format_with_precision(num_str[:-9], num_str[-9:], 'B') 35 | elif length > 6: # Millions 36 | return format_with_precision(num_str[:-6], num_str[-6:], 'M') 37 | elif length > 3: # Thousands 38 | return format_with_precision(num_str[:-3], num_str[-3:], 'K') 39 | else: 40 | return num_str -------------------------------------------------------------------------------- /step6_data_parallel_bucket/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, DistributedSampler 3 | import numpy as np 4 | from functools import partial 5 | from datasets import Features, Sequence, Value, load_dataset 6 | from transformers import AutoTokenizer 7 | 8 | import process_group_manager as pgm 9 | 10 | class MicroBatchDataLoader(DataLoader): 11 | def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, seed, split="train"): 12 | 13 | self.micro_batch_size = micro_batch_size 14 | self.grad_acc_steps = grad_acc_steps 15 | self.seq_len = seq_len 16 | 17 | self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size 18 | 19 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 20 | self.dataset = load_dataset(dataset_name, split=split) 21 | 22 | # Tokenize and chunk the dataset 23 | self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_len, num_proc) 24 | 25 | total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1) 26 | assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens" 27 | 28 | self.sampler = DistributedSampler( 29 | self.tokenized_dataset, 30 | num_replicas=pgm.process_group_manager.dp_world_size, 31 | rank=pgm.process_group_manager.dp_rank, 32 | seed=seed, 33 | shuffle=False 34 | ) 35 | 36 | super().__init__( 37 | self.tokenized_dataset, 38 | batch_size=micro_batch_size, 39 | collate_fn=self.collate_batch, 40 | pin_memory=True, 41 | num_workers=num_workers, 42 | sampler=self.sampler, 43 | shuffle=False, 44 | ) 45 | 46 | def tokenizer_group_text(self, examples, tokenizer, sequence_length): 47 | """Tokenize a list of texts and group them in chunks of sequence_length + 1""" 48 | tokenized_text_batch = tokenizer.batch_encode_plus( 49 | examples, 50 | return_attention_mask=False, 51 | return_token_type_ids=False, 52 | return_tensors='np' 53 | ) 54 | concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} 55 | total_length = len(concatenated_tokens['input_ids']) 56 | 57 | if total_length >= sequence_length + 1: 58 | total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 59 | 60 | result = { 61 | 'input_ids': [ 62 | concatenated_tokens['input_ids'][i : i + sequence_length + 1] 63 | for i in range(0, total_length - sequence_length, sequence_length) 64 | ] 65 | } 66 | return result 67 | 68 | def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): 69 | """Tokenize the dataset and group texts in chunks of sequence_length + 1""" 70 | tokenizer_func = partial( 71 | self.tokenizer_group_text, 72 | tokenizer=self.tokenizer, 73 | sequence_length=sequence_length 74 | ) 75 | 76 | tokenized_dataset = dataset.map( 77 | tokenizer_func, 78 | input_columns=text_column_name, 79 | remove_columns=dataset.column_names, 80 | features=Features({ 81 | "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) 82 | }), 83 | batched=True, 84 | num_proc=num_proc, 85 | load_from_cache_file=True, # Preprocess dataset only once and cache it 86 | desc=f"Grouping texts in chunks of {sequence_length+1}", 87 | ) 88 | 89 | return tokenized_dataset 90 | 91 | def collate_batch(self, batch): 92 | batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) 93 | batch_size = batch_input_ids.size(0) 94 | input_ids = batch_input_ids[:, :-1].contiguous() 95 | target_ids = batch_input_ids[:, 1:].contiguous() 96 | position_ids = torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() 97 | attn_mask = torch.tril(torch.ones((self.seq_len, self.seq_len), dtype=torch.bool)) 98 | attn_mask = attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() 99 | 100 | return { 101 | "input_ids": input_ids, 102 | "target_ids": target_ids, 103 | "position_ids": position_ids, 104 | "attn_mask": attn_mask, 105 | "hidden_states": None 106 | } 107 | 108 | def __iter__(self): 109 | if self._iterator is None: 110 | self._iterator = super().__iter__() 111 | return self 112 | 113 | def __next__(self): 114 | if self._iterator is None: 115 | self._iterator = super().__iter__() 116 | try: 117 | batch = next(self._iterator) 118 | except StopIteration: 119 | self._iterator = None 120 | raise StopIteration 121 | return batch -------------------------------------------------------------------------------- /step6_data_parallel_bucket/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from flash_attn.flash_attn_interface import flash_attn_func 5 | from flash_attn.layers.rotary import apply_rotary_emb 6 | from flash_attn.ops.triton.layer_norm import layer_norm_fn 7 | import process_group_manager as pgm 8 | 9 | def flash_attention(q, k, v, causal = True): 10 | q = q.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 11 | k = k.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 12 | v = v.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 13 | return flash_attn_func(q, k, v, causal=causal) 14 | 15 | def get_cos_sin(seq_length, head_dim, base=500000.0): 16 | assert head_dim%2==0 17 | # Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU 18 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim)) 19 | dtype = torch.bfloat16 20 | device = torch.device('cuda') 21 | position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1] 22 | # To match transformers implementation. m * theta should be computed on GPU 23 | theta = theta.to(device) 24 | return torch.cos(position.float()*theta.float()).to(dtype).repeat(1,2), torch.sin(position.float()*theta.float()).to(dtype).repeat(1,2) # [seq_length, head_dim], [seq_length, head_dim] 25 | 26 | class TritonRMSNorm(nn.Module): 27 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 28 | super().__init__() 29 | self.eps = eps 30 | self.weight = nn.Parameter(torch.ones(hidden_size)) 31 | self.register_parameter("bias", None) 32 | 33 | def forward( 34 | self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False 35 | ): 36 | return layer_norm_fn( 37 | hidden_states, 38 | self.weight, 39 | None, 40 | residual=residual, 41 | eps=self.eps, 42 | dropout_p=dropout_p, 43 | prenorm=prenorm, 44 | residual_in_fp32=residual_in_fp32, 45 | is_rms_norm=True, 46 | return_dropout_mask=return_dropout_mask, 47 | ) 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, config, layer_idx): 51 | super().__init__() 52 | self.hidden_size = config.hidden_size 53 | self.num_heads = config.num_attention_heads 54 | self.num_key_values = config.num_key_value_heads 55 | self.head_dim = self.hidden_size//self.num_heads 56 | assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size" 57 | assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size" 58 | self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism 59 | self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism 60 | 61 | 62 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) 63 | self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 64 | self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 65 | self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 66 | self.layer_idx = layer_idx 67 | 68 | def forward(self, x, cos, sin, attention_mask=None, position_ids=None): 69 | batch_size, seq_length, hidden_dim = x.size() 70 | q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim] 71 | k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 72 | v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 73 | 74 | q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim) # [batch_size, seq_length, num_heads, head_dim] 75 | k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim) # [batch_size, seq_length, num_key_values, head_dim] 76 | q = apply_rotary_emb(q,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_heads, head_dim] 77 | k = apply_rotary_emb(k,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_key_values, head_dim] 78 | q = q.transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim] 79 | k = k.transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim] 80 | v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1,2) # [batch_size, num_key_values, seq_length, head_dim] 81 | 82 | k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 83 | v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 84 | 85 | causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. 86 | 87 | out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] 88 | 89 | out = out.reshape(batch_size, seq_length, self.num_local_heads * self.head_dim) # [batch_size, seq_length, hidden_dim] 90 | out = self.out_proj(out) # [batch_size, seq_length, hidden_dim] 91 | return out 92 | 93 | class MLP(nn.Module): 94 | def __init__(self, config) -> None: 95 | super().__init__() 96 | self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 97 | self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 98 | self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 99 | 100 | def forward(self, x): 101 | #TODO: dont do single line operations as it is harder to debug 102 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 103 | 104 | class DecoderLayer(nn.Module): 105 | # TritonRMSNorm -> Attention -> Residual -> TritonRMSNorm -> MLP -> Residual 106 | def __init__(self, config, layer_idx): 107 | super().__init__() 108 | self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 109 | self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 110 | self.attention = Attention(config, layer_idx = layer_idx) 111 | self.mlp = MLP(config) 112 | self.layer_idx = layer_idx 113 | head_dim = config.hidden_size // config.num_attention_heads 114 | self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] 115 | 116 | def forward(self, x, attention_mask = None, position_ids = None): 117 | cos, sin = self.cos, self.sin 118 | x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention 119 | x = x + self.mlp(self.post_attention_layernorm(x)) # MLP 120 | return x 121 | 122 | class Llama(nn.Module): 123 | def __init__(self, config) -> None: 124 | super().__init__() 125 | # sanity check 126 | assert config.hidden_size % config.num_attention_heads==0 127 | assert config.num_attention_heads % config.num_key_value_heads==0 128 | 129 | # params 130 | self.vocab_size = config.vocab_size 131 | self.hidden_size = config.hidden_size 132 | self.num_heads = config.num_attention_heads 133 | self.num_key_values = config.num_key_value_heads 134 | self.head_dim = self.hidden_size//self.num_heads 135 | self.max_position_embeddings = config.max_position_embeddings 136 | self.num_layers = config.num_hidden_layers 137 | self.model_config = config 138 | 139 | # modules 140 | self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) 141 | self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)]) 142 | self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 143 | self.final_norm = TritonRMSNorm(self.hidden_size, eps=config.rms_norm_eps) 144 | 145 | def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None): 146 | x = self.embedding(input_ids) 147 | for layer in self.decoder_layers: 148 | x = layer(x) # [batch_size, seq_length, hidden_dim] 149 | x = self.final_norm(x) 150 | logits = self.final_proj(x) 151 | 152 | return logits # [batch_size, seq_length, vocab_size] -------------------------------------------------------------------------------- /step6_data_parallel_bucket/process_group_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class ProcessGroupManager: 6 | def __init__(self, dp_size, pp_size, tp_size): 7 | self.global_rank = dist.get_rank() 8 | self.world_size = dist.get_world_size() 9 | self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) 10 | 11 | assert self.world_size == dp_size * pp_size * tp_size, f"World size ({self.world_size}) != DP ({self.dp_size}) * PP ({self.pp_size}) * TP ({self.tp_size})" 12 | 13 | self.grid = torch.arange(self.world_size).view(dp_size, pp_size, tp_size) # DP * PP * TP grid 14 | # Find the position of the current process in the grid 15 | self.dp_rank, self.pp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() 16 | 17 | # Process group creation - Update indexing to match new grid order 18 | self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :].tolist() for d in range(dp_size) for p in range(pp_size)])[0] 19 | self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, t].tolist() for d in range(dp_size) for t in range(tp_size)])[0] 20 | self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, t].tolist() for p in range(pp_size) for t in range(tp_size)])[0] 21 | self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, t].flatten().tolist() for t in range(tp_size)])[0] 22 | 23 | self.world_group = dist.group.WORLD 24 | 25 | # Update group IDs with new grid ordering 26 | self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, :].tolist() 27 | self.pp_group_ids = self.grid[self.dp_rank, :, self.tp_rank].tolist() 28 | self.dp_group_ids = self.grid[:, self.pp_rank, self.tp_rank].tolist() 29 | 30 | # Tensor parallelism 31 | self.tp_world_size = dist.get_world_size(group=self.tp_group) 32 | self.tp_first_rank = self.tp_group_ids[0] 33 | self.tp_last_rank = self.tp_group_ids[-1] 34 | 35 | # Pipeline parallelism 36 | self.pp_world_size = dist.get_world_size(group=self.pp_group) 37 | self.pp_first_rank = self.pp_group_ids[0] 38 | self.pp_last_rank = self.pp_group_ids[-1] 39 | self.pp_is_first_stage = self.pp_rank == 0 40 | self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 41 | self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.tp_rank].item()) 42 | self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.tp_rank].item()) 43 | 44 | # Data parallelism 45 | self.dp_world_size = dist.get_world_size(group=self.dp_group) 46 | self.dp_first_rank = self.dp_group_ids[0] 47 | self.dp_last_rank = self.dp_group_ids[-1] 48 | 49 | def __str__(self): 50 | return f"DP({self.dp_world_size})-PP({self.pp_world_size})-TP({self.tp_world_size})-Rank({self.global_rank})" 51 | 52 | def setup_process_group_manager(dp_size, pp_size, tp_size): 53 | global process_group_manager 54 | process_group_manager = ProcessGroupManager(dp_size, pp_size, tp_size) -------------------------------------------------------------------------------- /step6_data_parallel_bucket/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | torchrun --nproc_per_node 4 train.py --dp_size 4 --micro_batch_size 1 --gradient_accumulation_steps 8 --seq_len 128 --max_tokens 40960 --num_proc 16 --run_name dp_bucket --use_wandb 3 | """ 4 | import os 5 | import time 6 | import wandb 7 | import datetime 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | import argparse 12 | from torch.optim import AdamW 13 | from transformers import AutoConfig 14 | 15 | import lovely_tensors as lt; lt.monkey_patch() 16 | 17 | from model import Llama 18 | from dataloader import MicroBatchDataLoader 19 | import process_group_manager as pgm 20 | from process_group_manager import setup_process_group_manager 21 | from utils import set_all_seed, print, to_readable_format 22 | 23 | from tensor_parallel import apply_tensor_parallel 24 | from data_parallel import DataParallelBucket 25 | 26 | def train_step(model, dataloader, device): 27 | acc_loss = 0.0 28 | 29 | requires_grad_sync = pgm.process_group_manager.dp_world_size > 1 30 | 31 | for i in range(dataloader.grad_acc_steps): 32 | # get the next batch 33 | batch = next(dataloader) 34 | input_ids = batch["input_ids"].to(device) 35 | target_ids = batch["target_ids"].to(device) 36 | 37 | # enable gradient synchronization for the last micro-batch only 38 | if requires_grad_sync: 39 | model.require_backward_grad_sync = (i == dataloader.grad_acc_steps - 1) 40 | 41 | outputs = model(input_ids=input_ids) 42 | 43 | # compute the loss 44 | batch_size, seq_len = input_ids.shape 45 | target_ids = target_ids.reshape(-1) 46 | outputs = outputs.view(seq_len*batch_size, -1) 47 | loss = F.cross_entropy(outputs, target_ids, reduction='mean') / dataloader.grad_acc_steps 48 | 49 | loss.backward() 50 | 51 | acc_loss += loss.item() 52 | 53 | return acc_loss 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser(description="Training script for LLaMA model") 57 | 58 | # Environment arguments 59 | parser.add_argument("--omp_num_threads", type=str, default="1") 60 | parser.add_argument("--tokenizers_parallelism", type=str, default="false") 61 | 62 | # Model arguments 63 | parser.add_argument("--model_name", type=str, default="HuggingFaceTB/SmolLM-360M-Instruct") 64 | parser.add_argument("--num_hidden_layers", type=int, default=32) 65 | parser.add_argument("--num_attention_heads", type=int, default=16) 66 | parser.add_argument("--num_key_value_heads", type=int, default=4) 67 | 68 | # Dataset arguments 69 | parser.add_argument("--dataset_name", type=str, default="roneneldan/TinyStories") 70 | parser.add_argument("--num_workers", type=int, default=1) 71 | parser.add_argument("--num_proc", type=int, default=4) 72 | 73 | # Training arguments 74 | parser.add_argument("--seed", type=int, default=42) 75 | parser.add_argument("--learning_rate", type=float, default=3e-4) 76 | parser.add_argument("--seq_len", type=int, default=32) 77 | parser.add_argument("--micro_batch_size", type=int, default=1) 78 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 79 | parser.add_argument("--max_tokens", type=int, default=1e6) 80 | 81 | # Distributed training arguments 82 | parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") 83 | parser.add_argument("--dp_size", type=int, default=1, help="Data Parallel size") 84 | parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel size") 85 | parser.add_argument("--pp_engine", type=str, default="afab", choices=["1f1b", "afab"]) 86 | 87 | # Logging arguments 88 | parser.add_argument("--run_name", type=str, default="default_run") 89 | parser.add_argument("--use_wandb", action="store_true") 90 | 91 | args = parser.parse_args() 92 | 93 | # Set environment variables 94 | os.environ["OMP_NUM_THREADS"] = args.omp_num_threads 95 | os.environ["TOKENIZERS_PARALLELISM"] = args.tokenizers_parallelism 96 | os.environ["DEVICE"] = "cuda" 97 | 98 | local_rank = int(os.environ["LOCAL_RANK"]) 99 | global_rank = int(os.environ["RANK"]) 100 | world_size = int(os.environ["WORLD_SIZE"]) 101 | backend = "nccl" 102 | torch.cuda.set_device(local_rank) 103 | device = torch.device("cuda", local_rank) 104 | dtype = torch.bfloat16 105 | 106 | dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=2)) 107 | setup_process_group_manager(dp_size=args.dp_size, pp_size=args.pp_size, tp_size=args.tp_size) 108 | 109 | is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage 110 | set_all_seed(args.seed) 111 | 112 | if is_wandb_rank and args.use_wandb: 113 | wandb.init( 114 | project="picotron_tutorial", 115 | name=f"{args.run_name}_{pgm.process_group_manager}", 116 | config={ 117 | "tensor_parallel_size": pgm.process_group_manager.tp_world_size, 118 | "pipeline_parallel_size": pgm.process_group_manager.pp_world_size, 119 | "data_parallel_size": pgm.process_group_manager.dp_world_size, 120 | "model": args.model_name, 121 | "learning_rate": args.learning_rate, 122 | "seed": args.seed, 123 | }, 124 | ) 125 | 126 | model_config = AutoConfig.from_pretrained(args.model_name) 127 | model_config.num_hidden_layers = args.num_hidden_layers 128 | model_config.num_attention_heads = args.num_attention_heads 129 | model_config.num_key_value_heads = args.num_key_value_heads 130 | model_config.max_position_embeddings = args.seq_len 131 | 132 | model = Llama(config=model_config) 133 | 134 | if pgm.process_group_manager.tp_world_size > 1: 135 | model = apply_tensor_parallel(model) 136 | 137 | # Need to move the model to the device before wrapping it with DataParallel. 138 | # Otherwise, the hook will get attached to the CPU model and not the GPU model. 139 | model.to(dtype).to(device) 140 | 141 | if pgm.process_group_manager.dp_world_size > 1: 142 | model = DataParallelBucket(model) 143 | 144 | model.train() 145 | 146 | dist.barrier() 147 | 148 | optimizer = AdamW(model.parameters(), lr=args.learning_rate) 149 | 150 | dist.barrier() 151 | 152 | # Create dataloader 153 | dataloader = MicroBatchDataLoader( 154 | seq_len=args.seq_len, 155 | micro_batch_size=args.micro_batch_size, 156 | grad_acc_steps=args.gradient_accumulation_steps, 157 | dataset_name=args.dataset_name, 158 | tokenizer_name=args.model_name, 159 | max_tokens=args.max_tokens, 160 | num_workers=args.num_workers, 161 | num_proc=args.num_proc, 162 | seed=args.seed, 163 | ) 164 | 165 | tokens_per_step = dataloader.global_batch_size * args.seq_len 166 | if pgm.process_group_manager.global_rank == 0: 167 | print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank) 168 | 169 | trained_token, step = 0, 0 170 | 171 | dist.barrier() 172 | 173 | # Training loop 174 | while trained_token < args.max_tokens: 175 | 176 | step_start_time = time.time() 177 | optimizer.zero_grad() 178 | 179 | loss = train_step(model, dataloader, device) 180 | 181 | optimizer.step() 182 | 183 | step_duration = time.time() - step_start_time 184 | trained_token += tokens_per_step 185 | step += 1 186 | 187 | # In DDP implementation, we need to reset the gradient buffers 188 | if hasattr(model, 'reset'): 189 | model.reset() 190 | 191 | print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " 192 | f"Global batch size (with seq_len): {to_readable_format(tokens_per_step)}, " 193 | f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, " 194 | f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, " 195 | f"Tokens: {to_readable_format(trained_token)}{('/' + to_readable_format(args.max_tokens))}, " 196 | f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB" 197 | , is_print_rank=is_wandb_rank 198 | ) 199 | 200 | if is_wandb_rank and args.use_wandb: 201 | wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\ 202 | "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": tokens_per_step}) 203 | 204 | if is_wandb_rank and args.use_wandb: 205 | wandb.finish() 206 | 207 | dist.destroy_process_group() -------------------------------------------------------------------------------- /step6_data_parallel_bucket/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import builtins 5 | import fcntl 6 | 7 | def print(*args, is_print_rank=True, **kwargs): 8 | """ solves multi-process interleaved print problem """ 9 | if not is_print_rank: return 10 | with open(__file__, "r") as fh: 11 | fcntl.flock(fh, fcntl.LOCK_EX) 12 | try: 13 | builtins.print(*args, **kwargs) 14 | finally: 15 | fcntl.flock(fh, fcntl.LOCK_UN) 16 | 17 | def set_all_seed(seed): 18 | for module in [random, np.random]: module.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 21 | 22 | def to_readable_format(num, precision=3): 23 | num_str = str(num) 24 | length = len(num_str) 25 | 26 | def format_with_precision(main, decimal, suffix): 27 | if precision == 0: 28 | return f"{main}{suffix}" 29 | return f"{main}.{decimal[:precision]}{suffix}" 30 | 31 | if length > 12: # Trillions 32 | return format_with_precision(num_str[:-12], num_str[-12:], 'T') 33 | elif length > 9: # Billions 34 | return format_with_precision(num_str[:-9], num_str[-9:], 'B') 35 | elif length > 6: # Millions 36 | return format_with_precision(num_str[:-6], num_str[-6:], 'M') 37 | elif length > 3: # Thousands 38 | return format_with_precision(num_str[:-3], num_str[-3:], 'K') 39 | else: 40 | return num_str -------------------------------------------------------------------------------- /step7_pipeline_parallel_afab/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, DistributedSampler 3 | import numpy as np 4 | from functools import partial 5 | from datasets import Features, Sequence, Value, load_dataset 6 | from transformers import AutoTokenizer 7 | 8 | import process_group_manager as pgm 9 | 10 | class MicroBatchDataLoader(DataLoader): 11 | def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, seed, split="train"): 12 | 13 | self.micro_batch_size = micro_batch_size 14 | self.grad_acc_steps = grad_acc_steps 15 | self.seq_len = seq_len 16 | 17 | self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size 18 | 19 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 20 | self.dataset = load_dataset(dataset_name, split=split) 21 | 22 | # Tokenize and chunk the dataset 23 | self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_len, num_proc) 24 | 25 | total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1) 26 | assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens" 27 | 28 | self.sampler = DistributedSampler( 29 | self.tokenized_dataset, 30 | num_replicas=pgm.process_group_manager.dp_world_size, 31 | rank=pgm.process_group_manager.dp_rank, 32 | seed=seed, 33 | shuffle=False 34 | ) 35 | 36 | super().__init__( 37 | self.tokenized_dataset, 38 | batch_size=micro_batch_size, 39 | collate_fn=self.collate_batch, 40 | pin_memory=True, 41 | num_workers=num_workers, 42 | sampler=self.sampler, 43 | shuffle=False, 44 | ) 45 | 46 | def tokenizer_group_text(self, examples, tokenizer, sequence_length): 47 | """Tokenize a list of texts and group them in chunks of sequence_length + 1""" 48 | tokenized_text_batch = tokenizer.batch_encode_plus( 49 | examples, 50 | return_attention_mask=False, 51 | return_token_type_ids=False, 52 | return_tensors='np' 53 | ) 54 | concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} 55 | total_length = len(concatenated_tokens['input_ids']) 56 | 57 | if total_length >= sequence_length + 1: 58 | total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 59 | 60 | result = { 61 | 'input_ids': [ 62 | concatenated_tokens['input_ids'][i : i + sequence_length + 1] 63 | for i in range(0, total_length - sequence_length, sequence_length) 64 | ] 65 | } 66 | return result 67 | 68 | def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): 69 | """Tokenize the dataset and group texts in chunks of sequence_length + 1""" 70 | tokenizer_func = partial( 71 | self.tokenizer_group_text, 72 | tokenizer=self.tokenizer, 73 | sequence_length=sequence_length 74 | ) 75 | 76 | tokenized_dataset = dataset.map( 77 | tokenizer_func, 78 | input_columns=text_column_name, 79 | remove_columns=dataset.column_names, 80 | features=Features({ 81 | "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) 82 | }), 83 | batched=True, 84 | num_proc=num_proc, 85 | load_from_cache_file=True, # Preprocess dataset only once and cache it 86 | desc=f"Grouping texts in chunks of {sequence_length+1}", 87 | ) 88 | 89 | return tokenized_dataset 90 | 91 | def collate_batch(self, batch): 92 | batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) 93 | batch_size = batch_input_ids.size(0) 94 | input_ids = batch_input_ids[:, :-1].contiguous() 95 | target_ids = batch_input_ids[:, 1:].contiguous() 96 | position_ids = torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() 97 | attn_mask = torch.tril(torch.ones((self.seq_len, self.seq_len), dtype=torch.bool)) 98 | attn_mask = attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() 99 | 100 | return { 101 | "input_ids": input_ids, 102 | "target_ids": target_ids, 103 | "position_ids": position_ids, 104 | "attn_mask": attn_mask, 105 | "hidden_states": None 106 | } 107 | 108 | def __iter__(self): 109 | if self._iterator is None: 110 | self._iterator = super().__iter__() 111 | return self 112 | 113 | def __next__(self): 114 | if self._iterator is None: 115 | self._iterator = super().__iter__() 116 | try: 117 | batch = next(self._iterator) 118 | except StopIteration: 119 | self._iterator = None 120 | raise StopIteration 121 | return batch -------------------------------------------------------------------------------- /step7_pipeline_parallel_afab/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from flash_attn.flash_attn_interface import flash_attn_func 5 | from flash_attn.layers.rotary import apply_rotary_emb 6 | from flash_attn.ops.triton.layer_norm import layer_norm_fn 7 | import process_group_manager as pgm 8 | 9 | def flash_attention(q, k, v, causal = True): 10 | q = q.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 11 | k = k.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 12 | v = v.permute(0, 2, 1, 3) # [batch_size, seq_length, num_head , head_dim] 13 | return flash_attn_func(q, k, v, causal=causal) 14 | 15 | def get_cos_sin(seq_length, head_dim, base=500000.0): 16 | assert head_dim%2==0 17 | # Results on CUDA and CPU are different even with the same formula, To match transformers implementation. frequency should be computed on CPU 18 | theta = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float().to('cpu') / head_dim)) 19 | dtype = torch.bfloat16 20 | device = torch.device('cuda') 21 | position = torch.arange(seq_length).to(device).unsqueeze(1).float() # [seq_length, 1] 22 | # To match transformers implementation. m * theta should be computed on GPU 23 | theta = theta.to(device) 24 | return torch.cos(position.float()*theta.float()).to(dtype).repeat(1,2), torch.sin(position.float()*theta.float()).to(dtype).repeat(1,2) # [seq_length, head_dim], [seq_length, head_dim] 25 | 26 | class TritonRMSNorm(nn.Module): 27 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 28 | super().__init__() 29 | self.eps = eps 30 | self.weight = nn.Parameter(torch.ones(hidden_size)) 31 | self.register_parameter("bias", None) 32 | 33 | def forward( 34 | self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False 35 | ): 36 | return layer_norm_fn( 37 | hidden_states, 38 | self.weight, 39 | None, 40 | residual=residual, 41 | eps=self.eps, 42 | dropout_p=dropout_p, 43 | prenorm=prenorm, 44 | residual_in_fp32=residual_in_fp32, 45 | is_rms_norm=True, 46 | return_dropout_mask=return_dropout_mask, 47 | ) 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, config, layer_idx): 51 | super().__init__() 52 | self.hidden_size = config.hidden_size 53 | self.num_heads = config.num_attention_heads 54 | self.num_key_values = config.num_key_value_heads 55 | self.head_dim = self.hidden_size//self.num_heads 56 | assert config.num_attention_heads % pgm.process_group_manager.tp_world_size == 0, "num_attention_heads should be divisible by tp world size" 57 | assert config.num_key_value_heads % pgm.process_group_manager.tp_world_size == 0, "num_key_value_heads should be divisible by tp world size" 58 | self.num_local_heads = config.num_attention_heads // pgm.process_group_manager.tp_world_size # TP parallelism 59 | self.num_local_kv_heads = config.num_key_value_heads // pgm.process_group_manager.tp_world_size # TP parallelism 60 | 61 | 62 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads*self.head_dim, bias=False) 63 | self.k_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 64 | self.v_proj = nn.Linear(config.hidden_size, self.num_key_values*self.head_dim, bias=False) 65 | self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) 66 | self.layer_idx = layer_idx 67 | 68 | def forward(self, x, cos, sin, attention_mask=None, position_ids=None): 69 | batch_size, seq_length, hidden_dim = x.size() 70 | q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim] 71 | k = self.k_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 72 | v = self.v_proj(x) # [batch_size, seq_length, num_key_values*head_dim] 73 | 74 | q = q.view(batch_size, seq_length, self.num_local_heads, self.head_dim) # [batch_size, seq_length, num_heads, head_dim] 75 | k = k.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim) # [batch_size, seq_length, num_key_values, head_dim] 76 | q = apply_rotary_emb(q,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_heads, head_dim] 77 | k = apply_rotary_emb(k,cos[:, :self.head_dim // 2], sin[:, :self.head_dim // 2],interleaved=False) # [batch_size, seq_length, num_key_values, head_dim] 78 | q = q.transpose(1, 2) # [batch_size, num_heads, seq_length, head_dim] 79 | k = k.transpose(1, 2) # [batch_size, num_key_values, seq_length, head_dim] 80 | v = v.view(batch_size, seq_length, self.num_local_kv_heads, self.head_dim).transpose(1,2) # [batch_size, num_key_values, seq_length, head_dim] 81 | 82 | k = k.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 83 | v = v.repeat_interleave(self.num_local_heads // self.num_local_kv_heads, dim=1) # [batch_size, num_heads, seq_length, head_dim] 84 | 85 | causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. 86 | 87 | out = flash_attention(q, k, v, causal = causal) # [batch_size, seq_length, num_heads, head_dim] 88 | 89 | out = out.reshape(batch_size, seq_length, self.num_local_heads * self.head_dim) # [batch_size, seq_length, hidden_dim] 90 | out = self.out_proj(out) # [batch_size, seq_length, hidden_dim] 91 | return out 92 | 93 | class MLP(nn.Module): 94 | def __init__(self, config) -> None: 95 | super().__init__() 96 | self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 97 | self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) 98 | self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 99 | 100 | def forward(self, x): 101 | #TODO: dont do single line operations as it is harder to debug 102 | return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) 103 | 104 | class DecoderLayer(nn.Module): 105 | # TritonRMSNorm -> Attention -> Residual -> TritonRMSNorm -> MLP -> Residual 106 | def __init__(self, config, layer_idx): 107 | super().__init__() 108 | self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 109 | self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 110 | self.attention = Attention(config, layer_idx = layer_idx) 111 | self.mlp = MLP(config) 112 | self.layer_idx = layer_idx 113 | head_dim = config.hidden_size // config.num_attention_heads 114 | self.cos, self.sin = get_cos_sin(config.max_position_embeddings, head_dim=head_dim , base=config.rope_theta) # [max_position_embeddings, head_dim] 115 | 116 | def forward(self, x, attention_mask = None, position_ids = None): 117 | cos, sin = self.cos, self.sin 118 | x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention 119 | x = x + self.mlp(self.post_attention_layernorm(x)) # MLP 120 | return x 121 | 122 | class Llama(nn.Module): 123 | def __init__(self, config) -> None: 124 | super().__init__() 125 | # sanity check 126 | assert config.hidden_size % config.num_attention_heads==0 127 | assert config.num_attention_heads % config.num_key_value_heads==0 128 | 129 | # params 130 | self.vocab_size = config.vocab_size 131 | self.hidden_size = config.hidden_size 132 | self.num_heads = config.num_attention_heads 133 | self.num_key_values = config.num_key_value_heads 134 | self.head_dim = self.hidden_size//self.num_heads 135 | self.max_position_embeddings = config.max_position_embeddings 136 | self.num_layers = config.num_hidden_layers 137 | self.model_config = config 138 | 139 | # modules 140 | self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) 141 | self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)]) 142 | self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False) 143 | self.final_norm = TritonRMSNorm(self.hidden_size, eps=config.rms_norm_eps) 144 | 145 | def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None): 146 | x = self.embedding(input_ids) 147 | for layer in self.decoder_layers: 148 | x = layer(x) # [batch_size, seq_length, hidden_dim] 149 | x = self.final_norm(x) 150 | logits = self.final_proj(x) 151 | 152 | return logits # [batch_size, seq_length, vocab_size] -------------------------------------------------------------------------------- /step7_pipeline_parallel_afab/pipeline_parallel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | 7 | import process_group_manager as pgm 8 | 9 | ### begin PP communications 10 | STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" 11 | def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None): 12 | """ 13 | Handles point-to-point communication between pipeline stages for forward and backward passes. 14 | 15 | Args: 16 | operation (str): Type of communication operation ('recv_forward', 'send_forward', 17 | 'recv_backward', 'send_backward') 18 | device: Target device for tensor operations (e.g., CPU, GPU) 19 | dtype: Data type for tensors 20 | tensor: Input tensor for send operations (default: None) 21 | shapes: Shape specifications for receiving tensors (default: None) 22 | 23 | Returns: 24 | torch.Tensor or None: Received tensor for receive operations, None for send operations 25 | """ 26 | global STEP 27 | global VERBOSE 28 | 29 | if operation == 'recv_forward': 30 | # Skip if this is the first pipeline stage (nothing to receive) 31 | if pgm.process_group_manager.pp_is_first_stage: return None 32 | # Create empty tensor to receive data 33 | tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype) 34 | src = pgm.process_group_manager.pp_prev_rank 35 | 36 | elif operation == 'send_forward': 37 | # Skip if this is the last pipeline stage (nothing to send forward) 38 | if pgm.process_group_manager.pp_is_last_stage: return 39 | dest = pgm.process_group_manager.pp_next_rank 40 | 41 | elif operation == 'recv_backward': 42 | # Skip if this is the last pipeline stage (nothing to receive from backward) 43 | if pgm.process_group_manager.pp_is_last_stage: return None 44 | tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype) 45 | src = pgm.process_group_manager.pp_next_rank 46 | 47 | elif operation == 'send_backward': 48 | # Skip if this is the first pipeline stage (nothing to send backward) 49 | if pgm.process_group_manager.pp_is_first_stage: return 50 | dest = pgm.process_group_manager.pp_prev_rank 51 | 52 | # Determine if this is a send operation and set peer rank 53 | is_send = operation.startswith('send') 54 | peer_rank = dest if is_send else src 55 | 56 | # Create P2P operation (send or receive) 57 | op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank) 58 | 59 | if VERBOSE: 60 | print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} " 61 | f"{pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | " 62 | f"STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True) 63 | 64 | # Execute communication operation and wait for completion 65 | [req.wait() for req in dist.batch_isend_irecv([op])] 66 | torch.cuda.synchronize() 67 | 68 | if VERBOSE: STEP += 1 69 | 70 | # Return received tensor for receive operations, None for send operations 71 | return tensor if not is_send else None 72 | ### end PP communications 73 | 74 | ### begin Pipeline Parallel 75 | class PipelineParallel(nn.Module): 76 | def __init__(self, model, config): 77 | super().__init__() 78 | layer_distribution = self.distribute_layers(config.num_hidden_layers) 79 | self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity() 80 | self.decoder_layers = nn.ModuleDict({str(i): model.decoder_layers[i] for i in layer_distribution}) 81 | self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity() 82 | self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity() 83 | 84 | def distribute_layers(self, num_layers): 85 | layers_per_gpu = [num_layers // pgm.process_group_manager.pp_world_size + (1 if i < num_layers % pgm.process_group_manager.pp_world_size else 0) for i in range(pgm.process_group_manager.pp_world_size)] 86 | start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank]) 87 | return list(range(start_layer, start_layer + layers_per_gpu[pgm.process_group_manager.pp_rank])) 88 | 89 | def forward(self, input_ids, position_ids, hidden_states): 90 | x = hidden_states if hidden_states is not None else input_ids 91 | x = self.embedding(x) 92 | for layer in self.decoder_layers.values(): 93 | x = layer(x, position_ids=position_ids) 94 | x = self.final_norm(x) 95 | return self.final_proj(x) 96 | 97 | def backward(self, input_tensor, output_tensor, output_tensor_grad): 98 | if input_tensor is not None: input_tensor.retain_grad() 99 | if output_tensor_grad is None: 100 | output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format) 101 | # torch.autograd.backward will automatically accumulates gradients in the leaves (cf: https://pytorch.org/docs/stable/generated/torch.autograd.backward.html) 102 | torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False) 103 | return input_tensor.grad if input_tensor is not None else None 104 | 105 | def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): 106 | """ 107 | Executes a training step using Activation Forward - Activation Backward (AFAB) pipeline parallelism. 108 | Implements separate forward and backward passes to optimize memory usage. 109 | """ 110 | logging_loss: torch.float32 = 0.0 111 | input_tensors, output_tensors = [], [] 112 | requires_grad_sync = pgm.process_group_manager.dp_world_size > 1 113 | 114 | # === All Forward Pass Phase === 115 | for _ in range(data_loader.grad_acc_steps): 116 | input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) 117 | batch = next(data_loader) 118 | batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor 119 | output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"]) 120 | pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype) 121 | 122 | # calculate loss on the last stage 123 | if pgm.process_group_manager.pp_is_last_stage: 124 | output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') 125 | logging_loss += output_tensor.item() / data_loader.grad_acc_steps 126 | 127 | input_tensors.append(input_tensor) 128 | output_tensors.append(output_tensor) 129 | 130 | # === All Backward Pass Phase === 131 | for ith_microbatch in range(data_loader.grad_acc_steps): 132 | if requires_grad_sync: 133 | is_last_iteration = (ith_microbatch == data_loader.grad_acc_steps - 1) 134 | model.require_backward_grad_sync = is_last_iteration 135 | output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype) 136 | input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) 137 | input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) 138 | pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype) 139 | 140 | return logging_loss 141 | 142 | ### end Pipeline Parallel -------------------------------------------------------------------------------- /step7_pipeline_parallel_afab/process_group_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class ProcessGroupManager: 6 | def __init__(self, dp_size, pp_size, tp_size): 7 | self.global_rank = dist.get_rank() 8 | self.world_size = dist.get_world_size() 9 | self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) 10 | 11 | assert self.world_size == dp_size * pp_size * tp_size, f"World size ({self.world_size}) != DP ({self.dp_size}) * PP ({self.pp_size}) * TP ({self.tp_size})" 12 | 13 | self.grid = torch.arange(self.world_size).view(dp_size, pp_size, tp_size) # DP * PP * TP grid 14 | # Find the position of the current process in the grid 15 | self.dp_rank, self.pp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() 16 | 17 | # Process group creation - Update indexing to match new grid order 18 | self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :].tolist() for d in range(dp_size) for p in range(pp_size)])[0] 19 | self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, t].tolist() for d in range(dp_size) for t in range(tp_size)])[0] 20 | self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, t].tolist() for p in range(pp_size) for t in range(tp_size)])[0] 21 | self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, t].flatten().tolist() for t in range(tp_size)])[0] 22 | 23 | self.world_group = dist.group.WORLD 24 | 25 | # Update group IDs with new grid ordering 26 | self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, :].tolist() 27 | self.pp_group_ids = self.grid[self.dp_rank, :, self.tp_rank].tolist() 28 | self.dp_group_ids = self.grid[:, self.pp_rank, self.tp_rank].tolist() 29 | 30 | # Tensor parallelism 31 | self.tp_world_size = dist.get_world_size(group=self.tp_group) 32 | self.tp_first_rank = self.tp_group_ids[0] 33 | self.tp_last_rank = self.tp_group_ids[-1] 34 | 35 | # Pipeline parallelism 36 | self.pp_world_size = dist.get_world_size(group=self.pp_group) 37 | self.pp_first_rank = self.pp_group_ids[0] 38 | self.pp_last_rank = self.pp_group_ids[-1] 39 | self.pp_is_first_stage = self.pp_rank == 0 40 | self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 41 | self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.tp_rank].item()) 42 | self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.tp_rank].item()) 43 | 44 | # Data parallelism 45 | self.dp_world_size = dist.get_world_size(group=self.dp_group) 46 | self.dp_first_rank = self.dp_group_ids[0] 47 | self.dp_last_rank = self.dp_group_ids[-1] 48 | 49 | def __str__(self): 50 | return f"DP({self.dp_world_size})-PP({self.pp_world_size})-TP({self.tp_world_size})-Rank({self.global_rank})" 51 | 52 | def setup_process_group_manager(dp_size, pp_size, tp_size): 53 | global process_group_manager 54 | process_group_manager = ProcessGroupManager(dp_size, pp_size, tp_size) -------------------------------------------------------------------------------- /step7_pipeline_parallel_afab/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import builtins 5 | import fcntl 6 | 7 | def print(*args, is_print_rank=True, **kwargs): 8 | """ solves multi-process interleaved print problem """ 9 | if not is_print_rank: return 10 | with open(__file__, "r") as fh: 11 | fcntl.flock(fh, fcntl.LOCK_EX) 12 | try: 13 | builtins.print(*args, **kwargs) 14 | finally: 15 | fcntl.flock(fh, fcntl.LOCK_UN) 16 | 17 | def set_all_seed(seed): 18 | for module in [random, np.random]: module.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 21 | 22 | def to_readable_format(num, precision=3): 23 | num_str = str(num) 24 | length = len(num_str) 25 | 26 | def format_with_precision(main, decimal, suffix): 27 | if precision == 0: 28 | return f"{main}{suffix}" 29 | return f"{main}.{decimal[:precision]}{suffix}" 30 | 31 | if length > 12: # Trillions 32 | return format_with_precision(num_str[:-12], num_str[-12:], 'T') 33 | elif length > 9: # Billions 34 | return format_with_precision(num_str[:-9], num_str[-9:], 'B') 35 | elif length > 6: # Millions 36 | return format_with_precision(num_str[:-6], num_str[-6:], 'M') 37 | elif length > 3: # Thousands 38 | return format_with_precision(num_str[:-3], num_str[-3:], 'K') 39 | else: 40 | return num_str -------------------------------------------------------------------------------- /step8_pipeline_parallel_1f1b/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, DistributedSampler 3 | import numpy as np 4 | from functools import partial 5 | from datasets import Features, Sequence, Value, load_dataset 6 | from transformers import AutoTokenizer 7 | 8 | import process_group_manager as pgm 9 | 10 | class MicroBatchDataLoader(DataLoader): 11 | def __init__(self, seq_len, micro_batch_size, grad_acc_steps, dataset_name, tokenizer_name, max_tokens, num_workers, num_proc, seed, split="train"): 12 | 13 | self.micro_batch_size = micro_batch_size 14 | self.grad_acc_steps = grad_acc_steps 15 | self.seq_len = seq_len 16 | 17 | self.global_batch_size = micro_batch_size * grad_acc_steps * pgm.process_group_manager.dp_world_size 18 | 19 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 20 | self.dataset = load_dataset(dataset_name, split=split) 21 | 22 | # Tokenize and chunk the dataset 23 | self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_len, num_proc) 24 | 25 | total_tokens = self.tokenized_dataset.num_rows * (self.seq_len + 1) 26 | assert total_tokens >= max_tokens, f"Not enough tokens. Have {total_tokens} tokens but need {max_tokens} tokens" 27 | 28 | self.sampler = DistributedSampler( 29 | self.tokenized_dataset, 30 | num_replicas=pgm.process_group_manager.dp_world_size, 31 | rank=pgm.process_group_manager.dp_rank, 32 | seed=seed, 33 | shuffle=False 34 | ) 35 | 36 | super().__init__( 37 | self.tokenized_dataset, 38 | batch_size=micro_batch_size, 39 | collate_fn=self.collate_batch, 40 | pin_memory=True, 41 | num_workers=num_workers, 42 | sampler=self.sampler, 43 | shuffle=False, 44 | ) 45 | 46 | def tokenizer_group_text(self, examples, tokenizer, sequence_length): 47 | """Tokenize a list of texts and group them in chunks of sequence_length + 1""" 48 | tokenized_text_batch = tokenizer.batch_encode_plus( 49 | examples, 50 | return_attention_mask=False, 51 | return_token_type_ids=False, 52 | return_tensors='np' 53 | ) 54 | concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])} 55 | total_length = len(concatenated_tokens['input_ids']) 56 | 57 | if total_length >= sequence_length + 1: 58 | total_length = ((total_length - 1) // sequence_length) * sequence_length + 1 59 | 60 | result = { 61 | 'input_ids': [ 62 | concatenated_tokens['input_ids'][i : i + sequence_length + 1] 63 | for i in range(0, total_length - sequence_length, sequence_length) 64 | ] 65 | } 66 | return result 67 | 68 | def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc): 69 | """Tokenize the dataset and group texts in chunks of sequence_length + 1""" 70 | tokenizer_func = partial( 71 | self.tokenizer_group_text, 72 | tokenizer=self.tokenizer, 73 | sequence_length=sequence_length 74 | ) 75 | 76 | tokenized_dataset = dataset.map( 77 | tokenizer_func, 78 | input_columns=text_column_name, 79 | remove_columns=dataset.column_names, 80 | features=Features({ 81 | "input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1) 82 | }), 83 | batched=True, 84 | num_proc=num_proc, 85 | load_from_cache_file=True, # Preprocess dataset only once and cache it 86 | desc=f"Grouping texts in chunks of {sequence_length+1}", 87 | ) 88 | 89 | return tokenized_dataset 90 | 91 | def collate_batch(self, batch): 92 | batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch]) 93 | batch_size = batch_input_ids.size(0) 94 | input_ids = batch_input_ids[:, :-1].contiguous() 95 | target_ids = batch_input_ids[:, 1:].contiguous() 96 | position_ids = torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous() 97 | attn_mask = torch.tril(torch.ones((self.seq_len, self.seq_len), dtype=torch.bool)) 98 | attn_mask = attn_mask.unsqueeze(0).expand(batch_size, -1, -1).contiguous() 99 | 100 | return { 101 | "input_ids": input_ids, 102 | "target_ids": target_ids, 103 | "position_ids": position_ids, 104 | "attn_mask": attn_mask, 105 | "hidden_states": None 106 | } 107 | 108 | def __iter__(self): 109 | if self._iterator is None: 110 | self._iterator = super().__iter__() 111 | return self 112 | 113 | def __next__(self): 114 | if self._iterator is None: 115 | self._iterator = super().__iter__() 116 | try: 117 | batch = next(self._iterator) 118 | except StopIteration: 119 | self._iterator = None 120 | raise StopIteration 121 | return batch -------------------------------------------------------------------------------- /step8_pipeline_parallel_1f1b/process_group_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | class ProcessGroupManager: 6 | def __init__(self, dp_size, pp_size, tp_size): 7 | self.global_rank = dist.get_rank() 8 | self.world_size = dist.get_world_size() 9 | self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) 10 | 11 | assert self.world_size == dp_size * pp_size * tp_size, f"World size ({self.world_size}) != DP ({self.dp_size}) * PP ({self.pp_size}) * TP ({self.tp_size})" 12 | 13 | self.grid = torch.arange(self.world_size).view(dp_size, pp_size, tp_size) # DP * PP * TP grid 14 | # Find the position of the current process in the grid 15 | self.dp_rank, self.pp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() 16 | 17 | # Process group creation - Update indexing to match new grid order 18 | self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :].tolist() for d in range(dp_size) for p in range(pp_size)])[0] 19 | self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, t].tolist() for d in range(dp_size) for t in range(tp_size)])[0] 20 | self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, t].tolist() for p in range(pp_size) for t in range(tp_size)])[0] 21 | self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, t].flatten().tolist() for t in range(tp_size)])[0] 22 | 23 | self.world_group = dist.group.WORLD 24 | 25 | # Update group IDs with new grid ordering 26 | self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, :].tolist() 27 | self.pp_group_ids = self.grid[self.dp_rank, :, self.tp_rank].tolist() 28 | self.dp_group_ids = self.grid[:, self.pp_rank, self.tp_rank].tolist() 29 | 30 | # Tensor parallelism 31 | self.tp_world_size = dist.get_world_size(group=self.tp_group) 32 | self.tp_first_rank = self.tp_group_ids[0] 33 | self.tp_last_rank = self.tp_group_ids[-1] 34 | 35 | # Pipeline parallelism 36 | self.pp_world_size = dist.get_world_size(group=self.pp_group) 37 | self.pp_first_rank = self.pp_group_ids[0] 38 | self.pp_last_rank = self.pp_group_ids[-1] 39 | self.pp_is_first_stage = self.pp_rank == 0 40 | self.pp_is_last_stage = self.pp_rank == self.pp_world_size - 1 41 | self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.tp_rank].item()) 42 | self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.tp_rank].item()) 43 | 44 | # Data parallelism 45 | self.dp_world_size = dist.get_world_size(group=self.dp_group) 46 | self.dp_first_rank = self.dp_group_ids[0] 47 | self.dp_last_rank = self.dp_group_ids[-1] 48 | 49 | def __str__(self): 50 | return f"DP({self.dp_world_size})-PP({self.pp_world_size})-TP({self.tp_world_size})-Rank({self.global_rank})" 51 | 52 | def setup_process_group_manager(dp_size, pp_size, tp_size): 53 | global process_group_manager 54 | process_group_manager = ProcessGroupManager(dp_size, pp_size, tp_size) -------------------------------------------------------------------------------- /step8_pipeline_parallel_1f1b/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import builtins 5 | import fcntl 6 | 7 | def print(*args, is_print_rank=True, **kwargs): 8 | """ solves multi-process interleaved print problem """ 9 | if not is_print_rank: return 10 | with open(__file__, "r") as fh: 11 | fcntl.flock(fh, fcntl.LOCK_EX) 12 | try: 13 | builtins.print(*args, **kwargs) 14 | finally: 15 | fcntl.flock(fh, fcntl.LOCK_UN) 16 | 17 | def set_all_seed(seed): 18 | for module in [random, np.random]: module.seed(seed) 19 | torch.manual_seed(seed) 20 | if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) 21 | 22 | def to_readable_format(num, precision=3): 23 | num_str = str(num) 24 | length = len(num_str) 25 | 26 | def format_with_precision(main, decimal, suffix): 27 | if precision == 0: 28 | return f"{main}{suffix}" 29 | return f"{main}.{decimal[:precision]}{suffix}" 30 | 31 | if length > 12: # Trillions 32 | return format_with_precision(num_str[:-12], num_str[-12:], 'T') 33 | elif length > 9: # Billions 34 | return format_with_precision(num_str[:-9], num_str[-9:], 'B') 35 | elif length > 6: # Millions 36 | return format_with_precision(num_str[:-6], num_str[-6:], 'M') 37 | elif length > 3: # Thousands 38 | return format_with_precision(num_str[:-3], num_str[-3:], 'K') 39 | else: 40 | return num_str --------------------------------------------------------------------------------