日々の学び

Snowflake External OAuthについての公式ドキュメントを読んでみた話

はじめに Enterpriseにおいて「お前は誰か?」を確認する手段は非常に多岐にわたる。 セキュリティと絡んで手段は拡大傾向にあり、新しい認証手段への追従が求められるケースは多い。 自前で認証情報を保有、管理し、セキュリティの保証を担保した手順を用意するのは不可能に近い。 現実的には認証情報の保有と管理、および認証手段を専用のプラットフォームに移譲させたい。 実際、認証の泥臭いプロセスはIdP(Identity Provider)が面倒を見てくれる。 SnowflakeはIdPと薄く関係して、IdPによる認証結果を使い回すことができる。 SnowflakeはIdPがどういったプロセスで認証したのかは一切関与しない。 認証後、「お前にこの権限を与えて良いか?」を実装しなければならない場合、 アプリ側に機能サポートがなければ、コードでそれを保証しなければならない。 Snowflakeは、ここをExternal OAuth統合として汎化しフルにサポートしている。 具体的には、SnowflakeはExternal OAuth統合として汎化していて、 OAuth2.0認可サーバと統合し、RBACとの紐付けまでを面倒みてくれる。 RBACの最小範囲であるスキーマより細かい粒度を区別する場合でなければ、 RBACだけで区別が完了することとなり、大幅な工数削減と品質安定化を達成できる。 昔Fitbit APIのOAuth2.0フローを実装した時から始まり、 過去に何件かWebアプリ開発で認証認可まわりの実装をしたと思う。 Webアプリの認証認可F/Wはかなり枯れていて、正直中身を知らなくても書けてしまう。 開発者人口が少ないSaaSサービスであるSnowflakeがブラックボックス化した 認証認可の仕組みを読み解くのは、Webアプリのそれとは次元の違う大変さがある。 (こと認証認可の文脈では安全性の保証がセットとなるため) Snowflake External OAuthについて厳密に調べる機会があったので、 生成AIを使わず100%自分の思考と言葉で記事を起こしていく。 [arst_toc tag=\"h4\"] 認証(AuthN) 認証、つまり、Authenticationは、「お前は誰か」を確認すること。 IdPにID/PWを登録しておきID/PWを入力したりMFAを通ることで「確かに〇〇さんだ」と確認すること。 単一要素認証(SFA)、多要素認証(MFA)、パスキー認証、FIDO2認証、他、多様な認証方式がある。 またシングルサインオン(SSO)、により組織を跨ぐ連携を行うことができる。 サービス間のSSO方式としてSAML2.0、API等のSSO方式としてOIDC2.0が広く使われている。 顧客管理のIdPによる認証を本IdPに引き継ぐIDフェデレーションにより組織間認証連携を実現できる。 認可(AuthZ) 一方認可、つまり、Authorizationは、「お前にこの権限を与えて良いか」を確認すること。 認可とは「誰がどのデータにどんなルールでアクセスして良いか」をコントロールする設計パターン。 「ルール作りの設計思想」と「システム間で権限をやり取りする技術規格」がごっちゃに扱われがち だが、レイヤが異なる2つの話を分けておくと少しわかりやすくなる。 「ルール作りの設計思想」 例えば以下のようにルールを定める。 ロールベースアクセス制御/Role Based Access Control ユーザ個人ではなく役割に対して権限を付与しユーザをそのロールに所属させる方式。管理者権限のユーザには作成・削除を与え、一般権限のユーザには閲覧のみを与えるなど、一般的な認可方式。SnowflakeのロールモデルはまさにRBACに基づく。 属性ベースアクセス制御/Attribute Based Access Control ロールだけでなくユーザの所属、勤務地、アクセスする時間帯、デバイスの種類など、複数の属性(コンテキスト)を組み合わせて動的に認可を判断する方式。 「システム間で権限をやり取りする技術規格」 例えば以下のようにルールを実現する技術規格を表す。 OAuth2.0 現在のWebで最も普及している「トークンベース」の認可フレームワーク。認可サーバーが発行した「アクセストークン(時限式のカードキー)」をアプリが提示し、リソースサーバー(Snowflakeなど)がそれを検証してアクセスを許可する。「権限の証明書」としてJWT(JSON Web Token)が実際にやり取りされる。JWTは、SON形式のデータを暗号論的に署名したもので、中身に「ユーザー名」「有効期限」、「付与されたロール(権限スコープ)」などが書き込まれている。 ケルベロス認証・認可 (Kerberos) 主に一昔前からの 社内ネットワーク(Active Directory)環境などで広く使われている方式。チケット」と呼ばれる暗号化されたデータをやり取りすることで一度のログインで社内のファイルサーバーやプリンタなどの利用権限(認可)をシームレスに得る。 あああ External OAuth External OAuthは顧客のOAuth2.0認可サーバを統合してシームレスなSSOを実現する。 認証プロセスはサービス側が気にするものではなく、本機能は認可の統合であることに注意すること。 なお公式(外部 OAuth の概要)は間違いなく認証・認可と言う言葉をごっちゃにしている。 OAuth2.0はRFC6749でThe OAuth2.0 Authorization Frameworkと定義されている。 受け渡しされるトークンはOIDCのような認証トークンではなく、OAuth2.0の認可トークンである。 外部OAuthという(認可の)仕組みをSnowflakeに設定しておくことで、 「外部のIdPが認証したという証明書」をSnowflakeが安全に受け取ってデータアクセス認可する仕組みだ。 公式(外部 OAuth の概要)によると、以下に公式に対応している。 公式にない場合は、外部 OAuth 用のカスタム認証サーバーを構成するで構成できる。 なお「公式」でないからといって「非対応」ではない。「公式」になくても汎用OAuth2.0用のカスタム認証サーバーとして構成できる。 Okta - 外部OAuth用Oktaの構成 Auth0はOktaファミリーだが↑では構成できない。カスタム認証サーバーとして構成が必要 Microsoft EntraID - 外部 OAuth 用 Microsoft Entra ID の構成 Ping Identity PingFederate - 外部 OAuth 用 Microsoft Entra ID の構成 Microsoft PowerBI - Power BI SSO からSnowflakeへ 公式にはExternal OAuthのメリットとして以下が挙げられている。 トークンの発行を認証サーバーに委任し、発行されたトークンの管理に集中できるようになる。 ログイン時のセキュリティルール(MFAやIP制限、承認フローなど)を、Prj IdP側に統合できる。 ユーザがその認証と許可に関する厳しいルール(テスト)をクリアしない限り、IdPはトークンを発行しない。 怪しいユーザはSnowflakeの入り口にすら辿り着けず、データは完璧に守られる。 認証をIdPに持たせることでSnowflake側から認証情報を除去できるためセキュアになる。 一見して認証のことしか書かれていないようだが、implicitに認可が書かれている。 Snowflakeは認可をIdPに完全に移譲し、認証とセットで認可が行われたトークンを確認するだけ、 ということは、Snowflake側に認可コードを一切書くことなしに認可を実現することと同義。 External OAuthの認証部分の基本フロー 公式に基本フローの図が貼ってある。ステップ1だけ構成時にのみ行う。他は都度実行される。 最初にセキュリティ統合の構成と、アプリ内の実装が開発者側の責務となる。 ベスプラに従ってルールから逸脱しないように構成することで、後はSaaSサービス間の自動連携となる。 外部OAuth認証サーバとSnowflakeのセキュリティ統合を構成し信頼性を確立する ユーザはアプリを介してSnowflakeにアクセスしようとする。アプリはユーザを確認しようとする 認証サーバはOAuthトークンをアプリに返す SnowflakeドライバはOAuthトークンを使用して接続文字列をSnowflakeに渡す SnowflakeはOAuthトークンを検証する Snowflakeはユーザ検索を実行する Snowflakeはユーザのロールに基づいてセッションをインスタンス化する External OAuthの認可部分、スコープ いきなり「スコープ」というワードが出てくるが、これ、JWTの\"scope\"キー/バリューのこと。 OAuth2.0においてJWTで認可範囲を設定するのだ、という理解と記憶がなければ読めない。 JWTは以下のような構成となっておりscopeを格納する場所がある。 認可サーバ側で何らかの許可処理の結果、ユーザのスコープが決まり、Snowflakeに送られる。 このトークンがSnowflakeに届くと、Snowflakeはscopeキーのバリューを読み取り、 「このユーザにはST_USER_ROLEというロール(権限)を適用してセッションを始めるべき」と判断する。 { \"iss\": \"https://your-project-idp.auth0.com/\", \"sub\": \"user_12345\", \"email\": \"user@client.com\", \"exp\": 1719100000, \"scope\": \"session:role:ST_USER_ROLE\" <-- 🌟これが「スコープ」 } Okta, PingFederate, カスタムの場合は以下のパターンを使用しなければならない。 スコープ 説明 session:role:<custom_role> Snowflakeのカスタムロールにマップする。例えばsession:role:ST_USER_ROLEで、ST_USER_ROLEにマップ session:role:public Snowflakeの PUBLIC ロールにマップ session:role-any 外部OAuthサーバでのSnowflakeロール管理を行わない場合これを渡す。特定のロールを固定せず、そのユーザに付与されているロールであれば、ログイン後に自由に切り替えて(USE ROLEして)使って良い、という少し緩めの認可 なお、以下のビルトインロールはデフォルトではブロックされる。 ACCOUNTADMIN GLOBALORGADMIN ORGADMIN SECURITYADMIN Snowflake OAuthは、セッション中のロールのセカンダリロールへの切り替えをサポートしていないが、 External OAuthでのセカンダリロールの使用はサポートしている。 External OAuth特有のセキュリティの抜け穴と対策 Snowflakeにおいて、アカウントレベルでネットワークポリシーによりIP制限をかけていたとしても、 External OAuthと合わせて構成するSecurity Integrationを経由してログインしてくる場合、 そのユーザ個人のIP制限が無視されてしまう、という仕様がある。 つまり、IdP側のIP制限が破られたり、トークンが盗まれたりした場合、 攻撃者はどこからでもSnowflakeのデータにアクセスできてしまう状態になる。 Snowflakeは、External OAuth自体にもネットワークポリシーを直接紐づけることを推奨している。 具体的にはSecurity Integrationにネットワークポリシーを直接紐づける。 これによりIdPから届いたトークンであっても、ネットワークポリシーで許可されたIPアドレス以外からの リクエストであれば、Snowflakeはセッションを開始しない。 これはIdPフェデレーション等で複雑化したユーザ組織の通信経路を全て把握する必要性を言っている。 こういうの、デフォルトで安全側に振って欲しいなとは思う。 カスタム認証サーバーの構成・トークンペイロード要件 カスタム認証サーバーがSnowflakeに送信するアクセストークンには、下表が含まれている必要がある。 クレーム 説明 scp Snowflake のカスタムロールを指定する文字列が含まれていること。値として session:role:ST_USER_ROLE のような Snowflake 指定の形式の文字列を、配列またはスペース区切りの文字列で必ず埋め込まなければならない。 scope 同上。IdPプロダクトによりscpかscopeのどちらかを入れる。 aud Snowflake アカウントの完全な URL(https://.snowflakecomputing.com)が含まれている必要がある。 exp 有効時間。トークンの有効期限が UNIX タイムスタンプ(エポック秒)で刻まれている必要がある。Snowflake はトークンを受け取った瞬間の時刻とこの exp を比較します。有効期限が過去の時刻になっている(期限切れ)場合は、その時点で認可を即座に拒否する。 iss 発行者。アクセストークンを発行したプリンシパルを文字列 URI として識別。つまりトークンを発行した IdPのアイデンティティ(例: https://your-project-idp.auth0.com/)。最後のスラッシュ(/)の有無まで1文字違わず一致させる必要がある。Snowflake 側の EXTERNAL_OAUTH_ISSUER で指定した文字列と完全に一致する必要がある。 iat 発行時刻。必須。JWT が発行された時刻を識別 カスタム認証サーバーの構成・セキュリティ統合の作成 External OAuth を実現する Snowflakeのリソースの実体。 カスタム認証サーバからのアクセストークンと安全に通信して検証し、アクセストークンに 関連付けられたユーザーロールに基づいてSnowflakeへのアクセスをユーザに提供する。 create security integration external_oauth_custom type = external_oauth enabled = true external_oauth_type = custom external_oauth_issuer = \'\' external_oauth_rsa_public_key = \'\' external_oauth_audience_list = (\'\', \'\') external_oauth_token_user_mapping_claim = \'upn\' external_oauth_snowflake_user_mapping_attribute = \'login_name\'; それぞれの内容は下表の通り。 パラメータ 説明 EXTERNAL_OAUTH_ISSUER 外部認証サーバー(IdP)を一意に識別するURL(発行元URL)を指定する。IdPから発行されるアクセストークン(JWT)の iss クレームの値と完全に一致する必要がある。 EXTERNAL_OAUTH_JWS_KEYS_URL 外部認証サーバーが公開している、デジタル署名の検証に必要な公開鍵(JWKS)が配置されたURLを指定する。SnowflakeはこのURLにアクセスしてトークンの妥当性を検証する。 EXTERNAL_OAUTH_TOKEN_USER_MAPPING_CLAIM 外部認証サーバーが発行するアクセストークン(JWT)の中で、ユーザーの識別情報(メールアドレスやユーザーIDなど)が格納されている「キー(クレーム名)」を指定する。 EXTERNAL_OAUTH_SNOWFLAKE_USER_MAPPING_ATTRIBUTE トークンから抽出したユーザー識別情報を、Snowflake側の USER オブジェクトのどの属性(EMAIL_ADDRESS または LOGIN_NAME)と一致させるかを指定する。 カスタム認証サーバーの構成・テスト 公式では、最短パスで構成を検証するため、curl で HTTP Post を送る手順が書かれている。 IdP側にテストユーザを作成しておく。テストユーザはパスワードを持つ必要がある Snowflake側にも、上記と同じメールアドレス(または識別子)を持つ USER オブジェクトを事前に作っておく。login_name, または emailでマッピングする IdP側の画面でこのテスト用のアカウントを作成し、専用のClientID, ClinetSecretを取得する 次のように、 OAuth 2.0クライアントがカスタムトークンエンドポイントに POST リクエストすることを許可 OAuth 2.0の用語でいう grant_type = password(Resource Owner Password Credentials Grant)方式を使うこと。すなわち「リソース所有者に設定された付与タイプ」であり、アプリ画面を介さず、ユーザーのID/PWを直接リクエストに含めてトークンを即時発行してもらう、テスト専用の最短ルートを構築する。 準備で用意したclientID と clientSecretをHTTP Basic認証ヘッダーに含めること リクエストのBody(送信データ)には、FORM形式(application/x-www-form-urlencoded)で、テストユーザーのID/PWと、Snowflakeに渡したいスコープを指定すること curl -X POST -H \"Content-Type: application/x-www-form-urlencoded;charset=UTF-8\" --user : --data-urlencode \"username=\" --data-urlencode \"password=\" --data-urlencode \"grant_type=password\" --data-urlencode \"scope=session:role:analyst\" 公式対応認証サーバーと非公式(カスタム対応)の違い 公式対応認証サーバーと、非公式(カスタム対応)の違いをまとめてみる。 ケース1:IdPの「署名用公開鍵」がローテーション(変更)されたとき JWT(トークン)が偽造されていないかを証明するための「公開鍵」は、 セキュリティ担保のために数ヶ月ごとに自動で新しいものにローテーションするのが一般的。 公式対応の場合、SnowflakeがOkta側の鍵更新スケジュールや新しい公開鍵の 取得先をあらかじめ知っているため、Snowflake側が自動で追従する。 開発者は何のアクションも起こす必要はなく、システムは止まらない。 カスタム、つまり非公式の場合であっても基本的には指定したURL (.well-known/jwks.json)を見に行ってくれるので自動追従するが、 もしIdP側のメジャーアップデート等で「公開鍵を配置するURLの仕様そのもの」 が変わった場合は、Snowflakeの設定パラメータ(EXTERNAL_OAUTH_JWS_KEYS_URL) を開発者が手動で新しいURLに書き換えるまで、認証・認可がすべてエラーになってシステムが停止する。 IdP側のセキュリティ仕様やエンドポイントの仕様が変更されたとき 近年、サイバー攻撃の高度化に伴い、IdP側(OktaやMicrosoftなど)がトークンの発行ルールや、 検証用APIの仕様(プロトコル)をより安全なものへ強制アップデートすることがある。 SnowflakeはOktaやMicrosoftと強固な技術パートナーシップを結んでいるため、 IdP側の仕様変更がリリースされる前に、Snowflake側の「特急レーン(専用プログラム)」を 事前にアップデートして追従させる。そのため、開発者がコードや設定を修正することなく、 シームレスに新しいセキュリティ基準へ移行できる。 カスタム、つまり非マネージドの場合、Snowflakeは「汎用的なOAuth 2.0の標準規格(RFC)」に 準拠していることしか保証しない。そのため、IdP側が独自のセキュリティ拡張を行ったり、 標準規格の解釈を変更したりした場合、トークンのペイロード構造(キー・バリュー)が変わり、 Snowflakeがトークンを解読できなくなるリスクがある。 この場合、開発者がIdP側の設定を手動で修正して追従する必要がある。 まとめ SnowflakeにおけるExternal OAuth統合の仕組みを「認証」と「認可」のレイヤを分離して読んでみた。 認証・認可を完全にIdPに移譲し、Snowflakeアプリケーション内で一切の認可コードを書かずに済む。 数あるIdPのうち、いくつかについてはSnowflakeが公式対応している。 公式IdP構成はテクノロジーパートナーシップに基づき、Snowflakeのマネージド構成の一部として、 Snowflake側がIdP側の変更に自動追従する可能性が高い。結果としてダウンタイムの発生を回避できる。 公式対応IdPでなくても、OAuth2.0 RFC準拠の認証サーバとしてカスタム連携することができるが、 SnowflakeがIdP側の変更に自動追従する性質ではなく、運用者・開発者がIdP側の変更に適用する必要がある。

Streamlit in Snowflakeの開発環境を整備して初めてのアプリケーションを実装した話

はじめに Streamlit in Snowflakeの開発を開始するには、Snowflakeアカウント、適切なIDE設定、ローカル開発環境の構築といった複数のステップが必要。この記事では、前提条件の確認、アプリケーション実装といった標準的なセットアップ手順をまとめる。 前提条件と必須の準備作業 Streamlit in Snowflakeの開発を始める前に、複数の前提条件を満たす必要がある。 前提条件の詳細: Snowflakeアカウントへのアクセス - 有効なSnowflakeアカウントと、CREATE APPLICATION PACKAGE 権限を持つロールが必須である。ロール設計を行い、この権限を付与したカスタムロールを使用する Pythonの開発環境 - Python 3.8以上がインストールされており、pipやcondaといったパッケージ管理ツールが動作する状態が前提である。Streamlit in SnowflakeはPython 3.10以上での動作を推奨している Snowparkライブラリ - ローカル開発環境にsnowpark、snowflake-snowpark-python といったパッケージをインストール済みであることが必須 Snowflake CLIツール - Snowflake提供の公式CLIツール(snow)をシステムに導入する必要がある。このツールを通じてSnowflakeを対話的に操作する 認証情報の管理 - ローカル開発では、Snowflakeへの接続情報をコードに埋め込まないことが重要である。環境変数、~/.snowsql/config ファイル、またはキーペア認証を使用して管理する。本番環境へのデプロイ時には、AWS Secrets Manager、Azure Key Vault、HashiCorp Vaultといった外部認証サービスの利用が推奨される IDE統合と開発環境の構築 Visual Studio Codeの統合により、ローカル開発フェーズ全体をエディタ内で完結させられる。Pythonコード編集、ローカルテスト実行(`streamlit run` での動作確認)、Snowflakeへのクエリ検証、デプロイまでの準備がVS Code内で実現される。一方、デプロイ後の本番環境ではSnowflakeウェブコンソール内でアプリケーションが動作する。公式のSnowflake拡張機能を利用することで、Snowflakeへの接続管理、SQL文の実行、デバッグが統一されたインターフェース内で実現される。 IDE統合のセットアップ手順: Visual Studio CodeにSnowflake拡張機能をインストールする。拡張機能マーケットプレイスから「Snowflake」を検索し、公式のSnowflake Inc.提供版をインストールする 拡張機能をインストール後、接続設定ファイル(通常は~/.snowsql/config)を確認し、接続情報が正確に記述されていることを検証する コマンドパレット(Ctrl+Shift+P または Cmd+Shift+P)からSnowflakeの接続を確立する。接続テストが成功することで、Snowflakeへの通信が確認される SQLエディタを開き、簡単なクエリ(例:SELECT CURRENT_USER())を実行してSnowflakeとの疎通確認を行う IDE統合によって、ローカルでのPythonコード編集とSnowflakeのデータ参照が同一画面で実現され、開発の効率が劇的に向上する。 GitHub Codespacesでの開発も可能: ローカルマシンの環境管理を避けたい場合、GitHub Codespacesを使用した開発も実用的に使用できる。CodespacesにおいてSnowflake拡張機能、Snowflake CLI、Streamlit CLIがサポートされている。`streamlit run` コマンドで起動したアプリケーションはCodespaces内で自動的にポート転送され、ブラウザプレビューが利用可能である。環境構築を.devcontainer/devcontainer.jsonで定義すれば、チーム全体で統一された開発環境を即座に立ち上げられる。 ローカル開発環境のセットアップ ローカルマシンでStreamlit in Snowflakeアプリケーションを開発するには、複数のPythonパッケージが必要である。仮想環境の構築を通じて、プロジェクト固有の依存関係を隔離することが実務上の標準である。 開発環境の選択肢について: Docker環境を用いてローカル開発環境を構築することも技術的には可能だが、本番環境がSnowflake内の管理コンテナ上で実行されるため、ローカルのDocker環境と本番環境の構成を統一することはできない。開発環境をDockerで隔離したとしても、本番デプロイ時には別の実行環境へ移行するため、Docker化による環境共通化のメリットは限定的である。仮想環境による環境隔離で十分であり、Docker導入による複雑さの増加は費用対効果が低い。チーム規模が大きく、開発環境の統一が重要な場合のみDocker化を検討する価値がある。 # 仮想環境を作成 python3 -m venv streamlit_env # 仮想環境を有効化(macOS/Linux) source streamlit_env/bin/activate # 仮想環境を有効化(Windows) streamlit_envScriptsactivate # 必須パッケージをインストール pip install streamlit pip install snowflake-snowpark-python pip install snowflake-cli-labs # requirements.txtを作成し、プロジェクトの依存関係を記録 pip freeze > requirements.txt requirements.txtファイルの内容例: streamlit==1.28.0 snowflake-snowpark-python==1.10.0 snowflake-cli-labs==2.0.0 pandas==2.0.0 仮想環境の隔離により、異なるプロジェクト間での依存パッケージの競合を回避できる。これは本番環境へのデプロイ時にも重要であり、requirements.txtはアプリケーションと共にSnowflakeにアップロードされる。 ローカルでの初期テストと動作確認 ローカル開発環境が構築されたら、Streamlitが正常に動作するか確認する必要がある。最小限のアプリケーションコードでSnowflakeへの接続テストを行う。 # app.py import streamlit as st from snowflake.snowpark.context import get_active_session st.title(\"Streamlit in Snowflake - 初期テスト\") try: session = get_active_session() user = session.sql(\"SELECT CURRENT_USER()\").collect()[0][0] st.success(f\"Snowflakeへの接続成功。現在のユーザー: {user}\") except Exception as e: st.error(f\"接続エラー: {str(e)}\") # 簡単なデータクエリ if st.checkbox(\"テーブル一覧を表示\"): try: databases = session.sql(\"SHOW DATABASES\").collect() st.write(f\"利用可能なデータベース数: {len(databases)}\") except Exception as e: st.error(f\"クエリ実行エラー: {str(e)}\") このテストアプリケーションを実行する場合、ローカルではStreamlit CLIでの実行が可能である。 streamlit run app.py ただし、ローカルでの実行にはSnowflakeへの認証情報が必要である。環境変数で接続情報を提供する方法が一般的である。 必須の環境変数: SNOWFLAKE_ACCOUNT - Snowflakeアカウント識別子(例:xy12345.us-east-1) SNOWFLAKE_USER - ログインユーザー名 SNOWFLAKE_PASSWORD - ユーザーのパスワード(パスワード認証の場合) SNOWFLAKE_WAREHOUSE - クエリ実行用のウェアハウス名 SNOWFLAKE_DATABASE - デフォルトのデータベース名 SNOWFLAKE_SCHEMA - デフォルトのスキーマ名 キーペア認証を用いる場合は、SNOWFLAKE_PASSWORD の代わりに SNOWFLAKE_PRIVATE_KEY_PATH と SNOWFLAKE_PRIVATE_KEY_PASSPHRASE を設定する。環境変数の設定例: export SNOWFLAKE_ACCOUNT=\"xy12345.us-east-1\" export SNOWFLAKE_USER=\"developer_user\" export SNOWFLAKE_PASSWORD=\"your_secure_password\" export SNOWFLAKE_WAREHOUSE=\"dev_warehouse\" export SNOWFLAKE_DATABASE=\"analytics_db\" export SNOWFLAKE_SCHEMA=\"dev_schema\" # その後、streamlit run app.py を実行 streamlit run app.py 別の方法として、~/.snowsql/config ファイルに接続情報を記述し、Snowpark が自動的に読み込む設定も可能である。 初めてのアプリケーション実装 前提条件とローカル環境が整備されたら、Snowflakeアカウント内に実際のアプリケーションを作成する準備が整う。最小限のアプリケーションを実装し、Snowflakeへのデプロイが正常に機能することを確認する。 最小限のアプリケーション実装: # app.py import streamlit as st st.title(\"初めてのアプリケーション\") st.write(\"Hello World.\") このシンプルな実装で、Streamlit in Snowflakeへのデプロイが正常に完了し、本番環境でUIが表示されることを確認できる。Snowflake側でアプリケーション作成用のステージとメタデータを準備する必要がある。 -- Snowflakeで実行:アプリケーション用のステージを作成 CREATE STAGE IF NOT EXISTS app_stage; -- アプリケーション設定ファイルを準備 -- manifest.ymlを作成してステージにアップロード Snowflake CLIを使用して、ローカルのアプリケーションコードをSnowflakeにデプロイする。 # Snowflake CLIでプロジェクトを初期化 snow project init # ローカルの開発コードをSnowflakeにデプロイ snow app deploy デプロイ後の検証: Snowflakeウェブコンソールにログインし、アプリケーション一覧から新規作成したアプリケーションが表示されていることを確認する アプリケーションをクリックして開き、UIが正常に表示され、Snowflakeへのクエリが実行される状態を確認する デプロイ後の最初の実行はコールドスタートのため、数秒の遅延が発生するが、以後のアクセスは高速化される 以下はエラーハンドリングを組み込んだ実装例。Snowflake環境において発生したネットワークエラー、タイムアウト、権限不足といった例外を補足し表示してみた。 # エラーハンドリングを含む実装例 import streamlit as st from snowflake.snowpark.context import get_active_session st.set_page_config(page_title=\"データダッシュボード\", layout=\"wide\") try: session = get_active_session() st.header(\"データ参照アプリケーション\") # ユーザー情報の取得 current_user = session.sql(\"SELECT CURRENT_USER()\").collect()[0][0] st.sidebar.write(f\"ユーザー: {current_user}\") # データベース選択 db_list = session.sql(\"SHOW DATABASES\").collect() databases = [row[1] for row in db_list] selected_db = st.selectbox(\"データベースを選択\", databases) st.success(f\"接続完了\") except Exception as e: st.error(f\"エラーが発生しました: {type(e).__name__}\") st.info(\"管理者に連絡してください\") まとめ 本記事では、前提条件の確認、IDE統合(Visual Studio Code、Snowflake拡張機能のセットアップ)、GitHub Codespacesでの開発環境構築の検討、ローカル開発環境の準備(仮想環境、パッケージインストール、Docker化の考慮を含む)について言及した。また、ローカルテスト実行時の環境変数設定方法についてまとめた。最後に、最小限のアプリケーションを実装し、Snowflakeへデプロイ後、動作確認を行なった。

Streamlit in Snowflakeにおける分離コンテナ環境とセッション管理の仕組みを理解した話

はじめに Streamlit in Snowflakeで本番環境のアプリケーションを構築する際、実行環境とセッション管理の仕組みを理解することは必須である。標準的なStreamlitとは異なり、Snowflake統合版はSnowflakeの管理するコンテナ内で実行され、アプリケーションのライフサイクル、パフォーマンス特性、状態管理が大きく異なる。本稿では、この実行モデルの核心部分に焦点を当て、本番環境での実装判断に必要な知識を整理する。標準的なStreamlitの開発経験がある技術者であっても、Snowflake統合版の独特なアーキテクチャを把握することで、より堅牢で効率的なアプリケーション設計が可能となる。 Snowflakeの管理するコンテナ内での実行 Streamlit in Snowflakeのアプリケーションは、Snowflakeのアカウント内で管理された隔離されたコンテナプロセス上で実行される。ローカルマシンのPythonプロセスのように直接制御することはなく、Snowflakeのインフラストラクチャが実行環境全体を統制する。 実行環境の核心的な特性: 各アプリケーションはSnowflakeのアカウント領域内で独立した仮想環境として分離されており、他のテナントや他のアプリケーションとの干渉を受けない アプリケーションの起動、実行、終了はSnowflakeの制御下にあり、ユーザーのアクセスパターンに応じた動的なスケーリングが自動的に実行される Pythonランタイムは事前にコンテナ内にプリロードされており、ユーザーがアプリケーションにアクセスした時点でコードの実行が即座に開始される コンテナはステートレスな設計であり、複数のユーザーセッション間でローカルのファイルシステム上の状態は保持されない メモリ、CPU、ネットワーク帯域幅などのリソースは制限されており、無限に大規模なデータセットをメモリに展開することはできない この設計により、スケーラビリティと管理負荷の削減が実現される。開発者はインフラストラクチャの保守運用から解放され、アプリケーション本体の開発に集中できる。一方で、アプリケーション開発者は「各セッションは独立している」「ローカル状態は永続しない」という前提でコーディングする必要があり、この認識がなければ本番環境で予期しない動作が発生する可能性がある。 ExecutionContextとSnowflakeのセッション情報へのアクセス Streamlit in Snowflakeで最も重要な概念がExecutionContextである。これはSnowflakeのセッション情報とアプリケーション実行の状態を統合したオブジェクトであり、アプリケーションコード内から直接アクセスすることが可能である。 ExecutionContextを通じて、認証済みユーザーの識別子、割り当てられたウェアハウス、セッションのロール情報、現在のデータベースとスキーマといった情報が取得できる。これらの情報はSnowflakeの権限管理体系と一体化しており、アプリケーションが実行するすべてのSQL文はこのコンテキストの権限に基づいて検証される。 from snowflake.snowpark.context import get_active_session session = get_active_session() # 現在のユーザーを取得 current_user = session.sql(\"SELECT CURRENT_USER()\").collect()[0][0] # 割り当てられたウェアハウスを確認 current_warehouse = session.sql(\"SELECT CURRENT_WAREHOUSE()\").collect()[0][0] # 現在のロール情報 current_role = session.sql(\"SELECT CURRENT_ROLE()\").collect()[0][0] # アプリケーション領域のスキーマを取得 current_schema = session.sql(\"SELECT CURRENT_SCHEMA()\").collect()[0][0] ExecutionContextから取得可能な情報の実用的な用途: 認証済みユーザーID:このユーザーが属するテナント、部門、権限レベルをデータベースから検索し、表示内容を動的に制御する基盤となる 割り当てられたウェアハウス:クエリの実行リソースがどのウェアハウスに割り当てられるかを把握し、多くの重い処理が実行される時間帯を避けるといった最適化判断に活用される セッションのロール情報:ロールベースアクセス制御の実装において、現在のユーザーが実行可能な操作を制限する際に利用される 現在のデータベースおよびスキーマ:アプリケーションが参照するテーブルやストアドプロシージャの名前空間を把握し、正確なクエリを構築する際に用いられる ExecutionContextはSnowflakeの行レベルセキュリティ(RLS)および動的データマスキング(DDM)と組み合わせることで、マルチテナント環境でのデータ分離を実装できる。ユーザーが属するテナント情報をExecutionContextから抽出し、その情報をSQLクエリに動的にフィルタリング条件として付与するパターンが一般的である。 セッション状態の管理と永続化戦略 Streamlit in Snowflakeでは、標準的なStreamlitの`st.session_state`メカニズムが使用される。ただし、その永続性と可用性については、通常のStreamlitと異なる考慮が必要である。 セッション状態の保持期間と動作: ユーザーがブラウザを閉じるまで、またはセッションのタイムアウト(デフォルト約60分)が発生するまで、`st.session_state`に格納されたPythonオブジェクトは保持される セッション終了後、メモリ上の状態は完全に消失し、その後のユーザーアクセスでは初期化された状態から再出発する 複数のユーザーセッションが並行して実行される場合、各セッションのメモリ空間は完全に独立しており、相互干渉は発生しない 分散環境ではコンテナのリバランシングが発生する可能性があり、メモリ内状態への依存度が高いと予期しない状態喪失が発生する危険性がある セッション状態を効果的に使用するパターンとしては、ユーザーの入力フォーム状態、フィルタ条件、ページネーション状態、一時的なキャッシュなど、セッション内での短期的な状態に限定することが推奨される。 import streamlit as st from snowflake.snowpark.context import get_active_session session = get_active_session() # セッション状態で一時的なUIフィルタを保持 if \'selected_date_range\' not in st.session_state: st.session_state.selected_date_range = (None, None) if \'filter_status\' not in st.session_state: st.session_state.filter_status = \'all\' # ユーザーインタラクションでセッション状態を更新 date_range = st.date_input(\"期間を選択\", st.session_state.selected_date_range) st.session_state.selected_date_range = date_range # 永続化が必要な設定はSnowflakeテーブルに明示的に保存 if st.button(\'設定を保存\'): current_user = session.sql(\"SELECT CURRENT_USER()\").collect()[0][0] session.sql(f\"\"\" UPDATE user_preferences SET ui_settings = parse_json(?) WHERE user_id = ? \"\"\", params=[str(st.session_state.filter_status), current_user]).collect() st.success(\"設定を保存しました\") セッション終了後も保持する必要があるデータ(ユーザー設定、保存された検索条件、レポート結果など)は、Snowflakeのテーブルに明示的に書き込む必要がある。この明確な分離により、アプリケーションの動作が予測可能になり、バグの温床となる隠れた状態共有が回避される。 パフォーマンス特性とコールドスタート最適化 Streamlit in Snowflakeのパフォーマンス特性は、コンテナの起動時間、リソースの割り当て、クエリの実行効率によって大きく影響を受ける。 パフォーマンスに関わる重要な指標: 初期化時間:ユーザーがアプリケーションに初めてアクセスする際、Snowflakeがコンテナを起動し、Pythonランタイムを初期化するまでに3秒から10秒程度要する場合がある。これをコールドスタートと呼ぶ SQLクエリ実行時間:SQLクエリの実行時間は主にSnowflakeのクエリプランニングと分散処理の時間に依存し、ネットワークレイテンシは相対的に最小限である メモリ制限:各コンテナプロセスのメモリは制限されており、gigabyte単位の大規模なデータセットを一度にメモリにロードすることは技術的に不可能である リソース競合:同一のウェアハウス上で複数のアプリケーションやクエリが並行実行される場合、リソース争奪による性能低下が発生する可能性がある キャッシュ効果:頻繁にアクセスされるテーブルやクエリ結果はSnowflakeの内部キャッシュに保持され、2度目以降のアクセスは高速化される 本番環境ではコールドスタート対策として、アプリケーション初期化時の処理を最小化し、必要なデータは遅延読み込みするパターンが採用される。また、複雑な分析処理やデータ変換はSnowflakeのストアドプロシージャに委譲し、アプリケーション層では結果の表示と対話的なUIの提供に専念することが効率的である。 ウェアハウスとリソース割り当ての考慮 Streamlit in Snowflakeのアプリケーションが実行するすべてのSQL文は、指定されたウェアハウスのコンピュート能力を消費する。ウェアハウスの選択は、クエリの実行速度、同時実行可能なセッション数、運用コストに大きな影響を与える。 ウェアハウス選択の実務的考慮: 小規模なウェアハウス(XSMALL、SMALL)はコストが低く、軽量なクエリや低アクセス頻度のアプリケーションに適しており、一方で大量のユーザーからの並行アクセスには不向きである 大規模なウェアハウス(LARGE、XLARGE以上)は並行クエリ処理の能力が高く、多くのユーザーからの同時アクセスに対応できるが、アイドル状態であってもコストが発生する オートスケール機能を有効にすることで、負荷に応じたウェアハウスの自動拡張が可能になり、ピーク時の対応と非ピーク時のコスト削減を両立させられる 複数のアプリケーションが同一ウェアハウスを共有する場合、負荷分散戦略を立案し、一つのアプリケーションの過度なリソース消費が他のアプリケーションに悪影響を与えないようにする必要がある リソースの効率的な利用と高いパフォーマンスの両立は、ウェアハウスのサイズ選択、クエリの最適化、適切なキャッシング戦略によって初めて実現される。 まとめ Streamlit in Snowflakeは、Snowflakeの管理するマネージドコンテナ環境内で動作し、ExecutionContextを通じてSnowflake側のセッション情報に直接アクセスできるアーキテクチャである。セッション状態は短期的なUI状態の保持に限定し、永続化が必要なデータはSnowflakeテーブルに委譲することが設計の原則である。また、コールドスタートやリソース共有の課題を念頭に置いて、初期化処理の最小化とクエリの最適化によるパフォーマンス改善アプローチを検討する必要がある。これらの理解があれば、本番環境での実装判断が格段に容易になり、堅牢で拡張性の高いアプリケーション設計が可能となる。

やりなおし統計

Fellegi-Sunterモデルに基づく確率的名寄せパッケージ Splinkを試してみる

[mathjax] Record Linkage、Entity Recognitionなど、いわゆる「名寄せ」問題には、割とさまざまな解決策が 提案されている。その1つに確率論的な数学的背景を持つFellegi-Sunterモデルを実装したSplinkがある。 数学的背景について咀嚼して説明してくれているサイトはあまりなく、Splinkの公式ドキュメントが唯一 に近いソースだろうと思う。 今回は5万人の人物データを使って重複を除去するSplinkの主要な機能のサンプルを追ってみた。 数式を読んだところで実りは薄い、ということもあり、なるべく数式を読まずに読み進めたが、 モデルのパラメタ、予測結果を解釈する際にどうしても数式を読まないとわからない箇所があった。 特にベイズ統計について理解がかなり浅くて、正直なところ理解しきれてはいない。 Splinkの公式ドキュメントを読み進めていった際に気付いたメモを書いてみることにした。 [arst_toc tag=\"h4\"] 環境構築・準備 環境構築の実装部分は主な学習要素ではないのでClaude Codeでサクッと作ってみた。 まずはPlan modeで計画を立ててもらう。これぐらいだとほぼ言い値で使える。 DuckDB、Splinkを使う最小構成の環境を作ります。環境はDocker composeで構築します。計画を立ててください。 ⏺ DuckDB、Splinkを使うDocker Compose環境の計画を完了しました。 計画概要 1. 設計文書: .tmp/design.mdで要件、アーキテクチャ、技術仕様を定義 2. タスク管理: .tmp/task.mdで詳細な実装手順を3フェーズに分けて整理 主要コンポーネント - Jupyter Lab環境: DuckDB、Splink、必要なライブラリがプリインストール - データボリューム: サンプルデータ、notebook、出力用のディレクトリ - 最小構成: 開発・学習用途に特化した軽量環境 次のステップ 実装を開始する場合は、Phase 1から順番に進めます: 1. Docker設定ファイル作成 2. ディレクトリ構造構築 3. サンプル実装 4. テストと検証 計画の詳細は.tmp/design.mdと.tmp/task.mdをご確認ください。実装を開始しますか? 実際にファイルを作ってもらうと以下のような構成となる。 得られる中間成果物である docker-compose.yml、Dockerfile は本記事では省略する。 . ├── data ├── docker-compose.yml ├── Dockerfile ├── notebooks ├── output └── requirements.txt 普通に http://localhost:8888 で JupyterLab が開く。 使用するサンプルデータ 5万人の人物データを使って名寄せを行うサンプル。おそらくSplinkの用途として最初に思いつくやつ。 Splinkにデータをロードする前に必要なデータクリーニング手順について説明がある。 公式によると、まずは行に一意のIDを割り当てる必要がある。 データセット内で一意となるIDであって、重複除去した後のエンティティを識別するIDのことではない。 [clink implicit=\"false\" url=\"https://moj-analytical-services.github.io/splink/demos/tutorials/01_Prerequisites.html\" imgurl=\"https://user-images.githubusercontent.com/7570107/85285114-3969ac00-b488-11ea-88ff-5fca1b34af1f.png\" title=\"Data Prerequisites\" excerpt=\"Splink では、リンクする前にデータをクリーンアップし、行に一意の ID を割り当てる必要があります。このセクションでは、Splink にデータをロードする前に必要な追加のデータクリーニング手順について説明します。\"] 使用するサンプルデータは以下の通り。 from splink import splink_datasets df = splink_datasets.historical_50k df.head() データの分布を可視化 splink.exploratoryのprofile_columnsを使って分布を可視化してみる。 from splink import DuckDBAPI from splink.exploratory import profile_columns db_api = DuckDBAPI() profile_columns(df, db_api, column_expressions=[\"first_name\", \"substr(surname,1,2)\"]) 同じ姓・名の人が大量にいることがわかる。 ブロッキングとブロッキングルールの評価 テーブル内のレコードが他のレコードと「同一かどうか」を調べるためには、 基本的には、他のすべてのレコードとの何らかの比較操作を行うこととなる。 全てのレコードについて全てのカラム同士を比較したいのなら、 対象のテーブルをCROSS JOINした結果、各カラム同士を比較することとなる。 SELECT ... FROM input_tables as l CROSS JOIN input_tables as r あるカラムが条件に合わなければ、もうその先は見ても意味がない、 というケースは多い。例えば、まず first_name 、surname が同じでなければ、 その先の比較を行わない、というのはあり得る。 SELECT ... FROM input_tables as l INNER JOIN input_tables as r ON l.first_name = r.first_name AND l.surname = r.surname このような考え方をブロッキング、ON句の条件をブロッキングルールと言う。 ただ、これだと性と名が完全一致していないレコードが残らない。 そこで、ブロッキングルールを複数定義し、いずれかが真であれば残すことができる。 ここでポイントなのが、ブロッキングルールを複数定義したとき、 それぞれのブロッキングルールで重複して選ばれるレコードが発生した場合、 Splinkが自動的に排除してくれる。 このため、ブロッキングルールを重ねがけすると、最終的に残るレコード数は一致する。 ただ、順番により、同じルールで残るレコード数は変化する。 逆に言うと、ブロッキングルールを足すことで、重複除去後のOR条件が増えていく。 積算グラフにして、ブロッキングルールとその順番の効果を見ることができる。 from splink import DuckDBAPI, block_on from splink.blocking_analysis import ( cumulative_comparisons_to_be_scored_from_blocking_rules_chart, ) blocking_rules = [ block_on(\"substr(first_name,1,3)\", \"substr(surname,1,4)\"), block_on(\"surname\", \"dob\"), block_on(\"first_name\", \"dob\"), block_on(\"postcode_fake\", \"first_name\"), block_on(\"postcode_fake\", \"surname\"), block_on(\"dob\", \"birth_place\"), block_on(\"substr(postcode_fake,1,3)\", \"dob\"), block_on(\"substr(postcode_fake,1,3)\", \"first_name\"), block_on(\"substr(postcode_fake,1,3)\", \"surname\"), block_on(\"substr(first_name,1,2)\", \"substr(surname,1,2)\", \"substr(dob,1,4)\"), ] db_api = DuckDBAPI() cumulative_comparisons_to_be_scored_from_blocking_rules_chart( table_or_tables=df, blocking_rules=blocking_rules, db_api=db_api, link_type=\"dedupe_only\", ) 積算グラフは以下の通り。積み上がっている数値は「比較の数」。 要は、論理和で条件を足していって、次第に緩和されている様子がわかる。 DuckDBでは比較の数を2,000万件以内、Athena,Sparkでは1億件以内を目安にせよとのこと。 比較の定義 Splinkは Fellegi-Sunter model モデル (というかフレームワーク) に基づいている。 https://moj-analytical-services.github.io/splink/topic_guides/theory/fellegi_sunter.html 各カラムの同士をカラムの特性に応じた距離を使って比較し、重みを計算していく。 各カラムの比較に使うためのメソッドが予め用意されているので、特性に応じて選んでいく。 以下では、first_name, sur_name に ForenameSurnameComparison が使われている。 dobにDateOfBirthComparison、birth_place、ocupationにExactMatchが使われている。 import splink.comparison_library as cl from splink import Linker, SettingsCreator settings = SettingsCreator( link_type=\"dedupe_only\", blocking_rules_to_generate_predictions=blocking_rules, comparisons=[ cl.ForenameSurnameComparison( \"first_name\", \"surname\", forename_surname_concat_col_name=\"first_name_surname_concat\", ), cl.DateOfBirthComparison( \"dob\", input_is_string=True ), cl.PostcodeComparison(\"postcode_fake\"), cl.ExactMatch(\"birth_place\").configure(term_frequency_adjustments=True), cl.ExactMatch(\"occupation\").configure(term_frequency_adjustments=True), ], retain_intermediate_calculation_columns=True, ) # Needed to apply term frequencies to first+surname comparison df[\"first_name_surname_concat\"] = df[\"first_name\"] + \" \" + df[\"surname\"] linker = Linker(df, settings, db_api=db_api) ComparisonとComparison Level ここでSplinkツール内の比較の概念の説明。以下の通り概念に名前がついている。 Data Linking Model ├─-- Comparison: Date of birth │ ├─-- ComparisonLevel: Exact match │ ├─-- ComparisonLevel: One character difference │ ├─-- ComparisonLevel: All other ├─-- Comparison: First name │ ├─-- ComparisonLevel: Exact match on first_name │ ├─-- ComparisonLevel: first_names have JaroWinklerSimilarity > 0.95 │ ├─-- ComparisonLevel: first_names have JaroWinklerSimilarity > 0.8 │ ├─-- ComparisonLevel: All other モデルのパラメタ推定 モデルの実行に必要なパラメタは以下の3つ。Splinkを用いてパラメタを得る。 ちなみに u は \"\'U\'nmatch\"、m は \"\'M\'atch\"。背後の数式の説明で現れる。 No パラメタ 説明 1 無作為に選んだレコードが一致する確率 入力データからランダムに取得した2つのレコードが一致する確率 (通常は非常に小さい数値) 2 u値(u確率) 実際には一致しないレコードの中で各 ComparisonLevel に該当するレコードの割合。具体的には、レコード同士が同じエンティティを表すにも関わらず値が異なる確率。例えば、同じ人なのにレコードによって生年月日が違う確率。これは端的には「データ品質」を表す。名前であればタイプミス、別名、ニックネーム、ミドルネーム、結婚後の姓など。 3 m値(m確率) 実際に一致するレコードの中で各 ComparisonLevel に該当するレコードの割合。具体的には、レコード同士が異なるエンティティを表すにも関わらず値が同じである確率。例えば別人なのにレコードによって性・名が同じ確率 (同姓同名)。性別は男か女かしかないので別人でも50%の確率で一致してしまう。 無作為に選んだレコードが一致する確率 入力データからランダムに抽出した2つのレコードが一致する確率を求める。 値は0.000136。すべての可能なレコードのペア比較のうち7,362.31組に1組が一致すると予想される。 合計1,279,041,753組の比較が可能なため、一致するペアは合計で約173,728.33組になると予想される、 とのこと。 linker.training.estimate_probability_two_random_records_match( [ block_on(\"first_name\", \"surname\", \"dob\"), block_on(\"substr(first_name,1,2)\", \"surname\", \"substr(postcode_fake,1,2)\"), block_on(\"dob\", \"postcode_fake\"), ], recall=0.6, ) > Probability two random records match is estimated to be 0.000136. > This means that amongst all possible pairwise record comparisons, > one in 7,362.31 are expected to match. > With 1,279,041,753 total possible comparisons, > we expect a total of around 173,728.33 matching pairs u確率の推定 実際には一致しないレコードの中でComparisonの評価結果がPositiveである確率。 基本、無作為に抽出したレコードは一致しないため、「無作為に抽出したレコード」を 「実際には一致しないレコード」として扱える、という点がミソ。 probability_two_random_records_match によって得られた値を使ってu確率を求める。 estimate_u_using_random_sampling によって、ラベルなし、つまり教師なしでu確率を得られる。 レコードのペアをランダムでサンプルして上で定義したComparisonを評価する。 ランダムサンプルなので大量の不一致が発生するが、各Comparisonにおける不一致の分布を得ている。 これは、例えば性別について、50%が一致、50%が不一致である、という分布を得ている。 一方、例えば生年月日について、一致する確率は 1%、1 文字の違いがある確率は 3%、 その他はすべて 96% の確率で発生する、という分布を得ている。 linker.training.estimate_u_using_random_sampling(max_pairs=5e6) > ----- Estimating u probabilities using random sampling ----- > > Estimated u probabilities using random sampling > > Your model is not yet fully trained. Missing estimates for: > - first_name_surname (no m values are trained). > - dob (no m values are trained). > - postcode_fake (no m values are trained). > - birth_place (no m values are trained). > - occupation (no m values are trained). m確率の推定 「実際に一致するレコード」の中で、Comparisonの評価がNegativeになる確率。 そもそも、このモデルを使って名寄せ、つまり「一致するレコード」を見つけたいのだから、 モデルを作るために「実際に一致するレコード」を計算しなければならないのは矛盾では..となる。 無作為抽出結果から求められるu確率とは異なり、m確率を求めるのは難しい。 もしラベル付けされた「一致するレコード」、つまり教師データセットがあるのであれば、 そのデータセットを使ってm確率を求められる。 例えば、日本人全員にマイナンバーが振られて、全てのレコードにマイナンバーが振られている、 というアナザーワールドがあるのであれば、マイナンバーを使ってm確率を推定する。(どういう状況??) ラベル付けされたデータがないのであれば、EMアルゴリズムでm確率を求めることになっている。 EMアルゴリズムは反復的な手法で、メモリや収束速度の点でペア数を減らす必要があり、 例ではブロッキングルールを設定している。 以下のケースでは、first_nameとsurnameをブロッキングルールとしている。 つまり、first_name, surnameが完全に一致するレコードについてペア比較を行う。 この仮定を設定したため、first_name, surname (first_name_surname) のパラメタを推定できない。 training_blocking_rule = block_on(\"first_name\", \"surname\") training_session_names = ( linker.training.estimate_parameters_using_expectation_maximisation( training_blocking_rule, estimate_without_term_frequencies=True ) ) > ----- Starting EM training session ----- > > Estimating the m probabilities of the model by blocking on: > (l.\"first_name\" = r.\"first_name\") AND (l.\"surname\" = r.\"surname\") > > Parameter estimates will be made for the following comparison(s): > - dob > - postcode_fake > - birth_place > - occupation > > Parameter estimates cannot be made for the following comparison(s) since they are used in the blocking rules: > - first_name_surname > > Iteration 1: Largest change in params was 0.248 in probability_two_random_records_match > Iteration 2: Largest change in params was 0.0929 in probability_two_random_records_match > Iteration 3: Largest change in params was -0.0237 in the m_probability of birth_place, level `Exact match on > birth_place` > Iteration 4: Largest change in params was 0.00961 in the m_probability of birth_place, level `All other >comparisons` > Iteration 5: Largest change in params was -0.00457 in the m_probability of birth_place, level `Exact match on birth_place` > Iteration 6: Largest change in params was -0.00256 in the m_probability of birth_place, level `Exact match on birth_place` > Iteration 7: Largest change in params was 0.00171 in the m_probability of dob, level `Abs date difference Iteration 8: Largest change in params was 0.00115 in the m_probability of dob, level `Abs date difference Iteration 9: Largest change in params was 0.000759 in the m_probability of dob, level `Abs date difference Iteration 10: Largest change in params was 0.000498 in the m_probability of dob, level `Abs date difference Iteration 11: Largest change in params was 0.000326 in the m_probability of dob, level `Abs date difference Iteration 12: Largest change in params was 0.000213 in the m_probability of dob, level `Abs date difference Iteration 13: Largest change in params was 0.000139 in the m_probability of dob, level `Abs date difference Iteration 14: Largest change in params was 9.04e-05 in the m_probability of dob, level `Abs date difference <= 10 year` 同様にdobをブロッキングルールに設定して実行すると、dob以外の列についてパラメタを推定できる。 training_blocking_rule = block_on(\"dob\") training_session_dob = ( linker.training.estimate_parameters_using_expectation_maximisation( training_blocking_rule, estimate_without_term_frequencies=True ) ) > ----- Starting EM training session ----- > > Estimating the m probabilities of the model by blocking on: > l.\"dob\" = r.\"dob\" > > Parameter estimates will be made for the following comparison(s): > - first_name_surname > - postcode_fake > - birth_place > - occupation > > Parameter estimates cannot be made for the following comparison(s) since they are used in the blocking rules: > - dob > > Iteration 1: Largest change in params was -0.474 in the m_probability of first_name_surname, level `Exact match on first_name_surname_concat` > Iteration 2: Largest change in params was 0.052 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 3: Largest change in params was 0.0174 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 4: Largest change in params was 0.00532 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 5: Largest change in params was 0.00165 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 6: Largest change in params was 0.00052 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 7: Largest change in params was 0.000165 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 8: Largest change in params was 5.29e-05 in the m_probability of first_name_surname, level `All other comparisons` > > EM converged after 8 iterations > > Your model is not yet fully trained. Missing estimates for: > - first_name_surname (some u values are not trained). モデルパラメタの可視化 m確率、u確率の可視化。 マッチウェイトの可視化。マッチウェイトは (log_2 (m / u))で計算される。 linker.visualisations.match_weights_chart() モデルの保存と読み込み 以下でモデルを保存できる。 settings = linker.misc.save_model_to_json( \"./saved_model_from_demo.json\", overwrite=True ) 以下で保存したモデルを読み込める。 import json settings = json.load( open(\'./saved_model_from_demo.json\', \'r\') ) リンクするのに十分な情報が含まれていないレコード 「John Smith」のみを含み、他のすべてのフィールドがnullであるレコードは、 他のレコードにリンクされている可能性もあるが、潜在的なリンクを明確にするには十分な情報がない。 以下により可視化できる。 linker.evaluation.unlinkables_chart() 横軸は「マッチウェイトの閾値」。縦軸は「リンクするのに十分な情報が含まれないレコード」の割合。 マッチウェイト閾値=6.11ぐらいのところを見ると、入力データセットのレコードの約1.3%が リンクできないことが示唆される。 訓練済みモデルを使って未知データのマッチウェイトを予測 上で構築した推定モデルを使用し、どのペア比較が一致するかを予測する。 内部的には以下を行うとのこと。 blocking_rules_to_generate_predictionsの少なくとも1つと一致するペア比較を生成 Comparisonで指定されたルールを使用して、入力データの類似性を評価 推定された一致重みを使用し、要求に応じて用語頻度調整を適用して、最終的な一致重みと一致確率スコアを生成 df_predictions = linker.inference.predict(threshold_match_probability=0.2) df_predictions.as_pandas_dataframe(limit=1) > Blocking time: 0.88 seconds > Predict time: 1.91 seconds > > -- WARNING -- > You have called predict(), but there are some parameter estimates which have neither been estimated or > specified in your settings dictionary. To produce predictions the following untrained trained parameters will > use default values. > Comparison: \'first_name_surname\': > u values not fully trained records_to_plot = df_e.to_dict(orient=\"records\") linker.visualisations.waterfall_chart(records_to_plot, filter_nulls=False) predictしたマッチウェイトの可視化、数式との照合 predictしたマッチウェイトは、ウォーターフォール図で可視化できる。 マッチウェイトは、モデル内の各特徴量によって一致の証拠がどの程度提供されるかを示す中心的な指標。 (lambda)は無作為抽出した2つのレコードが一致する確率。(K=m/u)はベイズ因子。 begin{align} M &= log_2 ( frac{lambda}{1-lambda} ) + log_2 K \\ &= log_2 ( frac{lambda}{1-lambda} ) + log_2 m - log_2 u end{align} 異なる列の比較が互いに独立しているという仮定を置いていて、 2つのレコードのベイズ係数が各列比較のベイズ係数の積として扱う。 begin{eqnarray} K_{feature} = K_{first_name_surname} + K_{dob} + K_{postcode_fake} + K_{birth_place} + K_{occupation} + cdots end{eqnarray} マッチウェイトは以下の和。 begin{eqnarray} M_{observe} = M_{prior} + M_{feature} end{eqnarray} ここで begin{align} M_{prior} &= log_2 (frac{lambda}{1-lambda}) \\ M_{feature} &= M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots end{align} 以下のように書き換える。 begin{align} M_{observe} &= log_2 (frac{lambda}{1-lambda}) + sum_i^{feature} log_2 (frac{m_i}{u_i}) \\ &= log_2 (frac{lambda}{1-lambda}) + log_2 (prod_i^{feature} (frac{m_i}{u_i}) ) end{align} ウォーターフォール図の一番左、赤いバーは(M_{prior} = log_2 (frac{lambda}{1-lambda}))。 特徴に関する追加の知識が考慮されていない場合のマッチウェイト。 横に並んでいる薄い緑のバーは (M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots)。 各特徴量のマッチウェイト。 一番右の濃い緑のバーは2つのレコードの合計マッチウェイト。 begin{align} M_{feature} &= M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots \\ &= 8.50w end{align} まとめ 長くなったのでいったん終了。この記事では教師なし確率的名寄せパッケージSplinkを使用してモデルを作ってみた。 次の記事では、作ったモデルを使用して実際に名寄せをしてみる。 途中、DuckDBが楽しいことに気づいたので、DuckDBだけで何個か記事にしてみようと思う。

分散と標準偏差を計算しやすくする

[mathjax] 分散と標準偏差を計算しやすく変形できる。 いちいち偏差(x_i-bar{x})を計算しておかなくても、2乗和(x_i^2)と平均(bar{x})がわかっていればOK。 begin{eqnarray} s^2 &=& frac{1}{n} sum_{i=1}^n (x_i - bar{x})^2 \\ &=& frac{1}{n} sum_{i=1}^n ( x_i^2 -2 x_i bar{x} + bar{x}^2 ) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 -2 bar{x} sum_{i=1}^n x_i + bar{x}^2 sum_{i=1}^n 1) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 -2 n bar{x}^2 + nbar{x}^2 ) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 - nbar{x}^2 ) \\ &=& frac{1}{n} sum_{i=1}^n x_i^2 - bar{x}^2 \\ s &=& sqrt{frac{1}{n} sum_{i=1}^n x_i^2 - bar{x}^2 } end{eqnarray} 以下みたい使える。平均と標準偏差と2乗和の関係。 begin{eqnarray} sum_{i=1}^n (x_i - bar{x})^2 &=& sum_{i=1}^n x_i^2 - nbar{x}^2 \\ ns^2 &=& sum_{i=1}^n x_i^2 - nbar{x}^2 \\ sum_{i=1}^n x_i^2 &=& n(s^2 + bar{x} ) end{eqnarray}

