SageMaker用のコードをローカルで動かす – scikit-learnの決定木でアヤメの種類を分類

SageMakerはローカルで使うことができるので、それを試してみた。
この記事を書くにあたって以下の公式の記事を参考にしています。
オンプレミス環境から Amazon SageMaker を利用する

機械学習のHelloWorld

アヤメデータをschikit-learnの決定木分類器で学習して種類を予測する.
結構いろいろなところで機械学習のHelloWorldとして使われている例題を題材にしていく.
sagemaker-python-sdk/scikit_learn_iris

既に SageMaker用のサンプルコードがあるので、
これをローカルで学習・推論できるように修正していく。

構成は以下の通り.
公式のブログの通り、SageMaker Notebook用に書かれた.ipynb を Local用に微修正するだけで動く。

  • SageMaker Notebookで動かすための.ipynb
  • .ipynbから呼ぶschikit-learnコード

SageMakerのサンプルはSageMakerのJupyterNotebookで動くように書かれているがが、
ちょっと修正するだけでローカルで動くようになる様子。(1つしか試してないけど)

前準備

学習と推論をローカルで行うが、そのために裏でDockerのコンテナが走る。
ローカルコンピュータ用にDockerをインストールしておく必要がある。

CredentialsとIAM

以下が必要。

  • AmazonSageMakerFullAccess 権限をもった IAM ユーザの Credential
  • AmazonSageMakerFullAccess の IAM ロール

ローカルコードからAWSリソースにアクセスするために aws configure を使って設定する。
Credentialsが書かれたcsvをダウンロードし aws configure の応答に答えていく。


$ pip install awscli --upgrade --user
$ aws configure
AWS Access Key ID [None]: ******************
AWS Secret Access Key [None]: ************************************
Default region name [None]: ap-northeast-1
Default output format [None]: json

SageMaker PythonSDKインストール

SageMaker PythonSDKをインストールする。
実行するコードに応じてSDKのバージョンを指定することができる。


$ pip install -U sagemaker >=2.15

バージョンを指定しない場合は以下の通り。


$ pip install sagemaker

ローカルのJupyter Notebookでファイルを修正

scikit_learn_estimator_example_with_batch_transform.ipynb を
ローカルのJupyter Notebookで修正していく。

SageMaker ローカルSessionを開始

SageMakerを想定したコードは以下。


# S3 prefix
prefix = "Scikit-iris"

import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

# Get a SageMaker-compatible role used by this Notebook Instance.
role = get_execution_role()

それをローカルで動かすために以下のように修正する
localSession()というセッションが用意されているのでそれを使用する。
ローカルでは get_execution_role()では ロールを取得できないので直接ロールのARNを指定する。


# S3 prefix
prefix = "Scikit-iris"

import sagemaker
from sagemaker import get_execution_role

# LocalSession()を使用する
sagemaker_session = sagemaker.local.LocalSession() # sagemaker.Session()から変更

# Get a SageMaker-compatible role used by this Notebook Instance.
# ローカルでは get_execution_role()は使えない。直接ロールのARNを指定する。
# role = get_execution_role()
role = 'arn:aws:iam::(12桁のAWSアカウントID):role/(ロール名)'

学習用データの準備 (変更なし)

学習用データが巨大であればS3にデータを準備する(と書かれている).
アヤメデータは軽量なので、ローカルファイルに保存する。


import numpy as np
import os
from sklearn import datasets

# Load Iris dataset, then join labels and features
iris = datasets.load_iris()
joined_iris = np.insert(iris.data, 0, iris.target, axis=1)

# Create directory and write csv
os.makedirs("./data", exist_ok=True)
np.savetxt("./data/iris.csv", joined_iris, delimiter=",", fmt="%1.1f, %1.3f, %1.3f, %1.3f, %1.3f")

その後、用意したローカルデータをSageMaker Python SDKに食わせる。


WORK_DIRECTORY = "data"

train_input = sagemaker_session.upload_data(
    WORK_DIRECTORY, key_prefix="{}/{}".format(prefix, WORK_DIRECTORY)
)

Scikit learn Estimator

scikit-learnの機械学習は以下の3段構成になっている。

  • Estimator: 与えられたデータから学習(fit)する
  • Transformer: 与えられたデータを変換(transform)する
  • Predictor: 与えられたデータから結果を予測(Predict)する

SageMakerは機械学習プラットフォームであって、かなり多くのライブラリや手法がサポートされている。
その中で、scikit-learnもサポートされていて、SKLearn Estimatorとして使用できる。
要は、schikit-learn のI/Fに準じたコードを SageMaker に内包することができる。
SKLearn Estimatorに scikit-learn コードを食わせると SageMakerから SKLearn インスタンスとして操作できる.

例えば、.ipynbで以下のように書く.


from sagemaker.sklearn.estimator import SKLearn

FRAMEWORK_VERSION = "0.23-1"
script_path = "scikit_learn_iris.py"

