├── LICENSE ├── README.md └── src └── checkers ├── __init__.py ├── chained_function_checker.py ├── function_call_checker.py ├── logic_op_complexity_checker.py ├── pylint_utils.py ├── select_alias_checker.py ├── select_cast_checker.py └── statement_call_checker.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Palantir Technologies, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PySpark Style Guide 2 | 3 | PySpark is a wrapper language that allows users to interface with an Apache Spark backend to quickly process data. Spark can operate on massive datasets across a distributed network of servers, providing major performance and reliability benefits when utilized correctly. It presents challenges, even for experienced Python developers, as the PySpark syntax draws on the JVM heritage of Spark and therefore implements code patterns that may be unfamiliar. 4 | 5 | This opinionated guide to PySpark code style presents common situations we've encountered and the associated best practices based on the most frequent recurring topics across PySpark repos. 6 | 7 | Beyond PySpark specifics, the general practices of clean code are important in PySpark repositories- the Google [PyGuide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md) is a strong starting point for learning more about these practices. 8 | 9 | 10 | 11 | # Prefer implicit column selection to direct access, except for disambiguation 12 | 13 | ```python 14 | # bad 15 | df = df.select(F.lower(df1.colA), F.upper(df2.colB)) 16 | 17 | # good 18 | df = df.select(F.lower(F.col('colA')), F.upper(F.col('colB'))) 19 | 20 | # better - since Spark 3.0 21 | df = df.select(F.lower('colA'), F.upper('colB')) 22 | ``` 23 | 24 | In most situations, it's best to avoid the first and second styles and just reference the column by its name, using a string, as in the third example. Spark 3.0 [greatly expanded](https://issues.apache.org/jira/browse/SPARK-26979) the cases where this works. When the string method is not possible, however, we must resort to a more verbose approach. 25 | 26 | In many situations the first style can be simpler, shorter and visually less polluted. However, we have found that it faces a number of limitations, that lead us to prefer the second style: 27 | 28 | - If the dataframe variable name is large, expressions involving it quickly become unwieldy; 29 | - If the column name has a space or other unsupported character, the bracket operator must be used instead. This generates inconsistency, and `df1['colA']` is just as difficult to write as `F.col('colA')`; 30 | - Column expressions involving the dataframe aren't reusable and can't be used for defining abstract functions; 31 | - Renaming a dataframe variable can be error-prone, as all column references must be updated in tandem. 32 | 33 | Additionally, the dot syntax encourages use of short and non-descriptive variable names for the dataframes, which we have found to be harmful for maintainability. Remember that dataframes are containers for data, and descriptive names is a helpful way to quickly set expectations about what's contained within. 34 | 35 | By contrast, `F.col('colA')` will always reference a column designated `colA` in the dataframe being operated on, named `df`, in this case. It does not require keeping track of other dataframes' states at all, so the code becomes more local and less susceptible to "spooky interaction at a distance," which is often challenging to debug. 36 | 37 | ### Caveats 38 | 39 | In some contexts there may be access to columns from more than one dataframe, and there may be an overlap in names. A common example is in matching expressions like `df.join(df2, on=(df.key == df2.key), how='left')`. In such cases it is fine to reference columns by their dataframe directly. You can also disambiguate joins using dataframe aliases (see more in the **Joins** section in this guide). 40 | 41 | 42 | # Refactor complex logical operations 43 | 44 | Logical operations, which often reside inside `.filter()` or `F.when()`, need to be readable. We apply the same rule as with chaining functions, keeping logic expressions inside the same code block to *three (3) expressions at most*. If they grow longer, it is often a sign that the code can be simplified or extracted out. Extracting out complex logical operations into variables makes the code easier to read and reason about, which also reduces bugs. 45 | 46 | ```python 47 | # bad 48 | F.when( (F.col('prod_status') == 'Delivered') | (((F.datediff('deliveryDate_actual', 'current_date') < 0) & ((F.col('currentRegistration') != '') | ((F.datediff('deliveryDate_actual', 'current_date') < 0) & ((F.col('originalOperator') != '') | (F.col('currentOperator') != '')))))), 'In Service') 49 | ``` 50 | 51 | The code above can be simplified in different ways. To start, focus on grouping the logic steps in a few named variables. PySpark requires that expressions are wrapped with parentheses. This, mixed with actual parenthesis to group logical operations, can hurt readability. For example the code above has a redundant `(F.datediff(df.deliveryDate_actual, df.current_date) < 0)` that the original author didn't notice because it's very hard to spot. 52 | 53 | ```python 54 | # better 55 | has_operator = ((F.col('originalOperator') != '') | (F.col('currentOperator') != '')) 56 | delivery_date_passed = (F.datediff('deliveryDate_actual', 'current_date') < 0) 57 | has_registration = (F.col('currentRegistration').rlike('.+')) 58 | is_delivered = (F.col('prod_status') == 'Delivered') 59 | 60 | F.when(is_delivered | (delivery_date_passed & (has_registration | has_operator)), 'In Service') 61 | ``` 62 | 63 | The above example drops the redundant expression and is easier to read. We can improve it further by reducing the number of operations. 64 | 65 | ```python 66 | # good 67 | has_operator = ((F.col('originalOperator') != '') | (F.col('currentOperator') != '')) 68 | delivery_date_passed = (F.datediff('deliveryDate_actual', 'current_date') < 0) 69 | has_registration = (F.col('currentRegistration').rlike('.+')) 70 | is_delivered = (F.col('prod_status') == 'Delivered') 71 | is_active = (has_registration | has_operator) 72 | 73 | F.when(is_delivered | (delivery_date_passed & is_active), 'In Service') 74 | ``` 75 | 76 | Note how the `F.when` expression is now succinct and readable and the desired behavior is clear to anyone reviewing this code. The reader only needs to visit the individual expressions if they suspect there is an error. It also makes each chunk of logic easy to test if you have unit tests in your code, and want to abstract them as functions. 77 | 78 | There is still some duplication of code in the final example: how to remove that duplication is an exercise for the reader. 79 | 80 | 81 | # Use `select` statements to specify a schema contract 82 | 83 | Doing a select at the beginning of a PySpark transform, or before returning, is considered good practice. This `select` statement specifies the contract with both the reader and the code about the expected dataframe schema for inputs and outputs. Any select should be seen as a cleaning operation that is preparing the dataframe for consumption by the next step in the transform. 84 | 85 | Keep select statements as simple as possible. Due to common SQL idioms, allow only *one* function from `spark.sql.function` to be used per selected column, plus an optional `.alias()` to give it a meaningful name. Keep in mind that this should be used sparingly. If there are more than *three* such uses in the same select, refactor it into a separate function like `clean_()` to encapsulate the operation. 86 | 87 | Expressions involving more than one dataframe, or conditional operations like `.when()` are discouraged to be used in a select, unless required for performance reasons. 88 | 89 | 90 | ```python 91 | # bad 92 | aircraft = aircraft.select( 93 | 'aircraft_id', 94 | 'aircraft_msn', 95 | F.col('aircraft_registration').alias('registration'), 96 | 'aircraft_type', 97 | F.avg('staleness').alias('avg_staleness'), 98 | F.col('number_of_economy_seats').cast('long'), 99 | F.avg('flight_hours').alias('avg_flight_hours'), 100 | 'operator_code', 101 | F.col('number_of_business_seats').cast('long'), 102 | ) 103 | ``` 104 | 105 | Unless order matters to you, try to cluster together operations of the same type. 106 | 107 | ```python 108 | # good 109 | aircraft = aircraft.select( 110 | 'aircraft_id', 111 | 'aircraft_msn', 112 | 'aircraft_type', 113 | 'operator_code', 114 | F.col('aircraft_registration').alias('registration'), 115 | F.col('number_of_economy_seats').cast('long'), 116 | F.col('number_of_business_seats').cast('long'), 117 | F.avg('staleness').alias('avg_staleness'), 118 | F.avg('flight_hours').alias('avg_flight_hours'), 119 | ) 120 | ``` 121 | 122 | The `select()` statement redefines the schema of a dataframe, so it naturally supports the inclusion or exclusion of columns, old and new, as well as the redefinition of pre-existing ones. By centralising all such operations in a single statement, it becomes much easier to identify the final schema, which makes code more readable. It also makes code more concise. 123 | 124 | Instead of calling `withColumnRenamed()`, use aliases: 125 | 126 | 127 | ```python 128 | #bad 129 | df.select('key', 'comments').withColumnRenamed('comments', 'num_comments') 130 | 131 | # good 132 | df.select('key', F.col('comments').alias('num_comments')) 133 | ``` 134 | 135 | Instead of using `withColumn()` to redefine type, cast in the select: 136 | ```python 137 | # bad 138 | df.select('comments').withColumn('comments', F.col('comments').cast('double')) 139 | 140 | # good 141 | df.select(F.col('comments').cast('double')) 142 | ``` 143 | 144 | But keep it simple: 145 | ```python 146 | # bad 147 | df.select( 148 | ((F.coalesce(F.unix_timestamp('closed_at'), F.unix_timestamp()) 149 | - F.unix_timestamp('created_at')) / 86400).alias('days_open') 150 | ) 151 | 152 | # good 153 | df.withColumn( 154 | 'days_open', 155 | (F.coalesce(F.unix_timestamp('closed_at'), F.unix_timestamp()) - F.unix_timestamp('created_at')) / 86400 156 | ) 157 | ``` 158 | 159 | Avoid including columns in the select statement if they are going to remain unused and choose instead an explicit set of columns - this is a preferred alternative to using `.drop()` since it guarantees that schema mutations won't cause unexpected columns to bloat your dataframe. However, dropping columns isn't inherently discouraged in all cases; for instance, it is commonly appropriate to drop columns after joins since it is common for joins to introduce redundant columns. 160 | 161 | Finally, instead of adding new columns via the select statement, using `.withColumn()` is recommended instead for single columns. When adding or manipulating tens or hundreds of columns, use a single `.select()` for performance reasons. 162 | 163 | # Empty columns 164 | 165 | If you need to add an empty column to satisfy a schema, always use `F.lit(None)` for populating that column. Never use an empty string or some other string signalling an empty value (such as `NA`). 166 | 167 | Beyond being semantically correct, one practical reason for using `F.lit(None)` is preserving the ability to use utilities like `isNull`, instead of having to verify empty strings, nulls, and `'NA'`, etc. 168 | 169 | 170 | ```python 171 | # bad 172 | df = df.withColumn('foo', F.lit('')) 173 | 174 | # bad 175 | df = df.withColumn('foo', F.lit('NA')) 176 | 177 | # good 178 | df = df.withColumn('foo', F.lit(None)) 179 | ``` 180 | 181 | # Using comments 182 | 183 | While comments can provide useful insight into code, it is often more valuable to refactor the code to improve its readability. The code should be readable by itself. If you are using comments to explain the logic step by step, you should refactor it. 184 | 185 | ```python 186 | # bad 187 | 188 | # Cast the timestamp columns 189 | cols = ['start_date', 'delivery_date'] 190 | for c in cols: 191 | df = df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(TimestampType())) 192 | ``` 193 | 194 | In the example above, we can see that those columns are getting cast to Timestamp. The comment doesn't add much value. Moreover, a more verbose comment might still be unhelpful if it only 195 | provides information that already exists in the code. For example: 196 | 197 | ```python 198 | # bad 199 | 200 | # Go through each column, divide by 1000 because millis and cast to timestamp 201 | cols = ['start_date', 'delivery_date'] 202 | for c in cols: 203 | df = df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(TimestampType())) 204 | ``` 205 | 206 | Instead of leaving comments that only describe the logic you wrote, aim to leave comments that give context, that explain the "*why*" of decisions you made when writing the code. This is particularly important for PySpark, since the reader can understand your code, but often doesn't have context on the data that feeds into your PySpark transform. Small pieces of logic might have involved hours of digging through data to understand the correct behavior, in which case comments explaining the rationale are especially valuable. 207 | 208 | ```python 209 | # good 210 | 211 | # The consumer of this dataset expects a timestamp instead of a date, and we need 212 | # to adjust the time by 1000 because the original datasource is storing these as millis 213 | # even though the documentation says it's actually a date. 214 | cols = ['start_date', 'delivery_date'] 215 | for c in cols: 216 | df = df.withColumn(c, F.from_unixtime(F.col(c) / 1000).cast(TimestampType())) 217 | ``` 218 | 219 | # UDFs (user defined functions) 220 | 221 | It is highly recommended to avoid UDFs in all situations, as they are dramatically less performant than native PySpark. In most situations, logic that seems to necessitate a UDF can be refactored to use only native PySpark functions. 222 | 223 | # Joins 224 | 225 | Be careful with joins! If you perform a left join, and the right side has multiple matches for a key, that row will be duplicated as many times as there are matches. This is called a "join explosion" and can dramatically bloat the output of your transforms job. Always double check your assumptions to see that the key you are joining on is unique, unless you are expecting the multiplication. 226 | 227 | Bad joins are the source of many tricky-to-debug issues. There are some things that help like specifying the `how` explicitly, even if you are using the default value `(inner)`: 228 | 229 | 230 | ```python 231 | # bad 232 | flights = flights.join(aircraft, 'aircraft_id') 233 | 234 | # also bad 235 | flights = flights.join(aircraft, 'aircraft_id', 'inner') 236 | 237 | # good 238 | flights = flights.join(aircraft, 'aircraft_id', how='inner') 239 | ``` 240 | 241 | Avoid `right` joins. If you are about to use a `right` join, switch the order of your dataframes and use a `left` join instead. It is more intuitive since the dataframe you are doing the operation on is the one that you are centering your join around. 242 | 243 | ```python 244 | # bad 245 | flights = aircraft.join(flights, 'aircraft_id', how='right') 246 | 247 | # good 248 | flights = flights.join(aircraft, 'aircraft_id', how='left') 249 | ``` 250 | 251 | Avoid renaming all columns to avoid collisions. Instead, give an alias to the 252 | whole dataframe, and use that alias to select which columns you want in the end. 253 | 254 | ```python 255 | # bad 256 | columns = ['start_time', 'end_time', 'idle_time', 'total_time'] 257 | for col in columns: 258 | flights = flights.withColumnRenamed(col, 'flights_' + col) 259 | parking = parking.withColumnRenamed(col, 'parking_' + col) 260 | 261 | flights = flights.join(parking, on='flight_code', how='left') 262 | 263 | flights = flights.select( 264 | F.col('flights_start_time').alias('flight_start_time'), 265 | F.col('flights_end_time').alias('flight_end_time'), 266 | F.col('parking_total_time').alias('client_parking_total_time') 267 | ) 268 | 269 | # good 270 | flights = flights.alias('flights') 271 | parking = parking.alias('parking') 272 | 273 | flights = flights.join(parking, on='flight_code', how='left') 274 | 275 | flights = flights.select( 276 | F.col('flights.start_time').alias('flight_start_time'), 277 | F.col('flights.end_time').alias('flight_end_time'), 278 | F.col('parking.total_time').alias('client_parking_total_time') 279 | ) 280 | ``` 281 | 282 | In such cases, keep in mind: 283 | 284 | 1. It's probably best to drop overlapping columns *prior* to joining if you don't need both; 285 | 2. In case you do need both, it might be best to rename one of them prior to joining; 286 | 3. You should always resolve ambiguous columns before outputting a dataset. After the transform is finished running you can no longer distinguish them. 287 | 288 | As a last word about joins, don't use `.dropDuplicates()` or `.distinct()` as a crutch. If unexpected duplicate rows are observed, there's almost always an underlying reason for why those duplicate rows appear. Adding `.dropDuplicates()` only masks this problem and adds overhead to the runtime. 289 | 290 | # Window Functions 291 | 292 | Always specify an explicit frame when using window functions, using either [row frames](https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/expressions/WindowSpec.html#rowsBetween-long-long-) or [range frames](https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/expressions/WindowSpec.html#rangeBetween-long-long-). If you do not specify a frame, Spark will generate one, in a way that might not be easy to predict. In particular, the generated frame will change depending on whether the window is ordered (see [here](https://github.com/apache/spark/blob/v3.0.1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L2899)). To see how this can be confusing, consider the following example: 293 | 294 | ```python 295 | from pyspark.sql import functions as F, Window as W 296 | df = spark.createDataFrame([('a', 1), ('a', 2), ('a', 3), ('a', 4)], ['key', 'num']) 297 | 298 | # bad 299 | w1 = W.partitionBy('key') 300 | w2 = W.partitionBy('key').orderBy('num') 301 | 302 | df.select('key', F.sum('num').over(w1).alias('sum')).collect() 303 | # => [Row(key='a', sum=10), Row(key='a', sum=10), Row(key='a', sum=10), Row(key='a', sum=10)] 304 | 305 | df.select('key', F.sum('num').over(w2).alias('sum')).collect() 306 | # => [Row(key='a', sum=1), Row(key='a', sum=3), Row(key='a', sum=6), Row(key='a', sum=10)] 307 | 308 | df.select('key', F.first('num').over(w2).alias('first')).collect() 309 | # => [Row(key='a', first=1), Row(key='a', first=1), Row(key='a', first=1), Row(key='a', first=1)] 310 | 311 | df.select('key', F.last('num').over(w2).alias('last')).collect() 312 | # => [Row(key='a', last=1), Row(key='a', last=2), Row(key='a', last=3), Row(key='a', last=4)] 313 | ``` 314 | 315 | It is much safer to always specify an explicit frame: 316 | ```python 317 | # good 318 | w3 = W.partitionBy('key').orderBy('num').rowsBetween(W.unboundedPreceding, 0) 319 | w4 = W.partitionBy('key').orderBy('num').rowsBetween(W.unboundedPreceding, W.unboundedFollowing) 320 | 321 | df.select('key', F.sum('num').over(w3).alias('sum')).collect() 322 | # => [Row(key='a', sum=1), Row(key='a', sum=3), Row(key='a', sum=6), Row(key='a', sum=10)] 323 | 324 | df.select('key', F.sum('num').over(w4).alias('sum')).collect() 325 | # => [Row(key='a', sum=10), Row(key='a', sum=10), Row(key='a', sum=10), Row(key='a', sum=10)] 326 | 327 | df.select('key', F.first('num').over(w4).alias('first')).collect() 328 | # => [Row(key='a', first=1), Row(key='a', first=1), Row(key='a', first=1), Row(key='a', first=1)] 329 | 330 | df.select('key', F.last('num').over(w4).alias('last')).collect() 331 | # => [Row(key='a', last=4), Row(key='a', last=4), Row(key='a', last=4), Row(key='a', last=4)] 332 | ``` 333 | 334 | ## Dealing with nulls 335 | 336 | While nulls are ignored for aggregate functions (like `F.sum()` and `F.max()`), they will generally impact the result of analytic functions (like `F.first()` and `F.lead()`): 337 | ```python 338 | df_nulls = spark.createDataFrame([('a', None), ('a', 1), ('a', 2), ('a', None)], ['key', 'num']) 339 | 340 | df_nulls.select('key', F.first('num').over(w4).alias('first')).collect() 341 | # => [Row(key='a', first=None), Row(key='a', first=None), Row(key='a', first=None), Row(key='a', first=None)] 342 | 343 | df_nulls.select('key', F.last('num').over(w4).alias('last')).collect() 344 | # => [Row(key='a', last=None), Row(key='a', last=None), Row(key='a', last=None), Row(key='a', last=None)] 345 | ``` 346 | 347 | Best to avoid this problem by enabling the `ignorenulls` flag: 348 | ```python 349 | df_nulls.select('key', F.first('num', ignorenulls=True).over(w4).alias('first')).collect() 350 | # => [Row(key='a', first=1), Row(key='a', first=1), Row(key='a', first=1), Row(key='a', first=1)] 351 | 352 | df_nulls.select('key', F.last('num', ignorenulls=True).over(w4).alias('last')).collect() 353 | # => [Row(key='a', last=2), Row(key='a', last=2), Row(key='a', last=2), Row(key='a', last=2)] 354 | ``` 355 | 356 | Also be mindful of explicit ordering of nulls to make sure the expected results are obtained: 357 | ```python 358 | w5 = W.partitionBy('key').orderBy(F.asc_nulls_first('num')).rowsBetween(W.currentRow, W.unboundedFollowing) 359 | w6 = W.partitionBy('key').orderBy(F.asc_nulls_last('num')).rowsBetween(W.currentRow, W.unboundedFollowing) 360 | 361 | df_nulls.select('key', F.lead('num').over(w5).alias('lead')).collect() 362 | # => [Row(key='a', lead=None), Row(key='a', lead=1), Row(key='a', lead=2), Row(key='a', lead=None)] 363 | 364 | df_nulls.select('key', F.lead('num').over(w6).alias('lead')).collect() 365 | # => [Row(key='a', lead=1), Row(key='a', lead=2), Row(key='a', lead=None), Row(key='a', lead=None)] 366 | ``` 367 | 368 | ## Empty `partitionBy()` 369 | 370 | Spark window functions can be applied over all rows, using a global frame. This is accomplished by specifying zero columns in the partition by expression (i.e. `W.partitionBy()`). 371 | 372 | Code like this should be avoided, however, as it forces Spark to combine all data into a single partition, which can be extremely harmful for performance. 373 | 374 | Prefer to use aggregations whenever possible: 375 | 376 | ```python 377 | # bad 378 | w = W.partitionBy() 379 | df = df.select(F.sum('num').over(w).alias('sum')) 380 | 381 | # good 382 | df = df.agg(F.sum('num').alias('sum')) 383 | ``` 384 | 385 | # Chaining of expressions 386 | 387 | Chaining expressions is a contentious topic, however, since this is an opinionated guide, we are opting to recommend some limits on the usage of chaining. See the conclusion of this section for a discussion of the rationale behind this recommendation. 388 | 389 | Avoid chaining of expressions into multi-line expressions with different types, particularly if they have different behaviours or contexts. For example- mixing column creation or joining with selecting and filtering. 390 | 391 | ```python 392 | # bad 393 | df = ( 394 | df 395 | .select('a', 'b', 'c', 'key') 396 | .filter(F.col('a') == 'truthiness') 397 | .withColumn('boverc', F.col('b') / F.col('c')) 398 | .join(df2, 'key', how='inner') 399 | .join(df3, 'key', how='left') 400 | .drop('c') 401 | ) 402 | 403 | # better (seperating into steps) 404 | # first: we select and trim down the data that we need 405 | # second: we create the columns that we need to have 406 | # third: joining with other dataframes 407 | 408 | df = ( 409 | df 410 | .select('a', 'b', 'c', 'key') 411 | .filter(F.col('a') == 'truthiness') 412 | ) 413 | 414 | df = df.withColumn('boverc', F.col('b') / F.col('c')) 415 | 416 | df = ( 417 | df 418 | .join(df2, 'key', how='inner') 419 | .join(df3, 'key', how='left') 420 | .drop('c') 421 | ) 422 | ``` 423 | 424 | Having each group of expressions isolated into its own logical code block improves legibility and makes it easier to find relevant logic. 425 | For example, a reader of the code below will probably jump to where they see dataframes being assigned `df = df...`. 426 | 427 | ```python 428 | # bad 429 | df = ( 430 | df 431 | .select('foo', 'bar', 'foobar', 'abc') 432 | .filter(F.col('abc') == 123) 433 | .join(another_table, 'some_field') 434 | ) 435 | 436 | # better 437 | df = ( 438 | df 439 | .select('foo', 'bar', 'foobar', 'abc') 440 | .filter(F.col('abc') == 123) 441 | ) 442 | 443 | df = df.join(another_table, 'some_field', how='inner') 444 | ``` 445 | 446 | There are legitimate reasons to chain expressions together. These commonly represent atomic logic steps, and are acceptable. Apply a rule with a maximum of number chained expressions in the same block to keep the code readable. 447 | We recommend chains of no longer than 5 statements. 448 | 449 | If you find you are making longer chains, or having trouble because of the size of your variables, consider extracting the logic into a separate function: 450 | 451 | ```python 452 | # bad 453 | customers_with_shipping_address = ( 454 | customers_with_shipping_address 455 | .select('a', 'b', 'c', 'key') 456 | .filter(F.col('a') == 'truthiness') 457 | .withColumn('boverc', F.col('b') / F.col('c')) 458 | .join(df2, 'key', how='inner') 459 | ) 460 | 461 | # also bad 462 | customers_with_shipping_address = customers_with_shipping_address.select('a', 'b', 'c', 'key') 463 | customers_with_shipping_address = customers_with_shipping_address.filter(F.col('a') == 'truthiness') 464 | 465 | customers_with_shipping_address = customers_with_shipping_address.withColumn('boverc', F.col('b') / F.col('c')) 466 | 467 | customers_with_shipping_address = customers_with_shipping_address.join(df2, 'key', how='inner') 468 | 469 | # better 470 | def join_customers_with_shipping_address(customers, df_to_join): 471 | 472 | customers = ( 473 | customers 474 | .select('a', 'b', 'c', 'key') 475 | .filter(F.col('a') == 'truthiness') 476 | ) 477 | 478 | customers = customers.withColumn('boverc', F.col('b') / F.col('c')) 479 | customers = customers.join(df_to_join, 'key', how='inner') 480 | return customers 481 | ``` 482 | 483 | Chains of more than 3 statement are prime candidates to factor into separate, well-named functions since they are already encapsulated, isolated blocks of logic. 484 | 485 | The rationale for why we've set these limits on chaining: 486 | 487 | 1. Differentiation between PySpark code and SQL code. Chaining is something that goes against most, if not all, other Python styling. You don’t chain in Python, you assign. 488 | 2. Discourage the creation of large single code blocks. These would often make more sense extracted as a named function. 489 | 3. It doesn’t need to be all or nothing, but a maximum of five lines of chaining balances practicality with legibility. 490 | 4. If you are using an IDE, it makes it easier to use automatic extractions or do code movements (i.e: `cmd + shift + up` in pycharm) 491 | 5. Large chains are hard to read and maintain, particularly if chains are nested. 492 | 493 | 494 | # Multi-line expressions 495 | 496 | The reason you can chain expressions is because PySpark was developed from Spark, which comes from JVM languages. This meant some design patterns were transported, specifically chainability. However, Python doesn't support multiline expressions gracefully and the only alternatives are to either provide explicit line breaks, or wrap the expression in parentheses. You only need to provide explicit line breaks if the chain happens at the root node. For example: 497 | 498 | ```python 499 | # needs `\` 500 | df = df.filter(F.col('event') == 'executing')\ 501 | .filter(F.col('has_tests') == True)\ 502 | .drop('has_tests') 503 | 504 | # chain not in root node so it doesn't need the `\` 505 | df = df.withColumn('safety', F.when(F.col('has_tests') == True, 'is safe') 506 | .when(F.col('has_executed') == True, 'no tests but runs') 507 | .otherwise('not safe')) 508 | ``` 509 | 510 | To keep things consistent, please wrap the entire expression into a single parenthesis block, and avoid using `\`: 511 | 512 | ```python 513 | # bad 514 | df = df.filter(F.col('event') == 'executing')\ 515 | .filter(F.col('has_tests') == True)\ 516 | .drop('has_tests') 517 | 518 | # good 519 | df = ( 520 | df 521 | .filter(F.col('event') == 'executing') 522 | .filter(F.col('has_tests') == True) 523 | .drop('has_tests') 524 | ) 525 | ``` 526 | 527 | 528 | 529 | 530 | # Other Considerations and Recommendations 531 | 532 | 1. Be wary of functions that grow too large. As a general rule, a file 533 | should not be over 250 lines, and a function should not be over 70 lines. 534 | 2. Try to keep your code in logical blocks. For example, if you have 535 | multiple lines referencing the same things, try to keep them 536 | together. Separating them reduces context and readability. 537 | 3. Test your code! If you *can* run the local tests, do so and make 538 | sure that your new code is covered by the tests. If you can't run 539 | the local tests, build the datasets on your branch and manually 540 | verify that the data looks as expected. 541 | 4. Avoid `.otherwise(value)` as a general fallback. If you are mapping 542 | a list of keys to a list of values and a number of unknown keys appear, 543 | using `otherwise` will mask all of these into one value. 544 | 5. Do not keep commented out code checked in the repository. This applies 545 | to single line of codes, functions, classes or modules. Rely on git 546 | and its capabilities of branching or looking at history instead. 547 | 6. When encountering a large single transformation composed of integrating multiple different source tables, split it into the natural sub-steps and extract the logic to functions. This allows for easier higher level readability and allows for code re-usability and consistency between transforms. 548 | 7. Try to be as explicit and descriptive as possible when naming functions 549 | or variables. Strive to capture what the function is actually doing 550 | as opposed to naming it based the objects used inside of it. 551 | 8. Think twice about introducing new import aliases, unless there is a good 552 | reason to do so. Some of the established ones are `types` and `functions` from PySpark `from pyspark.sql import types as T, functions as F`. 553 | 9. Avoid using literal strings or integers in filtering conditions, new 554 | values of columns etc. Instead, to capture their meaning, extract them into variables, constants, 555 | dicts or classes as suitable. This makes the 556 | code more readable and enforces consistency across the repository. 557 | 558 | WIP - To enforce consistent code style, each main repository should have [Pylint](https://www.pylint.org/) enabled, with the same configuration. We provide some PySpark specific checkers you can include in your Pylint to match the rules listed in this document. These checkers for Pylint still need some more energy put into them, but feel free to contribute and improve them. 559 | 560 | 561 | [pylint]: https://www.pylint.org/ 562 | -------------------------------------------------------------------------------- /src/checkers/__init__.py: -------------------------------------------------------------------------------- 1 | from function_call_checker import FunctionCallChecker 2 | from logic_op_complexity_checker import LogicOpComplexityChecker 3 | from select_alias_checker import SelectAliasChecker 4 | from select_cast_checker import SelectCastChecker 5 | from statement_call_checker import StatementCallChecker 6 | from chained_function_checker import ChainedDotFunctionsSyntaxChecker 7 | 8 | def register(linter): 9 | linter.register_checker(FunctionCallChecker(linter)) 10 | linter.register_checker(LogicOpComplexityChecker(linter)) 11 | linter.register_checker(SelectAliasChecker(linter)) 12 | linter.register_checker(SelectCastChecker(linter)) 13 | linter.register_checker(StatementCallChecker(linter)) 14 | linter.register_checker(ChainedDotFunctionsSyntaxChecker(linter)) -------------------------------------------------------------------------------- /src/checkers/chained_function_checker.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Copyright 2020 Palantir Technologies, Inc. All rights reserved. 3 | # Licensed under the MIT License (the "License"); you may obtain a copy of the 4 | # license at https://github.com/palantir/pyspark-style-guide/blob/develop/LICENSE 5 | ## 6 | 7 | 8 | import astroid 9 | 10 | from pylint import checkers 11 | from pylint import interfaces 12 | 13 | 14 | class ChainedDotFunctionsSyntaxChecker(checkers.BaseChecker): 15 | __implements__ = interfaces.IAstroidChecker 16 | 17 | name = 'chained-function-length' 18 | 19 | msgs = { 20 | 'E1085': ( 21 | 'Chained functions applied on a variable should not be more than 3.', 22 | 'chained-function-length', 23 | 'Applied functions should not be more than 3 as it becomes difficult to follow.' 24 | ), 25 | } 26 | 27 | def visit_assign(self, node): 28 | depth = get_num_of_first_level_functions(node, 0) 29 | if depth > 3: 30 | self.add_message('chained-function-length', node=node) 31 | else: 32 | return 33 | 34 | 35 | def get_num_of_first_level_functions(node, counter): 36 | if hasattr(node, 'value'): 37 | curr_value = node.value 38 | if isinstance(curr_value, astroid.Call): 39 | return get_num_of_first_level_functions(curr_value.func, counter+1) 40 | elif hasattr(node, 'expr'): 41 | curr_value = node.expr 42 | if isinstance(curr_value, astroid.Call): 43 | return get_num_of_first_level_functions(curr_value.func, counter+1) 44 | 45 | return counter 46 | -------------------------------------------------------------------------------- /src/checkers/function_call_checker.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Copyright 2020 Palantir Technologies, Inc. All rights reserved. 3 | # Licensed under the MIT License (the "License"); you may obtain a copy of the 4 | # license at https://github.com/palantir/pyspark-style-guide/blob/develop/LICENSE 5 | ## 6 | 7 | import astroid 8 | 9 | from pylint.checkers import BaseChecker 10 | from pylint.interfaces import IAstroidChecker 11 | 12 | from pylint_utils import compute_arguments_length 13 | from pylint_utils import is_line_split 14 | from pylint_utils import get_length 15 | 16 | 17 | class FunctionCallChecker(BaseChecker): 18 | __implements__ = IAstroidChecker 19 | 20 | name = 'split-call' 21 | priority = -1 22 | msgs = { 23 | 'E1072': ( 24 | 'Function call arguments should be on the same line if they fit.', 25 | 'unnecessarily-split-call', 26 | 'All arguments for this function fit into a single line and as such should be in a single line' 27 | ), 28 | } 29 | 30 | def is_line_split(self, function): 31 | line = function.lineno 32 | for arg in function.args.args: 33 | if arg.lineno != line: 34 | return True 35 | return False 36 | 37 | def __init__(self, linter=None): 38 | super(FunctionCallChecker, self).__init__(linter) 39 | self._function_stack = [] 40 | 41 | def visit_functiondef(self, node): 42 | for statement in node.body: 43 | if isinstance(statement, astroid.nodes.Expr): 44 | if is_line_split(statement.value): 45 | args_length = compute_arguments_length( 46 | statement.value.args) 47 | total_length = len(statement.value.func.name) + \ 48 | statement.col_offset + args_length 49 | if total_length <= 120: 50 | self.add_message('unnecessarily-split-call', node=node) 51 | self._function_stack.append([]) 52 | 53 | def leave_functiondef(self, node): 54 | self._function_stack.pop() 55 | 56 | def visit_functionheadercorrect(self, node): 57 | return 58 | -------------------------------------------------------------------------------- /src/checkers/logic_op_complexity_checker.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Copyright 2020 Palantir Technologies, Inc. All rights reserved. 3 | # Licensed under the MIT License (the "License"); you may obtain a copy of the 4 | # license at https://github.com/palantir/pyspark-style-guide/blob/develop/LICENSE 5 | ## 6 | 7 | import astroid 8 | 9 | from pylint.checkers import BaseChecker 10 | from pylint.interfaces import IAstroidChecker 11 | 12 | from pylint_utils import get_binary_op_complexity 13 | 14 | 15 | 16 | class LogicOpComplexityChecker(BaseChecker): 17 | __implements__ = IAstroidChecker 18 | 19 | name = 'logic-op-complexity' 20 | priority = -1 21 | msgs = { 22 | 'E1074': ( 23 | 'Complexity of inline logic statement should be 3 or lower.', 24 | 'high-logic-op-complexity', 25 | 'This statement has a high complexity and should be split into several smaller statements assigned with descriptive variable names.' 26 | ), 27 | } 28 | 29 | def is_line_split(self, function): 30 | line = function.lineno 31 | for arg in function.args.args: 32 | if arg.lineno != line: 33 | return True 34 | return False 35 | 36 | def __init__(self, linter=None): 37 | super(LogicOpComplexityChecker, self).__init__(linter) 38 | self._function_stack = [] 39 | 40 | def visit_functiondef(self, node): 41 | self._function_stack.append([]) 42 | 43 | def visit_binop(self, node): 44 | if get_binary_op_complexity(node) > 3: 45 | self.add_message('high-logic-op-complexity', node=node) 46 | 47 | def leave_functiondef(self, node): 48 | self._function_stack.pop() 49 | -------------------------------------------------------------------------------- /src/checkers/pylint_utils.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Copyright 2020 Palantir Technologies, Inc. All rights reserved. 3 | # Licensed under the MIT License (the "License"); you may obtain a copy of the 4 | # license at https://github.com/palantir/pyspark-style-guide/blob/develop/LICENSE 5 | ## 6 | 7 | import astroid 8 | 9 | 10 | def compute_arguments_length(arguments): 11 | args_length = 0 12 | for arg in arguments: 13 | args_length += get_length(arg) 14 | return args_length 15 | 16 | 17 | def compute_target_lengths(targets): 18 | total_length = 0 19 | for target in targets: 20 | total_length += get_length(target) 21 | return total_length 22 | 23 | 24 | def select_contains_alias_call(expression): 25 | if not hasattr(expression.value.func, 'attrname'): 26 | return False 27 | if expression.value.func.attrname == 'select': 28 | for arg in expression.value.args: 29 | if isinstance(arg, astroid.nodes.Call): 30 | if arg.func.attrname == 'alias': 31 | return True 32 | return False 33 | 34 | def select_contains_cast_call(expression): 35 | if not hasattr(expression.value.func, 'attrname'): 36 | return False 37 | if expression.value.func.name == 'select': 38 | for arg in expression.value.args: 39 | if isinstance(arg, astroid.nodes.Call): 40 | if arg.func.attrname == 'cast': 41 | return True 42 | return False 43 | 44 | def is_line_split(val): 45 | line = val.lineno 46 | if isinstance(val, astroid.nodes.Function): 47 | for arg in val.args.args: 48 | if arg.lineno != line: 49 | return True 50 | return False 51 | if isinstance(val, astroid.nodes.Call): 52 | for arg in val.args: 53 | if arg.lineno != line: 54 | return True 55 | return False 56 | if isinstance(val, astroid.nodes.Assign): 57 | if hasattr(val.value, 'args'): 58 | for arg in val.value.args: 59 | if arg.lineno != line: 60 | return True 61 | if hasattr(val.value, 'elts'): 62 | for arg in val.value.elts: 63 | if arg.lineno != line: 64 | return True 65 | return False 66 | 67 | def get_binary_op_complexity(arg): 68 | if isinstance(arg, astroid.nodes.BinOp): 69 | nested_complexity = get_binary_op_complexity(arg.left) + get_binary_op_complexity(arg.right) 70 | # We need this Because essentially multiple ops are nested 71 | return 1 + (nested_complexity if nested_complexity > 0 else 1) 72 | return 0 73 | 74 | def get_length(arg): 75 | if isinstance(arg, astroid.nodes.Const): 76 | length = len(str(arg.value)) 77 | if isinstance(arg.pytype(), basestring): 78 | length += 2 # Quotes 79 | return length 80 | if isinstance(arg, astroid.nodes.BinOp): 81 | base_length = 3 # _+_ 82 | compound_length = get_length(arg.left) + get_length(arg.right) 83 | #print "Binop length %d" % (compound_length) 84 | return base_length + compound_length 85 | if isinstance(arg, astroid.nodes.Call): 86 | #print arg 87 | base_length = 2 # Open and closing brackets 88 | length = get_length(arg.func) + \ 89 | compute_arguments_length(arg.args) + base_length 90 | #print "Call length %s" % length 91 | return length 92 | if isinstance(arg, astroid.nodes.Attribute): 93 | base_length = 1 # Period 94 | if hasattr(arg.expr, 'name'): 95 | expr_length = len(arg.expr.name) 96 | else: 97 | expr_length = get_length(arg.expr) 98 | total_length = expr_length + base_length + len(arg.attrname) 99 | return total_length 100 | if isinstance(arg, astroid.nodes.Assign): 101 | # print arg 102 | base_length = 2 # Brackets 103 | target_length = compute_target_lengths(arg.targets) 104 | value_length = get_length(arg.value) 105 | return base_length + value_length + target_length 106 | if isinstance(arg, astroid.nodes.Tuple): 107 | args_length = 0 108 | for value in arg.elts: 109 | args_length += len(value.name) 110 | if len(arg.elts) > 1: 111 | args_length += ((len(arg.elts) - 1) * 2) 112 | return args_length 113 | if isinstance(arg, astroid.nodes.Name): 114 | return len(arg.name) 115 | if isinstance(arg, astroid.nodes.AssignName): 116 | return len(arg.name) 117 | if isinstance(arg, astroid.nodes.Compare): 118 | total_length = 0 119 | for op in arg.ops: 120 | total_length += 2 + len(op[0]) + get_length(op[1]) 121 | total_length += get_length(arg.left) 122 | return total_length 123 | if isinstance(arg, astroid.nodes.List): 124 | total_length = 0 125 | for value in arg.elts: 126 | total_length += get_length(value) 127 | return total_length 128 | print "Unhandled %s" % arg 129 | return 0 130 | -------------------------------------------------------------------------------- /src/checkers/select_alias_checker.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Copyright 2020 Palantir Technologies, Inc. All rights reserved. 3 | # Licensed under the MIT License (the "License"); you may obtain a copy of the 4 | # license at https://github.com/palantir/pyspark-style-guide/blob/develop/LICENSE 5 | ## 6 | 7 | 8 | import astroid 9 | 10 | from pylint.checkers import BaseChecker 11 | from pylint.interfaces import IAstroidChecker 12 | 13 | from pylint_utils import select_contains_alias_call 14 | 15 | 16 | class SelectAliasChecker(BaseChecker): 17 | __implements__ = IAstroidChecker 18 | 19 | name = 'select-alias' 20 | priority = -1 21 | msgs = { 22 | 'E1075': ( 23 | 'Select statements should not contain .alias calls', 24 | 'select-contains-alias', 25 | 'Readability of select calls can be much improved by extracting .alias calls.' 26 | ), 27 | } 28 | 29 | def is_line_split(self, function): 30 | line = function.lineno 31 | for arg in function.args.args: 32 | if arg.lineno != line: 33 | return True 34 | return False 35 | 36 | def __init__(self, linter=None): 37 | super(SelectAliasChecker, self).__init__(linter) 38 | self._function_stack = [] 39 | 40 | def visit_functiondef(self, node): 41 | self._function_stack.append([]) 42 | 43 | def visit_expr(self, node): 44 | if select_contains_alias_call(node): 45 | self.add_message('select-contains-alias', node=node) 46 | 47 | def leave_functiondef(self, node): 48 | self._function_stack.pop() 49 | -------------------------------------------------------------------------------- /src/checkers/select_cast_checker.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Copyright 2020 Palantir Technologies, Inc. All rights reserved. 3 | # Licensed under the MIT License (the "License"); you may obtain a copy of the 4 | # license at https://github.com/palantir/pyspark-style-guide/blob/develop/LICENSE 5 | ## 6 | 7 | 8 | import astroid 9 | 10 | from pylint.checkers import BaseChecker 11 | from pylint.interfaces import IAstroidChecker 12 | 13 | from pylint_utils import select_contains_alias_call 14 | 15 | 16 | class SelectCastChecker(BaseChecker): 17 | __implements__ = IAstroidChecker 18 | 19 | name = 'select-cast' 20 | priority = -1 21 | msgs = { 22 | 'E1076': ( 23 | 'Select statements should not contain .cast calls', 24 | 'select-contains-cast', 25 | 'Readability of select calls can be much improved by extracting .cast calls.' 26 | ), 27 | } 28 | 29 | def is_line_split(self, function): 30 | line = function.lineno 31 | for arg in function.args.args: 32 | if arg.lineno != line: 33 | return True 34 | return False 35 | 36 | def __init__(self, linter=None): 37 | super(SelectCastChecker, self).__init__(linter) 38 | self._function_stack = [] 39 | 40 | def visit_functiondef(self, node): 41 | self._function_stack.append([]) 42 | 43 | def visit_expr(self, node): 44 | if select_contains_alias_call(node): 45 | self.add_message('select-contains-cast', node=node) 46 | 47 | def leave_functiondef(self, node): 48 | self._function_stack.pop() 49 | -------------------------------------------------------------------------------- /src/checkers/statement_call_checker.py: -------------------------------------------------------------------------------- 1 | ## 2 | # Copyright 2020 Palantir Technologies, Inc. All rights reserved. 3 | # Licensed under the MIT License (the "License"); you may obtain a copy of the 4 | # license at https://github.com/palantir/pyspark-style-guide/blob/develop/LICENSE 5 | ## 6 | 7 | 8 | import astroid 9 | 10 | from pylint.checkers import BaseChecker 11 | from pylint.interfaces import IAstroidChecker 12 | 13 | from pylint_utils import compute_arguments_length 14 | from pylint_utils import is_line_split 15 | from pylint_utils import get_length 16 | 17 | 18 | class StatementCallChecker(BaseChecker): 19 | __implements__ = IAstroidChecker 20 | 21 | name = 'split-statement' 22 | priority = -1 23 | msgs = { 24 | 'E1073': ( 25 | 'Statements should be on a single line if they fit.', 26 | 'unnecessarily-split-statement', 27 | 'This statement fits into a single line and as such should not be split across multiple lines.' 28 | ), 29 | } 30 | 31 | def is_line_split(self, function): 32 | line = function.lineno 33 | for arg in function.args.args: 34 | if arg.lineno != line: 35 | return True 36 | return False 37 | 38 | def __init__(self, linter=None): 39 | super(StatementCallChecker, self).__init__(linter) 40 | self._function_stack = [] 41 | 42 | def visit_functiondef(self, node): 43 | self._function_stack.append([]) 44 | 45 | def visit_assign(self, node): 46 | if is_line_split(node): 47 | base_length = 3 # _=_ 48 | if get_length(node) + base_length <= 120: 49 | self.add_message( 50 | 'unnecessarily-split-statement', node=node) 51 | 52 | def leave_functiondef(self, node): 53 | self._function_stack.pop() 54 | --------------------------------------------------------------------------------