標本調査に必要なサンプル数の下限を与える2次関数

[mathjax] 2項分布に従う母集団の母平均を推測するために有意水準を設定して95%信頼区間を求めてみた。 母平均のあたりがついていない状況だとやりにくい。 [clink url=\"https://ikuty.com/2019/01/11/sampling/\"] (hat{p})がどんな値であっても下限は(hat{p})の関数で抑えられると思ったので、 気になって(hat{p})を変数のまま残すとどうなるかやってみた。 begin{eqnarray} 1.96sqrt{frac{hat{p}(1-hat{p})}{n}} le 0.05 \\ frac{1.96}{0.05}sqrt{hat{p}(1-hat{p})} le sqrt{n} \\ 39.2^2 hat{p}(1-hat{p}) le n end{eqnarray} 左辺を(f(hat{p}))と置くと (f(hat{p}))は下に凸の2次関数であって、 (frac{d}{dhat{p}}f(hat{p})=0)の時に最大となる。というか(hat{p}=0.5)。 (hat{p}=0.5)であるとすると、これはアンケートを取るときのサンプル数を求める式と同じで、 非常に有名な以下の定数が出てくる。 begin{eqnarray} 1537 * 0.5 (1-0.5) le n \\ 384 le n end{eqnarray} (hat{p})がどんな値であっても、サンプル数を400とれば、 有意水準=5%の95%信頼区間を得られる。 だから、アンケートの(n)数はだいたい400で、となる。 さらに、有意水準を10%にとれば、(n)の下限は100で抑えられる。 なるはやのアンケートなら100、ちゃんとやるには400、というやつがこれ。

深層学習

勾配降下法

[mathjax] 各地点において関数の値を最大にするベクトル((frac{partial f}{partial x_0},frac{partial f}{partial x_1}))を全地点に対して計算したものを勾配とかいう。 ある地点において、このベクトルの方向に向かうことにより最も関数の値を大きくする。 で、今後のために正負を反転して関数の値を最小にするベクトルを考えることにした。 関数の値を小さくする操作を繰り返していけば、いずれ\"最小値\"が見つかるはず。 というモチベを続けるのが勾配降下法。学習率(eta)を使って以下みたいに書ける。。 begin{eqnarray} x_0 = x_0 - eta frac{partial f}{partial x_0} \\ x_1 = x_1 - eta frac{partial f}{partial x_1} end{eqnarray} ということで(f(x_0,x_1)=x_0^2+x_1^2)の最小値を初期値((3.0,4.0))、 学習率(eta=0.1)に設定して計算してみる。 import numpy as np def numerical_gradient(f, x): h = 1e-4 grad = np.zeros_like(x) for idx in range(x.size): tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) x[idx] = tmp_val - h fxh2 = f(x) grad[idx] = (fxh1 - fxh2) / (2*h) x[idx] = tmp_val return grad def gradient_descent(f, init_x, lr=0.01, step_num=100): x = init_x for i in range(step_num): grad = numerical_gradient(f,x) x -= lr * grad return x def function2(x): return x[0]**2 + x[1]**2 init_x = np.array([-3.0, 4.0]) v = gradient_descent(function2, init_x=init_x, lr=0.1, step_num=100) v # array([-6.11110793e-10, 8.14814391e-10]) ((0,0))に収束した。 ニューラルネットワークの勾配 損失関数を重みパラメータで微分する。以下みたいな感じ。 損失関数の大小を見るとして、例えば(w_{11})以外の重みを固定したとして(w_{11})をわずかに 増やしたときに損失関数の値がどれだけ大きくなるか。 損失関数の値はパラメータ(W)と入力(x)から決まるベクトルだけれども、それぞれ乱数と入力値が設定されている。 begin{eqnarray} W= begin{pmatrix} w_{11} & w_{12} & w_{13} \\ w_{21} & w_{22} & w_{23} end{pmatrix}, frac{partial L}{partial W}= begin{pmatrix} frac{partial L}{partial w_{11}} & frac{partial L}{partial w_{12}} & frac{partial L}{partial w_{13}} \\ frac{partial L}{partial w_{21}} & frac{partial L}{partial w_{22}} & frac{partial L}{partial w_{23}} end{pmatrix} end{eqnarray} 重み(W)が乱数で決まるネットワークがあるとする。このネットワークは入力と重みの積を出力 として返す。出力はSoftmaxを経由するとする。 ネットワークの出力と教師データのクロスエントロピー誤差を誤差として使う。 その前に、数値微分関数を多次元対応する。 普通、配列の次元が(n)個になると(n)重ループが必要になるけれども、 Numpy.nditer()を使うと(n)乗ループを1回のループにまとめることができる。 下のmulti_indexが((0,0),(0,1),(0,2),(1,0),(1,1),(1,2))みたいに イテレータが(n)次のタプルを返す。反復回数はタプルの要素数の直積。 Numpy配列にそのタプルでアクセスすることで晴れて全ての要素にアクセスできる。 def numerical_gradient_md(f, x): h = 1e-4 grad = np.zeros_like(x) it = np.nditer(x, flags=[\'multi_index\'], op_flags=[\'readwrite\']) while not it.finished: idx = it.multi_index tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) # f(x+h) x[idx] = tmp_val - h fxh2 = f(x) # f(x-h) grad[idx] = (fxh1 - fxh2) / (2*h) x[idx] = tmp_val # 値を元に戻す it.iternext() return grad 初期値(x=(0.6,0.9))、教師データ(t=(0,0,1))をネットワークに入力する。 predict()は(1 times 3)を返す。 それをSoftmax()を通して、(t)とのクロスエントロピー誤差を求めたものが以下。 import numpy as np def cross_entropy_error(y, t): if y.ndim == 1: t = t.reshape(1, t.size) y = y.reshape(1,y.size) batch_size = y.shape[0] delta = 1e-7 return -np.sum( t * np.log( y + delta)) / batch_size def softmax(x): c = np.max(x) return np.exp(x-c) / np.sum(np.exp(x-c)) import sys, os sys.path.append(os.pardir) import numpy as np class simpleNet: def __init__(self): self.W = np.random.randn(2,3) def predict(self, x): return np.dot(x, self.W) def loss(self, x, t): z = self.predict(x) y = softmax(z) loss = cross_entropy_error(y, t) return loss net = simpleNet() x = np.array([0.6, 0.9]) p = net.predict(x) t = np.array([0, 0, 1]) net.loss(x, t) # 0.9463818740797788 このlossを(W)で微分したのが以下。 あえてパラメータ(W)を引数にとり損失関数の値を計算する(f(W))を定義することで、 数値微分が何と何の演算なのかをわかりやすくしている。 実際は(f(W))は(W)とは関係なく(x)と(t)だけから結果を返すけれども、 損失関数(f(W))を(W)で微分するという操作が自明になるようにコードを合わせている。 def f(W): return net.loss(x, t) dW = numerical_gradient_md(f, net.W) dW # array([[ 0.07627371, 0.49923236, -0.57550607], # [ 0.11441057, 0.74884853, -0.8632591 ]]) 結果の解釈 上記の(w),(W),(t)から(frac{partial L}{partial W})が求まった。 損失関数が何か複雑な形をしているという状況で、 (frac{partial L}{partial w_{11}})は(w_{11})がわずかに動いたときに損失関数の値が変化する量を表している。 それが(w_{11})から(w_{32})まで6個分存在する。 begin{eqnarray} frac{partial L}{partial W} = begin{pmatrix} frac{partial L}{partial w_{11}} & frac{partial L}{partial w_{21}} & frac{partial L}{partial w_{31}} \\ frac{partial L}{partial w_{12}} & frac{partial L}{partial w_{22}} & frac{partial L}{partial w_{32}} end{pmatrix} = begin{pmatrix} 0.07627371 & 0.49923236 & -0.57550607 \\ 0.11441057 & 0.74884853 & -0.8632591 end{pmatrix} end{eqnarray}

勾配の可視化

[mathjax] 2変数関数(f(x_0,x_1))を各変数で偏微分する。 地点((i,j))におけるベクトル((frac{partial f(x_0,j)}{partial x_0},frac{partial f(i,x_1)}{partial x_1}))を全地点で記録していき、ベクトル場を得る。 このベクトル場が勾配(gradient)。 (f(x_0,x_1)=x_0^2+x_1^2)について、(-4.0 le x_0 le 4.0)、(-4.0 le x_1 le 4.0)の範囲で、 勾配を求めてみる。また、勾配を可視化してみる。 まず、2変数関数(f(x_0,x_1))の偏微分係数を求める関数の定義。 ((3.0,3.0))の偏微分係数は((6.00..,6.00..))。 def numerical_gradient(f, x): h = 10e-4 grad = np.zeros_like(x) for idx in range(x.size): tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) x[idx] = tmp_val - h fxh2 = f(x) grad[idx] = (fxh1 - fxh2) / 2*h x[idx] = tmp_val return grad def function2(x): return x[0]**2 + x[1]**2 p = np.array([3.0,3.0]) v = numerical_gradient(function2, p) v # array([6.e-06, 6.e-06]) (-4.0 le x_0 le 4.0)、(-4.0 le x_1 le 4.0)の範囲((0.5)刻み)で偏微分係数を求めて、 ベクトル場っぽく表示してみる。matplotlibのquiver()は便利。 各地点において関数の値を最も増やす方向が表示されている。 w_range = 4 dw = 0.5 w0 = np.arange(-w_range, w_range, dw) w1 = np.arange(-w_range, w_range, dw) wn = w0.shape[0] diff_w0 = np.zeros((len(w0), len(w1))) diff_w1 = np.zeros((len(w0), len(w1))) for i0 in range(wn): for i1 in range(wn): d = numerical_gradient(function2, np.array([ w0[i0], w1[i1] ])) diff_w0[i1, i0], diff_w1[i1, i0] = (d[0], d[1]) plt.xlabel(\'$x_0$\',fontsize=14) #x軸のラベル plt.ylabel(\'$x_1$\',fontsize=14) #y軸のラベル plt.xticks(range(-w_range,w_range+1,1)) #x軸に表示する値 plt.yticks(range(-w_range,w_range+1,1)) #y軸に表示する値 plt.quiver(w0, w1, diff_w0, diff_w1) plt.show() 値が大きい方向に矢印が向いている。例えば((-3.0,3.0))における偏微分係数は((-6.0,6.0))。 左上方向へのベクトル。 参考にしている本にはことわりが書いてあり、勾配にマイナスをつけたものを図にしている。 その場合、関数の値を最も減らす方向が表示されることになる。 各地点において、この勾配を参照することで、どちらに移動すれば関数の値を最も小さくできるかがわかる。