sklearn = SKLearn(
    entry_point=script_path,
    framework_version=FRAMEWORK_VERSION,
    instance_type="local",
    role=role,
    sagemaker_session=sagemaker_session,
    hyperparameters={"max_leaf_nodes": 30},
)

SKLearnにentry_pointとして渡しているのがscikit-learnのコード本体。
内容は以下。普通の決定木分類のコードにSageMakerとのIFに関わるコードが追加されている。

実行時引数として、SM_MODEL_DIR、SM_OUTPUT_DATA_DIR、SM_CHANNEL_TRAINが渡される。
fitで学習した結果(つまり係数)をシリアライズしSM_MODEL_DIRに保存する。

model_fnでは、SM_MODEL_DIRにシリアライズされた係数をデシリアライズし、
scikit-learnの決定木分類木オブジェクトを返す。


#  Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License").
#  You may not use this file except in compliance with the License.
#  A copy of the License is located at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  or in the "license" file accompanying this file. This file is distributed
#  on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
#  express or implied. See the License for the specific language governing
#  permissions and limitations under the License.

from __future__ import print_function

import argparse
import os

import joblib
import pandas as pd
from sklearn import tree

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Hyperparameters are described here. In this simple example we are just including one hyperparameter.
    parser.add_argument("--max_leaf_nodes", type=int, default=-1)

    # Sagemaker specific arguments. Defaults are set in the environment variables.
    parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
    parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAIN"])

    args = parser.parse_args()

    # Take the set of files and read them all into a single pandas dataframe
    input_files = [os.path.join(args.train, file) for file in os.listdir(args.train)]
    if len(input_files) == 0:
        raise ValueError(
            (
                "There are no files in {}.\n"
                + "This usually indicates that the channel ({}) was incorrectly specified,\n"
                + "the data specification in S3 was incorrectly specified or the role specified\n"
                + "does not have permission to access the data."
            ).format(args.train, "train")
        )
    raw_data = [pd.read_csv(file, header=None, engine="python") for file in input_files]
    train_data = pd.concat(raw_data)

    # labels are in the first column
    train_y = train_data.iloc[:, 0]
    train_X = train_data.iloc[:, 1:]

    # Here we support a single hyperparameter, 'max_leaf_nodes'. Note that you can add as many
    # as your training my require in the ArgumentParser above.
    max_leaf_nodes = args.max_leaf_nodes

    # Now use scikit-learn's decision tree classifier to train the model.
    clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
    clf = clf.fit(train_X, train_y)

    # Print the coefficients of the trained classifier, and save the coefficients
    joblib.dump(clf, os.path.join(args.model_dir, "model.joblib"))


def model_fn(model_dir):
    """Deserialized and return fitted model

    Note that this should have the same name as the serialized model in the main method
    """
    clf = joblib.load(os.path.join(model_dir, "model.joblib"))
    return clf

学習

SageMaker(またはローカル)の.ipynb は、SKLearnインスタンスに対して fit() を実行するだけで良い。


sklearn.fit({"train": train_input})

推論

なんと、推論はWebインターフェースになっている。
推論コンテナ内でnginxが動作し、PythonWebAppが wsgi(gunicorn) を介してnginxから入力/応答する。

SageMaker(またはローカル)の.ipynbからは、SKLearnインスタンスに対してdeploy()を実行する。
推論コンテナへのインターフェースとなるインスタンスが生成され、後はこのインスタンスに対してpredict()を呼ぶ。

推論用にデータを集めて predict()を実行する例。
テストデータと推論の結果を並べて表示している。 うまくいっていれば同じになるはず。
(こんな風に訓練データとテストデータを拾って良いのかはさておき…)


predictor = sklearn.deploy(initial_instance_count=1, instance_type="local")

import itertools
import pandas as pd

shape = pd.read_csv("data/iris.csv", header=None)

a = [50 * i for i in range(3)]
b = [40 + i for i in range(10)]
indices = [i + j for i, j in itertools.product(a, b)]

test_data = shape.iloc[indices[:-1]]
test_X = test_data.iloc[:, 1:]
test_y = test_data.iloc[:, 0]

print(predictor.predict(test_X.values))
print(test_y.values)

/invocationsというURLに対してPOSTリクエストが発行されている。
応答は以下の通り、
テストデータの説明変数と、predict()の結果得られた値が一致していそう。


hqy7i6eoyi-algo-1-vq8df | 2021-05-22 16:26:50,617 INFO - sagemaker-containers - No GPUs detected (normal if no gpus installed)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 2. 2. 2. 2.
 2. 2. 2. 2. 2.]hqy7i6eoyi-algo-1-vq8df | 172.23.0.1 - - [22/May/2021:16:26:51 +0000] "POST /invocations HTTP/1.1" 200 360 "-" "python-urllib3/1.26.4"

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 2. 2. 2. 2.
 2. 2. 2. 2. 2.]

まとめ

SageMaker用の機械学習のHelloWorldをローカルで動かしてみた。
(かなり雑だけども), SageMakerのサンプルコードをちょっと修正するだけでローカルで動くことがわかった。