日々の学び

AirflowでEnd-To-End Pipeline Testsを行うためにAirflow APIを調べてみた話

Airflow自体にDAGの実行結果をテスト(End-To-End Pipeline Tests)する仕組みは無いようで、 以下のような地道な仕組みを自力で作る必要がありそうです。 テストデータを用意する Airflowが提供するAirflow APIを使用してDAGを実行する DAGの終了を待つ 結果をAssertする 他にAirflow CLIも使えそうですが、pythonコードの一部にするならAPIの方が使い勝手が良さそうです。 API仕様書を上から読んでみたので、その感想を書いてみます。 他にもあるのですが、今回の用途に使いそうなものを抜粋しています。 \"読んでみた\"だけなので、誤りがあるかもしれません。概要を理解するぐらいの気持ちで読んでください。 [arst_toc tag=\"h4\"] Airflow API概要 今日時点のAirflow APIのAPI仕様書は以下です。 Airflow API (Stable) (2.10.0) RESTful APIとなっていて、Resourceに対するCRUDをHTTP Methodで表現します。 1つ、update_maskという考え方があります。リソースの値を更新する際、リソースjsonと同時に クエリパラメタで\"変更したい値は何か\"を渡すことで、リソースjsonの該当値のみを更新できます。 resource = request.get(\'/resource/my-id\').json() resource[\'my_field\'] = \'new-value\' request.patch(\'/resource/my-id?update_mask=my_field\', data=json.dumps(resource)) API Authenticationがusername/passwordで雑ですが、 DAGのis_pausedをtrueにするには、以下の通りpatchを叩くようです。 curl -X PATCH \'https://example.com/api/v1/dags/{dag_id}?update_mask=is_paused\' -H \'Content-Type: application/json\' --user \"username:password\" -d \'{ \"is_paused\": true }\' CORSを有効にする必要があります。Enabling CORS 様々なAPI認証が用意されています。API認証はAirflowのauth managerで管理されます。Authentication エラーはRFC7807準拠です。つまり、Unauthenticated、PermissionDenied、BadRequest、NotFound、MethodNotAllowed、NotAcceptable、AlreadyExistsが扱われます。Errors Connections ざっとAPIを眺めていきます。 まずはConnection。順当なCRUDです。patchでupdate_maskが使われます。 コードから一通りConnectionを触れそうです。 Testって何か調べてみました。 デフォルトでdisabledになっていますが、Airflow UI(Connections)から\"Test\"ボタンを押下できます。 Connectionと関連付けられたhookのtest_connection()メソッドを実行するようです。 これと同等の機能が動くようです。 Method Endpoint Overview Response GET /connections List Connection array of objects(ConnectionCollectionItem). POST /connections Create a Connection created connection. GET /connections/{connection_id} Get a connection connection PATCH /connections/{connection_id} Update a connection updated connection DELETE /connections/{connection_id} Delete a connection (Status) POST /connections/test Test a connection (Status) DAG 次はDAG。まずDAG一覧に対する操作。一覧に対してpatchを叩ける様子です。 Method Endpoint Overview GET /dags List DAGs in the database. dag_id_pattern can be set to match dags of a specific pattern PATCH /dags Update DAGs of a given dag_id_pattern using UpdateMask. This endpoint allows specifying ~ as the dag_id_pattern to update all DAGs. New in version 2.3.0 次は個別のDAGに対する操作。 Method Endpoint Overview GET /dags/{dag_id} Get basic information about a DAG.Presents only information available in database (DAGModel). If you need detailed information, consider using GET /dags/{dag_id}/details. PATCH /dags/{dag_id} Update a DAG. DELETE /dags/{dag_id} Deletes all metadata related to the DAG, including finished DAG Runs and Tasks. Logs are not deleted. This action cannot be undone.New in version 2.2.0 GET /dags/{dag_id}/tasks/detail Get simplified representation of a task. GET /dags/{dag_id}/detail Get a simplified representation of DAG.The response contains many DAG attributes, so the response can be large. If possible, consider using GET /dags/{dag_id}. Airflowにおいて、Operatorのインスタンスに\"Task\"という用語が割り当てられています。 つまり、「Operatorに定義した処理を実際に実行すること」が\"Task\"としてモデリングされています。 「\"Task\"をA月B日X時Y分Z秒に実行すること」が、\"TaskInstance\"としてモデリングされています。 あるDAGは、実行日/実行時間ごとの複数の\"TaskInstance\"を保持しています。 以下のAPIにおいて、DAGが保持する\"Task\",\"日付レンジ\"等を指定して実行します。 \"TaskInstance\"を\"Clear(再実行)\"します。また、\"TaskInstance\"の状態を一気に更新します。 Method Endpoint Overview POST /dags/{dag_id}/clearTaskInstances Clears a set of task instances associated with the DAG for a specified date range. POST /dags/{dag_id}/updateTaskInstancesState Updates the state for multiple task instances simultaneously. GET /dags/{dag_id}/tasks Get tasks for DAG. なんだこれ、ソースコードを取得できるらしいです。 Method Endpoint Overview GET /dagSources/{file_token} Get a source code using file token. DAGRun \"Task\"と\"TaskInstance\"の関係と同様に\"DAG\"と\"DAGRun\"が関係しています。 「A月B日X時Y分Z秒のDAG実行」が\"DAGRun\"です。DAGRun。順当な感じです。 新規にトリガしたり、既存のDAGRunを取得して更新したり削除したり、再実行したりできます。 Method Endpoint Overview GET /dags/{dag_id}/dagRuns List DAG runs.This endpoint allows specifying ~ as the dag_id to retrieve DAG runs for all DAGs. POST /dags/{dag_id}/dagRuns Trigger a new DAG run.This will initiate a dagrun. If DAG is paused then dagrun state will remain queued, and the task won\'t run. POST /dags/~/dagRuns/list List DAG runs (batch).This endpoint is a POST to allow filtering across a large number of DAG IDs, where as a GET it would run in to maximum HTTP request URL length limit. GET /dags/{dag_id}/dagRuns/{dag_run_id} Get a DAG run. DELETE /dags/{dag_id}/dagRuns/{dag_run_id} Delete a DAG run. PATCH /dags/{dag_id}/dagRuns/{dag_run_id} Modify a DAG run.New in version 2.2.0 POST /dags/{dag_id}/dagRuns/{dag_run_id}/clear Clear a DAG run.New in version 2.4.0 以下はスキップ.. Method Endpoint Overview GET /dags/{dag_id}/dagRuns/{dag_run_id}/upstreamDatasetEvents Get datasets for a dag run.New in version 2.4.0 PATCH /dags/{dag_id}/dagRuns/{dag_run_id}/setNote Update the manual user note of a DagRun.New in version 2.5.0 DAGWarning DAGのimport_errors一覧を返します。 Method Endpoint Overview GET /dagWarnings List Dag Waranings. DAGStats A DAG Run status is determined when the execution of the DAG is finished. The execution of the DAG depends on its containing tasks and their dependencies. The status is assigned to the DAG Run when all of the tasks are in the one of the terminal states (i.e. if there is no possible transition to another state) like success, failed or skipped. The DAG Run is having the status assigned based on the so-called “leaf nodes” or simply “leaves”. Leaf nodes are the tasks with no children. There are two possible terminal states for the DAG Run: success if all of the leaf nodes states are either success or skipped, failed if any of the leaf nodes state is either failed or upstream_failed. Method Endpoint Overview GET /dagStats List Dag statistics. ImportError Airflow Best PractiveのTesting a DagにDAGのテスト観点に関する記述が(サラッと)書かれています。 まず、DAGは普通のpythonコードなので、pythonインタプリタで実行する際にエラーが起きないことを確認すべし、とのことです。 以下の実行により、未解決の依存関係、文法エラーをチェックします。もちろん、どこで実行するかが重要なので、DAG実行環境と合わせる必要があります。 Airflow APIにより、このレベルのエラーがDAGファイルにあるか確認できるようです。 $ python your-dag-file.py Method Endpoint Overview GET /importErrors List import errors. GET /importErrors/{import_error_id} Get an import error. Variables DAGに記述したくないCredentials等を管理する仕組みで、Airflow UIからポチポチ操作すると作れます。 Variableはkey-valueそのままです。DAGからkeyを指定することで参照できます。 Airflow APIからもVariableをCRUDできます。 Method Endpoint Overview GET /variables List variables.The collection does not contain data. To get data, you must get a single entity. POST /variables Create a variable. GET /variables/{variable_key} Get a variable by key. PATCH /variables/{variable_key} Update a variable by key. DELETE /variables/{variable_key} Delete a variable by key. まとめ RESTfulAPIが用意されているということは、内部のオブジェクトをCRUD出来るということなのだろう、 という推測のもと、Airflow APIのAPI仕様書を読んで感想を書いてみました。 Airflowの概念と対応するリソースはAPIに出現していて、End-To-End Pipeline Testを書く際に、Assert、実行制御を記述できそうな気持ちになりました。 Assert、実行制御、だけなら、こんなに要らない気もします。 API呼び出し自体の煩雑さがあり、Testの記述量が増えてしまうかもしれません。 以下の記事のようにwrapperを書く必要があるかもしれません。 https://github.com/chandulal/airflow-testing/blob/master/src/integrationtest/python/airflow_api.py DAGの入力側/出力側Endに対するファイル入出力は別で解決が必要そうです。 「API仕様書を読んでみた」の次の記事が書けるときになったら、再度まとめ記事を書いてみようと思います。

