PySpark Guide

PySpark GroupBy: Shuffle Cost, Skew, and Production Patterns

groupBy creates a shuffle boundary. The default output is 200 partitions. collect_list on skewed keys causes OOM. Interviewers test whether you know the shuffle cost of every aggregation you write.

Basic GroupBy with Multiple Aggregations

from pyspark.sql import functions as F

# groupBy creates a shuffle boundary.
# The number of output partitions = spark.sql.shuffle.partitions
# which defaults to 200. Every row moves across the network.
result = df.groupBy("department").agg(
    F.sum("salary").alias("total_salary"),
    F.avg("salary").alias("avg_salary"),
    F.count("*").alias("headcount"),
    F.countDistinct("title").alias("unique_titles")
)

# Always alias output columns. Without alias, Spark generates
# names like "avg(salary)" which break downstream column references.

groupBy creates a shuffle boundary. Spark redistributes every row so that rows sharing the same key land on the same partition. The output partition count equals spark.sql.shuffle.partitions, which defaults to 200. If you have 10 departments, 190 of those 200 partitions will be empty. If you have 50,000 distinct keys, 200 partitions may be too few.

GroupBy on Multiple Columns

# Spark shuffles by ALL grouping columns combined
result = df.groupBy("department", F.year("hire_date").alias("hire_year")).agg(
    F.count("*").alias("hires")
)

# Grouping by high-cardinality columns (e.g., user_id)
# produces many small groups spread across 200 partitions.
# Grouping by low-cardinality columns (e.g., country)
# produces few large groups, some potentially skewed.

Spark hashes all grouping columns together to determine the target partition. High-cardinality keys distribute evenly but create many small groups. Low-cardinality keys create fewer groups, and if one value dominates (e.g., country="US" is 60% of traffic), that partition becomes a bottleneck.

Pivot Tables

# Always pass the values list to pivot(). Without it,
# Spark does a full scan to discover unique values first,
# then scans again to build the pivot. Two passes, not one.
pivoted = df.groupBy("department").pivot(
    "quarter", ["Q1", "Q2", "Q3", "Q4"]
).agg(
    F.sum("revenue")
)

# Result: one row per department, columns: Q1, Q2, Q3, Q4

pivot() without explicit values triggers a full extra scan of the data to discover unique values. In production, this doubles your I/O on large tables. Always pass the values list as the second argument. Interviewers test whether you know this optimization.

collect_list and collect_set (The OOM Trap)

# collect_list pulls ALL group values into a single array
# on ONE partition. This is safe for small groups.
result = df.groupBy("order_id").agg(
    F.collect_list("item_name").alias("items")
)

# But on skewed keys, one partition gets all the data.
# A key with 50M rows tries to build a 50M-element array
# in one executor's memory. OOM.

# Safer alternative: aggregate first, then collect
result = (
    df.groupBy("order_id", "item_name")
    .agg(F.count("*").alias("qty"))
    .groupBy("order_id")
    .agg(F.collect_list(
        F.struct("item_name", "qty")
    ).alias("items"))
)

collect_list and collect_set pull all group values to one partition. On skewed keys, this causes OOM. A key that appears 50M times tries to build a 50M-element array in a single executor. The fix is to pre-aggregate before collecting, or to set a max group size and filter out keys that exceed it. Interviewers test this because it exposes whether you have dealt with real data skew.

GroupBy vs Window Functions

# GroupBy collapses rows: one output row per group
dept_totals = df.groupBy("department").agg(
    F.sum("salary").alias("dept_total")
)

# Window keeps all rows, adds a computed column
from pyspark.sql.window import Window
w = Window.partitionBy("department")
with_total = df.withColumn("dept_total", F.sum("salary").over(w))

# Use groupBy when you need summary rows.
# Use window when you need individual rows with group context.
# Both trigger a shuffle. The window shuffle is often larger
# because it preserves all columns on every row.

GroupBy collapses rows into summaries. Window functions keep every row and add a computed column. Both create a shuffle boundary. The tradeoff: window shuffles move more data because they preserve all columns on every row, while groupBy only shuffles the grouping key plus the aggregation inputs.

What Production Teaches You About PySpark GroupBy

In production, the default 200 shuffle partitions is almost never right. A 10GB dataset split into 200 partitions produces 50MB partitions, which is reasonable. A 1TB dataset produces 5GB partitions, which spill to disk and slow everything down. A 100MB dataset produces 0.5MB partitions, which waste scheduling overhead. Set spark.sql.shuffle.partitions based on your data size, not the default.

Filter rows and select only needed columns before the groupBy. Every extra column you carry through the shuffle gets serialized, sent across the network, and deserialized. On a 50-column table where you only aggregate 3 columns, the other 47 are dead weight in the shuffle.

approx_count_distinct is 2 to 5% off the true count but runs 3 to 10x faster than countDistinct on large datasets. If your dashboard rounds to the nearest thousand anyway, exact counts are waste. Interviewers at senior levels test whether you reach for approximate functions when precision is not required.

When the same groupBy key appears in multiple downstream operations, repartition by that key once and cache. This avoids reshuffling the data for each subsequent groupBy. The tradeoff: caching consumes memory. Only do this when the reuse is real and measurable.

PySpark GroupBy FAQ

Does groupBy trigger a shuffle in PySpark?+
Yes. groupBy creates a shuffle boundary. Spark redistributes all rows so that rows with the same key end up on the same partition. The number of output partitions equals spark.sql.shuffle.partitions, which defaults to 200. Shuffle write/read is the #1 bottleneck in 80%+ of slow Spark jobs.
What is the difference between groupBy and groupByKey in PySpark?+
groupBy().agg() uses Spark SQL's Catalyst optimizer and Tungsten memory management. It performs partial aggregation (map-side combine) before the shuffle, reducing data movement. groupByKey() is an RDD-level operation that sends all values across the network before aggregation. On a 1B-row dataset, groupByKey can move 10x more data than groupBy().agg().
How do I handle data skew in groupBy?+
If one key holds more than 10x the median partition size, that partition becomes a bottleneck. Options: (1) salting the key to spread it across multiple partitions, then re-aggregating, (2) filtering out the hot key and processing it separately, (3) increasing spark.sql.shuffle.partitions to reduce per-partition size. AQE in Spark 3.x can also split skewed partitions automatically when a partition exceeds 5x the median AND is larger than 256MB.
Can I group by a computed column in PySpark?+
Yes. Pass an expression directly: df.groupBy(F.year("date").alias("year")).agg(...). Spark evaluates the expression before the shuffle, so the computed value becomes the shuffle key.

Practice PySpark GroupBy Patterns Before Your Interview

DataDriven has PySpark challenges that test aggregations, pivot tables, skew handling, and collect_list traps against real datasets.

Start Practicing