PySpark (+Jupyter Notebook) でDataFrameを扱う

前回の投稿では、PySparkをJupyter Notebookから操作するための環境を作りました。

ohke.hateblo.jp

今回は上の環境を使って、PySparkでDataFrameを扱う方法についてまとめます。
(そのため上の環境構築が済んでいる前提となります。)

SparkのDataFrame

Sparkで、分散させるデータを扱うためのAPIが3種類あります。

  1. RDD
  2. DataFrame
  3. DataSet

このうち、PySparkでできるのはRDD or DataFrameとなります (DataSet APIは型安全性を確保するのがメインテーマであるため、Scala/Javaのみサポートされてます) 。

ではどちらを使うべきかですが、DataFrame APIを使うのが良いそうです。
というのも、DataFrame APIはCatalyst Optimizerがラップされており、これがクエリの実行計画を最適化するため、Scala/Javaと比較しても遜色ない程度のパフォーマンスが得られるためです (RDDではpy4jによって発生するPythonとJVMの間のコンテキストスイッチに対して最適化されていないので、Scala/Javaと比較して格段に遅くなりやすかったのです) 。

DataFrameの作成

最初にpysparkをインストールしておきます。

$ pip install pyspark

今回はUCIから提供されている肺がんデータセットをダウンロードして使います。

http://mlr.cs.umass.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Original%29

ダウンロードしたCSVファイルをDataFrameを作りますが、その前にSparkSessionでセッションを作成します。
インタラクティブシェルであれば起動時に自動的にセッションが作成あれ、sparkオブジェクトを介してアクセスできます。インタラクティブシェルを使わない場合 (Jupyter Notebookなど) は、以下のように自前でセッションを確立します。

from pyspark.sql import SparkSession
from pyspark.sql.types import *
import urllib.request

# 肺がんデータセットをダウンロード
urllib.request.urlretrieve('http://mlr.cs.umass.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data', 'breast-cancer-wisconsin.data')

# セッションの作成
spark_session = SparkSession.builder.getOrCreate()
# セッションを切るときは spark_session.stop()

上で確立したSparkSessionを使って、DataFrameを作成します。

  • スキーマはStructTypeで定義します
  • CSVファイルから作成しますので、DataFrameReaderクラスのcsvメソッドにファイルパス (とスキーマ定義) を渡してます
# CSVファイルのスキーマを定義
data_schema = StructType([
    StructField("id", StringType(), False),
    StructField('clump_thickness', LongType(), True),
    StructField('uniformity_of_cell_size', LongType(), True),
    StructField('uniformity_of_cell_shape', LongType(), True),
    StructField('marginal_adhesion', LongType(), True),
    StructField('single_epithelial_cell_size', LongType(), True),
    StructField('bare_nuclei', LongType(), True),
    StructField('bland_chromatin', LongType(), True),
    StructField('normal_nucleoli', LongType(), True),
    StructField('mitoses', LongType(), True),
    StructField('classification', LongType(), True),
])

# DataFrameの作成
data = spark_session.read.csv('breast-cancer-wisconsin.data', schema=data_schema)
data.show(1)

#+-------+---------------+-----------------------+------------------------+-----------------+---------------------------+-----------+---------------+---------------+-------+--------------+
|     #id|clump_thickness|uniformity_of_cell_size|uniformity_of_cell_shape|marginal_adhesion|single_epithelial_cell_size|bare_nuclei|bland_chromatin|normal_nucleoli|mitoses|classification|
#+-------+---------------+-----------------------+------------------------+-----------------+---------------------------+-----------+---------------+---------------+-------+--------------+
#|1000025|              5|                      1|                       1|                1|                          2|          1|              3|              1|      1|             2|
#+-------+---------------+-----------------------+------------------------+-----------------+---------------------------+-----------+---------------+---------------+-------+--------------+
#only showing top 1 row

DataFrameへのクエリ

DataFrameへのクエリは、DataFrameオブジェクトのAPIを使う方法と、SQLで記述する方法の2種類があります。

DataFrameオブジェクトのAPI

作成したDataFrameオブジェクトのAPIを使ってクエリする例は以下です。

メソッドチェーンで選択、射影、結合を記述します。それぞれfilterselectjoinのメソッドで行われます。
アクション (showcounttakecollectなど) を実行するまでは遅延評価されます。C#のLINQなどにイメージは近いです。

# id列だけ5件を表示する
data.select('id').show(5)
# +-------+
# |     id|
# +-------+
# |1000025|
# |1002945|
# |1015425|
# |1016277|
# |1017023|
# +-------+

# classificationが4の行数を返す
data.filter("classification == 4").count() # 241

# idの末尾が0で、かつ、classificationが4の行を3行取得
data.select('id', 'classification').filter("id like '%0' and classification == 4").take(3)
# [Row(id='1047630', classification=4),
#  Row(id='1050670', classification=4),
#  Row(id='1054590', classification=4)]

# classificationが2ならば"benign", 4ならば"malignant"とするJSONファイルを作成
with open('label.json', mode='w') as f:
    f.write("""
    [
        { "classification": 2, "label": "benign" },
        { "classification": 4, "label": "malignant" }
    ]
    """)

# JSONファイルからDataFrameを作成
label = spark_session.read.json('label.json', multiLine=True)

# dataとjoinする
data.join(label, data.classification == label.classification).select('id', 'label').take(3)
# [Row(id='1000025', label='benign'),
#  Row(id='1002945', label='benign'),
#  Row(id='1015425', label='benign')]

SQLを使ったクエリ

セッションを使って、SQLで記述することもできます。

SQLの場合、DataFrameから一時テーブルを作成する必要があります。ここでは、createOrReplaceTempViewメソッドで一時テーブル化してます。
SQLでは、作成時に渡された名前でテーブルを参照します。

# 一時テーブルを作成
data.createOrReplaceTempView('cancer')
label.createOrReplaceTempView('label')

spark_session.sql("""
select id, label
from cancer 
    inner join label
        on cancer.classification = label.classification
where id like '%0'
""").take(3)
# [Row(id='1047630', label='malignant'),
#  Row(id='1050670', label='malignant'),
#  Row(id='1054590', label='malignant'),
#  Row(id='1071760', label='benign'),
#  Row(id='1074610', label='benign')]

まとめ

今回はPySparkでDataFrameを扱う方法について整理しました。最初にセッションを確立し、DataFrameのAPI or SQLで操作できることを確認しました。

入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム

入門 PySpark ―PythonとJupyterで活用するSpark 2エコシステム