panderaでDataFrameをバリデーションする

pandasのDataFrameは柔軟なテーブル構造を提供してくれますが、時に柔軟すぎて困ることもしばしばです。本番運用するアプリケーションですとなおさらこの欠点が目立ちます。

  • 入力データに依存して意図しない型に変わってしまったり...
    • ex. [1, 2, 3]だとint64、[1, None, 3]だとfloat64で解釈される
  • そもそも意図しないフォーマットの入力データが問題なく入ってしまったり...

DataFrameのバリデーションを行うライブラリはいくつかありますが、今回は pandera を紹介します。仮説検定を行う機能 (Hypothesis) も備えていますが、用途が限定的ですので、単純な値バリデーション (Check) のみに説明を絞りたいと思います。

% python --version
Python 3.8.5

% pip list | grep pandera
pandera           0.4.4

アクセスログを集計して得られたセッションログ (session_df) を例とし、これに対してバリデーションしてみます。

import pandas as pd

# セッションID (id) がインデックス
session_df = pd.DataFrame(
    {
        # ログインしている場合は3桁, していない場合はNone
        "login_id": ["U10", "I22", None, "U05"],
        # PC, SD, APのいずれか
        "device": ["SD", "PC", "SD", "AP"],
        # 9/11ランディングのみ
        "landing_time": pd.to_datetime(
            [
                "2020-09-11T00:00:00",
                "2020-09-11T00:00:12",
                "2020-09-11T00:01:07",
                "2020-09-11T00:01:30",
            ]
        ),
        # パスは"/"開始
        "landing_path": ["/", "/page/hoge", "/", "/pages"],
        # 滞在時間は0秒以上
        "duration_secs": [121, 63, 0, 90],
    },
    index=pd.Index([1001, 1002, 1003, 1004], name="id"),
)

print(session_df)
#      login_id device        landing_time landing_path  duration_secs
# id
# 1001      U10     SD 2020-09-11 00:00:00            /            121
# 1002      I22     PC 2020-09-11 00:00:12   /page/hoge             63
# 1003     None     SD 2020-09-11 00:01:07            /              0
# 1004      U05     AP 2020-09-11 00:01:30       /pages             90

panderaでのバリデーションは2ステップです。

  • ステップ1 DataFrameSchemaでインデックスやカラムごとのルールを定義
    • checksに1つ以上 (複数の場合はlist) のルール (= _CheckBaseのサブクラス) をセット
    • pandera.Check以下にstr_lengthやisin、rangeなど様々なルールが定義されています
    • スキーマはyamlでも定義できます
  • ステップ2 validateメソッドでバリデーション
    • 成功すると、入力したDataFrameが返されます
    • 失敗すると、SchemaErrorがraiseされます
import pandera as pa

# ステップ1
session_df_schema = pa.DataFrameSchema(
    index=pa.Index(pa.Int, name="id", allow_duplicates=False),
    columns={
        "login_id": pa.Column(
            pa.String,
            nullable=True,
            checks=pa.Check.str_length(min_value=3, max_value=3),
        ),
        "device": pa.Column(pa.String, checks=pa.Check.isin(["PC", "SD", "AP"])),
        "landing_time": pa.Column(
            pa.DateTime,
            checks=pa.Check.in_range(
                min_value=pd.to_datetime("2020-09-11T00:00:00"),
                max_value=pd.to_datetime("2020-09-12T00:00:00"),
                include_max=False,
            ),
        ),
        "landing_path": pa.Column(pa.String, checks=pa.Check.str_startswith("/")),
        "duration_secs": pa.Column(
            pa.Int, checks=pa.Check.greater_than_or_equal_to(0)
        ),
    },
)

# ステップ2
session_df_schema.validate(session_df)

誤った値 ("U100"で4桁) が入っているとvalidateでSchemaErrorが投げられますが、エラーメッセージに誤ったインデックス・カラム・ルールを含みますので、その後の調査や修正もしやすいです。

