Speed Up Spark Dataset Writes using python's ThreadPoolExecutor

TL;DR

When dealing with large-scale ML outputs, writing results back to storage often becomes the bottleneck. Spark provides efficient distributed computation, but once you’ve materialized results into a single DataFrame, persisting them back to Parquet can still take a significant amount of time. This blog showcases how to split a massive Spark DataFrame into parallel threads to overcome IO-bound computation.

In the previous blog article, we discussed how to split CPU-bound computation to speed up ML pipelines. Now we have a new issue: actually storing the results back quickly.

The Solution

The trick is to chunk the DataFrame and leverage Python’s ThreadPoolExecutor to parallelize writes across multiple threads. Each thread takes a chunk of rows, converts it back to a Spark DataFrame with the same schema, and writes to Parquet in append mode.

This way, instead of waiting for a single sequential write, you saturate available IO bandwidth by performing multiple writes concurrently.

Code Walkthrough

Here’s the high-level pattern:

  1. Define a schema to ensure consistency across chunks.
  2. Split the Pandas DataFrame (Y) into chunks.
  3. Convert each chunk back into a Spark DataFrame.
  4. Use ThreadPoolExecutor to write each chunk in parallel.
import pyspark.sql.functions as F
import concurrent.futures
from datetime import datetime

# Schema defined ahead of time (abbreviated here)
schema = StructType([
    StructField('ID', StringType(), True),
    StructField('PROBABILITY', FloatType(), True),
    # ... many more fields
])

def write_to_parquet(args):
    i, chunk = args
    start_time = datetime.now()
    print('Chunk', i, 'started')
    chunk.write.parquet(
        f'/.../prediction_result_data_{prediction_month}.parquet',
        mode='append'
    )
    print('Chunk', i, 'ended', datetime.now() - start_time)

# Split into chunks of 100k rows
chunks = [
    (i, spark.createDataFrame(
        Y.iloc[i:i+100000, :].astype({
            'TOTAL_INSTALLMENTS': 'string',
            # ... more type casts
        }), schema=schema
    ))
    for i in range(0, len(Y), 100000)
]

print("Writing", len(chunks), "chunks...")

# First chunk overwrites to create the base file
chunks[0][1].write.parquet(
    f'/.../prediction_result_data_{prediction_month}.parquet',
    mode='overwrite'
)

# Rest are appended in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
    results = list(executor.map(write_to_parquet, chunks[1:]))

Why It Works

Key Considerations

Takeaway

When working with massive ML pipelines, computation isn’t the only bottleneck. Writing results back to disk can be just as painful. Splitting Spark DataFrames into parallel thread-based writes provides a simple yet powerful technique to scale IO-bound parts of your pipeline.