CustomOperatorのUnitTestを理解するためGCSToBigQueryOperatorのUnitTestを読んでみた話

未知の連携先との入出力を行う際、CustomOperatorを作るという解決策があります。
CustomOperatorを自作した場合、そのテストをどう書くか、という問題が発生します。
ビルトインのGCSToBigQueryOperatorがどうテストされているかを読むと、雰囲気がわかりました。
UnitTestコードを読んで見ましたので、本記事で感想を書いてみます。

https://github.com/apache/airflow/blob/main/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py

前提となる知識

Airflowのhookについて理解する必要がありました。
フワッとしていますが、コードを読んで使われ方をながめているとイメージが湧いてきます。

hook

しばしば外部からデータを入力したり外部へデータを出力する必要が出てくる。
外部と接続する際にcredentialsを保管し使用する必要があるが、
Airflowはconnectionという概念のオブジェクトを用意している。
connection は conn_id により識別される。Airflow UIやCLIから管理できる。

connectionを直接操作するようなlow-levelコードを書くこともできるが、
煩雑にならないよう、外部リソース毎にhookというhigh-levelインターフェースが用意されている。

Connections & Hooks

pythonのunittestも理解する必要がありました。
unittestのmockについて以下が参考になりました。

UnitTestを読んでいく

TestGCSToBigQueryOperatorというクラスにUnitTestメソッドの実装例が書かれています。
python built-inのテストパッケージであるunittestが使用されています。
@mock.patchデコレータを使用しBigQueryHookをpatchしています。

BigQueryHookのmockインスタンスがhookとして渡ります。

hookのreturn_value, side_effectを差し替えてGCSToBigQueryOperatorインスタンスを実行します。
insert_job(),generate_job_id(),split_table_name(),get_job()の差し替えを行なっています。
メソッドの階層をドット(.)で繋いでより深い場所を差し替えられる様子です。
unittestを書いた人はコードが何に依存しているか分かるので、知識に基づいて依存しているものをmockします。


import json
from unittest import mock
from unittest.mock import MagicMock, call

TASK_ID = "test-gcs-to-bq-operator"
TEST_EXPLICIT_DEST = "test-project.dataset.table"
WRITE_DISPOSITION = "WRITE_TRUNCATE"
SCHEMA_FIELDS = [
    {"name": "id", "type": "STRING", "mode": "NULLABLE"},
    {"name": "name", "type": "STRING", "mode": "NULLABLE"},
]
MAX_ID_KEY = "id"
JOB_PROJECT_ID = "job-project-id"

TEST_BUCKET = "test-bucket"
TEST_SOURCE_OBJECTS = "test/objects/test.csv"
DATASET = "dataset"
TABLE = "table"

GCS_TO_BQ_PATH = "airflow.providers.google.cloud.transfers.gcs_to_bigquery.{}"

job_id = "123456"
hash_ = "hash"
REAL_JOB_ID = f"{job_id}_{hash_}"

class TestGCSToBigQueryOperator:
    @mock.patch(GCS_TO_BQ_PATH.format("BigQueryHook"))
    def test_max_value_external_table_should_execute_successfully(self, hook):
        hook.return_value.insert_job.side_effect = [
            MagicMock(job_id=REAL_JOB_ID, error_result=False),
            REAL_JOB_ID,
        ]
        hook.return_value.generate_job_id.return_value = REAL_JOB_ID
        hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
        hook.return_value.get_job.return_value.result.return_value = ("1",)
        operator = GCSToBigQueryOperator(
            task_id=TASK_ID,
            bucket=TEST_BUCKET,
            source_objects=TEST_SOURCE_OBJECTS,
            destination_project_dataset_table=TEST_EXPLICIT_DEST,
            write_disposition=WRITE_DISPOSITION,
            schema_fields=SCHEMA_FIELDS,
            max_id_key=MAX_ID_KEY,
            external_table=True,
            project_id=JOB_PROJECT_ID,
        )

“基づく知識”は第三者には理解不能ですが、GCSToBigQueryOperator.pyを読むと理由がわかります。
GCSToBigQueryOperatorのexecute(self, context:Context)を読むと、
先頭でBigQueryHookのインスタンスを取得し、BaseOperator由来のself.hookに設定しているようです。
generate_job_id()により、job_idを取得しています。
_use_existing_table()内で、split_table_name()により,ProjectID,Dataset,Tableを取得しています。
mockしたjob_idが既に存在している場合、get_job()で既存を取得しています。


    def execute(self, context: Context):
        hook = BigQueryHook(
            gcp_conn_id=self.gcp_conn_id,
            location=self.location,
            impersonation_chain=self.impersonation_chain,
        )
        self.hook = hook
        self.source_format = self.source_format.upper()

        job_id = self.hook.generate_job_id(
            job_id=self.job_id,
            dag_id=self.dag_id,
            task_id=self.task_id,
            logical_date=context["logical_date"],
            configuration=self.configuration,
            force_rerun=self.force_rerun,
        )

さて、Assertは以下のように書かれています。
GCSToBigQueryOperatorは、Source(GCS)から.csv等を読み込みDest(BigQuery)へ配置するものです。
Destの然るべき場所にテーブルが作られ、値が入ります。
execute()すると、max_id_keyで指定したカラムの最大値が戻るようです。

“test-bucket”に配置した”test/objects/test.csv”は”id”,”name”の2列からなるCSVで、
例えば”id”=”1″, “name”=”hoge”ならば、”id”列の最大値である1が戻るため、1をassertすればOKです。


        result = operator.execute(context=MagicMock())
        assert result == "1"

これだと、分岐をだいぶすっ飛ばしているので、だいぶ薄いカバレッジになるかと思います。

まとめ

GCSToBigQueryOperatorのUnitTestを読んでみました。分かってしまうと普通のUnitTestでした。
Source to Destのパターンはだいたい似たようになるのかも、なので、
作るUnitTestも似たような感じになるのかもしれません。