session_df = pd.DataFrame(
    {
        "login_id": ["U100", "I22", None, "U05"],
    # ...

session_df_schema.validate(session_df)
# ...
# pandera.errors.SchemaError: <Schema Column: 'login_id' type=string> failed element-wise validator 0:
# <Check _str_length: str_length(3, 3)>
# failure cases:
#    index failure_case
# 0   1001         U100

CheckにSeriesを受け取る関数を渡すことで、ルールを自分で定義することもできます。landing_pathのチェックを自前で実装している例を示します。

session_df_schema = pa.DataFrameSchema(
        ...,
        # "landing_path": pa.Column(pa.String, checks=pa.Check.str_startswith("/")),
        "landing_path": pa.Column(
            pa.String, checks=pa.Check(lambda s: s.str.startswith("/"))
        ),
        ...
    },
)

MySQL: カラムをちょっと変えてテーブルをコピーしたい

テーブルのレコードをコピーしたい、だけどちょっとカラムの定義や値を変えたいケースがあります。

例として、以下の要件を満たしつつ、postsからcopied_postsへコピーしたいとします。

  • コピーした日時 (copied_at) を追加したい
  • emailをNULLにしたい
mysql> create table copied_posts (
    ->     id int unsigned not null,
    ->     title varchar(256) not null,
    ->     body text,
    ->     email varchar(256),
    ->     created_at datetime not null,
    ->     copied_at datetime not null,
    ->     primary key (id)
    -> );

mysql> insert into posts (title, body, email, created_at) values ('title1', 'body1', 'hoge@fuga.com', now());
mysql> insert into posts (title, body, email, created_at) values ('title2', 'body2', 'hoge@fuga.com', now());
mysql> insert into posts (title, body, email, created_at) values ('title3', 'body3', 'hoge@fuga.com', now());

mysql> select * from posts;
+----+--------+-------+---------------+---------------------+
| id | title  | body  | email         | created_at          |
+----+--------+-------+---------------+---------------------+
|  1 | title1 | body1 | hoge@fuga.com | 2020-09-05 11:36:07 |
|  2 | title2 | body2 | hoge@fuga.com | 2020-09-05 11:36:33 |
|  3 | title3 | body3 | hoge@fuga.com | 2020-09-05 11:36:37 |
+----+--------+-------+---------------+---------------------+

mysql> create table copied_posts (
    ->     id int unsigned not null,
    ->     title varchar(256) not null,
    ->     body text,
    ->     email varchar(256),
    ->     created_at datetime not null,
    ->     copied_at datetime not null,
    ->     primary key (id)
    -> );

postsとcopied_postsのカラムの定義が異なるので、直接insertしようとすると当然エラーになります。

mysql> insert into copied_posts select * from posts;
ERROR 1136 (21S01): Column count doesn't match value count at row 1

コピーするだけなのにアプリケーションを組むのもハイコストですので、できればSQLで完結させたいところです。

最初にtemporary table (ここではtmp_posts) を作って、そのテーブルに対してalterやupdateを実行し、最後に目的テーブルにinsertすると楽にできます。

  • temporary tableはセッションが終了したら自動的にdropされます (参考)
mysql> create temporary table tmp_posts select * from posts;
mysql> alter table tmp_posts add column copied_at datetime;
mysql> update tmp_posts set copied_at = now();
mysql> update tmp_posts set email = NULL;
mysql> insert into copied_posts select * from tmp_posts;

Kubeflow Pipelines SDKを用いた並列処理の実装

最近はお仕事でKubeflow Pipelinesを触り始めています。
PythonでDAGを定義し、SDK (KFP) を使ってArgo Workflowのマニフェストを出力して、それをKubeflowにアップロードしてパイプラインを作る、という流れで開発しています。

今回は、タスクの一部を並列化する方法について備忘録としてまとめます。PythonとSDK (KFP) は以下のバージョンを用いています。

$ python --version
Python 3.8.5

$ pip list | grep kfp
kfp                      0.5.1

サンプルパイプライン

サンプルとして、1〜nの数字をm個ずつに分割してそれぞれで合計を求めるパイプラインを考えます。ユーザから与えられるパラメータはnとmです。

上を実現するために preprocessで分割 -> processで合計 という簡単な2段階のパイプラインを実装したものが、以下コードです (ここではmain.py) 。次に$ python main.pyで実行することでマニフェストを生成してKubeflowにアップロードします。

import kfp
from kfp import compiler, dsl, components

def _preprocess(n: int, m: int) -> list:
    num_lists = []
    for num in range(0, n, m):
        num_lists.append([i for i in range(num + 1, num + m + 1) if i <= n])
    return num_lists

def _process(num_lists: list) -> list:
    return [sum(nums) for nums in num_lists]

@kfp.dsl.pipeline(name="serial-pipeline")
def pipeline(n: int, m: int) -> None:
    image = "python:3.8-alpine"

    preprocess_func = components.func_to_container_op(_preprocess, base_image=image)
    preprocess_op = preprocess_func(n, m)

    process_func = components.func_to_container_op(_process, base_image=image)
    process_op = process_func(preprocess_op.output)

    process_op.after(preprocess_op)

if __name__ == "__main__":
    compiler.Compiler().compile(pipeline, "pipeline.yaml")

マニフェストをアップロードすると、以下のパイプラインが定義されます。このパイプラインを n=100, m=30 で実行すると [465, 1365, 2265, 955] がprocessから出力されます。

processの並列化

2段目のprocessを並列化していきます。

main.pyの実装を以下の用に変更します。キモはParallelForで、これを使うとprprocessの出力がprocessの入力にwithParamで渡されるマニフェストへ変換されます (https://argoproj.github.io/argo/examples/#loopsも参照) 。これによって、今回のように入力パラメータによって並列数を変えることができます。

  • _preprocessではJSON文字列にする
    • withParamで解釈できるフォーマットにしないといけないため
  • コンテキスト (with) の範囲ならprocessからさらに他のオペレータを連結させることもできます
import kfp
from kfp import compiler, dsl, components

def _preprocess(n: int, m: int) -> str:
    import json

    num_lists = []
    for num in range(0, n, m):
        num_lists.append([str(i) for i in range(num + 1, num + m + 1) if i <= n])
    return json.dumps([",".join(nums) for nums in num_lists])

def _process(nums: str) -> int:
    return sum([int(num) for num in nums.split(",")])

@kfp.dsl.pipeline(name="parallel-pipeline")
def pipeline(n: int, m: int) -> None:
    image = "python:3.8-alpine"

    preprocess_func = components.func_to_container_op(_preprocess, base_image=image)
    preprocess_op = preprocess_func(n, m)

    with dsl.ParallelFor(preprocess_op.output) as nums:
        process_func = components.func_to_container_op(_process, base_image=image)
        process_op = process_func(nums)
        process_op.after(preprocess_op)

    return process_op.output

if __name__ == "__main__":
    compiler.Compiler().compile(pipeline, "pipeline.yaml")

Kubeflowにアップロードすると、loopが入って3段のように見えます (左図) 。同じく n=100, m=30 で実行すると、processが4並列で実行されることが確認できます (右図) 。

並列後の集約

ここまでできたらあとは結果を集約 (fan-in) するタスクを組みたくなりますが、残念ながら2020/8時点のKFP 0.5.1では簡単にはできないようです。対応を待ちましょう。