CustomOperatorのUnitTestを理解するためGCSToBigQueryOperatorのUnitTestを読んでみた話

未知の連携先との入出力を行う際、CustomOperatorを作るという解決策があります。 CustomOperatorを自作した場合、そのテストをどう書くか、という問題が発生します。 ビルトインのGCSToBigQueryOperatorがどうテストされているかを読むと、雰囲気がわかりました。 UnitTestコードを読んで見ましたので、本記事で感想を書いてみます。 https://github.com/apache/airflow/blob/main/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py 前提となる知識 Airflowのhookについて理解する必要がありました。 フワッとしていますが、コードを読んで使われ方をながめているとイメージが湧いてきます。 hook しばしば外部からデータを入力したり外部へデータを出力する必要が出てくる。 外部と接続する際にcredentialsを保管し使用する必要があるが、 Airflowはconnectionという概念のオブジェクトを用意している。 connection は conn_id により識別される。Airflow UIやCLIから管理できる。 connectionを直接操作するようなlow-levelコードを書くこともできるが、 煩雑にならないよう、外部リソース毎にhookというhigh-levelインターフェースが用意されている。 Connections & Hooks pythonのunittestも理解する必要がありました。 unittestのmockについて以下が参考になりました。 [clink implicit=\"false\" url=\"https://qiita.com/satamame/items/1c56e7ff3fc7b2986003\" imgurl=\"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQqktBhX0kv-C4zk1lu0D8T0ExDUFQdNdu9dQ&s\" title=\"Python の unittest の mock\" excerpt=\"Python の unittest を使っていて、mock が何をするものかは分かっているけど、まだちょっと得体が知れない、怖い、という段階があると思います。この段階を克服するために、何も知らない状態から徐々に mock を理解するためのステップを作りたいと思いました。対象は Python 3.x です。\"] UnitTestを読んでいく TestGCSToBigQueryOperatorというクラスにUnitTestメソッドの実装例が書かれています。 python built-inのテストパッケージであるunittestが使用されています。 @mock.patchデコレータを使用しBigQueryHookをpatchしています。 BigQueryHookのmockインスタンスがhookとして渡ります。 hookのreturn_value, side_effectを差し替えてGCSToBigQueryOperatorインスタンスを実行します。 insert_job(),generate_job_id(),split_table_name(),get_job()の差し替えを行なっています。 メソッドの階層をドット(.)で繋いでより深い場所を差し替えられる様子です。 unittestを書いた人はコードが何に依存しているか分かるので、知識に基づいて依存しているものをmockします。 import json from unittest import mock from unittest.mock import MagicMock, call TASK_ID = \"test-gcs-to-bq-operator\" TEST_EXPLICIT_DEST = \"test-project.dataset.table\" WRITE_DISPOSITION = \"WRITE_TRUNCATE\" SCHEMA_FIELDS = [ {\"name\": \"id\", \"type\": \"STRING\", \"mode\": \"NULLABLE\"}, {\"name\": \"name\", \"type\": \"STRING\", \"mode\": \"NULLABLE\"}, ] MAX_ID_KEY = \"id\" JOB_PROJECT_ID = \"job-project-id\" TEST_BUCKET = \"test-bucket\" TEST_SOURCE_OBJECTS = \"test/objects/test.csv\" DATASET = \"dataset\" TABLE = \"table\" GCS_TO_BQ_PATH = \"airflow.providers.google.cloud.transfers.gcs_to_bigquery.{}\" job_id = \"123456\" hash_ = \"hash\" REAL_JOB_ID = f\"{job_id}_{hash_}\" class TestGCSToBigQueryOperator: @mock.patch(GCS_TO_BQ_PATH.format(\"BigQueryHook\")) def test_max_value_external_table_should_execute_successfully(self, hook): hook.return_value.insert_job.side_effect = [ MagicMock(job_id=REAL_JOB_ID, error_result=False), REAL_JOB_ID, ] hook.return_value.generate_job_id.return_value = REAL_JOB_ID hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE) hook.return_value.get_job.return_value.result.return_value = (\"1\",) operator = GCSToBigQueryOperator( task_id=TASK_ID, bucket=TEST_BUCKET, source_objects=TEST_SOURCE_OBJECTS, destination_project_dataset_table=TEST_EXPLICIT_DEST, write_disposition=WRITE_DISPOSITION, schema_fields=SCHEMA_FIELDS, max_id_key=MAX_ID_KEY, external_table=True, project_id=JOB_PROJECT_ID, ) \"基づく知識\"は第三者には理解不能ですが、GCSToBigQueryOperator.pyを読むと理由がわかります。 GCSToBigQueryOperatorのexecute(self, context:Context)を読むと、 先頭でBigQueryHookのインスタンスを取得し、BaseOperator由来のself.hookに設定しているようです。 generate_job_id()により、job_idを取得しています。 _use_existing_table()内で、split_table_name()により,ProjectID,Dataset,Tableを取得しています。 mockしたjob_idが既に存在している場合、get_job()で既存を取得しています。 def execute(self, context: Context): hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, location=self.location, impersonation_chain=self.impersonation_chain, ) self.hook = hook self.source_format = self.source_format.upper() job_id = self.hook.generate_job_id( job_id=self.job_id, dag_id=self.dag_id, task_id=self.task_id, logical_date=context[\"logical_date\"], configuration=self.configuration, force_rerun=self.force_rerun, ) さて、Assertは以下のように書かれています。 GCSToBigQueryOperatorは、Source(GCS)から.csv等を読み込みDest(BigQuery)へ配置するものです。 Destの然るべき場所にテーブルが作られ、値が入ります。 execute()すると、max_id_keyで指定したカラムの最大値が戻るようです。 \"test-bucket\"に配置した\"test/objects/test.csv\"は\"id\",\"name\"の2列からなるCSVで、 例えば\"id\"=\"1\", \"name\"=\"hoge\"ならば、\"id\"列の最大値である1が戻るため、1をassertすればOKです。 result = operator.execute(context=MagicMock()) assert result == \"1\" これだと、分岐をだいぶすっ飛ばしているので、だいぶ薄いカバレッジになるかと思います。 まとめ GCSToBigQueryOperatorのUnitTestを読んでみました。分かってしまうと普通のUnitTestでした。 Source to Destのパターンはだいたい似たようになるのかも、なので、 作るUnitTestも似たような感じになるのかもしれません。

