PySpark GroupBy: Shuffle Cost, Skew, and Production Patterns
groupBy creates a shuffle boundary. The default output is 200 partitions. collect_list on skewed keys causes OOM. Interviewers test whether you know the shuffle cost of every aggregation you write.
Basic GroupBy with Multiple Aggregations
from pyspark.sql import functions as F
# groupBy creates a shuffle boundary.
# The number of output partitions = spark.sql.shuffle.partitions
# which defaults to 200. Every row moves across the network.
result = df.groupBy("department").agg(
F.sum("salary").alias("total_salary"),
F.avg("salary").alias("avg_salary"),
F.count("*").alias("headcount"),
F.countDistinct("title").alias("unique_titles")
)
# Always alias output columns. Without alias, Spark generates
# names like "avg(salary)" which break downstream column references.groupBy creates a shuffle boundary. Spark redistributes every row so that rows sharing the same key land on the same partition. The output partition count equals spark.sql.shuffle.partitions, which defaults to 200. If you have 10 departments, 190 of those 200 partitions will be empty. If you have 50,000 distinct keys, 200 partitions may be too few.
GroupBy on Multiple Columns
# Spark shuffles by ALL grouping columns combined
result = df.groupBy("department", F.year("hire_date").alias("hire_year")).agg(
F.count("*").alias("hires")
)
# Grouping by high-cardinality columns (e.g., user_id)
# produces many small groups spread across 200 partitions.
# Grouping by low-cardinality columns (e.g., country)
# produces few large groups, some potentially skewed.Spark hashes all grouping columns together to determine the target partition. High-cardinality keys distribute evenly but create many small groups. Low-cardinality keys create fewer groups, and if one value dominates (e.g., country="US" is 60% of traffic), that partition becomes a bottleneck.
Pivot Tables
# Always pass the values list to pivot(). Without it,
# Spark does a full scan to discover unique values first,
# then scans again to build the pivot. Two passes, not one.
pivoted = df.groupBy("department").pivot(
"quarter", ["Q1", "Q2", "Q3", "Q4"]
).agg(
F.sum("revenue")
)
# Result: one row per department, columns: Q1, Q2, Q3, Q4pivot() without explicit values triggers a full extra scan of the data to discover unique values. In production, this doubles your I/O on large tables. Always pass the values list as the second argument. Interviewers test whether you know this optimization.
collect_list and collect_set (The OOM Trap)
# collect_list pulls ALL group values into a single array
# on ONE partition. This is safe for small groups.
result = df.groupBy("order_id").agg(
F.collect_list("item_name").alias("items")
)
# But on skewed keys, one partition gets all the data.
# A key with 50M rows tries to build a 50M-element array
# in one executor's memory. OOM.
# Safer alternative: aggregate first, then collect
result = (
df.groupBy("order_id", "item_name")
.agg(F.count("*").alias("qty"))
.groupBy("order_id")
.agg(F.collect_list(
F.struct("item_name", "qty")
).alias("items"))
)collect_list and collect_set pull all group values to one partition. On skewed keys, this causes OOM. A key that appears 50M times tries to build a 50M-element array in a single executor. The fix is to pre-aggregate before collecting, or to set a max group size and filter out keys that exceed it. Interviewers test this because it exposes whether you have dealt with real data skew.
GroupBy vs Window Functions
# GroupBy collapses rows: one output row per group
dept_totals = df.groupBy("department").agg(
F.sum("salary").alias("dept_total")
)
# Window keeps all rows, adds a computed column
from pyspark.sql.window import Window
w = Window.partitionBy("department")
with_total = df.withColumn("dept_total", F.sum("salary").over(w))
# Use groupBy when you need summary rows.
# Use window when you need individual rows with group context.
# Both trigger a shuffle. The window shuffle is often larger
# because it preserves all columns on every row.GroupBy collapses rows into summaries. Window functions keep every row and add a computed column. Both create a shuffle boundary. The tradeoff: window shuffles move more data because they preserve all columns on every row, while groupBy only shuffles the grouping key plus the aggregation inputs.
What Production Teaches You About PySpark GroupBy
In production, the default 200 shuffle partitions is almost never right. A 10GB dataset split into 200 partitions produces 50MB partitions, which is reasonable. A 1TB dataset produces 5GB partitions, which spill to disk and slow everything down. A 100MB dataset produces 0.5MB partitions, which waste scheduling overhead. Set spark.sql.shuffle.partitions based on your data size, not the default.
Filter rows and select only needed columns before the groupBy. Every extra column you carry through the shuffle gets serialized, sent across the network, and deserialized. On a 50-column table where you only aggregate 3 columns, the other 47 are dead weight in the shuffle.
approx_count_distinct is 2 to 5% off the true count but runs 3 to 10x faster than countDistinct on large datasets. If your dashboard rounds to the nearest thousand anyway, exact counts are waste. Interviewers at senior levels test whether you reach for approximate functions when precision is not required.
When the same groupBy key appears in multiple downstream operations, repartition by that key once and cache. This avoids reshuffling the data for each subsequent groupBy. The tradeoff: caching consumes memory. Only do this when the reuse is real and measurable.
PySpark GroupBy FAQ
Does groupBy trigger a shuffle in PySpark?+
What is the difference between groupBy and groupByKey in PySpark?+
How do I handle data skew in groupBy?+
Can I group by a computed column in PySpark?+
Practice PySpark GroupBy Patterns Before Your Interview
DataDriven has PySpark challenges that test aggregations, pivot tables, skew handling, and collect_list traps against real datasets.
Start Practicing