├── .gitignore ├── README.md └── code ├── Ch02 ├── end_of_chapter.py └── word_count_submit.py ├── Ch03 ├── word_count.py └── word_count_submit.py ├── Ch04 ├── checkpoint.py └── commercials.py ├── Ch05 ├── checkpoint.py └── commercials.py ├── Ch07 ├── download_backblaze_data.py └── most_reliable_drives.py └── Ch14 ├── custom_feature.py ├── data_prep.py ├── hasInputCol.py └── read_write.py /.gitignore: -------------------------------------------------------------------------------- 1 | # We don't want the PDF to take useless space 2 | *.pdf 3 | 4 | # Mac-specific useless files 5 | .DS_Store 6 | 7 | # Emacs temporary files 8 | .#* 9 | 10 | # Python temporary files (cache and helpers) 11 | .mypy_cache/ 12 | 13 | # Relevant stuff for paperwork 14 | paperwork/ 15 | 16 | # Early drawings 17 | drawings/ 18 | 19 | # Chapter revisions 20 | manuscript/chapter_revisions 21 | 22 | # Large data files 23 | data/broadcast_logs/BroadcastLogs_2018_Q3_M8.CSV 24 | 25 | # Scribbles 26 | scribbles.* 27 | 28 | # Data files from Chapter 7 that are too big to be git-ted. Better to download them using the script 29 | /data/backblaze/*.zip 30 | /data/backblaze/d* 31 | 32 | # References for chapters and appendixes 33 | chapter_reference.txt 34 | appendix_reference.txt 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Analysis with Python and PySpark 2 | 3 | This is the companion repository for the _Data Analysis with Python and PySpark_ 4 | book (Manning, 2022). It contains the source 5 | code and data download scripts, when pertinent. 6 | 7 | ## Get the data 8 | 9 | The complete data set for the book hovers at around ~1GB. Because of this, [I 10 | moved the data sources to another repository]( 11 | https://github.com/jonesberg/DataAnalysisWithPythonAndPySpark-Data) to 12 | avoid cloning a gigantic repository just to get the code. The book assumes the data is under 13 | `./data`. 14 | 15 | ## Mistakes or omissions 16 | 17 | If you encounter mistakes in the book manuscript (including the printed source 18 | code), please use the Manning platform to provide feedback. 19 | -------------------------------------------------------------------------------- /code/Ch02/end_of_chapter.py: -------------------------------------------------------------------------------- 1 | # end-of-chapter.py############################################################ 2 | # 3 | # Use this to get a free pass from Chapter 2 to Chapter 3. 4 | # 5 | # Remember, with great power comes great responsibility. Make sure you 6 | # understand the code before running it! If necessary, refer to the text in 7 | # Chapter 2. 8 | # 9 | ############################################################################### 10 | 11 | from pyspark.sql import SparkSession 12 | from pyspark.sql.functions import col, split, explode, lower, regexp_extract 13 | 14 | spark = SparkSession.builder.getOrCreate() 15 | 16 | book = spark.read.text("../../data/gutenberg_books/1342-0.txt") 17 | 18 | lines = book.select(split(book.value, " ").alias("line")) 19 | 20 | words = lines.select(explode(col("line")).alias("word")) 21 | 22 | words_lower = words.select(lower(col("word")).alias("word_lower")) 23 | 24 | words_clean = words_lower.select( 25 | regexp_extract(col("word_lower"), "[a-z]*", 0).alias("word") 26 | ) 27 | 28 | words_nonull = words_clean.where(col("word") != "") 29 | -------------------------------------------------------------------------------- /code/Ch02/word_count_submit.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql import SparkSession 2 | import pyspark.sql.functions as F 3 | 4 | 5 | spark = SparkSession.builder.appName( 6 | "Counting word occurences from a book." 7 | ).getOrCreate() 8 | 9 | spark.sparkContext.setLogLevel("WARN") 10 | 11 | # If you need to read multiple text files, replace `1342-0` by `*`. 12 | results = ( 13 | spark.read.text("../../data/gutenberg_books/1342-0.txt") 14 | .select(F.split(F.col("value"), " ").alias("line")) 15 | .select(F.explode(F.col("line")).alias("word")) 16 | .select(F.lower(F.col("word")).alias("word")) 17 | .select(F.regexp_extract(F.col("word"), "[a-z']*", 0).alias("word")) 18 | .where(F.col("word") != "") 19 | .groupby(F.col("word")) 20 | .count() 21 | ) 22 | 23 | results.orderBy("count", ascending=False).show(10) 24 | results.coalesce(1).write.csv("./results_single_partition.csv") 25 | -------------------------------------------------------------------------------- /code/Ch03/word_count.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql import SparkSession 2 | from pyspark.sql.functions import ( 3 | col, 4 | explode, 5 | lower, 6 | regexp_extract, 7 | split, 8 | ) 9 | 10 | spark = SparkSession.builder.appName( 11 | "Analyzing the vocabulary of Pride and Prejudice." 12 | ).getOrCreate() 13 | 14 | book = spark.read.text("./data/gutenberg_books/1342-0.txt") 15 | 16 | lines = book.select(split(book.value, " ").alias("line")) 17 | 18 | words = lines.select(explode(col("line")).alias("word")) 19 | 20 | words_lower = words.select(lower(col("word")).alias("word")) 21 | 22 | words_clean = words_lower.select( 23 | regexp_extract(col("word"), "[a-z']*", 0).alias("word") 24 | ) 25 | 26 | words_nonull = words_clean.where(col("word") != "") 27 | 28 | results = words_nonull.groupby(col("word")).count() 29 | 30 | results.orderBy("count", ascending=False).show(10) 31 | 32 | results.coalesce(1).write.csv("./simple_count_single_partition.csv") 33 | -------------------------------------------------------------------------------- /code/Ch03/word_count_submit.py: -------------------------------------------------------------------------------- 1 | from pyspark.sql import SparkSession 2 | import pyspark.sql.functions as F 3 | 4 | 5 | spark = SparkSession.builder.appName( 6 | "Counting word occurences from a book." 7 | ).getOrCreate() 8 | 9 | spark.sparkContext.setLogLevel("WARN") 10 | 11 | # If you need to read multiple text files, replace `1342-0` by `*`. 12 | results = ( 13 | spark.read.text("../../data/gutenberg_books/1342-0.txt") 14 | .select(F.split(F.col("value"), " ").alias("line")) 15 | .select(F.explode(F.col("line")).alias("word")) 16 | .select(F.lower(F.col("word")).alias("word")) 17 | .select(F.regexp_extract(F.col("word"), "[a-z']*", 0).alias("word")) 18 | .where(F.col("word") != "") 19 | .groupby(F.col("word")) 20 | .count() 21 | ) 22 | 23 | results.orderBy("count", ascending=False).show(10) 24 | results.coalesce(1).write.csv("./results_single_partition.csv") 25 | -------------------------------------------------------------------------------- /code/Ch04/checkpoint.py: -------------------------------------------------------------------------------- 1 | """Checkpoint code for the book Data Analysis with Python and PySpark, Chapter 4.""" 2 | 3 | import os 4 | from pyspark.sql import SparkSession 5 | import pyspark.sql.functions as F 6 | 7 | spark = SparkSession.builder.getOrCreate() 8 | 9 | DIRECTORY = "./data/broadcast_logs" 10 | logs = ( 11 | spark.read.csv( 12 | os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8.CSV"), 13 | sep="|", 14 | header=True, 15 | inferSchema=True, 16 | timestampFormat="yyyy-MM-dd", 17 | ) 18 | .drop("BroadcastLogID", "SequenceNO") 19 | .withColumn( 20 | "duration_seconds", 21 | ( 22 | F.col("Duration").substr(1, 2).cast("int") * 60 * 60 23 | + F.col("Duration").substr(4, 2).cast("int") * 60 24 | + F.col("Duration").substr(7, 2).cast("int") 25 | ), 26 | ) 27 | ) 28 | -------------------------------------------------------------------------------- /code/Ch04/commercials.py: -------------------------------------------------------------------------------- 1 | # commercials.py ############################################################# 2 | # 3 | # This program computes the commercial ratio for each channel present in the 4 | # dataset. 5 | # 6 | ############################################################################### 7 | 8 | import os 9 | 10 | import pyspark.sql.functions as F 11 | from pyspark.sql import SparkSession 12 | 13 | spark = SparkSession.builder.appName( 14 | "Getting the Canadian TV channels with the highest/lowest proportion of commercials." 15 | ).getOrCreate() 16 | 17 | spark.sparkContext.setLogLevel("WARN") 18 | 19 | ############################################################################### 20 | # Reading all the relevant data sources 21 | ############################################################################### 22 | 23 | DIRECTORY = "./data/Ch03" 24 | 25 | logs = spark.read.csv( 26 | os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8.CSV"), 27 | sep="|", 28 | header=True, 29 | inferSchema=True, 30 | ) 31 | 32 | log_identifier = spark.read.csv( 33 | "./data/Ch03/ReferenceTables/LogIdentifier.csv", 34 | sep="|", 35 | header=True, 36 | inferSchema=True, 37 | ) 38 | 39 | cd_category = spark.read.csv( 40 | "./data/Ch03/ReferenceTables/CD_Category.csv", 41 | sep="|", 42 | header=True, 43 | inferSchema=True, 44 | ).select( 45 | "CategoryID", 46 | "CategoryCD", 47 | F.col("EnglishDescription").alias("Category_Description"), 48 | ) 49 | 50 | cd_program_class = spark.read.csv( 51 | "./data/Ch03/ReferenceTables/CD_ProgramClass.csv", 52 | sep="|", 53 | header=True, 54 | inferSchema=True, 55 | ).select( 56 | "ProgramClassID", 57 | "ProgramClassCD", 58 | F.col("EnglishDescription").alias("ProgramClass_Description"), 59 | ) 60 | 61 | ############################################################################### 62 | # Data processing 63 | ############################################################################### 64 | 65 | logs = logs.drop("BroadcastLogID", "SequenceNO") 66 | 67 | logs = logs.withColumn( 68 | "duration_seconds", 69 | ( 70 | F.col("Duration").substr(1, 2).cast("int") * 60 * 60 71 | + F.col("Duration").substr(4, 2).cast("int") * 60 72 | + F.col("Duration").substr(7, 2).cast("int") 73 | ), 74 | ) 75 | 76 | log_identifier = log_identifier.where(F.col("PrimaryFG") == 1) 77 | 78 | logs_and_channels = logs.join(log_identifier, "LogServiceID") 79 | 80 | full_log = logs_and_channels.join(cd_category, "CategoryID", how="left").join( 81 | cd_program_class, "ProgramClassID", how="left" 82 | ) 83 | 84 | full_log.groupby("LogIdentifierID").agg( 85 | F.sum( 86 | F.when( 87 | F.trim(F.col("ProgramClassCD")).isin( 88 | ["COM", "PRC", "PGI", "PRO", "LOC", "SPO", "MER", "SOL"] 89 | ), 90 | F.col("duration_seconds"), 91 | ).otherwise(0) 92 | ).alias("duration_commercial"), 93 | F.sum("duration_seconds").alias("duration_total"), 94 | ).withColumn( 95 | "commercial_ratio", F.col("duration_commercial") / F.col("duration_total") 96 | ).orderBy( 97 | "commercial_ratio", ascending=False 98 | ).show( 99 | 1000, False 100 | ) 101 | -------------------------------------------------------------------------------- /code/Ch05/checkpoint.py: -------------------------------------------------------------------------------- 1 | """Checkpoint code for the book Data Analysis with Python and PySpark, Chapter 4.""" 2 | 3 | import os 4 | from pyspark.sql import SparkSession 5 | import pyspark.sql.functions as F 6 | 7 | spark = SparkSession.builder.getOrCreate() 8 | 9 | DIRECTORY = "./data/broadcast_logs" 10 | logs = ( 11 | spark.read.csv( 12 | os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8.CSV"), 13 | sep="|", 14 | header=True, 15 | inferSchema=True, 16 | timestampFormat="yyyy-MM-dd", 17 | ) 18 | .drop("BroadcastLogID", "SequenceNO") 19 | .withColumn( 20 | "duration_seconds", 21 | ( 22 | F.col("Duration").substr(1, 2).cast("int") * 60 * 60 23 | + F.col("Duration").substr(4, 2).cast("int") * 60 24 | + F.col("Duration").substr(7, 2).cast("int") 25 | ), 26 | ) 27 | ) 28 | -------------------------------------------------------------------------------- /code/Ch05/commercials.py: -------------------------------------------------------------------------------- 1 | # commercials.py ############################################################# 2 | # 3 | # This program computes the commercial ratio for each channel present in the 4 | # dataset. 5 | # 6 | ############################################################################### 7 | 8 | import os 9 | 10 | import pyspark.sql.functions as F 11 | from pyspark.sql import SparkSession 12 | 13 | spark = SparkSession.builder.appName( 14 | "Getting the Canadian TV channels with the highest/lowest proportion of commercials." 15 | ).getOrCreate() 16 | 17 | spark.sparkContext.setLogLevel("WARN") 18 | 19 | ############################################################################### 20 | # Reading all the relevant data sources 21 | ############################################################################### 22 | 23 | DIRECTORY = "./data/Ch03" 24 | 25 | logs = spark.read.csv( 26 | os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8.CSV"), 27 | sep="|", 28 | header=True, 29 | inferSchema=True, 30 | ) 31 | 32 | log_identifier = spark.read.csv( 33 | "./data/Ch03/ReferenceTables/LogIdentifier.csv", 34 | sep="|", 35 | header=True, 36 | inferSchema=True, 37 | ) 38 | 39 | cd_category = spark.read.csv( 40 | "./data/Ch03/ReferenceTables/CD_Category.csv", 41 | sep="|", 42 | header=True, 43 | inferSchema=True, 44 | ).select( 45 | "CategoryID", 46 | "CategoryCD", 47 | F.col("EnglishDescription").alias("Category_Description"), 48 | ) 49 | 50 | cd_program_class = spark.read.csv( 51 | "./data/Ch03/ReferenceTables/CD_ProgramClass.csv", 52 | sep="|", 53 | header=True, 54 | inferSchema=True, 55 | ).select( 56 | "ProgramClassID", 57 | "ProgramClassCD", 58 | F.col("EnglishDescription").alias("ProgramClass_Description"), 59 | ) 60 | 61 | ############################################################################### 62 | # Data processing 63 | ############################################################################### 64 | 65 | logs = logs.drop("BroadcastLogID", "SequenceNO") 66 | 67 | logs = logs.withColumn( 68 | "duration_seconds", 69 | ( 70 | F.col("Duration").substr(1, 2).cast("int") * 60 * 60 71 | + F.col("Duration").substr(4, 2).cast("int") * 60 72 | + F.col("Duration").substr(7, 2).cast("int") 73 | ), 74 | ) 75 | 76 | log_identifier = log_identifier.where(F.col("PrimaryFG") == 1) 77 | 78 | logs_and_channels = logs.join(log_identifier, "LogServiceID") 79 | 80 | full_log = logs_and_channels.join(cd_category, "CategoryID", how="left").join( 81 | cd_program_class, "ProgramClassID", how="left" 82 | ) 83 | 84 | full_log.groupby("LogIdentifierID").agg( 85 | F.sum( 86 | F.when( 87 | F.trim(F.col("ProgramClassCD")).isin( 88 | ["COM", "PRC", "PGI", "PRO", "LOC", "SPO", "MER", "SOL"] 89 | ), 90 | F.col("duration_seconds"), 91 | ).otherwise(0) 92 | ).alias("duration_commercial"), 93 | F.sum("duration_seconds").alias("duration_total"), 94 | ).withColumn( 95 | "commercial_ratio", F.col("duration_commercial") / F.col("duration_total") 96 | ).orderBy( 97 | "commercial_ratio", ascending=False 98 | ).show( 99 | 1000, False 100 | ) 101 | -------------------------------------------------------------------------------- /code/Ch07/download_backblaze_data.py: -------------------------------------------------------------------------------- 1 | ##download_backblaze_data.py################################################### 2 | # 3 | # Requirements 4 | # 5 | # - wget (`pip install wget`) 6 | # 7 | # How to use: 8 | # 9 | # `python download_backblaze_data.py [parameter]` 10 | # 11 | # Parameters: 12 | # 13 | # - parameter: either `full` or `min` 14 | # 15 | # If set to `full` will download the data sets used in Chapter 7 (4 files, 16 | # ~2.3GB compressed, 12.4GB uncompressed). 17 | # 18 | # If set to `minimal` will download only 2019 Q3 (1 file, 574MB compressed, 19 | # 3.1GB uncompressed). 20 | # 21 | ############################################################################### 22 | 23 | import sys 24 | import wget 25 | 26 | 27 | DATASETS_FULL = [ 28 | "https://f001.backblazeb2.com/file/Backblaze-Hard-Drive-Data/data_Q1_2019.zip", 29 | "https://f001.backblazeb2.com/file/Backblaze-Hard-Drive-Data/data_Q2_2019.zip", 30 | "https://f001.backblazeb2.com/file/Backblaze-Hard-Drive-Data/data_Q3_2019.zip", 31 | "https://f001.backblazeb2.com/file/Backblaze-Hard-Drive-Data/data_Q4_2019.zip", 32 | ] 33 | 34 | DATASETS_MINIMAL = DATASETS_FULL[2:3] # Slice to keep as a list. Simplifies 35 | # the code later. 36 | 37 | if __name__ == "__main__": 38 | 39 | try: 40 | param = sys.argv[1] 41 | 42 | if param.lower() == "full": 43 | datasets = DATASETS_FULL 44 | elif param.lower() == "minimal": 45 | datasets = DATASETS_MINIMAL 46 | else: 47 | raise AssertionError() 48 | except (AssertionError, IndexError): 49 | print( 50 | "Parameter missing. Refer to the documentation at the top of the source code for more information" 51 | ) 52 | sys.exit(1) 53 | 54 | for dataset in datasets: 55 | print("\n", dataset.split("/")[-1]) 56 | wget.download(dataset, out="../../data/Ch07/") 57 | -------------------------------------------------------------------------------- /code/Ch07/most_reliable_drives.py: -------------------------------------------------------------------------------- 1 | # most_reliable_drives.py ##################################################### 2 | # 3 | # This program computes the most reliable drives for a given set of capacities. 4 | # 5 | # We are looking at (± 10% by default) 6 | # 7 | # - 500GB 8 | # - 1TB 9 | # - 2TB 10 | # - 4TB 11 | # 12 | ############################################################################### 13 | # tag::ch07-code-final-ingestion[] 14 | from functools import reduce 15 | 16 | import pyspark.sql.functions as F 17 | from pyspark.sql import SparkSession 18 | 19 | spark = SparkSession.builder.getOrCreate() 20 | 21 | DATA_DIRECTORY = "../../data/Ch07/" 22 | 23 | DATA_FILES = [ 24 | "drive_stats_2019_Q1", 25 | "data_Q2_2019", 26 | "data_Q3_2019", 27 | "data_Q4_2019", 28 | ] 29 | 30 | data = [ 31 | spark.read.csv(DATA_DIRECTORY + file, header=True, inferSchema=True) 32 | for file in DATA_FILES 33 | ] 34 | 35 | common_columns = list( 36 | reduce(lambda x, y: x.intersection(y), [set(df.columns) for df in data]) 37 | ) 38 | 39 | assert set(["model", "capacity_bytes", "date", "failure"]).issubset( 40 | set(common_columns) 41 | ) 42 | 43 | full_data = reduce( 44 | lambda x, y: x.select(common_columns).union(y.select(common_columns)), data 45 | ) 46 | # end::ch07-code-final-ingestion[] 47 | 48 | # tag::ch07-code-final-processing[] 49 | full_data = full_data.selectExpr( 50 | "model", "capacity_bytes / pow(1024, 3) capacity_GB", "date", "failure" 51 | ) 52 | 53 | drive_days = full_data.groupby("model", "capacity_GB").agg( 54 | F.count("*").alias("drive_days") 55 | ) 56 | 57 | failures = ( 58 | full_data.where("failure = 1") 59 | .groupby("model", "capacity_GB") 60 | .agg(F.count("*").alias("failures")) 61 | ) 62 | 63 | summarized_data = ( 64 | drive_days.join(failures, on=["model", "capacity_GB"], how="left") 65 | .fillna(0.0, ["failures"]) 66 | .selectExpr("model", "capacity_GB", "failures / drive_days failure_rate") 67 | .cache() 68 | ) 69 | # end::ch07-code-final-processing[] 70 | 71 | 72 | # tag::ch07-code-final-function[] 73 | 74 | def most_reliable_drive_for_capacity(data, capacity_GB=2048, precision=0.25, top_n=3): 75 | """Returns the top 3 drives for a given approximate capacity. 76 | 77 | Given a capacity in GB and a precision as a decimal number, we keep the N 78 | drives where: 79 | 80 | - the capacity is between (capacity * 1/(1+precision)), capacity * (1+precision) 81 | - the failure rate is the lowest 82 | 83 | """ 84 | capacity_min = capacity_GB / (1 + precision) 85 | capacity_max = capacity_GB * (1 + precision) 86 | 87 | answer = ( 88 | data.where(f"capacity_GB between {capacity_min} and {capacity_max}") # <1> 89 | .orderBy("failure_rate", "capacity_GB", ascending=[True, False]) 90 | .limit(top_n) # <2> 91 | ) 92 | 93 | return answer 94 | # end::ch07-code-final-function[] 95 | 96 | 97 | if __name__ == "__main__": 98 | pass 99 | -------------------------------------------------------------------------------- /code/Ch14/custom_feature.py: -------------------------------------------------------------------------------- 1 | import pyspark.sql.functions as F 2 | from pyspark import keyword_only 3 | from pyspark.ml import Estimator, Model, Transformer 4 | from pyspark.ml.param import Param, Params, TypeConverters 5 | from pyspark.ml.param.shared import ( 6 | HasInputCol, 7 | HasInputCols, 8 | HasOutputCol, 9 | HasOutputCols, 10 | ) 11 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable 12 | 13 | 14 | class ScalarNAFiller( 15 | Transformer, 16 | HasInputCol, 17 | HasOutputCol, 18 | HasInputCols, 19 | HasOutputCols, 20 | DefaultParamsReadable, 21 | DefaultParamsWritable, 22 | ): 23 | """Fills the `null` values of inputCol with a scalar value `filler`.""" 24 | 25 | filler = Param( 26 | Params._dummy(), 27 | "filler", 28 | "Value we want to replace our null values with.", 29 | typeConverter=TypeConverters.toFloat, 30 | ) 31 | 32 | @keyword_only 33 | def __init__( 34 | self, 35 | inputCol=None, 36 | outputCol=None, 37 | inputCols=None, 38 | outputCols=None, 39 | filler=None, 40 | ): 41 | super().__init__() 42 | self._setDefault(filler=None) 43 | kwargs = self._input_kwargs 44 | self.setParams(**kwargs) 45 | 46 | @keyword_only 47 | def setParams( 48 | self, 49 | inputCol=None, 50 | outputCol=None, 51 | inputCols=None, 52 | outputCols=None, 53 | filler=None, 54 | ): 55 | kwargs = self._input_kwargs 56 | return self._set(**kwargs) 57 | 58 | def setFiller(self, new_filler): 59 | return self.setParams(filler=new_filler) 60 | 61 | def setInputCol(self, new_inputCol): 62 | return self.setParams(inputCol=new_inputCol) 63 | 64 | def setOutputCol(self, new_outputCol): 65 | return self.setParams(outputCol=new_outputCol) 66 | 67 | def setInputCols(self, new_inputCols): 68 | return self.setParams(inputCols=new_inputCols) 69 | 70 | def setOutputCols(self, new_outputCols): 71 | return self.setParams(outputCols=new_outputCols) 72 | 73 | def getFiller(self): 74 | return self.getOrDefault(self.filler) 75 | 76 | def checkParams(self): 77 | # Test #1: either inputCol or inputCols can be set (but not both). 78 | if self.isSet("inputCol") and (self.isSet("inputCols")): 79 | raise ValueError( 80 | "Only one of `inputCol` and `inputCols`" "must be set." 81 | ) 82 | 83 | # Test #2: at least one of inputCol or inputCols must be set. 84 | if not (self.isSet("inputCol") or self.isSet("inputCols")): 85 | raise ValueError( 86 | "One of `inputCol` or `inputCols` must be set." 87 | ) 88 | 89 | # Test #3: if `inputCols` is set, then `outputCols` 90 | # must be a list of the same len() 91 | if self.isSet("inputCols"): 92 | if len(self.getInputCols()) != len(self.getOutputCols()): 93 | raise ValueError( 94 | "The length of `inputCols` does not match" 95 | " the length of `outputCols`" 96 | ) 97 | 98 | def _transform(self, dataset): 99 | self.checkParams() 100 | 101 | # If `inputCol` / `outputCol`, we wrap into a single-item list 102 | input_columns = ( 103 | [self.getInputCol()] 104 | if self.isSet("inputCol") 105 | else self.getInputCols() 106 | ) 107 | output_columns = ( 108 | [self.getOutputCol()] 109 | if self.isSet("outputCol") 110 | else self.getOutputCols() 111 | ) 112 | 113 | answer = dataset 114 | 115 | # If input_columns == output_columns, we overwrite and no need to create 116 | # new columns. 117 | if input_columns != output_columns: 118 | for in_col, out_col in zip(input_columns, output_columns): 119 | answer = answer.withColumn(out_col, F.col(in_col)) 120 | 121 | na_filler = self.getFiller() 122 | return answer.fillna(na_filler, output_columns) 123 | 124 | 125 | class _ExtremeValueCapperParams( 126 | HasInputCol, HasOutputCol, DefaultParamsWritable, DefaultParamsReadable 127 | ): 128 | 129 | boundary = Param( 130 | Params._dummy(), 131 | "boundary", 132 | "Multiple of standard deviation for the cap and floor. Default = 0.0.", 133 | TypeConverters.toFloat, 134 | ) 135 | 136 | def __init__(self, *args): 137 | super().__init__(*args) 138 | self._setDefault(boundary=0.0) 139 | 140 | def getBoundary(self): 141 | return self.getOrDefault(self.boundary) 142 | 143 | 144 | class ExtremeValueCapperModel(Model, _ExtremeValueCapperParams): 145 | 146 | cap = Param( 147 | Params._dummy(), 148 | "cap", 149 | "Upper bound of the values `inputCol` can take." 150 | "Values will be capped to this value.", 151 | TypeConverters.toFloat, 152 | ) 153 | floor = Param( 154 | Params._dummy(), 155 | "floor", 156 | "Lower bound of the values `inputCol` can take." 157 | "Values will be floored to this value.", 158 | TypeConverters.toFloat, 159 | ) 160 | 161 | @keyword_only 162 | def __init__( 163 | self, inputCol=None, outputCol=None, cap=None, floor=None 164 | ): 165 | super().__init__() 166 | kwargs = self._input_kwargs 167 | self.setParams(**kwargs) 168 | 169 | @keyword_only 170 | def setParams( 171 | self, inputCol=None, outputCol=None, cap=None, floor=None 172 | ): 173 | kwargs = self._input_kwargs 174 | return self._set(**kwargs) 175 | 176 | def setCap(self, new_cap): 177 | return self.setParams(cap=new_cap) 178 | 179 | def setFloor(self, new_floor): 180 | return self.setParams(floor=new_floor) 181 | 182 | def setInputCol(self, new_inputCol): 183 | return self.setParams(inputCol=new_inputCol) 184 | 185 | def setOutputCol(self, new_outputCol): 186 | return self.setParams(outputCol=new_outputCol) 187 | 188 | def getCap(self): 189 | return self.getOrDefault(self.cap) 190 | 191 | def getFloor(self): 192 | return self.getOrDefault(self.floor) 193 | 194 | def _transform(self, dataset): 195 | if not self.isSet("inputCol"): 196 | raise ValueError( 197 | "No input column set for the " 198 | "ExtremeValueCapperModel transformer." 199 | ) 200 | input_column = dataset[self.getInputCol()] 201 | output_column = self.getOutputCol() 202 | cap_value = self.getOrDefault("cap") 203 | floor_value = self.getOrDefault("floor") 204 | 205 | return dataset.withColumn( 206 | output_column, 207 | F.when(input_column > cap_value, cap_value) 208 | .when(input_column < floor_value, floor_value) 209 | .otherwise(input_column), 210 | ) 211 | 212 | 213 | class ExtremeValueCapper(Estimator, _ExtremeValueCapperParams): 214 | @keyword_only 215 | def __init__(self, inputCol=None, outputCol=None, boundary=None): 216 | super().__init__() 217 | kwargs = self._input_kwargs 218 | self.setParams(**kwargs) 219 | 220 | @keyword_only 221 | def setParams(self, inputCol=None, outputCol=None, boundary=None): 222 | kwargs = self._input_kwargs 223 | return self._set(**kwargs) 224 | 225 | def setBoundary(self, new_boundary): 226 | self.setParams(boundary=new_boundary) 227 | 228 | def setInputCol(self, new_inputCol): 229 | return self.setParams(inputCol=new_inputCol) 230 | 231 | def setOutputCol(self, new_outputCol): 232 | return self.setParams(outputCol=new_outputCol) 233 | 234 | def _fit(self, dataset): 235 | input_column = self.getInputCol() 236 | output_column = self.getOutputCol() 237 | boundary = self.getBoundary() 238 | 239 | avg, stddev = dataset.agg( 240 | F.mean(input_column), F.stddev(input_column) 241 | ).head() 242 | 243 | cap_value = avg + boundary * stddev 244 | floor_value = avg - boundary * stddev 245 | return ExtremeValueCapperModel( 246 | inputCol=input_column, 247 | outputCol=output_column, 248 | cap=cap_value, 249 | floor=floor_value, 250 | ) 251 | -------------------------------------------------------------------------------- /code/Ch14/data_prep.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pylint: disable=missing-function-docstring 4 | 5 | from typing import Optional 6 | 7 | import pyspark.sql.functions as F 8 | import pyspark.sql.types as T 9 | from pyspark.ml.feature import Imputer, MinMaxScaler, VectorAssembler 10 | from pyspark.sql import SparkSession 11 | 12 | spark = ( 13 | SparkSession.builder.appName("Recipes ML model - Are you a dessert?") 14 | .config("spark.driver.memory", "8g") 15 | .getOrCreate() 16 | ) 17 | 18 | food = spark.read.csv( 19 | "./data/recipes/epi_r.csv", inferSchema=True, header=True 20 | ) 21 | 22 | 23 | def sanitize_column_name(name): 24 | """Drops unwanted characters from the column name. 25 | 26 | We replace spaces, dashes and slashes with underscore, 27 | and only keep alphanumeric characters.""" 28 | answer = name 29 | for i, j in ((" ", "_"), ("-", "_"), ("/", "_"), ("&", "and")): 30 | answer = answer.replace(i, j) 31 | return "".join( 32 | [ 33 | char 34 | for char in answer 35 | if char.isalpha() or char.isdigit() or char == "_" 36 | ] 37 | ) 38 | 39 | 40 | food = food.toDF(*[sanitize_column_name(name) for name in food.columns]) 41 | 42 | 43 | # Keeping only the relevant values for `cakeweek` and `wasteless`. 44 | # Check the exercises for a more robust approach to this. 45 | food = food.where( 46 | (F.col("cakeweek").isin([0.0, 1.0]) | F.col("cakeweek").isNull()) 47 | & (F.col("wasteless").isin([0.0, 1.0]) | F.col("wasteless").isNull()) 48 | ) 49 | 50 | IDENTIFIERS = ["title"] 51 | 52 | CONTINUOUS_COLUMNS = [ 53 | "rating", 54 | "calories", 55 | "protein", 56 | "fat", 57 | "sodium", 58 | ] 59 | 60 | TARGET_COLUMN = ["dessert"] 61 | 62 | BINARY_COLUMNS = [ 63 | x 64 | for x in food.columns 65 | if x not in CONTINUOUS_COLUMNS 66 | and x not in TARGET_COLUMN 67 | and x not in IDENTIFIERS 68 | ] 69 | 70 | food = food.dropna( 71 | how="all", 72 | subset=[x for x in food.columns if x not in IDENTIFIERS], 73 | ) 74 | 75 | food = food.dropna(subset=TARGET_COLUMN) 76 | 77 | @F.udf(T.BooleanType()) 78 | def is_a_number(value: Optional[str]) -> bool: 79 | if not value: 80 | return True 81 | try: 82 | _ = float(value) 83 | except ValueError: 84 | return False 85 | return True 86 | 87 | 88 | for column in ["rating", "calories"]: 89 | food = food.where(is_a_number(F.col(column))) 90 | food = food.withColumn(column, F.col(column).cast(T.DoubleType())) 91 | 92 | # TODO: REMOVE THIS 93 | maximum = { 94 | "calories": 3203.0, 95 | "protein": 173.0, 96 | "fat": 207.0, 97 | "sodium": 5661.0, 98 | } 99 | 100 | 101 | inst_sum_of_binary_columns = [ 102 | F.sum(F.col(x)).alias(x) for x in BINARY_COLUMNS 103 | ] 104 | 105 | sum_of_binary_columns = ( 106 | food.select(*inst_sum_of_binary_columns).head().asDict() 107 | ) 108 | 109 | num_rows = food.count() 110 | too_rare_features = [ 111 | k 112 | for k, v in sum_of_binary_columns.items() 113 | if v < 10 or v > (num_rows - 10) 114 | ] 115 | 116 | BINARY_COLUMNS = list(set(BINARY_COLUMNS) - set(too_rare_features)) 117 | 118 | food = food.withColumn( 119 | "protein_ratio", F.col("protein") * 4 / F.col("calories") 120 | ).withColumn("fat_ratio", F.col("fat") * 9 / F.col("calories")) 121 | 122 | CONTINUOUS_COLUMNS += ["protein_ratio", "fat_ratio"] 123 | 124 | 125 | from pyspark.ml.classification import LogisticRegression 126 | 127 | lr = LogisticRegression( 128 | featuresCol="features", labelCol="dessert", predictionCol="prediction" 129 | ) 130 | 131 | from pyspark.ml import Pipeline 132 | import pyspark.ml.feature as MF 133 | 134 | imputer = MF.Imputer( # <1> 135 | strategy="mean", 136 | inputCols=[ 137 | "calories", 138 | "protein", 139 | "fat", 140 | "sodium", 141 | "protein_ratio", 142 | "fat_ratio", 143 | ], 144 | outputCols=[ 145 | "calories_i", 146 | "protein_i", 147 | "fat_i", 148 | "sodium_i", 149 | "protein_ratio_i", 150 | "fat_ratio_i", 151 | ], 152 | ) 153 | 154 | continuous_assembler = MF.VectorAssembler( 155 | inputCols=["rating", "calories_i", "protein_i", "fat_i", "sodium_i"], 156 | outputCol="continuous", 157 | ) 158 | 159 | continuous_scaler = MF.MinMaxScaler( 160 | inputCol="continuous", 161 | outputCol="continuous_scaled", 162 | ) 163 | 164 | preml_assembler = MF.VectorAssembler( 165 | inputCols=BINARY_COLUMNS 166 | + ["continuous_scaled"] 167 | + ["protein_ratio_i", "fat_ratio_i"], 168 | outputCol="features", 169 | ) 170 | -------------------------------------------------------------------------------- /code/Ch14/hasInputCol.py: -------------------------------------------------------------------------------- 1 | class HasInputCols(Params): 2 | """Mixin for param inputCols: input column names.""" 3 | 4 | inputCols = Param( # <1> 5 | Params._dummy(), 6 | "inputCols", "input column names.", 7 | typeConverter=TypeConverters.toListString, 8 | ) 9 | 10 | def __init__(self): 11 | super(HasInputCols, self).__init__() 12 | 13 | def getInputCols(self): 14 | """Gets the value of inputCols or its default value. """ 15 | return self.getOrDefault(self.inputCols) 16 | -------------------------------------------------------------------------------- /code/Ch14/read_write.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # tag::ch14-params-read-write[] 4 | 5 | from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable 6 | 7 | class ScalarNAFiller( 8 | Transformer, 9 | HasInputCol, 10 | HasOutputCol, 11 | HasInputCols, 12 | HasOutputCols, 13 | DefaultParamsReadable, 14 | DefaultParamsWritable, 15 | ): 16 | # ... rest of the class here 17 | 18 | class _ExtremeValueCapperParams( 19 | HasInputCol, HasOutputCol, DefaultParamsWritable, DefaultParamsReadable 20 | ): 21 | # ... rest of the class here 22 | 23 | # end::ch14-params-read-write[] 24 | --------------------------------------------------------------------------------