GoogleによるAirflow DAG実装のベスプラ集を読んでみた – その1

GoogleによるフルマネージドAirflow環境であるCloud Composerを使う必要があり、 いそぎでAirflow+Cloud Composerをキャッチアップすることになりました。 Googleが公開するベスプラ集があることを知り、読んでみることにしました。 Cloud Composerと題されていいますが、ほぼAirflowと読み替えて良いのかなと思います。 書かれているのは、少し基本的なシナリオだと思います。 経験に裏付けられたゴリゴリの集合知、というものはタダでは手に入らないのだろうと思います。 スタート地点に立つ際の道しるべ、ぐらいの気持ちです。 おそらく一緒に使うシナリオが多いであろう、データ変換ツールのdbtと競合するものがあります。 大構造としてはAirflow DAGの下にdbt DAGが来るため、Airflow DAGのベスプラを実現する前提で dbt DAGを書いていくものと考えていました。これだけだとバッティングすると思います。 ウェアハウスとは切り離されています。特にBigQueryを前提にするならもう少し踏み込んだ内容と なるはずだと思いますが、ちょっと書かれていないようです。 いったん半分くらい読んでみたので読書感想文を書いてみました。 [clink implicit=\"false\" url=\"https://cloud.google.com/blog/ja/products/data-analytics/optimize-cloud-composer-via-better-airflow-dags\" imgurl=\"https://www.gstatic.com/pantheon/images/welcome/supercloud.svg\" title=\"Airflow DAG の改良による Cloud Composer の最適化\" excerpt=\"このガイドには、Apache Airflow DAG を作成する際に一般的に適用できる、作業項目のチェックリストが掲載されています。これらのチェック項目は、Google Cloud とオープンソース コミュニティが判断したベスト プラクティスに沿ったものとなっています。一連の高パフォーマンスの DAG によって Cloud Composer の動作が最適化され、標準化された作成方法によって、デベロッパーが数百個、あるいは数千個の DAG でも管理できるようになります。チェック項目のそれぞれが、Cloud Composer 環境と開発プロセスにメリットをもたらします。\"] [arst_toc tag=\"h4\"] はじめに ファイル名を標準化します Workflowの特徴を個別に表現するだけでは不十分で、ファイル名が含む部分文字列がサブ機能や属性のインデックスとなっていて欲しい。さらに、ファイル名から機能概要を逆引き推測できたら便利。 作成した DAG ファイルのコレクションを他のデベロッパーが容易に参照できるようにするためです。例: team_project_workflow_version.py DAG は決定的でなければなりません 入力となるデータが同じであれば、出力は常に同じであるべき、という観点だと思う。 例えば、入力となるデータは同じであっても、実行時間に依存して処理を行ってしまうと、 出力が時間依存となる。テストが無茶苦茶大変になるだろうなと思う。 Airflow DAG単体であれば、そう理解に難しくないポイントだとは思う。 しかし、dbt DAGを含めると一気に縛りがキツくなると思う。 特定の入力によって常に同じ出力が生成される必要があります DAG はべき等でなければなりません 大雑把に書けば、ある操作を1回行っても数回行っても結果が同じであることを言う。 これを実現する仕組みの選択は結構悩ましく、「追加する範囲をいったん削除追加する」が簡単。 しかし、この方法だと無駄なスキャン量が発生する。 dbtを使用する場合、incremental modelがべき等性担保の手段として挙げられることが多いが、 実際にべき等性を担保するには考慮しないといけないことがある。 こちらの記事(dbtで「Incremental」を使わずに冪等性を担保する方法について)が詳しい。 DAG を何度トリガーしても、毎回同じ効果 / 結果が得られなければなりません 例えば、以下のように書けば、入力テーブルが変わらない限りべき等となる。 これを行うには、入力テーブルに「ロード日時」といったメタデータが必要となる。 {{ config( materialized=\"incremental\" ) }} ・・・ {%- if is_incremental() %} WHERE ORDERDATE = TO_DATE(\'{{ var(\'load_date\') }}\') {%- endif %} タスクはアトミック、かつ、べき等でなければなりません ちょっと何言っているかわからない書きっぷり。データベースのACID特性のAtomic性を意識する。 ある操作が一連の処理の完了によって達成される場合、部分的に処理の一部が成功する、という 状態になってしまってはいけない。全ての処理が成功したか、全ての処理が失敗したか、のどちらか、 になっていないといけない。 タスクごとに、他のオペレーションとは独立して再実行できる 1つのオペレーションを処理するようにします。タスクがアトミックな場合、そのタスクの一部が成功したことは、タスク全体が成功したことを意味します。 可能な限り単純な DAG にします ちょっと一般的すぎて良くわからない。 「スケジューリングのコスト」って、「実行のコスト」よりもだいぶ小さいんじゃなかろうか、と思うが、 それでも意識しないといけないのだろうか。 ネストされたツリー構造は、そもそも理解しづらくて避けるべきだろう、とは思う。 タスク間の依存関係が少ない単純な DAG にすると、オーバーヘッドが少なくなるため、スケジューリングのパフォーマンスが向上する傾向があります。一般的に、多数の依存関係がある深くネストされたツリー構造よりも線形構造(例: A->B->C)にしたほうが、効率的な DAG になります。 Python の docstring 規則に従って、各関数のファイルの上部にドキュメントを記述してください AirflowはPythonで書けることが最大の特徴なので、そのメリットを発揮するため、docstringでコメント書けよと。 Python の docstring 規則は、他のデベロッパーやプラットフォーム エンジニアが Airflow DAG を理解するために役立ちます。 関数と同様に BashOperator のドキュメントも作成するようにしてください。DAG で bash スクリプトを参照している場合、スクリプトの目的を記したドキュメントがないと、このスクリプトに詳しくないデベロッパーにはトラブルシューティングが困難です。 DAG の作成を標準化する default_args にオーナーを追加します。 タスクを得る(Operatorをインスタンス化する)際に、各Operatorのコンストラクタに引数を与えるが、 複数のOperatorに渡す引数を共通化したい場合には、default_argsをDAG()に与える。 こうすると、Operatorにdefault_argsで設定した引数を与えたことになる。 各Operatorに引数を与えると、defautl_argsをオーバーライドする動作となる。 過去の公式でdefault_argsは、task_id と owner が mandatory(必須) であるとされている。 これについて、Why is \'owner\' a mandatory argument for tasks and dags? という記事がある。 それに続くPRは More detail on mandatory task arguments であり、mandatoryの根拠を聞いている。 歴史的な理由による、で片付いているな。ベスプラではownerに実装者のメアドなどを書けという。 実装担当者を明らかにせよ、という話であればcomitterを見れば良いだけで、ちょっと意味不明。 mandatoryなので、何か入れないといけないなら、とりあえず実装者のメアドを入れておけ、ということか。 import pendulum with DAG( dag_id=\'my_dag\', start_date=pendulum.datetime(2016, 1, 1, tz=\"UTC\"), schedule_interval=\'@daily\', catchup=False, default_args={\'owner\': \'hoge@ikuty.com\'}, ) as dag: op = BashOperator(task_id=\'dummy\', bash_command=\'Hello World!\') print(op.retries) # 2 dag = DAG() ではなく with DAG() as dag を使用します。 Pythonのwith文の仕様。コンテキストマネージャという言う。 try... except... finally をラップするため、リソースの確保と対応する解放が必ず行われる。 with DAG(...):文の下のインデントの中では、各Operatorのコンストラクタにdagインスタンスを 渡さなくてよくなる。 すべてのオペレーターまたはタスクグループに DAG オブジェクトを渡す必要がなくなるようにします。 DAG ID 内にバージョンを設定します。 以下とのこと。あぁ..バージョニングが実装されていないので手動でバージョニングを行うべし、と。 AirflowはDAGをファイルIDで管理しているため、ファイルIDを変更するとUI上、別のDAGとして扱われるよう。 積極的にDAG IDを変更して、Airflow UIに無駄な情報を出さないようにする、というアイデア。 DAG 内のコードを変更するたびにバージョンを更新します。 こうすると、削除されたタスクログが UI に表示されなくなることや、ステータスのないタスクが古い DAG 実行に対して生成されること、DAG 変更時の一般的な混乱を防ぐことができます。 Airflow オープンソースには、将来的にバージョニングが実装される予定です。 DAG にタグを追加します。 公式はこちら。Add tags to DAGs and use it for filtering in the UI 単にUI上の整理のために留まらず、処理の記述に積極的に使うことができる様子。 これを無秩序に使うと扱いづらくなりそうなので、使う場合は用途を明確にしてから使うべきかと思う。 1. デベロッパーがタグをフィルタして Airflow UI をナビゲートできるようにします。 2. 組織、チーム、プロジェクト、アプリケーションなどを基準に DAG をグループ化します。 DAG の説明を追加します。 唐突で非常に当たり前なのだが、あえて宣言することが大事なんだろうと思う。 他のデベロッパーが自分の DAG の内容を理解できるようにします。 作成時には DAG を一時停止します。 なるほど。 こうすると、誤って DAG が実行されて Cloud Composer 環境の負荷が増すという事態を回避できます。 catchup=Falseに設定して自動キャッチアップによるCloud Composer環境の過負荷を避けます。 まず、catchupの前に、Airflowの実行タイミングが直感的でなさそう。 こちらの記事がとても参考になった。【Airflow】DAG実行タイミングを改めて纏めてみた DAGの実行タイミングはstart_dateとschedule_intervalを利用して計算される。 重要なポイントはschedule_intervalの終了時にDAGが実行される、という点。 また、schedule_intervalはウインドウ枠を表している。 例えば 0 11 * * * であれば、毎日11:00-翌日11:00という時間の幅を表す。 start_date=7月15日、schedule_interval=0 11 * * * のとき、 7月15日 11:00から 7月16日11:00までの期間が終わった後、DAGが開始される。 DAGをデプロイする際、デプロイ日時よりも古いstart_dateを設定することができる。 このとき、start_dateからデプロイ日時までの間で、本来完了しているはずだが実行していない schedule_intervalについてDAGを実行する機能がcatchup。 catchup=Trueとすると、これらのschedule_intervalが全て再実行の対象となる。 一方、catchup=Falseとsるうと、これらのうち、最後だけが再実行の対象となる。 (Falseとしても、最後の1回は再実行される) 過去のデータを自動投入するとか、危ないので、確認しながら手動実行すべきだと思う。 もし本当にcatchupするのであれば、計画的にFalseからTrueにすべきだろうし、 その時は負荷を許容できる状況としないといけない。 DAGが完了せずにCloud Composer環境のリソースが保持されることや、再試行時に競合が引き起こされることのないよう、 dagrun_timeout を設定します DAG、タスク、それぞれにタイムアウトプロパティが存在する。それぞれ理解する必要がある。 DAGタイムアウトはdagrun_timeout、タスクタイムアウトはexecution_timeout。 以下が検証コード。job1のexecution_timeout引数をコメントアウトしている。 コメントアウトした状態では、dagrun_timeoutがDAGのタイムアウト時間となる。 検証コードにおいては、タイムアウト時間が15秒のところ、タスクで20秒かかるのでタイムアウトが起きる。 execution_timeout引数のコメントアウトを外すと、DAGのタイムアウト時間が タスクのタイムアウト時間で上書きされ30秒となる。 タスクで20秒かかってもタイムアウトとならない。 from datetime import timedelta from airflow.utils.dates import days_ago from airflow import DAG from airflow.operators.python import PythonOperator def wait(**context): time.sleep(20) defalut_args = { \"start_date\": days_ago(2), \"provide_context\": True } with DAG( default_args=defalut_args, dagrun_timeout=timedelta(seconds=15), ) as dag: job1 = PythonOperator( task_id=\'wait_task1\', python_callable=wait, # execution_timeout=timedelta(second=30) ) ベスプラの言うところは、ちゃんとタイムアウトを設定しろよ、ということだと思う。 インスタンス化で DAG に引数を渡し、すべてのタスクにデフォルトで同じ start_date が設定されるようにします Airflowでは、Operatorのコンストラクタにstart_dateを与えられるようになっている。 同一DAGに所属するタスクが異なるstart_dateを持つ、という管理が大変なDAGを作ることも出来てしまう。 基本的には、DAGにstart_dateを渡して、タスクのデフォルトを揃えるべき、だそう。 DAG では静的な start_date を使用します。 これがベスプラになっているのはかなり助かる。 動的な start_date を使用した場合、誤った開始日が導き出され、失敗したタスク インスタンスやスキップされた DAG 実行を消去するときにエラーが発生する可能性があります。 retries を、DAG レベルで適用される default_args として設定します。 retriesについても、start_dateと同様にDAGレベルで default_args として設定するそう。 なお、タスクのリトライに関する設定には以下のようなものがある。 retries (int) retry_delay (datetime.timedelta) retry_exponential_backoff (bool) max_retry_delay (datetime.timedelta) on_retry_callback (callable) retries (int)は、タスクが\"失敗\"となる前に実行されるリトライ回数。 retry_delay (datetime.timedelta)はリトライ時の遅延時間。 retry_exponential_backoff (bool)はリトライ遅延での指数関数的後退アルゴリズムによるリトライ間隔の待ち時間を増加させるかどうか max_retry_delay (datetime.timedelta)はリトライ間の最大遅延間隔 on_retry_callback (callable)はリトライ時のコールバック関数 適切な再試行回数は 1~4 回です。再試行回数が多すぎると、Cloud Composer 環境に不要な負荷がかかります。 具体的に retries を何に設定すべきか、について書かれている。 ここまでのまとめ ここまでのステートメントがコードになっている。 わかりやすい。 import airflow from airflow import DAG from airflow.operators.bash_operator import BashOperator # default_args 辞書を定義して、DAG のデフォルト パラメータ(開始日や頻度など)を指定する default_args = { \'owner\': \'me\', \'retries\': 2, # 最大再試行回数は 2~4 回にすること \'retry_delay\': timedelta(minutes=5) } # `with` ステートメントを使用して DAG オブジェクトを定義し、一意の DAG ID と default_args 辞書を指定する with DAG( \'dag_id_v1_0_0\', # ID にバージョンを含める default_args=default_args, description=\'This is a detailed description of the DAG\', # 詳しい説明 start_date=datetime(2022, 1, 1), # 静的な開始日 dagrun_timeout=timedelta(minutes=10), # この DAG に固有のタイムアウト is_paused_upon_creation= True, catchup= False, tags=[\'example\', \'versioned_dag_id\'], # この DAG に固有のタグ schedule_interval=None, ) as dag: # BashOperator を使用してタスクを定義する task = BashOperator( task_id=\'bash_task\', bash_command=\'echo \"Hello World\"\' )

やりなおし統計

分散と標準偏差を計算しやすくする

[mathjax] 分散と標準偏差を計算しやすく変形できる。 いちいち偏差(x_i-bar{x})を計算しておかなくても、2乗和(x_i^2)と平均(bar{x})がわかっていればOK。 begin{eqnarray} s^2 &=& frac{1}{n} sum_{i=1}^n (x_i - bar{x})^2 \\ &=& frac{1}{n} sum_{i=1}^n ( x_i^2 -2 x_i bar{x} + bar{x}^2 ) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 -2 bar{x} sum_{i=1}^n x_i + bar{x}^2 sum_{i=1}^n 1) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 -2 n bar{x}^2 + nbar{x}^2 ) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 - nbar{x}^2 ) \\ &=& frac{1}{n} sum_{i=1}^n x_i^2 - bar{x}^2 \\ s &=& sqrt{frac{1}{n} sum_{i=1}^n x_i^2 - bar{x}^2 } end{eqnarray} 以下みたい使える。平均と標準偏差と2乗和の関係。 begin{eqnarray} sum_{i=1}^n (x_i - bar{x})^2 &=& sum_{i=1}^n x_i^2 - nbar{x}^2 \\ ns^2 &=& sum_{i=1}^n x_i^2 - nbar{x}^2 \\ sum_{i=1}^n x_i^2 &=& n(s^2 + bar{x} ) end{eqnarray}

