二八定律解决 PySpark
分布式思维
它和普通 Pandas/Python 脚本的本质区别:
Lazy Evaluation (延迟计算):你写的所有
select,filter,join(称为 Transformations)都不会立刻执行,只是在构建一个执行计划图(DAG)。只有当你调用show(),count(),collect(),write(称为 Actions)时,Spark 才会把任务下发到集群真正开始算。Partition (分区):数据不是存在一台机器上的,而是被切分成多个 Partition 分布在多台机器(Executors)上并行处理。
Shuffle (洗牌):当你执行
groupBy、join或Window操作时,相同 Key 的数据必须被拉取到同一个节点上,这个网络传输过程叫 Shuffle,它是所有 Spark 性能问题的万恶之源。
20% API (高频操作)
在代码中养成统一的导入习惯:
1
2
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
A. 读写与输入输出 (I/O)
读:
spark.read.parquet("path"),spark.sql("SELECT * FROM hive_table")写:
df.write.mode("overwrite").partitionBy("date").parquet("path")- 注:数仓中通常按日期
partitionBy__,这是极度高频的操作。
- 注:数仓中通常按日期
B. SQL 与 DataFrame 无缝切换 (降维打击)
如果你已经精通 SQL,不需要强行把所有逻辑翻译成 DataFrame API。你可以直接把 DataFrame 注册成虚拟表,用 SQL 写最复杂的逻辑:
1
2
df.createOrReplaceTempView("my_temp_table")
result_df = spark.sql("SELECT a, SUM(b) FROM my_temp_table GROUP BY a")
C. 列操作与条件分支 (pyspark.sql.functions)
F 模块是你最常用的工具箱:
选列与重命名:
df.select(F.col("a").alias("new_a"), "b")条件赋值 (If-Else):
df.withColumn("status", F.when(F.col("age") > 18, "adult").otherwise("minor"))类型转换:
F.col("str_num").cast("int")常量列:
F.lit(1)
D. 聚合与关联 (Join & Agg)
聚合:
df.groupBy("store_id").agg(F.sum("sales").alias("total_sales"), F.countDistinct("user_id"))Join:
df_a.join(df_b, on="user_id", how="left")Broadcast Join (必须掌握的优化):当一张表很小(比如几百MB以内),另一张表很大时,务必用
F.broadcast(小表)。这能彻底消除大表的 Shuffle 过程,性能提升十几倍。df_large.join(F.broadcast(df_small), on=”id”)
E. 窗口函数 (Window Functions)
ETL 中处理“过去7天滚动平均”、“同组排序取TopN”的唯一利器。
1
2
w = Window.partitionBy("user_id").orderBy("date").rowsBetween(-6, 0) *# 过去7天*
df.withColumn("7d_avg", F.avg("sales").over(w))
3. 高级特性:Pandas UDF
由于你是做 AI/Algo Infra 相关的(从你刚打开的 fit_daily_volatility_spark.py 可以看出),普通的 Spark UDF 性能极差(JVM 和 Python 进程之间存在巨大的序列化开销)。 永远优先使用基于 Apache Arrow 的 Pandas UDF 或 applyInPandas。
场景:需要按组(比如每个门店)进行复杂的算法拟合或 Pandas 操作。
用法(正是你代码中用到的模式):
1 2 3 4 5
def fit_model(pdf: pd.DataFrame) -> pd.DataFrame: *# 这里的 pdf 是一个组内全部数据的原生 Pandas DataFrame* res = my_algo(pdf['features']) return pd.DataFrame({"result": [res]}) df.groupBy("store_id").applyInPandas(fit_model, schema="result string")
4. 生产环境避坑指南 (Troubleshooting)
学会在集群上活下来比学 API 更重要:
Driver OOM (内存溢出):绝大多数是因为你对着几亿条数据调用了
df.collect()或df.toPandas()。记住,collect会把所有机器上的数据拉回一台机器的内存里。在拉取前先df.limit(1000)。数据倾斜 (Data Skew):一个任务跑到了 99%,卡了半小时不动。大概率是
join或groupBy时,某个 Key(比如null或者是极其热门的store_id)的数据量太大,全被分到了同一个节点。- 解法:如果是 Join 导致的,尝试 Broadcast 小表;如果是过滤不严,先
filter(F.col("key").isNotNull());或者给 Key 加随机后缀(打散/加盐)。
- 解法:如果是 Join 导致的,尝试 Broadcast 小表;如果是过滤不严,先
看懂 Spark UI:当你提交任务后,务必打开 Spark UI。看
Stages页面,关注哪个 Stage 耗时最长,里面是否有少数几个 Tasks 耗时远超其他 Tasks(数据倾斜的铁证)。