├── scripts ├── requirements.txt ├── inference.py └── train.py ├── images ├── wav2vec2.png └── solution_overview.png ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── CONTRIBUTING.md └── sagemaker-fine-tune-wav2vec2.ipynb /scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | jiwer -------------------------------------------------------------------------------- /images/wav2vec2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/My-Machine-Learning-Projects-CT/amazon-sagemaker-fine-tune-and-deploy-wav2vec2-huggingface/main/images/wav2vec2.png -------------------------------------------------------------------------------- /images/solution_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/My-Machine-Learning-Projects-CT/amazon-sagemaker-fine-tune-and-deploy-wav2vec2-huggingface/main/images/solution_overview.png -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /scripts/inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import torch 4 | 5 | from transformers import Wav2Vec2ForCTC 6 | from transformers import Wav2Vec2Processor 7 | 8 | logger = logging.getLogger() 9 | logger.setLevel(logging.INFO) 10 | logger.addHandler(logging.StreamHandler()) 11 | 12 | model_path = '/opt/ml/model' 13 | logger.info("Libraries are loaded") 14 | 15 | def model_fn(model_dir): 16 | device = get_device() 17 | 18 | model = Wav2Vec2ForCTC.from_pretrained(model_path).to(device) 19 | logger.info("Model is loaded") 20 | 21 | return model 22 | 23 | def input_fn(json_request_data, content_type='application/json'): 24 | 25 | input_data = json.loads(json_request_data) 26 | logger.info("Input data is processed") 27 | 28 | return input_data 29 | 30 | def predict_fn(input_data, model): 31 | 32 | logger.info("Starting inference.") 33 | device = get_device() 34 | 35 | logger.info(input_data) 36 | 37 | speech_array = input_data['speech_array'] 38 | sampling_rate = input_data['sampling_rate'] 39 | 40 | processor = Wav2Vec2Processor.from_pretrained(model_path) 41 | input_values = processor(speech_array, sampling_rate=sampling_rate, return_tensors="pt").input_values.to(device) 42 | 43 | with torch.no_grad(): 44 | logits = model(input_values).logits 45 | pred_ids = torch.argmax(logits, dim=-1) 46 | transcript = processor.batch_decode(pred_ids)[0] 47 | 48 | return transcript 49 | 50 | def output_fn(transcript, accept='application/json'): 51 | return json.dumps(transcript), accept 52 | 53 | def get_device(): 54 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 55 | return device 56 | 57 | 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wav2vec2-huggingface-sagemaker 2 | **Fine-tune and deploy Wav2Vec2 model for speech recognition with HuggingFace and SageMaker** 3 | 4 | In this repository, we use SUPERB dataset that available from Hugging Face [Datasets](https://huggingface.co/datasets/superb) library, and fine-tune the Wav2Vec2 model and deploy it as SageMaker endpoint for real-time inference for an ASR task. 5 | 6 | 7 | First of all, we show how to load and preprocess the SUPERB dataset in SageMaker environment in order to obtain tokenizer and feature extractor, which are required for fine-tuning the Wav2Vec2 model. Then we use SageMaker Script Mode for training and inference steps, that allows you to define and use custom training and inference scripts and SageMaker provides supported Hugging Face framework Docker containers. For more information about training and serving Hugging Face models on SageMaker, see Use [Hugging Face with Amazon SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/hugging-face.html). This functionality is available through the development of Hugging Face [AWS Deep Learning Container (DLC)](https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/what-is-dlc.html). 8 | 9 | This notebook is tested in both SageMaker Studio and SageMaker Notebook environments. Below shows detailed setup. 10 | - SageMaker Studio: **ml.m5.xlarge** instance with **Data Science** kernel. 11 | - SageMaker Notebook: **ml.m5.xlarge** instance with **conda_python3** kernel. 12 | 13 | ## Requirements 14 | 15 | * sagemaker version: 2.78.0 16 | * transformers version: 4.6.1 17 | * datasets version: 1.18.4 18 | * s3fs version: 2022.02.0 19 | * pytorch version: 1.7.1 20 | * jiwer 21 | * soundfile 22 | 23 | ## Security 24 | 25 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 26 | 27 | ## License 28 | 29 | This library is licensed under the MIT-0 License. See the LICENSE file. 30 | 31 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | Wav2Vec2ForCTC, 3 | Trainer, 4 | TrainingArguments, 5 | Wav2Vec2CTCTokenizer, 6 | Wav2Vec2FeatureExtractor, 7 | Wav2Vec2Processor) 8 | from datasets import load_from_disk, load_metric 9 | from dataclasses import dataclass 10 | from typing import Dict, List, Optional, Union 11 | import logging 12 | import sys 13 | import argparse 14 | import os 15 | import torch 16 | import numpy as np 17 | import boto3 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | # hyperparameters sent by the client are passed as command-line arguments to the script. 25 | parser.add_argument("--epochs", type=int, default=10) 26 | parser.add_argument("--train_batch_size", type=int, default=8) 27 | parser.add_argument("--eval_batch_size", type=int, default=8) 28 | parser.add_argument("--warmup_steps", type=int, default=500) 29 | parser.add_argument("--model_name", type=str, default="facebook/wav2vec2-base") 30 | parser.add_argument("--learning_rate", type=str, default=1e-4) 31 | parser.add_argument("--weight_decay", type=str, default=0.005) 32 | parser.add_argument("--vocab_url", type=str) 33 | 34 | # Data, model, and output directories 35 | parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"]) 36 | parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) 37 | parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"]) 38 | parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"]) 39 | parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"]) 40 | 41 | args, _ = parser.parse_known_args() 42 | 43 | # set up logging 44 | logger = logging.getLogger(__name__) 45 | 46 | logging.basicConfig( 47 | level=logging.getLevelName("INFO"), 48 | handlers=[logging.StreamHandler(sys.stdout)], 49 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 50 | ) 51 | 52 | # load datasets 53 | print(args.training_dir) 54 | print(args.test_dir) 55 | train_dataset = load_from_disk(args.training_dir) 56 | test_dataset = load_from_disk(args.test_dir) 57 | 58 | logger.info(f" loaded train_dataset length is: {len(train_dataset)}") 59 | logger.info(f" loaded test_dataset length is: {len(test_dataset)}") 60 | 61 | 62 | # add tokenizer, feature extractor and processor 63 | s3 = boto3.client('s3') 64 | vocab_url = args.vocab_url 65 | s3.download_file(vocab_url.split('/')[2],vocab_url.split('/')[3]+'/vocab.json','vocab.json') 66 | 67 | tokenizer = Wav2Vec2CTCTokenizer("vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") 68 | 69 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 70 | 71 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 72 | 73 | # the data collator class is for dynamic padding and speed up data processing in CPU GPU. The code is from: 74 | # https://github.com/huggingface/transformers/blob/9a06b6b11bdfc42eea08fa91d0c737d1863c99e3/examples/research_projects/wav2vec2/run_asr.py#L81 75 | @dataclass 76 | class DataCollatorCTCWithPadding: 77 | """ 78 | Data collator that will dynamically pad the inputs received. 79 | Args: 80 | processor (:class:`~transformers.Wav2Vec2Processor`) 81 | The processor used for proccessing the data. 82 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 83 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 84 | among: 85 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 86 | sequence if provided). 87 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 88 | maximum acceptable input length for the model if that argument is not provided. 89 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 90 | different lengths). 91 | max_length (:obj:`int`, `optional`): 92 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 93 | max_length_labels (:obj:`int`, `optional`): 94 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 95 | pad_to_multiple_of (:obj:`int`, `optional`): 96 | If set will pad the sequence to a multiple of the provided value. 97 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 98 | 7.5 (Volta). 99 | """ 100 | processor: Wav2Vec2Processor 101 | padding: Union[bool, str] = True 102 | max_length: Optional[int] = None 103 | max_length_labels: Optional[int] = None 104 | pad_to_multiple_of: Optional[int] = None 105 | pad_to_multiple_of_labels: Optional[int] = None 106 | 107 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 108 | # split inputs and labels since they have to be of different lenghts and need 109 | # different padding methods 110 | input_features = [{"input_values": feature["input_values"]} for feature in features] 111 | label_features = [{"input_ids": feature["labels"]} for feature in features] 112 | 113 | batch = self.processor.pad( 114 | input_features, 115 | padding=self.padding, 116 | max_length=self.max_length, 117 | pad_to_multiple_of=self.pad_to_multiple_of, 118 | return_tensors="pt", 119 | ) 120 | with self.processor.as_target_processor(): 121 | labels_batch = self.processor.pad( 122 | label_features, 123 | padding=self.padding, 124 | max_length=self.max_length_labels, 125 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 126 | return_tensors="pt", 127 | ) 128 | 129 | # replace padding with -100 to ignore loss correctly 130 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 131 | 132 | batch["labels"] = labels 133 | 134 | return batch 135 | 136 | data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) 137 | 138 | # compute metrics 139 | wer_metric = load_metric("wer") 140 | 141 | def compute_metrics(pred): 142 | pred_logits = pred.predictions 143 | pred_ids = np.argmax(pred_logits, axis=-1) 144 | 145 | pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id 146 | 147 | pred_str = processor.batch_decode(pred_ids) 148 | label_str = processor.batch_decode(pred.label_ids, group_tokens=False) 149 | 150 | wer = wer_metric.compute(predictions=pred_str, references=label_str) 151 | 152 | return {"wer": wer} 153 | 154 | 155 | model = Wav2Vec2ForCTC.from_pretrained( 156 | args.model_name, 157 | gradient_checkpointing=True, 158 | ctc_loss_reduction="mean", 159 | pad_token_id=processor.tokenizer.pad_token_id, 160 | ) 161 | model.freeze_feature_extractor() 162 | 163 | # define training args 164 | training_args = TrainingArguments( 165 | output_dir=args.model_dir, 166 | group_by_length=True, 167 | per_device_train_batch_size=args.train_batch_size, 168 | per_device_eval_batch_size=args.eval_batch_size, 169 | evaluation_strategy="steps", 170 | num_train_epochs=args.epochs, 171 | fp16=True, # enable mixed-precision training 172 | save_steps=50, 173 | eval_steps=50, 174 | logging_steps=50, 175 | learning_rate=float(args.learning_rate), 176 | weight_decay=float(args.weight_decay), 177 | warmup_steps=args.warmup_steps, 178 | save_total_limit=2, 179 | logging_dir=f"{args.output_data_dir}/logs", 180 | ) 181 | 182 | # create Trainer instance 183 | trainer = Trainer( 184 | model=model, 185 | data_collator=data_collator, 186 | args=training_args, 187 | compute_metrics=compute_metrics, 188 | train_dataset=train_dataset, 189 | eval_dataset=test_dataset, 190 | tokenizer=processor.feature_extractor, 191 | ) 192 | 193 | # train model 194 | trainer.train() 195 | 196 | # evaluate model 197 | eval_result = trainer.evaluate(eval_dataset=test_dataset) 198 | 199 | # writes eval result to file which can be accessed later in s3 ouput 200 | with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer: 201 | print(f"***** Eval results *****") 202 | for key, value in sorted(eval_result.items()): 203 | writer.write(f"{key} = {value}\n") 204 | 205 | # Saves the model to s3 206 | trainer.save_model(args.model_dir) 207 | tokenizer.save_pretrained(args.model_dir) 208 | 209 | -------------------------------------------------------------------------------- /sagemaker-fine-tune-wav2vec2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Fine-tune and deploy Wav2Vec2 model for speech recognition with Hugging Face and SageMaker\n", 8 | "\n", 9 | "## Background\n", 10 | "\n", 11 | "Wav2Vec2 is a transformer-based architecture for ASR tasks and was released in September 2020. We show its simplified architecture diagram below. For more details, see the [original paper](https://arxiv.org/abs/2006.11477). The model is composed of a multi-layer convolutional network (CNN) as feature extractor, which takes input audio signal and outputs audio representations, also considered as features. They are fed into a transformer network to generate contextualized representations. This part of training can be self-supervised, it means that the transformer can be trained with a mass of unlabeled speech and learn from them. Then the model is fine-tuned on labeled data with Connectionist Temporal Classification (CTC) algorithm for specific ASR tasks. The base model we use in this post is [Wav2Vec2-Base-960h](https://huggingface.co/facebook/wav2vec2-base-960h), it is fine-tuned on 960 hours of Librispeech on 16kHz sampled speech audio. \n", 12 | "\n", 13 | "\n", 14 | "Connectionist Temporal Classification (CTC) is character-based algorithm. During the training, it’s able to demarcate each character of the transcription in the speech automatically, so the timeframe alignment is not required between audio signal and transcription. For example, one audio clip says “Hello World”, we don’t need to know in which second word “hello” is located. It saves a lot of labeling effort for ASR use cases. If you are interested in how the algorithm works underneath, see [this article](https://distill.pub/2017/ctc/) for more information. \n", 15 | "\n", 16 | "\n", 17 | "## Notebook Overview \n", 18 | "\n", 19 | "In this notebook, we use [SUPERB \n", 20 | "(Speech processing Universal PERformance Benchmark) dataset](https://huggingface.co/datasets/superb) that available from Hugging Face Datasets library, and fine-tune the Wav2Vec2 model and deploy it as SageMaker endpoint for real-time inference for an ASR task. \n", 21 | "\n", 22 | "\n", 23 | "First of all, we show how to load and preprocess the SUPERB dataset in SageMaker environment in order to obtain tokenizer and feature extractor, which are required for fine-tuning the Wav2Vec2 model. Then we use SageMaker Script Mode for training and inference steps, that allows you to define and use custom training and inference scripts and SageMaker provides supported Hugging Face framework Docker containers. For more information about training and serving Hugging Face models on SageMaker, see Use [Hugging Face with Amazon SageMaker](https://docs.aws.amazon.com/sagemaker/latest/dg/hugging-face.html). This functionality is available through the development of Hugging Face [AWS Deep Learning Container (DLC)](https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/what-is-dlc.html). \n", 24 | "\n", 25 | "This notebook is tested in both SageMaker Studio and SageMaker Notebook environments. Below shows detailed setup. \n", 26 | "- SageMaker Studio: **ml.m5.xlarge** instance with **Data Science** kernel.\n", 27 | "- SageMaker Notebook: **ml.m5.xlarge** instance with **conda_python3** kernel. \n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Set up \n", 35 | "First, install the dependencies." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "scrolled": true 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "!pip install sagemaker --upgrade\n", 47 | "!pip install \"transformers>=4.4.2\" \n", 48 | "!pip install s3fs --upgrade\n", 49 | "!pip install datasets --upgrade \n", 50 | "!pip install librosa\n", 51 | "!pip install torch # framework is required for transformer " 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "**soundfile** library will be used to read raw audio files and convert them into arrays. Before installing **soundfile** python library, package **libsndfile** needs to be installed. " 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "!conda install -c conda-forge libsndfile -y\n", 68 | "!pip install soundfile" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "Following let's import common python libraries. Create a S3 bucket in AWS console for this project, and replace **[BUCKET_NAME]** with your bucket. \n", 76 | "Get the execution role which allows training and servering jobs to access your data. " 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "import json\n", 86 | "import time\n", 87 | "import boto3\n", 88 | "import numpy as np\n", 89 | "import random\n", 90 | "import soundfile \n", 91 | "import sagemaker\n", 92 | "import sagemaker.huggingface\n", 93 | "\n", 94 | "BUCKET=\"[BUCKET_NAME]\" # please use your bucket name\n", 95 | "PREFIX = \"huggingface-blog\" \n", 96 | "ROLE = sagemaker.get_execution_role()\n", 97 | "sess = sagemaker.Session(default_bucket=BUCKET)\n", 98 | "\n", 99 | "print(f\"sagemaker role arn: {ROLE}\")\n", 100 | "print(f\"sagemaker bucket: {sess.default_bucket()}\")\n", 101 | "print(f\"sagemaker session region: {sess.boto_region_name}\")" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## Data Pre-processing\n", 109 | "We are using SUPERB dataset for this notebook, which can be loaded from Hugging Face [dataset library](https://huggingface.co/datasets/superb) directly using `load_dataset` function. SUPERB is a leaderboard to benchmark the performance of a shared model across a wide range of speech processing tasks with minimal architecture changes and labeled data. It also includes speaker_id and chapter_id etc., these columns are removed from the dataset, and we only keep audio files and transcriptions to fine-tune the Wav2Vec2 model for an audio recognition task, which transcribes speech to text. " 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "from datasets import load_dataset, DatasetDict\n", 119 | "data = load_dataset(\"superb\", 'asr', ignore_verifications=True) \n", 120 | "data = data.remove_columns(['speaker_id', 'chapter_id', 'id'])\n", 121 | "# reduce the data volume for this example. only take the test data from the original dataset for fine-tune\n", 122 | "data = data['test'] \n", 123 | "\n", 124 | "train_test = data.train_test_split(test_size=0.2)\n", 125 | "dataset = DatasetDict({\n", 126 | " 'train': train_test['train'],\n", 127 | " 'test': train_test['test']})\n", 128 | "\n", 129 | "# helper function to remove special characters and convert texts to lower case\n", 130 | "def remove_special_characters(batch):\n", 131 | " import re\n", 132 | " chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"]'\n", 133 | " \n", 134 | " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"text\"]).lower()\n", 135 | " return batch\n", 136 | "\n", 137 | "dataset = dataset.map(remove_special_characters)\n", 138 | "print(dataset)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "dataset['train'][0]" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### Build vocabulary file \n", 155 | "Wav2Vec2 model is using [CTC](https://en.wikipedia.org/wiki/Connectionist_temporal_classification) algorithm to train deep neural networks in sequence problems, and its output is a single letter or blank. It uses a character-based tokenizer. Hence, we extract distinct letters from the dataset and build the vocabulary file. " 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "def extract_characters(batch):\n", 165 | " texts = \" \".join(batch[\"text\"])\n", 166 | " vocab = list(set(texts))\n", 167 | " return {\"vocab\": [vocab], \"texts\": [texts]}\n", 168 | "\n", 169 | "vocabs = dataset.map(extract_characters, batched=True, batch_size=-1, \n", 170 | " keep_in_memory=True, remove_columns=dataset.column_names[\"train\"])\n", 171 | "\n", 172 | "vocab_list = list(set(vocabs[\"train\"][\"vocab\"][0]) | set(vocabs[\"test\"][\"vocab\"][0]))\n", 173 | "\n", 174 | "vocab_dict = {v: k for k, v in enumerate(vocab_list)}\n", 175 | "\n", 176 | "vocab_dict[\"|\"] = vocab_dict[\" \"]\n", 177 | "del vocab_dict[\" \"]\n", 178 | "\n", 179 | "vocab_dict[\"[UNK]\"] = len(vocab_dict) # add \"unknown\" token \n", 180 | "vocab_dict[\"[PAD]\"] = len(vocab_dict) # add a padding token that corresponds to CTC's \"blank token\"\n", 181 | "\n", 182 | "with open('vocab.json', 'w') as vocab_file:\n", 183 | " json.dump(vocab_dict, vocab_file)\n", 184 | " \n", 185 | "# vocab.json file will be used in training container, hence upload it to s3 bucket for later steps \n", 186 | "s3 = boto3.client('s3')\n", 187 | "s3.upload_file('vocab.json', BUCKET, f'{PREFIX}/vocab.json')" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "### Create tokenizer with vocabulary file and feature extractor \n", 195 | "Wav2Vec2 model contains tokenizer and feature extractor. We can use vocab.json that created from previous step to create the Wav2Vec2CTCTokenizer. Wav2Vec2FeatureExtractor is to make sure that the dataset used in fine-tune has the same audio sampling rate as the dataset used for pretraining. Finally, create a Wav2Vec2 processor can wrap the feature extractor and the tokenizer into one single processor.\n" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "from transformers import Wav2Vec2CTCTokenizer,Wav2Vec2FeatureExtractor, Wav2Vec2Processor\n", 205 | "\n", 206 | "# create Wav2Vec2 tokenizer\n", 207 | "tokenizer = Wav2Vec2CTCTokenizer(\"vocab.json\", unk_token=\"[UNK]\", pad_token=\"[PAD]\", word_delimiter_token=\"|\")\n", 208 | "\n", 209 | "# create Wav2Vec2 feature extractor\n", 210 | "feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, \n", 211 | " padding_value=0.0, do_normalize=True, return_attention_mask=False)\n", 212 | "# create a processor pipeline \n", 213 | "processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "### Prepare train and test datasets" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "def extract_array_samplingrate(batch):\n", 230 | " batch[\"speech\"] = batch['audio']['array'].tolist()\n", 231 | " batch[\"sampling_rate\"] = batch['audio']['sampling_rate']\n", 232 | " batch[\"target_text\"] = batch[\"text\"]\n", 233 | " return batch\n", 234 | "\n", 235 | "dataset = dataset.map(extract_array_samplingrate, remove_columns=dataset.column_names[\"train\"])" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "dataset" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "# check one audio file from the training dataset\n", 254 | "import IPython.display as ipd\n", 255 | "\n", 256 | "rand_int = random.randint(0, len(dataset[\"train\"]))\n", 257 | "print(dataset[\"train\"][rand_int][\"target_text\"])\n", 258 | "ipd.Audio(data=np.asarray(dataset[\"train\"][rand_int][\"speech\"]), autoplay=True, rate=16000)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "# process the dataset with processor pipeline that created above\n", 268 | "def process_dataset(batch): \n", 269 | " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n", 270 | "\n", 271 | " with processor.as_target_processor():\n", 272 | " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n", 273 | " return batch\n", 274 | "\n", 275 | "data_processed = dataset.map(process_dataset, remove_columns=dataset.column_names[\"train\"], batch_size=8, batched=True)\n", 276 | "\n", 277 | "train_dataset = data_processed['train']\n", 278 | "test_dataset = data_processed['test']" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "Next we upload train and test data to S3. " 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "from datasets.filesystems import S3FileSystem\n", 295 | "s3 = S3FileSystem()\n", 296 | "\n", 297 | "# save train_dataset to s3\n", 298 | "training_input_path = f's3://{BUCKET}/{PREFIX}/train'\n", 299 | "train_dataset.save_to_disk(training_input_path,fs=s3)\n", 300 | "\n", 301 | "# save test_dataset to s3\n", 302 | "test_input_path = f's3://{BUCKET}/{PREFIX}/test'\n", 303 | "test_dataset.save_to_disk(test_input_path,fs=s3)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "## Fine-tune the HuggingFace model (Wav2Vec2)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": {}, 316 | "source": [ 317 | "### Training script\n", 318 | "\n", 319 | "Here we are using SageMaker HuggingFace DLC (Deep Learning Container) script mode to construct the training and inference job, which allows you to write custom trianing and serving code and using HuggingFace framework containers that maintained and supported by AWS. \n", 320 | "\n", 321 | "When we create a training job using the script mode, the `entry_point` script, hyperparameters, its dependencies (inside requirements.txt) and input data (train and test datasets) will be copied into the container. Then it invokes the `entry_point` training script, where the train and test datasets will be loaded, training steps will be executed and model artifacts will be saved in `/opt/ml/model` in the container. After training, artifacts in this directory are uploaded to S3 for later model hosting.\n", 322 | "\n", 323 | "This script is saved in directory `scripts`, and you can inspect the training script by running the next cell. " 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "!pygmentize scripts/train.py" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "### Creating an Estimator and start a training job\n", 340 | "\n", 341 | "Worth to highlight that, when you create a Hugging Face Estimator, you can configure hyperparameters and provide a custom parameter into the training script, such as `vocab_url` in this example. Also you can specify the metrics in the Estimator, and parse the logs of metrics and send them to CloudWatch to monitor and track the training performance. " 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "from sagemaker.huggingface import HuggingFace\n", 351 | "\n", 352 | "#create an unique id to tag training job, model name and endpoint name. \n", 353 | "id = int(time.time())\n", 354 | "\n", 355 | "TRAINING_JOB_NAME = f\"huggingface-wav2vec2-training-{id}\"\n", 356 | "print('Training job name: ', TRAINING_JOB_NAME)\n", 357 | "\n", 358 | "vocab_url = f\"s3://{BUCKET}/{PREFIX}/vocab.json\"\n", 359 | "hyperparameters = {'epochs':10, # you can increase the epoch number to improve model accuracy\n", 360 | " 'train_batch_size': 8,\n", 361 | " 'model_name': \"facebook/wav2vec2-base\",\n", 362 | " 'vocab_url': vocab_url\n", 363 | " }\n", 364 | "\n", 365 | "# define metrics definitions\n", 366 | "metric_definitions=[\n", 367 | " {'Name': 'eval_loss', 'Regex': \"'eval_loss': ([0-9]+(.|e\\-)[0-9]+),?\"},\n", 368 | " {'Name': 'eval_wer', 'Regex': \"'eval_wer': ([0-9]+(.|e\\-)[0-9]+),?\"},\n", 369 | " {'Name': 'eval_runtime', 'Regex': \"'eval_runtime': ([0-9]+(.|e\\-)[0-9]+),?\"},\n", 370 | " {'Name': 'eval_samples_per_second', 'Regex': \"'eval_samples_per_second': ([0-9]+(.|e\\-)[0-9]+),?\"},\n", 371 | " {'Name': 'epoch', 'Regex': \"'epoch': ([0-9]+(.|e\\-)[0-9]+),?\"}]" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": {}, 377 | "source": [ 378 | "We use the [HuggingFace estimator class](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html) to train our model. When creating the estimator, the following parameters need to specify. \n", 379 | "\n", 380 | "* **entry_point**: the name of the training script. It loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model. \n", 381 | "* **source_dir**: the location of the training scripts. \n", 382 | "* **transformers_version**: the Hugging Face transformers library version we want to use.\n", 383 | "* **pytorch_version**: the pytorch version that compatible with transformers library. \n", 384 | "\n", 385 | "**Instance Selection**: For this use case and dataset, we use one ml.p3.2xlarge instance and the training job is able to finish within two hours. You can select a more powerful instance to reduce the training time, however it will generate more cost. " 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'\n", 395 | "\n", 396 | "huggingface_estimator = HuggingFace(entry_point='train.py',\n", 397 | " source_dir='./scripts',\n", 398 | " output_path= OUTPUT_PATH, \n", 399 | " instance_type='ml.p3.2xlarge',\n", 400 | " instance_count=1,\n", 401 | " transformers_version='4.6.1',\n", 402 | " pytorch_version='1.7.1',\n", 403 | " py_version='py36',\n", 404 | " role=ROLE,\n", 405 | " hyperparameters = hyperparameters,\n", 406 | " metric_definitions = metric_definitions,\n", 407 | " )\n", 408 | "\n", 409 | "#Starts the training job using the fit function, training takes approximately 2 hours to complete.\n", 410 | "huggingface_estimator.fit({'train': training_input_path, 'test': test_input_path},\n", 411 | " job_name=TRAINING_JOB_NAME)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": {}, 417 | "source": [ 418 | "From the training logs you can see that, after 10 epochs of training, and model evaluation metrics wer can achieve around 0.32 for the subset of SUPERB dataset. You can increase the number of epochs or use the full dataset to improve the model further. " 419 | ] 420 | }, 421 | { 422 | "cell_type": "markdown", 423 | "metadata": {}, 424 | "source": [ 425 | "## Deploy the model as endpoint on SageMaker and inference the model" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": {}, 431 | "source": [ 432 | "### Inference script\n", 433 | "\n", 434 | "We are using [SageMaker HuggingFace inference tool kit](https://github.com/aws/sagemaker-huggingface-inference-toolkit) to host our fine-tuned model. It provides default functions of preprocessing, predict and postprocessing for certain tasks. However, the default capabilities are not able to inference our model properly. Hence, we defined below functions in `inference.py` script to override the default settings with custom requirements.\n", 435 | "\n", 436 | "* `model_fn(model_dir)`: overrides the default method for loading the model, the return value model will be used in the predict() for predicitions. It receives argument the model_dir, the path to your unzipped model.tar.gz.\n", 437 | "* `input_fn(input_data, content_type)`: overrides the default method for prerprocessing, the return value data will be used in the predict() method for predicitions. The input is input_data, the raw body of your request and content_type, the content type form the request Header.\n", 438 | "* `predict_fn(processed_data, model)`: overrides the default method for predictions, the return value predictions will be used in the postprocess() method. The input is processed_data, the result of the preprocess() method.\n", 439 | "* `output_fn(prediction, accept)`: overrides the default method for postprocessing, the return value result will be the respond of your request(e.g.JSON). The inputs are predictions, the result of the predict() method and accept the return accept type from the HTTP Request, e.g. application/json\n", 440 | "\n", 441 | "**Note**: Inference tool kit can inference tasks from architectures that ending with: 'TapasForQuestionAnswering', 'ForQuestionAnswering', 'ForTokenClassification', 'ForSequenceClassification', 'ForMultipleChoice', 'ForMaskedLM', 'ForCausalLM', 'ForConditionalGeneration', 'MTModel', 'EncoderDecoderModel', 'GPT2LMHeadModel', 'T5WithLMHeadModel' as of Jan2022. \n", 442 | "\n", 443 | "This script is saved in directory `scripts`, you can inspect the inference script by running the next cell. " 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "!pygmentize scripts/inference.py" 453 | ] 454 | }, 455 | { 456 | "cell_type": "markdown", 457 | "metadata": {}, 458 | "source": [ 459 | "### Create a HuggingFaceModel from the estimator " 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "We use the [HuggingFaceModel class](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-model) to create a model object, which can be deployed to a SageMaker endpoint. When creating the model, the following parameters need to specify. \n", 467 | "\n", 468 | "* **entry_point**: the name of the inference script. The methods defined in the inference script will be implemented to the endpoint. \n", 469 | "* **source_dir**: the location of the inference scripts. \n", 470 | "* **transformers_version**: the Hugging Face transformers library version we want to use. It should be consistent with training step. \n", 471 | "* **pytorch_version**: the pytorch version that compatible with transformers library. It should be consistent with training step.\n", 472 | "* **model_data**: the Amazon S3 location of a SageMaker model data `.tar.gz` file\n" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "from sagemaker.huggingface import HuggingFaceModel\n", 482 | "\n", 483 | "huggingface_model = HuggingFaceModel(\n", 484 | " entry_point = 'inference.py',\n", 485 | " source_dir='./scripts',\n", 486 | " name = f'huggingface-wav2vec2-model-{id}',\n", 487 | " transformers_version='4.6.1', \n", 488 | " pytorch_version='1.7.1', \n", 489 | " py_version='py36',\n", 490 | " model_data=huggingface_estimator.model_data,\n", 491 | " role=ROLE,\n", 492 | " )" 493 | ] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "metadata": {}, 498 | "source": [ 499 | "### Deploy the model on an endpoint \n", 500 | "\n", 501 | "Next, we create a predictor by using the `model.deploy` function. You can change the instance count and instance type based on your performance requirements. " 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [ 510 | "predictor = huggingface_model.deploy(\n", 511 | " initial_instance_count=1,\n", 512 | " instance_type=\"ml.g4dn.xlarge\", \n", 513 | " endpoint_name = f'huggingface-wav2vec2-endpoint-{id}'\n", 514 | ")" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "metadata": {}, 520 | "source": [ 521 | "### Inference audio files \n", 522 | "\n", 523 | "After the endpoint is deployed, you can run below prediction tests to check the model performance. " 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "# inference audio file that download from S3 bucket or inference local audio file \n", 533 | "import soundfile\n", 534 | "\n", 535 | "# s3.download_file(BUCKET, 'huggingface-blog/sample_audio/xxxxxx.wav', 'downloaded.wav')\n", 536 | "# file_name ='downloaded.wav'\n", 537 | "\n", 538 | "# download a sample audio file by using below link\n", 539 | "!wget https://datashare.ed.ac.uk/bitstream/handle/10283/343/MKH800_19_0001.wav\n", 540 | " \n", 541 | "file_name ='MKH800_19_0001.wav'\n", 542 | "\n", 543 | "speech_array, sampling_rate = soundfile.read(file_name)\n", 544 | "\n", 545 | "ipd.Audio(data=np.asarray(speech_array), autoplay=True, rate=16000)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": null, 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "%%time\n", 555 | "json_request_data = {\"speech_array\": speech_array.tolist(),\n", 556 | " \"sampling_rate\": sampling_rate}\n", 557 | "\n", 558 | "prediction = predictor.predict(json_request_data)\n", 559 | "print(prediction)" 560 | ] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "metadata": {}, 565 | "source": [ 566 | "**Please note**, as we are using real-time inference endpoint, the maximum payload size is 6MB. If you see any error message like \"Received client error (413) from primary and could not load the entire response body\", please use blow code to check your payload size. " 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "import sys\n", 576 | "sys.getsizeof(speech_array) " 577 | ] 578 | }, 579 | { 580 | "cell_type": "markdown", 581 | "metadata": {}, 582 | "source": [ 583 | "## Cleanup\n", 584 | "\n", 585 | "Finally, please remember to delete the Amazon SageMaker endpoint to avoid charges:" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": null, 591 | "metadata": {}, 592 | "outputs": [], 593 | "source": [ 594 | "predictor.delete_endpoint()" 595 | ] 596 | } 597 | ], 598 | "metadata": { 599 | "instance_type": "ml.m5.xlarge", 600 | "kernelspec": { 601 | "display_name": "Python 3 (Data Science)", 602 | "language": "python", 603 | "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:eu-west-1:470317259841:image/datascience-1.0" 604 | }, 605 | "language_info": { 606 | "codemirror_mode": { 607 | "name": "ipython", 608 | "version": 3 609 | }, 610 | "file_extension": ".py", 611 | "mimetype": "text/x-python", 612 | "name": "python", 613 | "nbconvert_exporter": "python", 614 | "pygments_lexer": "ipython3", 615 | "version": "3.7.10" 616 | } 617 | }, 618 | "nbformat": 4, 619 | "nbformat_minor": 4 620 | } 621 | --------------------------------------------------------------------------------