標本調査に必要なサンプル数の下限を与える2次関数

[mathjax] 2項分布に従う母集団の母平均を推測するために有意水準を設定して95%信頼区間を求めてみた。 母平均のあたりがついていない状況だとやりにくい。 [clink url=\"https://ikuty.com/2019/01/11/sampling/\"] (hat{p})がどんな値であっても下限は(hat{p})の関数で抑えられると思ったので、 気になって(hat{p})を変数のまま残すとどうなるかやってみた。 begin{eqnarray} 1.96sqrt{frac{hat{p}(1-hat{p})}{n}} le 0.05 \\ frac{1.96}{0.05}sqrt{hat{p}(1-hat{p})} le sqrt{n} \\ 39.2^2 hat{p}(1-hat{p}) le n end{eqnarray} 左辺を(f(hat{p}))と置くと (f(hat{p}))は下に凸の2次関数であって、 (frac{d}{dhat{p}}f(hat{p})=0)の時に最大となる。というか(hat{p}=0.5)。 (hat{p}=0.5)であるとすると、これはアンケートを取るときのサンプル数を求める式と同じで、 非常に有名な以下の定数が出てくる。 begin{eqnarray} 1537 * 0.5 (1-0.5) le n \\ 384 le n end{eqnarray} (hat{p})がどんな値であっても、サンプル数を400とれば、 有意水準=5%の95%信頼区間を得られる。 だから、アンケートの(n)数はだいたい400で、となる。 さらに、有意水準を10%にとれば、(n)の下限は100で抑えられる。 なるはやのアンケートなら100、ちゃんとやるには400、というやつがこれ。

標本調査に必要なサンプル数を素人が求めてみる。

[mathjax] ちょっと不思議な計算をしてみる。 仮定に仮定を積み重ねた素人の統計。 成功か失敗かを応答する認証装置があったとする。 1回の試行における成功確率(p)は試行によらず一定でありベルヌーイ試行である。 (n)回の独立な試行を繰り返したとき、成功数(k)を確率変数とする離散確率変数に従う。 二項分布の確率密度関数は以下の通り。 begin{eqnarray} P(X=k)= {}_n C_k p^k (1-p)^{n-k} end{eqnarray} 期待値、分散は、 begin{eqnarray} E(X) &=& np \\ V(X) &=& np(1-p) end{eqnarray} (z)得点(偏差値,つまり平均からの誤差が標準偏差何個分か?)は、 begin{eqnarray} z &=& frac{X-E(X)}{sigma} \\ &=& frac{X-E(X)}{sqrt{V(X)}} \\ &=& frac{X-np}{sqrt{np(1-p)}} end{eqnarray} であり、(z)は標準正規分布に従う。 これを標本比率(hat{p}=frac{X}{n})を使うように式変形する。 begin{eqnarray} z &=& frac{frac{1}{n}}{frac{1}{n}} frac{X-np}{sqrt{np(1-p)}} \\ &=& frac{frac{X}{n}-p}{sqrt{frac{p(1-p)}{n}}} \\ &=& frac{hat{p}-p}{sqrt{frac{p(1-p)}{n}}} end{eqnarray} (n)が十分に大きいとき、(z)は標準正規分布(N(0,1))に従う。 従って、(Z)の95%信頼区間は以下である。 begin{eqnarray} -1.96 le Z le 1.96 end{eqnarray} なので、 begin{eqnarray} -1.96 le frac{hat{p}-p}{sqrt{frac{p(1-p)}{n}}} le 1.96 end{eqnarray} (hat{p})は(p)の一致推定量であるから、(n)が大なるとき(hat{p}=p)とすることができる。 begin{eqnarray} -1.96 le frac{hat{p}-p}{sqrt{frac{hat{p}(1-hat{p})}{n}}} le 1.96 \\ end{eqnarray} (p)について解くと(p)の95%信頼区間が求まる。 begin{eqnarray} hat{p}-1.96 sqrt{frac{hat{p}(1-hat{p})}{n}} le p le hat{p}+1.96 sqrt{frac{hat{p}(1-hat{p})}{n}} end{eqnarray} 上記のにおいて、標準誤差(1.96sqrt{frac{hat{p}(1-hat{p})}{n}})が小さければ小さいほど、 95%信頼区間の幅が狭くなる。この幅が5%以内であることを言うためには以下である必要がある。 (有意水準=5%) begin{eqnarray} 1.96sqrt{frac{hat{p}(1-hat{p})}{n}} le 0.05 end{eqnarray} 観測された(hat{p})が(0.9)であったとして(n)について解くと、 begin{eqnarray} 1.96sqrt{frac{0.9(1-0.9)}{n}} le 0.05 \\ frac{1.96}{0.05} sqrt{0.09} le sqrt{n} \\ 11.76 le sqrt{n} \\ 138.2 le n end{eqnarray} 139回試行すれば、100回中95回は(p)は以下の95%信頼区間に収まる。 つまり95%信頼区間は以下となる。 begin{eqnarray} hat{p}-1.96 sqrt{frac{hat{p}(1-hat{p})}{n}} &le& p le hat{p}+1.96 sqrt{frac{hat{p}(1-hat{p})}{n}} \\ 0.9-1.96 frac{sqrt{0.09}}{sqrt{139}} &le& p le 0.9 + 1.96 frac{sqrt{0.09}}{sqrt{139}} \\ 0.9-1.96 frac{0.3}{11.78} &le& p le 0.9+1.96 frac{0.3}{11.78} \\ 0.85 &le& p le 0.95 end{eqnarray} (n)を下げたい場合は有意水準を下げれば良い。 統計的に有意水準=10%まで許容されることがある。 有意水準が10%であるとすると、(n)は35以上であれば良いことになる。 begin{eqnarray} 1.96sqrt{frac{0.9(1-0.9)}{n}} le 0.1 \\ frac{1.96}{0.1} sqrt{0.09} le sqrt{n} \\ 5.88 le sqrt{n} \\ 34.6 le n end{eqnarray} 信頼区間と有意水準の式において(p)を標本から取ってきたけど、 アンケートにおいてYes/Noを答える場合、(p)は標本における最大値(つまり0.5)を 設定して(n)を求める。 つまり、(p)として利用するのは標本比率ではないのかな?と。 このあたり、(hat{p})を変数として残すとどういうことがわかった。 [clink url=\"https://ikuty.com/2019/01/13/sampling_with2/\"]

深層学習

勾配降下法

[mathjax] 各地点において関数の値を最大にするベクトル((frac{partial f}{partial x_0},frac{partial f}{partial x_1}))を全地点に対して計算したものを勾配とかいう。 ある地点において、このベクトルの方向に向かうことにより最も関数の値を大きくする。 で、今後のために正負を反転して関数の値を最小にするベクトルを考えることにした。 関数の値を小さくする操作を繰り返していけば、いずれ\"最小値\"が見つかるはず。 というモチベを続けるのが勾配降下法。学習率(eta)を使って以下みたいに書ける。。 begin{eqnarray} x_0 = x_0 - eta frac{partial f}{partial x_0} \\ x_1 = x_1 - eta frac{partial f}{partial x_1} end{eqnarray} ということで(f(x_0,x_1)=x_0^2+x_1^2)の最小値を初期値((3.0,4.0))、 学習率(eta=0.1)に設定して計算してみる。 import numpy as np def numerical_gradient(f, x): h = 1e-4 grad = np.zeros_like(x) for idx in range(x.size): tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) x[idx] = tmp_val - h fxh2 = f(x) grad[idx] = (fxh1 - fxh2) / (2*h) x[idx] = tmp_val return grad def gradient_descent(f, init_x, lr=0.01, step_num=100): x = init_x for i in range(step_num): grad = numerical_gradient(f,x) x -= lr * grad return x def function2(x): return x[0]**2 + x[1]**2 init_x = np.array([-3.0, 4.0]) v = gradient_descent(function2, init_x=init_x, lr=0.1, step_num=100) v # array([-6.11110793e-10, 8.14814391e-10]) ((0,0))に収束した。 ニューラルネットワークの勾配 損失関数を重みパラメータで微分する。以下みたいな感じ。 損失関数の大小を見るとして、例えば(w_{11})以外の重みを固定したとして(w_{11})をわずかに 増やしたときに損失関数の値がどれだけ大きくなるか。 損失関数の値はパラメータ(W)と入力(x)から決まるベクトルだけれども、それぞれ乱数と入力値が設定されている。 begin{eqnarray} W= begin{pmatrix} w_{11} & w_{12} & w_{13} \\ w_{21} & w_{22} & w_{23} end{pmatrix}, frac{partial L}{partial W}= begin{pmatrix} frac{partial L}{partial w_{11}} & frac{partial L}{partial w_{12}} & frac{partial L}{partial w_{13}} \\ frac{partial L}{partial w_{21}} & frac{partial L}{partial w_{22}} & frac{partial L}{partial w_{23}} end{pmatrix} end{eqnarray} 重み(W)が乱数で決まるネットワークがあるとする。このネットワークは入力と重みの積を出力 として返す。出力はSoftmaxを経由するとする。 ネットワークの出力と教師データのクロスエントロピー誤差を誤差として使う。 その前に、数値微分関数を多次元対応する。 普通、配列の次元が(n)個になると(n)重ループが必要になるけれども、 Numpy.nditer()を使うと(n)乗ループを1回のループにまとめることができる。 下のmulti_indexが((0,0),(0,1),(0,2),(1,0),(1,1),(1,2))みたいに イテレータが(n)次のタプルを返す。反復回数はタプルの要素数の直積。 Numpy配列にそのタプルでアクセスすることで晴れて全ての要素にアクセスできる。 def numerical_gradient_md(f, x): h = 1e-4 grad = np.zeros_like(x) it = np.nditer(x, flags=[\'multi_index\'], op_flags=[\'readwrite\']) while not it.finished: idx = it.multi_index tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) # f(x+h) x[idx] = tmp_val - h fxh2 = f(x) # f(x-h) grad[idx] = (fxh1 - fxh2) / (2*h) x[idx] = tmp_val # 値を元に戻す it.iternext() return grad 初期値(x=(0.6,0.9))、教師データ(t=(0,0,1))をネットワークに入力する。 predict()は(1 times 3)を返す。 それをSoftmax()を通して、(t)とのクロスエントロピー誤差を求めたものが以下。 import numpy as np def cross_entropy_error(y, t): if y.ndim == 1: t = t.reshape(1, t.size) y = y.reshape(1,y.size) batch_size = y.shape[0] delta = 1e-7 return -np.sum( t * np.log( y + delta)) / batch_size def softmax(x): c = np.max(x) return np.exp(x-c) / np.sum(np.exp(x-c)) import sys, os sys.path.append(os.pardir) import numpy as np class simpleNet: def __init__(self): self.W = np.random.randn(2,3) def predict(self, x): return np.dot(x, self.W) def loss(self, x, t): z = self.predict(x) y = softmax(z) loss = cross_entropy_error(y, t) return loss net = simpleNet() x = np.array([0.6, 0.9]) p = net.predict(x) t = np.array([0, 0, 1]) net.loss(x, t) # 0.9463818740797788 このlossを(W)で微分したのが以下。 あえてパラメータ(W)を引数にとり損失関数の値を計算する(f(W))を定義することで、 数値微分が何と何の演算なのかをわかりやすくしている。 実際は(f(W))は(W)とは関係なく(x)と(t)だけから結果を返すけれども、 損失関数(f(W))を(W)で微分するという操作が自明になるようにコードを合わせている。 def f(W): return net.loss(x, t) dW = numerical_gradient_md(f, net.W) dW # array([[ 0.07627371, 0.49923236, -0.57550607], # [ 0.11441057, 0.74884853, -0.8632591 ]]) 結果の解釈 上記の(w),(W),(t)から(frac{partial L}{partial W})が求まった。 損失関数が何か複雑な形をしているという状況で、 (frac{partial L}{partial w_{11}})は(w_{11})がわずかに動いたときに損失関数の値が変化する量を表している。 それが(w_{11})から(w_{32})まで6個分存在する。 begin{eqnarray} frac{partial L}{partial W} = begin{pmatrix} frac{partial L}{partial w_{11}} & frac{partial L}{partial w_{21}} & frac{partial L}{partial w_{31}} \\ frac{partial L}{partial w_{12}} & frac{partial L}{partial w_{22}} & frac{partial L}{partial w_{32}} end{pmatrix} = begin{pmatrix} 0.07627371 & 0.49923236 & -0.57550607 \\ 0.11441057 & 0.74884853 & -0.8632591 end{pmatrix} end{eqnarray}

勾配の可視化

[mathjax] 2変数関数(f(x_0,x_1))を各変数で偏微分する。 地点((i,j))におけるベクトル((frac{partial f(x_0,j)}{partial x_0},frac{partial f(i,x_1)}{partial x_1}))を全地点で記録していき、ベクトル場を得る。 このベクトル場が勾配(gradient)。 (f(x_0,x_1)=x_0^2+x_1^2)について、(-4.0 le x_0 le 4.0)、(-4.0 le x_1 le 4.0)の範囲で、 勾配を求めてみる。また、勾配を可視化してみる。 まず、2変数関数(f(x_0,x_1))の偏微分係数を求める関数の定義。 ((3.0,3.0))の偏微分係数は((6.00..,6.00..))。 def numerical_gradient(f, x): h = 10e-4 grad = np.zeros_like(x) for idx in range(x.size): tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) x[idx] = tmp_val - h fxh2 = f(x) grad[idx] = (fxh1 - fxh2) / 2*h x[idx] = tmp_val return grad def function2(x): return x[0]**2 + x[1]**2 p = np.array([3.0,3.0]) v = numerical_gradient(function2, p) v # array([6.e-06, 6.e-06]) (-4.0 le x_0 le 4.0)、(-4.0 le x_1 le 4.0)の範囲((0.5)刻み)で偏微分係数を求めて、 ベクトル場っぽく表示してみる。matplotlibのquiver()は便利。 各地点において関数の値を最も増やす方向が表示されている。 w_range = 4 dw = 0.5 w0 = np.arange(-w_range, w_range, dw) w1 = np.arange(-w_range, w_range, dw) wn = w0.shape[0] diff_w0 = np.zeros((len(w0), len(w1))) diff_w1 = np.zeros((len(w0), len(w1))) for i0 in range(wn): for i1 in range(wn): d = numerical_gradient(function2, np.array([ w0[i0], w1[i1] ])) diff_w0[i1, i0], diff_w1[i1, i0] = (d[0], d[1]) plt.xlabel(\'$x_0$\',fontsize=14) #x軸のラベル plt.ylabel(\'$x_1$\',fontsize=14) #y軸のラベル plt.xticks(range(-w_range,w_range+1,1)) #x軸に表示する値 plt.yticks(range(-w_range,w_range+1,1)) #y軸に表示する値 plt.quiver(w0, w1, diff_w0, diff_w1) plt.show() 値が大きい方向に矢印が向いている。例えば((-3.0,3.0))における偏微分係数は((-6.0,6.0))。 左上方向へのベクトル。 参考にしている本にはことわりが書いてあり、勾配にマイナスをつけたものを図にしている。 その場合、関数の値を最も減らす方向が表示されることになる。 各地点において、この勾配を参照することで、どちらに移動すれば関数の値を最も小さくできるかがわかる。

おっさんが数値微分を復習する

引き続き、ゼロDの写経を続ける。今回は、学生の頃が懐かしい懐ワード、数値微分。 全然Deepに入れないけれどおっさんの復習。解析的な微分と数値微分が一致するところを確認してみる。 昔と違うのは、PythonとJupyterNotebookで超絶簡単に実験できるし、 こうやってWordPressでLaTeXで記事を書いたりできる点。 [mathjax] まず、微分の基本的な考え方は以下の通り。高校数学の数3の範囲。 begin{eqnarray} frac{df(x)}{fx} = lim_{hrightarrow infty} frac{f(x+h)-f(x)}{h} end{eqnarray} 情報系学科に入って最初の方でEuler法とRunge-Kutta法を教わってコードを 書いたりレポート書いたりする。懐すぎる..。 または、基本情報の試験かなんかで、小さい値と小さい値どうしの計算で発生する問題が現れる。 ゼロDにはこの定義を少し改良した方法が載っている。へぇ。 begin{eqnarray} frac{df(x)}{fx} = lim_{hrightarrow infty} frac{f(x+h)-f(x-h)}{2h} end{eqnarray} 写経なので、がんばって数値微分を書いて動かしてみる。 簡単な2次関数(f(x))。 begin{eqnarray} f(x) &=& x^2 - 5x +3 \\ f\'(x) &=& 2x - 5 end{eqnarray} def numerical_diff(f, x): h = 10e-4 return (f(x+h) - f(x-h)) / (2*h) (f(x))と、(x=2.5)のところの(f\'(x))をmatplotlibで書いてみる。懐い... import matplotlib.pyplot as plt import numpy as np def f(x): return x**2 - 5*x + 3 x = np.arange(-10, 10, 0.1) y = f(x) dy = numerical_diff(f,x) plt.plot(x, y) x1 = -2.5 dy1 = numerical_diff(f, x1) y1 = f(x1) # y-y1 = dy1(x-x1) -> y = dy1(x-x1) + y1 j = lambda x: dy1 * (x-x1) + y1 plt.plot(x,j(x)) plt.xlabel(\'x\') plt.ylabel(\'y\') plt.grid() plt.show() 偏微分 2変数以上の関数の数値微分は以下の通り。片方を止める。 数値微分の方法は上記と同じものを使った。 begin{eqnarray} frac{partial f(x_0,x_1)}{partial x_0} &=& lim_{hrightarrow infty} frac{f(x_0 +h,x_1)-f(x_0-h,x_1)}{2h} \\ frac{partial f(x_0,x_1)}{partial x_1} &=& lim_{hrightarrow infty} frac{f(x_0,x_1+h)-f(x_0,x_1-h)}{2h} end{eqnarray} ((x_0,x_1)=(1,1))における(x_0)に対する偏微分(frac{partial f(x_0,x_1)}{x_0})、(x_1)に対する偏微分(frac{partial f(x_0,x_1)}{x_1})を求めてみる。 ちゃんと(frac{partial f(x_0,1.0)}{x_0}=2.00..)、(frac{partial f(1.0,x_1)}{x_1}=2.00..)になった。 import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.mplot3d import Axes3D def f(x): return x[0]**2 + x[1]**2 X = np.meshgrid(np.arange(-5., 5., 0.2),np.arange(-5., 5., 0.2)) Z = f(X) fig = plt.figure(figsize=(6, 6)) axes = fig.add_subplot(111, projection=\'3d\') axes.plot_surface(X[0],X[1], Z) f0 = lambda x: x**2 + 1.0**2 f1 = lambda x: 1.0**2 + x**2 df0 = numerical_diff(f0, 1.0) df1 = numerical_diff(f1, 1.0) print(df0) # 2.0000000000000018 print(df1) # 2.0000000000000018 plt.show()