おっさんが数値微分を復習する

引き続き、ゼロDの写経を続ける。今回は、学生の頃が懐かしい懐ワード、数値微分。 全然Deepに入れないけれどおっさんの復習。解析的な微分と数値微分が一致するところを確認してみる。 昔と違うのは、PythonとJupyterNotebookで超絶簡単に実験できるし、 こうやってWordPressでLaTeXで記事を書いたりできる点。 [mathjax] まず、微分の基本的な考え方は以下の通り。高校数学の数3の範囲。 begin{eqnarray} frac{df(x)}{fx} = lim_{hrightarrow infty} frac{f(x+h)-f(x)}{h} end{eqnarray} 情報系学科に入って最初の方でEuler法とRunge-Kutta法を教わってコードを 書いたりレポート書いたりする。懐すぎる..。 または、基本情報の試験かなんかで、小さい値と小さい値どうしの計算で発生する問題が現れる。 ゼロDにはこの定義を少し改良した方法が載っている。へぇ。 begin{eqnarray} frac{df(x)}{fx} = lim_{hrightarrow infty} frac{f(x+h)-f(x-h)}{2h} end{eqnarray} 写経なので、がんばって数値微分を書いて動かしてみる。 簡単な2次関数(f(x))。 begin{eqnarray} f(x) &=& x^2 - 5x +3 \\ f\'(x) &=& 2x - 5 end{eqnarray} def numerical_diff(f, x): h = 10e-4 return (f(x+h) - f(x-h)) / (2*h) (f(x))と、(x=2.5)のところの(f\'(x))をmatplotlibで書いてみる。懐い... import matplotlib.pyplot as plt import numpy as np def f(x): return x**2 - 5*x + 3 x = np.arange(-10, 10, 0.1) y = f(x) dy = numerical_diff(f,x) plt.plot(x, y) x1 = -2.5 dy1 = numerical_diff(f, x1) y1 = f(x1) # y-y1 = dy1(x-x1) -> y = dy1(x-x1) + y1 j = lambda x: dy1 * (x-x1) + y1 plt.plot(x,j(x)) plt.xlabel(\'x\') plt.ylabel(\'y\') plt.grid() plt.show() 偏微分 2変数以上の関数の数値微分は以下の通り。片方を止める。 数値微分の方法は上記と同じものを使った。 begin{eqnarray} frac{partial f(x_0,x_1)}{partial x_0} &=& lim_{hrightarrow infty} frac{f(x_0 +h,x_1)-f(x_0-h,x_1)}{2h} \\ frac{partial f(x_0,x_1)}{partial x_1} &=& lim_{hrightarrow infty} frac{f(x_0,x_1+h)-f(x_0,x_1-h)}{2h} end{eqnarray} ((x_0,x_1)=(1,1))における(x_0)に対する偏微分(frac{partial f(x_0,x_1)}{x_0})、(x_1)に対する偏微分(frac{partial f(x_0,x_1)}{x_1})を求めてみる。 ちゃんと(frac{partial f(x_0,1.0)}{x_0}=2.00..)、(frac{partial f(1.0,x_1)}{x_1}=2.00..)になった。 import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.mplot3d import Axes3D def f(x): return x[0]**2 + x[1]**2 X = np.meshgrid(np.arange(-5., 5., 0.2),np.arange(-5., 5., 0.2)) Z = f(X) fig = plt.figure(figsize=(6, 6)) axes = fig.add_subplot(111, projection=\'3d\') axes.plot_surface(X[0],X[1], Z) f0 = lambda x: x**2 + 1.0**2 f1 = lambda x: 1.0**2 + x**2 df0 = numerical_diff(f0, 1.0) df1 = numerical_diff(f1, 1.0) print(df0) # 2.0000000000000018 print(df1) # 2.0000000000000018 plt.show()