├── README.md ├── images └── pipeline.png ├── ray ├── README.md └── code │ ├── microbenchmark.py │ ├── ray-10G.py │ ├── ray-10T.py │ └── ray-300G.py ├── sagemaker ├── README.md └── code │ ├── inference.ipynb │ └── predict.py └── spark ├── README.md └── code ├── microbenchmark.ipynb ├── torch-batch-inference-10G-s3-cpu-only.ipynb ├── torch-batch-inference-10G-s3-predict-only.ipynb ├── torch-batch-inference-10G-stage-level-scheduling.ipynb ├── torch-batch-inference-300G-s3-standard.ipynb ├── torch-batch-inference-s3-10G-single-node.ipynb ├── torch-batch-inference-s3-10G-standard-iterator-databricks-prefetch.ipynb ├── torch-batch-inference-s3-10G-standard-iterator.ipynb └── torch-batch-inference-s3-10G-standard.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Batch Inference Benchmarking 2 | 3 | This repo contains benchmarks for batch inference benchmarking using Ray, Apache Spark and Amazon SageMaker Batch Transform. 4 | 5 | We use the image classification task from the [MLPerf Inference Benchmark suite](https://arxiv.org/pdf/1911.02549.pdf) in the offline setting. 6 | 7 | - Images from [ImageNet 2012 Dataset](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php#Images) 8 | - [PyTorch ResNet50 model](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html) 9 | 10 | The workload is a simple 3 step pipeline: 11 | ![Pipeline](./images/pipeline.png) 12 | 13 | Images are stored in parquet format (with ~1k images per parquet file) and are read from S3 from within the same region. 14 | 15 | ## 10 GB 16 | In-memory dataset size is 10 GB. 10 GB dataset. The compressed on-disk size is much smaller. 17 | 18 | All experiments used PyTorch v1.12.1 and CUDA 11.3 19 | 20 | ## Configurations 21 | 22 | ### Ray 23 | - 1 `gd4n.12xlarge` instance. Contains 48 CPUs and 4 GPUs. 24 | - Experiments were all run on the [Anyscale platform](https://www.anyscale.com/). 25 | - Uses [Ray Data](https://docs.ray.io/en/latest/data/dataset.html) [nightly version](https://docs.ray.io/en/latest/ray-overview/installation.html#daily-releases-nightlies). 26 | - [Code](ray/code/ray-10G.py) 27 | 28 | ### Spark 29 | We tried 2 configurations. All experiments were run on Databricks with the Databricks Runtime v12.0, and using the ML GPU runtime when applicable. 30 | 31 | - **Config 1**: Creates a standard Databricks cluster with a `g4dn.12xlarge` instance. 32 | - This starts a 2 node cluster: 1 node for the driver that does not run tasks, and 1 node for the executor. Databricks does not support running tasks on the head node. 33 | - Spark fuses all stages together, so total parallelism, even for CPU tasks, is limited by the # of GPUs. 34 | - [Code](spark/code/torch-batch-inference-s3-10G-standard.ipynb) 35 | 36 | - **Config 2**: Use 2 separate clusters: 1 CPU-only cluster for preprocessing, and 1 GPU cluster for predicting. We use DBFS to store the intermeditate preprocessed data. This allows preprocessing to scale independently from prediction, at the cost of having to persist data in between the steps. 37 | - **CPU cluster**: 1 `m6gd.12xlarge` instance with Photon acceleration enabled. This is the smallest `m6gd` instance that does not OOM. 38 | - **GPU cluster**: 1 `g4dn.12xlarge` instance. 39 | - [CPU Code](spark/code/torch-batch-inference-10G-s3-cpu-only.ipynb) 40 | - [GPU Code](spark/code/torch-batch-inference-10G-s3-predict-only.ipynb) 41 | 42 | Additional configurations were tried that performed worse which you can read about in the [spark directory](spark/README.md). 43 | 44 | ### SageMaker Batch Transform 45 | SageMaker Batch Transform with 4 `g4dn.xlarge` instances. There is no built-in multi-GPU support, so we cannot use the multi-GPU `g4dn.12xlarge` instance. There are still 4 GPUs total in the cluster. 46 | 47 | [Code](sagemaker/code/inference-image.ipynb) 48 | 49 | Additional configurations were tried that failed, which you can read more about in the [sagemaker directory](sagemaker/README.md) 50 | -------------------------------------------------------------------------------- /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amogkam/batch-inference-benchmarks/932b351670414e404546aa1a8e868108211cabc5/images/pipeline.png -------------------------------------------------------------------------------- /ray/README.md: -------------------------------------------------------------------------------- 1 | # Batch Inference Benchmarking with Ray 2 | 3 | This repo contains benchmarks for batch inference benchmarking with [Ray Data](https://docs.ray.io/en/latest/data/dataset.html). 4 | 5 | We use the image classification task from the [MLPerf Inference Benchmark suite](https://arxiv.org/pdf/1911.02549.pdf) in the offline setting. 6 | 7 | - Images from ImageNet 2012 Dataset 8 | - ResNet50 model 9 | 10 | The workload is a simple 3 step pipeline: 11 | ![Pipeline](../images/pipeline.png) 12 | 13 | Images are saved in parquet format (with ~1k images per parquet file). 14 | 15 | We tried with two dataset sizes, 10 GB and 300 GB. These sizes are for when the data is loaded in memory. The compressed on-disk size is much smaller. 16 | 17 | We also run a microbenchmark to measure overhead from Ray Data. 18 | 19 | All experiments are run in Anyscale using Ray 2.5. 20 | 21 | ## 10 GB 22 | 10 GB dataset using a single-node cluster. 23 | 24 | ### Configurations 25 | - 1 `gd4n.12xlarge` instance. Contains 48 CPUs and 4 GPUs. 26 | - Experiments were all run on the [Anyscale platform](https://www.anyscale.com/). 27 | - Uses [Ray Data](https://docs.ray.io/en/latest/data/dataset.html) [nightly version](https://docs.ray.io/en/latest/ray-overview/installation.html#daily-releases-nightlies). 28 | - [Code](ray/code/ray-10G.py) 29 | 30 | **Throughput**: 312.460 img/sec 31 | 32 | ## 300 GB 33 | 34 | We scale up to more nodes for inference on 300 GB data. Uses 4 `g4dn.12xlarge` instances., 16 GPUs in total. 35 | 36 | [Code](code/ray-300G.py) 37 | 38 | **Throughput**: 2658.314 img/sec 39 | 40 | 41 | ## 10 TB 42 | 43 | We scale up to even more nodes for inference on 10 TB data. Since deep learning workloads are often memory contrained, we use 44 | a heterogenous cluster consisting of some GPU nodes and some CPU-only nodes to fully maximize throughput and GPU utilization. 45 | 46 | The cluster consists of: 47 | - 10 `g4dn.12xlarge` instances 48 | - 10 `m5.16xlarge` instances 49 | 50 | **Throughput** 11580.958 img/sec 51 | 52 | ## Microbenchmark 53 | Run a microbenchmark that reads from S3 and does a dummy preprocessing step with `time.sleep(1)`. 54 | 55 | [Full code is here](code/microbenchmark.py) 56 | 57 | We force execution of the read and cache the result before executing preprocessing to isolate just the preprocessing time. 58 | 59 | Preprocessing takes **4.383** seconds. 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /ray/code/microbenchmark.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import time 3 | 4 | ds = ray.data.read_parquet("s3://air-example-data-2/10G-image-data-synthetic-raw-parquet/") 5 | ds = ds.materialize() 6 | 7 | def dummy_preprocess(image): 8 | time.sleep(1) 9 | return image 10 | 11 | ds = ds.map_batches(dummy_preprocess) 12 | 13 | start_time = time.time() 14 | ds.fully_executed() 15 | end_time = time.time() 16 | print(f"Preprocessing took: {end_time-start_time} seconds") -------------------------------------------------------------------------------- /ray/code/ray-10G.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from torchvision.models import resnet50, ResNet50_Weights 3 | import torch 4 | import time 5 | from torchvision import transforms 6 | 7 | from ray.data import ActorPoolStrategy 8 | 9 | BATCH_SIZE = 1000 10 | 11 | model = resnet50(weights=ResNet50_Weights.DEFAULT) 12 | model_ref = ray.put(model) 13 | 14 | start_time = time.time() 15 | ds = ray.data.read_parquet("s3://air-example-data-2/10G-image-data-synthetic-raw-parquet/") 16 | 17 | def preprocess(image_batch): 18 | preprocess = transforms.Compose( 19 | [ 20 | transforms.Resize(256), 21 | transforms.CenterCrop(224), 22 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 23 | ] 24 | ) 25 | torch_tensor = torch.Tensor(image_batch["image"].transpose(0, 3, 1, 2)) 26 | preprocessed_images = preprocess(torch_tensor).numpy() 27 | return {"image": preprocessed_images} 28 | 29 | class Actor: 30 | def __init__(self, model): 31 | self.model = ray.get(model) 32 | self.model.eval() 33 | self.model.to("cuda") 34 | 35 | def __call__(self, batch): 36 | with torch.inference_mode(): 37 | output = self.model(torch.as_tensor(batch["image"], device="cuda")) 38 | return output.cpu().numpy() 39 | 40 | start_time_without_metadata_fetching = time.time() 41 | ds = ds.map_batches(preprocess, batch_format="numpy") 42 | ds = ds.map_batches(Actor, batch_size=BATCH_SIZE, compute=ActorPoolStrategy(size=4), num_gpus=1, batch_format="numpy", fn_constructor_kwargs={"model": model_ref}, max_concurrency=2) 43 | for _ in ds.iter_batches(batch_size=None, batch_format="pyarrow"): 44 | pass 45 | end_time = time.time() 46 | 47 | print("Total time: ", end_time-start_time) 48 | print("Throughput (img/sec): ", (16232)/(end_time-start_time)) 49 | print("Total time w/o metadata fetching (img/sec) : ", (end_time-start_time_without_metadata_fetching)) 50 | print("Throughput w/o metadata fetching (img/sec) ", (16232)/(end_time-start_time_without_metadata_fetching)) 51 | 52 | print(ds.stats()) -------------------------------------------------------------------------------- /ray/code/ray-10T.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from ray.data.context import DataContext 3 | from torchvision.models import resnet50, ResNet50_Weights 4 | import torch 5 | import time 6 | from torchvision import transforms 7 | 8 | from ray.data import ActorPoolStrategy 9 | 10 | BATCH_SIZE = 1000 11 | 12 | model = resnet50(weights=ResNet50_Weights.DEFAULT) 13 | model_ref = ray.put(model) 14 | 15 | start_time = time.time() 16 | ds = ray.data.read_parquet("s3://air-example-data-2/10T-image-data-synthetic-raw-parquet/") 17 | 18 | def preprocess(image_batch): 19 | preprocess = transforms.Compose( 20 | [ 21 | transforms.Resize(256), 22 | transforms.CenterCrop(224), 23 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ] 25 | ) 26 | torch_tensor = torch.Tensor(image_batch["image"].transpose(0, 3, 1, 2)) 27 | preprocessed_images = preprocess(torch_tensor).numpy() 28 | return {"image": preprocessed_images} 29 | 30 | class Actor: 31 | def __init__(self, model): 32 | self.model = ray.get(model) 33 | self.model.eval() 34 | self.model.to("cuda") 35 | 36 | def __call__(self, batch): 37 | with torch.inference_mode(): 38 | output = self.model(torch.as_tensor(batch["image"], device="cuda")) 39 | return output.cpu().numpy() 40 | 41 | start_time_without_metadata_fetching = time.time() 42 | ds = ds.map_batches(preprocess, batch_format="numpy") 43 | ds = ds.map_batches(Actor, batch_size=BATCH_SIZE, compute=ActorPoolStrategy(size=40), num_gpus=1, batch_format="numpy", fn_constructor_kwargs={"model": model_ref}, max_concurrency=2) 44 | for _ in ds.iter_batches(batch_size=None, batch_format="pyarrow"): 45 | pass 46 | end_time = time.time() 47 | 48 | print("Total time: ", end_time-start_time) 49 | print("Throughput (img/sec): ", (16232000)/(end_time-start_time)) 50 | print("Total time w/o metadata fetching (img/sec) : ", (end_time-start_time_without_metadata_fetching)) 51 | print("Throughput w/o metadata fetching (img/sec) ", (16232000)/(end_time-start_time_without_metadata_fetching)) 52 | 53 | print(ds.stats()) 54 | -------------------------------------------------------------------------------- /ray/code/ray-300G.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from torchvision.models import resnet50, ResNet50_Weights 3 | import torch 4 | import time 5 | from torchvision import transforms 6 | 7 | from ray.data import ActorPoolStrategy 8 | 9 | BATCH_SIZE = 1000 10 | 11 | model = resnet50(weights=ResNet50_Weights.DEFAULT) 12 | model_ref = ray.put(model) 13 | 14 | start_time = time.time() 15 | ds = ray.data.read_parquet("s3://air-example-data-2/300G-image-data-synthetic-raw-parquet/") 16 | 17 | def preprocess(image_batch): 18 | preprocess = transforms.Compose( 19 | [ 20 | transforms.Resize(256), 21 | transforms.CenterCrop(224), 22 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 23 | ] 24 | ) 25 | torch_tensor = torch.Tensor(image_batch["image"].transpose(0, 3, 1, 2)) 26 | preprocessed_images = preprocess(torch_tensor).numpy() 27 | return {"image": preprocessed_images} 28 | 29 | class Actor: 30 | def __init__(self, model): 31 | self.model = ray.get(model) 32 | self.model.eval() 33 | self.model.to("cuda") 34 | 35 | def __call__(self, batch): 36 | with torch.inference_mode(): 37 | output = self.model(torch.as_tensor(batch["image"], device="cuda")) 38 | return output.cpu().numpy() 39 | 40 | start_time_without_metadata_fetching = time.time() 41 | ds = ds.map_batches(preprocess, batch_format="numpy") 42 | ds = ds.map_batches(Actor, batch_size=BATCH_SIZE, compute=ActorPoolStrategy(size=16), num_gpus=1, batch_format="numpy", fn_constructor_kwargs={"model": model_ref}, max_concurrency=2) 43 | for _ in ds.iter_batches(batch_size=None, batch_format="pyarrow"): 44 | pass 45 | end_time = time.time() 46 | 47 | print("Total time: ", end_time-start_time) 48 | print("Throughput (img/sec): ", (488207)/(end_time-start_time)) 49 | print("Total time w/o metadata fetching (img/sec) : ", (end_time-start_time_without_metadata_fetching)) 50 | print("Throughput w/o metadata fetching (img/sec) ", (488207)/(end_time-start_time_without_metadata_fetching)) 51 | 52 | print(ds.stats()) -------------------------------------------------------------------------------- /sagemaker/README.md: -------------------------------------------------------------------------------- 1 | # Batch Inference Benchmarking with SageMaker Batch Transform 2 | SageMaker Batch Transform with 4 `g4dn.xlarge` instances. There is no built-in multi-GPU support, so we cannot use the multi-GPU `g4dn.12xlarge` instance. There are still 4 GPUs total in the cluster. 3 | 4 | [Code](sagemaker/code/inference-image.ipynb). Running this code will upload a pre-trained model to S3. It also packages the code in `predict.py` and runs it on the cluster to handle the logic for performing inference. 5 | 6 | 7 | ## Raw Images 8 | 9 | SageMaker Batch Transform reads raw images from S3 and sends them as individual HTTP requests to the cluster. Batching across multiple files is not supported 10 | 11 | > "SageMaker processes each input file separately. It doesn't combine mini-batches from different input files to comply with the MaxPayloadInMB limit." 12 | 13 | https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html. 14 | 15 | **Throughput**: 18.702 img/sec 16 | ## Parquet files 17 | 18 | Since SageMaker does not batch multiple image files together, this means our GPU is extremely underutilized. We tried an additional approach involving batching images together into parquet files beforehand and doing inference on these parquet files. 19 | 20 | However, we ran into a few issues: 21 | 1. The max payload size is 100 MB, which is far less than the ideal batch size to maximize GPUs 22 | 2. It's unclear how to actually parse the input request and read it as a parquet file. Even though code works locally, the job fails with an unhelpful error message, making it impossible to debug: 23 | 24 | ``` 25 | air-example-data-2/10G-image-data-synthetic-raw-parquet-120-partition/644f08f256a24362b744d6219523cd16_000097.parquet: ClientError: 413 26 | ``` -------------------------------------------------------------------------------- /sagemaker/code/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import transforms 4 | from torchvision.models import resnet50 5 | 6 | import os 7 | import io 8 | from PIL import Image 9 | import numpy as np 10 | import pandas as pd 11 | import pyarrow as pa 12 | import pyarrow.parquet as pq 13 | 14 | def model_fn(model_dir): 15 | device = torch.device("cuda") 16 | model = resnet50() 17 | with open(os.path.join(model_dir, "model.ckpt"), "rb") as f: 18 | model.load_state_dict(torch.load(f)) 19 | model = model.to(device) 20 | model.eval() 21 | return model 22 | 23 | # https://stackoverflow.com/questions/62415237/aws-sagemaker-using-parquet-file-for-batch-transform-job 24 | def load_parquet_from_bytearray(request_body): 25 | image_as_bytes = io.BytesIO(request_body) 26 | reader = pa.BufferReader(image_as_bytes) 27 | df = pq.read_table(reader) 28 | batch_dim = len(df) 29 | numpy_batch = np.stack(df["image"].to_numpy()) 30 | reshaped_images = numpy_batch.reshape(batch_dim, 256, 256, 3).astype(float) 31 | torch_tensor = torch.Tensor(reshaped_images.transpose(0, 3, 1, 2)) 32 | 33 | preprocess = transforms.Compose( 34 | [ 35 | transforms.Resize(256), 36 | transforms.CenterCrop(224), 37 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 38 | ] 39 | ) 40 | preprocessed_images = preprocess(torch_tensor) 41 | return preprocessed_images 42 | 43 | 44 | def load_from_bytearray(request_body): 45 | image_as_bytes = io.BytesIO(request_body) 46 | image = Image.open(image_as_bytes) 47 | 48 | preprocess = transforms.Compose( 49 | [ 50 | transforms.Resize(256), 51 | transforms.CenterCrop(224), 52 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 53 | ] 54 | ) 55 | image_tensor = transforms.ToTensor()(image).unsqueeze(0) 56 | image_tensor = preprocess(image_tensor) 57 | return image_tensor 58 | 59 | 60 | def input_fn(request_body, request_content_type): 61 | if request_content_type == "application/x-parquet": 62 | image_tensor = load_parquet_from_bytearray(request_body) 63 | elif request_content_type == "application/x-image": 64 | image_tensor = load_from_bytearray(request_body) 65 | else: 66 | raise ValueError("Expected `application/x-parquet`.") 67 | return image_tensor 68 | 69 | 70 | # Perform prediction on the deserialized object, with the loaded model 71 | def predict_fn(input_object, model): 72 | with torch.inference_mode(): 73 | output = model.forward(input_object.to(torch.device("cuda"))) 74 | 75 | return {"predictions": output.cpu().numpy()} -------------------------------------------------------------------------------- /spark/README.md: -------------------------------------------------------------------------------- 1 | # Batch Inference Benchmarking with Apache Spark 2 | 3 | This repo contains benchmarks for batch inference benchmarking with Apache Spark. 4 | 5 | We use the image classification task from the [MLPerf Inference Benchmark suite](https://arxiv.org/pdf/1911.02549.pdf) in the offline setting. 6 | 7 | - Images from ImageNet 2012 Dataset 8 | - ResNet50 model 9 | 10 | The workload is a simple 3 step pipeline: 11 | ![Pipeline](../images/pipeline.png) 12 | 13 | Images are saved in parquet format (with ~1k images per parquet file). 14 | 15 | We tried with two dataset sizes, 10 GB and 300 GB. These sizes are for when the data is loaded in memory. The compressed on-disk size is much smaller. 16 | 17 | We also run a microbenchmark to measure overhead from Spark. 18 | 19 | All experiments are run in Databricks using Databricks Runtime v12.0, with Spark 3.3.1,and using the ML GPU runtime when applicable. 20 | 21 | ## 10 GB 22 | 10 GB dataset using a single-node cluster. 23 | 24 | ### Configurations 25 | 26 | **Update 5/10**: Thanks to the Databricks spark developers for pointing out the prefetching flag (`spark.databricks.execution.pandasUDF.prefetch.maxBatches`) available in Databricks Runtime that can be used with the Iterator API. It is also possible to implement this manually in Spark open source with background threads. With prefetching set to 4, Spark reaches 159.86 images/s. 27 | 28 | - **Local**: `g4dn.16xlarge` instance (1 GPU). This is the smallest `g4dn` instance that does not OOM. 29 | - Creates a [single-node cluster](https://docs.databricks.com/clusters/single-node.html) which starts Spark locally on the driver. 30 | - Local clusters do not support GPU scheduling. Spark will schedule tasks based on available CPU cores. 31 | - We have to manually repartition the data between the preprocessing and prediction steps to match the number of GPUs. 32 | - We cannot use multi-GPU machines since we cannot specify the CUDA visible devices for each task. 33 | - [Code](code/torch-batch-inference-s3-10G-single-node.ipynb) 34 | 35 | - **Single-cluster**. Creates a standard Databricks cluster. 36 | - This starts a 2 node cluster: 1 node for the driver that does not run tasks, and 1 node for the executor. 37 | - Standard clusters support GPU scheduling 38 | - However, since Spark fuses all stages, the effective parallelism is limited by the # of GPUs. 39 | - Use single `gd4n.12xlarge` instance consisting of 4 GPUs. 40 | - [Code](code/torch-batch-inference-s3-10G-standard.ipynb) 41 | - Also tried using the [Iterator UDF API](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.functions.pandas_udf.html#pyspark.sql.functions.pandas_udf). [Code](torch-batch-inference-s3-10G-standard-iterator.ipynb) 42 | - Per feedback from Databricks spark developers, also tried using [Iterator UDF API](https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.functions.pandas_udf.html#pyspark.sql.functions.pandas_udf) with Databricks prefetching support. [Code](code/torch-batch-inference-s3-10G-standard-iterator-databricks-prefetch.ipynb). 43 | 44 | - **Multi-cluster**. Use 2 separate clusters: 1 CPU-only cluster for preprocessing, and 1 GPU cluster for predicting. We use DBFS to store the intermeditate preprocessed data. This allows preprocessing to scale independently from prediction, at the cost of having to persist data in between the steps. 45 | - **CPU cluster**: 1 `m6gd.12xlarge` instance with Photon acceleration enabled. This is the smallest `m6gd` instance that does not OOM. 46 | - **GPU cluster**: 1 `g4dn.12xlarge` instance. 47 | - [CPU Code](code/torch-batch-inference-10G-s3-cpu-only.ipynb) 48 | - [GPU Code](code/torch-batch-inference-10G-s3-predict-only.ipynb) 49 | 50 | 51 | 52 | | Configuration | Throughput (img/sec) | 53 | |-----------------------------------------------|----------------------| 54 | | Local | 117.658 | 55 | | Single-cluster | 147.848 | 56 | | Single-cluster Iterator | 113.353 | 57 | | Single-cluster Iterator + Databricks prefetch | 159.862 | 58 | | Multi-cluster | 108.768 | 59 | 60 | ## 300 GB 61 | 62 | We pick the best configuration from the 10 GB experiments, and scale up to more nodes for inference on 300 GB data. 63 | 64 | [Code](code/torch-batch-inference-300G-s3-standard.ipynb) 65 | 66 | 4 `g4dn.12xlarge` instances. 67 | 68 | | Configuration | Throughput (img/sec) | 69 | |----------------|----------------------| 70 | | Single-cluster | 689.084 | 71 | ## Microbenchmark 72 | Run a microbenchmark that reads from S3, caches the dataset, and then does a dummy preprocessing step with `time.sleep(1)`. 73 | 74 | [Full code is here](code/microbenchmark.ipynb) 75 | 76 | We force execution of the read before executing preprocessing to isolate just the preprocessing time. 77 | 78 | Preprocessing takes **128.918** seconds. 79 | 80 | Profiling shows that the actual UDF execution is 1 second, so the additional time is coming from some other Spark overhead. 81 | 82 | ``` 83 | ============================================================ 84 | Profile of UDF 85 | ============================================================ 86 | 48 function calls in 16.015 seconds 87 | 88 | Ordered by: internal time, cumulative time 89 | 90 | ncalls tottime percall cumtime percall filename:lineno(function) 91 | 16 16.015 1.001 16.015 1.001 {built-in method time.sleep} 92 | 16 0.000 0.000 16.015 1.001 :7(dummy_preprocess) 93 | 16 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects} 94 | ``` 95 | 96 | ## Stage level Scheduling 97 | We also tried enabling [stage level scheduling](https://books.japila.pl/apache-spark-internals/stage-level-scheduling/) to maximize both CPU and GPU parallelism without having to use two separate clusters. 98 | 99 | 1. Start a standard cluster with `g4dn.12xlarge` instance (4 GPUs) 100 | 2. Set `spark.task.resource.gpu.amount` to `# GPU/# CPU==1/12==0.0833` during cluster startup. This is to prevent GPU from limiting parallelism during reading+preprocessing stage. 101 | 3. After the preprocessing stage, use create a new `TaskResourceRequest` with 1 GPU per task. Run prediction on the preprocessed RDD with the newly created `TaskResourceRequest`. 102 | 103 | However, the inference stage is not respecting the new resource request, leading to more parallelism than available GPU and CUDA OOM. 104 | 105 | [Full code is here](code/torch-batch-inference-10G-stage-level-scheduling.ipynb) 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /spark/code/microbenchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 0, 6 | "metadata": { 7 | "application/vnd.databricks.v1+cell": { 8 | "cellMetadata": { 9 | "byteLimit": 2048000, 10 | "rowLimit": 10000 11 | }, 12 | "inputWidgets": {}, 13 | "nuid": "a4832578-6320-43f6-a462-bd8eb1ae454b", 14 | "showTitle": false, 15 | "title": "" 16 | } 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "# Enable Arrow support.\n", 21 | "spark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 0, 27 | "metadata": { 28 | "application/vnd.databricks.v1+cell": { 29 | "cellMetadata": { 30 | "byteLimit": 2048000, 31 | "rowLimit": 10000 32 | }, 33 | "inputWidgets": {}, 34 | "nuid": "525bb7ea-f7b3-44ff-a013-3d8d9fc82135", 35 | "showTitle": false, 36 | "title": "" 37 | } 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "input_data = spark.read.format(\"parquet\").load(\"s3://air-example-data-2/10G-image-data-synthetic-raw-parquet\")\n", 42 | "# Force execution of the read, and cache the data\n", 43 | "cached_input = input_data.cache()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 0, 49 | "metadata": { 50 | "application/vnd.databricks.v1+cell": { 51 | "cellMetadata": { 52 | "byteLimit": 2048000, 53 | "rowLimit": 10000 54 | }, 55 | "inputWidgets": {}, 56 | "nuid": "0ac17aec-e3d2-45dd-aeeb-f7b3e8935353", 57 | "showTitle": false, 58 | "title": "" 59 | } 60 | }, 61 | "outputs": [ 62 | { 63 | "output_type": "stream", 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "# data partitions: 32\n# Spark max parallelism: 64\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "# More parallelism than data partitions.\n", 73 | "print(\"# data partitions: \", cached_input.rdd.getNumPartitions())\n", 74 | "print(\"# Spark max parallelism: \", sc.defaultParallelism)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 0, 80 | "metadata": { 81 | "application/vnd.databricks.v1+cell": { 82 | "cellMetadata": { 83 | "byteLimit": 2048000, 84 | "rowLimit": 10000 85 | }, 86 | "inputWidgets": {}, 87 | "nuid": "269d5038-ac2c-4415-aa15-2ae4b239698f", 88 | "showTitle": false, 89 | "title": "" 90 | } 91 | }, 92 | "outputs": [ 93 | { 94 | "output_type": "display_data", 95 | "data": { 96 | "application/vnd.databricks.v1+bamboolib_hint": "{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}", 97 | "text/plain": [] 98 | }, 99 | "metadata": { 100 | "application/vnd.databricks.v1+output": { 101 | "addedWidgets": {}, 102 | "arguments": {}, 103 | "data": { 104 | "application/vnd.databricks.v1+bamboolib_hint": "{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}", 105 | "text/plain": "" 106 | }, 107 | "datasetInfos": [], 108 | "executionCount": null, 109 | "metadata": { 110 | "kernelSessionId": "b57204b4-3c4675446be8738502ee6ac0" 111 | }, 112 | "removedWidgets": [], 113 | "type": "mimeBundle" 114 | } 115 | }, 116 | "output_type": "display_data" 117 | } 118 | ], 119 | "source": [ 120 | "from pyspark.sql.functions import col, pandas_udf\n", 121 | "from pyspark.sql.types import ArrayType, FloatType\n", 122 | "\n", 123 | "import pandas as pd\n", 124 | "import time\n", 125 | "\n", 126 | "@pandas_udf(ArrayType(FloatType()))\n", 127 | "def dummy_preprocess(image: pd.Series) -> pd.Series:\n", 128 | " time.sleep(1)\n", 129 | " return image" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 0, 135 | "metadata": { 136 | "application/vnd.databricks.v1+cell": { 137 | "cellMetadata": { 138 | "byteLimit": 2048000, 139 | "rowLimit": 10000 140 | }, 141 | "inputWidgets": {}, 142 | "nuid": "3ce8e20a-6f74-4cb8-b94c-b771736c7e43", 143 | "showTitle": false, 144 | "title": "" 145 | } 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "# Preprocess with a 1 second sleep\n", 150 | "# Since the parallelism is more than data partitions, all partitions should run in parallel.\n", 151 | "dummy_preprocessed_data = cached_input.select(dummy_preprocess(col(\"image\")))" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 0, 157 | "metadata": { 158 | "application/vnd.databricks.v1+cell": { 159 | "cellMetadata": { 160 | "byteLimit": 2048000, 161 | "rowLimit": 10000 162 | }, 163 | "inputWidgets": {}, 164 | "nuid": "b004ed5f-7ce5-4f02-b140-773804fa5570", 165 | "showTitle": false, 166 | "title": "" 167 | } 168 | }, 169 | "outputs": [ 170 | { 171 | "output_type": "stream", 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "Preprocessing took: 128.91825461387634 seconds\n" 176 | ] 177 | } 178 | ], 179 | "source": [ 180 | "# Force execution of preprocessing\n", 181 | "\n", 182 | "start_time = time.time()\n", 183 | "dummy_preprocessed_data.write.mode(\"overwrite\").format(\"noop\").save()\n", 184 | "end_time = time.time()\n", 185 | "print(f\"Preprocessing took: {end_time-start_time} seconds\")" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 0, 191 | "metadata": { 192 | "application/vnd.databricks.v1+cell": { 193 | "cellMetadata": { 194 | "byteLimit": 2048000, 195 | "rowLimit": 10000 196 | }, 197 | "inputWidgets": {}, 198 | "nuid": "3065d09f-3c50-431c-9ba7-4189dbdd1da2", 199 | "showTitle": false, 200 | "title": "" 201 | } 202 | }, 203 | "outputs": [ 204 | { 205 | "output_type": "stream", 206 | "name": "stdout", 207 | "output_type": "stream", 208 | "text": [ 209 | "============================================================\nProfile of UDF\n============================================================\n 48 function calls in 16.015 seconds\n\n Ordered by: internal time, cumulative time\n\n ncalls tottime percall cumtime percall filename:lineno(function)\n 16 16.015 1.001 16.015 1.001 {built-in method time.sleep}\n 16 0.000 0.000 16.015 1.001 :7(dummy_preprocess)\n 16 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}\n\n\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "sc.show_profiles()" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 0, 220 | "metadata": { 221 | "application/vnd.databricks.v1+cell": { 222 | "cellMetadata": {}, 223 | "inputWidgets": {}, 224 | "nuid": "a4c1b933-45e5-4116-9481-e0525917d640", 225 | "showTitle": false, 226 | "title": "" 227 | } 228 | }, 229 | "outputs": [], 230 | "source": [] 231 | } 232 | ], 233 | "metadata": { 234 | "application/vnd.databricks.v1+notebook": { 235 | "dashboards": [], 236 | "language": "python", 237 | "notebookMetadata": { 238 | "pythonIndentUnit": 4 239 | }, 240 | "notebookName": "microbenchmark", 241 | "notebookOrigID": 566047737056974, 242 | "widgets": {} 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 0 247 | } 248 | -------------------------------------------------------------------------------- /spark/code/torch-batch-inference-10G-s3-cpu-only.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","source":["#print(\"Profiling enabled: \", spark.conf.get(\"spark.python.profile\"))\nprint(\"Executor memory: \", spark.conf.get(\"spark.executor.memory\"))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"59841386-1fb4-4a0c-afbe-922f8f8049d7","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Executor memory: 37182m\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Executor memory: 37182m\n"]}}],"execution_count":0},{"cell_type":"code","source":["!pip install numpy -U\n!pip install torchvision"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"54e12fe4-68d4-4b8e-b211-cd8c18e5c95c","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Requirement already satisfied: numpy in /local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/lib/python3.9/site-packages (1.24.1)\r\n\u001B[33mWARNING: You are using pip version 21.2.4; however, version 22.3.1 is available.\r\nYou should consider upgrading via the '/local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/bin/python -m pip install --upgrade pip' command.\u001B[0m\r\nRequirement already satisfied: torchvision in /local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/lib/python3.9/site-packages (0.14.1)\r\nRequirement already satisfied: numpy in /local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/lib/python3.9/site-packages (from torchvision) (1.24.1)\r\nRequirement already satisfied: torch in /local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/lib/python3.9/site-packages (from torchvision) (1.13.1)\r\nRequirement already satisfied: typing-extensions in /databricks/python3/lib/python3.9/site-packages (from torchvision) (4.1.1)\r\nRequirement already satisfied: requests in /databricks/python3/lib/python3.9/site-packages (from torchvision) (2.27.1)\r\nRequirement already satisfied: pillow!=8.3.*,>=5.3.0 in /databricks/python3/lib/python3.9/site-packages (from torchvision) (9.0.1)\r\nRequirement already satisfied: certifi>=2017.4.17 in /databricks/python3/lib/python3.9/site-packages (from requests->torchvision) (2021.10.8)\r\nRequirement already satisfied: charset-normalizer~=2.0.0 in /databricks/python3/lib/python3.9/site-packages (from requests->torchvision) (2.0.4)\r\nRequirement already satisfied: idna<4,>=2.5 in /databricks/python3/lib/python3.9/site-packages (from requests->torchvision) (3.3)\r\nRequirement already satisfied: urllib3<1.27,>=1.21.1 in /databricks/python3/lib/python3.9/site-packages (from requests->torchvision) (1.26.9)\r\n\u001B[33mWARNING: You are using pip version 21.2.4; however, version 22.3.1 is available.\r\nYou should consider upgrading via the '/local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/bin/python -m pip install --upgrade pip' command.\u001B[0m\r\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Requirement already satisfied: numpy in /local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/lib/python3.9/site-packages (1.24.1)\r\n\u001B[33mWARNING: You are using pip version 21.2.4; however, version 22.3.1 is available.\r\nYou should consider upgrading via the '/local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/bin/python -m pip install --upgrade pip' command.\u001B[0m\r\nRequirement already satisfied: torchvision in /local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/lib/python3.9/site-packages (0.14.1)\r\nRequirement already satisfied: numpy in /local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/lib/python3.9/site-packages (from torchvision) (1.24.1)\r\nRequirement already satisfied: torch in /local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/lib/python3.9/site-packages (from torchvision) (1.13.1)\r\nRequirement already satisfied: typing-extensions in /databricks/python3/lib/python3.9/site-packages (from torchvision) (4.1.1)\r\nRequirement already satisfied: requests in /databricks/python3/lib/python3.9/site-packages (from torchvision) (2.27.1)\r\nRequirement already satisfied: pillow!=8.3.*,>=5.3.0 in /databricks/python3/lib/python3.9/site-packages (from torchvision) (9.0.1)\r\nRequirement already satisfied: certifi>=2017.4.17 in /databricks/python3/lib/python3.9/site-packages (from requests->torchvision) (2021.10.8)\r\nRequirement already satisfied: charset-normalizer~=2.0.0 in /databricks/python3/lib/python3.9/site-packages (from requests->torchvision) (2.0.4)\r\nRequirement already satisfied: idna<4,>=2.5 in /databricks/python3/lib/python3.9/site-packages (from requests->torchvision) (3.3)\r\nRequirement already satisfied: urllib3<1.27,>=1.21.1 in /databricks/python3/lib/python3.9/site-packages (from requests->torchvision) (1.26.9)\r\n\u001B[33mWARNING: You are using pip version 21.2.4; however, version 22.3.1 is available.\r\nYou should consider upgrading via the '/local_disk0/.ephemeral_nfs/envs/pythonEnv-11bb0e31-d4a7-46cc-88b9-8edcb48ae8a2/bin/python -m pip install --upgrade pip' command.\u001B[0m\r\n"]}}],"execution_count":0},{"cell_type":"code","source":["import pandas as pd\nfrom torchvision import transforms\nimport time"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9d5cb47e-49c5-4904-9df3-538972665541","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# Enable Arrow support.\nspark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"a4832578-6320-43f6-a462-bd8eb1ae454b","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["input_data = spark.read.format(\"parquet\").load(\"s3://air-example-data-2/10G-image-data-synthetic-raw-parquet\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"525bb7ea-f7b3-44ff-a013-3d8d9fc82135","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# Preprocessing\nfrom pyspark.sql.functions import col, pandas_udf\nfrom pyspark.sql.types import ArrayType, FloatType\n\nimport numpy as np\n\nimport torch\nimport time\n\n# Read documentation here: https://spark.apache.org/docs/3.0.1/sql-pyspark-pandas-with-arrow.html\n\n@pandas_udf(ArrayType(FloatType()))\ndef preprocess(image: pd.Series) -> pd.Series:\n preprocess = transforms.Compose(\n [\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n ]\n )\n print(f\"number of images: {len(image)}\")\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n # Need to convert to float dtype otherwise torchvision transforms will complain. The data is read as short (int16) by default\n batch_dim = len(image)\n numpy_batch = np.stack(image.values)\n reshaped_images = numpy_batch.reshape(batch_dim, 256, 256, 3).astype(float)\n \n torch_tensor = torch.Tensor(reshaped_images.transpose(0, 3, 1, 2))\n preprocessed_images = preprocess(torch_tensor).numpy()\n # Arrow only works with single dimension numpy arrays, so need to flatten the array before outputting it\n preprocessed_images = [image.flatten() for image in preprocessed_images]\n return pd.Series(preprocessed_images)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"ef059a6f-8c22-44ef-960c-518956975eb2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["preprocessed_data = input_data.select(preprocess(col(\"image\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9af2d6fc-4fb8-4983-aca3-672381e96310","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["#dbutils.fs.rm(\"/preprocessed_data/\", recurse=True)\ndbutils.fs.mkdirs(\"/preprocessed_data/\")\ndbutils.fs.ls(\"/preprocessed_data/\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"37e47f37-a300-449d-8e8b-25bbb00d5e22","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Out[10]: []","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Out[10]: []"]}}],"execution_count":0},{"cell_type":"code","source":["start_time = time.time()\npreprocessed_data.write.mode(\"overwrite\").format(\"parquet\").save(\"/preprocessed_data/\")\nend_time = time.time()\nprint(f\"Preprocessing+Writing took: {end_time-start_time} seconds\")\nassert preprocessed_data.count() == 16232"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"00bb4621-2878-41fa-9895-b5415b609380","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Preprocessing+Writing took: 49.70147895812988 seconds\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Preprocessing+Writing took: 49.70147895812988 seconds\n"]}}],"execution_count":0},{"cell_type":"code","source":[""],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"26fe5ed2-dc5c-43c6-aeef-80b5704ed73a","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0}],"metadata":{"application/vnd.databricks.v1+notebook":{"notebookName":"torch-batch-inference-10G-s3-cpu-only","dashboards":[],"notebookMetadata":{"pythonIndentUnit":4},"language":"python","widgets":{},"notebookOrigID":3607923681779397}},"nbformat":4,"nbformat_minor":0} 2 | -------------------------------------------------------------------------------- /spark/code/torch-batch-inference-10G-s3-predict-only.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","source":["import time\nimport torch\ntorch.__version__\ntorch.cuda.is_available()"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"701d2809-daab-4de6-979d-ff8c6aab47cf","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Out[1]: True","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Out[1]: True"]}}],"execution_count":0},{"cell_type":"code","source":["#print(\"Profiling enabled: \", spark.conf.get(\"spark.python.profile\"))\nprint(\"Executor memory: \", spark.conf.get(\"spark.executor.memory\"))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"59841386-1fb4-4a0c-afbe-922f8f8049d7","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Executor memory: 148728m\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Executor memory: 148728m\n"]}}],"execution_count":0},{"cell_type":"code","source":["import pandas as pd\nfrom torchvision import transforms\nfrom torchvision.models import resnet50, ResNet50_Weights"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9d5cb47e-49c5-4904-9df3-538972665541","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"},"removedWidgets":[],"addedWidgets":{},"metadata":{"kernelSessionId":"1d550a9e-78b77e858748bf6fc5111be6"},"type":"mimeBundle","arguments":{}}},"output_type":"display_data","data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"}}],"execution_count":0},{"cell_type":"code","source":["# Enable Arrow support.\nspark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"a4832578-6320-43f6-a462-bd8eb1ae454b","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# Create and broadcast model state. Equivalent to AIR Checkpoint\nmodel_state = resnet50(weights=ResNet50_Weights.DEFAULT).state_dict()\n# sc is already initialized by Databricks. Broadcast the model state to all executors.\nbc_model_state = sc.broadcast(model_state)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"b6229f4e-80ba-40ea-bac7-9d07bc2b84fe","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n"]}},{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":" 0%| | 0.00/97.8M [00:00 pd.Series:\n with torch.inference_mode():\n model = resnet50()\n model.load_state_dict(bc_model_state.value)\n model = model.to(torch.device(\"cuda\")) # Move model to GPU\n model.eval()\n \n batch = preprocessed_images\n batch_dim = len(batch)\n numpy_batch = np.stack(batch.values)\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n reshaped_images = numpy_batch.reshape(batch_dim, 3, 224, 224)\n gpu_batch = torch.Tensor(reshaped_images).to(torch.device(\"cuda\"))\n predictions = list(model(gpu_batch).cpu().numpy())\n assert len(predictions) == batch_dim\n \n return pd.Series(predictions)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"48c15069-1c2f-47c3-9b1c-6fe0eb5641e2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["predictions = input_data.select(predict(col(\"preprocess(image)\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"d30eb582-afe4-433c-9fee-9dc5c6f3cf86","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["start_time = time.time()\npredictions.write.mode(\"overwrite\").format(\"noop\").save()\nend_time = time.time()\nprint(f\"Prediction took: {end_time-start_time} seconds\")\n\nassert predictions.count() == 16232"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"505599b5-fdd0-4f55-bb9d-f17898858f47","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Prediction took: 99.5347626209259 seconds\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Prediction took: 99.5347626209259 seconds\n"]}}],"execution_count":0},{"cell_type":"code","source":[""],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"d27b067d-7e24-49ea-9696-69d7bacaf63f","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0}],"metadata":{"application/vnd.databricks.v1+notebook":{"notebookName":"torch-batch-inference-10G-s3-predict-only","dashboards":[],"notebookMetadata":{"pythonIndentUnit":4},"language":"python","widgets":{},"notebookOrigID":3607923681779423}},"nbformat":4,"nbformat_minor":0} 2 | -------------------------------------------------------------------------------- /spark/code/torch-batch-inference-10G-stage-level-scheduling.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","source":["import time\nimport torch\ntorch.__version__\ntorch.cuda.is_available()"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"701d2809-daab-4de6-979d-ff8c6aab47cf","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Out[1]: True","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Out[1]: True"]}}],"execution_count":0},{"cell_type":"code","source":["print(\"Nums gpus per task: \", spark.conf.get(\"spark.task.resource.gpu.amount\"))\nprint(\"Executor memory: \", spark.conf.get(\"spark.executor.memory\"))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"59841386-1fb4-4a0c-afbe-922f8f8049d7","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Nums gpus per task: 0.0833\nExecutor memory: 148728m\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Nums gpus per task: 0.0833\nExecutor memory: 148728m\n"]}}],"execution_count":0},{"cell_type":"code","source":["import pandas as pd\nfrom torchvision import transforms\nfrom torchvision.models import resnet50, ResNet50_Weights"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9d5cb47e-49c5-4904-9df3-538972665541","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"},"removedWidgets":[],"addedWidgets":{},"metadata":{"kernelSessionId":"594d1087-ded35e044d5dee7765bc0ce2"},"type":"mimeBundle","arguments":{}}},"output_type":"display_data","data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"}}],"execution_count":0},{"cell_type":"code","source":["# Enable Arrow support.\nspark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"a4832578-6320-43f6-a462-bd8eb1ae454b","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# Create and broadcast model state. Equivalent to AIR Checkpoint\nmodel_state = resnet50(weights=ResNet50_Weights.DEFAULT).state_dict()\n# sc is already initialized by Databricks. Broadcast the model state to all executors.\nbc_model_state = sc.broadcast(model_state)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"b6229f4e-80ba-40ea-bac7-9d07bc2b84fe","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n"]}},{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":" 0%| | 0.00/97.8M [00:00 pd.Series:\n preprocess = transforms.Compose(\n [\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n ]\n )\n print(f\"number of images: {len(image)}\")\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n # Need to convert to float dtype otherwise torchvision transforms will complain. The data is read as short (int16) by default\n batch_dim = len(image)\n numpy_batch = np.stack(image.values)\n reshaped_images = numpy_batch.reshape(batch_dim, 256, 256, 3).astype(float)\n \n torch_tensor = torch.Tensor(reshaped_images.transpose(0, 3, 1, 2))\n preprocessed_images = preprocess(torch_tensor).numpy()\n # Arrow only works with single dimension numpy arrays, so need to flatten the array before outputting it\n preprocessed_images = [image.flatten() for image in preprocessed_images]\n return pd.Series(preprocessed_images)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"ef059a6f-8c22-44ef-960c-518956975eb2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["preprocessed_data = input_data.select(preprocess(col(\"image\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9af2d6fc-4fb8-4983-aca3-672381e96310","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"370ef5d5-55a1-4821-b428-414a62a36620","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["def predict_custom_batching(input_rdd_iter):\n with torch.inference_mode():\n model = resnet50()\n model.load_state_dict(bc_model_state.value)\n model = model.to(torch.device(\"cuda\")) # Move model to GPU\n model.eval()\n \n input_batch = []\n output = []\n for image in input_rdd_iter:\n input_batch.append(image)\n\n if len(input_batch) == 1000:\n numpy_batch = np.array(input_batch)\n reshaped_images = numpy_batch.reshape(1000, 3, 224, 224)\n gpu_batch = torch.Tensor(reshaped_images).to(torch.device(\"cuda\"))\n predictions = list(model(gpu_batch).cpu().numpy())\n assert len(predictions) == 1000\n output.extend(list(predictions))\n\n return output"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"48c15069-1c2f-47c3-9b1c-6fe0eb5641e2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["from pyspark.resource.profile import ResourceProfileBuilder\nfrom pyspark.resource.requests import TaskResourceRequests\n\ntask_res_req = TaskResourceRequests().cpus(int(spark.sparkContext.getConf().get(\"spark.task.cpus\", \"1\")))\ntask_res_req.resource(\"gpu\", 1)\nres_profile = ResourceProfileBuilder().require(task_res_req).build\n\npreprocessed_rdd = preprocessed_data.rdd.withResources(res_profile)\npredictions = preprocessed_rdd.mapPartitions(predict_custom_batching).collect()"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"d30eb582-afe4-433c-9fee-9dc5c6f3cf86","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"data":"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m\n\u001B[0;31mPy4JJavaError\u001B[0m Traceback (most recent call last)\nFile \u001B[0;32m:18\u001B[0m\n\u001B[1;32m 15\u001B[0m res_profile \u001B[38;5;241m=\u001B[39m ResourceProfileBuilder()\u001B[38;5;241m.\u001B[39mrequire(task_res_req)\u001B[38;5;241m.\u001B[39mbuild\n\u001B[1;32m 17\u001B[0m preprocessed_rdd \u001B[38;5;241m=\u001B[39m preprocessed_data\u001B[38;5;241m.\u001B[39mrdd\u001B[38;5;241m.\u001B[39mwithResources(res_profile)\n\u001B[0;32m---> 18\u001B[0m predictions \u001B[38;5;241m=\u001B[39m preprocessed_rdd\u001B[38;5;241m.\u001B[39mmapPartitions(predict_custom_batching)\u001B[38;5;241m.\u001B[39mcollect()\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/instrumentation_utils.py:48\u001B[0m, in \u001B[0;36m_wrap_function..wrapper\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m 46\u001B[0m start \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mperf_counter()\n\u001B[1;32m 47\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 48\u001B[0m res \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 49\u001B[0m logger\u001B[38;5;241m.\u001B[39mlog_success(\n\u001B[1;32m 50\u001B[0m module_name, class_name, function_name, time\u001B[38;5;241m.\u001B[39mperf_counter() \u001B[38;5;241m-\u001B[39m start, signature\n\u001B[1;32m 51\u001B[0m )\n\u001B[1;32m 52\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m res\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/rdd.py:1219\u001B[0m, in \u001B[0;36mRDD.collect\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 1217\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m SCCallSiteSync(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcontext):\n\u001B[1;32m 1218\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mctx\u001B[38;5;241m.\u001B[39m_jvm \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m-> 1219\u001B[0m sock_info \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mctx\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_jvm\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mPythonRDD\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcollectAndServe\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_jrdd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrdd\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1220\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mlist\u001B[39m(_load_from_socket(sock_info, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_jrdd_deserializer))\n\nFile \u001B[0;32m/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py:1321\u001B[0m, in \u001B[0;36mJavaMember.__call__\u001B[0;34m(self, *args)\u001B[0m\n\u001B[1;32m 1315\u001B[0m command \u001B[38;5;241m=\u001B[39m proto\u001B[38;5;241m.\u001B[39mCALL_COMMAND_NAME \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1316\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcommand_header \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1317\u001B[0m args_command \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1318\u001B[0m proto\u001B[38;5;241m.\u001B[39mEND_COMMAND_PART\n\u001B[1;32m 1320\u001B[0m answer \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mgateway_client\u001B[38;5;241m.\u001B[39msend_command(command)\n\u001B[0;32m-> 1321\u001B[0m return_value \u001B[38;5;241m=\u001B[39m \u001B[43mget_return_value\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 1322\u001B[0m \u001B[43m \u001B[49m\u001B[43manswer\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgateway_client\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtarget_id\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mname\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1324\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m temp_arg \u001B[38;5;129;01min\u001B[39;00m temp_args:\n\u001B[1;32m 1325\u001B[0m temp_arg\u001B[38;5;241m.\u001B[39m_detach()\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/sql/utils.py:196\u001B[0m, in \u001B[0;36mcapture_sql_exception..deco\u001B[0;34m(*a, **kw)\u001B[0m\n\u001B[1;32m 194\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mdeco\u001B[39m(\u001B[38;5;241m*\u001B[39ma: Any, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkw: Any) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Any:\n\u001B[1;32m 195\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 196\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mf\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43ma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkw\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 197\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m Py4JJavaError \u001B[38;5;28;01mas\u001B[39;00m e:\n\u001B[1;32m 198\u001B[0m converted \u001B[38;5;241m=\u001B[39m convert_exception(e\u001B[38;5;241m.\u001B[39mjava_exception)\n\nFile \u001B[0;32m/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/protocol.py:326\u001B[0m, in \u001B[0;36mget_return_value\u001B[0;34m(answer, gateway_client, target_id, name)\u001B[0m\n\u001B[1;32m 324\u001B[0m value \u001B[38;5;241m=\u001B[39m OUTPUT_CONVERTER[\u001B[38;5;28mtype\u001B[39m](answer[\u001B[38;5;241m2\u001B[39m:], gateway_client)\n\u001B[1;32m 325\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m answer[\u001B[38;5;241m1\u001B[39m] \u001B[38;5;241m==\u001B[39m REFERENCE_TYPE:\n\u001B[0;32m--> 326\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m Py4JJavaError(\n\u001B[1;32m 327\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mAn error occurred while calling \u001B[39m\u001B[38;5;132;01m{0}\u001B[39;00m\u001B[38;5;132;01m{1}\u001B[39;00m\u001B[38;5;132;01m{2}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39m\n\u001B[1;32m 328\u001B[0m \u001B[38;5;28mformat\u001B[39m(target_id, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m, name), value)\n\u001B[1;32m 329\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 330\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m Py4JError(\n\u001B[1;32m 331\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mAn error occurred while calling \u001B[39m\u001B[38;5;132;01m{0}\u001B[39;00m\u001B[38;5;132;01m{1}\u001B[39;00m\u001B[38;5;132;01m{2}\u001B[39;00m\u001B[38;5;124m. Trace:\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;132;01m{3}\u001B[39;00m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39m\n\u001B[1;32m 332\u001B[0m \u001B[38;5;28mformat\u001B[39m(target_id, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m, name, value))\n\n\u001B[0;31mPy4JJavaError\u001B[0m: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.\n: org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 9.0 failed 4 times, most recent failure: Lost task 3.3 in stage 9.0 (TID 489) (10.146.255.71 executor 1): org.apache.spark.api.python.PythonException: 'RuntimeError: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.', from , line 28. Full traceback below:\nTraceback (most recent call last):\n File \"/databricks/spark/python/pyspark/worker.py\", line 1019, in main\n process()\n File \"/databricks/spark/python/pyspark/worker.py\", line 1009, in process\n out_iter = func(split_index, iterator)\n File \"/databricks/spark/python/pyspark/rdd.py\", line 543, in func\n return f(iterator)\n File \"\", line 28, in predict_custom_batching\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 927, in to\n return self._apply(convert)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 579, in _apply\n module._apply(fn)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 602, in _apply\n param_applied = fn(param)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 925, in convert\n return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)\nRuntimeError: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n\n\tat org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:694)\n\tat org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:904)\n\tat org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:886)\n\tat org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:647)\n\tat org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)\n\tat scala.collection.Iterator.foreach(Iterator.scala:943)\n\tat scala.collection.Iterator.foreach$(Iterator.scala:943)\n\tat org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)\n\tat scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)\n\tat scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)\n\tat scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)\n\tat scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)\n\tat scala.collection.TraversableOnce.to(TraversableOnce.scala:366)\n\tat scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)\n\tat org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)\n\tat scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)\n\tat scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)\n\tat org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)\n\tat scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)\n\tat scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)\n\tat org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)\n\tat org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1028)\n\tat org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)\n\tat org.apache.spark.scheduler.Task.doRunTask(Task.scala:169)\n\tat org.apache.spark.scheduler.Task.$anonfun$run$4(Task.scala:137)\n\tat com.databricks.unity.EmptyHandle$.runWithAndClose(UCSHandle.scala:104)\n\tat org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:137)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:96)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:902)\n\tat org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1702)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:905)\n\tat scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:760)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:750)\n\nDriver stacktrace:\n\tat org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3312)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3244)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3235)\n\tat scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)\n\tat scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)\n\tat scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)\n\tat org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3235)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1425)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1425)\n\tat scala.Option.foreach(Option.scala:407)\n\tat org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1425)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3524)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3462)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3450)\n\tat org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:51)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1170)\n\tat scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)\n\tat com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)\n\tat org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1158)\n\tat org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2702)\n\tat org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1026)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)\n\tat org.apache.spark.rdd.RDD.withScope(RDD.scala:410)\n\tat org.apache.spark.rdd.RDD.collect(RDD.scala:1024)\n\tat org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:282)\n\tat org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)\n\tat sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\n\tat sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\n\tat sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat java.lang.reflect.Method.invoke(Method.java:498)\n\tat py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\n\tat py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)\n\tat py4j.Gateway.invoke(Gateway.java:306)\n\tat py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\n\tat py4j.commands.CallCommand.execute(CallCommand.java:79)\n\tat py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:195)\n\tat py4j.ClientServerConnection.run(ClientServerConnection.java:115)\n\tat java.lang.Thread.run(Thread.java:750)\nCaused by: org.apache.spark.api.python.PythonException: 'RuntimeError: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.', from , line 28. Full traceback below:\nTraceback (most recent call last):\n File \"/databricks/spark/python/pyspark/worker.py\", line 1019, in main\n process()\n File \"/databricks/spark/python/pyspark/worker.py\", line 1009, in process\n out_iter = func(split_index, iterator)\n File \"/databricks/spark/python/pyspark/rdd.py\", line 543, in func\n return f(iterator)\n File \"\", line 28, in predict_custom_batching\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 927, in to\n return self._apply(convert)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 579, in _apply\n module._apply(fn)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 602, in _apply\n param_applied = fn(param)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 925, in convert\n return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)\nRuntimeError: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n\n\tat org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:694)\n\tat org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:904)\n\tat org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:886)\n\tat org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:647)\n\tat org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)\n\tat scala.collection.Iterator.foreach(Iterator.scala:943)\n\tat scala.collection.Iterator.foreach$(Iterator.scala:943)\n\tat org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)\n\tat scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)\n\tat scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)\n\tat scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)\n\tat scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)\n\tat scala.collection.TraversableOnce.to(TraversableOnce.scala:366)\n\tat scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)\n\tat org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)\n\tat scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)\n\tat scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)\n\tat org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)\n\tat scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)\n\tat scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)\n\tat org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)\n\tat org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1028)\n\tat org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)\n\tat org.apache.spark.scheduler.Task.doRunTask(Task.scala:169)\n\tat org.apache.spark.scheduler.Task.$anonfun$run$4(Task.scala:137)\n\tat com.databricks.unity.EmptyHandle$.runWithAndClose(UCSHandle.scala:104)\n\tat org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:137)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:96)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:902)\n\tat org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1702)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:905)\n\tat scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:760)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\t... 1 more\n","errorSummary":"org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 9.0 failed 4 times, most recent failure: Lost task 3.3 in stage 9.0 (TID 489) (10.146.255.71 executor 1): org.apache.spark.api.python.PythonException: 'RuntimeError: CUDA error: out of memory","metadata":{},"errorTraceType":"ansi","type":"ipynbError","arguments":{}}},"output_type":"display_data","data":{"text/plain":["\u001B[0;31m---------------------------------------------------------------------------\u001B[0m\n\u001B[0;31mPy4JJavaError\u001B[0m Traceback (most recent call last)\nFile \u001B[0;32m:18\u001B[0m\n\u001B[1;32m 15\u001B[0m res_profile \u001B[38;5;241m=\u001B[39m ResourceProfileBuilder()\u001B[38;5;241m.\u001B[39mrequire(task_res_req)\u001B[38;5;241m.\u001B[39mbuild\n\u001B[1;32m 17\u001B[0m preprocessed_rdd \u001B[38;5;241m=\u001B[39m preprocessed_data\u001B[38;5;241m.\u001B[39mrdd\u001B[38;5;241m.\u001B[39mwithResources(res_profile)\n\u001B[0;32m---> 18\u001B[0m predictions \u001B[38;5;241m=\u001B[39m preprocessed_rdd\u001B[38;5;241m.\u001B[39mmapPartitions(predict_custom_batching)\u001B[38;5;241m.\u001B[39mcollect()\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/instrumentation_utils.py:48\u001B[0m, in \u001B[0;36m_wrap_function..wrapper\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m 46\u001B[0m start \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mperf_counter()\n\u001B[1;32m 47\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 48\u001B[0m res \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 49\u001B[0m logger\u001B[38;5;241m.\u001B[39mlog_success(\n\u001B[1;32m 50\u001B[0m module_name, class_name, function_name, time\u001B[38;5;241m.\u001B[39mperf_counter() \u001B[38;5;241m-\u001B[39m start, signature\n\u001B[1;32m 51\u001B[0m )\n\u001B[1;32m 52\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m res\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/rdd.py:1219\u001B[0m, in \u001B[0;36mRDD.collect\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 1217\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m SCCallSiteSync(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcontext):\n\u001B[1;32m 1218\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mctx\u001B[38;5;241m.\u001B[39m_jvm \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m-> 1219\u001B[0m sock_info \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mctx\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_jvm\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mPythonRDD\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcollectAndServe\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_jrdd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrdd\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1220\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mlist\u001B[39m(_load_from_socket(sock_info, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_jrdd_deserializer))\n\nFile \u001B[0;32m/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py:1321\u001B[0m, in \u001B[0;36mJavaMember.__call__\u001B[0;34m(self, *args)\u001B[0m\n\u001B[1;32m 1315\u001B[0m command \u001B[38;5;241m=\u001B[39m proto\u001B[38;5;241m.\u001B[39mCALL_COMMAND_NAME \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1316\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcommand_header \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1317\u001B[0m args_command \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1318\u001B[0m proto\u001B[38;5;241m.\u001B[39mEND_COMMAND_PART\n\u001B[1;32m 1320\u001B[0m answer \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mgateway_client\u001B[38;5;241m.\u001B[39msend_command(command)\n\u001B[0;32m-> 1321\u001B[0m return_value \u001B[38;5;241m=\u001B[39m \u001B[43mget_return_value\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 1322\u001B[0m \u001B[43m \u001B[49m\u001B[43manswer\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgateway_client\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtarget_id\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mname\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1324\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m temp_arg \u001B[38;5;129;01min\u001B[39;00m temp_args:\n\u001B[1;32m 1325\u001B[0m temp_arg\u001B[38;5;241m.\u001B[39m_detach()\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/sql/utils.py:196\u001B[0m, in \u001B[0;36mcapture_sql_exception..deco\u001B[0;34m(*a, **kw)\u001B[0m\n\u001B[1;32m 194\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mdeco\u001B[39m(\u001B[38;5;241m*\u001B[39ma: Any, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkw: Any) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Any:\n\u001B[1;32m 195\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 196\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mf\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43ma\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkw\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 197\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m Py4JJavaError \u001B[38;5;28;01mas\u001B[39;00m e:\n\u001B[1;32m 198\u001B[0m converted \u001B[38;5;241m=\u001B[39m convert_exception(e\u001B[38;5;241m.\u001B[39mjava_exception)\n\nFile \u001B[0;32m/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/protocol.py:326\u001B[0m, in \u001B[0;36mget_return_value\u001B[0;34m(answer, gateway_client, target_id, name)\u001B[0m\n\u001B[1;32m 324\u001B[0m value \u001B[38;5;241m=\u001B[39m OUTPUT_CONVERTER[\u001B[38;5;28mtype\u001B[39m](answer[\u001B[38;5;241m2\u001B[39m:], gateway_client)\n\u001B[1;32m 325\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m answer[\u001B[38;5;241m1\u001B[39m] \u001B[38;5;241m==\u001B[39m REFERENCE_TYPE:\n\u001B[0;32m--> 326\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m Py4JJavaError(\n\u001B[1;32m 327\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mAn error occurred while calling \u001B[39m\u001B[38;5;132;01m{0}\u001B[39;00m\u001B[38;5;132;01m{1}\u001B[39;00m\u001B[38;5;132;01m{2}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39m\n\u001B[1;32m 328\u001B[0m \u001B[38;5;28mformat\u001B[39m(target_id, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m, name), value)\n\u001B[1;32m 329\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 330\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m Py4JError(\n\u001B[1;32m 331\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mAn error occurred while calling \u001B[39m\u001B[38;5;132;01m{0}\u001B[39;00m\u001B[38;5;132;01m{1}\u001B[39;00m\u001B[38;5;132;01m{2}\u001B[39;00m\u001B[38;5;124m. Trace:\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;132;01m{3}\u001B[39;00m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39m\n\u001B[1;32m 332\u001B[0m \u001B[38;5;28mformat\u001B[39m(target_id, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m, name, value))\n\n\u001B[0;31mPy4JJavaError\u001B[0m: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.\n: org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 9.0 failed 4 times, most recent failure: Lost task 3.3 in stage 9.0 (TID 489) (10.146.255.71 executor 1): org.apache.spark.api.python.PythonException: 'RuntimeError: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.', from , line 28. Full traceback below:\nTraceback (most recent call last):\n File \"/databricks/spark/python/pyspark/worker.py\", line 1019, in main\n process()\n File \"/databricks/spark/python/pyspark/worker.py\", line 1009, in process\n out_iter = func(split_index, iterator)\n File \"/databricks/spark/python/pyspark/rdd.py\", line 543, in func\n return f(iterator)\n File \"\", line 28, in predict_custom_batching\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 927, in to\n return self._apply(convert)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 579, in _apply\n module._apply(fn)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 602, in _apply\n param_applied = fn(param)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 925, in convert\n return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)\nRuntimeError: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n\n\tat org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:694)\n\tat org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:904)\n\tat org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:886)\n\tat org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:647)\n\tat org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)\n\tat scala.collection.Iterator.foreach(Iterator.scala:943)\n\tat scala.collection.Iterator.foreach$(Iterator.scala:943)\n\tat org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)\n\tat scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)\n\tat scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)\n\tat scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)\n\tat scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)\n\tat scala.collection.TraversableOnce.to(TraversableOnce.scala:366)\n\tat scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)\n\tat org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)\n\tat scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)\n\tat scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)\n\tat org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)\n\tat scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)\n\tat scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)\n\tat org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)\n\tat org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1028)\n\tat org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)\n\tat org.apache.spark.scheduler.Task.doRunTask(Task.scala:169)\n\tat org.apache.spark.scheduler.Task.$anonfun$run$4(Task.scala:137)\n\tat com.databricks.unity.EmptyHandle$.runWithAndClose(UCSHandle.scala:104)\n\tat org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:137)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:96)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:902)\n\tat org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1702)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:905)\n\tat scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:760)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\tat java.lang.Thread.run(Thread.java:750)\n\nDriver stacktrace:\n\tat org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3312)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3244)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3235)\n\tat scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)\n\tat scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)\n\tat scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)\n\tat org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3235)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1425)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1425)\n\tat scala.Option.foreach(Option.scala:407)\n\tat org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1425)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3524)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3462)\n\tat org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3450)\n\tat org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:51)\n\tat org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1170)\n\tat scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)\n\tat com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)\n\tat org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1158)\n\tat org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2702)\n\tat org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1026)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)\n\tat org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)\n\tat org.apache.spark.rdd.RDD.withScope(RDD.scala:410)\n\tat org.apache.spark.rdd.RDD.collect(RDD.scala:1024)\n\tat org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:282)\n\tat org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)\n\tat sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\n\tat sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\n\tat sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\tat java.lang.reflect.Method.invoke(Method.java:498)\n\tat py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\n\tat py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)\n\tat py4j.Gateway.invoke(Gateway.java:306)\n\tat py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\n\tat py4j.commands.CallCommand.execute(CallCommand.java:79)\n\tat py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:195)\n\tat py4j.ClientServerConnection.run(ClientServerConnection.java:115)\n\tat java.lang.Thread.run(Thread.java:750)\nCaused by: org.apache.spark.api.python.PythonException: 'RuntimeError: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.', from , line 28. Full traceback below:\nTraceback (most recent call last):\n File \"/databricks/spark/python/pyspark/worker.py\", line 1019, in main\n process()\n File \"/databricks/spark/python/pyspark/worker.py\", line 1009, in process\n out_iter = func(split_index, iterator)\n File \"/databricks/spark/python/pyspark/rdd.py\", line 543, in func\n return f(iterator)\n File \"\", line 28, in predict_custom_batching\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 927, in to\n return self._apply(convert)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 579, in _apply\n module._apply(fn)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 602, in _apply\n param_applied = fn(param)\n File \"/databricks/python/lib/python3.9/site-packages/torch/nn/modules/module.py\", line 925, in convert\n return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)\nRuntimeError: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\n\n\tat org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:694)\n\tat org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:904)\n\tat org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:886)\n\tat org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:647)\n\tat org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)\n\tat scala.collection.Iterator.foreach(Iterator.scala:943)\n\tat scala.collection.Iterator.foreach$(Iterator.scala:943)\n\tat org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)\n\tat scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)\n\tat scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)\n\tat scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)\n\tat scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)\n\tat scala.collection.TraversableOnce.to(TraversableOnce.scala:366)\n\tat scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)\n\tat org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)\n\tat scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)\n\tat scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)\n\tat org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)\n\tat scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)\n\tat scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)\n\tat org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)\n\tat org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1028)\n\tat org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)\n\tat org.apache.spark.scheduler.Task.doRunTask(Task.scala:169)\n\tat org.apache.spark.scheduler.Task.$anonfun$run$4(Task.scala:137)\n\tat com.databricks.unity.EmptyHandle$.runWithAndClose(UCSHandle.scala:104)\n\tat org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:137)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:96)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:902)\n\tat org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1702)\n\tat org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:905)\n\tat scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)\n\tat com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:760)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)\n\t... 1 more\n"]}}],"execution_count":0},{"cell_type":"code","source":["start_time = time.time()\npredictions.write.mode(\"overwrite\").format(\"noop\").save()\nend_time = time.time()\nprint(f\"Prediction took: {end_time-start_time} seconds\")\n\nassert predictions.count() == 16232"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"505599b5-fdd0-4f55-bb9d-f17898858f47","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"data":"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m\n\u001B[0;31mPythonException\u001B[0m Traceback (most recent call last)\nFile \u001B[0;32m:3\u001B[0m\n\u001B[1;32m 1\u001B[0m start_time \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mtime()\n\u001B[1;32m 2\u001B[0m predictions\u001B[38;5;241m.\u001B[39mpersist() \u001B[38;5;66;03m# Persist is a lazy operation- need to also have the line below\u001B[39;00m\n\u001B[0;32m----> 3\u001B[0m predictions\u001B[38;5;241m.\u001B[39mwrite\u001B[38;5;241m.\u001B[39mmode(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124moverwrite\u001B[39m\u001B[38;5;124m\"\u001B[39m)\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mnoop\u001B[39m\u001B[38;5;124m\"\u001B[39m)\u001B[38;5;241m.\u001B[39msave()\n\u001B[1;32m 4\u001B[0m end_time \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mtime()\n\u001B[1;32m 5\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mPrediction took: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mend_time\u001B[38;5;241m-\u001B[39mstart_time\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m seconds\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/instrumentation_utils.py:48\u001B[0m, in \u001B[0;36m_wrap_function..wrapper\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m 46\u001B[0m start \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mperf_counter()\n\u001B[1;32m 47\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 48\u001B[0m res \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 49\u001B[0m logger\u001B[38;5;241m.\u001B[39mlog_success(\n\u001B[1;32m 50\u001B[0m module_name, class_name, function_name, time\u001B[38;5;241m.\u001B[39mperf_counter() \u001B[38;5;241m-\u001B[39m start, signature\n\u001B[1;32m 51\u001B[0m )\n\u001B[1;32m 52\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m res\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/sql/readwriter.py:966\u001B[0m, in \u001B[0;36mDataFrameWriter.save\u001B[0;34m(self, path, format, mode, partitionBy, **options)\u001B[0m\n\u001B[1;32m 964\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;28mformat\u001B[39m)\n\u001B[1;32m 965\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m path \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m--> 966\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_jwrite\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msave\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 967\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 968\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_jwrite\u001B[38;5;241m.\u001B[39msave(path)\n\nFile \u001B[0;32m/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py:1321\u001B[0m, in \u001B[0;36mJavaMember.__call__\u001B[0;34m(self, *args)\u001B[0m\n\u001B[1;32m 1315\u001B[0m command \u001B[38;5;241m=\u001B[39m proto\u001B[38;5;241m.\u001B[39mCALL_COMMAND_NAME \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1316\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcommand_header \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1317\u001B[0m args_command \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1318\u001B[0m proto\u001B[38;5;241m.\u001B[39mEND_COMMAND_PART\n\u001B[1;32m 1320\u001B[0m answer \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mgateway_client\u001B[38;5;241m.\u001B[39msend_command(command)\n\u001B[0;32m-> 1321\u001B[0m return_value \u001B[38;5;241m=\u001B[39m \u001B[43mget_return_value\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 1322\u001B[0m \u001B[43m \u001B[49m\u001B[43manswer\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgateway_client\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtarget_id\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mname\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1324\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m temp_arg \u001B[38;5;129;01min\u001B[39;00m temp_args:\n\u001B[1;32m 1325\u001B[0m temp_arg\u001B[38;5;241m.\u001B[39m_detach()\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/sql/utils.py:202\u001B[0m, in \u001B[0;36mcapture_sql_exception..deco\u001B[0;34m(*a, **kw)\u001B[0m\n\u001B[1;32m 198\u001B[0m converted \u001B[38;5;241m=\u001B[39m convert_exception(e\u001B[38;5;241m.\u001B[39mjava_exception)\n\u001B[1;32m 199\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(converted, UnknownException):\n\u001B[1;32m 200\u001B[0m \u001B[38;5;66;03m# Hide where the exception came from that shows a non-Pythonic\u001B[39;00m\n\u001B[1;32m 201\u001B[0m \u001B[38;5;66;03m# JVM exception message.\u001B[39;00m\n\u001B[0;32m--> 202\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m converted \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;28mNone\u001B[39m\n\u001B[1;32m 203\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 204\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m\n\n\u001B[0;31mPythonException\u001B[0m: An exception was thrown from a UDF: 'RuntimeError: CUDA out of memory. Tried to allocate 288.00 MiB (GPU 0; 14.76 GiB total capacity; 97.73 MiB already allocated; 98.75 MiB free; 118.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF', from , line 21. Full traceback below:\nTraceback (most recent call last):\n File \"\", line 21, in predict_not_iterator\nRuntimeError: CUDA out of memory. Tried to allocate 288.00 MiB (GPU 0; 14.76 GiB total capacity; 97.73 MiB already allocated; 98.75 MiB free; 118.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n","errorSummary":"PythonException: An exception was thrown from a UDF: 'RuntimeError: CUDA out of memory. Tried to allocate 288.00 MiB (GPU 0; 14.76 GiB total capacity; 97.73 MiB already allocated; 98.75 MiB free; 118.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF', from , line 21. Full traceback below:\nTraceback (most recent call last):\n File \"\", line 21, in predict_not_iterator\nRuntimeError: CUDA out of memory. Tried to allocate 288.00 MiB (GPU 0; 14.76 GiB total capacity; 97.73 MiB already allocated; 98.75 MiB free; 118.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n","metadata":{},"errorTraceType":"ansi","type":"ipynbError","arguments":{}}},"output_type":"display_data","data":{"text/plain":["\u001B[0;31m---------------------------------------------------------------------------\u001B[0m\n\u001B[0;31mPythonException\u001B[0m Traceback (most recent call last)\nFile \u001B[0;32m:3\u001B[0m\n\u001B[1;32m 1\u001B[0m start_time \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mtime()\n\u001B[1;32m 2\u001B[0m predictions\u001B[38;5;241m.\u001B[39mpersist() \u001B[38;5;66;03m# Persist is a lazy operation- need to also have the line below\u001B[39;00m\n\u001B[0;32m----> 3\u001B[0m predictions\u001B[38;5;241m.\u001B[39mwrite\u001B[38;5;241m.\u001B[39mmode(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124moverwrite\u001B[39m\u001B[38;5;124m\"\u001B[39m)\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mnoop\u001B[39m\u001B[38;5;124m\"\u001B[39m)\u001B[38;5;241m.\u001B[39msave()\n\u001B[1;32m 4\u001B[0m end_time \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mtime()\n\u001B[1;32m 5\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mPrediction took: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mend_time\u001B[38;5;241m-\u001B[39mstart_time\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m seconds\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/instrumentation_utils.py:48\u001B[0m, in \u001B[0;36m_wrap_function..wrapper\u001B[0;34m(*args, **kwargs)\u001B[0m\n\u001B[1;32m 46\u001B[0m start \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mperf_counter()\n\u001B[1;32m 47\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 48\u001B[0m res \u001B[38;5;241m=\u001B[39m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 49\u001B[0m logger\u001B[38;5;241m.\u001B[39mlog_success(\n\u001B[1;32m 50\u001B[0m module_name, class_name, function_name, time\u001B[38;5;241m.\u001B[39mperf_counter() \u001B[38;5;241m-\u001B[39m start, signature\n\u001B[1;32m 51\u001B[0m )\n\u001B[1;32m 52\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m res\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/sql/readwriter.py:966\u001B[0m, in \u001B[0;36mDataFrameWriter.save\u001B[0;34m(self, path, format, mode, partitionBy, **options)\u001B[0m\n\u001B[1;32m 964\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;28mformat\u001B[39m)\n\u001B[1;32m 965\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m path \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m--> 966\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_jwrite\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msave\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 967\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 968\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_jwrite\u001B[38;5;241m.\u001B[39msave(path)\n\nFile \u001B[0;32m/databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py:1321\u001B[0m, in \u001B[0;36mJavaMember.__call__\u001B[0;34m(self, *args)\u001B[0m\n\u001B[1;32m 1315\u001B[0m command \u001B[38;5;241m=\u001B[39m proto\u001B[38;5;241m.\u001B[39mCALL_COMMAND_NAME \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1316\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcommand_header \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1317\u001B[0m args_command \u001B[38;5;241m+\u001B[39m\\\n\u001B[1;32m 1318\u001B[0m proto\u001B[38;5;241m.\u001B[39mEND_COMMAND_PART\n\u001B[1;32m 1320\u001B[0m answer \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mgateway_client\u001B[38;5;241m.\u001B[39msend_command(command)\n\u001B[0;32m-> 1321\u001B[0m return_value \u001B[38;5;241m=\u001B[39m \u001B[43mget_return_value\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 1322\u001B[0m \u001B[43m \u001B[49m\u001B[43manswer\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgateway_client\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtarget_id\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mname\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1324\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m temp_arg \u001B[38;5;129;01min\u001B[39;00m temp_args:\n\u001B[1;32m 1325\u001B[0m temp_arg\u001B[38;5;241m.\u001B[39m_detach()\n\nFile \u001B[0;32m/databricks/spark/python/pyspark/sql/utils.py:202\u001B[0m, in \u001B[0;36mcapture_sql_exception..deco\u001B[0;34m(*a, **kw)\u001B[0m\n\u001B[1;32m 198\u001B[0m converted \u001B[38;5;241m=\u001B[39m convert_exception(e\u001B[38;5;241m.\u001B[39mjava_exception)\n\u001B[1;32m 199\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(converted, UnknownException):\n\u001B[1;32m 200\u001B[0m \u001B[38;5;66;03m# Hide where the exception came from that shows a non-Pythonic\u001B[39;00m\n\u001B[1;32m 201\u001B[0m \u001B[38;5;66;03m# JVM exception message.\u001B[39;00m\n\u001B[0;32m--> 202\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m converted \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;28mNone\u001B[39m\n\u001B[1;32m 203\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 204\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m\n\n\u001B[0;31mPythonException\u001B[0m: An exception was thrown from a UDF: 'RuntimeError: CUDA out of memory. Tried to allocate 288.00 MiB (GPU 0; 14.76 GiB total capacity; 97.73 MiB already allocated; 98.75 MiB free; 118.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF', from , line 21. Full traceback below:\nTraceback (most recent call last):\n File \"\", line 21, in predict_not_iterator\nRuntimeError: CUDA out of memory. Tried to allocate 288.00 MiB (GPU 0; 14.76 GiB total capacity; 97.73 MiB already allocated; 98.75 MiB free; 118.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n"]}}],"execution_count":0}],"metadata":{"application/vnd.databricks.v1+notebook":{"notebookName":"torch-batch-inference-dbfs-stage-level-scheduling","dashboards":[],"notebookMetadata":{"pythonIndentUnit":4},"language":"python","widgets":{},"notebookOrigID":848636385098271}},"nbformat":4,"nbformat_minor":0} 2 | -------------------------------------------------------------------------------- /spark/code/torch-batch-inference-300G-s3-standard.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","source":["import time\nimport torch\ntorch.__version__\ntorch.cuda.is_available()"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"701d2809-daab-4de6-979d-ff8c6aab47cf","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Out[1]: True","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Out[1]: True"]}}],"execution_count":0},{"cell_type":"code","source":["print(\"Executor memory: \", spark.conf.get(\"spark.executor.memory\"))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"59841386-1fb4-4a0c-afbe-922f8f8049d7","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Executor memory: 148728m\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Executor memory: 148728m\n"]}}],"execution_count":0},{"cell_type":"code","source":["import pandas as pd\nfrom torchvision import transforms\nfrom torchvision.models import resnet50, ResNet50_Weights"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9d5cb47e-49c5-4904-9df3-538972665541","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"},"removedWidgets":[],"addedWidgets":{},"metadata":{"kernelSessionId":"859d58e2-8d3023023ccd2dba3d2ccc3b"},"type":"mimeBundle","arguments":{}}},"output_type":"display_data","data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"}}],"execution_count":0},{"cell_type":"code","source":["# Enable Arrow support.\nspark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"a4832578-6320-43f6-a462-bd8eb1ae454b","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# Create and broadcast model state. Equivalent to AIR Checkpoint\nmodel_state = resnet50(weights=ResNet50_Weights.DEFAULT).state_dict()\n# sc is already initialized by Databricks. Broadcast the model state to all executors.\nbc_model_state = sc.broadcast(model_state)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"b6229f4e-80ba-40ea-bac7-9d07bc2b84fe","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n"]}},{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":" 0%| | 0.00/97.8M [00:00 pd.Series:\n preprocess = transforms.Compose(\n [\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n ]\n )\n print(f\"number of images: {len(image)}\")\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n # Need to convert to float dtype otherwise torchvision transforms will complain. The data is read as short (int16) by default\n batch_dim = len(image)\n numpy_batch = np.stack(image.values)\n reshaped_images = numpy_batch.reshape(batch_dim, 256, 256, 3).astype(np.float)\n \n torch_tensor = torch.Tensor(reshaped_images.transpose(0, 3, 1, 2))\n preprocessed_images = preprocess(torch_tensor).numpy()\n # Arrow only works with single dimension numpy arrays, so need to flatten the array before outputting it\n preprocessed_images = [image.flatten() for image in preprocessed_images]\n return pd.Series(preprocessed_images)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"ef059a6f-8c22-44ef-960c-518956975eb2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["preprocessed_data = input_data.select(preprocess(col(\"image\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9af2d6fc-4fb8-4983-aca3-672381e96310","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"370ef5d5-55a1-4821-b428-414a62a36620","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["@pandas_udf(ArrayType(FloatType()))\ndef predict(preprocessed_images: pd.Series) -> pd.Series:\n with torch.inference_mode():\n model = resnet50()\n model.load_state_dict(bc_model_state.value)\n model = model.to(torch.device(\"cuda\")) # Move model to GPU\n model.eval()\n \n batch = preprocessed_images\n batch_dim = len(batch)\n numpy_batch = np.stack(batch.values)\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n reshaped_images = numpy_batch.reshape(batch_dim, 3, 224, 224)\n gpu_batch = torch.Tensor(reshaped_images).to(torch.device(\"cuda\"))\n predictions = list(model(gpu_batch).cpu().numpy())\n assert len(predictions) == batch_dim\n \n return pd.Series(predictions)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"48c15069-1c2f-47c3-9b1c-6fe0eb5641e2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["predictions = preprocessed_data.select(predict(col(\"preprocess(image)\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"d30eb582-afe4-433c-9fee-9dc5c6f3cf86","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["start_time = time.time()\npredictions.write.mode(\"overwrite\").format(\"noop\").save()\nend_time = time.time()\nprint(f\"Prediction took: {end_time-start_time} seconds\")\n\n# 300 GB\nassert predictions.count() == 488207"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"505599b5-fdd0-4f55-bb9d-f17898858f47","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Prediction took: 708.4862368106842 seconds\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Prediction took: 708.4862368106842 seconds\n"]}}],"execution_count":0},{"cell_type":"code","source":["assert predictions.count() == 488207"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"84b8156d-513c-4417-be00-cb67620c00b7","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":[""],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"f01fcea5-3245-436d-af46-002d4c2e5bab","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0}],"metadata":{"application/vnd.databricks.v1+notebook":{"notebookName":"torch-batch-inference-300G-s3-standard","dashboards":[],"notebookMetadata":{"pythonIndentUnit":4},"language":"python","widgets":{},"notebookOrigID":3607923681779439}},"nbformat":4,"nbformat_minor":0} 2 | -------------------------------------------------------------------------------- /spark/code/torch-batch-inference-s3-10G-single-node.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","source":["import time\nimport torch\ntorch.__version__\ntorch.cuda.is_available()"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"701d2809-daab-4de6-979d-ff8c6aab47cf","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Out[1]: True","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Out[1]: True"]}}],"execution_count":0},{"cell_type":"code","source":["print(\"Executor memory: \", spark.conf.get(\"spark.executor.memory\"))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"59841386-1fb4-4a0c-afbe-922f8f8049d7","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Executor memory: 199584m\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Executor memory: 199584m\n"]}}],"execution_count":0},{"cell_type":"code","source":["import pandas as pd\nfrom torchvision import transforms\nfrom torchvision.models import resnet50, ResNet50_Weights"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9d5cb47e-49c5-4904-9df3-538972665541","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"},"removedWidgets":[],"addedWidgets":{},"metadata":{"kernelSessionId":"8a06c5a7-c10a247bdebf13efb742f05e"},"type":"mimeBundle","arguments":{}}},"output_type":"display_data","data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"}}],"execution_count":0},{"cell_type":"code","source":["# Enable Arrow support.\nspark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"a4832578-6320-43f6-a462-bd8eb1ae454b","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# Create and broadcast model state to all executors\nmodel_state = resnet50(weights=ResNet50_Weights.DEFAULT).state_dict()\nbc_model_state = sc.broadcast(model_state)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"b6229f4e-80ba-40ea-bac7-9d07bc2b84fe","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n"]}},{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":" 0%| | 0.00/97.8M [00:00 pd.Series:\n preprocess = transforms.Compose(\n [\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n ]\n )\n print(f\"number of images: {len(image)}\")\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n # Need to convert to float dtype otherwise torchvision transforms will complain. The data is read as short (int16) by default\n batch_dim = len(image)\n numpy_batch = np.stack(image.values)\n reshaped_images = numpy_batch.reshape(batch_dim, 256, 256, 3).astype(np.float)\n \n torch_tensor = torch.Tensor(reshaped_images.transpose(0, 3, 1, 2))\n preprocessed_images = preprocess(torch_tensor).numpy()\n # Arrow only works with single dimension numpy arrays, so need to flatten the array before outputting it\n preprocessed_images = [image.flatten() for image in preprocessed_images]\n return pd.Series(preprocessed_images)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"ef059a6f-8c22-44ef-960c-518956975eb2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["preprocessed_data = input_data.select(preprocess(col(\"image\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9af2d6fc-4fb8-4983-aca3-672381e96310","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# 1000 is the largest batch size that can fit on GPU. Limit batch size to 1000 to avoid CUDA OOM.\nspark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"370ef5d5-55a1-4821-b428-414a62a36620","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["@pandas_udf(ArrayType(FloatType()))\ndef predict(preprocessed_images: pd.Series) -> pd.Series:\n with torch.inference_mode():\n model = resnet50()\n model.load_state_dict(bc_model_state.value)\n model = model.to(torch.device(\"cuda\")) # Move model to GPU\n model.eval()\n \n batch = preprocessed_images\n batch_dim = len(batch)\n numpy_batch = np.stack(batch.values)\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n reshaped_images = numpy_batch.reshape(batch_dim, 3, 224, 224)\n gpu_batch = torch.Tensor(reshaped_images).to(torch.device(\"cuda\"))\n predictions = list(model(gpu_batch).cpu().numpy())\n assert len(predictions) == batch_dim\n \n return pd.Series(predictions)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"48c15069-1c2f-47c3-9b1c-6fe0eb5641e2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# Repartition to 1 to limit parallelism for GPU prediction.\n# Single node clusters do not have GPU scheduling support.\n# Otherwise 32 tasks will run in parallel causing Cuda OOM.\none_partition_data = preprocessed_data.repartition(1)\npredictions = one_partition_data.select(predict(col(\"preprocess(image)\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"d30eb582-afe4-433c-9fee-9dc5c6f3cf86","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["start_time = time.time()\npredictions.write.mode(\"overwrite\").format(\"noop\").save()\nend_time = time.time()\nprint(f\"Prediction took: {end_time-start_time} seconds\")\n\nassert preprocessed_data.count() == 16232"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"505599b5-fdd0-4f55-bb9d-f17898858f47","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Prediction took: 137.95954656600952 seconds\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Prediction took: 137.95954656600952 seconds\n"]}}],"execution_count":0}],"metadata":{"application/vnd.databricks.v1+notebook":{"notebookName":"torch-batch-inference-s3-10G-single-node","dashboards":[],"notebookMetadata":{"pythonIndentUnit":4},"language":"python","widgets":{},"notebookOrigID":566047737056949}},"nbformat":4,"nbformat_minor":0} 2 | -------------------------------------------------------------------------------- /spark/code/torch-batch-inference-s3-10G-standard-iterator-databricks-prefetch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 0, 6 | "metadata": { 7 | "application/vnd.databricks.v1+cell": { 8 | "cellMetadata": { 9 | "byteLimit": 2048000, 10 | "rowLimit": 10000 11 | }, 12 | "inputWidgets": {}, 13 | "nuid": "e74a8d6d-e81d-4e60-a239-d03eee92e924", 14 | "showTitle": false, 15 | "title": "" 16 | } 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "input_data = spark.read.format(\"parquet\").load(\"s3://air-example-data-2/10G-image-data-synthetic-raw-parquet\")" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 0, 26 | "metadata": { 27 | "application/vnd.databricks.v1+cell": { 28 | "cellMetadata": { 29 | "byteLimit": 2048000, 30 | "rowLimit": 10000 31 | }, 32 | "inputWidgets": {}, 33 | "nuid": "701d2809-daab-4de6-979d-ff8c6aab47cf", 34 | "showTitle": false, 35 | "title": "" 36 | } 37 | }, 38 | "outputs": [ 39 | { 40 | "output_type": "execute_result", 41 | "data": { 42 | "text/plain": [ 43 | "True" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "import time\n", 53 | "import torch\n", 54 | "torch.__version__\n", 55 | "torch.cuda.is_available()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 0, 61 | "metadata": { 62 | "application/vnd.databricks.v1+cell": { 63 | "cellMetadata": { 64 | "byteLimit": 2048000, 65 | "rowLimit": 10000 66 | }, 67 | "inputWidgets": {}, 68 | "nuid": "59841386-1fb4-4a0c-afbe-922f8f8049d7", 69 | "showTitle": false, 70 | "title": "" 71 | } 72 | }, 73 | "outputs": [ 74 | { 75 | "output_type": "stream", 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "Executor memory: 148728m\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "print(\"Executor memory: \", spark.conf.get(\"spark.executor.memory\"))" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 0, 90 | "metadata": { 91 | "application/vnd.databricks.v1+cell": { 92 | "cellMetadata": { 93 | "byteLimit": 2048000, 94 | "rowLimit": 10000 95 | }, 96 | "inputWidgets": {}, 97 | "nuid": "9d5cb47e-49c5-4904-9df3-538972665541", 98 | "showTitle": false, 99 | "title": "" 100 | } 101 | }, 102 | "outputs": [ 103 | { 104 | "output_type": "display_data", 105 | "data": { 106 | "application/vnd.databricks.v1+bamboolib_hint": "{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}", 107 | "text/plain": [] 108 | }, 109 | "metadata": {}, 110 | "output_type": "display_data" 111 | } 112 | ], 113 | "source": [ 114 | "import pandas as pd\n", 115 | "from torchvision import transforms\n", 116 | "from torchvision.models import resnet50, ResNet50_Weights" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 0, 122 | "metadata": { 123 | "application/vnd.databricks.v1+cell": { 124 | "cellMetadata": { 125 | "byteLimit": 2048000, 126 | "rowLimit": 10000 127 | }, 128 | "inputWidgets": {}, 129 | "nuid": "a4832578-6320-43f6-a462-bd8eb1ae454b", 130 | "showTitle": false, 131 | "title": "" 132 | } 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "# Enable Arrow support.\n", 137 | "spark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 0, 143 | "metadata": { 144 | "application/vnd.databricks.v1+cell": { 145 | "cellMetadata": { 146 | "byteLimit": 2048000, 147 | "rowLimit": 10000 148 | }, 149 | "inputWidgets": {}, 150 | "nuid": "b6229f4e-80ba-40ea-bac7-9d07bc2b84fe", 151 | "showTitle": false, 152 | "title": "" 153 | } 154 | }, 155 | "outputs": [ 156 | { 157 | "output_type": "stream", 158 | "name": "stderr", 159 | "output_type": "stream", 160 | "text": [ 161 | "Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n" 162 | ] 163 | }, 164 | { 165 | "output_type": "display_data", 166 | "data": { 167 | "application/vnd.jupyter.widget-view+json": { 168 | "model_id": "34b9a7842b4041a3b66ebd1cd81b60cf", 169 | "version_major": 2, 170 | "version_minor": 0 171 | }, 172 | "text/plain": [ 173 | " 0%| | 0.00/97.8M [00:00 Iterator[pd.Series]:\n", 236 | " with torch.inference_mode():\n", 237 | " model = resnet50()\n", 238 | " model.load_state_dict(bc_model_state.value)\n", 239 | " model = model.to(torch.device(\"cuda\")) # Move model to GPU\n", 240 | " model.eval()\n", 241 | "\n", 242 | " for pandas_series in pandas_series_iter:\n", 243 | " image_batch = torch.tensor(np.stack(pandas_series.values).astype(np.uint8))\n", 244 | " # change uint 0 ~ 255 range values to 0 ~ 1 range float32 values\n", 245 | " image_batch = image_batch / np.float32(256)\n", 246 | "\n", 247 | " image_batch = preprocess(image_batch)\n", 248 | " \n", 249 | " \n", 250 | " image_batch = image_batch.to(torch.device(\"cuda\"))\n", 251 | " \n", 252 | " predictions = list(model(image_batch).cpu().numpy())\n", 253 | " \n", 254 | " yield pd.Series(predictions)\n" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 0, 260 | "metadata": { 261 | "application/vnd.databricks.v1+cell": { 262 | "cellMetadata": { 263 | "byteLimit": 2048000, 264 | "rowLimit": 10000 265 | }, 266 | "inputWidgets": {}, 267 | "nuid": "9af2d6fc-4fb8-4983-aca3-672381e96310", 268 | "showTitle": false, 269 | "title": "" 270 | } 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "predictions = input_data.select(resnet_predict(col(\"image\")))" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 0, 280 | "metadata": { 281 | "application/vnd.databricks.v1+cell": { 282 | "cellMetadata": { 283 | "byteLimit": 2048000, 284 | "rowLimit": 10000 285 | }, 286 | "inputWidgets": {}, 287 | "nuid": "370ef5d5-55a1-4821-b428-414a62a36620", 288 | "showTitle": false, 289 | "title": "" 290 | } 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"64\")" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 0, 300 | "metadata": { 301 | "application/vnd.databricks.v1+cell": { 302 | "cellMetadata": { 303 | "byteLimit": 2048000, 304 | "rowLimit": 10000 305 | }, 306 | "inputWidgets": {}, 307 | "nuid": "505599b5-fdd0-4f55-bb9d-f17898858f47", 308 | "showTitle": false, 309 | "title": "" 310 | } 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "\n", 315 | "run_times = []\n", 316 | "run_N = 6\n", 317 | "for i in range(run_N):\n", 318 | " start_time = time.time()\n", 319 | " predictions.write.mode(\"overwrite\").format(\"noop\").save()\n", 320 | " end_time = time.time()\n", 321 | " run_times.append(end_time-start_time)\n", 322 | "\n", 323 | "assert input_data.count() == 16232" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 0, 329 | "metadata": { 330 | "application/vnd.databricks.v1+cell": { 331 | "cellMetadata": { 332 | "byteLimit": 2048000, 333 | "rowLimit": 10000 334 | }, 335 | "inputWidgets": {}, 336 | "nuid": "d3e66381-0471-40e8-939b-a8de901ec781", 337 | "showTitle": false, 338 | "title": "" 339 | } 340 | }, 341 | "outputs": [ 342 | { 343 | "output_type": "stream", 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "Run times: [126.96703696250916, 106.93883275985718, 99.0119571685791, 95.7514579296112, 90.47398567199707, 90.08183360099792]\nAverge Prediction took: 101.53751734892528 seconds\n" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "print(\"Run times: \", run_times)\n", 353 | "print(f\"Averge Prediction took: {sum(run_times) / run_N} seconds\")" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 0, 359 | "metadata": { 360 | "application/vnd.databricks.v1+cell": { 361 | "cellMetadata": { 362 | "byteLimit": 2048000, 363 | "rowLimit": 10000 364 | }, 365 | "inputWidgets": {}, 366 | "nuid": "4436de79-835a-4fb9-acc1-8630768561ec", 367 | "showTitle": false, 368 | "title": "" 369 | } 370 | }, 371 | "outputs": [ 372 | { 373 | "output_type": "execute_result", 374 | "data": { 375 | "text/html": [ 376 | "\n", 377 | "
\n", 378 | "

SparkSession - hive

\n", 379 | " \n", 380 | "
\n", 381 | "

SparkContext

\n", 382 | "\n", 383 | "

Spark UI

\n", 384 | "\n", 385 | "
\n", 386 | "
Version
\n", 387 | "
v3.4.0
\n", 388 | "
Master
\n", 389 | "
spark://10.146.232.191:7077
\n", 390 | "
AppName
\n", 391 | "
Databricks Shell
\n", 392 | "
\n", 393 | "
\n", 394 | " \n", 395 | "
\n", 396 | " " 397 | ], 398 | "text/plain": [ 399 | "" 400 | ] 401 | }, 402 | "execution_count": 12, 403 | "metadata": {}, 404 | "output_type": "execute_result" 405 | } 406 | ], 407 | "source": [ 408 | "spark" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 0, 414 | "metadata": { 415 | "application/vnd.databricks.v1+cell": { 416 | "cellMetadata": {}, 417 | "inputWidgets": {}, 418 | "nuid": "5e42afc3-a845-4bf2-a6ac-e0fba290dbcd", 419 | "showTitle": false, 420 | "title": "" 421 | } 422 | }, 423 | "outputs": [], 424 | "source": [] 425 | } 426 | ], 427 | "metadata": { 428 | "application/vnd.databricks.v1+notebook": { 429 | "dashboards": [], 430 | "language": "python", 431 | "notebookMetadata": { 432 | "mostRecentlyExecutedCommandWithImplicitDF": { 433 | "commandId": 2878901766425688, 434 | "dataframes": [ 435 | "_sqldf" 436 | ] 437 | }, 438 | "pythonIndentUnit": 4 439 | }, 440 | "notebookName": "(databricks modified) torch-batch-inference-s3-10G-standard-iterator", 441 | "notebookOrigID": 2252021368504006, 442 | "widgets": {} 443 | } 444 | }, 445 | "nbformat": 4, 446 | "nbformat_minor": 0 447 | } 448 | -------------------------------------------------------------------------------- /spark/code/torch-batch-inference-s3-10G-standard-iterator.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 0, 6 | "metadata": { 7 | "application/vnd.databricks.v1+cell": { 8 | "cellMetadata": { 9 | "byteLimit": 2048000, 10 | "rowLimit": 10000 11 | }, 12 | "inputWidgets": {}, 13 | "nuid": "701d2809-daab-4de6-979d-ff8c6aab47cf", 14 | "showTitle": false, 15 | "title": "" 16 | } 17 | }, 18 | "outputs": [ 19 | { 20 | "output_type": "stream", 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "Out[1]: True" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "import time\n", 30 | "import torch\n", 31 | "torch.__version__\n", 32 | "torch.cuda.is_available()" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 0, 38 | "metadata": { 39 | "application/vnd.databricks.v1+cell": { 40 | "cellMetadata": { 41 | "byteLimit": 2048000, 42 | "rowLimit": 10000 43 | }, 44 | "inputWidgets": {}, 45 | "nuid": "59841386-1fb4-4a0c-afbe-922f8f8049d7", 46 | "showTitle": false, 47 | "title": "" 48 | } 49 | }, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "Executor memory: 148728m\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "#print(\"Profiling enabled: \", spark.conf.get(\"spark.python.profile\"))\n", 62 | "print(\"Executor memory: \", spark.conf.get(\"spark.executor.memory\"))" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 0, 68 | "metadata": { 69 | "application/vnd.databricks.v1+cell": { 70 | "cellMetadata": { 71 | "byteLimit": 2048000, 72 | "rowLimit": 10000 73 | }, 74 | "inputWidgets": {}, 75 | "nuid": "9d5cb47e-49c5-4904-9df3-538972665541", 76 | "showTitle": false, 77 | "title": "" 78 | } 79 | }, 80 | "outputs": [ 81 | { 82 | "output_type": "display_data", 83 | "data": { 84 | "application/vnd.databricks.v1+bamboolib_hint": "{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}", 85 | "text/plain": [] 86 | }, 87 | "metadata": {}, 88 | "output_type": "display_data" 89 | } 90 | ], 91 | "source": [ 92 | "import pandas as pd\n", 93 | "from torchvision import transforms\n", 94 | "from torchvision.models import resnet50, ResNet50_Weights" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 0, 100 | "metadata": { 101 | "application/vnd.databricks.v1+cell": { 102 | "cellMetadata": { 103 | "byteLimit": 2048000, 104 | "rowLimit": 10000 105 | }, 106 | "inputWidgets": {}, 107 | "nuid": "a4832578-6320-43f6-a462-bd8eb1ae454b", 108 | "showTitle": false, 109 | "title": "" 110 | } 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "# Enable Arrow support.\n", 115 | "spark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 0, 121 | "metadata": { 122 | "application/vnd.databricks.v1+cell": { 123 | "cellMetadata": { 124 | "byteLimit": 2048000, 125 | "rowLimit": 10000 126 | }, 127 | "inputWidgets": {}, 128 | "nuid": "b6229f4e-80ba-40ea-bac7-9d07bc2b84fe", 129 | "showTitle": false, 130 | "title": "" 131 | } 132 | }, 133 | "outputs": [ 134 | { 135 | "output_type": "stream", 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n" 140 | ] 141 | }, 142 | { 143 | "output_type": "display_data", 144 | "data": { 145 | "application/vnd.jupyter.widget-view+json": { 146 | "model_id": "a401f46c6c434755b5c19acf7a119c10", 147 | "version_major": 2, 148 | "version_minor": 0 149 | }, 150 | "text/plain": [ 151 | " 0%| | 0.00/97.8M [00:00 Iterator[pd.Series]:\n", 216 | " preprocess = transforms.Compose(\n", 217 | " [\n", 218 | " transforms.Resize(256),\n", 219 | " transforms.CenterCrop(224),\n", 220 | " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", 221 | " ]\n", 222 | " )\n", 223 | " for image in image_iter:\n", 224 | " print(f\"number of images: {len(image)}\")\n", 225 | " # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n", 226 | " # Each image is represented as a flattened numpy array.\n", 227 | " # We have to reshape back to the original number of dimensions.\n", 228 | " # Need to convert to float dtype otherwise torchvision transforms will complain. The data is read as short (int16) by default\n", 229 | " batch_dim = len(image)\n", 230 | " numpy_batch = np.stack(image.values)\n", 231 | " reshaped_images = numpy_batch.reshape(batch_dim, 256, 256, 3).astype(np.float)\n", 232 | " \n", 233 | " torch_tensor = torch.Tensor(reshaped_images.transpose(0, 3, 1, 2))\n", 234 | " preprocessed_images = preprocess(torch_tensor).numpy()\n", 235 | " # Arrow only works with single dimension numpy arrays, so need to flatten the array before outputting it\n", 236 | " preprocessed_images = [image.flatten() for image in preprocessed_images]\n", 237 | " yield pd.Series(preprocessed_images)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 0, 243 | "metadata": { 244 | "application/vnd.databricks.v1+cell": { 245 | "cellMetadata": { 246 | "byteLimit": 2048000, 247 | "rowLimit": 10000 248 | }, 249 | "inputWidgets": {}, 250 | "nuid": "9af2d6fc-4fb8-4983-aca3-672381e96310", 251 | "showTitle": false, 252 | "title": "" 253 | } 254 | }, 255 | "outputs": [], 256 | "source": [ 257 | "preprocessed_data = input_data.select(preprocess(col(\"image\")))" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 0, 263 | "metadata": { 264 | "application/vnd.databricks.v1+cell": { 265 | "cellMetadata": { 266 | "byteLimit": 2048000, 267 | "rowLimit": 10000 268 | }, 269 | "inputWidgets": {}, 270 | "nuid": "370ef5d5-55a1-4821-b428-414a62a36620", 271 | "showTitle": false, 272 | "title": "" 273 | } 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "# 1000 is the largest batch size that can fit on GPU. Limit batch size to 1000 to avoid CUDA OOM.\n", 278 | "spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 0, 284 | "metadata": { 285 | "application/vnd.databricks.v1+cell": { 286 | "cellMetadata": { 287 | "byteLimit": 2048000, 288 | "rowLimit": 10000 289 | }, 290 | "inputWidgets": {}, 291 | "nuid": "48c15069-1c2f-47c3-9b1c-6fe0eb5641e2", 292 | "showTitle": false, 293 | "title": "" 294 | } 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "from typing import Iterator\n", 299 | "\n", 300 | "@pandas_udf(ArrayType(FloatType()))\n", 301 | "def predict(preprocessed_images_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:\n", 302 | " with torch.inference_mode():\n", 303 | " model = resnet50()\n", 304 | " model.load_state_dict(bc_model_state.value)\n", 305 | " model = model.to(torch.device(\"cuda\")) # Move model to GPU\n", 306 | " model.eval()\n", 307 | " \n", 308 | " for preprocessed_images in preprocessed_images_iter:\n", 309 | " batch = preprocessed_images\n", 310 | " batch_dim = len(batch)\n", 311 | " numpy_batch = np.stack(batch.values)\n", 312 | " # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n", 313 | " # Each image is represented as a flattened numpy array.\n", 314 | " # We have to reshape back to the original number of dimensions.\n", 315 | " reshaped_images = numpy_batch.reshape(batch_dim, 3, 224, 224)\n", 316 | " gpu_batch = torch.Tensor(reshaped_images).to(torch.device(\"cuda\"))\n", 317 | " predictions = list(model(gpu_batch).cpu().numpy())\n", 318 | " assert len(predictions) == batch_dim\n", 319 | " \n", 320 | " yield pd.Series(predictions)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 0, 326 | "metadata": { 327 | "application/vnd.databricks.v1+cell": { 328 | "cellMetadata": { 329 | "byteLimit": 2048000, 330 | "rowLimit": 10000 331 | }, 332 | "inputWidgets": {}, 333 | "nuid": "d30eb582-afe4-433c-9fee-9dc5c6f3cf86", 334 | "showTitle": false, 335 | "title": "" 336 | } 337 | }, 338 | "outputs": [], 339 | "source": [ 340 | "predictions = preprocessed_data.select(predict(col(\"preprocess(image)\")))" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 0, 346 | "metadata": { 347 | "application/vnd.databricks.v1+cell": { 348 | "cellMetadata": { 349 | "byteLimit": 2048000, 350 | "rowLimit": 10000 351 | }, 352 | "inputWidgets": {}, 353 | "nuid": "505599b5-fdd0-4f55-bb9d-f17898858f47", 354 | "showTitle": false, 355 | "title": "" 356 | } 357 | }, 358 | "outputs": [ 359 | { 360 | "output_type": "stream", 361 | "name": "stdout", 362 | "output_type": "stream", 363 | "text": [ 364 | "Prediction took: 143.19921851158142 seconds\n" 365 | ] 366 | } 367 | ], 368 | "source": [ 369 | "start_time = time.time()\n", 370 | "predictions.write.mode(\"overwrite\").format(\"noop\").save()\n", 371 | "end_time = time.time()\n", 372 | "print(f\"Prediction took: {end_time-start_time} seconds\")\n", 373 | "\n", 374 | "assert preprocessed_data.count() == 16232" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 0, 380 | "metadata": { 381 | "application/vnd.databricks.v1+cell": { 382 | "cellMetadata": {}, 383 | "inputWidgets": {}, 384 | "nuid": "8da01873-5803-4813-b561-0fb6819f9dc8", 385 | "showTitle": false, 386 | "title": "" 387 | } 388 | }, 389 | "outputs": [], 390 | "source": [] 391 | } 392 | ], 393 | "metadata": { 394 | "application/vnd.databricks.v1+notebook": { 395 | "dashboards": [], 396 | "language": "python", 397 | "notebookMetadata": { 398 | "pythonIndentUnit": 4 399 | }, 400 | "notebookName": "torch-batch-inference-s3-10G-standard-iterator", 401 | "notebookOrigID": 4465860287893945, 402 | "widgets": {} 403 | } 404 | }, 405 | "nbformat": 4, 406 | "nbformat_minor": 0 407 | } 408 | -------------------------------------------------------------------------------- /spark/code/torch-batch-inference-s3-10G-standard.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"code","source":["import time\nimport torch\ntorch.__version__\ntorch.cuda.is_available()"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"701d2809-daab-4de6-979d-ff8c6aab47cf","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Out[1]: True","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Out[1]: True"]}}],"execution_count":0},{"cell_type":"code","source":["#print(\"Profiling enabled: \", spark.conf.get(\"spark.python.profile\"))\nprint(\"Executor memory: \", spark.conf.get(\"spark.executor.memory\"))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"59841386-1fb4-4a0c-afbe-922f8f8049d7","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Executor memory: 148728m\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Executor memory: 148728m\n"]}}],"execution_count":0},{"cell_type":"code","source":["import pandas as pd\nfrom torchvision import transforms\nfrom torchvision.models import resnet50, ResNet50_Weights"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9d5cb47e-49c5-4904-9df3-538972665541","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"},"removedWidgets":[],"addedWidgets":{},"metadata":{"kernelSessionId":"bb7346f4-55fd67bf8d59a800bcda876e"},"type":"mimeBundle","arguments":{}}},"output_type":"display_data","data":{"text/plain":"","application/vnd.databricks.v1+bamboolib_hint":"{\"pd.DataFrames\": [], \"version\": \"0.0.1\"}"}}],"execution_count":0},{"cell_type":"code","source":["# Enable Arrow support.\nspark.conf.set(\"spark.sql.execution.arrow.enabled\", \"true\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"a4832578-6320-43f6-a462-bd8eb1ae454b","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# Create and broadcast model state. Equivalent to AIR Checkpoint\nmodel_state = resnet50(weights=ResNet50_Weights.DEFAULT).state_dict()\n# sc is already initialized by Databricks. Broadcast the model state to all executors.\nbc_model_state = sc.broadcast(model_state)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"b6229f4e-80ba-40ea-bac7-9d07bc2b84fe","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Downloading: \"https://download.pytorch.org/models/resnet50-11ad3fa6.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth\n"]}},{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":{"text/plain":" 0%| | 0.00/97.8M [00:00 pd.Series:\n preprocess = transforms.Compose(\n [\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n ]\n )\n print(f\"number of images: {len(image)}\")\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n # Need to convert to float dtype otherwise torchvision transforms will complain. The data is read as short (int16) by default\n batch_dim = len(image)\n numpy_batch = np.stack(image.values)\n reshaped_images = numpy_batch.reshape(batch_dim, 256, 256, 3).astype(np.float)\n \n torch_tensor = torch.Tensor(reshaped_images.transpose(0, 3, 1, 2))\n preprocessed_images = preprocess(torch_tensor).numpy()\n # Arrow only works with single dimension numpy arrays, so need to flatten the array before outputting it\n preprocessed_images = [image.flatten() for image in preprocessed_images]\n return pd.Series(preprocessed_images)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"ef059a6f-8c22-44ef-960c-518956975eb2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["preprocessed_data = input_data.select(preprocess(col(\"image\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"9af2d6fc-4fb8-4983-aca3-672381e96310","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["# 1000 is the largest batch size that can fit on GPU. Limit batch size to 1000 to avoid CUDA OOM.\nspark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"370ef5d5-55a1-4821-b428-414a62a36620","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["@pandas_udf(ArrayType(FloatType()))\ndef predict(preprocessed_images: pd.Series) -> pd.Series:\n with torch.inference_mode():\n model = resnet50()\n model.load_state_dict(bc_model_state.value)\n model = model.to(torch.device(\"cuda\")) # Move model to GPU\n model.eval()\n \n batch = preprocessed_images\n batch_dim = len(batch)\n numpy_batch = np.stack(batch.values)\n # Spark has no tensor support, so it flattens the image tensor to a single array during read.\n # Each image is represented as a flattened numpy array.\n # We have to reshape back to the original number of dimensions.\n reshaped_images = numpy_batch.reshape(batch_dim, 3, 224, 224)\n gpu_batch = torch.Tensor(reshaped_images).to(torch.device(\"cuda\"))\n predictions = list(model(gpu_batch).cpu().numpy())\n assert len(predictions) == batch_dim\n \n return pd.Series(predictions)"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"48c15069-1c2f-47c3-9b1c-6fe0eb5641e2","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["predictions = preprocessed_data.select(predict(col(\"preprocess(image)\")))"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"d30eb582-afe4-433c-9fee-9dc5c6f3cf86","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0},{"cell_type":"code","source":["start_time = time.time()\npredictions.write.mode(\"overwrite\").format(\"noop\").save()\nend_time = time.time()\nprint(f\"Prediction took: {end_time-start_time} seconds\")\n\nassert preprocessed_data.count() == 16232"],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"505599b5-fdd0-4f55-bb9d-f17898858f47","inputWidgets":{},"title":""}},"outputs":[{"output_type":"display_data","metadata":{"application/vnd.databricks.v1+output":{"datasetInfos":[],"data":"Prediction took: 109.78821015357971 seconds\n","removedWidgets":[],"addedWidgets":{},"metadata":{},"type":"ansi","arguments":{}}},"output_type":"display_data","data":{"text/plain":["Prediction took: 109.78821015357971 seconds\n"]}}],"execution_count":0},{"cell_type":"code","source":[""],"metadata":{"application/vnd.databricks.v1+cell":{"showTitle":false,"cellMetadata":{},"nuid":"8da01873-5803-4813-b561-0fb6819f9dc8","inputWidgets":{},"title":""}},"outputs":[],"execution_count":0}],"metadata":{"application/vnd.databricks.v1+notebook":{"notebookName":"torch-batch-inference-s3-10G-standard","dashboards":[],"notebookMetadata":{"pythonIndentUnit":4},"language":"python","widgets":{},"notebookOrigID":566047737057004}},"nbformat":4,"nbformat_minor":0} 2 | --------------------------------------------------------------------------------