Post

二八定律解决 PySpark

二八定律解决 PySpark

分布式思维

它和普通 Pandas/Python 脚本的本质区别:

  • Lazy Evaluation (延迟计算):你写的所有 selectfilterjoin(称为 Transformations)都不会立刻执行,只是在构建一个执行计划图(DAG)。只有当你调用 show()count()collect()write(称为 Actions)时,Spark 才会把任务下发到集群真正开始算。

  • Partition (分区):数据不是存在一台机器上的,而是被切分成多个 Partition 分布在多台机器(Executors)上并行处理。

  • Shuffle (洗牌):当你执行 groupByjoin 或 Window 操作时,相同 Key 的数据必须被拉取到同一个节点上,这个网络传输过程叫 Shuffle,它是所有 Spark 性能问题的万恶之源。

20% API (高频操作)

在代码中养成统一的导入习惯:

1
2
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F

A. 读写与输入输出 (I/O)

  • spark.read.parquet("path")spark.sql("SELECT * FROM hive_table")

  • df.write.mode("overwrite").partitionBy("date").parquet("path")

    • 注:数仓中通常按日期 partitionBy__,这是极度高频的操作。

B. SQL 与 DataFrame 无缝切换 (降维打击)

如果你已经精通 SQL,不需要强行把所有逻辑翻译成 DataFrame API。你可以直接把 DataFrame 注册成虚拟表,用 SQL 写最复杂的逻辑:

1
2
df.createOrReplaceTempView("my_temp_table")
result_df = spark.sql("SELECT a, SUM(b) FROM my_temp_table GROUP BY a")

C. 列操作与条件分支 (pyspark.sql.functions)

F 模块是你最常用的工具箱:

  • 选列与重命名df.select(F.col("a").alias("new_a"), "b")

  • 条件赋值 (If-Else)df.withColumn("status", F.when(F.col("age") > 18, "adult").otherwise("minor"))

  • 类型转换F.col("str_num").cast("int")

  • 常量列F.lit(1)

D. 聚合与关联 (Join & Agg)

  • 聚合df.groupBy("store_id").agg(F.sum("sales").alias("total_sales"), F.countDistinct("user_id"))

  • Joindf_a.join(df_b, on="user_id", how="left")

  • Broadcast Join (必须掌握的优化):当一张表很小(比如几百MB以内),另一张表很大时,务必用 F.broadcast(小表)。这能彻底消除大表的 Shuffle 过程,性能提升十几倍。

    df_large.join(F.broadcast(df_small), on=”id”)

E. 窗口函数 (Window Functions)

ETL 中处理“过去7天滚动平均”、“同组排序取TopN”的唯一利器。

1
2
w = Window.partitionBy("user_id").orderBy("date").rowsBetween(-6, 0) *# 过去7天*
df.withColumn("7d_avg", F.avg("sales").over(w))

3. 高级特性:Pandas UDF

由于你是做 AI/Algo Infra 相关的(从你刚打开的 fit_daily_volatility_spark.py 可以看出),普通的 Spark UDF 性能极差(JVM 和 Python 进程之间存在巨大的序列化开销)。 永远优先使用基于 Apache Arrow 的 Pandas UDF 或 applyInPandas

  • 场景:需要按组(比如每个门店)进行复杂的算法拟合或 Pandas 操作。

  • 用法(正是你代码中用到的模式):

    1
    2
    3
    4
    5
    
    def fit_model(pdf: pd.DataFrame) -> pd.DataFrame:
    *# 这里的 pdf 是一个组内全部数据的原生 Pandas DataFrame*
    res = my_algo(pdf['features'])
    return pd.DataFrame({"result": [res]})
    df.groupBy("store_id").applyInPandas(fit_model, schema="result string")
    

4. 生产环境避坑指南 (Troubleshooting)

学会在集群上活下来比学 API 更重要:

  • Driver OOM (内存溢出):绝大多数是因为你对着几亿条数据调用了 df.collect() 或 df.toPandas()。记住,collect 会把所有机器上的数据拉回一台机器的内存里。在拉取前先 df.limit(1000)

  • 数据倾斜 (Data Skew):一个任务跑到了 99%,卡了半小时不动。大概率是 join 或 groupBy 时,某个 Key(比如 null 或者是极其热门的 store_id)的数据量太大,全被分到了同一个节点。

    • 解法:如果是 Join 导致的,尝试 Broadcast 小表;如果是过滤不严,先 filter(F.col("key").isNotNull());或者给 Key 加随机后缀(打散/加盐)。
  • 看懂 Spark UI:当你提交任务后,务必打开 Spark UI。看 Stages 页面,关注哪个 Stage 耗时最长,里面是否有少数几个 Tasks 耗时远超其他 Tasks(数据倾斜的铁证)。

This post is licensed under CC BY 4.0 by the author.