spark-optimization

Optimize Apache Spark jobs with partitioning, caching, shuffle optimization, and memory tuning. Use when improving Spark performance, debugging slow jobs, or scaling data processing pipelines.

View Source
name:spark-optimizationdescription:Optimize Apache Spark jobs with partitioning, caching, shuffle optimization, and memory tuning. Use when improving Spark performance, debugging slow jobs, or scaling data processing pipelines.

Apache Spark Optimization

Production patterns for optimizing Apache Spark jobs including partitioning strategies, memory management, shuffle optimization, and performance tuning.

Do not use this skill when

  • The task is unrelated to apache spark optimization

  • You need a different domain or tool outside this scope
  • Instructions

  • Clarify goals, constraints, and required inputs.

  • Apply relevant best practices and validate outcomes.

  • Provide actionable steps and verification.

  • If detailed examples are required, open resources/implementation-playbook.md.
  • Use this skill when

  • Optimizing slow Spark jobs

  • Tuning memory and executor configuration

  • Implementing efficient partitioning strategies

  • Debugging Spark performance issues

  • Scaling Spark pipelines for large datasets

  • Reducing shuffle and data skew
  • Core Concepts

    1. Spark Execution Model

    Driver Program

    Job (triggered by action)

    Stages (separated by shuffles)

    Tasks (one per partition)

    2. Key Performance Factors

    FactorImpactSolution
    ShuffleNetwork I/O, disk I/OMinimize wide transformations
    Data SkewUneven task durationSalting, broadcast joins
    SerializationCPU overheadUse Kryo, columnar formats
    MemoryGC pressure, spillsTune executor memory
    PartitionsParallelismRight-size partitions

    Quick Start

    from pyspark.sql import SparkSession
    from pyspark.sql import functions as F

    Create optimized Spark session


    spark = (SparkSession.builder
    .appName("OptimizedJob")
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .config("spark.sql.adaptive.skewJoin.enabled", "true")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.sql.shuffle.partitions", "200")
    .getOrCreate())

    Read with optimized settings


    df = (spark.read
    .format("parquet")
    .option("mergeSchema", "false")
    .load("s3://bucket/data/"))

    Efficient transformations


    result = (df
    .filter(F.col("date") >= "2024-01-01")
    .select("id", "amount", "category")
    .groupBy("category")
    .agg(F.sum("amount").alias("total")))

    result.write.mode("overwrite").parquet("s3://bucket/output/")

    Patterns

    Pattern 1: Optimal Partitioning

    # Calculate optimal partition count
    def calculate_partitions(data_size_gb: float, partition_size_mb: int = 128) -> int:
    """
    Optimal partition size: 128MB - 256MB
    Too few: Under-utilization, memory pressure
    Too many: Task scheduling overhead
    """
    return max(int(data_size_gb 1024 / partition_size_mb), 1)

    Repartition for even distribution


    df_repartitioned = df.repartition(200, "partition_key")

    Coalesce to reduce partitions (no shuffle)


    df_coalesced = df.coalesce(100)

    Partition pruning with predicate pushdown


    df = (spark.read.parquet("s3://bucket/data/")
    .filter(F.col("date") == "2024-01-01")) # Spark pushes this down

    Write with partitioning for future queries


    (df.write
    .partitionBy("year", "month", "day")
    .mode("overwrite")
    .parquet("s3://bucket/partitioned_output/"))

    Pattern 2: Join Optimization

    from pyspark.sql import functions as F
    from pyspark.sql.types import

    1. Broadcast Join - Small table joins


    Best when: One side < 10MB (configurable)


    small_df = spark.read.parquet("s3://bucket/small_table/") # < 10MB
    large_df = spark.read.parquet("s3://bucket/large_table/") # TBs

    Explicit broadcast hint


    result = large_df.join(
    F.broadcast(small_df),
    on="key",
    how="left"
    )

    2. Sort-Merge Join - Default for large tables


    Requires shuffle, but handles any size


    result = large_df1.join(large_df2, on="key", how="inner")

    3. Bucket Join - Pre-sorted, no shuffle at join time


    Write bucketed tables


    (df.write
    .bucketBy(200, "customer_id")
    .sortBy("customer_id")
    .mode("overwrite")
    .saveAsTable("bucketed_orders"))

    Join bucketed tables (no shuffle!)


    orders = spark.table("bucketed_orders")
    customers = spark.table("bucketed_customers") # Same bucket count
    result = orders.join(customers, on="customer_id")

    4. Skew Join Handling


    Enable AQE skew join optimization


    spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
    spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
    spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")

    Manual salting for severe skew


    def salt_join(df_skewed, df_other, key_col, num_salts=10):
    """Add salt to distribute skewed keys"""
    # Add salt to skewed side
    df_salted = df_skewed.withColumn(
    "salt",
    (F.rand() num_salts).cast("int")
    ).withColumn(
    "salted_key",
    F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
    )

    # Explode other side with all salts
    df_exploded = df_other.crossJoin(
    spark.range(num_salts).withColumnRenamed("id", "salt")
    ).withColumn(
    "salted_key",
    F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
    )

    # Join on salted key
    return df_salted.join(df_exploded, on="salted_key", how="inner")

    Pattern 3: Caching and Persistence

    from pyspark import StorageLevel

    Cache when reusing DataFrame multiple times


    df = spark.read.parquet("s3://bucket/data/")
    df_filtered = df.filter(F.col("status") == "active")

    Cache in memory (MEMORY_AND_DISK is default)


    df_filtered.cache()

    Or with specific storage level


    df_filtered.persist(StorageLevel.MEMORY_AND_DISK_SER)

    Force materialization


    df_filtered.count()

    Use in multiple actions


    agg1 = df_filtered.groupBy("category").count()
    agg2 = df_filtered.groupBy("region").sum("amount")

    Unpersist when done


    df_filtered.unpersist()

    Storage levels explained:


    MEMORY_ONLY - Fast, but may not fit


    MEMORY_AND_DISK - Spills to disk if needed (recommended)


    MEMORY_ONLY_SER - Serialized, less memory, more CPU


    DISK_ONLY - When memory is tight


    OFF_HEAP - Tungsten off-heap memory

    Checkpoint for complex lineage


    spark.sparkContext.setCheckpointDir("s3://bucket/checkpoints/")
    df_complex = (df
    .join(other_df, "key")
    .groupBy("category")
    .agg(F.sum("amount")))
    df_complex.checkpoint() # Breaks lineage, materializes

    Pattern 4: Memory Tuning

    # Executor memory configuration

    spark-submit --executor-memory 8g --executor-cores 4

    Memory breakdown (8GB executor):


    - spark.memory.fraction = 0.6 (60% = 4.8GB for execution + storage)


    - spark.memory.storageFraction = 0.5 (50% of 4.8GB = 2.4GB for cache)


    - Remaining 2.4GB for execution (shuffles, joins, sorts)


    - 40% = 3.2GB for user data structures and internal metadata

    spark = (SparkSession.builder
    .config("spark.executor.memory", "8g")
    .config("spark.executor.memoryOverhead", "2g") # For non-JVM memory
    .config("spark.memory.fraction", "0.6")
    .config("spark.memory.storageFraction", "0.5")
    .config("spark.sql.shuffle.partitions", "200")
    # For memory-intensive operations
    .config("spark.sql.autoBroadcastJoinThreshold", "50MB")
    # Prevent OOM on large shuffles
    .config("spark.sql.files.maxPartitionBytes", "128MB")
    .getOrCreate())

    Monitor memory usage


    def print_memory_usage(spark):
    """Print current memory usage"""
    sc = spark.sparkContext
    for executor in sc._jsc.sc().getExecutorMemoryStatus().keySet().toArray():
    mem_status = sc._jsc.sc().getExecutorMemoryStatus().get(executor)
    total = mem_status._1() / (10243)
    free = mem_status._2() / (1024
    3)
    print(f"{executor}: {total:.2f}GB total, {free:.2f}GB free")

    Pattern 5: Shuffle Optimization

    # Reduce shuffle data size
    spark.conf.set("spark.sql.shuffle.partitions", "auto") # With AQE
    spark.conf.set("spark.shuffle.compress", "true")
    spark.conf.set("spark.shuffle.spill.compress", "true")

    Pre-aggregate before shuffle


    df_optimized = (df
    # Local aggregation first (combiner)
    .groupBy("key", "partition_col")
    .agg(F.sum("value").alias("partial_sum"))
    # Then global aggregation
    .groupBy("key")
    .agg(F.sum("partial_sum").alias("total")))

    Avoid shuffle with map-side operations


    BAD: Shuffle for each distinct


    distinct_count = df.select("category").distinct().count()

    GOOD: Approximate distinct (no shuffle)


    approx_count = df.select(F.approx_count_distinct("category")).collect()[0][0]

    Use coalesce instead of repartition when reducing partitions


    df_reduced = df.coalesce(10) # No shuffle

    Optimize shuffle with compression


    spark.conf.set("spark.io.compression.codec", "lz4") # Fast compression

    Pattern 6: Data Format Optimization

    # Parquet optimizations
    (df.write
    .option("compression", "snappy") # Fast compression
    .option("parquet.block.size", 128
    1024 * 1024) # 128MB row groups
    .parquet("s3://bucket/output/"))

    Column pruning - only read needed columns


    df = (spark.read.parquet("s3://bucket/data/")
    .select("id", "amount", "date")) # Spark only reads these columns

    Predicate pushdown - filter at storage level


    df = (spark.read.parquet("s3://bucket/partitioned/year=2024/")
    .filter(F.col("status") == "active")) # Pushed to Parquet reader

    Delta Lake optimizations


    (df.write
    .format("delta")
    .option("optimizeWrite", "true") # Bin-packing
    .option("autoCompact", "true") # Compact small files
    .mode("overwrite")
    .save("s3://bucket/delta_table/"))

    Z-ordering for multi-dimensional queries


    spark.sql("""
    OPTIMIZE delta.s3://bucket/delta_table/
    ZORDER BY (customer_id, date)
    """)

    Pattern 7: Monitoring and Debugging

    # Enable detailed metrics
    spark.conf.set("spark.sql.codegen.wholeStage", "true")
    spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

    Explain query plan


    df.explain(mode="extended")

    Modes: simple, extended, codegen, cost, formatted

    Get physical plan statistics


    df.explain(mode="cost")

    Monitor task metrics


    def analyze_stage_metrics(spark):
    """Analyze recent stage metrics"""
    status_tracker = spark.sparkContext.statusTracker()

    for stage_id in status_tracker.getActiveStageIds():
    stage_info = status_tracker.getStageInfo(stage_id)
    print(f"Stage {stage_id}:")
    print(f" Tasks: {stage_info.numTasks}")
    print(f" Completed: {stage_info.numCompletedTasks}")
    print(f" Failed: {stage_info.numFailedTasks}")

    Identify data skew


    def check_partition_skew(df):
    """Check for partition skew"""
    partition_counts = (df
    .withColumn("partition_id", F.spark_partition_id())
    .groupBy("partition_id")
    .count()
    .orderBy(F.desc("count")))

    partition_counts.show(20)

    stats = partition_counts.select(
    F.min("count").alias("min"),
    F.max("count").alias("max"),
    F.avg("count").alias("avg"),
    F.stddev("count").alias("stddev")
    ).collect()[0]

    skew_ratio = stats["max"] / stats["avg"]
    print(f"Skew ratio: {skew_ratio:.2f}x (>2x indicates skew)")

    Configuration Cheat Sheet

    # Production configuration template
    spark_configs = {
    # Adaptive Query Execution (AQE)
    "spark.sql.adaptive.enabled": "true",
    "spark.sql.adaptive.coalescePartitions.enabled": "true",
    "spark.sql.adaptive.skewJoin.enabled": "true",

    # Memory
    "spark.executor.memory": "8g",
    "spark.executor.memoryOverhead": "2g",
    "spark.memory.fraction": "0.6",
    "spark.memory.storageFraction": "0.5",

    # Parallelism
    "spark.sql.shuffle.partitions": "200",
    "spark.default.parallelism": "200",

    # Serialization
    "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
    "spark.sql.execution.arrow.pyspark.enabled": "true",

    # Compression
    "spark.io.compression.codec": "lz4",
    "spark.shuffle.compress": "true",

    # Broadcast
    "spark.sql.autoBroadcastJoinThreshold": "50MB",

    # File handling
    "spark.sql.files.maxPartitionBytes": "128MB",
    "spark.sql.files.openCostInBytes": "4MB",
    }

    Best Practices

    Do's


  • Enable AQE - Adaptive query execution handles many issues

  • Use Parquet/Delta - Columnar formats with compression

  • Broadcast small tables - Avoid shuffle for small joins

  • Monitor Spark UI - Check for skew, spills, GC

  • Right-size partitions - 128MB - 256MB per partition
  • Don'ts


  • Don't collect large data - Keep data distributed

  • Don't use UDFs unnecessarily - Use built-in functions

  • Don't over-cache - Memory is limited

  • Don't ignore data skew - It dominates job time

  • Don't use .count() for existence - Use .take(1) or .isEmpty()
  • Resources

  • Spark Performance Tuning

  • Spark Configuration

  • Databricks Optimization Guide

    1. spark-optimization - Agent Skills