├── 01_Introduction_to_LUMI └── README.md ├── 02_Using_the_LUMI_web_interface ├── Clone_with_JupyterLab.md ├── GPT-neo-IMDB-introduction.ipynb ├── README.md └── images │ ├── step0.png │ ├── step1.png │ └── step2.png ├── 03_Your_first_AI_training_job_on_LUMI ├── GPT-neo-IMDB-finetuning.py ├── README.md ├── images │ └── lumi_web_interface_edit_file.png ├── reference_solution │ ├── GPT-neo-IMDB-finetuning.py │ ├── resume_from_checkpoint │ │ ├── GPT-neo-IMDB-finetuning.py │ │ ├── run.sh │ │ └── util.py │ ├── run.sh │ └── util.py ├── run.sh └── util.py ├── 04_Understanding_GPU_activity_and_checking_jobs ├── README.md ├── images │ └── profile.png └── reference_solution │ └── GPT-neo-IMDB-finetuning-profile.py ├── 05_Running_containers_on_LUMI ├── Hello_LUMI_GPU_World.py ├── README.md ├── examples │ ├── build_ubuntu_tree.sh │ ├── print_directory_trees.sh │ └── ubuntu_tree.def └── reference_solution │ └── reference_solution.md ├── 06_Bulding_containers_from_conda_pip_environments ├── README.md ├── examples │ ├── PandasAI.yml │ ├── minimal_pytorch.yml │ └── python312.yml └── reference_solution │ ├── panopticapi.yml │ ├── python312_extra.yml │ └── reference_solution.md ├── 07_Extending_containers_with_virtual_environments_for_faster_testing ├── README.md └── examples │ └── extending_containers_with_venv.md ├── 08_Scaling_to_multiple_GPUs ├── GPT-neo-IMDB-finetuning.py ├── README.md ├── reference_solution │ ├── GPT-neo-IMDB-finetuning.py │ ├── prints_only_from_single_process │ │ ├── GPT-neo-IMDB-finetuning.py │ │ ├── run_no_torchrun.sh │ │ ├── run_torchrun.sh │ │ ├── slurm-9304946.out │ │ ├── slurm-9304949.out │ │ └── util.py │ ├── run_no_torchrun.sh │ ├── run_torchrun.sh │ ├── util.py │ └── with_cpu_bindings │ │ ├── GPT-neo-IMDB-finetuning.py │ │ ├── run_no_torchrun.sh │ │ ├── run_torchrun.sh │ │ └── util.py ├── run.sh └── util.py ├── 09_Extreme_scale_AI ├── README.md ├── images │ ├── profile-detail.png │ └── profile.png └── reference_solution │ └── README.md ├── 10_Coupling_AI_and_HPC └── README.md ├── LICENSE ├── README.md └── bonus_material ├── README.md └── exercise_container_recipes ├── README.md ├── build_pytorch_transformers.sh └── pytorch_transformers.yml /01_Introduction_to_LUMI/README.md: -------------------------------------------------------------------------------- 1 | # 01 Introduction to LUMI 2 | 3 | No other material than slides exists for this lecture. 4 | -------------------------------------------------------------------------------- /02_Using_the_LUMI_web_interface/Clone_with_JupyterLab.md: -------------------------------------------------------------------------------- 1 | # Cloning the course git repository using JupyterLab UI 2 | 3 | 1. Open a JupyterLab session using the Jupyter app on the LUMI web interface [www.lumi.csc.fi](https://www.lumi.csc.fi) 4 | 5 | Follow the instructions in the second part of the exercise for this session. You can then keep using the session 6 | for the rest of the exercise. 7 | 8 | 2. Once you have opened JupyterLab and opened your own folder in the navigation panel to the left, your browser should present a view like this (in this case for user `lukaspre`): 9 | 10 | ![After starting JupyterLab and opening your own folder, the navigation panel shows an empty list and the main screen a selection of apps to use in JupyterLab.](images/step0.png) 11 | 12 | 4. Use the highlighted button to open the UI popup for cloning a git repository: 13 | 14 | ![The button for cloning a git repository is in the top-left corner, just above the file search input.](images/step1.png) 15 | 16 | 5. Enter the repository URL ( [https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop](https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop) ) and press the "Clone" button. 17 | 18 | ![The repository URL should be entered in the opening popup.](images/step2.png) 19 | 20 | This will clone the respository in a new folder "Getting_Started_with_AI_workshop" in your directory on the course project scratch filesystem. 21 | -------------------------------------------------------------------------------- /02_Using_the_LUMI_web_interface/GPT-neo-IMDB-introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "editable": true, 7 | "slideshow": { 8 | "slide_type": "" 9 | }, 10 | "tags": [] 11 | }, 12 | "source": [ 13 | "# Introduction to the Ongoing Example\n", 14 | "\n", 15 | "In this notebook you will get to know the example machine learning task we will consider for most of the exercises throughout the course: We will finetune the GPT-neo language model by EleutherAI on the Stanford IMDb movie review data set to obtain a model specialised in generating movie reviews.\n", 16 | "\n", 17 | "Since both the model and the data set are availabe from huggingface.co, we will use the libraries provided by HuggingFace, which present a slightly higher level abstraction of training with PyTorch.\n", 18 | "\n", 19 | "This notebook does not yet perform any training but demonstrates loading the model and allows you to perform inference, i.e., generating some text with it. It also loads the training data set for you to explore." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "editable": true, 26 | "slideshow": { 27 | "slide_type": "" 28 | }, 29 | "tags": [] 30 | }, 31 | "source": [ 32 | "We begin by loading the required Python modules, but before that we first need to set environment variable to point to a shared cache directory which the `transformers` library uses when loading the model, so it does not have to download the same model repeatedly:" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import os\n", 42 | "os.environ[\"HF_HOME\"] = \"/flash/project_465001958/hf-cache\"" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import torch\n", 52 | "from datasets import load_dataset\n", 53 | "from transformers import AutoModelForCausalLM, AutoTokenizer" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "Next we determine the device on which to run the model. Even though LUMI uses AMD MI250x GPUs, PyTorch still uses `cuda` when we mean \"GPU\".\n", 61 | "The following should print: \"Using device: cuda\".\n", 62 | "If this is not the case, then we have made a mistake in allocating resources for the job or loading the proper software environment." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", 72 | "print(f\"Using device {device}\")\n", 73 | "if device.type == 'cuda':\n", 74 | " print(f\"Device name is {torch.cuda.get_device_name(device)}\")" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## Meet the Pre-Trained Base Model\n", 82 | "\n", 83 | "Now we can load the actual model. We use the 1.3 billion parameter variant of the GPT-neo model, which takes about 5.4 GiB of VRAM in its native 32-bit float form. A single Graphics Compute Die (GCD (i.e., a GPU)) on LUMI has 64 GiB of VRAM, so we do not need to worry about our memory footprint at this point. We also set up the corresponding tokenizer the model was trained with." 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "pretrained_model = \"EleutherAI/gpt-neo-1.3B\"\n", 93 | "tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True)\n", 94 | "tokenizer.pad_token = tokenizer.eos_token\n", 95 | "model = AutoModelForCausalLM.from_pretrained(pretrained_model)\n", 96 | "model.to(device)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "With the tokenizer and model set up and loaded to the GPU, we can now use the model to generate some text. Since we ultimately want to generate movie reviews (after finetuning in the upcoming exercises), let's see how well the GPT-neo model does in generating reviews prior to finetuning." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "with torch.no_grad():\n", 113 | " prompt = \"The movie 'How to run ML on LUMI - A documentation' was great because\"\n", 114 | " inputs = tokenizer(prompt, return_tensors='pt').to(device)\n", 115 | " outputs = model.generate(**inputs, do_sample=True, max_length=80, num_return_sequences=4)\n", 116 | " decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", 117 | "\n", 118 | " print('Sample generated reviews:')\n", 119 | " for i, txt in enumerate(decoded_outputs):\n", 120 | " print(\"#######################\")\n", 121 | " print(f\"{i+1}: {txt}\")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "These do probably not all look like reviews for movies (although some probably start of somewhat promising, then deviate into something that looks more like a blog post or similar). \n", 129 | "In the next exercises we will train the model on the IMDb data set to make it generate better movie reviews.\n", 130 | "\n", 131 | "At this point, you can experiment with the text generation if you wish. Text generation strategies are discussed here: https://huggingface.co/docs/transformers/generation_strategies . You can also change the input prompt. Alternatively, skip ahead to the next part.\n", 132 | "\n", 133 | "In particular, these parameters fo `model.generate` might be interesting:\n", 134 | "\n", 135 | " - `max_new_tokens`: the maximum number of tokens to generate,\n", 136 | " - `num_beams`: activate Beam search by setting this > 1,\n", 137 | " - `do_sample`: activate multinomial sampling if set to `True`,\n", 138 | " - `num_return_sequences`: the number of candidate sentences to return (available only for beam search and sampling).\n", 139 | "\n", 140 | "For a more detailed description of how to perform generation with different decoding methods / search strategies with the `transformers` module, you may want to read this blog post: https://huggingface.co/blog/how-to-generate" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "## Meet the Training Data\n", 148 | "\n", 149 | "Finally, let us have a look at the training data. The Standford IMDb movie data set was primarily set up for sentiment analysis tasks and consists of 100'000 movie reviews, 50'000 of which are annotated with a sentiment label while the remainder are unlabelled (\"unsupervised\"). Of the labelled reviews, 25'000 are designated for testing.\n", 150 | "\n", 151 | "The `datasets` module makes it easy to load from huggingface.co . For our purposes we use both the labelled and unlabelled training splits (`train` and `unsupervised`).\n", 152 | "\n", 153 | "Since the data set is relatively small (only a couple hundred MB), we can keep it entirely in memory and not have to worry about filesystem IO." 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "train_dataset = load_dataset(\"imdb\", split=\"train+unsupervised\", trust_remote_code=False, keep_in_memory=True)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "Let's have a look at an example from the training data:" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "train_dataset[200]" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": { 184 | "editable": true, 185 | "slideshow": { 186 | "slide_type": "" 187 | }, 188 | "tags": [] 189 | }, 190 | "source": [ 191 | "We can see that each element has the review text as well as a sentiment label. We will ignore the label in the following exercises as we are only interested in fine-tuning the model to generate texts that look like IMDB movie reviews." 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "kernelspec": { 197 | "display_name": "Python 3 (ipykernel)", 198 | "language": "python", 199 | "name": "python3" 200 | }, 201 | "language_info": { 202 | "codemirror_mode": { 203 | "name": "ipython", 204 | "version": 3 205 | }, 206 | "file_extension": ".py", 207 | "mimetype": "text/x-python", 208 | "name": "python", 209 | "nbconvert_exporter": "python", 210 | "pygments_lexer": "ipython3", 211 | "version": "3.10.12" 212 | } 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 4 216 | } 217 | -------------------------------------------------------------------------------- /02_Using_the_LUMI_web_interface/README.md: -------------------------------------------------------------------------------- 1 | # 02 The LUMI web-interface 2 | 3 | ## Hands-on exercises 4 | 5 | 1. Get started with the LUMI web interface and set up your own copy of the exercises. 6 | 7 | In this exercise you will gain first experience with using the LUMI web interface to navigate files and directories on the LUMI supercomputer. You will also set up your own copy of the exercise repository on the system, so that you can work on them without interfering with the other course participants. 8 | 9 | 1. Log in to the LUMI web interface: https://www.lumi.csc.fi 10 | 2. Create your own subdirectory in `/project/project_465001958/` and `/scratch/project_465001958/`. Use your username for the directory name. You can either 11 | - Use the built-in file explorer ("Home Directory"), or 12 | - Use the login node shell app in the webinterface 13 | 3. Clone the [exercise repository](https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop) to your folder in `/project/project_465001958/`. You can either 14 | - use the login node shell app in the webinterface, or 15 | - start a Jupyter lab job and use the Jupyter lab UI for cloning Git repositories, see [Clone_with_JupyterLab.md](./Clone_with_JupyterLab.md) for an illustrated step-by-step guide for this. 16 | 4. Get familiar with the exercise repository layout. 17 | 18 | 2. Start an interactive Jupyter lab job and run inference with GPT-neo. 19 | 20 | In this exercise you will learn how to reserve resources for and start an interactive job to run a Jupyter notebook via the LUMI web interface. The notebook itself introduces you to our running example of finetuning a language model using PyTorch and the training libraries provided by Huggingface. In this exercise you will not do any training, but familiarise yourself a bit with the software and the base model. 21 | 22 | 1. Start an interactive Jupyter session: Open the Jupyter app (! not "Jupyter for Courses" !) in the LUMI webinterface and set the following settings before pressing `Launch` 23 | - Project: `project_465001958` 24 | - Reservation: Use the course reservation `AI_workshop_1` (there should only be one available option) 25 | - Partition: `small-g` 26 | - Number of CPU cores: `7` 27 | - Memory (GB): `16` 28 | - Time: `0:30:00` 29 | - Working directory: `/project/project_465001958/` 30 | - Python: `pytorch (Via CSC stack, limited support available)` 31 | - Virtual environment path: leave empty 32 | 2. Wait for the session to start, then press `Connect to Jupyter` 33 | 34 | > **Note** 35 | > 36 | > Jupyter will open in a new tab. Note that your interactive job will not stop if you close the tab, so you can always reconnect to Jupyter via the `My Interactive Session` page of the LUMI web interface. You also have to explicitly `Cancel` the running job from there in order to stop Jupyter when you are done - otherwise your job will continue to consume the allocated resources until the time limit you gave is reached. 37 | 38 | 3. Open and run the `/Getting_Started_with_AI_workshop/02_Using_the_LUMI_web_interface/GPT-neo-IMDB-introduction.ipynb` notebook, which introduces our ongoing example for the remaining exercises. Familiarise yourself with the code. You can try to 39 | - explore the effect of adjusting the parameters for the `tokenizer` and `model.generate` calls 40 | - try different input prompts 41 | - explore the contents of the training data set 42 | 43 | 3. (OPTIONAL) Explore other apps available in the LUMI webinterface 44 | 45 | If you want, you can explore the remaining apps in the LUMI webinterface a bit. Interesting could be 46 | 47 | - Disk quotas: Shows how much storage is available to your projects and how much is currently in use. 48 | - Project view: Shows how much compute and storage billing units are left for each of your projects. 49 | - Active jobs: Shows a list of all your current compute jobs. 50 | - Desktop: Starts an interactive session with a Desktop-like user interface / window manager. 51 | - Cloud storage configuration: Allows you to configure access tokens to transfer files between LUMI object storage and the LUMI compute cluster. 52 | -------------------------------------------------------------------------------- /02_Using_the_LUMI_web_interface/images/step0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/ce4d7cc37162350ff0d620d0301b8f9e7f0b5c98/02_Using_the_LUMI_web_interface/images/step0.png -------------------------------------------------------------------------------- /02_Using_the_LUMI_web_interface/images/step1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/ce4d7cc37162350ff0d620d0301b8f9e7f0b5c98/02_Using_the_LUMI_web_interface/images/step1.png -------------------------------------------------------------------------------- /02_Using_the_LUMI_web_interface/images/step2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/ce4d7cc37162350ff0d620d0301b8f9e7f0b5c98/02_Using_the_LUMI_web_interface/images/step2.png -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/GPT-neo-IMDB-finetuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # IMDB movie review text generation 5 | # 6 | # In this script, we'll fine-tune a GPT2-like model to generate more 7 | # movie reviews based on a prompt. 8 | # 9 | # Partly based on this tutorial: 10 | # https://github.com/omidiu/GPT-2-Fine-Tuning/ 11 | 12 | 13 | import argparse 14 | import math 15 | import time 16 | from pprint import pprint 17 | 18 | import torch 19 | from datasets import load_dataset 20 | from util import preprocess_data, get_output_paths 21 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 22 | DataCollatorForLanguageModeling, Trainer, 23 | TrainingArguments) 24 | 25 | if __name__ == "__main__": 26 | 27 | # First we set up some command line arguments to allow us to specify data/output paths 28 | # and the number of worker processes without changing the code. 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--model-name", 32 | type=str, 33 | default="gpt-imdb-model", 34 | help="A name for the trained model under. A subdirectory with the given name will be created under the `output-path`.", 35 | ) 36 | parser.add_argument( 37 | "--output-path", 38 | type=str, 39 | help="The root directory under which model checkpoints are stored.", 40 | ) 41 | parser.add_argument( 42 | "--logging-path", 43 | type=str, 44 | help="The root directory under which logging data (for tensorboard) are stored.", 45 | ) 46 | parser.add_argument( 47 | "--num-workers", 48 | type=int, 49 | default=1, 50 | help="The number of CPU worker processes to use.", 51 | ) 52 | args, _ = parser.parse_known_args() 53 | 54 | # Then we determine the device on which to train the model. 55 | print("Using PyTorch version:", torch.__version__) 56 | if torch.cuda.is_available(): 57 | device = torch.device("cuda") 58 | print("Using GPU, device name:", torch.cuda.get_device_name(device)) 59 | else: 60 | print("No GPU found, using CPU instead.") 61 | device = torch.device("cpu") 62 | 63 | # We also ensure that output paths exist 64 | output_dir, logging_dir = get_output_paths(args) 65 | 66 | # #### Loading the GPT-neo model 67 | # 68 | # We'll use the gpt-neo-1.3B model from the Hugging Face library: 69 | # https://huggingface.co/EleutherAI/gpt-neo-1.3B 70 | # Let's start with getting the appropriate tokenizer. 71 | pretrained_model = "EleutherAI/gpt-neo-1.3B" 72 | 73 | print("Loading model and tokenizer") 74 | start = time.time() 75 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 76 | tokenizer.pad_token = tokenizer.eos_token 77 | 78 | # Load the actual base model from Hugging Face 79 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 80 | model.to(device) 81 | stop = time.time() 82 | print(f"Loading model and tokenizer took: {stop-start:.2f} seconds") 83 | 84 | # #### Loading the IMDb data set 85 | # 86 | # Next we'll load the IMDb data set: https://huggingface.co/docs/datasets/index. 87 | # 88 | # The data set contains 100,000 movies reviews from the Internet Movie 89 | # Database, split into 25,000 reviews for training and 25,000 reviews 90 | # for testing and 50,000 without labels (unsupervised) that we also use for training. 91 | 92 | train_dataset = load_dataset( 93 | "imdb", split="train+unsupervised", trust_remote_code=False, keep_in_memory=True 94 | ) 95 | eval_dataset = load_dataset( 96 | "imdb", split="test", trust_remote_code=False, keep_in_memory=True 97 | ) 98 | 99 | # Let's print one sample from the dataset. 100 | print("Sample from dataset") 101 | pprint(train_dataset[200]) 102 | 103 | # #### Setting up the training configuration 104 | train_batch_size = 32 # This just about fits into the VRAM of a single MI250x GCD with 16-bit floats 105 | eval_batch_size = 128 # No optimizer state during evaluation, so can use bigger batches for increased throughput 106 | 107 | training_args = TrainingArguments( 108 | output_dir=output_dir, 109 | save_strategy="steps", 110 | save_steps=100, 111 | save_total_limit=4, 112 | logging_dir=logging_dir, 113 | eval_strategy="steps", 114 | eval_steps=200, # compute validation loss every 200 steps 115 | learning_rate=2e-5, 116 | weight_decay=0.01, 117 | bf16=True, # use 16-bit floating point precision 118 | per_device_train_batch_size=train_batch_size, 119 | per_device_eval_batch_size=eval_batch_size, 120 | max_steps=1000, 121 | dataloader_num_workers=args.num_workers, 122 | dataloader_pin_memory=True, 123 | report_to=["tensorboard"], # log statistics for tensorboard 124 | ) 125 | 126 | # #### Preprocessing of training data 127 | # We tokenize the data into torch tensors, split training into training and validation and set up a collator that 128 | # is able to arrange single data samples into batches. 129 | 130 | train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized = preprocess_data(train_dataset, eval_dataset, tokenizer, training_args) 131 | 132 | collator = DataCollatorForLanguageModeling( 133 | tokenizer, mlm=False, return_tensors="pt" 134 | ) 135 | 136 | # Sanity check: How does the training data look like after preprocessing? 137 | print("Sample of tokenized data") 138 | for b in train_dataset_tokenized: 139 | pprint(b, compact=True) 140 | print("Length of input_ids:", len(b["input_ids"])) 141 | break 142 | print("Length of dataset (tokenized)", len(train_dataset_tokenized)) 143 | 144 | # #### Training 145 | # We use the Hugging Face trainer instead of a manual training loop. 146 | # 147 | # You can read about the many, many different parameters to the 148 | # Hugging Face trainer here: 149 | # https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments 150 | # 151 | 152 | collator = DataCollatorForLanguageModeling( 153 | tokenizer, mlm=False, return_tensors="pt" 154 | ) 155 | 156 | trainer = Trainer( 157 | model=model, 158 | args=training_args, 159 | tokenizer=tokenizer, 160 | data_collator=collator, 161 | train_dataset=train_dataset_tokenized, 162 | eval_dataset=validate_dataset_tokenized, 163 | ) 164 | 165 | # With 1000 steps, batch size 32 and a single GCD, this should take just under 30 minutes. 166 | trainer.train() 167 | 168 | print() 169 | print("Training done, you can find all the model checkpoints in", output_dir) 170 | 171 | # #### Evaluating the finetuned model 172 | with torch.no_grad(): 173 | model.eval() 174 | # Calculate perplexity 175 | eval_results = trainer.evaluate() 176 | test_results = trainer.evaluate(eval_dataset_tokenized) 177 | 178 | print(f'Perplexity on validation: {math.exp(eval_results["eval_loss"]):.2f}') 179 | print(f'Perplexity on test: {math.exp(test_results["eval_loss"]):.2f}') 180 | 181 | # Let's print a few sample generated reviews; this is the same as in the previous exercise 182 | # but now we use the finetuned model 183 | prompt = "The movie 'How to run ML on LUMI - A documentation' was great because" 184 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 185 | outputs = model.generate( 186 | **inputs, do_sample=True, max_length=80, num_return_sequences=4 187 | ) 188 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 189 | 190 | print("Sample generated review:") 191 | for txt in decoded_outputs: 192 | print("-", txt) 193 | -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/README.md: -------------------------------------------------------------------------------- 1 | # 03 Your first AI training job on LUMI 2 | 3 | ## Hands-on exercises 4 | 5 | 1. Familiarise yourself with the training script. 6 | 7 | As a first exercise, familiarise yourself with the training script and understand how it is working. You do not need to make any changes to the file at this point. 8 | 9 | 1. Check the training script [03_Your_first_AI_training_job_on_LUMI/GPT-neo-IMDB-finetuning.py](GPT-neo-IMDB-finetuning.py). 10 | 11 | You can open it via 12 | - the built-in editor of the [LUMI web interface](https://lumi.csc.fi) file explorer: ![Open the LUMI web interface file editor by navigating to a file, clicking the "three dots" menu button and then selecting "Edit"](images/lumi_web_interface_edit_file.png) 13 | - any command line editor from a login node shell, either via the [LUMI web interface](https://lumi.csc.fi) or an [SSH connection](https://docs.lumi-supercomputer.eu/firststeps/loggingin/). 14 | - the Visual Studio Code app in the LUMI web interface (use the `interactive` partition) 15 | 16 | 2. Create a slurm batch file and start a training run. 17 | 18 | Next you will prepare the slurm batch file that specifies the resources required for the training, sets up the software environment and, finally, executes the training script on a compute node in a singularity container environment. 19 | 20 | In the same directory as the script you can find [03_Your_first_AI_training_job_on_LUMI/run.sh](run.sh), an incomplete slurm batch job file. 21 | 22 | 1. Fill in the missing pieces (marked with ``). 23 | 24 | You should specify at least the following: 25 | - the correct slurm partition 26 | - number of CPUs requested 27 | - number of GPUs requested (1) 28 | - RAM requested 29 | - requested runtime (recommended: 15 minutes, so you can continue with sub-exercise 4 below) 30 | 31 | It can also be helpful to specify a name for the slurm logfile that contains the command line outputs of the script. 32 | 33 | The Python command needs to be run in a singularity container with the required software packages installed. The slurm batch file sets up a variable `CONTAINER` 34 | with the path to the container you should use. 35 | 36 | For the Python script itself you will need to provide the following command line arguments: 37 | - `--output-path` (for the trained model and checkpoints) 38 | - `--logging-path` (for tensorboard logging data) 39 | - `--model-name` (a name under which the model produced by the run will be stored; optional) 40 | - `--num-workers` (optional, is used to set the number of PyTorch dataloader processes) 41 | 42 | Please set the paths to some destination of your choice within your `/scratch/project_465001958/` directory. 43 | 44 | > **Tip** 45 | > 46 | > The script sets environment variables `OUTPUT_DIR` and `LOGGING_DIR` with suggested paths you can use 47 | > for storing the trained model as well as tensorboard logging data. 48 | 49 | > **Tip** 50 | > 51 | > Slurm sets the environment variable `SLURM_CPUS_PER_TASK` to the value of allocated CPU cores per task, which 52 | > can be useful for setting `--num-workers`. 53 | 54 | 2. Start the training using `sbatch run.sh` from a login node shell (LUMI web interface or SSH). 55 | 56 | You can continue with the remaining exercises below while the training proceeds. 57 | 58 | 3. Check your job. 59 | 60 | 1. From a login node shell, use the slurm command `squeue --me` to check that your job is running. You can use the `tail -f` command to check the outputs of the job from the slurm log file `/03_Your_first_AI_training_job_on_LUMI/slurm-.out` that is created once the job is running. 61 | 62 | You can also check your active jobs from the LUMI web interface: Navigate to Jobs > Active Jobs. 63 | 64 | > **Note** 65 | > 66 | > We will cover more details about checking the status and progress of your job in a later exercise. 67 | 68 | 4. Modify the script to enable it to continue from a checkpoint. 69 | 70 | Using checkpoints can be helpful to recover from failing runs (either due to errors encountered in your code or due to hardware or scheduler issues) without losing all your progress. It also allows to split your total training into smaller pieces that are more likely to pass through the scheduler queue quickly. 71 | 72 | The script currently already writes checkpoints but always starts training with from the basic GPT-neo model (i.e., it ignores previously written checkpoints). 73 | 74 | 1. Change the training script so it can load a checkpoint from a previously interrupted training run and resume training. Check the [documentation about HuggingFace Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) for guidance on this. 75 | 76 | You may want to add additional command line arguments to control this behaviour from the slurm batch script. 77 | 78 | 2. Your earlier training job should by now have timed out without completing the training. Use your modified script to resume training from the last checkpoint. 79 | 80 | > **Important** 81 | > 82 | > Make sure to use the same `MODEL_NAME` (and paths) when resuming training from the checkpoint. 83 | 84 | > **Note** 85 | > 86 | > If your earlier training job is still running, you can stop it using the `scancel` command. 87 | 88 | ## Solutions 89 | 90 | The folder `reference_solution/` contains an example solution for this exercise parts 1 and 2. `reference_solution/resume_from_checkpoint/` additionally contains the changes for part 4 91 | -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/images/lumi_web_interface_edit_file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/ce4d7cc37162350ff0d620d0301b8f9e7f0b5c98/03_Your_first_AI_training_job_on_LUMI/images/lumi_web_interface_edit_file.png -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/reference_solution/GPT-neo-IMDB-finetuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # IMDB movie review text generation 5 | # 6 | # In this script, we'll fine-tune a GPT2-like model to generate more 7 | # movie reviews based on a prompt. 8 | # 9 | # Partly based on this tutorial: 10 | # https://github.com/omidiu/GPT-2-Fine-Tuning/ 11 | 12 | 13 | import argparse 14 | import math 15 | import time 16 | from pprint import pprint 17 | 18 | import torch 19 | from datasets import load_dataset 20 | from util import preprocess_data, get_output_paths 21 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 22 | DataCollatorForLanguageModeling, Trainer, 23 | TrainingArguments) 24 | 25 | if __name__ == "__main__": 26 | 27 | # First we set up some command line arguments to allow us to specify data/output paths 28 | # and the number of worker processes without changing the code. 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--model-name", 32 | type=str, 33 | default="gpt-imdb-model", 34 | help="A name for the trained model under. A subdirectory with the given name will be created under the `output-path`.", 35 | ) 36 | parser.add_argument( 37 | "--output-path", 38 | type=str, 39 | help="The root directory under which model checkpoints are stored.", 40 | ) 41 | parser.add_argument( 42 | "--logging-path", 43 | type=str, 44 | help="The root directory under which logging data (for tensorboard) are stored.", 45 | ) 46 | parser.add_argument( 47 | "--num-workers", 48 | type=int, 49 | default=1, 50 | help="The number of CPU worker processes to use.", 51 | ) 52 | args, _ = parser.parse_known_args() 53 | 54 | # Then we determine the device on which to train the model. 55 | print("Using PyTorch version:", torch.__version__) 56 | if torch.cuda.is_available(): 57 | device = torch.device("cuda") 58 | print("Using GPU, device name:", torch.cuda.get_device_name(device)) 59 | else: 60 | print("No GPU found, using CPU instead.") 61 | device = torch.device("cpu") 62 | 63 | # We also ensure that output paths exist 64 | output_dir, logging_dir = get_output_paths(args) 65 | 66 | # #### Loading the GPT-neo model 67 | # 68 | # We'll use the gpt-neo-1.3B model from the Hugging Face library: 69 | # https://huggingface.co/EleutherAI/gpt-neo-1.3B 70 | # Let's start with getting the appropriate tokenizer. 71 | pretrained_model = "EleutherAI/gpt-neo-1.3B" 72 | 73 | print("Loading model and tokenizer") 74 | start = time.time() 75 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 76 | tokenizer.pad_token = tokenizer.eos_token 77 | 78 | # Load the actual base model from Hugging Face 79 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 80 | model.to(device) 81 | stop = time.time() 82 | print(f"Loading model and tokenizer took: {stop-start:.2f} seconds") 83 | 84 | # #### Loading the IMDb data set 85 | # 86 | # Next we'll load the IMDb data set: https://huggingface.co/docs/datasets/index. 87 | # 88 | # The data set contains 100,000 movies reviews from the Internet Movie 89 | # Database, split into 25,000 reviews for training and 25,000 reviews 90 | # for testing and 50,000 without labels (unsupervised) that we also use for training. 91 | 92 | train_dataset = load_dataset( 93 | "imdb", split="train+unsupervised", trust_remote_code=False, keep_in_memory=True 94 | ) 95 | eval_dataset = load_dataset( 96 | "imdb", split="test", trust_remote_code=False, keep_in_memory=True 97 | ) 98 | 99 | # Let's print one sample from the dataset. 100 | print("Sample from dataset") 101 | pprint(train_dataset[200]) 102 | 103 | # #### Setting up the training configuration 104 | train_batch_size = 32 # This just about fits into the VRAM of a single MI250x GCD with 16-bit floats 105 | eval_batch_size = 128 # No optimizer state during evaluation, so can use bigger batches for increased throughput 106 | 107 | training_args = TrainingArguments( 108 | output_dir=output_dir, 109 | save_strategy="steps", 110 | save_steps=100, 111 | save_total_limit=4, 112 | logging_dir=logging_dir, 113 | eval_strategy="steps", 114 | eval_steps=200, # compute validation loss every 200 steps 115 | learning_rate=2e-5, 116 | weight_decay=0.01, 117 | bf16=True, # use 16-bit floating point precision 118 | per_device_train_batch_size=train_batch_size, 119 | per_device_eval_batch_size=eval_batch_size, 120 | max_steps=1000, 121 | dataloader_num_workers=args.num_workers, 122 | dataloader_pin_memory=True, 123 | report_to=["tensorboard"], # log statistics for tensorboard 124 | ) 125 | 126 | # #### Preprocessing of training data 127 | # We tokenize the data into torch tensors, split training into training and validation and set up a collator that 128 | # is able to arrange single data samples into batches. 129 | 130 | train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized = preprocess_data(train_dataset, eval_dataset, tokenizer, training_args) 131 | 132 | collator = DataCollatorForLanguageModeling( 133 | tokenizer, mlm=False, return_tensors="pt" 134 | ) 135 | 136 | # Sanity check: How does the training data look like after preprocessing? 137 | print("Sample of tokenized data") 138 | for b in train_dataset_tokenized: 139 | pprint(b, compact=True) 140 | print("Length of input_ids:", len(b["input_ids"])) 141 | break 142 | print("Length of dataset (tokenized)", len(train_dataset_tokenized)) 143 | 144 | # #### Training 145 | # We use the Hugging Face trainer instead of a manual training loop. 146 | # 147 | # You can read about the many, many different parameters to the 148 | # Hugging Face trainer here: 149 | # https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments 150 | # 151 | 152 | collator = DataCollatorForLanguageModeling( 153 | tokenizer, mlm=False, return_tensors="pt" 154 | ) 155 | 156 | trainer = Trainer( 157 | model=model, 158 | args=training_args, 159 | tokenizer=tokenizer, 160 | data_collator=collator, 161 | train_dataset=train_dataset_tokenized, 162 | eval_dataset=validate_dataset_tokenized, 163 | ) 164 | 165 | # With 1000 steps, batch size 32 and a single GCD, this should take just under 30 minutes. 166 | trainer.train() 167 | 168 | print() 169 | print("Training done, you can find all the model checkpoints in", output_dir) 170 | 171 | # #### Evaluating the finetuned model 172 | with torch.no_grad(): 173 | model.eval() 174 | # Calculate perplexity 175 | eval_results = trainer.evaluate() 176 | test_results = trainer.evaluate(eval_dataset_tokenized) 177 | 178 | print(f'Perplexity on validation: {math.exp(eval_results["eval_loss"]):.2f}') 179 | print(f'Perplexity on test: {math.exp(test_results["eval_loss"]):.2f}') 180 | 181 | # Let's print a few sample generated reviews; this is the same as in the previous exercise 182 | # but now we use the finetuned model 183 | prompt = "The movie 'How to run ML on LUMI - A documentation' was great because" 184 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 185 | outputs = model.generate( 186 | **inputs, do_sample=True, max_length=80, num_return_sequences=4 187 | ) 188 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 189 | 190 | print("Sample generated review:") 191 | for txt in decoded_outputs: 192 | print("-", txt) 193 | -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/reference_solution/resume_from_checkpoint/GPT-neo-IMDB-finetuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # IMDB movie review text generation 5 | # 6 | # In this script, we'll fine-tune a GPT2-like model to generate more 7 | # movie reviews based on a prompt. 8 | # 9 | # Partly based on this tutorial: 10 | # https://github.com/omidiu/GPT-2-Fine-Tuning/ 11 | 12 | 13 | import argparse 14 | import math 15 | import time 16 | from pprint import pprint 17 | 18 | import torch 19 | from datasets import load_dataset 20 | from util import preprocess_data, get_output_paths 21 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 22 | DataCollatorForLanguageModeling, Trainer, 23 | TrainingArguments) 24 | 25 | if __name__ == "__main__": 26 | 27 | # First we set up some command line arguments to allow us to specify data/output paths 28 | # and the number of worker processes without changing the code. 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--model-name", 32 | type=str, 33 | default="gpt-imdb-model", 34 | help="A name for the trained model under. A subdirectory with the given name will be created under the `output-path`.", 35 | ) 36 | parser.add_argument( 37 | "--output-path", 38 | type=str, 39 | help="The root directory under which model checkpoints are stored.", 40 | ) 41 | parser.add_argument( 42 | "--logging-path", 43 | type=str, 44 | help="The root directory under which logging data (for tensorboard) are stored.", 45 | ) 46 | parser.add_argument( 47 | "--num-workers", 48 | type=int, 49 | default=1, 50 | help="The number of CPU worker processes to use.", 51 | ) 52 | parser.add_argument( 53 | "--resume", 54 | default=False, 55 | action="store_true", 56 | help="If set, continue from a previously interrupted run. Otherwise, overwrite existing checkpoints.", 57 | ) 58 | args, _ = parser.parse_known_args() 59 | 60 | # Then we determine the device on which to train the model. 61 | print("Using PyTorch version:", torch.__version__) 62 | if torch.cuda.is_available(): 63 | device = torch.device("cuda") 64 | print("Using GPU, device name:", torch.cuda.get_device_name(device)) 65 | else: 66 | print("No GPU found, using CPU instead.") 67 | device = torch.device("cpu") 68 | 69 | # We also ensure that output paths exist 70 | output_dir, logging_dir = get_output_paths(args) 71 | 72 | # #### Loading the GPT-neo model 73 | # 74 | # We'll use the gpt-neo-1.3B model from the Hugging Face library: 75 | # https://huggingface.co/EleutherAI/gpt-neo-1.3B 76 | # Let's start with getting the appropriate tokenizer. 77 | pretrained_model = "EleutherAI/gpt-neo-1.3B" 78 | 79 | print("Loading model and tokenizer") 80 | start = time.time() 81 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 82 | tokenizer.pad_token = tokenizer.eos_token 83 | 84 | # Load the actual base model from Hugging Face 85 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 86 | model.to(device) 87 | stop = time.time() 88 | print(f"Loading model and tokenizer took: {stop-start:.2f} seconds") 89 | 90 | # #### Loading the IMDb data set 91 | # 92 | # Next we'll load the IMDb data set: https://huggingface.co/docs/datasets/index. 93 | # 94 | # The data set contains 100,000 movies reviews from the Internet Movie 95 | # Database, split into 25,000 reviews for training and 25,000 reviews 96 | # for testing and 50,000 without labels (unsupervised) that we also use for training. 97 | 98 | train_dataset = load_dataset( 99 | "imdb", split="train+unsupervised", trust_remote_code=False, keep_in_memory=True 100 | ) 101 | eval_dataset = load_dataset( 102 | "imdb", split="test", trust_remote_code=False, keep_in_memory=True 103 | ) 104 | 105 | # Let's print one sample from the dataset. 106 | print("Sample from dataset") 107 | pprint(train_dataset[200]) 108 | 109 | # #### Setting up the training configuration 110 | train_batch_size = 32 # This just about fits into the VRAM of a single MI250x GCD with 16-bit floats 111 | eval_batch_size = 128 # No optimizer state during evaluation, so can use bigger batches for increased throughput 112 | 113 | training_args = TrainingArguments( 114 | output_dir=output_dir, 115 | overwrite_output_dir=not args.resume, 116 | save_strategy="steps", 117 | save_steps=100, 118 | save_total_limit=4, 119 | logging_dir=logging_dir, 120 | eval_strategy="steps", 121 | eval_steps=200, # compute validation loss every 200 steps 122 | learning_rate=2e-5, 123 | weight_decay=0.01, 124 | bf16=True, # use 16-bit floating point precision 125 | per_device_train_batch_size=train_batch_size, 126 | per_device_eval_batch_size=eval_batch_size, 127 | max_steps=1000, 128 | dataloader_num_workers=args.num_workers, 129 | dataloader_pin_memory=True, 130 | report_to=["tensorboard"], # log statistics for tensorboard 131 | ) 132 | 133 | # #### Preprocessing of training data 134 | # We tokenize the data into torch tensors, split training into training and validation and set up a collator that 135 | # is able to arrange single data samples into batches. 136 | 137 | train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized = preprocess_data(train_dataset, eval_dataset, tokenizer, training_args) 138 | 139 | collator = DataCollatorForLanguageModeling( 140 | tokenizer, mlm=False, return_tensors="pt" 141 | ) 142 | 143 | # Sanity check: How does the training data look like after preprocessing? 144 | print("Sample of tokenized data") 145 | for b in train_dataset_tokenized: 146 | pprint(b, compact=True) 147 | print("Length of input_ids:", len(b["input_ids"])) 148 | break 149 | print("Length of dataset (tokenized)", len(train_dataset_tokenized)) 150 | 151 | # #### Training 152 | # We use the Hugging Face trainer instead of a manual training loop. 153 | # 154 | # You can read about the many, many different parameters to the 155 | # Hugging Face trainer here: 156 | # https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments 157 | # 158 | 159 | collator = DataCollatorForLanguageModeling( 160 | tokenizer, mlm=False, return_tensors="pt" 161 | ) 162 | 163 | trainer = Trainer( 164 | model=model, 165 | args=training_args, 166 | tokenizer=tokenizer, 167 | data_collator=collator, 168 | train_dataset=train_dataset_tokenized, 169 | eval_dataset=validate_dataset_tokenized, 170 | ) 171 | 172 | # With 1000 steps, batch size 32 and a single GCD, this should take just under 30 minutes. 173 | trainer.train(resume_from_checkpoint=args.resume) 174 | 175 | print() 176 | print("Training done, you can find all the model checkpoints in", output_dir) 177 | 178 | # #### Evaluating the finetuned model 179 | with torch.no_grad(): 180 | model.eval() 181 | # Calculate perplexity 182 | eval_results = trainer.evaluate() 183 | test_results = trainer.evaluate(eval_dataset_tokenized) 184 | 185 | print(f'Perplexity on validation: {math.exp(eval_results["eval_loss"]):.2f}') 186 | print(f'Perplexity on test: {math.exp(test_results["eval_loss"]):.2f}') 187 | 188 | # Let's print a few sample generated reviews; this is the same as in the previous exercise 189 | # but now we use the finetuned model 190 | prompt = "The movie 'How to run ML on LUMI - A documentation' was great because" 191 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 192 | outputs = model.generate( 193 | **inputs, do_sample=True, max_length=80, num_return_sequences=4 194 | ) 195 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 196 | 197 | print("Sample generated review:") 198 | for txt in decoded_outputs: 199 | print("-", txt) 200 | -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/reference_solution/resume_from_checkpoint/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_1 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=small-g 5 | #SBATCH --gpus-per-node=1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=7 8 | #SBATCH --mem-per-gpu=60G 9 | #SBATCH --time=0:15:00 10 | 11 | # Set up the software environment 12 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 13 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 14 | # If you are interested, you can check the exact paths being mounted from 15 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 16 | module purge 17 | module use /appl/local/containers/ai-modules 18 | module load singularity-AI-bindings 19 | 20 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 21 | 22 | # Some environment variables to set up cache directories 23 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 24 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 25 | export TORCH_HOME=$SCRATCH/torch-cache 26 | export HF_HOME=$FLASH/hf-cache 27 | mkdir -p $TORCH_HOME $HF_HOME 28 | 29 | # Disable internal parallelism of huggingface's tokenizer since we 30 | # want to retain direct control of parallelism options. 31 | export TOKENIZERS_PARALLELISM=false 32 | 33 | # Path to where the trained model and logging data will go 34 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 35 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 36 | 37 | set -xv # print the command so that we can verify setting arguments correctly from the logs 38 | srun singularity exec $CONTAINER \ 39 | python GPT-neo-IMDB-finetuning.py \ 40 | --model-name gpt-imdb-model \ 41 | --output-path $OUTPUT_DIR \ 42 | --logging-path $LOGGING_DIR \ 43 | --num-workers ${SLURM_CPUS_PER_TASK} \ 44 | --resume # Comment this for the first run, uncomment to resume in later runs 45 | -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/reference_solution/resume_from_checkpoint/util.py: -------------------------------------------------------------------------------- 1 | ../util.py -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/reference_solution/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_1 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=small-g 5 | #SBATCH --gpus-per-node=1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --cpus-per-task=7 8 | #SBATCH --mem-per-gpu=60G 9 | #SBATCH --time=0:15:00 10 | 11 | # Set up the software environment 12 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 13 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 14 | # If you are interested, you can check the exact paths being mounted from 15 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 16 | module purge 17 | module use /appl/local/containers/ai-modules 18 | module load singularity-AI-bindings 19 | 20 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 21 | 22 | # Some environment variables to set up cache directories 23 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 24 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 25 | export TORCH_HOME=$SCRATCH/torch-cache 26 | export HF_HOME=$FLASH/hf-cache 27 | mkdir -p $TORCH_HOME $HF_HOME 28 | 29 | # Disable internal parallelism of huggingface's tokenizer since we 30 | # want to retain direct control of parallelism options. 31 | export TOKENIZERS_PARALLELISM=false 32 | 33 | # Path to where the trained model and logging data will go 34 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 35 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 36 | 37 | set -xv # print the command so that we can verify setting arguments correctly from the logs 38 | srun singularity exec $CONTAINER \ 39 | python GPT-neo-IMDB-finetuning.py \ 40 | --model-name gpt-imdb-model \ 41 | --output-path $OUTPUT_DIR \ 42 | --logging-path $LOGGING_DIR \ 43 | --num-workers ${SLURM_CPUS_PER_TASK} 44 | -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/reference_solution/util.py: -------------------------------------------------------------------------------- 1 | ../util.py -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_1 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=... 5 | ## 6 | 7 | # Set up the software environment 8 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 9 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 10 | # If you are interested, you can check the exact paths being mounted from 11 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 12 | module purge 13 | module use /appl/local/containers/ai-modules 14 | module load singularity-AI-bindings 15 | 16 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 17 | 18 | # Some environment variables to set up cache directories 19 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 20 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 21 | export TORCH_HOME=$SCRATCH/torch-cache 22 | export HF_HOME=$FLASH/hf-cache 23 | mkdir -p $TORCH_HOME $HF_HOME 24 | 25 | # Disable internal parallelism of huggingface's tokenizer since we 26 | # want to retain direct control of parallelism options. 27 | export TOKENIZERS_PARALLELISM=false 28 | 29 | # Path to where the trained model and logging data will go 30 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 31 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 32 | 33 | ## 34 | -------------------------------------------------------------------------------- /03_Your_first_AI_training_job_on_LUMI/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import PreTrainedTokenizerFast, TrainingArguments 3 | import argparse 4 | 5 | 6 | def preprocess_data(train_dataset, eval_dataset, tokenizer: PreTrainedTokenizerFast, training_args: TrainingArguments): 7 | """ Transforms the labelled IMDb data into tokenized tensors for LLM training; further splits train_dataset into training and validation sets. 8 | 9 | Arguments: 10 | - train_dataset: IMDb training data set split, as loaded by load_dataset. 11 | - eval_dataset: IMDb testing data set split, as loaded by load_dataset. 12 | - tokenizer: The tokenizer used with the model to be trained. 13 | - training_args: The TrainingArguments used for training, to get batch_size and number of workers. 14 | Returns: 15 | tuple (train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized) where 16 | - train_dataset_tokenized and validate_dataset_tokenized are the tokenized version of train_dataset with an additional subdivision, 17 | - eval_dataset_tokenized is the tokenized version of eval_dataset. 18 | """ 19 | 20 | # IMDb examples are presented as a dictionary: 21 | # { 22 | # 'text': the review text as a string, 23 | # 'label': a sentiment label as an integer, 24 | # }. 25 | # We tokenize the text and add the special token for indicating the end of the 26 | # text at the end of each review. We also truncate reviews to a maximum 27 | # length to avoid excessively long sequences during training. 28 | # As we have no use for the label, we discard it. 29 | max_length = 256 30 | 31 | def tokenize(x): 32 | texts = [example + tokenizer.eos_token for example in x["text"]] 33 | return tokenizer( 34 | texts, 35 | max_length=max_length, 36 | truncation=True, 37 | add_special_tokens=True, 38 | return_overflowing_tokens=True, 39 | return_length=False, 40 | ) 41 | 42 | train_dataset_tokenized = train_dataset.map( 43 | tokenize, 44 | remove_columns=["text", "label"], 45 | batched=True, 46 | batch_size=training_args.train_batch_size, 47 | num_proc=training_args.dataloader_num_workers, 48 | ) 49 | 50 | eval_dataset_tokenized = eval_dataset.map( 51 | tokenize, 52 | remove_columns=["text", "label"], 53 | batched=True, 54 | num_proc=training_args.dataloader_num_workers, 55 | ) 56 | 57 | # We split a small amount of training data as "validation" test set to keep track of evaluation 58 | # of the loss on non-training data during training. 59 | # This is purely because computing the loss on the full evaluation dataset takes much longer. 60 | train_validate_splits = train_dataset_tokenized.train_test_split( 61 | test_size=1000, seed=42, keep_in_memory=True 62 | ) 63 | train_dataset_tokenized = train_validate_splits["train"] 64 | validate_dataset_tokenized = train_validate_splits["test"] 65 | 66 | return train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized 67 | 68 | def get_output_paths(args: argparse.Namespace): 69 | """ Creates the final output and logging paths from command line arguments and creates the folders, if needed. 70 | 71 | Arguments: 72 | - args: Namespace object of parsed command line arguments as returned by argparse.ArgumentParser().parse_args() 73 | 74 | Returns: 75 | tuple (output_dir, logging_dir) where 76 | - output_dir: the path of the directory in which model checkpoints are to be stored, 77 | - logging_dir: the path of the directory in which tensorboard logging data are to be stored. 78 | """ 79 | # this is where trained model and checkpoints will go 80 | output_dir = os.path.join(args.output_path, args.model_name) 81 | os.makedirs(output_dir, exist_ok=True) 82 | 83 | # this is where tensorboard logging outputs will go 84 | logging_dir = os.path.join(args.logging_path, args.model_name) 85 | os.makedirs(logging_dir, exist_ok=True) 86 | 87 | return output_dir, logging_dir 88 | 89 | -------------------------------------------------------------------------------- /04_Understanding_GPU_activity_and_checking_jobs/README.md: -------------------------------------------------------------------------------- 1 | # 04 Understanding GPU activity & checking jobs 2 | 3 | These examples are based on the ROCm container provided to you at: 4 | ``` 5 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif 6 | ``` 7 | 8 | To avoid running into any storage issues, we recomment running the examples from a folder you create in the scratch file system, e.g.: 9 | ``` 10 | mkdir -p /scratch/project_465001958/$(whoami) 11 | cd /scratch/project_465001958/$(whoami) 12 | ``` 13 | 14 | The examples also assume there is an allocation in place to be used for one or more nodes. That could be accomplished with, e.g.: 15 | ``` 16 | salloc -p small-g --account=project_465001958 --reservation=AI_workshop_1 --gpus-per-node=2 --ntasks-per-node=1 --cpus-per-task=14 --mem-per-gpu=60G --time=0:30:00 17 | ``` 18 | This is very similiar to what you have been doing with `sbatch` should you be using a run script with: 19 | ``` 20 | #SBATCH --account=project_465001958 21 | #SBATCH --reservation=AI_workshop_1 22 | #SBATCH --partition=small-g 23 | #SBATCH --gpus-per-node=1 24 | #SBATCH --ntasks-per-node=1 25 | #SBATCH --cpus-per-task=7 26 | #SBATCH --mem-per-gpu=60G 27 | #SBATCH --time=0:30:00 28 | ``` 29 | The difference is that it gives you a mechanism to just allocate the nodes without running anything. You can then issue `srun` commands interactively which can be useful to experiment more easily. You are always welcome to transition to use `sbatch` if that is preferred. 30 | 31 | 39 | 40 | With the allocation and container set we can do a quick smoke test to make sure Pytorch can detect the GPUs available in a node: 41 | ``` 42 | srun singularity exec \ 43 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 44 | bash -c '$WITH_CONDA ; \ 45 | python -c "import torch; print(torch.cuda.device_count())"' 46 | ``` 47 | 48 | It should yield `2` given that only two GPUs were requested. Note that each time a node is used for the first time, there is a latency to have the container loaded. Running the command above again on the same allocation should complete faster. 49 | 50 | ## Hands-on exercise 51 | 52 | We will leverage here the same LLM example as before with small adaptations. No extra files are needed. You might be interested in collating the different steps in a batch script or run interactively as presented. 53 | 54 | ### 1. Let's recover our LLM example. 55 | Here we'll recover our fine-tunning example for IMDB movie review generation: 56 | 57 | ``` 58 | curl -o GPT-neo-IMDB-finetuning.py -L https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/raw/main/03_Your_first_AI_training_job_on_LUMI/reference_solution/GPT-neo-IMDB-finetuning.py 59 | curl -o util.py -L https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/raw/main/03_Your_first_AI_training_job_on_LUMI/util.py 60 | ``` 61 | 62 | ### 2. Spin training work 63 | We can now run our training as: 64 | 65 | ``` 66 | mkdir -p torch-cache hf-cache 67 | 68 | srun -n1 singularity exec \ 69 | -B .:/workdir \ 70 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 71 | bash -c '$WITH_CONDA ; cd /workdir ; \ 72 | HIP_VISIBLE_DEVICES=0 \ 73 | TORCH_HOME=/workdir/torch-cache \ 74 | HF_HOME=/workdir/hf-cache \ 75 | TOKENIZERS_PARALLELISM=false \ 76 | python -u /workdir/GPT-neo-IMDB-finetuning.py \ 77 | --model-name gpt-imdb-model \ 78 | --output-path /workdir/train-output \ 79 | --logging-path /workdir/train-logging \ 80 | --num-workers 7' 81 | ``` 82 | 83 | While the training runs, let's discover what is the CPU/GPU activity. Note that we are leveraging an allocation with 2 logical GPUs, so we are limiting visibility with the variable `HIP_VISIBLE_DEVICES`. Given that the actually GPU chip has two GCDs (logical GPUs) is better to try monitor on the actually GPU, and not just half of it. 84 | 85 | ### 3. Monitoring GPU activity 86 | 87 | Monitoring in a separate tab can be done by checking you jobID and connect to the first node of the allocation. E.g.: 88 | 89 | * Get jobID - in this case `7100665`: 90 | ``` 91 | squeue --me 92 | JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON) 93 | 7100665 small-g interact samantao R 1:03:21 1 nid005021 94 | ... 95 | ``` 96 | * Start interactive parallel session: 97 | ``` 98 | srun --jobid 7100665 --interactive --pty /bin/bash 99 | ``` 100 | * Use `rocm-smi` to monitor GPU activity: 101 | ``` 102 | watch -n1 rocm-smi 103 | ``` 104 | This will give a snapshot of the GPU utilization captured by the driver every second: 105 | ``` 106 | ======================= ROCm System Management Interface ======================= 107 | ================================= Concise Info ================================= 108 | GPU Temp AvgPwr SCLK MCLK Fan Perf PwrCap VRAM% GPU% 109 | 0 58.0c 324.0W 1650Mhz 1600Mhz 0% manual 500.0W 98% 100% 110 | 1 49.0c N/A 800Mhz 1600Mhz 0% manual 0.0W 0% 0% 111 | ================================================================================ 112 | ============================= End of ROCm SMI Log ============================== 113 | ``` 114 | As expected we only have activity on one GCD but the power metrics are per GPU. Note that these numbers needs to be interpreted. For example, if `GPU%` shows `100%` that does NOT necessarily mean the GPU is being well utilized. A better metric is drawn power `AvgPwr`: oscillating around `500.0W` is an indication there is significant compute activity on the full GPU. 115 | 116 | Here we see drawn power to oscillate around `300.0W` while a single GCD is being used, which is an indication that we might be compute bound. 117 | 118 | ### 4. Activate logging reporting GPU activity 119 | 120 | Other ways to understand the activity connected to GPU-enabled libraries is to enable logging messages for these libraries. Here are some examples: 121 | 122 | * `AMD_LOG_LEVEL=4` - this captures the HIP runtime activity used to copy data and issue kernels into the GPU. 123 | 124 | * `MIOPEN_ENABLE_LOGGING=1` - this captures API activity for the MIOpen library that provides optimized kernels for AI applications. Your application might not use that though, 125 | 126 | So, running the following: 127 | ``` 128 | srun -n1 singularity exec \ 129 | -B .:/workdir \ 130 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 131 | bash -c '$WITH_CONDA ; cd /workdir ; \ 132 | HIP_VISIBLE_DEVICES=0 \ 133 | AMD_LOG_LEVEL=4 \ 134 | TORCH_HOME=/workdir/torch-cache \ 135 | HF_HOME=/workdir/hf-cache \ 136 | TOKENIZERS_PARALLELISM=false \ 137 | python -u /workdir/GPT-neo-IMDB-finetuning.py \ 138 | --model-name gpt-imdb-model \ 139 | --output-path /workdir/train-output \ 140 | --logging-path /workdir/train-logging \ 141 | --num-workers 7' 142 | ``` 143 | would return something like the following for a given kernel and its dispatch configuration: 144 | ``` 145 | :3:hip_module.cpp :662 : 117659918626 us: 8088 : [tid:0x14b2015e9700] hipLaunchKernel ( 0x14b5ec183ed0, {32768,1,1}, {512,1,1}, 0x14b2015e71b0, 0, stream: ) :4:command.cpp :349 : 117659918630 us: 8088 : [tid:0x14b2015e9700] Command (KernelExecution) enqueued: 0x14b151fe3b00 :3:rocvirtual.cpp :786 : 117659918634 us: 8088 : [tid:0x14b2015e9700] Arg0: = val:16777216 146 | :3:rocvirtual.cpp :786 : 117659918636 us: 8088 : [tid:0x14b2015e9700] Arg1: = val:22689590804480 :3:rocvirtual.cpp :2853: 117659918639 us: 8088 : [tid:0x14b2015e9700] ShaderName : _ZN2at6native6legacy18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENK 147 | UlvE0_clEvENKUlvE5_clEvEUlfE_EEvS5_RKT_EUliE_EEviT1_ 148 | :4:rocvirtual.cpp :891 : 117659918644 us: 8088 : [tid:0x14b2015e9700] HWq=0x14b30ee00000, Dispatch Header = 0xb02 (type=2, barrier=1, acquire=1, release=1), setup=3, grid=[16777216, 1, 1], workgroup=[512, 1, 1], privat 149 | e_seg_size=0, group_seg_size=0, kernel_obj=0x14b4a5220000, kernarg_address=0x14b30ec73780, completion_signal=0x0 150 | :3:hip_module.cpp :663 : 117659918649 us: 8088 : [tid:0x14b2015e9700] hipLaunchKernel: Returned hipSuccess : 151 | ``` 152 | Try to interpret the different kinds of activity. 153 | 154 | ### 5. Using a profiler to assess GPU activity. 155 | 156 | Another way to check for GPU activity is to use a profiler. There is a GPU profiler included in any ROCm instalation: `ROCprofiler`. This profiler is also available inside the containers, so no extra instalations is required. It has a command-line driver called `rocprof` and you can see the options one can use with: 157 | ``` 158 | srun -n1 singularity exec \ 159 | -B .:/workdir \ 160 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 161 | rocprof --help 162 | ``` 163 | Given that Pytorch uses the HIP runtime in its implementation, one of the most relevant options is `--hip-trace` to instruct the profiler to collect the HIP runtime activity. Another option that is convinient is `--stats` that generates some statistics on the usage of the GPU. 164 | 165 | Just to allow a quicker completion time, let's focus on just a few training steps. For that just open the file `GPT-neo-IMDB-finetuning.py` and replace: 166 | ``` 167 | max_steps=1000, 168 | ``` 169 | with: 170 | ``` 171 | max_steps=10, 172 | ``` 173 | and place a `import sys ; sys.exit(0)` statement after: 174 | ``` 175 | trainer.train() 176 | ``` 177 | 178 | Now we can just run the profiler by preceding our original command with `rocprof`. 179 | 180 | ``` 181 | srun -n1 singularity exec \ 182 | -B .:/workdir \ 183 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 184 | bash -c '$WITH_CONDA ; cd /workdir ; \ 185 | HIP_VISIBLE_DEVICES=0 \ 186 | TORCH_HOME=/workdir/torch-cache \ 187 | HF_HOME=/workdir/hf-cache \ 188 | TOKENIZERS_PARALLELISM=false \ 189 | rocprof --hip-trace --stats python -u /workdir/GPT-neo-IMDB-finetuning.py \ 190 | --model-name gpt-imdb-model \ 191 | --output-path /workdir/train-output \ 192 | --logging-path /workdir/train-logging \ 193 | --num-workers 7' 194 | ``` 195 | This will generate a few files named `results.*`. For example, `results.stats.csv` will provide the stats of the kernels that were executed in the GPU in descending order of combined execution time. These, can sometimes be easier to read if imported into a spreadsheet. 196 | 197 | ### 6. Visualizing a profile trace 198 | Other file that might be interesting to look at is `results.json`. This can be loaded into the web app `https://ui.perfetto.dev/v46.0-35b3d9845/#/` and will allow you to visualize the GPU execution. Here is a snapshot of the 10 steps executed: 199 | 200 | ![image](https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/raw/main/04_Understanding_GPU_activity_and_checking_jobs/images/profile.png) 201 | 202 | ### 7. Using Pytorch profiling infrastructure. 203 | 204 | Pytorch already provides profiling infrastruture that captures GPU activity as well as ranges for the CPU activities. It can be loaded with: 205 | ``` 206 | from torch.profiler import profile, ProfilerActivity 207 | ``` 208 | Then, you can identify the part of the code to profile, e.g. a given epoch. At the start of that part you can create and start the `profile` object: 209 | ``` 210 | prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) 211 | prof.start() 212 | ``` 213 | and at the end you can stop and create the profile file to be loaded into Perfetto UI tool mentioned above: 214 | ``` 215 | prof.stop() 216 | prof.export_chrome_trace("trace.json") 217 | ``` 218 | 219 | Let's get our example: 220 | ``` 221 | curl -o GPT-neo-IMDB-finetuning-profile.py -L https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/raw/main/03_Your_first_AI_training_job_on_LUMI/reference_solution/GPT-neo-IMDB-finetuning.py 222 | ``` 223 | Use `max_steps=10` and place the profiler start and end around: 224 | ``` 225 | trainer.train(resume_from_checkpoint=args.resume) 226 | ``` 227 | Run as before: 228 | ``` 229 | srun -n1 singularity exec \ 230 | -B .:/workdir \ 231 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 232 | bash -c '$WITH_CONDA ; cd /workdir ; \ 233 | HIP_VISIBLE_DEVICES=0 \ 234 | TORCH_HOME=/workdir/torch-cache \ 235 | HF_HOME=/workdir/hf-cache \ 236 | TOKENIZERS_PARALLELISM=false \ 237 | python -u /workdir/GPT-neo-IMDB-finetuning-profile.py \ 238 | --model-name gpt-imdb-model \ 239 | --output-path /workdir/train-output \ 240 | --logging-path /workdir/train-logging \ 241 | --num-workers 7' 242 | ``` 243 | Then you can visualize the file `trace.json`. 244 | 245 | A solution `GPT-neo-IMDB-finetuning-profile.py` is available [here](reference_solution/GPT-neo-IMDB-finetuning-profile.py). 246 | -------------------------------------------------------------------------------- /04_Understanding_GPU_activity_and_checking_jobs/images/profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/ce4d7cc37162350ff0d620d0301b8f9e7f0b5c98/04_Understanding_GPU_activity_and_checking_jobs/images/profile.png -------------------------------------------------------------------------------- /04_Understanding_GPU_activity_and_checking_jobs/reference_solution/GPT-neo-IMDB-finetuning-profile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # IMDB movie review text generation 5 | # 6 | # In this script, we'll fine-tune a GPT2-like model to generate more 7 | # movie reviews based on a prompt. 8 | # 9 | # Partly based on this tutorial: 10 | # https://github.com/omidiu/GPT-2-Fine-Tuning/ 11 | 12 | 13 | import argparse 14 | import math 15 | import os 16 | import time 17 | from pprint import pprint 18 | 19 | import torch 20 | from datasets import load_dataset 21 | from util import preprocess_data, get_output_paths 22 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 23 | DataCollatorForLanguageModeling, Trainer, 24 | TrainingArguments) 25 | from torch.profiler import profile, ProfilerActivity 26 | 27 | if __name__ == "__main__": 28 | 29 | # First we set up some command line arguments to allow us to specify data/output paths 30 | # and the number of worker processes without changing the code. 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | "--model-name", 34 | type=str, 35 | default="gpt-imdb-model", 36 | help="A name for the trained model under. A subdirectory with the given name will be created under the `output-path`.", 37 | ) 38 | parser.add_argument( 39 | "--output-path", 40 | type=str, 41 | help="The root directory under which model checkpoints are stored.", 42 | ) 43 | parser.add_argument( 44 | "--logging-path", 45 | type=str, 46 | help="The root directory under which logging data (for tensorboard) are stored.", 47 | ) 48 | parser.add_argument( 49 | "--num-workers", 50 | type=int, 51 | default=1, 52 | help="The number of CPU worker processes to use.", 53 | ) 54 | parser.add_argument( 55 | "--resume", 56 | default=False, 57 | action="store_true", 58 | help="If set, continue from a previously interrupted run. Otherwise, overwrite existing checkpoints.", 59 | ) 60 | args, _ = parser.parse_known_args() 61 | 62 | # Then we determine the device on which to train the model. 63 | print("Using PyTorch version:", torch.__version__) 64 | if torch.cuda.is_available(): 65 | device = torch.device("cuda") 66 | print("Using GPU, device name:", torch.cuda.get_device_name(device)) 67 | else: 68 | print("No GPU found, using CPU instead.") 69 | device = torch.device("cpu") 70 | 71 | # We also ensure that output paths exist 72 | output_dir, logging_dir = get_output_paths(args) 73 | 74 | # #### Loading the GPT-neo model 75 | # 76 | # We'll use the gpt-neo-1.3B model from the Hugging Face library: 77 | # https://huggingface.co/EleutherAI/gpt-neo-1.3B 78 | # Let's start with getting the appropriate tokenizer. 79 | pretrained_model = "EleutherAI/gpt-neo-1.3B" 80 | 81 | print("Loading model and tokenizer") 82 | start = time.time() 83 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 84 | tokenizer.pad_token = tokenizer.eos_token 85 | 86 | # Load the actual base model from Hugging Face 87 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 88 | model.to(device) 89 | stop = time.time() 90 | print(f"Loading model and tokenizer took: {stop-start:.2f} seconds") 91 | 92 | # #### Loading the IMDb data set 93 | # 94 | # Next we'll load the IMDb data set: https://huggingface.co/docs/datasets/index. 95 | # 96 | # The data set contains 100,000 movies reviews from the Internet Movie 97 | # Database, split into 25,000 reviews for training and 25,000 reviews 98 | # for testing and 50,000 without labels (unsupervised) that we also use for training. 99 | 100 | train_dataset = load_dataset( 101 | "imdb", split="train+unsupervised", trust_remote_code=False, keep_in_memory=True 102 | ) 103 | eval_dataset = load_dataset( 104 | "imdb", split="test", trust_remote_code=False, keep_in_memory=True 105 | ) 106 | 107 | # Let's print one sample from the dataset. 108 | print("Sample from dataset") 109 | for b in train_dataset: 110 | pprint(b) 111 | break 112 | 113 | # #### Setting up the training configuration 114 | train_batch_size = 32 # This just about fits into the VRAM of a single MI250x GCD with 16-bit floats 115 | eval_batch_size = 128 # No optimizer state during evaluation, so can use bigger batches for increased throughput 116 | 117 | training_args = TrainingArguments( 118 | output_dir=output_dir, 119 | overwrite_output_dir=not args.resume, 120 | save_strategy="steps", 121 | save_steps=100, 122 | save_total_limit=4, 123 | logging_dir=logging_dir, 124 | evaluation_strategy="steps", 125 | eval_steps=200, # compute validation loss every 200 steps 126 | learning_rate=2e-5, 127 | weight_decay=0.01, 128 | bf16=True, # use 16-bit floating point precision 129 | per_device_train_batch_size=train_batch_size, 130 | per_device_eval_batch_size=eval_batch_size, 131 | max_steps=10, 132 | dataloader_num_workers=args.num_workers, 133 | dataloader_pin_memory=True, 134 | report_to=["tensorboard"], # log statistics for tensorboard 135 | ) 136 | 137 | # #### Preprocessing of training data 138 | # We tokenize the data into torch tensors, split training into training and validation and set up a collator that 139 | # is able to arrange single data samples into batches. 140 | 141 | train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized = preprocess_data(train_dataset, eval_dataset, tokenizer, training_args) 142 | 143 | collator = DataCollatorForLanguageModeling( 144 | tokenizer, mlm=False, return_tensors="pt" 145 | ) 146 | 147 | # Sanity check: How does the training data look like after preprocessing? 148 | print("Sample of tokenized data") 149 | for b in train_dataset_tokenized: 150 | pprint(b, compact=True) 151 | print("Length of input_ids:", len(b["input_ids"])) 152 | break 153 | print("Length of dataset (tokenized)", len(train_dataset_tokenized)) 154 | 155 | # #### Training 156 | # We use the Hugging Face trainer instead of a manual training loop. 157 | # 158 | # You can read about the many, many different parameters to the 159 | # Hugging Face trainer here: 160 | # https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments 161 | # 162 | 163 | collator = DataCollatorForLanguageModeling( 164 | tokenizer, mlm=False, return_tensors="pt" 165 | ) 166 | 167 | trainer = Trainer( 168 | model=model, 169 | args=training_args, 170 | tokenizer=tokenizer, 171 | data_collator=collator, 172 | train_dataset=train_dataset_tokenized, 173 | eval_dataset=validate_dataset_tokenized, 174 | ) 175 | 176 | # With 1000 steps, batch size 32 and a single GCD, this should take just under 30 minutes. 177 | 178 | prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) 179 | prof.start() 180 | trainer.train(resume_from_checkpoint=args.resume) 181 | prof.stop() 182 | prof.export_chrome_trace("trace.json") 183 | 184 | print() 185 | print("Training done, you can find all the model checkpoints in", output_dir) 186 | 187 | import sys 188 | sys.exit(0) 189 | 190 | # #### Evaluating the finetuned model 191 | with torch.no_grad(): 192 | model.eval() 193 | # Calculate perplexity 194 | eval_results = trainer.evaluate() 195 | test_results = trainer.evaluate(eval_dataset_tokenized) 196 | 197 | print(f'Perplexity on validation: {math.exp(eval_results["eval_loss"]):.2f}') 198 | print(f'Perplexity on test: {math.exp(test_results["eval_loss"]):.2f}') 199 | 200 | # Let's print a few sample generated reviews; this is the same as in the previous exercise 201 | # but now we use the finetuned model 202 | prompt = "The movie 'How to run ML on LUMI - A documentation' was great because" 203 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 204 | outputs = model.generate( 205 | **inputs, do_sample=True, max_length=80, num_return_sequences=4 206 | ) 207 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 208 | 209 | print("Sample generated review:") 210 | for txt in decoded_outputs: 211 | print("-", txt) 212 | -------------------------------------------------------------------------------- /05_Running_containers_on_LUMI/Hello_LUMI_GPU_World.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import socket 4 | 5 | import torch 6 | 7 | print(f"Hello LUMI GPU World from {socket.gethostname()}") 8 | print("*" * 80) 9 | print(f" - We are running in the Singularity container {os.environ.get('SINGULARITY_CONTAINER', 'N/A')}") 10 | print(f" - We are running Python version {sys.version} from {sys.executable}") 11 | print(f" - The number of GPUs (really GCDs) available to PyTorch is {torch.cuda.device_count()}") 12 | print(f" - Our SLURM job ID is {os.environ.get('SLURM_JOB_ID', 'N/A')}") 13 | print("*" * 80) 14 | -------------------------------------------------------------------------------- /05_Running_containers_on_LUMI/README.md: -------------------------------------------------------------------------------- 1 | # 05 Running containers on LUMI 2 | 3 | ## Examples 4 | 5 | - **Building a container using SingularityCE+proot**: The [build_ubuntu_tree.sh](examples/build_ubuntu_tree.sh) script may be used to build the container `ubuntu_tree.sif` defined in the Singularity definition file [ubutnu_tree.def](examples/ubuntu_tree.def). This new container is simply the latest Ubuntu 22.04 Docker Hub image with the `tree` package installed using the Ubuntu Apt package manager. Note that you only need to load the `systools` module on LUMI to have Singularity automatically pick up `proot` from it to do an unpriviledged build on LUMI, i.e. no need for `root`, `sudo`, `--fakeroot`, etc. 6 | - **Showing the top directory trees on LUMI and in an Ubuntu container**: The [print_directory_trees.sh](examples/print_directory_trees.sh) script may be used on LUMI to print: 7 | 1. The `tree -L 1 /` of the `ubuntu_tree.sif` container 8 | 2. The `tree -L 1 /` of LUMI 9 | 3. The `tree -L 1 /` of `ubuntu_tree.sif` with the `/project/project_465001958` folder from LUMI bind mounted 10 | 11 | ## Hands-on exercises 12 | 13 | 1. Hello (LUMI GPU) world in a container 14 | 15 | In this exercise you get to practice running a Python script inside an official LUMI container on both a LUMI login node and a LUMI-G compute node. 16 | 17 | 1. Select one of the PyTorch containers found in /appl/local/containers/sif-images/ on LUMI. 18 | 2. Run the `Hello_LUMI_GPU_World.py` Python script inside the container on: 19 | - A LUMI login node 20 | - A LUMI-G compute node 21 | 22 | 2. Pulling a Docker container and using it on LUMI 23 | 24 | In this exercise you will learn how to pull and run an existing Docker container on LUMI. 25 | 26 | 1. Pick a container from [Docker Hub](https://hub.docker.com/), e.g. [the official Alpine Docker image](https://hub.docker.com/_/alpine), and pull it to LUMI using Singularity. 27 | - Make sure the Singularity cache is not filling up your home folder (hint: see the [LUMI Docs container page](https://docs.lumi-supercomputer.eu/software/containers/singularity/#pulling-container-images-from-a-registry)) 28 | - Once Singularity has created the SIF file, you can use it like any other container on LUMI. 29 | 30 | 3. Correctly running the official LUMI containers 31 | 32 | In this exercise you will learn to correctly bind mount the necessary CPE bits from LUMI and activate the conda environment in the official LUMI containers. 33 | 34 | 1. Open an interactive shell in the LUMI TensorFlow+Horovod container 35 | 2. Open an interactive Python interpreter in the interactive container shell and (successfully) `import horovod.tensorflow` 36 | -------------------------------------------------------------------------------- /05_Running_containers_on_LUMI/examples/build_ubuntu_tree.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | module load LUMI/24.03 systools # gets us access to proot 3 | singularity build ubuntu_tree.sif ubuntu_tree.def 4 | -------------------------------------------------------------------------------- /05_Running_containers_on_LUMI/examples/print_directory_trees.sh: -------------------------------------------------------------------------------- 1 | # Container namespace 2 | echo "Ubuntu container directory tree" 3 | singularity exec ubuntu_tree.sif tree -L 1 / 4 | 5 | # LUMI namespace 6 | echo "LUMI directory tree" 7 | module load LUMI/24.03 systools 8 | tree -L 1 / 9 | 10 | # Container namespace with /project bind-mounted 11 | echo "Ubuntu container directory tree with /project bind-mounted" 12 | singularity exec --bind /project/project_465001707 ubuntu_tree.sif tree -L 1 / 13 | -------------------------------------------------------------------------------- /05_Running_containers_on_LUMI/examples/ubuntu_tree.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: ubuntu:22.04 3 | 4 | %post 5 | apt update 6 | apt install -y tree 7 | -------------------------------------------------------------------------------- /05_Running_containers_on_LUMI/reference_solution/reference_solution.md: -------------------------------------------------------------------------------- 1 | # Reference solutions to the hands-on exercises for 05 Running containers on LUMI 2 | 3 | ## Exercise 1 4 | 5 | > 1. Select one of the PyTorch containers found in /appl/local/containers/sif-images/ on LUMI. 6 | > 2. Run the `Hello_LUMI_GPU_World.py` Python script inside the container on: 7 | > - A LUMI login node 8 | > - A LUMI-G compute node 9 | 10 | For this exercise, we may use e.g. the lumi-pytorch-rocm-6.2.4-python-3.12-pytorch-v2.6.0.sif container found in /appl/local/containers/sif-images/. 11 | 12 | To run the `Hello_LUMI_GPU_World.py` Python script using one of the LUMI PyTorch container, we must remember to: 13 | 14 | 1. Bind mount the folder in which the `Hello_LUMI_GPU_World.py` script is placed (if not it's not your home folder) 15 | 2. Run the container using `singularity exec` 16 | 3. Activate the conda environment in the container by running `$WITH_CONDA` in the container 17 | 4. Submit the job using `srun` when using a LUMI-G compute node 18 | 19 | On a LUMI login node, it may be done by: 20 | 21 | ```bash 22 | $ module use /appl/local/containers/ai-modules 23 | $ module load singularity-AI-bindings 24 | $ singularity exec /appl/local/containers/sif-images/lumi-pytorch-rocm-6.2.4-python-3.12-pytorch-v2.6.0.sif bash -c "\$WITH_CONDA; python3 Hello_LUMI_GPU_World.py" 25 | Hello LUMI GPU World from uan03 26 | ******************************************************************************** 27 | - We are running in the Singularity container /appl/local/containers/sif-images/lumi-pytorch-rocm-6.2.4-python-3.12-pytorch-v2.6.0.sif 28 | - We are running Python version 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] from /opt/miniconda3/envs/pytorch/bin/python3 29 | /opt/miniconda3/envs/pytorch/lib/python3.12/site-packages/torch/cuda/__init__.py:721: UserWarning: Can't initialize amdsmi - Error code: 34 30 | warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}") 31 | - The number of GPUs (really GCDs) available to PyTorch is 0 32 | - Our SLURM job ID is N/A 33 | ******************************************************************************** 34 | $ 35 | ``` 36 | 37 | On a LUMI-G node, it may be done by: 38 | 39 | ```bash 40 | $ module use /appl/local/containers/ai-modules 41 | $ module load singularity-AI-bindings 42 | $ srun --account=project_465001958 --partition=small-g --time=00:00:30 --nodes=1 --gpus=4 singularity exec /project/project_465001958/containers/lumi-pytorch-rocm-6.2.4-python-3.12-pytorch-v2.6.0.sif bash -c "\$WITH_CONDA; python3 Hello_LUMI_GPU_World.py" 43 | 44 | srun: job 11170342 queued and waiting for resources 45 | srun: job 11170342 has been allocated resources 46 | Hello LUMI GPU World from nid007856 47 | ******************************************************************************** 48 | - We are running in the Singularity container /appl/local/containers/sif-images/lumi-pytorch-rocm-6.2.4-python-3.12-pytorch-v2.6.0.sif 49 | - We are running Python version 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] from /opt/miniconda3/envs/pytorch/bin/python3 50 | - The number of GPUs (really GCDs) available to PyTorch is 4 51 | - Our SLURM job ID is 11170342 52 | ******************************************************************************** 53 | $ 54 | ``` 55 | 56 | > [!IMPORTANT] 57 | > The number of GPUs/GCDs available to PyTorch is based on how many you request from SLURM. The default is 0! 58 | 59 | > [!NOTE] 60 | > It is a good idea to copy the `lumi-pytorch-rocm-6.2.4-python-3.12-pytorch-v2.6.0.sif` container to your project folder and run it from there to enable you to reproduce your results. We may remove or replace `lumi-pytorch-rocm-6.2.4-python-3.12-pytorch-v2.6.0.sif` at any point in time! 61 | 62 | ## Exercise 2 63 | 64 | > 1. Pick a container from [Docker Hub](https://hub.docker.com/), e.g. [the official Alpine Docker image](https://hub.docker.com/_/alpine), and pull it to LUMI using Singularity. 65 | > - Make sure the Singularity cache is not filling up your home folder (hint: see the [LUMI Docs container page](https://docs.lumi-supercomputer.eu/software/containers/singularity/#pulling-container-images-from-a-registry)) 66 | > - Once Singularity has created the SIF file, you can use it like any other container on LUMI. 67 | 68 | To pull containers from Docker Hub without filling up our home folder with Singularity temp/cache files, we must remember to: 69 | 70 | 1. Use the `docker:///:` URI specifier with `singularity pull` 71 | 2. Set `SINGULARITY_TMPDIR` and `SINGULARITY_CACHEDIR` environment variables to another location than our home folder 72 | 73 | Pulling version/tag 3.19.1 of the alpine container on a LUMI login node may be done by: 74 | 75 | ```bash 76 | $ export SINGULARITY_TMPDIR=/tmp/$USER 77 | $ export SINGULARITY_CACHEDIR=/tmp/$USER 78 | $ singularity pull docker://alpine:3.19.1 79 | INFO: Converting OCI blobs to SIF format 80 | INFO: Starting build... 81 | INFO: Fetching OCI image... 82 | 3.3MiB / 3.3MiB [===========================================================] 100 % 24.5 KiB/s 0s 83 | INFO: Extracting OCI image... 84 | INFO: Inserting Singularity configuration... 85 | INFO: Creating SIF file... 86 | $ 87 | $ ls -al alpine_3.19.1.sif 88 | -rwxrwx--- 1 javicher pepr_javicher 3379200 Nov 18 00:22 alpine_3.19.1.sif 89 | $ 90 | ``` 91 | 92 | which generates the `alpine_3.19.1.sif` container. 93 | 94 | > [!IMPORTANT] 95 | > There is no automatic cleaning of `/tmp` on the LUMI login nodes. You have to delete the Singularity temp/cache files under `/tmp/$USER` yourself when you are done pull/building containers! 96 | 97 | ## Exercise 3 98 | 99 | > 1. Open an interactive shell in the LUMI TensorFlow+Horovod container 100 | > 2. Open an interactive Python interpreter in the interactive container shell and (successfully) `import horovod.tensorflow` 101 | 102 | To successfully import Horovod+Tensorflow in the container, we must remember to: 103 | 104 | 1. Open an interactive shell in the container using `singularity shell` 105 | 2. Bind mount the CPE bits when opening the container shell, as Horovod uses MPI that requires parts of the CPE from LUMI 106 | 3. Activate the conda environment in the container by running `$WITH_CONDA` in the container shell 107 | 108 | On a LUMI login node, it may be done by: 109 | 110 | ```bash 111 | $ singularity shell --bind /var/spool/slurmd,/opt/cray,/usr/lib64/libcxi.so.1,/usr/lib64/libjansson.so.4 /appl/local/containers/sif-images/lumi-tensorflow-rocm-6.2.0-python-3.10-tensorflow-2.16.1-horovod-0.28.1.sif 112 | Singularity> $WITH_CONDA 113 | (tensorflow) Singularity> python3 114 | Python 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] on linux 115 | Type "help", "copyright", "credits" or "license" for more information. 116 | >>> import horovod.tensorflow 117 | 2024-11-25 23:13:07.472850: E external/local_xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: FFT 118 | 2024-11-25 23:13:09.704705: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 119 | To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 120 | 2024-11-25 23:13:11.553455: E external/local_xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: DNN 121 | >>> exit() 122 | (tensorflow) Singularity> exit 123 | exit 124 | ``` 125 | 126 | Remember that instead of manually specifying the bind mounts, you may load the `singularity-AI-bindings` module: 127 | 128 | ```bash 129 | $ module use /appl/local/containers/ai-modules 130 | $ module load singularity-AI-bindings 131 | $ singularity shell /appl/local/containers/sif-images/lumi-tensorflow-rocm-6.2.0-python-3.10-tensorflow-2.16.1-horovod-0.28.1.sif 132 | Singularity> $WITH_CONDA 133 | (tensorflow) Singularity> python3 134 | Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux 135 | Type "help", "copyright", "credits" or "license" for more information. 136 | >>> import horovod.tensorflow 137 | 2024-11-25 23:27:34.216421: E external/local_xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: FFT 138 | 2024-11-25 23:27:46.333772: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 139 | To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 140 | 2024-11-25 23:27:52.266194: E external/local_xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: DNN 141 | >>> exit() 142 | (tensorflow) Singularity> exit 143 | exit 144 | ``` 145 | 146 | If you don't bind mount the CPE bits, you will get an error about `libmpi_cray.so.12` not being available: 147 | 148 | ```bash 149 | $ singularity shell /appl/local/containers/sif-images/lumi-tensorflow-rocm-6.2.0-python-3.10-tensorflow-2.16.1-horovod-0.28.1.sif 150 | Singularity> $WITH_CONDA 151 | (tensorflow) Singularity> python3 152 | Python 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] on linux 153 | Type "help", "copyright", "credits" or "license" for more information. 154 | >>> import horovod.tensorflow 155 | 2024-11-18 00:30:03.593024: E external/local_xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: FFT 156 | 2024-11-18 00:30:05.430617: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 157 | To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 158 | 2024-11-18 00:30:07.082782: E external/local_xla/xla/stream_executor/plugin_registry.cc:91] Invalid plugin kind specified: DNN 159 | Traceback (most recent call last): 160 | File "", line 1, in 161 | File "/opt/miniconda3/envs/tensorflow/lib/python3.10/site-packages/horovod/tensorflow/__init__.py", line 27, in 162 | from horovod.tensorflow import elastic 163 | File "/opt/miniconda3/envs/tensorflow/lib/python3.10/site-packages/horovod/tensorflow/elastic.py", line 24, in 164 | from horovod.tensorflow.functions import broadcast_object, broadcast_object_fn, broadcast_variables 165 | File "/opt/miniconda3/envs/tensorflow/lib/python3.10/site-packages/horovod/tensorflow/functions.py", line 24, in 166 | from horovod.tensorflow.mpi_ops import allgather, broadcast, broadcast_ 167 | File "/opt/miniconda3/envs/tensorflow/lib/python3.10/site-packages/horovod/tensorflow/mpi_ops.py", line 53, in 168 | raise e 169 | File "/opt/miniconda3/envs/tensorflow/lib/python3.10/site-packages/horovod/tensorflow/mpi_ops.py", line 50, in 170 | MPI_LIB = _load_library('mpi_lib' + get_ext_suffix()) 171 | File "/opt/miniconda3/envs/tensorflow/lib/python3.10/site-packages/horovod/tensorflow/mpi_ops.py", line 45, in _load_library 172 | library = load_library.load_op_library(filename) 173 | File "/opt/miniconda3/envs/tensorflow/lib/python3.10/site-packages/tensorflow/python/framework/load_library.py", line 54, in load_op_library 174 | lib_handle = py_tf.TF_LoadLibrary(library_filename) 175 | tensorflow.python.framework.errors_impl.NotFoundError: libmpi_cray.so.12: cannot open shared object file: No such file or directory 176 | >>> exit() 177 | Singularity> exit 178 | ``` 179 | 180 | If you don't activate the conda environment, it will use the container default Python, which does not have TensorFlow and Horovod installed: 181 | 182 | ```bash 183 | $ singularity shell --bind /var/spool/slurmd,/opt/cray,/usr/lib64/libcxi.so.1,/usr/lib64/libjansson.so.4 /appl/local/containers/sif-images/lumi-tensorflow-rocm-6.2.0-python-3.10-tensorflow-2.16.1-horovod-0.28.1.sif 184 | Singularity> python3 185 | Python 3.6.15 (default, Sep 23 2021, 15:41:43) [GCC] on linux 186 | Type "help", "copyright", "credits" or "license" for more information. 187 | >>> import horovod.tensorflow 188 | Traceback (most recent call last): 189 | File "", line 1, in 190 | ModuleNotFoundError: No module named 'horovod' 191 | >>> exit() 192 | Singularity> exit 193 | ``` 194 | -------------------------------------------------------------------------------- /06_Bulding_containers_from_conda_pip_environments/README.md: -------------------------------------------------------------------------------- 1 | # 06 Building containers from Conda/pip environments 2 | 3 | ## Examples 4 | 5 | - An example of a complete PandasAI conda environment specification is provided in [PandasAI.yml](examples/PandasAI.yml). 6 | - An absolute minimal conda environment specification only including Python 3.12 is provided in [python312.yml](examples/python312.yml). 7 | - The minimal conda environment PyTorch recipe for LUMI-G is provided in [minimal_pytorch.yml](examples/minimal_pytorch.yml). This environment file can also be used with the `/appl/local/containers/sif-images/lumi-rocm-rocm-6.0.3.sif` base image. 8 | 9 | ## Hands-on exercises 10 | 11 | 1. The basics of using cotainr to build containers on LUMI 12 | 13 | In this exercise you get to practice building containers both interactively and non-interactively on LUMI using cotainr. 14 | 15 | 1. Using the example [python312.yml](examples/python312.yml) conda environment, use cotainr to build a container: 16 | - Interactively on a login node 17 | - Non-interactively on a compute node 18 | 2. Compare the output of running `python3 -c "import sys; print(sys.executable); print(sys.version)"` on a login node: 19 | - In the container you built 20 | - Directly on LUMI 21 | 22 | 2. Making changes to the software environment in the container 23 | 24 | In this exercise you will learn how to add additional packages to your containerized environment using cotainr. 25 | 26 | 1. Using cotainr, update the container you built using the `python312.yml` conda environment to contain a few extra packages of your choice, e.g. pandas and scikit-learn. 27 | 2. Open an interactive Python interpreter in the container and import your newly added packages. 28 | 29 | 3. Creative pip installs using cotainr 30 | 31 | In this exercise you will learn how to install Python packages in a container using cotainr when no conda package or pip wheel exists for the package. 32 | 33 | 1. Check the [panopticapi](https://github.com/cocodataset/panopticapi) GitHub repo for ways to install it from source. Also check the [setup.py](https://github.com/cocodataset/panopticapi/blob/master/setup.py) for hints about the dependencies needed by panopticapi 34 | 2. Create a conda environment file for installing panopticapi 35 | 3. Use the conda environment file to build a container for LUMI-C using cotainr 36 | -------------------------------------------------------------------------------- /06_Bulding_containers_from_conda_pip_environments/examples/PandasAI.yml: -------------------------------------------------------------------------------- 1 | name: PandasAI 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - annotated-types=0.6.0 6 | - anyio=4.3.0 7 | - astor=0.8.1 8 | - certifi=2024.2.2 9 | - charset-normalizer=3.3.2 10 | - contourpy=1.2.1 11 | - cycler=0.12.1 12 | - distro=1.9.0 13 | - duckdb=0.7.1 14 | - faker=19.13.0 15 | - fonttools=4.51.0 16 | - greenlet=3.0.3 17 | - h11=0.14.0 18 | - httpcore=1.0.5 19 | - httpx=0.27.0 20 | - idna=3.7 21 | - jinja2=3.1.3 22 | - kiwisolver=1.4.5 23 | - markupsafe=2.1.5 24 | - matplotlib=3.8.4 25 | - numpy=1.26.4 26 | - openai=1.23.5 27 | - packaging=24.0 28 | - pandas=1.5.3 29 | - pillow=10.3.0 30 | - pip=24.0 31 | - pydantic=2.7.1 32 | - pydantic-core=2.18.2 33 | - pyparsing=3.1.2 34 | - python=3.11.9 35 | - python-dotenv=1.0.1 36 | - requests=2.31.0 37 | - scipy=1.13.0 38 | - sniffio=1.3.1 39 | - sqlalchemy=2.0.29 40 | - tqdm=4.66.2 41 | - typing-extensions=4.11.0 42 | - urllib3=2.2.1 43 | - pip: 44 | - pandasai==2.0.35 45 | -------------------------------------------------------------------------------- /06_Bulding_containers_from_conda_pip_environments/examples/minimal_pytorch.yml: -------------------------------------------------------------------------------- 1 | name: minimal_pytorch 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - filelock=3.15.4 6 | - fsspec=2024.9.0 7 | - jinja2=3.1.4 8 | - markupsafe=2.1.5 9 | - mpmath=1.3.0 10 | - networkx=3.3 11 | - numpy=2.1.1 12 | - pillow=10.4.0 13 | - pip=24.0 14 | - python=3.12.3 15 | - sympy=1.13.2 16 | - typing-extensions=4.12.2 17 | - pip: 18 | - --extra-index-url https://download.pytorch.org/whl/rocm6.0/ 19 | - pytorch-triton-rocm==3.0.0 20 | - torch==2.4.1+rocm6.0 21 | - torchaudio==2.4.1+rocm6.0 22 | - torchvision==0.19.1+rocm6.0 23 | 24 | -------------------------------------------------------------------------------- /06_Bulding_containers_from_conda_pip_environments/examples/python312.yml: -------------------------------------------------------------------------------- 1 | name: python312 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.12 6 | -------------------------------------------------------------------------------- /06_Bulding_containers_from_conda_pip_environments/reference_solution/panopticapi.yml: -------------------------------------------------------------------------------- 1 | name: panopticapi 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - git=2.45.1 6 | - numpy=1.26.4 7 | - pillow=10.3.0 8 | - pip=24.0 9 | - python=3.9 10 | - pip: 11 | - git+https://github.com/cocodataset/panopticapi.git 12 | -------------------------------------------------------------------------------- /06_Bulding_containers_from_conda_pip_environments/reference_solution/python312_extra.yml: -------------------------------------------------------------------------------- 1 | name: python312_extra 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.12.3 6 | - pip=24.0 7 | - python=3.12.3 8 | - pandas=2.2.0 9 | - scikit-learn=1.4.2 10 | - pip: 11 | - env-var==1.0.1 -------------------------------------------------------------------------------- /06_Bulding_containers_from_conda_pip_environments/reference_solution/reference_solution.md: -------------------------------------------------------------------------------- 1 | # Reference solutions to the hands-on exercises for 06 Building containers from conda/pip environments 2 | 3 | ## Exercise 1 4 | 5 | > 1. Using the example [python312.yml](examples/python312.yml) conda environment, use cotainr to build a container: 6 | > - Interactively on a login node 7 | > - Non-interactively on a compute node 8 | > 2. Compare the output of running `python3 -c "import sys; print(sys.executable); print(sys.version)"` on a login node: 9 | > - In the container you built 10 | > - Directly on LUMI 11 | 12 | To build a container using cotainr on LUMI, we must remember to: 13 | 14 | 1. Load the cotainr module on LUMI 15 | 2. Determine a suitable base image. For this exercise, we use the `lumi-rocm-rocm-6.0.3.sif` container found in `/appl/local/containers/sif-images/`. 16 | 3. Run cotainr using `srun`, redirect stdout/stderr, and accept all licenses up-front when building non-interactively on a compute node 17 | 18 | Since the `python312.yml` environment only contains Python 3.12, we don't need ROCm or other special system libraries. 19 | Thus, using `--system=lumi-c` instead of `--base-image=...` with cotainr would be sufficient for getting a fairly minimal base image. 20 | However, for sake of consistency we will use the ROCm base image. (Feel free to experiment with the `--system=lumi-c` or `--system=lumi-g` options!) 21 | 22 | On a login node, we may build the container interactively by: 23 | 24 | ```bash 25 | $ module purge 26 | $ module use /appl/local/training/modules/AI-20241126 27 | $ module load cotainr 28 | $ cotainr build python312.sif --base-image=/appl/local/containers/sif-images/lumi-rocm-rocm-6.0.3.sif --conda-env=examples/python312.yml 29 | ``` 30 | 31 | > [!NOTE] 32 | > The `module use /appl/local/training/modules/AI-20241126` provides the most recent version of cotainr installed in the AI workshop training project. If you don't include this, you get an older version of cotainr installed in the default LUMI software stack. 33 | 34 | On a LUMI-C compute node, we may build the container non-interactively by: 35 | 36 | ```bash 37 | $ module purge 38 | $ module use /appl/local/training/modules/AI-20241126 39 | $ module load cotainr 40 | $ srun --output=cotainr.out --error=cotainr.err --account=project_465001958 --time=00:15:00 --mem=100G --cpus-per-task=8 --partition=dev-g cotainr build python312.sif --base-image=/appl/local/containers/sif-images/lumi-rocm-rocm-6.0.3.sif --conda-env=examples/python312.yml --accept-licenses 41 | ``` 42 | 43 | > [!WARNING] 44 | > Cotainr will ask for permission to overwrite the `python312.sif` container if it already exists. Since cotainr currently does not provide a way to non-interactively accept this, when building non-interactively, it will get stuck until it is terminated by SLURM due to the time limit, if the `python312.sif` container already exists. 45 | 46 | > [!TIP] 47 | > As an alternative to directly calling `srun`, you may consider creating a SLURM batch script to setup your cotainr build on a compute node. 48 | 49 | Now, if we run the `python3 -c "import sys; print(sys.executable); print(sys.version)"` command to show which version of Python we are using, when running in the container, we get: 50 | 51 | ```bash 52 | $ singularity exec python312.sif python3 -c "import sys; print(sys.executable); print(sys.version)" 53 | /opt/conda/envs/conda_container_env/bin/python3 54 | 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:21:13) [GCC 13.3.0] 55 | ``` 56 | 57 | whereas directly on LUMI we get: 58 | 59 | ```bash 60 | $ python3 -c "import sys; print(sys.executable); print(sys.version)" 61 | /usr/bin/python3 62 | 3.6.15 (default, Sep 23 2021, 15:41:43) [GCC] 63 | ``` 64 | 65 | which shows that within the container we directly have access to the Python 3.12 we installed as part of our conda environment instead of the Python 3.6 provided by the OS. Note that if you run `python3 -c "import sys; print(sys.executable); print(sys.version)"` after having run `module load cotainr`, you will get 66 | 67 | ```bash 68 | $ python3 -c "import sys; print(sys.executable); print(sys.version)" 69 | /opt/cray/pe/python/3.11.7/bin/python3 70 | 3.11.7 (main, Feb 8 2024, 20:49:32) [GCC 12.3.0] 71 | ``` 72 | 73 | since the cotainr module loads the cray-python module to get a Python >= 3.8 which is needed for running cotainr. 74 | 75 | ## Exercise 2 76 | 77 | > 1. Using cotainr, update the container you built using the `python312.yml` conda environment to contain a few extra packages of your choice, e.g. pandas and scikit-learn. 78 | > 2. Open an interactive Python interpreter in the container and import your newly added packages. 79 | 80 | To update our container with extra packages, we must remember to: 81 | 82 | 1. Update the conda environment yaml file and rebuild the container - by design cotainr does not offer a way to change/update an existing container in order to maximize the reproducibility of the software environment in the container and [minimize the risk of ending up with a broken conda environment](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#using-pip-in-an-environment). 83 | 2. Pin versions of the packages we add when updating the conda environment yaml file to maximize reproducibility. 84 | 85 | Assuming we would like to add the `pandas`, `scikit-learn`, and `env-var` Python packages to the container, we may create an updated `python312_extra.yml` containing: 86 | 87 | ```yaml 88 | name: python312_extra 89 | channels: 90 | - conda-forge 91 | dependencies: 92 | - pip=24.0 93 | - python=3.12.3 94 | - pandas=2.2.0 95 | - scikit-learn=1.4.2 96 | - pip: 97 | - env-var==1.0.1 98 | ``` 99 | 100 | where we have added `pandas`and `scikit-learn` as Conda packages and `env-var` as a pip package, since no conda package exists for it (at least not on conda-forge). 101 | 102 | Now we can build the updated container: 103 | 104 | ```bash 105 | $ module use /appl/local/training/modules/AI-20241126 106 | $ module load cotainr 107 | $ cotainr build python312_extra.sif --base-image=/appl/local/containers/sif-images/lumi-rocm-rocm-6.0.3.sif --conda-env=python312_extra.yml 108 | ``` 109 | 110 | and open an interactive shell and an interactive Python interpreter in it, and import our added packages: 111 | 112 | ```bash 113 | $ singularity shell python312_extra.sif 114 | Singularity> python3 115 | Python 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0] on linux 116 | Type "help", "copyright", "credits" or "license" for more information. 117 | >>> import pandas, sklearn, env_var 118 | ``` 119 | 120 | > [!NOTE] 121 | > We don't need to activate the conda environment in the container as this is done automatically. 122 | 123 | > [!NOTE] 124 | > Even though we pin the versions of the added packages, their dependencies are not pinned and may change if building the container again at a later point in time. To be able to build a new container with the exact same set of all packages (including dependencies), you need to use the output of `conda env export` (in the container) as the conda environment file provided to cotainr (or specify all dependencies manually). The output of `conda env export` in the container looks something like: 125 | > 126 | > ```bash 127 | > Singularity> conda env export 128 | > name: conda_container_env 129 | > channels: 130 | > - conda-forge 131 | > dependencies: 132 | > - _libgcc_mutex=0.1=conda_forge 133 | > - _openmp_mutex=4.5=2_gnu 134 | > - bzip2=1.0.8=hd590300_5 135 | > - ca-certificates=2024.2.2=hbcca054_0 136 | > - joblib=1.4.2=pyhd8ed1ab_0 137 | > - ld_impl_linux-64=2.40=h55db66e_0 138 | > - libblas=3.9.0=22_linux64_openblas 139 | > - libcblas=3.9.0=22_linux64_openblas 140 | > - libexpat=2.6.2=h59595ed_0 141 | > - libffi=3.4.2=h7f98852_5 142 | > - libgcc-ng=13.2.0=h77fa898_7 143 | > - libgfortran-ng=13.2.0=h69a702a_7 144 | > - libgfortran5=13.2.0=hca663fb_7 145 | > - libgomp=13.2.0=h77fa898_7 146 | > - liblapack=3.9.0=22_linux64_openblas 147 | > - libnsl=2.0.1=hd590300_0 148 | > - libopenblas=0.3.27=pthreads_h413a1c8_0 149 | > - libsqlite=3.45.3=h2797004_0 150 | > - libstdcxx-ng=13.2.0=hc0a3c3a_7 151 | > - libuuid=2.38.1=h0b41bf4_0 152 | > - libxcrypt=4.4.36=hd590300_1 153 | > - libzlib=1.2.13=hd590300_5 154 | > - ncurses=6.5=h59595ed_0 155 | > - numpy=1.26.4=py312heda63a1_0 156 | > - openssl=3.3.0=hd590300_0 157 | > - pandas=2.2.0=py312hfb8ada1_0 158 | > - pip=24.0=pyhd8ed1ab_0 159 | > - python=3.12.3=hab00c5b_0_cpython 160 | > - python-dateutil=2.9.0=pyhd8ed1ab_0 161 | > - python-tzdata=2024.1=pyhd8ed1ab_0 162 | > - python_abi=3.12=4_cp312 163 | > - pytz=2024.1=pyhd8ed1ab_0 164 | > - readline=8.2=h8228510_1 165 | > - scikit-learn=1.4.2=py312h394d371_0 166 | > - scipy=1.13.0=py312hc2bc53b_1 167 | > - setuptools=69.5.1=pyhd8ed1ab_0 168 | > - six=1.16.0=pyh6c4a22f_0 169 | > - threadpoolctl=3.5.0=pyhc1e730c_0 170 | > - tk=8.6.13=noxft_h4845f30_101 171 | > - tzdata=2024a=h0c530f3_0 172 | > - wheel=0.43.0=pyhd8ed1ab_1 173 | > - xz=5.2.6=h166bdaf_0 174 | > - pip: 175 | > - arrow==1.3.0 176 | > - decorator==5.1.1 177 | > - env-var==1.0.1 178 | > - isoduration==20.11.0 179 | > - rfc3339-validator==0.1.4 180 | > - rfc3986-validator==0.1.1 181 | > - types-python-dateutil==2.9.0.20240316 182 | > - validators==0.18.2 183 | > prefix: /opt/conda/envs/conda_container_env 184 | > ``` 185 | 186 | ## Exercise 3 187 | 188 | > 1. Create a conda environment file for installing [panopticapi](https://github.com/cocodataset/panopticapi) 189 | > 2. Use the conda environment file to build a container for LUMI-C using cotainr 190 | 191 | To build a container for panopticapi using cotainr on LUMI, we must remember to: 192 | 193 | 1. Include as many of panopticapi's dependencies (in this case listed in the `setup.py` file) as possible as Conda packages in our conda environment file. 194 | 2. List the panopticapi GitHub repo master branch as a pip dependency since no conda/pip packages exist for panopticapi. 195 | 3. Add git as a conda dependency for pip to be able to pull the panopticapi GitHub repo. 196 | 197 | A `panopticapi.yml` conda environment file may look like: 198 | 199 | ```yaml 200 | name: panopticapi 201 | channels: 202 | - conda-forge 203 | dependencies: 204 | - git=2.45.1 205 | - numpy=1.26.4 206 | - pillow=10.3.0 207 | - pip=24.0 208 | - python=3.9 209 | - pip: 210 | - git+https://github.com/cocodataset/panopticapi.git 211 | ``` 212 | 213 | > [!NOTE] 214 | > panopticapi is a somewhat old package that does not have any specific release and no pinned versions of dependencies. Thus, one may need some trial & error to install versions of numpy and pillow that are compatible with the current master branch of panopticapi. Maybe the above works... 215 | 216 | Now we can build a container the usual way: 217 | 218 | ```bash 219 | $ module use /appl/local/training/modules/AI-20241126 220 | $ module load cotainr 221 | $ cotainr build panopticapi.sif --base-image=/appl/local/containers/sif-images/lumi-rocm-rocm-6.0.3.sif --conda-env=panopticapi.yml 222 | ``` 223 | 224 | > [!NOTE] 225 | > When you install directly from (private) Git(Hub) repos, you may need to install extra Conda packages needed by pip to connect to the repo, e.g. git and openssh. Alternatively, you can also install directly from a zip archive of the repo, e.g. specifying https://github.com/cocodataset/panopticapi/archive/master.zip instead of git+https://github.com/cocodataset/panopticapi.git. See the [cotainr conda environment documentation](https://cotainr.readthedocs.io/en/latest/user_guide/conda_env.html#pip-packages-from-private-repositories) for a more elaborate example. 226 | -------------------------------------------------------------------------------- /07_Extending_containers_with_virtual_environments_for_faster_testing/README.md: -------------------------------------------------------------------------------- 1 | # 07 Extending containers with virtual environments for faster testing 2 | 3 | ## Examples 4 | 5 | - A full example of how to extend a container with a virtual environment can be found in the Markdown file [extending_containers_with_venv.md](examples/extending_containers_with_venv.md). 6 | -------------------------------------------------------------------------------- /07_Extending_containers_with_virtual_environments_for_faster_testing/examples/extending_containers_with_venv.md: -------------------------------------------------------------------------------- 1 | # Extending containers with virtual environments for faster testing 2 | 3 | This is a short example of how to extend the containers built via `cotainr` via virtual environments. This approach can be useful for developing and testing as it doesn't require rebuilding a container from scratch every time a new package is added. 4 | 5 | ## Requirements 6 | 7 | We assume you have built a container from a `conda` environment file via something like: 8 | ```bash 9 | module load LUMI/24.03 cotainr 10 | cotainr build minimal_pytorch.sif --base-image=/appl/local/containers/sif-images/lumi-rocm-rocm-6.0.3.sif --conda-env=minimal_pytorch.yml --accept-license 11 | ``` 12 | 13 | ## Set up a virtual environment 14 | 15 | First, we run a shell inside the container 16 | ```bash 17 | singularity shell --bind /pfs,/scratch,/projappl,/project,/flash,/appl minimal_pytorch.sif 18 | ``` 19 | Note that setting `--bind` is optional, you achieve the same by 20 | ```bash 21 | module use /appl/local/containers/ai-modules 22 | module load singularity-AI-bindings 23 | singularity shell minimal_pytorch.sif 24 | ``` 25 | 26 | In order to install additional packages, we create a virtual environment via `venv` and activate it inside the container 27 | ```bash 28 | python -m venv myenv --system-site-packages 29 | source myenv/bin/activate 30 | ``` 31 | The `--system-site-packages` flag gives the virtual environment access to the packages from the container. 32 | 33 | ## Install custom packages 34 | 35 | After activating the virtual environment, we can now install custom packages via pip, for example: 36 | ```bash 37 | pip install torchmetrics 38 | ``` 39 | 40 | ## Run container with `venv` packages 41 | If we want to run the container with the freshly installed packages in a batch script, we need to first source the `venv` before executing the python script: 42 | ```bash 43 | singularity exec minimal_pytorch.sif bash -c "source myenv/bin/activate && python my_script.py" 44 | ``` 45 | 46 | > [!WARNING] 47 | > You should not stop here, as this way of installing python packages creates typically thousands of small files. This puts a lot of strain on the Lustre file system and might exceed your file quota. Choose one of the following options next: 48 | 49 | 50 | ## Option 1: Create a new container with `cotainr` 51 | After having found all packages needed for our project, we should create a new container with an updated `conda` environment file. The virtual environment should then be deleted 52 | ```bash 53 | cotainr build updated_pytorch.sif --base-image=/appl/local/containers/sif-images/lumi-rocm-rocm-6.0.3.sif --conda-env=updated_pytorch.yml --accept-license 54 | rm -rf myenv 55 | ``` 56 | 57 | 58 | ## Option 2: Turn `myenv` into a SquashFS file 59 | Alternatively, we can turn the `myenv` directory into a SquashFS file and bind mount it to the container: 60 | ```bash 61 | mksquashfs myenv myenv.sqsh 62 | rm -rf myenv # the myenv directory can be deleted 63 | export SINGULARITYENV_PREPEND_PATH=/user-software/bin # gives access to packages inside the container 64 | singularity exec -B myenv.sqsh:/user-software:image-src=/ minimal_pytorch.sif python my_script.py 65 | ``` 66 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/GPT-neo-IMDB-finetuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # IMDB movie review text generation 5 | # 6 | # In this script, we'll fine-tune a GPT2-like model to generate more 7 | # movie reviews based on a prompt. 8 | # 9 | # Partly based on this tutorial: 10 | # https://github.com/omidiu/GPT-2-Fine-Tuning/ 11 | 12 | 13 | import argparse 14 | import math 15 | import time 16 | from pprint import pprint 17 | 18 | import torch 19 | from datasets import load_dataset 20 | from util import preprocess_data, get_output_paths 21 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 22 | DataCollatorForLanguageModeling, Trainer, 23 | TrainingArguments) 24 | 25 | 26 | if __name__ == "__main__": 27 | 28 | # First we set up some command line arguments to allow us to specify data/output paths 29 | # and the number of worker processes without changing the code. 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument( 32 | "--model-name", 33 | type=str, 34 | default="gpt-imdb-model", 35 | help="A name for the trained model under. A subdirectory with the given name will be created under the `output-path`.", 36 | ) 37 | parser.add_argument( 38 | "--output-path", 39 | type=str, 40 | help="The root directory under which model checkpoints are stored.", 41 | ) 42 | parser.add_argument( 43 | "--logging-path", 44 | type=str, 45 | help="The root directory under which logging data (for tensorboard) are stored.", 46 | ) 47 | parser.add_argument( 48 | "--num-workers", 49 | type=int, 50 | default=1, 51 | help="The number of CPU worker processes to use.", 52 | ) 53 | args, _ = parser.parse_known_args() 54 | 55 | # Then we determine the device on which to train the model. 56 | print("Using PyTorch version:", torch.__version__) 57 | if torch.cuda.is_available(): 58 | # 59 | device = torch.device("cuda") 60 | print("Using GPU, device name:", torch.cuda.get_device_name(device)) 61 | else: 62 | print("No GPU found, using CPU instead.") 63 | device = torch.device("cpu") 64 | 65 | # We also ensure that output paths exist 66 | output_dir, logging_dir = get_output_paths(args) 67 | 68 | # #### Loading the GPT-neo model 69 | # 70 | # We'll use the gpt-neo-1.3B model from the Hugging Face library: 71 | # https://huggingface.co/EleutherAI/gpt-neo-1.3B 72 | # Let's start with getting the appropriate tokenizer. 73 | pretrained_model = "EleutherAI/gpt-neo-1.3B" 74 | 75 | print("Loading model and tokenizer") 76 | start = time.time() 77 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 78 | tokenizer.pad_token = tokenizer.eos_token 79 | 80 | # Load the actual base model from Hugging Face 81 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 82 | model.to(device) 83 | stop = time.time() 84 | print(f"Loading model and tokenizer took: {stop-start:.2f} seconds") 85 | 86 | # #### Loading the IMDb data set 87 | # 88 | # Next we'll load the IMDb data set: https://huggingface.co/docs/datasets/index. 89 | # 90 | # The data set contains 100,000 movies reviews from the Internet Movie 91 | # Database, split into 25,000 reviews for training and 25,000 reviews 92 | # for testing and 50,000 without labels (unsupervised) that we also use for training. 93 | 94 | train_dataset = load_dataset( 95 | "imdb", split="train+unsupervised", trust_remote_code=False, keep_in_memory=True 96 | ) 97 | eval_dataset = load_dataset( 98 | "imdb", split="test", trust_remote_code=False, keep_in_memory=True 99 | ) 100 | 101 | # Let's print one sample from the dataset. 102 | print("Sample from dataset") 103 | pprint(train_dataset[200]) 104 | 105 | # #### Setting up the training configuration 106 | # 107 | train_batch_size = 32 # This just about fits into the VRAM of a single MI250x GCD with 16-bit floats 108 | eval_batch_size = 128 # No optimizer state during evaluation, so can use bigger batches for increased throughput 109 | 110 | training_args = TrainingArguments( 111 | output_dir=output_dir, 112 | save_strategy="steps", 113 | save_steps=100, 114 | save_total_limit=4, 115 | logging_dir=logging_dir, 116 | eval_strategy="steps", 117 | eval_steps=200, # compute validation loss every 200 steps 118 | learning_rate=2e-5, 119 | weight_decay=0.01, 120 | bf16=True, # use 16-bit floating point precision 121 | per_device_train_batch_size=train_batch_size, 122 | per_device_eval_batch_size=eval_batch_size, 123 | max_steps=1000, 124 | dataloader_num_workers=args.num_workers, 125 | dataloader_pin_memory=True, 126 | report_to=["tensorboard"], # log statistics for tensorboard 127 | ) 128 | 129 | # #### Preprocessing of training data 130 | # We tokenize the data into torch tensors, split training into training and validation and set up a collator that 131 | # is able to arrange single data samples into batches. 132 | 133 | train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized = preprocess_data(train_dataset, eval_dataset, tokenizer, training_args) 134 | 135 | collator = DataCollatorForLanguageModeling( 136 | tokenizer, mlm=False, return_tensors="pt" 137 | ) 138 | 139 | # Sanity check: How does the training data look like after preprocessing? 140 | print("Sample of tokenized data") 141 | for b in train_dataset_tokenized: 142 | pprint(b, compact=True) 143 | print("Length of input_ids:", len(b["input_ids"])) 144 | break 145 | print("Length of dataset (tokenized)", len(train_dataset_tokenized)) 146 | 147 | # #### Training 148 | # We use the Hugging Face trainer instead of a manual training loop. 149 | # 150 | # You can read about the many, many different parameters to the 151 | # Hugging Face trainer here: 152 | # https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments 153 | # 154 | 155 | trainer = Trainer( 156 | model=model, 157 | args=training_args, 158 | tokenizer=tokenizer, 159 | data_collator=collator, 160 | train_dataset=train_dataset_tokenized, 161 | eval_dataset=validate_dataset_tokenized, 162 | ) 163 | 164 | # With 1000 steps, batch size 32 and a single GCD, this should take just under 30 minutes. 165 | trainer.train() 166 | 167 | print() 168 | print("Training done, you can find all the model checkpoints in", output_dir) 169 | 170 | # #### Evaluating the finetuned model 171 | with torch.no_grad(): 172 | model.eval() 173 | # Calculate perplexity 174 | eval_results = trainer.evaluate() 175 | test_results = trainer.evaluate(eval_dataset_tokenized) 176 | 177 | print(f'Perplexity on validation: {math.exp(eval_results["eval_loss"]):.2f}') 178 | print(f'Perplexity on test: {math.exp(test_results["eval_loss"]):.2f}') 179 | 180 | # Let's print a few sample generated reviews; this is the same as in the previous exercise 181 | # but now we use the finetuned model 182 | prompt = "The movie 'How to run ML on LUMI - A documentation' was great because" 183 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 184 | outputs = model.generate( 185 | **inputs, do_sample=True, max_length=80, num_return_sequences=4 186 | ) 187 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 188 | 189 | print("Sample generated review:") 190 | for txt in decoded_outputs: 191 | print("-", txt) 192 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/README.md: -------------------------------------------------------------------------------- 1 | # 08 Scaling to multiple GPUs 2 | 3 | ## Hands-on exercises 4 | 5 | 1. Adjust the training script to run with torchrun for multiple GCDs on a single node. 6 | 7 | In this exercise you have to make some changes to the Python training script to prepare it for running on multiple GPUs across several processes. 8 | 9 | We will use [torchrun](https://pytorch.org/docs/stable/elastic/run.html) to run the training on all GCDs on a full LUMI node. Torchrun creates and manages one process per GCD, each executing our training script, and provides the following environment variables to each process: 10 | - `WORLD_SIZE`: The total number of processes. 11 | - `RANK`: The process number from `0` to `WORLD_SIZE-1`. 12 | - `LOCAL_WORLD_SIZE`: The total number of processes on the same node. 13 | - `LOCAL_RANK`: The process number within the same node from `0` to `LOCAL_WORLD_SIZE-1`. 14 | - `MASTER_ADDR`: The URL of the host that is running worker with rank 0; used to initialize the Torch Distributed backend. 15 | - `MASTER_PORT`: The port on the `MASTER_ADDR` the different processes use to communicate. 16 | 17 | In this exercise we will use only a single node (but all 8 GPUs from it). Therefore `LOCAL_WORLD_SIZE` and `LOCAL_RANK` will be identical to `WORLD_SIZE` and `RANK`. 18 | 19 | Using torchrun, every process sees all GCDs on a node, so we need to make sure that our script selects one GCD it is going to train on according to its local rank. 20 | 21 | On the other hand, the HuggingFace Trainer will automatically take care of setting up data-parallel training for the model when the above environment variables are set by torchrun, so we do not need to handle that part of the setup manually - but we do need to adjust the batch size handled locally in each process. 22 | 23 | Find the training script in [08_Scaling_to_multiple_GPUs/GPT-neo-IMDB-finetuning.py](08_Scaling_to_multiple_GPUs/GPT-neo-IMDB-finetuning.py). It is the same as the one used earlier for training with a single GCD/GPU. 24 | 25 | 1. You will need to make the following changes to the script: 26 | 27 | - select the correct PyTorch device 28 | - adjust the per-device batch size handled per process 29 | - (optional) limit printing of outputs to a single process 30 | 31 | Places where you need to edit the file have been marked with ``. 32 | 33 | 2. Adjust the slurm batch file. 34 | 35 | Now you need to change the slurm batch file to request multiple GCDs (GPUs) on a single node and use `torchrun` to start a training job that 36 | parallelises training across the GCDs. 37 | 38 | 1. Edit the slurm batch file [08_Scaling_to_multiple_GPUs/run.sh](run.sh) for single-node multi-gpu training using torchrun. 39 | 40 | You should specify at least the following: 41 | - the correct slurm partition 42 | - number of GPUs requested (8) 43 | - number of CPUs requested 44 | - RAM requested (we recommend using 60GB per requested GPU to leave some room for the OS overhead) 45 | - requested runtime (20 minutes should be plenty to finish training and running evaluation) 46 | 47 | It can also be helpful to specify a name for the slurm logfile that contains the command line outputs of the script. 48 | 49 | > ** Tip ** 50 | > 51 | > You can use a different `--model-name` than in Exercise 3, to start a fresh training run without overwriting your 52 | > earlier results. The environment variable `MODEL_NAME` is a suggestion for a name that you can use. 53 | 54 | You will also need to add the relevant parts for setting up the PyTorch software environment (these are the same as for Exercise `03_Your_first_AI_training_job_on_LUMI`). 55 | 56 | To invoke torchrun from the batch file, follow the [Single-node multi-worker usage example on the torchrun website](https://pytorch.org/docs/stable/elastic/run.html#single-node-multi-worker). 57 | 58 | 2. Run your job using `sbatch run.sh`. 59 | 60 | 3. Compare how the run time differs between running on a full node and the previous run on a single GCD. 61 | 62 | You don't necessarily need to wait for the run to finish but can compare the estimated total time given by the progress bar. 63 | 64 | 3. (Optional/Bonus): Set up CPU bindings. 65 | 66 | In order to achieve optimal CPU-GPU data transfer performance we can ensure that each script runs on the CPU cores closest to the respective GPU. 67 | As we are using torchrun to manage the worker processes, we cannot handle these CPU bindings via slurm but must set them up in our Python training script. 68 | 69 | 1. Edit [08_Scaling_to_multiple_GPUs/GPT-neo-IMDB-finetuning.py](GPT-neo-IMDB-finetuning.py) to set up the correct CPU-GPU bindings based on the processes rank. 70 | 71 | You can find a [figure showing which cores are closest to which GCD](https://docs.lumi-supercomputer.eu/assets/images/lumig-cpu-gpu-links.svg) on the [LUMI Docs LUMI-G page](https://docs.lumi-supercomputer.eu/hardware/lumig/). 72 | 73 | > **Tip** 74 | > 75 | > Use the `psutil.Process().cpu_affinity(...)` function to set the binding from inside the Python script. 76 | 77 | 4. (Optional/Bonus): Running without torchrun. 78 | 79 | We can also start worker processes directly without using torchrun to have direct control over all processes. 80 | 81 | 1. Change the slurm batch script to 82 | - instruct slurm to start the appropriate number of processes, 83 | - set the environment variables mentioned above manually, 84 | - replace the `torchrun` invocation with direct `python` commands to run the training script. 85 | 86 | > **Note** 87 | > 88 | > You can get the hostname of the node running the rank 0 process using the command 89 | > ``` 90 | > hostname 91 | > ``` 92 | 93 | In this setting you could then also do the CPU bindings from the slurm batch file instead of Python, to keep the training script free of system specific setup. 94 | 95 | ## Solutions 96 | 97 | The folder `reference_solution/` contains an example solution for this exercise parts 1, 2 and 4. `reference_solution/prints_only_from_single_process` extends this to ensure that `print` statements in the code are run only by a single process. `reference_solution/with_cpu_bindings` shows how CPU bindings can be used both from within Python (when using torchrun) and directly via SLURM (exercise part 3). 98 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/GPT-neo-IMDB-finetuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # IMDB movie review text generation 5 | # 6 | # In this script, we'll fine-tune a GPT2-like model to generate more 7 | # movie reviews based on a prompt. 8 | # 9 | # Partly based on this tutorial: 10 | # https://github.com/omidiu/GPT-2-Fine-Tuning/ 11 | 12 | 13 | import argparse 14 | import math 15 | import os 16 | import time 17 | from pprint import pprint 18 | 19 | import torch 20 | from datasets import load_dataset 21 | from util import preprocess_data, get_output_paths 22 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 23 | DataCollatorForLanguageModeling, Trainer, 24 | TrainingArguments) 25 | 26 | 27 | if __name__ == "__main__": 28 | 29 | # First we set up some command line arguments to allow us to specify data/output paths 30 | # and the number of worker processes without changing the code. 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | "--model-name", 34 | type=str, 35 | default="gpt-imdb-model", 36 | help="A name for the trained model under. A subdirectory with the given name will be created under the `output-path`.", 37 | ) 38 | parser.add_argument( 39 | "--output-path", 40 | type=str, 41 | help="The root directory under which model checkpoints are stored.", 42 | ) 43 | parser.add_argument( 44 | "--logging-path", 45 | type=str, 46 | help="The root directory under which logging data (for tensorboard) are stored.", 47 | ) 48 | parser.add_argument( 49 | "--num-workers", 50 | type=int, 51 | default=1, 52 | help="The number of CPU worker processes to use.", 53 | ) 54 | args, _ = parser.parse_known_args() 55 | 56 | # Read the environment variables provided by torchrun 57 | rank = int(os.environ["RANK"]) 58 | local_rank = int(os.environ["LOCAL_RANK"]) 59 | world_size = int(os.environ["WORLD_SIZE"]) 60 | local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 61 | 62 | # Then we determine the device on which to train the model. 63 | print("Using PyTorch version:", torch.__version__) 64 | if torch.cuda.is_available(): 65 | print( 66 | f"Rank {rank} of {world_size} (local: {local_rank}) sees {torch.cuda.device_count()} devices" 67 | ) 68 | device = torch.device("cuda", local_rank) 69 | print("Using GPU, device name:", torch.cuda.get_device_name(device)) 70 | else: 71 | print("No GPU found, using CPU instead.") 72 | device = torch.device("cpu") 73 | 74 | # We also ensure that output paths exist 75 | output_dir, logging_dir = get_output_paths(args) 76 | 77 | # #### Loading the GPT-neo model 78 | # 79 | # We'll use the gpt-neo-1.3B model from the Hugging Face library: 80 | # https://huggingface.co/EleutherAI/gpt-neo-1.3B 81 | # Let's start with getting the appropriate tokenizer. 82 | pretrained_model = "EleutherAI/gpt-neo-1.3B" 83 | 84 | print("Loading model and tokenizer") 85 | start = time.time() 86 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 87 | tokenizer.pad_token = tokenizer.eos_token 88 | 89 | # Load the actual base model from Hugging Face 90 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 91 | model.to(device) 92 | stop = time.time() 93 | print(f"Loading model and tokenizer took: {stop-start:.2f} seconds") 94 | 95 | # #### Loading the IMDb data set 96 | # 97 | # Next we'll load the IMDb data set: https://huggingface.co/docs/datasets/index. 98 | # 99 | # The data set contains 100,000 movies reviews from the Internet Movie 100 | # Database, split into 25,000 reviews for training and 25,000 reviews 101 | # for testing and 50,000 without labels (unsupervised) that we also use for training. 102 | 103 | train_dataset = load_dataset( 104 | "imdb", split="train+unsupervised", trust_remote_code=False, keep_in_memory=True 105 | ) 106 | eval_dataset = load_dataset( 107 | "imdb", split="test", trust_remote_code=False, keep_in_memory=True 108 | ) 109 | 110 | # Let's print one sample from the dataset. 111 | print("Sample from dataset") 112 | pprint(train_dataset[200]) 113 | 114 | # #### Setting up the training configuration 115 | global_train_batch_size = 32 # We keep the overall batch size (across all GPUs) the same as before ... 116 | per_device_train_batch_size = global_train_batch_size // world_size # ... which means we divide by the number of processes for the batch size of each GPU 117 | eval_batch_size = 128 # No optimizer state during evaluation, so can use bigger batches for increased throughput 118 | 119 | training_args = TrainingArguments( 120 | output_dir=output_dir, 121 | save_strategy="steps", 122 | save_steps=100, 123 | save_total_limit=4, 124 | logging_dir=logging_dir, 125 | eval_strategy="steps", 126 | eval_steps=200, # compute validation loss every 200 steps 127 | learning_rate=2e-5, 128 | weight_decay=0.01, 129 | bf16=True, # use 16-bit floating point precision 130 | per_device_train_batch_size=per_device_train_batch_size, 131 | per_device_eval_batch_size=eval_batch_size, 132 | max_steps=1000, 133 | dataloader_num_workers=args.num_workers, 134 | dataloader_pin_memory=True, 135 | report_to=["tensorboard"], # log statistics for tensorboard 136 | ddp_find_unused_parameters=False, # there are no unused parameters, causing PyTorch to issue a warning should this be set to True 137 | ) 138 | 139 | # #### Preprocessing of training data 140 | # We tokenize the data into torch tensors, split training into training and validation and set up a collator that 141 | # is able to arrange single data samples into batches. 142 | 143 | train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized = preprocess_data(train_dataset, eval_dataset, tokenizer, training_args) 144 | 145 | collator = DataCollatorForLanguageModeling( 146 | tokenizer, mlm=False, return_tensors="pt" 147 | ) 148 | 149 | # Sanity check: How does the training data look like after preprocessing? 150 | print("Sample of tokenized data") 151 | for b in train_dataset_tokenized: 152 | pprint(b, compact=True) 153 | print("Length of input_ids:", len(b["input_ids"])) 154 | break 155 | print("Length of dataset (tokenized)", len(train_dataset_tokenized)) 156 | 157 | # #### Training 158 | # We use the Hugging Face trainer instead of a manual training loop. 159 | # 160 | # You can read about the many, many different parameters to the 161 | # Hugging Face trainer here: 162 | # https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments 163 | # 164 | 165 | trainer = Trainer( 166 | model=model, 167 | args=training_args, 168 | tokenizer=tokenizer, 169 | data_collator=collator, 170 | train_dataset=train_dataset_tokenized, 171 | eval_dataset=validate_dataset_tokenized, 172 | ) 173 | 174 | # With 1000 steps, batch size 32 and a single GCD, this should take just under 30 minutes. 175 | trainer.train() 176 | 177 | print() 178 | print("Training done, you can find all the model checkpoints in", output_dir) 179 | 180 | # #### Evaluating the finetuned model 181 | with torch.no_grad(): 182 | model.eval() 183 | # Calculate perplexity 184 | eval_results = trainer.evaluate() 185 | test_results = trainer.evaluate(eval_dataset_tokenized) 186 | 187 | print(f'Perplexity on validation: {math.exp(eval_results["eval_loss"]):.2f}') 188 | print(f'Perplexity on test: {math.exp(test_results["eval_loss"]):.2f}') 189 | 190 | # Let's print a few sample generated reviews; this is the same as in the previous exercise 191 | # but now we use the finetuned model 192 | prompt = "The movie 'How to run ML on LUMI - A documentation' was great because" 193 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 194 | outputs = model.generate( 195 | **inputs, do_sample=True, max_length=80, num_return_sequences=4 196 | ) 197 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 198 | 199 | print("Sample generated review:") 200 | for txt in decoded_outputs: 201 | print("-", txt) 202 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/prints_only_from_single_process/GPT-neo-IMDB-finetuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # IMDB movie review text generation 5 | # 6 | # In this script, we'll fine-tune a GPT2-like model to generate more 7 | # movie reviews based on a prompt. 8 | # 9 | # Partly based on this tutorial: 10 | # https://github.com/omidiu/GPT-2-Fine-Tuning/ 11 | 12 | 13 | import argparse 14 | import math 15 | import os 16 | import time 17 | from pprint import pprint 18 | 19 | import torch 20 | from datasets import load_dataset 21 | from util import preprocess_data, get_output_paths 22 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 23 | DataCollatorForLanguageModeling, Trainer, 24 | TrainingArguments) 25 | 26 | 27 | if __name__ == "__main__": 28 | 29 | # First we set up some command line arguments to allow us to specify data/output paths 30 | # and the number of worker processes without changing the code. 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | "--model-name", 34 | type=str, 35 | default="gpt-imdb-model", 36 | help="A name for the trained model under. A subdirectory with the given name will be created under the `output-path`.", 37 | ) 38 | parser.add_argument( 39 | "--output-path", 40 | type=str, 41 | help="The root directory under which model checkpoints are stored.", 42 | ) 43 | parser.add_argument( 44 | "--logging-path", 45 | type=str, 46 | help="The root directory under which logging data (for tensorboard) are stored.", 47 | ) 48 | parser.add_argument( 49 | "--num-workers", 50 | type=int, 51 | default=1, 52 | help="The number of CPU worker processes to use.", 53 | ) 54 | args, _ = parser.parse_known_args() 55 | 56 | # Read the environment variables provided by torchrun 57 | rank = int(os.environ["RANK"]) 58 | local_rank = int(os.environ["LOCAL_RANK"]) 59 | world_size = int(os.environ["WORLD_SIZE"]) 60 | local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 61 | 62 | # Then we determine the device on which to train the model. 63 | print("Using PyTorch version:", torch.__version__) 64 | if torch.cuda.is_available(): 65 | print( 66 | f"Rank {rank} of {world_size} (local: {local_rank}) sees {torch.cuda.device_count()} devices" 67 | ) 68 | device = torch.device("cuda", local_rank) 69 | print("Using GPU, device name:", torch.cuda.get_device_name(device)) 70 | else: 71 | print("No GPU found, using CPU instead.") 72 | device = torch.device("cpu") 73 | 74 | # We also ensure that output paths exist 75 | output_dir, logging_dir = get_output_paths(args) 76 | 77 | # #### Loading the GPT-neo model 78 | # 79 | # We'll use the gpt-neo-1.3B model from the Hugging Face library: 80 | # https://huggingface.co/EleutherAI/gpt-neo-1.3B 81 | # Let's start with getting the appropriate tokenizer. 82 | pretrained_model = "EleutherAI/gpt-neo-1.3B" 83 | 84 | if rank == 0: 85 | print("Loading model and tokenizer") 86 | start = time.time() 87 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 88 | tokenizer.pad_token = tokenizer.eos_token 89 | 90 | # Load the actual base model from Hugging Face 91 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 92 | model.to(device) 93 | stop = time.time() 94 | if rank == 0: 95 | print(f"Loading model and tokenizer took: {stop-start:.2f} seconds") 96 | 97 | # #### Loading the IMDb data set 98 | # 99 | # Next we'll load the IMDb data set: https://huggingface.co/docs/datasets/index. 100 | # 101 | # The data set contains 100,000 movies reviews from the Internet Movie 102 | # Database, split into 25,000 reviews for training and 25,000 reviews 103 | # for testing and 50,000 without labels (unsupervised) that we also use for training. 104 | 105 | train_dataset = load_dataset( 106 | "imdb", split="train+unsupervised", trust_remote_code=False, keep_in_memory=True 107 | ) 108 | eval_dataset = load_dataset( 109 | "imdb", split="test", trust_remote_code=False, keep_in_memory=True 110 | ) 111 | 112 | # Let's print one sample from the dataset. 113 | if rank == 0: 114 | print("Sample from dataset") 115 | pprint(train_dataset[200]) 116 | 117 | # #### Setting up the training configuration 118 | global_train_batch_size = 32 # We keep the overall batch size (across all GPUs) the same as before ... 119 | per_device_train_batch_size = global_train_batch_size // world_size # ... which means we divide by the number of processes for the batch size of each GPU 120 | eval_batch_size = 128 # No optimizer state during evaluation, so can use bigger batches for increased throughput 121 | 122 | training_args = TrainingArguments( 123 | output_dir=output_dir, 124 | save_strategy="steps", 125 | save_steps=100, 126 | save_total_limit=4, 127 | logging_dir=logging_dir, 128 | eval_strategy="steps", 129 | eval_steps=200, # compute validation loss every 200 steps 130 | learning_rate=2e-5, 131 | weight_decay=0.01, 132 | bf16=True, # use 16-bit floating point precision 133 | per_device_train_batch_size=per_device_train_batch_size, 134 | per_device_eval_batch_size=eval_batch_size, 135 | max_steps=1000, 136 | dataloader_num_workers=args.num_workers, 137 | dataloader_pin_memory=True, 138 | report_to=["tensorboard"], # log statistics for tensorboard 139 | ddp_find_unused_parameters=False, # there are no unused parameters, causing PyTorch to issue a warning should this be set to True 140 | ) 141 | 142 | # #### Preprocessing of training data 143 | # We tokenize the data into torch tensors, split training into training and validation and set up a collator that 144 | # is able to arrange single data samples into batches. 145 | 146 | train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized = preprocess_data(train_dataset, eval_dataset, tokenizer, training_args) 147 | 148 | collator = DataCollatorForLanguageModeling( 149 | tokenizer, mlm=False, return_tensors="pt" 150 | ) 151 | 152 | # Sanity check: How does the training data look like after preprocessing? 153 | if rank == 0: 154 | print("Sample of tokenized data") 155 | for b in train_dataset_tokenized: 156 | pprint(b, compact=True) 157 | print("Length of input_ids:", len(b["input_ids"])) 158 | break 159 | print("Length of dataset (tokenized)", len(train_dataset_tokenized)) 160 | 161 | # #### Training 162 | # We use the Hugging Face trainer instead of a manual training loop. 163 | # 164 | # You can read about the many, many different parameters to the 165 | # Hugging Face trainer here: 166 | # https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments 167 | # 168 | 169 | trainer = Trainer( 170 | model=model, 171 | args=training_args, 172 | tokenizer=tokenizer, 173 | data_collator=collator, 174 | train_dataset=train_dataset_tokenized, 175 | eval_dataset=validate_dataset_tokenized, 176 | ) 177 | 178 | # With 1000 steps, batch size 32 and a single GCD, this should take just under 30 minutes. 179 | trainer.train() 180 | 181 | if rank == 0: 182 | print() 183 | print("Training done, you can find all the model checkpoints in", output_dir) 184 | 185 | # #### Evaluating the finetuned model 186 | with torch.no_grad(): 187 | model.eval() 188 | # Calculate perplexity 189 | eval_results = trainer.evaluate() 190 | test_results = trainer.evaluate(eval_dataset_tokenized) 191 | 192 | if rank == 0: 193 | print( 194 | f'Perplexity on validation: {math.exp(eval_results["eval_loss"]):.2f}' 195 | ) 196 | print(f'Perplexity on test: {math.exp(test_results["eval_loss"]):.2f}') 197 | 198 | # Let's print a few sample generated reviews; this is the same as in the previous exercise 199 | # but now we use the finetuned model 200 | prompt = ( 201 | "The movie 'How to run ML on LUMI - A documentation' was great because" 202 | ) 203 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 204 | outputs = model.generate( 205 | **inputs, do_sample=True, max_length=80, num_return_sequences=4 206 | ) 207 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 208 | 209 | print("Sample generated review:") 210 | for txt in decoded_outputs: 211 | print("-", txt) 212 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/prints_only_from_single_process/run_no_torchrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_2 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=standard-g 5 | #SBATCH --nodes=1 6 | #SBATCH --gpus-per-node=8 7 | #SBATCH --ntasks-per-node=8 # we want one process per GPU 8 | #SBATCH --cpus-per-task=7 9 | #SBATCH --mem-per-gpu=60G 10 | #SBATCH --time=0:20:00 11 | 12 | # Set up the software environment 13 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 14 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 15 | # If you are interested, you can check the exact paths being mounted from 16 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 17 | module purge 18 | module use /appl/local/containers/ai-modules 19 | module load singularity-AI-bindings 20 | 21 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 22 | 23 | # Some environment variables to set up cache directories 24 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 25 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 26 | export TORCH_HOME=$SCRATCH/torch-cache 27 | export HF_HOME=$FLASH/hf-cache 28 | mkdir -p $TORCH_HOME $HF_HOME 29 | 30 | # Disable internal parallelism of huggingface's tokenizer since we 31 | # want to retain direct control of parallelism options. 32 | export TOKENIZERS_PARALLELISM=false 33 | 34 | # Path to where the trained model and logging data will go 35 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 36 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 37 | export MODEL_NAME=gpt-imdb-model-multigpu-no-torchrun 38 | 39 | set -xv # print the command so that we can verify setting arguments correctly from the logs 40 | 41 | # Set up variables to control distributed PyTorch training 42 | export MASTER_ADDR=$(hostname) 43 | export MASTER_PORT=25900 44 | export WORLD_SIZE=$SLURM_NPROCS 45 | export LOCAL_WORLD_SIZE=$SLURM_GPUS_PER_NODE 46 | 47 | # As opposed to the example in `run_torchrun.sh`, we can set the CPU binds directly via the slurm command, since we have 48 | # one task per GPU. In this case we do NOT need to set them from within the Python code itself. 49 | srun singularity exec $CONTAINER \ 50 | bash -c "RANK=\$SLURM_PROCID \ 51 | LOCAL_RANK=\$SLURM_LOCALID \ 52 | python GPT-neo-IMDB-finetuning.py \ 53 | --model-name $MODEL_NAME \ 54 | --output-path $OUTPUT_DIR \ 55 | --logging-path $LOGGING_DIR \ 56 | --num-workers ${SLURM_CPUS_PER_TASK}" 57 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/prints_only_from_single_process/run_torchrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_2 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=standard-g 5 | #SBATCH --nodes=1 6 | #SBATCH --gpus-per-node=8 7 | #SBATCH --ntasks-per-node=1 # we start a single torchrun process, which will take care of spawning more 8 | #SBATCH --cpus-per-task=56 # 7 cores per GPU 9 | #SBATCH --mem-per-gpu=60G 10 | #SBATCH --time=0:20:00 11 | 12 | # Set up the software environment 13 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 14 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 15 | # If you are interested, you can check the exact paths being mounted from 16 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 17 | module purge 18 | module use /appl/local/containers/ai-modules 19 | module load singularity-AI-bindings 20 | 21 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 22 | 23 | # Some environment variables to set up cache directories 24 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 25 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 26 | export TORCH_HOME=$SCRATCH/torch-cache 27 | export HF_HOME=$FLASH/hf-cache 28 | mkdir -p $TORCH_HOME $HF_HOME 29 | 30 | # Disable internal parallelism of huggingface's tokenizer since we 31 | # want to retain direct control of parallelism options. 32 | export TOKENIZERS_PARALLELISM=false 33 | 34 | # Path to where the trained model and logging data will go 35 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 36 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 37 | export MODEL_NAME=gpt-imdb-model-multigpu 38 | 39 | set -xv # print the command so that we can verify setting arguments correctly from the logs 40 | 41 | # Since we start only one task with slurm which then starts subprocesses, we cannot use slurm to configure CPU binds. 42 | # Therefore we need to set them up in the Python code itself. 43 | 44 | srun singularity exec $CONTAINER \ 45 | torchrun --standalone \ 46 | --nnodes=1 \ 47 | --nproc-per-node=${SLURM_GPUS_PER_NODE} \ 48 | GPT-neo-IMDB-finetuning.py \ 49 | --model-name $MODEL_NAME \ 50 | --output-path $OUTPUT_DIR \ 51 | --logging-path $LOGGING_DIR \ 52 | --num-workers $(( SLURM_CPUS_PER_TASK / SLURM_GPUS_PER_NODE )) 53 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/prints_only_from_single_process/slurm-9304946.out: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/ce4d7cc37162350ff0d620d0301b8f9e7f0b5c98/08_Scaling_to_multiple_GPUs/reference_solution/prints_only_from_single_process/slurm-9304946.out -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/prints_only_from_single_process/util.py: -------------------------------------------------------------------------------- 1 | ../util.py -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/run_no_torchrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_2 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=standard-g 5 | #SBATCH --nodes=1 6 | #SBATCH --gpus-per-node=8 7 | #SBATCH --ntasks-per-node=8 # we want one process per GPU 8 | #SBATCH --cpus-per-task=7 9 | #SBATCH --mem-per-gpu=60G 10 | #SBATCH --time=0:20:00 11 | 12 | # Set up the software environment 13 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 14 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 15 | # If you are interested, you can check the exact paths being mounted from 16 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 17 | module purge 18 | module use /appl/local/containers/ai-modules 19 | module load singularity-AI-bindings 20 | 21 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 22 | 23 | # Some environment variables to set up cache directories 24 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 25 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 26 | export TORCH_HOME=$SCRATCH/torch-cache 27 | export HF_HOME=$FLASH/hf-cache 28 | mkdir -p $TORCH_HOME $HF_HOME 29 | 30 | # Disable internal parallelism of huggingface's tokenizer since we 31 | # want to retain direct control of parallelism options. 32 | export TOKENIZERS_PARALLELISM=false 33 | 34 | # Path to where the trained model and logging data will go 35 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 36 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 37 | export MODEL_NAME=gpt-imdb-model-multigpu-no-torchrun 38 | 39 | set -xv # print the command so that we can verify setting arguments correctly from the logs 40 | 41 | # Set up variables to control distributed PyTorch training 42 | export MASTER_ADDR=$(hostname) 43 | export MASTER_PORT=25900 44 | export WORLD_SIZE=$SLURM_NPROCS 45 | export LOCAL_WORLD_SIZE=$SLURM_GPUS_PER_NODE 46 | 47 | # As opposed to the example in `run_torchrun.sh`, we can set the CPU binds directly via the slurm command, since we have 48 | # one task per GPU. In this case we do NOT need to set them from within the Python code itself. 49 | srun singularity exec $CONTAINER \ 50 | bash -c "RANK=\$SLURM_PROCID \ 51 | LOCAL_RANK=\$SLURM_LOCALID \ 52 | python GPT-neo-IMDB-finetuning.py \ 53 | --model-name $MODEL_NAME \ 54 | --output-path $OUTPUT_DIR \ 55 | --logging-path $LOGGING_DIR \ 56 | --num-workers ${SLURM_CPUS_PER_TASK}" 57 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/run_torchrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_2 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=standard-g 5 | #SBATCH --nodes=1 6 | #SBATCH --gpus-per-node=8 7 | #SBATCH --ntasks-per-node=1 # we start a single torchrun process, which will take care of spawning more 8 | #SBATCH --cpus-per-task=56 # 7 cores per GPU 9 | #SBATCH --mem-per-gpu=60G 10 | #SBATCH --time=0:20:00 11 | 12 | # Set up the software environment 13 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 14 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 15 | # If you are interested, you can check the exact paths being mounted from 16 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 17 | module purge 18 | module use /appl/local/containers/ai-modules 19 | module load singularity-AI-bindings 20 | 21 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 22 | 23 | # Some environment variables to set up cache directories 24 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 25 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 26 | export TORCH_HOME=$SCRATCH/torch-cache 27 | export HF_HOME=$FLASH/hf-cache 28 | mkdir -p $TORCH_HOME $HF_HOME 29 | 30 | # Disable internal parallelism of huggingface's tokenizer since we 31 | # want to retain direct control of parallelism options. 32 | export TOKENIZERS_PARALLELISM=false 33 | 34 | # Path to where the trained model and logging data will go 35 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 36 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 37 | export MODEL_NAME=gpt-imdb-model-multigpu 38 | 39 | set -xv # print the command so that we can verify setting arguments correctly from the logs 40 | 41 | # Since we start only one task with slurm which then starts subprocesses, we cannot use slurm to configure CPU binds. 42 | # Therefore we need to set them up in the Python code itself. 43 | 44 | srun singularity exec $CONTAINER \ 45 | torchrun --standalone \ 46 | --nnodes=1 \ 47 | --nproc-per-node=${SLURM_GPUS_PER_NODE} \ 48 | GPT-neo-IMDB-finetuning.py \ 49 | --model-name $MODEL_NAME \ 50 | --output-path $OUTPUT_DIR \ 51 | --logging-path $LOGGING_DIR \ 52 | --num-workers $(( SLURM_CPUS_PER_TASK / SLURM_GPUS_PER_NODE )) 53 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/util.py: -------------------------------------------------------------------------------- 1 | ../util.py -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/with_cpu_bindings/GPT-neo-IMDB-finetuning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # IMDB movie review text generation 5 | # 6 | # In this script, we'll fine-tune a GPT2-like model to generate more 7 | # movie reviews based on a prompt. 8 | # 9 | # Partly based on this tutorial: 10 | # https://github.com/omidiu/GPT-2-Fine-Tuning/ 11 | 12 | 13 | import argparse 14 | import math 15 | import os 16 | import time 17 | from pprint import pprint 18 | 19 | import psutil 20 | import torch 21 | from datasets import load_dataset 22 | from util import preprocess_data, get_output_paths 23 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 24 | DataCollatorForLanguageModeling, Trainer, 25 | TrainingArguments) 26 | 27 | 28 | def set_cpu_affinity(local_rank): 29 | LUMI_GPU_CPU_map = { 30 | # A mapping from GCD to the closest CPU cores in a LUMI-G node 31 | # Note that CPU cores 0, 8, 16, 24, 32, 40, 48, 56 are reserved for the 32 | # system and not available for the user 33 | # See https://docs.lumi-supercomputer.eu/hardware/lumig/ 34 | 0: [49, 50, 51, 52, 53, 54, 55], 35 | 1: [57, 58, 59, 60, 61, 62, 63], 36 | 2: [17, 18, 19, 20, 21, 22, 23], 37 | 3: [25, 26, 27, 28, 29, 30, 31], 38 | 4: [1, 2, 3, 4, 5, 6, 7], 39 | 5: [9, 10, 11, 12, 13, 14, 15], 40 | 6: [33, 34, 35, 36, 37, 38, 39], 41 | 7: [41, 42, 43, 44, 45, 46, 47], 42 | } 43 | cpu_list = LUMI_GPU_CPU_map[local_rank] 44 | print(f"Rank {rank} (local {local_rank}) binding to cpus: {cpu_list}") 45 | psutil.Process().cpu_affinity(cpu_list) 46 | 47 | 48 | if __name__ == "__main__": 49 | 50 | # First we set up some command line arguments to allow us to specify data/output paths 51 | # and the number of worker processes without changing the code. 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument( 54 | "--model-name", 55 | type=str, 56 | default="gpt-imdb-model", 57 | help="A name for the trained model under. A subdirectory with the given name will be created under the `output-path`.", 58 | ) 59 | parser.add_argument( 60 | "--output-path", 61 | type=str, 62 | help="The root directory under which model checkpoints are stored.", 63 | ) 64 | parser.add_argument( 65 | "--logging-path", 66 | type=str, 67 | help="The root directory under which logging data (for tensorboard) are stored.", 68 | ) 69 | parser.add_argument( 70 | "--num-workers", 71 | type=int, 72 | default=1, 73 | help="The number of CPU worker processes to use.", 74 | ) 75 | parser.add_argument( 76 | "--set-cpu-binds", 77 | default=False, 78 | action="store_true", 79 | help="Bind the process to the CPU cores closest to the GPU used by the process (identified by the LOCAL_RANK environment variable).", 80 | ) 81 | args, _ = parser.parse_known_args() 82 | 83 | # Read the environment variables provided by torchrun 84 | rank = int(os.environ["RANK"]) 85 | local_rank = int(os.environ["LOCAL_RANK"]) 86 | world_size = int(os.environ["WORLD_SIZE"]) 87 | local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 88 | 89 | # Set up CPU binding if --set-cpu-binds is given 90 | if args.set_cpu_binds: 91 | set_cpu_affinity(local_rank) 92 | 93 | # Then we determine the device on which to train the model. 94 | print("Using PyTorch version:", torch.__version__) 95 | if torch.cuda.is_available(): 96 | print( 97 | f"Rank {rank} of {world_size} (local: {local_rank}) sees {torch.cuda.device_count()} devices" 98 | ) 99 | device = torch.device("cuda", local_rank) 100 | print("Using GPU, device name:", torch.cuda.get_device_name(device)) 101 | else: 102 | print("No GPU found, using CPU instead.") 103 | device = torch.device("cpu") 104 | 105 | # We also ensure that output paths exist 106 | output_dir, logging_dir = get_output_paths(args) 107 | 108 | # #### Loading the GPT-neo model 109 | # 110 | # We'll use the gpt-neo-1.3B model from the Hugging Face library: 111 | # https://huggingface.co/EleutherAI/gpt-neo-1.3B 112 | # Let's start with getting the appropriate tokenizer. 113 | pretrained_model = "EleutherAI/gpt-neo-1.3B" 114 | 115 | print("Loading model and tokenizer") 116 | start = time.time() 117 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True) 118 | tokenizer.pad_token = tokenizer.eos_token 119 | 120 | # Load the actual base model from Hugging Face 121 | model = AutoModelForCausalLM.from_pretrained(pretrained_model) 122 | model.to(device) 123 | stop = time.time() 124 | print(f"Loading model and tokenizer took: {stop-start:.2f} seconds") 125 | 126 | # #### Loading the IMDb data set 127 | # 128 | # Next we'll load the IMDb data set: https://huggingface.co/docs/datasets/index. 129 | # 130 | # The data set contains 100,000 movies reviews from the Internet Movie 131 | # Database, split into 25,000 reviews for training and 25,000 reviews 132 | # for testing and 50,000 without labels (unsupervised) that we also use for training. 133 | 134 | train_dataset = load_dataset( 135 | "imdb", split="train+unsupervised", trust_remote_code=False, keep_in_memory=True 136 | ) 137 | eval_dataset = load_dataset( 138 | "imdb", split="test", trust_remote_code=False, keep_in_memory=True 139 | ) 140 | 141 | # Let's print one sample from the dataset. 142 | print("Sample from dataset") 143 | pprint(train_dataset[200]) 144 | 145 | # #### Setting up the training configuration 146 | global_train_batch_size = 32 # We keep the overall batch size (across all GPUs) the same as before ... 147 | per_device_train_batch_size = global_train_batch_size // world_size # ... which means we divide by the number of processes for the batch size of each GPU 148 | eval_batch_size = 128 # No optimizer state during evaluation, so can use bigger batches for increased throughput 149 | 150 | training_args = TrainingArguments( 151 | output_dir=output_dir, 152 | save_strategy="steps", 153 | save_steps=100, 154 | save_total_limit=4, 155 | logging_dir=logging_dir, 156 | eval_strategy="steps", 157 | eval_steps=200, # compute validation loss every 200 steps 158 | learning_rate=2e-5, 159 | weight_decay=0.01, 160 | bf16=True, # use 16-bit floating point precision 161 | per_device_train_batch_size=per_device_train_batch_size, 162 | per_device_eval_batch_size=eval_batch_size, 163 | max_steps=1000, 164 | dataloader_num_workers=args.num_workers, 165 | dataloader_pin_memory=True, 166 | report_to=["tensorboard"], # log statistics for tensorboard 167 | ddp_find_unused_parameters=False, # there are no unused parameters, causing PyTorch to issue a warning should this be set to True 168 | ) 169 | 170 | # #### Preprocessing of training data 171 | # We tokenize the data into torch tensors, split training into training and validation and set up a collator that 172 | # is able to arrange single data samples into batches. 173 | 174 | train_dataset_tokenized, validate_dataset_tokenized, eval_dataset_tokenized = preprocess_data(train_dataset, eval_dataset, tokenizer, training_args) 175 | 176 | collator = DataCollatorForLanguageModeling( 177 | tokenizer, mlm=False, return_tensors="pt" 178 | ) 179 | 180 | # Sanity check: How does the training data look like after preprocessing? 181 | print("Sample of tokenized data") 182 | for b in train_dataset_tokenized: 183 | pprint(b, compact=True) 184 | print("Length of input_ids:", len(b["input_ids"])) 185 | break 186 | print("Length of dataset (tokenized)", len(train_dataset_tokenized)) 187 | 188 | # #### Training 189 | # We use the Hugging Face trainer instead of a manual training loop. 190 | # 191 | # You can read about the many, many different parameters to the 192 | # Hugging Face trainer here: 193 | # https://huggingface.co/docs/transformers/v4.37.0/en/main_classes/trainer#transformers.TrainingArguments 194 | # 195 | 196 | trainer = Trainer( 197 | model=model, 198 | args=training_args, 199 | tokenizer=tokenizer, 200 | data_collator=collator, 201 | train_dataset=train_dataset_tokenized, 202 | eval_dataset=validate_dataset_tokenized, 203 | ) 204 | 205 | # With 1000 steps, batch size 32 and a single GCD, this should take just under 30 minutes. 206 | trainer.train() 207 | 208 | print() 209 | print("Training done, you can find all the model checkpoints in", output_dir) 210 | 211 | # #### Evaluating the finetuned model 212 | with torch.no_grad(): 213 | model.eval() 214 | # Calculate perplexity 215 | eval_results = trainer.evaluate() 216 | test_results = trainer.evaluate(eval_dataset_tokenized) 217 | 218 | print(f'Perplexity on validation: {math.exp(eval_results["eval_loss"]):.2f}') 219 | print(f'Perplexity on test: {math.exp(test_results["eval_loss"]):.2f}') 220 | 221 | # Let's print a few sample generated reviews; this is the same as in the previous exercise 222 | # but now we use the finetuned model 223 | prompt = "The movie 'How to run ML on LUMI - A documentation' was great because" 224 | inputs = tokenizer(prompt, return_tensors="pt").to(device) 225 | outputs = model.generate( 226 | **inputs, do_sample=True, max_length=80, num_return_sequences=4 227 | ) 228 | decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) 229 | 230 | print("Sample generated review:") 231 | for txt in decoded_outputs: 232 | print("-", txt) 233 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/with_cpu_bindings/run_no_torchrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_2 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=standard-g 5 | #SBATCH --nodes=1 6 | #SBATCH --gpus-per-node=8 7 | #SBATCH --ntasks-per-node=8 # we want one process per GPU 8 | #SBATCH --cpus-per-task=7 9 | #SBATCH --mem-per-gpu=60G 10 | #SBATCH --time=0:20:00 11 | 12 | # Set up the software environment 13 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 14 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 15 | # If you are interested, you can check the exact paths being mounted from 16 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 17 | module purge 18 | module use /appl/local/containers/ai-modules 19 | module load singularity-AI-bindings 20 | 21 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 22 | 23 | # Some environment variables to set up cache directories 24 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 25 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 26 | export TORCH_HOME=$SCRATCH/torch-cache 27 | export HF_HOME=$FLASH/hf-cache 28 | mkdir -p $TORCH_HOME $HF_HOME 29 | 30 | # Disable internal parallelism of huggingface's tokenizer since we 31 | # want to retain direct control of parallelism options. 32 | export TOKENIZERS_PARALLELISM=false 33 | 34 | # Path to where the trained model and logging data will go 35 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 36 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 37 | export MODEL_NAME=gpt-imdb-model-multigpu-no-torchrun 38 | 39 | set -xv # print the command so that we can verify setting arguments correctly from the logs 40 | 41 | # Set up variables to control distributed PyTorch training 42 | export MASTER_ADDR=$(hostname) 43 | export MASTER_PORT=25900 44 | export WORLD_SIZE=$SLURM_NPROCS 45 | export LOCAL_WORLD_SIZE=$SLURM_GPUS_PER_NODE 46 | 47 | # As opposed to the example in `run_torchrun.sh`, we can set the CPU binds directly via the slurm command, since we have 48 | # one task per GPU. In this case we do NOT need to set them from within the Python code itself. 49 | 50 | # Set up the CPU bind masks (can only be used with full node runs (standard-g or small-g with slurm argument `--exclusive`)) 51 | CPU_BIND_MASKS="0x00fe000000000000,0xfe00000000000000,0x0000000000fe0000,0x00000000fe000000,0x00000000000000fe,0x000000000000fe00,0x000000fe00000000,0x0000fe0000000000" 52 | 53 | # tell slurm to configure the cpu binds specified by the mask, additional option v prints to configuration to the logs 54 | srun --cpu-bind=v,mask_cpu=$CPU_BIND_MASKS \ 55 | singularity exec $CONTAINER \ 56 | bash -c "RANK=\$SLURM_PROCID \ 57 | LOCAL_RANK=\$SLURM_LOCALID \ 58 | python GPT-neo-IMDB-finetuning.py \ 59 | --model-name $MODEL_NAME \ 60 | --output-path $OUTPUT_DIR \ 61 | --logging-path $LOGGING_DIR \ 62 | --num-workers ${SLURM_CPUS_PER_TASK}" 63 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/with_cpu_bindings/run_torchrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_2 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=standard-g 5 | #SBATCH --nodes=1 6 | #SBATCH --gpus-per-node=8 7 | #SBATCH --ntasks-per-node=1 # we start a single torchrun process, which will take care of spawning more 8 | #SBATCH --cpus-per-task=56 # 7 cores per GPU 9 | #SBATCH --mem-per-gpu=60G 10 | #SBATCH --time=0:20:00 11 | 12 | # Set up the software environment 13 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 14 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 15 | # If you are interested, you can check the exact paths being mounted from 16 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 17 | module purge 18 | module use /appl/local/containers/ai-modules 19 | module load singularity-AI-bindings 20 | 21 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 22 | 23 | # Some environment variables to set up cache directories 24 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 25 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 26 | export TORCH_HOME=$SCRATCH/torch-cache 27 | export HF_HOME=$FLASH/hf-cache 28 | mkdir -p $TORCH_HOME $HF_HOME 29 | 30 | # Disable internal parallelism of huggingface's tokenizer since we 31 | # want to retain direct control of parallelism options. 32 | export TOKENIZERS_PARALLELISM=false 33 | 34 | # Path to where the trained model and logging data will go 35 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 36 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 37 | export MODEL_NAME=gpt-imdb-model-multigpu 38 | 39 | set -xv # print the command so that we can verify setting arguments correctly from the logs 40 | 41 | # Since we start only one task with slurm which then starts subprocesses, we cannot use slurm to configure CPU binds. 42 | # Therefore we need to set them up in the Python code itself. 43 | 44 | srun singularity exec $CONTAINER \ 45 | torchrun --standalone \ 46 | --nnodes=1 \ 47 | --nproc-per-node=${SLURM_GPUS_PER_NODE} \ 48 | GPT-neo-IMDB-finetuning.py \ 49 | --model-name $MODEL_NAME \ 50 | --output-path $OUTPUT_DIR \ 51 | --logging-path $LOGGING_DIR \ 52 | --num-workers $(( SLURM_CPUS_PER_TASK / SLURM_GPUS_PER_NODE )) \ 53 | --set-cpu-binds # enable setting of the CPU binds in the training script (can only be used with full node runs (standard-g or small-g with slurm argument `--exclusive`)) 54 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/reference_solution/with_cpu_bindings/util.py: -------------------------------------------------------------------------------- 1 | ../util.py -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=project_465001958 3 | #SBATCH --reservation=AI_workshop_2 # comment this out if the reservation is no longer available 4 | #SBATCH --partition=... 5 | ## 6 | 7 | # Set up the software environment 8 | # NOTE: the loaded module makes relevant filesystem locations available inside the singularity container 9 | # (/scratch, /project, etc) as well as mounts some important system libraries that are optimized for LUMI 10 | # If you are interested, you can check the exact paths being mounted from 11 | # /appl/local/containers/ai-modules/singularity-AI-bindings/24.03.lua 12 | module purge 13 | module use /appl/local/containers/ai-modules 14 | module load singularity-AI-bindings 15 | 16 | CONTAINER=/project/project_465001958/containers/pytorch_transformers.sif 17 | 18 | # Some environment variables to set up cache directories 19 | SCRATCH="/scratch/${SLURM_JOB_ACCOUNT}" 20 | FLASH="/flash/${SLURM_JOB_ACCOUNT}" 21 | export TORCH_HOME=$SCRATCH/torch-cache 22 | export HF_HOME=$FLASH/hf-cache 23 | mkdir -p $TORCH_HOME $HF_HOME 24 | 25 | # Disable internal parallelism of huggingface's tokenizer since we 26 | # want to retain direct control of parallelism options. 27 | export TOKENIZERS_PARALLELISM=false 28 | 29 | # Path to where the trained model and logging data will go 30 | export OUTPUT_DIR=$SCRATCH/$USER/data/ 31 | export LOGGING_DIR=$SCRATCH/$USER/runs/ 32 | export MODEL_NAME=gpt-imdb-model-multigpu 33 | 34 | ## 35 | -------------------------------------------------------------------------------- /08_Scaling_to_multiple_GPUs/util.py: -------------------------------------------------------------------------------- 1 | ../03_Your_first_AI_training_job_on_LUMI/util.py -------------------------------------------------------------------------------- /09_Extreme_scale_AI/README.md: -------------------------------------------------------------------------------- 1 | # 09 Extreme scale AI 2 | 3 | These examples are based on the ROCm container provided to you at: 4 | ``` 5 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif 6 | ``` 7 | 8 | The examples also assume there is an allocation in place to be used for one or more nodes. That could be accomplished with, e.g.: 9 | ``` 10 | N=2 ; salloc -p standard-g --account=project_465001958 --reservation=AI_workshop_2 --threads-per-core 1 --exclusive -N $N --gpus $((N*8)) -t 1:00:00 --mem 0 11 | ``` 12 | 13 | With the allocation and container set we can do a quick smoke test to make sure Pytorch can detect the GPUs available in a node: 14 | ``` 15 | srun singularity exec \ 16 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 17 | bash -c '$WITH_CONDA ; \ 18 | python -c "import torch; print(torch.cuda.device_count())"' 19 | ``` 20 | Each node of the allocation should report `8` GPUs available. 21 | 22 | ## Things to consider when scaling up a training model 23 | 24 | If you have a model that already uses multiple GPUs, scaling it further should not require modifications. It should just work accross nodes as it does accross a node. 25 | 26 | However, there are a few performance implications that result from how GPUs are rigged inside the node: 27 | 28 | ![image](https://docs.olcf.ornl.gov/_images/Frontier_Node_Diagram.jpg) 29 | 30 | One of the corollary of this image has to do with the CPU binding to match the GPUs, which we already know we can accomplish with the SLURM option: 31 | ``` 32 | CPU_BIND_MASKS="0x00fe000000000000,0xfe00000000000000,0x0000000000fe0000,0x00000000fe000000,0x00000000000000fe,0x000000000000fe00,0x000000fe00000000,0x0000fe0000000000" 33 | 34 | --cpu-bind=mask_cpu=$CPU_BIND_MASKS 35 | ``` 36 | 37 | You see GPUs connect directly to the high-speed interfaces to comunicate with GPUs in other nodes. The communication library used by most AI applications is RCCL and it is a good idea to tell RCCL what interfaces it should use. This can be accomplished with: 38 | ``` 39 | export NCCL_SOCKET_IFNAME=hsn0,hsn1,hsn2,hsn3 40 | ``` 41 | We should also direct RCCL to do GPU RDMA whenever it can with: 42 | ``` 43 | export NCCL_NET_GDR_LEVEL=PHB 44 | ``` 45 | 46 | The other aspect to this is that RCCL needs to be able to comunicate with the network provider. When the application starts, RCCL will look for plugins that it could leverage to effect this communication. In the LUMI case we leverage AWS CXI plugin that has been hipified to support AMD GPUs: https://github.com/ROCm/aws-ofi-rccl. 47 | 48 | Luckily, if you use the containers provided, this plugin is already available there so you don't need to worry about it. However, if you get some container from the internet most likely it won't have it so you need to be careful under the penalty of your scaling accross nodes being poor. There are a few environment variables that can help inform us on what RCCL is doing, e.g.: 49 | ``` 50 | export NCCL_DEBUG=INFO 51 | export NCCL_DEBUG_SUBSYS=INIT,COLL 52 | export NCCL_DEBUG_FILE=/tmp/$(whoami)-rccl-rank$SLURM_PROCID.txt 53 | ``` 54 | This will produce information for the initialization of RCCL as well as the execution of collectives. It is also a good idea to direct this information to a file with `NCCL_DEBUG_FILE` as stdout performance might start affecting progress. 55 | 56 | As we learned before, there is a master rank which is going to coordinate the execution and all ranks must know who that is. When we were using a single node we did: 57 | ``` 58 | export MASTER_ADDR=$(hostname) 59 | ``` 60 | However, with multiple nodes this is not good enough as all ranks, regardless of the node, need to know what is the first node of the allocation. Luckily, SLURM can help. We can get the first node of an allocation with: 61 | ``` 62 | export MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n1) 63 | ``` 64 | Note, however, there are no SLURM tools inside the container, so the master address must be defined outside the container and propagated inside. 65 | 66 | Another aspect that we should consider when scalling is to direct the cache folders of some of the libraries, namely MIOpen. MIOpen provides many optimized AI kernels and does just-in-time compilation which results are cached in your home folder by default. With many nodes racing for that cache, the file system might not cope with locks required, so we recommend setting it to each node `/tmp`. If your application relies heavily on these caches you can save them at the end of your job execution and reinstante it at the begginning of your next job. 67 | 68 | Selecting different cache locations is fairly easy by setting the environment variables: 69 | 70 | ``` 71 | export MIOPEN_USER_DB_PATH="/tmp/$(whoami)-miopen-cache-$SLURM_NODEID" 72 | export MIOPEN_CUSTOM_CACHE_DIR=$MIOPEN_USER_DB_PATH 73 | ``` 74 | 75 | ## LLM hands-on exercises 76 | 77 | We'll continue with our LLM example to explore our scaling oportunities. You might be interested in collating the different steps in a batch script or run interactively as presented. But first... 78 | 79 | ### 1. Setting some run scripts is a great idea! 80 | 81 | There are a lot of components to set and monitor the right environment for our training jobs as discussed above. Setup a run script with all the relevant bits so that then you allow yourself to forget about it. Can you do that? Let's call it `run.sh` and we have an example in `reference_solution` folder. 82 | 83 | ### 2. Scalling our LLM example. 84 | 85 | Let's recover our multiple GPU LLM training application: 86 | ``` 87 | curl -o GPT-neo-IMDB-finetuning.py -L https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/raw/main/03_Your_first_AI_training_job_on_LUMI/reference_solution/GPT-neo-IMDB-finetuning.py 88 | curl -o util.py -L https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/raw/main/03_Your_first_AI_training_job_on_LUMI/util.py 89 | ``` 90 | The only change we will do is selecting a different thread multiprocessing strategy. We will add: 91 | ``` 92 | torch.multiprocessing.set_start_method('spawn') 93 | ``` 94 | right after: 95 | ``` 96 | if __name__ == "__main__": 97 | ``` 98 | This is meant to workaround an issue in Pytorch around the registration of threads (https://github.com/pytorch/pytorch/issues/119845). 99 | The updated file should be named `GPT-neo-IMDB-finetuning-mp.py`. 100 | 101 | One might need to reduce the training batch size from 32 if we use too much memory. That can be done by setting: 102 | ``` 103 | train_batch_size = 24 104 | ``` 105 | 106 | Now we can run in a single node: 107 | ``` 108 | MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n1) \ 109 | srun -N1 -n8 --gpus 8 \ 110 | --cpu-bind=mask_cpu=0x00fe000000000000,0xfe00000000000000,0x0000000000fe0000,0x00000000fe000000,0x00000000000000fe,0x000000000000fe00,0x000000fe00000000,0x0000fe0000000000\ 111 | singularity exec \ 112 | -B .:/workdir \ 113 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 114 | /workdir/run.sh \ 115 | python -u /workdir/GPT-neo-IMDB-finetuning-mp.py \ 116 | --model-name gpt-imdb-model \ 117 | --output-path /workdir/train-output \ 118 | --logging-path /workdir/train-logging \ 119 | --num-workers 7 120 | ``` 121 | or in a couple of nodes: 122 | ``` 123 | MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n1) \ 124 | srun -N2 -n16 --gpus 16 \ 125 | --cpu-bind=mask_cpu=0x00fe000000000000,0xfe00000000000000,0x0000000000fe0000,0x00000000fe000000,0x00000000000000fe,0x000000000000fe00,0x000000fe00000000,0x0000fe0000000000\ 126 | singularity exec \ 127 | -B /var/spool/slurmd \ 128 | -B /opt/cray \ 129 | -B /usr/lib64/libcxi.so.1 \ 130 | -B .:/workdir \ 131 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 132 | /workdir/run.sh \ 133 | python -u /workdir/GPT-neo-IMDB-finetuning-mp.py \ 134 | --model-name gpt-imdb-model \ 135 | --output-path /workdir/train-output \ 136 | --logging-path /workdir/train-logging \ 137 | --num-workers 7 138 | ``` 139 | Notice that, the moment one wants to run accross nodes, binding `/var/spool/slurmd`, `/opt/cray` and `/usr/lib64/libcxi.so.1` is a requirement. 140 | 141 | ### 3. Monitoring GPU activity 142 | 143 | We can monitor activity as before. However, if you want to use the profiler when multiple ranks are being run, it makes more sense to profile a few selected ones, otherwise will be too much overhead. Most AI training is balanced, so what we see in a rank can be extrapolated to others. 144 | 145 | To profile just a single rank you can create a copy of your run script, let's call it `run-profile.sh` and replace the last `eval` command with: 146 | ``` 147 | pcmd='' 148 | if [ $RANK -eq 14 ]; then 149 | pcmd='rocprof --hip-trace --stats' 150 | fi 151 | 152 | eval "$pcmd $@" 153 | ``` 154 | this will only profile rank number 14. You could select any other rank. 155 | We can now use the same strategy as before to profile just 32 steps and then run with: 156 | ``` 157 | MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n1) \ 158 | srun -N2 -n16 --gpus 16 \ 159 | --cpu-bind=mask_cpu=0x00fe000000000000,0xfe00000000000000,0x0000000000fe0000,0x00000000fe000000,0x00000000000000fe,0x000000000000fe00,0x000000fe00000000,0x0000fe0000000000\ 160 | singularity exec \ 161 | -B /var/spool/slurmd \ 162 | -B /opt/cray \ 163 | -B /usr/lib64/libcxi.so.1 \ 164 | -B .:/workdir \ 165 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 166 | /workdir/run-profile.sh \ 167 | python -u /workdir/GPT-neo-IMDB-finetuning-mp.py \ 168 | --model-name gpt-imdb-model \ 169 | --output-path /workdir/train-output \ 170 | --logging-path /workdir/train-logging \ 171 | --num-workers 7 172 | ``` 173 | The resulting profile for the 32 steps would look like: 174 | ![image](https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/raw/main/09_Extreme_scale_AI/images/profile.png) 175 | 176 | Zooming in, we can see the RCCL activity. The moment these kernels dominate the profile we start to be network bound. 177 | ![image](https://github.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/raw/main/09_Extreme_scale_AI/images/profile-detail.png) 178 | 179 | 180 | ## Computer vision hands-on exercises: 181 | 182 | Large language models is one of the main usecases in production these days. However, applications from computer vision are still relevant. The good news is that a lot of the same concepts discussed before can be easily applied to run this application. 183 | 184 | ### 1. Preparing our computer vision example 185 | Let's grab one of the official pytorch examples for image classification: 186 | ``` 187 | curl -L \ 188 | -o cv_example.py \ 189 | https://github.com/pytorch/examples/raw/main/imagenet/main.py 190 | ``` 191 | Let's just add the same fix as before by adding: 192 | ``` 193 | torch.multiprocessing.set_start_method('spawn') 194 | ``` 195 | right after: 196 | ``` 197 | def main(): 198 | ``` 199 | 200 | ### 2. Know where the data lives 201 | 202 | We have downloaded in advance the data set (ImageNet) as that is a time consuming process. The image classification labels are controlled by the folder naming of the data set that contains many files in a compressed image format. To be able to make quicker progress, we also created a trimmed down version of the data-set with just a fraction of the classes. 203 | 204 | Here's how the data is organized: 205 | * Reduced set in scratch storage: 206 | * /scratch/project_465001958/data-sets/data-resnet-small 207 | * Reduced set in flash storage: 208 | * /flash/project_465001958/data-sets/data-resnet-small 209 | 210 | * Tarball container for the data set: 211 | * /flash/project_465001958/data-sets/data-resnet-small.tar 212 | 213 | The container is useful to move the data around as it is much faster to move a single large file rather than many small files, e.g. it is better to untar a container than copy an expanded dataset from elsewhere. The folders `/scratch` and `/flash` contain symbolic links so it is important to mount in your containers `/pfs` as these links are pointing there. 214 | 215 | ### 3. Training at scale 216 | We are ready to run with one or more nodes (adjust `N` for the number of nodes) just by issuing: 217 | 218 | ``` 219 | N=1 ; \ 220 | srun -N $N -n $((N*8)) --gpus $((N*8)) \ 221 | --cpu-bind=mask_cpu=0x00fe000000000000,0xfe00000000000000,0x0000000000fe0000,0x00000000fe000000,0x00000000000000fe,0x000000000000fe00,0x000000fe00000000,0x0000fe0000000000\ 222 | singularity exec \ 223 | -B /var/spool/slurmd \ 224 | -B /opt/cray \ 225 | -B /usr/lib64/libcxi.so.1 \ 226 | -B .:/workdir \ 227 | -B /flash -B /pfs \ 228 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 229 | /workdir/run.sh \ 230 | python -u /workdir/cv_example.py \ 231 | -a resnet50 \ 232 | --batch-size $((8*512)) \ 233 | --workers $((8*7)) \ 234 | --gpu \$SLURM_LOCALID \ 235 | --world-size \$SLURM_NPROCS \ 236 | --rank \$SLURM_PROCID \ 237 | --dist-url "tcp://$(scontrol show hostname "$SLURM_NODELIST" | head -n1):45678" \ 238 | --dist-backend 'nccl' \ 239 | --epochs 2 \ 240 | /flash/project_465001958/data-sets/data-resnet-small 241 | ``` 242 | Here we are doing training using ResNet-50 over 2 epochs with 512 batch-size per GPU. We use the same 7 workers as before. The dataset is given by the last argument - we use the small data set but you are free to try the complete one. The other arguments are similar to what we used before to translate information from the SLURM environment. 243 | 244 | ### 4. Monitor GPU activity 245 | 246 | This example leverages MIOpen, so you can check if `/tmp/$(whoami)-miopen-cache-$SLURM_NODEID` is being populated with the MIOpen databases in each node. Start a parallel SLURM session to that effect. 247 | 248 | Then, try monitor the GPU activity as before. You should be able to see snapshots such as this: 249 | 250 | ``` 251 | ======================= ROCm System Management Interface ======================= 252 | ================================= Concise Info ================================= 253 | GPU Temp AvgPwr SCLK MCLK Fan Perf PwrCap VRAM% GPU% 254 | 0 42.0c 128.0W 1650Mhz 1600Mhz 0% manual 500.0W 78% 100% 255 | 1 43.0c N/A 1650Mhz 1600Mhz 0% manual 0.0W 78% 100% 256 | 2 38.0c 132.0W 1650Mhz 1600Mhz 0% manual 500.0W 78% 100% 257 | 3 49.0c N/A 1650Mhz 1600Mhz 0% manual 0.0W 78% 100% 258 | 4 41.0c 111.0W 800Mhz 1600Mhz 0% manual 500.0W 78% 0% 259 | 5 43.0c N/A 1650Mhz 1600Mhz 0% manual 0.0W 78% 100% 260 | 6 45.0c 130.0W 1650Mhz 1600Mhz 0% manual 500.0W 78% 100% 261 | 7 47.0c N/A 1650Mhz 1600Mhz 0% manual 0.0W 78% 100% 262 | ================================================================================ 263 | ============================= End of ROCm SMI Log ============================== 264 | ``` 265 | It is like some GPUs are intermitently idle, holding everyone else back. This is related with some I/O bottlenecks, as images have the potential of putting more strain on the filesystem. We'll learn more about that in the next session. 266 | 267 | ## Distributed training frameworks hands-on exercise 268 | 269 | Several frameworks for distributed training have been developed for different purposes and different levels of integration in the serial section of the model. E.g. Horovod (https://horovod.ai/) supply an MPI-like approach to control sharing of data and providing wrappers for the local operations of the most popular AI frameworks like Pytorch and TensorFlow. Others, like DeepSpeed (https://github.com/microsoft/DeepSpeed) offer more optimized distributed operations tailored for specific problems/optimizers, being widely used in LLM. It also offer computer vision optimizations. 270 | 271 | ### 1. Prepare DeepSpeed example 272 | We can try one of the DeepSpeed examples on our setup similar to our computer vision example: 273 | ``` 274 | curl -L -o cv_example_ds.py \ 275 | https://github.com/microsoft/DeepSpeedExamples/raw/master/training/imagenet/main.py 276 | 277 | curl -LO \ 278 | https://github.com/microsoft/DeepSpeedExamples/raw/master/training/imagenet/config/ds_fp16_z1_config.json 279 | ``` 280 | Parse the files to create some understanding of the differences. 281 | 282 | ### 2. Running DeepSpeed with required dependencies 283 | This container has DeepSpeed already installed so we will leverage it: `/appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif`. 284 | 285 | You can run the example like the following, however some dependencies might be missing. Can you install those? Can you setup the `spawn` multiprocessing mode? 286 | ``` 287 | N=2 ; \ 288 | MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n1) \ 289 | srun -N $N -n $((N*8)) --gpus $((N*8)) \ 290 | --cpu-bind=mask_cpu=0x00fe000000000000,0xfe00000000000000,0x0000000000fe0000,0x00000000fe000000,0x00000000000000fe,0x000000000000fe00,0x000000fe00000000,0x0000fe0000000000\ 291 | singularity exec \ 292 | -B /var/spool/slurmd \ 293 | -B /opt/cray \ 294 | -B /usr/lib64/libcxi.so.1 \ 295 | -B .:/workdir \ 296 | -B /flash -B /pfs \ 297 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 298 | /workdir/run.sh \ 299 | python -u /workdir/cv_example_ds.py \ 300 | --deepspeed \ 301 | --deepspeed_config /workdir/ds_fp16_z1_config.json \ 302 | -a resnet50 \ 303 | --batch-size $((8*512)) \ 304 | --workers 7 \ 305 | --gpu \$SLURM_LOCALID \ 306 | --local_rank \$SLURM_LOCALID \ 307 | --world-size \$SLURM_NPROCS \ 308 | --epochs 2 \ 309 | /flash/project_465001958/data-sets/data-resnet-small 310 | ``` 311 | Note that, in spite of this being a similar example to what we tested before the options and their meaning changed a bit. E.g. the number of worker is per GPU in this case. 312 | 313 | 314 | ## I/O considerations hands-on exercise 315 | 316 | In our computer vision example, we experienced the I/O limits - we aimed at using flash storage. If one would have used scratch storage it would have been worse. However, the limits of flash storage wouldn't have let us have the complete set of files. So there are always tradeoffs we have to observe. 317 | 318 | ### 1. Play with datasets and training models 319 | 320 | You are welcome to try larger data-sets and from different storage types to see how that affects the training. The largest the model more time the initialization will take as the labels are being processed. You can also select smaller models, like Resnet-18, so that you make things less GPU-bound and observe more easily the challenges around the data input pipeline. 321 | 322 | ### 2. Pre-stage in memory 323 | 324 | If limited by I/O, we could try in-memory storage. LUMI nodes don't have local SSD but have significant ammount of memory, so that could be sufficient for your needs. To store data in memory it is sufficient to do it as files under `/tmp` as that lives in memory. So we can do: 325 | ``` 326 | srun tar -C /tmp -xf /flash/project_465001958/data-sets/data-resnet-small.tar 327 | ``` 328 | to expand the trimmed down data set into memory and then we can just our model training there: 329 | ``` 330 | N=1 ; \ 331 | srun -N $N -n $((N*8)) --gpus $((N*8)) \ 332 | --cpu-bind=mask_cpu=0x00fe000000000000,0xfe00000000000000,0x0000000000fe0000,0x00000000fe000000,0x00000000000000fe,0x000000000000fe00,0x000000fe00000000,0x0000fe0000000000\ 333 | singularity exec \ 334 | -B /var/spool/slurmd \ 335 | -B /opt/cray \ 336 | -B /usr/lib64/libcxi.so.1 \ 337 | -B .:/workdir \ 338 | -B /flash -B /pfs \ 339 | /appl/local/containers/sif-images/lumi-pytorch-rocm-6.1.3-python-3.12-pytorch-v2.4.1.sif \ 340 | /workdir/run.sh \ 341 | python -u /workdir/cv_example.py \ 342 | -a resnet50 \ 343 | --batch-size $((8*512)) \ 344 | --workers $((8*7)) \ 345 | --gpu \$SLURM_LOCALID \ 346 | --world-size \$SLURM_NPROCS \ 347 | --rank \$SLURM_PROCID \ 348 | --dist-url "tcp://$(scontrol show hostname "$SLURM_NODELIST" | head -n1):45678" \ 349 | --dist-backend 'nccl' \ 350 | --epochs 2 \ 351 | /tmp/data-resnet 352 | ``` 353 | ### 3. Monitor GPU activity 354 | Try monitor the activity and scale to more nodes. You see the training completes much faster as the I/O pipeline can keep up with the GPU demands. 355 | 356 | Note however, that in general, when scalling goes up, one tends to start using smaller batch sizes, which means less work per GPU, which can bring us back to an I/O bottleneck situation. 357 | -------------------------------------------------------------------------------- /09_Extreme_scale_AI/images/profile-detail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/ce4d7cc37162350ff0d620d0301b8f9e7f0b5c98/09_Extreme_scale_AI/images/profile-detail.png -------------------------------------------------------------------------------- /09_Extreme_scale_AI/images/profile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lumi-supercomputer/Getting_Started_with_AI_workshop/ce4d7cc37162350ff0d620d0301b8f9e7f0b5c98/09_Extreme_scale_AI/images/profile.png -------------------------------------------------------------------------------- /09_Extreme_scale_AI/reference_solution/README.md: -------------------------------------------------------------------------------- 1 | # 09 Extreme scale AI - Solution 2 | ## LLM hands-on exercises 3 | ### 1. Setting some run scripts is a great idea! 4 | 5 | Here's how to setup a wrapper script: 6 | 7 | ``` 8 | cat > run.sh << EOF 9 | #!/bin/bash -e 10 | 11 | # Report affinity 12 | echo "Rank \$SLURM_PROCID --> \$(taskset -p \$\$)" 13 | 14 | # Report GPUs 15 | if [ \$SLURM_LOCALID -eq 0 ] ; then 16 | rocm-smi 17 | else 18 | sleep 2 19 | fi 20 | 21 | # Start conda environment inside the container 22 | \$WITH_CONDA 23 | 24 | # Setting the caches relevant to our application. 25 | export TORCH_HOME=/workdir/torch-cache 26 | export HF_HOME=/workdir/hf-cache 27 | export TOKENIZERS_PARALLELISM=false 28 | 29 | # Tell RCCL to use only Slingshot interfaces and GPU RDMA 30 | export NCCL_SOCKET_IFNAME=hsn0,hsn1,hsn2,hsn3 31 | export NCCL_NET_GDR_LEVEL=PHB 32 | 33 | # Tell MIOpen where to store its cache 34 | export MIOPEN_USER_DB_PATH="/tmp/$(whoami)-miopen-cache-\$SLURM_NODEID" 35 | export MIOPEN_CUSTOM_CACHE_DIR=\$MIOPEN_USER_DB_PATH 36 | 37 | if [ \$SLURM_LOCALID -eq 0 ] ; then 38 | rm -rf \$MIOPEN_USER_DB_PATH 39 | mkdir -p \$MIOPEN_USER_DB_PATH 40 | else 41 | sleep 2 42 | fi 43 | 44 | # export NCCL_DEBUG=INFO 45 | # export NCCL_DEBUG_SUBSYS=INIT,COLL 46 | # export NCCL_DEBUG_FILE=/tmp/$(whoami)-rccl-rank\$SLURM_PROCID.txt 47 | 48 | # Translate SLURM environment 49 | 50 | export MASTER_PORT=25900 51 | export WORLD_SIZE=\$SLURM_NPROCS 52 | export LOCAL_WORLD_SIZE=8 53 | export RANK=\$SLURM_PROCID 54 | export LOCAL_RANK=\$SLURM_LOCALID 55 | 56 | set -x 57 | 58 | # Run application 59 | eval "\$@" 60 | 61 | EOF 62 | chmod +x run.sh 63 | ``` 64 | 65 | Let's take a look on what is going on here from top to bottom: 66 | * We leverage the `taskset` tool to report the affinity of the current process. This allows us to verify we are getting the affinity we expect. 67 | * Then, we report the GPUs available using rocm-smi. This is a smoke test that the GPUs are up and running. We do this only for the first rank in a node - that rank will have `SLURM_LOCALID` set to `0`. 68 | * Then, we setup our conda environment as well as a few other environment variables to control the Pytorch and HuggingFace caches for our application. 69 | * Then we configure RCCL to use the high-speed interfaces as well as GPU RDMA. 70 | * Next step is the MIOpen cache. We also have the first rank in each node creating the cache folder. Note that, this is not used by our LLM application as it doesn't use MIOpen kernels. However, it doesn't do any harm and we'll keep you covered for other models you might want to train. 71 | * Then, there are a few RCCL environment variables that you may chose to uncomment so as to get logging of the RCCL activity. 72 | * Next, we translate the SLURM environment to something that Pytorch distributed module understands. 73 | * Finally, the arguments of the run scrips are expanded and executed. 74 | -------------------------------------------------------------------------------- /10_Coupling_AI_and_HPC/README.md: -------------------------------------------------------------------------------- 1 | # 11 Coupling AI & HPC 2 | 3 | No other material than slides exists for this lecture. 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LUMI Supercomputer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training material for "Moving your AI training jobs to LUMI: A Hands-On Workshop" 2 | 3 | The workshop is held in Amsterdam, Netherlands, May 27-28, 2025. It is co-organized by the [LUMI User Support Team (LUST)](https://lumi-supercomputer.eu/user-support/), the [EuroCC National Competence Centers (NCCs) in Finland](https://www.eurocc-access.eu/about-us/meet-the-nccs/ncc-finland/), the [Danish e-Infrastructure Consortium](https://www.deic.dk/en) and [SURF](https://www.surf.nl/en), the IT cooperative of Dutch education and research institutions. The workshop is hosted by [SURF](https://www.surf.nl/en) and [EuroCC NL](https://eurocc-netherlands.nl/nl/). 4 | 5 | This repository contains Jupyter notebooks, jobscripts, and other files related to the examples and hands-on exercises presented with each lecture in the workshop. 6 | 7 | ## Structure of this repository 8 | 9 | All files related to each lecture are placed in a subfolder created for that lecture. Each such subfolder contains an `README.md` file that list the examples and exercises for that lecture. 10 | 11 | Additionally, the `bonus_material` subfolder contains training material that is not related to a specific lecture but rather the workshop in its entirety. Each such "set" of bonus material is placed in its own subfolder to the `bonus_material` folder. 12 | -------------------------------------------------------------------------------- /bonus_material/README.md: -------------------------------------------------------------------------------- 1 | # Bonus material 2 | 3 | - **exercise_container_recipes**: cotainr files for building the containers used in the workshop exercises. 4 | -------------------------------------------------------------------------------- /bonus_material/exercise_container_recipes/README.md: -------------------------------------------------------------------------------- 1 | # Exercise Container Recipes 2 | 3 | This folder contains the cotainr recipes for building the container used in Exercise 3 and 8. 4 | -------------------------------------------------------------------------------- /bonus_material/exercise_container_recipes/build_pytorch_transformers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | module purge 4 | module load CrayEnv cotainr 5 | 6 | cotainr build pytorch_transformers.sif --system lumi-g --conda-env=./pytorch_transformers.yml 7 | -------------------------------------------------------------------------------- /bonus_material/exercise_container_recipes/pytorch_transformers.yml: -------------------------------------------------------------------------------- 1 | name: pytorch_transformers 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - datasets=2.19.1 6 | - filelock=3.17.0 7 | - numpy=2.2.2 8 | - python=3.12.8 9 | - tensorboardx=2.6.2.2 10 | - transformers=4.45.2 11 | - huggingface_hub=0.24.6 12 | - pip: 13 | - --extra-index-url https://download.pytorch.org/whl/rocm6.2/ 14 | - accelerate==1.3.0 15 | - pytorch-triton-rocm==3.1.0 16 | - torch==2.5.1+rocm6.2 17 | - torchaudio==2.5.1+rocm6.2 18 | - torchvision==0.20.1+rocm6.2 19 | --------------------------------------------------------------------------------