日々の学び

React+Next.jsでDummy JSONのCRUDをCSR/SSRの両方で作成して違いを調べてみた話

Next.jsでCRUDを作りSSRの挙動を調べてみた。いったんPage Routerを選択。 バックエンドとなるAPIとしてDummyJSONのPosts-Docs APIを使用した。 一覧、詳細、更新、削除が用意される。ただし更新、削除はダミーで永続化されない。 [clink implicit=\"false\" url=\"https://dummyjson.com/docs/posts\" imgurl=\"https://dummyjson.com/public/img/hero-image.svg\" title=\"Free Fake REST API for Placeholder JSON Data\" excerpt=\"Develop, Build, Test.Get instant dummy JSON data for your frontend with DummyJSON Server — no backend setup needed!\"] 目次は以下。 [arst_toc tag=\"h4\"] 構成 CSR版/SSR版の2パターンについてCRUDを行うアプリをClaude(Sonnet4.5)で環境を構築した。 ルーティングについては今回の調査範囲外のため、いったんシンプルなPage Routerを使用した。 npm run dev で next dev --turbopack が動く何かが作られた。turbopackはrust製のwebpackの後継。 いったん実行環境の詳細な把握をスキップして上物の理解を進めることにする。上物の構成は以下。 . ├── app/ │ ├── page.tsx # ホームページ │ └── posts/ │ ├── page.tsx # 投稿一覧選択(CSR/SSR) │ ├── csr/ │ │ ├── page.tsx # 投稿一覧(Client Component版) │ │ ├── [id]/ │ │ │ ├── page.tsx # 投稿詳細(Client Component版) │ │ │ └── edit/ │ │ │ └── page.tsx # 投稿編集(Client Component版) │ │ └── new/ │ │ └── page.tsx # 新規投稿作成(Client Component版) │ ├── ssr/ │ │ ├── page.tsx # 投稿一覧(Server Component版) │ │ ├── [id]/ │ │ │ ├── page.tsx # 投稿詳細(Server Component版) │ │ │ └── edit/ │ │ │ └── page.tsx # 投稿編集(Server Actions版) │ │ └── new/ │ │ └── page.tsx # 新規投稿作成(Server Actions版) │ └── _components/ │ └── DeleteButton.tsx # 削除ボタン(Client Component) ├── lib/ │ └── api.ts # API関数 ├── types/ │ └── post.ts # 型定義 ├── Dockerfile # Dockerイメージ設定 ├── docker-compose.yml # Docker Compose設定 └── next.config.ts # Next.js設定 以下のようなアプリができた。 [video width=\"2280\" height=\"1792\" webm=\"https://ikuty.com/wp-content/uploads/2025/10/recording.webm\"][/video] ReactとNext.jsの関係性と役割分担 ReactはUIを作るためのJavaScriptライブラリとして機能する。コンポーネント、フック、JSXを提供する。 Next.jsはReactを使ったフレームワークであり、ルーティング、ビルド、最適化などの機能を提供する。 React (ライブラリ) Next.js (フレームワーク) 役割の焦点 UI (ユーザーインターフェース) 構築 Webアプリケーション全体の構築 主な提供物 コンポーネント (UI要素)、JSX、Hooks (状態管理など) ルーティング、レンダリング戦略、最適化機能、バックエンド連携 ルーティング 非搭載。別途React Routerなどのライブラリが必要。 ファイルベースルーティングが組み込みで提供される。 レンダリング クライアントサイドレンダリング (CSR) が基本。ブラウザでJavaScriptが実行されてUIを描画する。 プリレンダリング (SSR/SSG) が組み込みで提供される。レンディングのタイミングと場所を制御する。 データ取得 非搭載。fetch APIなどをコンポーネント内で使用する。 データ取得パターン (Server Components, Route Handlersなど) とキャッシュの仕組みが組み込みで提供される。 クライアントサイドレンダリング(CSR) ブラウザ側で動的にHTMLを生成する。useState、useEffect、イベントハンドラが使える。 Step1.初期レンダリング(サーバ側) app/layout.tsx がサーバーで実行される。<html><body>の枠組みを作る。 app/posts/[id]/page.tsxがClinet Componentとして認識され初期HTMLを作る。 ブラウザに初期HTMLを送り、ブラウザは初期HTMLを表示する // app/layout.tsx (20-34行目) export default function RootLayout({ children }: { children: React.ReactNode }) { return ( <html lang=\"en\"> <body> {children} {/* ← ここに子コンポーネントが入る */} </body> </html> ); } ... //ブラウザに送られる初期HTMLの例 <html> <body> <div class=\"min-h-screen p-8\"> <p>読み込み中...</p> ← loading=trueの状態 </div> <script src=\"/_next/...\"></script> ← クライアント用JS </body> </html> Step2.ハイドレーション(Hydration) JavaScriptが読み込まれる Reactがコンポーネントを「水分補給」(Hydrate)、HTMLに機能を追加 初期state: loading = true, post = null Step3.useEffectの実行(副作用) コンポーネントが画面に表示された直後に1回実行される。 // app/posts/[id]/page.tsx (16-18行目) useEffect(() => { loadPost(); // ← コンポーネントがマウントされたら実行 }, [params.id]); Step4.データフェッチとstate更新 api.getPost()を実行。-> fetch(\'https://dummyjson.com/posts/1\') const data でレスポンスを受け取る setPost(data)でstateを更新 setLoading(false)でローディング終了。loading=false // app/posts/[id]/page.tsx (20-32行目) const loadPost = async () => { try { setLoading(true); // ローディング表示 const data = await api.getPost(Number(params.id)); // API呼び出し setPost(data); // ← state更新! setError(null); } catch (err) { setError(\'投稿の読み込みに失敗しました\'); } finally { setLoading(false); // ← state更新! } }; Step5.再レンダリング(stateが変わったので) post stateが更新されたので再レンダリング 条件分岐を再評価 最終的なJSXをDOMに反映 // app/posts/[id]/page.tsx (47-55行目) if (loading) { // loading = false なので通過 return 読み込み中...; } if (error || !post) { // error = null, post = データあり なので通過 return エラー表示; } // ここが実行される! return ( <div className=\"min-h-screen p-8\"> <h1>{post.title}</h1> {/* ← post.title を表示 */} <p>{post.body}</p> {/* ← post.body を表示 */} {/* ... */} </div> ); Step6.リストの動的レンダリング // app/posts/[id]/page.tsx (133-140行目) {post.tags.map((tag, index) => ( // ← 配列をループ <span key={index} className=\"...\"> {tag} {/* ← 各タグを表示 */} </span> ))} 実行結果: post.tags = [\"history\", \"american\", \"crime\"] ↓ map() で変換 <span key={0}>history</span> <span key={1}>american</span> <span key={2}>crime</span> 全体のレンダリングフロー [ユーザーが /posts/1 にアクセス] ↓ ┌──────────────────────────────┐ │ サーバー側(Next.js Server) │ ├──────────────────────────────┤ │ 1. app/layout.tsx を実行 │ │ → を生成 │ │ │ │ 2. app/posts/[id]/page.tsx │ │ を「クライアントコンポーネント」│ │ として認識 │ │ → 初期HTML生成 │ │ (loading=true状態) │ └──────────────────────────────┘ ↓ HTML + JS送信 ┌──────────────────────────────┐ │ ブラウザ側(Client) │ ├──────────────────────────────┤ │ 3. HTMLを表示 │ │ 「読み込み中...」 │ │ │ │ 4. JavaScriptロード │ │ → Hydration実行 │ │ │ │ 5. useEffect発火 │ │ → loadPost()実行 │ │ │ │ 6. API呼び出し │ │ fetch(https://dummyjson.com/posts/1) │ ↓ │ │ レスポンス受信 │ │ │ │ 7. setState実行 │ │ setPost(data) │ │ setLoading(false) │ │ ↓ │ │ 8. 再レンダリング │ │ → 投稿詳細を表示 │ └──────────────────────────────┘ サーバーサイドレンダリング(SSR) サーバ側でHTMLが生成される。DBやAPIに直接アクセスできる。useState,useEffectを使わない。 例えば、当アプリにおいて / へのアクセスに対してNext.jsが app/page.tsx を実行する。 HTMLを生成してブラウザに送信し、ブラウザはHTMLを表示する。 1周回って戻ってきたというか、LaravelやRailsにフロントエンドから寄っていくスタイル。 バックエンドにAPIを用意せずDBを直接操作できるため、SPAが不要な簡易的な管理画面など、 大幅な工数削減が可能になると思う。 Laravel,Railsだと、フロントエンドの記述にVue/Reactを導入する必要があるため、 バックエンド・フロントエンド、という棲み分けが発生してしまうが、 Next.jsのSSR(+CSR混在)により、フロントエンドとバックエンドを同じ仕組みで実現できる点で 管理する対象が大幅に減るのもメリットだと思う。 import Link from \'next/link\'; import { api } from \'@/lib/api\'; import DeleteButton from \'@/app/posts/_components/DeleteButton\'; // Server Component(デフォルト) // \'use client\' ディレクティブがないため、サーバー側で実行される export default async function PostsPageSSR() { // サーバー側で直接データ取得 // useEffect や useState は不要 const data = await api.getPosts(); const posts = data.posts; return ( <div className=\"min-h-screen p-8 bg-gray-50\"> <div className=\"max-w-4xl mx-auto\"> {/* ヘッダー部分 */} <div className=\"bg-blue-50 border border-blue-200 rounded-lg p-4 mb-6\"> <p className=\"text-sm text-blue-800\"> <strong>Server Component版 - このページはサーバー側でレンダリングされ、HTMLに既にデータが含まれています </p> </div> ... ); } SSRとCSRの統合 SSRモードとCSRモードの2つのモードが存在する訳ではなく、SSRとCSRは同時に存在し得る。 例えば、今回作成したSSR版アプリの投稿一覧画面において、CSRで削除ボタンを実現している。 コンポーネント単位でSSR/CSRの分離が起こるだけで、アーキ全体ではSSRとCSRは混在できる。 TypeScriptにより型安全にpropsを渡せるし、状態管理がReactの仕組みで統一できる。 部分的な更新は可能 (router.refresh() )。 // app/posts/ssr/page.tsx (Server Component) export default async function PostsPageSSR() { const data = await api.getPosts(); // サーバー側で実行 const posts = data.posts; return ( <div> {posts.map(post => ( <div key={post.id}> <h2>{post.title}</h2> {/* Client Componentをそのまま埋め込める */} <DeleteButton postId={post.id} /> </div> ))} </div> ); } // app/posts/_components/DeleteButton.tsx (Client Component) \'use client\'; export default function DeleteButton({ postId }: { postId: number }) { const router = useRouter(); const handleDelete = async () => { if (!confirm(\'削除しますか?\')) return; await api.deletePost(postId); router.refresh(); // この部分だけ更新 }; return ( <button onClick={handleDelete}>削除</button> ); } まとめ Next.jsのHello WorldをしつつSSRとCSRの挙動を確認した。 フロント側フレームワークの枠組みを越え、フロント・バックエンドを統一的に扱えることを確認した。 アプリケーションの要件次第で、SSRを中心に部分的にCSRとすることで大幅な工数削減を期待できそう。

go-txdbを使ってgolang, gin, gorm(gen)+sqlite構成のAPI をテストケース毎に管理する

データベース依存のテストケースを作る際に、テストケース毎にDBがクリーンな状態を維持したい。 go-txdbはDBへの接続時にトランザクションを開始、切断時にトランザクションを終了するSQLドライバ。 テスト実行中にトランザクション内で発行したステートメント・行はテスト終了時には消滅する。 DB毎に実装方法は異なり、例えばSQLiteでは\"トランザクション\"ではなくsaveponitで実装される。 [clink implicit=\"false\" url=\"https://github.com/DATA-DOG/go-txdb\" imgurl=\"https://avatars.githubusercontent.com/u/6613360?s=48&v=4\" title=\"Single transaction based sql.Driver for GO\" excerpt=\"Package txdb is a single transaction based database sql driver. When the connection is opened, it starts a transaction and all operations performed on this sql.DB will be within that transaction. If concurrent actions are performed, the lock is acquired and connection is always released the statements and rows are not holding the connection.\"] [arst_toc tag=\"h4\"] 環境構築 Claude Code (Sonnet4.5) で以下の環境を構成した。途中15回のエラーリカバリを挟んで 期待通りの環境が出来上がった。 main.goがアプリケーションのルーティング(ハンドラ共有)、 main_test.goが main.goのルートに対するテスト。テストにはTestMain()が含まれている。 test_repository_test.goはGinが生成したリポジトリ層(モデル)をルートを経由せずテストする。 $ tree . -L 2 . ├── data │   └── db.sqlite # SQLite DBファイル ├── docker-compose.yml # Go+sqlite ├── Dockerfile # golang:1.23-alpineベース ├── gen.go # GORM Genコード生成スクリプト ├── go.mod # 依存関係の定義(go getやgo mod tidyで更新) ├── go.sum # 依存関係の検証用ハッシュ(自動) ├── init.sql # DDL,初期レコード ├── main.go # Gin初期化,ルーティング ├── main_test.go # main.goのルーティングに対するテストコード ├── models # モデル │   ├── model # testsテーブルと対応する構造体定義 (自動生成) │   └── query # (自動生成) ├── repository │   └── test_repository_test.go # リポジトリ層(データアクセス層)のテスト └── testhelper └── db.go # TxDB初期化等テスト用ヘルパー サンプルデータの準備 testsというテーブルに id, value というカラムを用意し、hoge, fuga レコードを挿入しておく。 簡略化のためにSQLiteを使用しており、ホスト側のファイルをbindマウントし初期実行判定して投入した。 -- Create tests table CREATE TABLE IF NOT EXISTS tests ( id INTEGER PRIMARY KEY, value TEXT NOT NULL ); -- Insert initial data INSERT OR IGNORE INTO tests (id, value) VALUES (1, \'hoge\'); INSERT OR IGNORE INTO tests (id, value) VALUES (2, \'fuga\'); CRUD ルーティング gin, gorm(gen) を使用して testsテーブルに対するCRUDを行う以下のルートを定義した。 それぞれ、genを使用しGolang言語のレベルでオブジェクトを操作している。 | メソッド | エンドポイント| 説明 | 仕様 | |--------|------------|----------------|-------------------------------------------------| | GET | /hello | 全レコード取得 | Find()で全レコードを取得し返却 | | GET | /hello/:id | 指定IDのレコード取得 | URLパラメータからIDを取得し、該当レコードを返却 | | POST | /hello | 新規レコード追加 | JSONリクエストボディからidとvalueを受け取り新規作成 | | PATCH | /hello/:id | 指定IDのレコード更新 | URLパラメータのIDとJSONボディのvalueでレコード更新 | | DELETE | /hello/:id | 指定IDのレコード削除 | URLパラメータのIDでレコード削除. | 各ハンドラの詳細な実装は冗長なので割愛。 手動リクエストと応答 各エンドポイント に対するリクエストとレスポンスの関係は以下。期待通り。 # 全件取得し応答 $ curl http://localhost:8080/hello [{\"id\":1,\"value\":\"hoge\"},{\"id\":2,\"value\":\"fuga\"}] # id=1を取得し応答 $ curl http://localhost:8080/hello/1 {\"id\":1,\"value\":\"hoge\"} # id=3を追加 $ curl -X POST http://localhost:8080/hello -H \"Content-Type: application/json\" -d \'{\"id\":3,\"value\":\"piyo\"}\' {\"id\":3,\"value\":\"piyo\"} $ curl http://localhost:8080/hello [{\"id\":1,\"value\":\"hoge\"},{\"id\":2,\"value\":\"fuga\"},{\"id\":3,\"value\":\"piyo\"}] # id=3を変更 $ curl -X PATCH http://localhost:8080/hello/3 -H \"Content-Type: application/json\" -d \'{\"value\":\"updated_piyo\"}\' {\"id\":3,\"value\":\"updated_piyo\"} # id=3を削除 $ curl -X DELETE http://localhost:8080/hello/3 {\"message\":\"record deleted successfully\"} # 全件取得し応答 $ curl http://localhost:8080/hello [{\"id\":1,\"value\":\"hoge\"},{\"id\":2,\"value\":\"fuga\"}] txdbを使用するためのテスト用ヘルパー関数 txdbを使用するためのテスト用ヘルパー関数を以下のように定義しておく。 package testhelper import ( \"database/sql\" \"fmt\" \"os\" \"sync\" \"sync/atomic\" \"github.com/DATA-DOG/go-txdb\" _ \"github.com/mattn/go-sqlite3\" \"gorm.io/driver/sqlite\" \"gorm.io/gorm\" ) var ( once sync.Once connID atomic.Uint64 ) // SetupTxDB initializes txdb driver for testing func SetupTxDB() { once.Do(func() { // Get database path dbPath := os.Getenv(\"DB_PATH\") if dbPath == \"\" { dbPath = \"./data/db.sqlite\" } // Register txdb driver with SQLite-specific options // Use WAL mode and configure for better concurrent access dsn := fmt.Sprintf(\"%s?_journal_mode=WAL&_busy_timeout=5000\", dbPath) txdb.Register(\"txdb\", \"sqlite3\", dsn) }) } // NewTestDB creates a new test database connection with txdb // Each connection will be isolated in a transaction and rolled back after test func NewTestDB() (*gorm.DB, error) { SetupTxDB() // Open connection using txdb driver with unique connection ID // This ensures each test gets its own isolated transaction id := connID.Add(1) sqlDB, err := sql.Open(\"txdb\", fmt.Sprintf(\"connection_%d\", id)) if err != nil { return nil, fmt.Errorf(\"failed to open txdb connection: %w\", err) } // Wrap with GORM db, err := gorm.Open(sqlite.Dialector{ Conn: sqlDB, }, &gorm.Config{}) if err != nil { return nil, fmt.Errorf(\"failed to open gorm connection: %w\", err) } return db, nil } テストの命名規則と共通処理 テストの関数名はTestXXX()のようにTestから始まりキャメルケースを続ける。 TestMain()内に全ての処理の前に実行する処理、後に実行する処理を記述できる。 package main import ( \"bytes\" \"encoding/json\" \"net/http\" \"net/http/httptest\" \"os\" \"testing\" \"gin_txdb/testhelper\" \"github.com/gin-gonic/gin\" \"github.com/stretchr/testify/assert\" \"github.com/stretchr/testify/require\" ) func TestMain(m *testing.M) { // Set DB_PATH for testing os.Setenv(\"DB_PATH\", \"./data/db.sqlite\") // Set Gin to test mode gin.SetMode(gin.TestMode) // Run tests code := m.Run() os.Exit(code) } 全件取得のテスト ヘルパー関数のNewTestDB()を使用することでtxdbを使用してDBに接続している。 defer func()内でコネクションを明示的にクローズする処理を遅延評価(=テスト完了時評価)しているが、 テスト実行中にエラーやpanicが起きた場合に開いたDBを切ることができなくなる問題への対処。 特にSQLiteの場合「接続は常に1つ」なので、切り忘れで接続が開きっぱなしになると、 次のテスト実行でロックエラーが発生する。明示的に閉じることでこの問題を確実に回避できる。 後はアサートを書いていく。 func TestGetAllTests(t *testing.T) { // Setup test database with txdb db, err := testhelper.NewTestDB() require.NoError(t, err) defer func() { sqlDB, _ := db.DB() sqlDB.Close() }() // Setup router using main.go\'s SetupRouter router := SetupRouter(db) // Create request req, _ := http.NewRequest(http.MethodGet, \"/hello\", nil) w := httptest.NewRecorder() // Perform request router.ServeHTTP(w, req) // Assert response assert.Equal(t, http.StatusOK, w.Code) var response []map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &response) require.NoError(t, err) // Should have 2 initial records assert.Len(t, response, 2) assert.Equal(t, float64(1), response[0][\"id\"]) assert.Equal(t, \"hoge\", response[0][\"value\"]) assert.Equal(t, float64(2), response[1][\"id\"]) assert.Equal(t, \"fuga\", response[1][\"value\"]) } このテストだけ実行してみる。-run オプションでテスト名を指定する。 $ go test -run TestGetAllTests [GIN] 2025/10/15 - 17:17:44 | 200 | 238.666µs | | GET \"/hello\" PASS ok gin_txdb 0.496s 1件取得のテスト(正常系) [GET] /hello/:id のテスト。指定したIDが存在する正常系。 func TestGetTestByID_Success(t *testing.T) { // Setup test database with txdb db, err := testhelper.NewTestDB() require.NoError(t, err) defer func() { sqlDB, _ := db.DB() sqlDB.Close() }() // Setup router router := SetupRouter(db) // Create request req, _ := http.NewRequest(http.MethodGet, \"/hello/1\", nil) w := httptest.NewRecorder() // Perform request router.ServeHTTP(w, req) // Assert response assert.Equal(t, http.StatusOK, w.Code) var response map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &response) require.NoError(t, err) assert.Equal(t, float64(1), response[\"id\"]) assert.Equal(t, \"hoge\", response[\"value\"]) } 実行結果は以下の通り。 go test -run TestGetTestByID_Success [GIN] 2025/10/15 - 17:24:41 | 200 | 207.25µs | | GET \"/hello/1\" PASS ok gin_txdb 0.330s 1件取得のテスト(異常系) [GET] /hello/:idのテスト。指定したIDが見つからない異常系。 func TestGetTestByID_NotFound(t *testing.T) { // Setup test database with txdb db, err := testhelper.NewTestDB() require.NoError(t, err) defer func() { sqlDB, _ := db.DB() sqlDB.Close() }() // Setup router router := SetupRouter(db) // Create request for non-existent ID req, _ := http.NewRequest(http.MethodGet, \"/hello/999\", nil) w := httptest.NewRecorder() // Perform request router.ServeHTTP(w, req) // Assert response assert.Equal(t, http.StatusNotFound, w.Code) var response map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &response) require.NoError(t, err) assert.Equal(t, \"record not found\", response[\"error\"]) } 実行結果は以下の通り。 go test -run TestGetTestByID_NotFound ./gin_txdb/main.go:52 record not found [0.105ms] [rows:0] SELECT * FROM `tests` WHERE `tests`.`id` = 999 ORDER BY `tests`.`id` LIMIT 1 [GIN] 2025/10/15 - 17:22:45 | 404 | 542.875µs | | GET \"/hello/999\" PASS ok gin_txdb 0.672s 1件追加のテスト(正常系) [POST] /helloが正常終了した場合、追加したレコードをレスポンスで返す処理のため、 レスポンスで返ってきたデータをアサートしている。 その後、[GET] /hello/:id のレスポンスを使ってアサートしている。 func TestCreateTest_Success(t *testing.T) { // Setup test database with txdb db, err := testhelper.NewTestDB() require.NoError(t, err) defer func() { sqlDB, _ := db.DB() sqlDB.Close() }() // Setup router router := SetupRouter(db) // Create request body payload := map[string]interface{}{ \"id\": 100, \"value\": \"test_value\", } body, _ := json.Marshal(payload) // Create request req, _ := http.NewRequest(http.MethodPost, \"/hello\", bytes.NewBuffer(body)) req.Header.Set(\"Content-Type\", \"application/json\") w := httptest.NewRecorder() // Perform request router.ServeHTTP(w, req) // Assert response assert.Equal(t, http.StatusCreated, w.Code) var response map[string]interface{} err = json.Unmarshal(w.Body.Bytes(), &response) require.NoError(t, err) assert.Equal(t, float64(100), response[\"id\"]) assert.Equal(t, \"test_value\", response[\"value\"]) // Verify the record was actually created req2, _ := http.NewRequest(http.MethodGet, \"/hello/100\", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code) } 実行結果は以下の通り。 $ go test -run TestCreateTest_Success [GIN] 2025/10/15 - 17:30:04 | 201 | 398.167µs | | POST \"/hello\" [GIN] 2025/10/15 - 17:30:04 | 200 | 47.625µs | | GET \"/hello/100\" PASS ok gin_txdb 0.505s 1件追加のテスト(異常なパラメタ。異常系) testsレコードはid,valueカラムを持つ。idのみ(valueなし)を渡した場合400エラーを返す。 func TestCreateTest_MissingFields(t *testing.T) { // Setup test database with txdb db, err := testhelper.NewTestDB() require.NoError(t, err) defer func() { sqlDB, _ := db.DB() sqlDB.Close() }() // Setup router router := SetupRouter(db) // Create request body with missing value field payload := map[string]interface{}{ \"id\": 100, } body, _ := json.Marshal(payload) // Create request req, _ := http.NewRequest(http.MethodPost, \"/hello\", bytes.NewBuffer(body)) req.Header.Set(\"Content-Type\", \"application/json\") w := httptest.NewRecorder() // Perform request router.ServeHTTP(w, req) // Assert response assert.Equal(t, http.StatusBadRequest, w.Code) } 実行結果は以下の通り。 期待通り400エラーが返ったことをアサートできた。 go test -run TestCreateTest_MissingFields [GIN] 2025/10/15 - 17:36:49 | 400 | 139.709µs | | POST \"/hello\" PASS ok gin_txdb 0.501s txdbが正しくトランザクションを分離していることのテスト Claude Code (Sonnet4.5) が (指示していないのに) 自動的にこのテストを作成してくれた。 お勉強を兼ねたテストプロジェクトであることを伝えたために、気を利かせてくれた感がある。 以下をテストする。 トランザクション内での一貫性 (トランザクション内で作成したデータを同一トランザクション内で観察できる) トランザクション間の分離 (別のトランザクションで作成したデータを観察できない。テストは独立している) 自動ロールバックの動作確認 (txdbがClose()時に自動的にロールバックを実行している。DBは初期状態に戻る) あくまで、一貫性、分離、ロールバックの一例を見せてもらうだけなのだが、 こういうことをやりたいのだな、という背景を理解できたのでお勉強としては十分。 func TestTransactionIsolation(t *testing.T) { // This test demonstrates that each test gets isolated transactions t.Run(\"Test1_CreateRecord\", func(t *testing.T) { db, err := testhelper.NewTestDB() require.NoError(t, err) defer func() { sqlDB, _ := db.DB() sqlDB.Close() }() router := SetupRouter(db) // Create a new record payload := map[string]interface{}{ \"id\": 200, \"value\": \"test200\", } body, _ := json.Marshal(payload) req, _ := http.NewRequest(http.MethodPost, \"/hello\", bytes.NewBuffer(body)) req.Header.Set(\"Content-Type\", \"application/json\") w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusCreated, w.Code) // Verify it exists in this transaction req2, _ := http.NewRequest(http.MethodGet, \"/hello/200\", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) assert.Equal(t, http.StatusOK, w2.Code) }) t.Run(\"Test2_VerifyRollback\", func(t *testing.T) { db, err := testhelper.NewTestDB() require.NoError(t, err) defer func() { sqlDB, _ := db.DB() sqlDB.Close() }() router := SetupRouter(db) // The record created in Test1 should not exist (rolled back) req, _ := http.NewRequest(http.MethodGet, \"/hello/200\", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) // Should still have only 2 original records req2, _ := http.NewRequest(http.MethodGet, \"/hello\", nil) w2 := httptest.NewRecorder() router.ServeHTTP(w2, req2) var response []map[string]interface{} json.Unmarshal(w2.Body.Bytes(), &response) assert.Len(t, response, 2) }) } まとめ go-txdbを使うことで、テストケース毎にDBを分離できることを確認した。 あればかなり便利だと思う。

gorm互換の型安全なORMであるgenでCRUD APIを試作

GolangとGinでAPIを書くために必要な要素技術を学習中、 ORMは何が良いか考えてみたが、gormは無いなぁと思うに至った。 ORマッパーがどの程度の抽象化を担うべきか、については答えはないと思うが、 Webアプリケーションのシナリオで出てくるテーブル構造と関係程度は完全にSQLを排除して欲しい。 SQLを排除することで可読性が向上するし、静的型付けによる恩恵を得られる。 Genには以下のような特徴がある。 型安全: コンパイル時にエラー検出 自動補完: IDEでメソッドとフィールドが補完される クエリビルダー: Where(q.Product.Name.Like(\"%...\"))のような直感的なAPI GORM互換: 既存のGORMモデルをそのまま使用可能 なぜ\"Gen\"なのかは、ビルド時にGolangコードから静的に(ビルド前に)オブジェクトにアクセスする ために必要なGoオブジェクトを生成する、という仕組みから来ているのではないかと思う。 [clink implicit=\"false\" url=\"https://gorm.io/gen/index.html\" imgurl=\"https://gorm.io/gorm.svg\" title=\"Gen Guides\" excerpt=\"GEN: Friendly & Safer GORM powered by Code Generation.Idiomatic & Reusable API from Dynamic Raw SQL.100% Type-safe DAO API without interface{}.Database To Struct follows GORM conventions.GORM under the hood, supports all features, plugins, DBMS that GORM supports.\"] [arst_toc tag=\"h4\"] 環境構築 サクッとClaudeで環境を作った。実際に商用環境を作るとしたら必要な理解の度合いは上がるだろうが、 試してみるまでの時間が無駄にかかって勿体無いのと、Claudeに入口を教わるのは悪くない。 以下の構成で、Golang+GinにCRUDルートを設定しgenを介してDBアクセスできる。 models以下にテーブルと対応する型定義された構造体が格納される。 また、query以下にGormレベルの操作をGen(Golang)レベルに抽象化する自動生成コードが格納される。 query以下を読むと、GenがGormのラッパーであることが良くわかる。 $ tree . -n 2 . ├── cmd │   └── generate │   └── main.go # マイグレーション ├── config │   └── database.go # DB接続設定 ├── database │   └── database.go # Conenct(), Close(), GetDB()など ├── docker-compose.yml # Golangアプリケーション(8080), PostgreSQL(5432) ├── Dockerfile ├── go.mod ├── go.sum ├── handlers │   └── product.go ├── main.go # CRUD APIのルーティング ├── models │   └── product.go # テーブル->モデル ├── query │   ├── gen.go # モデルを操作するラッパー │   └── products.gen.go # SQLレベルのモデル操作をGolangレベルに抽象化するためのIF └── README.md CRUDルート 早速、CRUD APIのルートを作っていく。Claudeにお任せしたところ商品(Product)のCRUD APIが出来た。 その位置にMigrate置くの本当に良いの? という感があるが、本題はそこではないので省略。 package main import ( \"log\" \"github.com/gin-gonic/gin\" \"github.com/gin-gonic/gin/binding\" \"github.com/ikuty/golang-gin/database\" \"github.com/ikuty/golang-gin/handlers\" \"github.com/ikuty/golang-gin/models\" \"github.com/ikuty/golang-gin/query\" ) func main() { // データベース接続 if err := database.Connect(); err != nil { log.Fatalf(\"Failed to connect to database: %v\", err) } defer database.Close() // マイグレーション実行 db := database.GetDB() if err := db.AutoMigrate(&models.Product{}); err != nil { log.Fatalf(\"Failed to migrate database: %v\", err) } // Gen初期化 query.SetDefault(db) // Ginエンジンの初期化 r := gin.Default() // 8. GORM + PostgreSQL - CRUD操作 r.GET(\"/api/products\", handlers.GetProductsHandler) // 全商品取得 r.GET(\"/api/products/:id\", handlers.GetProductHandler) // 商品詳細取得 r.POST(\"/api/products\", handlers.CreateProductHandler) // 商品作成 r.PUT(\"/api/products/:id\", handlers.UpdateProductHandler) // 商品更新 r.DELETE(\"/api/products/:id\", handlers.DeleteProductHandler) // 商品削除 r.GET(\"/api/products/search\", handlers.SearchProductsHandler) // 商品検索 // サーバー起動 r.Run(\":8080\") } モデル さて、モデル定義(=テーブル構造)はどうなっているかというと、以下の通り。 フィールドの物理型をGenを介してGolangで厳密で管理できるのは動的型付け言語にはない利点。 package models import ( \"time\" \"gorm.io/gorm\" ) // Product は商品モデル type Product struct { ID uint `gorm:\"primarykey\" json:\"id\"` Name string `gorm:\"size:100;not null\" json:\"name\" binding:\"required\"` Description string `gorm:\"size:500\" json:\"description\"` Price float64 `gorm:\"not null\" json:\"price\" binding:\"required,gt=0\"` Stock int `gorm:\"default:0\" json:\"stock\"` Category string `gorm:\"size:50\" json:\"category\"` CreatedAt time.Time `json:\"created_at\"` UpdatedAt time.Time `json:\"updated_at\"` DeletedAt gorm.DeletedAt `gorm:\"index\" json:\"-\"` } // TableName はテーブル名を指定 func (Product) TableName() string { return \"products\" } ハンドラ(商品詳細取得) 素晴らしい。説明が不要なくらいDBアクセスが抽象化されている。 ただ、依存性注入があるEloquentと比べるとロジックと関係ない冗長な処理が残っている。 db,q,Contextは裏側に隠して欲しいという思いはあるものの、これでも良いかとも思う。 Find()はGenにより自動生成される。interfaceが用意されビルド時に全て解決される。 なお、VSCodeなどで補完が効く、というのは、例えば JetBrains環境であれば、 動的型付け言語であってもほぼ実現されているので、それほど実利があるメリットではない。 package handlers import ( \"net/http\" \"strconv\" \"github.com/gin-gonic/gin\" \"github.com/ikuty/golang-gin/database\" \"github.com/ikuty/golang-gin/models\" \"github.com/ikuty/golang-gin/query\" ) // GetProductsHandler は全商品を取得 func GetProductsHandler(c *gin.Context) { db := database.GetDB() q := query.Use(db) products, err := q.Product.WithContext(c.Request.Context()).Find() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to fetch products\", }) return } c.JSON(http.StatusOK, gin.H{ \"data\": products, \"count\": len(products), }) } ハンドラ(指定の商品を取得) バリデータを介さず自力でバリデーション(IDがUintか)を行っている。 Productに対してWhereで条件指定し(Order By Ascした後に)先頭のオブジェクトを取得している。 もはや他に説明が必要ないくらい抽象化されていて良い。 // GetProductHandler は指定IDの商品を取得 func GetProductHandler(c *gin.Context) { id := c.Param(\"id\") idUint, err := strconv.ParseUint(id, 10, 32) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid ID\", }) return } db := database.GetDB() q := query.Use(db) product, err := q.Product.WithContext(c.Request.Context()).Where(q.Product.ID.Eq(uint(idUint))).First() if err != nil { c.JSON(http.StatusNotFound, gin.H{ \"error\": \"Product not found\", }) return } c.JSON(http.StatusOK, gin.H{ \"data\": product, }) } ハンドラ(商品作成) 次はCreate。モデルオブジェクトを空から生成し入力値をバインドして整形した後に、 Create()に渡している。Create()の内部はGormレベルの(低レイヤの)コードが動く。 // CreateProductHandler は商品を作成 func CreateProductHandler(c *gin.Context) { var product models.Product if err := c.ShouldBindJSON(&product); err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid request\", \"details\": err.Error(), }) return } db := database.GetDB() q := query.Use(db) if err := q.Product.WithContext(c.Request.Context()).Create(&product); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to create product\", }) return } c.JSON(http.StatusCreated, gin.H{ \"message\": \"Product created successfully\", \"data\": product, }) } ハンドラ(商品更新) 基本的にはCreate()と同じ。空モデルに入力値をバインドしUpdate()に渡している。 実行後に更新対象のオブジェクトを取得しているがEloquentは確か更新の戻りがオブジェクトだった。 // UpdateProductHandler は商品を更新 func UpdateProductHandler(c *gin.Context) { id := c.Param(\"id\") idUint, err := strconv.ParseUint(id, 10, 32) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid ID\", }) return } db := database.GetDB() q := query.Use(db) ctx := c.Request.Context() // 既存の商品を取得 product, err := q.Product.WithContext(ctx).Where(q.Product.ID.Eq(uint(idUint))).First() if err != nil { c.JSON(http.StatusNotFound, gin.H{ \"error\": \"Product not found\", }) return } // 更新データをバインド var updateData models.Product if err := c.ShouldBindJSON(&updateData); err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid request\", \"details\": err.Error(), }) return } // 更新実行 _, err = q.Product.WithContext(ctx).Where(q.Product.ID.Eq(uint(idUint))).Updates(&updateData) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to update product\", }) return } // 更新後のデータを取得 product, _ = q.Product.WithContext(ctx).Where(q.Product.ID.Eq(uint(idUint))).First() c.JSON(http.StatusOK, gin.H{ \"message\": \"Product updated successfully\", \"data\": product, }) } ハンドラ(論理削除) DeletedAtフィールドがNULLの場合、そのレコードはアクティブ。非Nullなら論理削除済み。 Unscoped()を介さずDelete()した場合(つまりデフォルトでは)論理削除となる。 DeletedAtは他のAPIから透過的に扱われる。論理削除状態かどうかは把握しなくて良い。 DeletedAtはデフォルトでは*time.Time型だが、のデータ形式の対応も可能。 // DeleteProductHandler は商品を削除(ソフトデリート) func DeleteProductHandler(c *gin.Context) { id := c.Param(\"id\") idUint, err := strconv.ParseUint(id, 10, 32) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid ID\", }) return } db := database.GetDB() q := query.Use(db) // ソフトデリート実行 _, err = q.Product.WithContext(c.Request.Context()).Where(q.Product.ID.Eq(uint(idUint))).Delete() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to delete product\", }) return } c.JSON(http.StatusOK, gin.H{ \"message\": \"Product deleted successfully\", }) } ハンドラ(商品検索) Where句を複数記述する場合など、手続き的に条件用のオブジェクトを足していける。 一見、productQueryを上から上書きしているように見えるが、Genのクエリビルダーはimmutableパターン として振る舞い、都度実行によりWhereの戻りとなるオブジェクトが累積していく動作となる。 // SearchProductsHandler は商品を検索 func SearchProductsHandler(c *gin.Context) { db := database.GetDB() q := query.Use(db) ctx := c.Request.Context() // クエリパラメータを取得 name := c.Query(\"name\") category := c.Query(\"category\") minPrice := c.Query(\"min_price\") maxPrice := c.Query(\"max_price\") // クエリビルダー productQuery := q.Product.WithContext(ctx) if name != \"\" { productQuery = productQuery.Where(q.Product.Name.Like(\"%\" + name + \"%\")) } if category != \"\" { productQuery = productQuery.Where(q.Product.Category.Eq(category)) } if minPrice != \"\" { if price, err := strconv.ParseFloat(minPrice, 64); err == nil { productQuery = productQuery.Where(q.Product.Price.Gte(price)) } } if maxPrice != \"\" { if price, err := strconv.ParseFloat(maxPrice, 64); err == nil { productQuery = productQuery.Where(q.Product.Price.Lte(price)) } } // 検索実行 products, err := productQuery.Find() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to search products\", }) return } c.JSON(http.StatusOK, gin.H{ \"data\": products, \"count\": len(products), }) } 変換後のクエリを見てみる。 $ http://localhost:8080/api/products/search?name=Test&category=Electronics&min_price=1400&max_price=1600 SELECT * FROM \"products\" WHERE \"products\".\"name\" LIKE \'%Test%\' AND \"products\".\"category\" = \'Electronics\' AND \"products\".\"price\" >= 1400 AND \"products\".\"price\" <= 1600 AND "products"."deleted_at" IS NULL まとめ GolangのORMであるGormをラップしたGenを使って、CRUDを行うAPIをGinで書いて動かしてみた。 確かにGormレベル(SQLレベル)の記述が不要であることを確認した。 (まだ見ていないが)テーブルをJOINしていった先にGormを素で触らないといけない場面は発生するだろうが、 多くのシナリオでGenだけで行けるのであれば、Genを導入するメリットとなるのではないだろうか。

やりなおし統計

Fellegi-Sunterモデルに基づく確率的名寄せパッケージ Splinkを試してみる

[mathjax] Record Linkage、Entity Recognitionなど、いわゆる「名寄せ」問題には、割とさまざまな解決策が 提案されている。その1つに確率論的な数学的背景を持つFellegi-Sunterモデルを実装したSplinkがある。 数学的背景について咀嚼して説明してくれているサイトはあまりなく、Splinkの公式ドキュメントが唯一 に近いソースだろうと思う。 今回は5万人の人物データを使って重複を除去するSplinkの主要な機能のサンプルを追ってみた。 数式を読んだところで実りは薄い、ということもあり、なるべく数式を読まずに読み進めたが、 モデルのパラメタ、予測結果を解釈する際にどうしても数式を読まないとわからない箇所があった。 特にベイズ統計について理解がかなり浅くて、正直なところ理解しきれてはいない。 Splinkの公式ドキュメントを読み進めていった際に気付いたメモを書いてみることにした。 [arst_toc tag=\"h4\"] 環境構築・準備 環境構築の実装部分は主な学習要素ではないのでClaude Codeでサクッと作ってみた。 まずはPlan modeで計画を立ててもらう。これぐらいだとほぼ言い値で使える。 DuckDB、Splinkを使う最小構成の環境を作ります。環境はDocker composeで構築します。計画を立ててください。 ⏺ DuckDB、Splinkを使うDocker Compose環境の計画を完了しました。 計画概要 1. 設計文書: .tmp/design.mdで要件、アーキテクチャ、技術仕様を定義 2. タスク管理: .tmp/task.mdで詳細な実装手順を3フェーズに分けて整理 主要コンポーネント - Jupyter Lab環境: DuckDB、Splink、必要なライブラリがプリインストール - データボリューム: サンプルデータ、notebook、出力用のディレクトリ - 最小構成: 開発・学習用途に特化した軽量環境 次のステップ 実装を開始する場合は、Phase 1から順番に進めます: 1. Docker設定ファイル作成 2. ディレクトリ構造構築 3. サンプル実装 4. テストと検証 計画の詳細は.tmp/design.mdと.tmp/task.mdをご確認ください。実装を開始しますか? 実際にファイルを作ってもらうと以下のような構成となる。 得られる中間成果物である docker-compose.yml、Dockerfile は本記事では省略する。 . ├── data ├── docker-compose.yml ├── Dockerfile ├── notebooks ├── output └── requirements.txt 普通に http://localhost:8888 で JupyterLab が開く。 使用するサンプルデータ 5万人の人物データを使って名寄せを行うサンプル。おそらくSplinkの用途として最初に思いつくやつ。 Splinkにデータをロードする前に必要なデータクリーニング手順について説明がある。 公式によると、まずは行に一意のIDを割り当てる必要がある。 データセット内で一意となるIDであって、重複除去した後のエンティティを識別するIDのことではない。 [clink implicit=\"false\" url=\"https://moj-analytical-services.github.io/splink/demos/tutorials/01_Prerequisites.html\" imgurl=\"https://user-images.githubusercontent.com/7570107/85285114-3969ac00-b488-11ea-88ff-5fca1b34af1f.png\" title=\"Data Prerequisites\" excerpt=\"Splink では、リンクする前にデータをクリーンアップし、行に一意の ID を割り当てる必要があります。このセクションでは、Splink にデータをロードする前に必要な追加のデータクリーニング手順について説明します。\"] 使用するサンプルデータは以下の通り。 from splink import splink_datasets df = splink_datasets.historical_50k df.head() データの分布を可視化 splink.exploratoryのprofile_columnsを使って分布を可視化してみる。 from splink import DuckDBAPI from splink.exploratory import profile_columns db_api = DuckDBAPI() profile_columns(df, db_api, column_expressions=[\"first_name\", \"substr(surname,1,2)\"]) 同じ姓・名の人が大量にいることがわかる。 ブロッキングとブロッキングルールの評価 テーブル内のレコードが他のレコードと「同一かどうか」を調べるためには、 基本的には、他のすべてのレコードとの何らかの比較操作を行うこととなる。 全てのレコードについて全てのカラム同士を比較したいのなら、 対象のテーブルをCROSS JOINした結果、各カラム同士を比較することとなる。 SELECT ... FROM input_tables as l CROSS JOIN input_tables as r あるカラムが条件に合わなければ、もうその先は見ても意味がない、 というケースは多い。例えば、まず first_name 、surname が同じでなければ、 その先の比較を行わない、というのはあり得る。 SELECT ... FROM input_tables as l INNER JOIN input_tables as r ON l.first_name = r.first_name AND l.surname = r.surname このような考え方をブロッキング、ON句の条件をブロッキングルールと言う。 ただ、これだと性と名が完全一致していないレコードが残らない。 そこで、ブロッキングルールを複数定義し、いずれかが真であれば残すことができる。 ここでポイントなのが、ブロッキングルールを複数定義したとき、 それぞれのブロッキングルールで重複して選ばれるレコードが発生した場合、 Splinkが自動的に排除してくれる。 このため、ブロッキングルールを重ねがけすると、最終的に残るレコード数は一致する。 ただ、順番により、同じルールで残るレコード数は変化する。 逆に言うと、ブロッキングルールを足すことで、重複除去後のOR条件が増えていく。 積算グラフにして、ブロッキングルールとその順番の効果を見ることができる。 from splink import DuckDBAPI, block_on from splink.blocking_analysis import ( cumulative_comparisons_to_be_scored_from_blocking_rules_chart, ) blocking_rules = [ block_on(\"substr(first_name,1,3)\", \"substr(surname,1,4)\"), block_on(\"surname\", \"dob\"), block_on(\"first_name\", \"dob\"), block_on(\"postcode_fake\", \"first_name\"), block_on(\"postcode_fake\", \"surname\"), block_on(\"dob\", \"birth_place\"), block_on(\"substr(postcode_fake,1,3)\", \"dob\"), block_on(\"substr(postcode_fake,1,3)\", \"first_name\"), block_on(\"substr(postcode_fake,1,3)\", \"surname\"), block_on(\"substr(first_name,1,2)\", \"substr(surname,1,2)\", \"substr(dob,1,4)\"), ] db_api = DuckDBAPI() cumulative_comparisons_to_be_scored_from_blocking_rules_chart( table_or_tables=df, blocking_rules=blocking_rules, db_api=db_api, link_type=\"dedupe_only\", ) 積算グラフは以下の通り。積み上がっている数値は「比較の数」。 要は、論理和で条件を足していって、次第に緩和されている様子がわかる。 DuckDBでは比較の数を2,000万件以内、Athena,Sparkでは1億件以内を目安にせよとのこと。 比較の定義 Splinkは Fellegi-Sunter model モデル (というかフレームワーク) に基づいている。 https://moj-analytical-services.github.io/splink/topic_guides/theory/fellegi_sunter.html 各カラムの同士をカラムの特性に応じた距離を使って比較し、重みを計算していく。 各カラムの比較に使うためのメソッドが予め用意されているので、特性に応じて選んでいく。 以下では、first_name, sur_name に ForenameSurnameComparison が使われている。 dobにDateOfBirthComparison、birth_place、ocupationにExactMatchが使われている。 import splink.comparison_library as cl from splink import Linker, SettingsCreator settings = SettingsCreator( link_type=\"dedupe_only\", blocking_rules_to_generate_predictions=blocking_rules, comparisons=[ cl.ForenameSurnameComparison( \"first_name\", \"surname\", forename_surname_concat_col_name=\"first_name_surname_concat\", ), cl.DateOfBirthComparison( \"dob\", input_is_string=True ), cl.PostcodeComparison(\"postcode_fake\"), cl.ExactMatch(\"birth_place\").configure(term_frequency_adjustments=True), cl.ExactMatch(\"occupation\").configure(term_frequency_adjustments=True), ], retain_intermediate_calculation_columns=True, ) # Needed to apply term frequencies to first+surname comparison df[\"first_name_surname_concat\"] = df[\"first_name\"] + \" \" + df[\"surname\"] linker = Linker(df, settings, db_api=db_api) ComparisonとComparison Level ここでSplinkツール内の比較の概念の説明。以下の通り概念に名前がついている。 Data Linking Model ├─-- Comparison: Date of birth │ ├─-- ComparisonLevel: Exact match │ ├─-- ComparisonLevel: One character difference │ ├─-- ComparisonLevel: All other ├─-- Comparison: First name │ ├─-- ComparisonLevel: Exact match on first_name │ ├─-- ComparisonLevel: first_names have JaroWinklerSimilarity > 0.95 │ ├─-- ComparisonLevel: first_names have JaroWinklerSimilarity > 0.8 │ ├─-- ComparisonLevel: All other モデルのパラメタ推定 モデルの実行に必要なパラメタは以下の3つ。Splinkを用いてパラメタを得る。 ちなみに u は \"\'U\'nmatch\"、m は \"\'M\'atch\"。背後の数式の説明で現れる。 No パラメタ 説明 1 無作為に選んだレコードが一致する確率 入力データからランダムに取得した2つのレコードが一致する確率 (通常は非常に小さい数値) 2 u値(u確率) 実際には一致しないレコードの中で各 ComparisonLevel に該当するレコードの割合。具体的には、レコード同士が同じエンティティを表すにも関わらず値が異なる確率。例えば、同じ人なのにレコードによって生年月日が違う確率。これは端的には「データ品質」を表す。名前であればタイプミス、別名、ニックネーム、ミドルネーム、結婚後の姓など。 3 m値(m確率) 実際に一致するレコードの中で各 ComparisonLevel に該当するレコードの割合。具体的には、レコード同士が異なるエンティティを表すにも関わらず値が同じである確率。例えば別人なのにレコードによって性・名が同じ確率 (同姓同名)。性別は男か女かしかないので別人でも50%の確率で一致してしまう。 無作為に選んだレコードが一致する確率 入力データからランダムに抽出した2つのレコードが一致する確率を求める。 値は0.000136。すべての可能なレコードのペア比較のうち7,362.31組に1組が一致すると予想される。 合計1,279,041,753組の比較が可能なため、一致するペアは合計で約173,728.33組になると予想される、 とのこと。 linker.training.estimate_probability_two_random_records_match( [ block_on(\"first_name\", \"surname\", \"dob\"), block_on(\"substr(first_name,1,2)\", \"surname\", \"substr(postcode_fake,1,2)\"), block_on(\"dob\", \"postcode_fake\"), ], recall=0.6, ) > Probability two random records match is estimated to be 0.000136. > This means that amongst all possible pairwise record comparisons, > one in 7,362.31 are expected to match. > With 1,279,041,753 total possible comparisons, > we expect a total of around 173,728.33 matching pairs u確率の推定 実際には一致しないレコードの中でComparisonの評価結果がPositiveである確率。 基本、無作為に抽出したレコードは一致しないため、「無作為に抽出したレコード」を 「実際には一致しないレコード」として扱える、という点がミソ。 probability_two_random_records_match によって得られた値を使ってu確率を求める。 estimate_u_using_random_sampling によって、ラベルなし、つまり教師なしでu確率を得られる。 レコードのペアをランダムでサンプルして上で定義したComparisonを評価する。 ランダムサンプルなので大量の不一致が発生するが、各Comparisonにおける不一致の分布を得ている。 これは、例えば性別について、50%が一致、50%が不一致である、という分布を得ている。 一方、例えば生年月日について、一致する確率は 1%、1 文字の違いがある確率は 3%、 その他はすべて 96% の確率で発生する、という分布を得ている。 linker.training.estimate_u_using_random_sampling(max_pairs=5e6) > ----- Estimating u probabilities using random sampling ----- > > Estimated u probabilities using random sampling > > Your model is not yet fully trained. Missing estimates for: > - first_name_surname (no m values are trained). > - dob (no m values are trained). > - postcode_fake (no m values are trained). > - birth_place (no m values are trained). > - occupation (no m values are trained). m確率の推定 「実際に一致するレコード」の中で、Comparisonの評価がNegativeになる確率。 そもそも、このモデルを使って名寄せ、つまり「一致するレコード」を見つけたいのだから、 モデルを作るために「実際に一致するレコード」を計算しなければならないのは矛盾では..となる。 無作為抽出結果から求められるu確率とは異なり、m確率を求めるのは難しい。 もしラベル付けされた「一致するレコード」、つまり教師データセットがあるのであれば、 そのデータセットを使ってm確率を求められる。 例えば、日本人全員にマイナンバーが振られて、全てのレコードにマイナンバーが振られている、 というアナザーワールドがあるのであれば、マイナンバーを使ってm確率を推定する。(どういう状況??) ラベル付けされたデータがないのであれば、EMアルゴリズムでm確率を求めることになっている。 EMアルゴリズムは反復的な手法で、メモリや収束速度の点でペア数を減らす必要があり、 例ではブロッキングルールを設定している。 以下のケースでは、first_nameとsurnameをブロッキングルールとしている。 つまり、first_name, surnameが完全に一致するレコードについてペア比較を行う。 この仮定を設定したため、first_name, surname (first_name_surname) のパラメタを推定できない。 training_blocking_rule = block_on(\"first_name\", \"surname\") training_session_names = ( linker.training.estimate_parameters_using_expectation_maximisation( training_blocking_rule, estimate_without_term_frequencies=True ) ) > ----- Starting EM training session ----- > > Estimating the m probabilities of the model by blocking on: > (l.\"first_name\" = r.\"first_name\") AND (l.\"surname\" = r.\"surname\") > > Parameter estimates will be made for the following comparison(s): > - dob > - postcode_fake > - birth_place > - occupation > > Parameter estimates cannot be made for the following comparison(s) since they are used in the blocking rules: > - first_name_surname > > Iteration 1: Largest change in params was 0.248 in probability_two_random_records_match > Iteration 2: Largest change in params was 0.0929 in probability_two_random_records_match > Iteration 3: Largest change in params was -0.0237 in the m_probability of birth_place, level `Exact match on > birth_place` > Iteration 4: Largest change in params was 0.00961 in the m_probability of birth_place, level `All other >comparisons` > Iteration 5: Largest change in params was -0.00457 in the m_probability of birth_place, level `Exact match on birth_place` > Iteration 6: Largest change in params was -0.00256 in the m_probability of birth_place, level `Exact match on birth_place` > Iteration 7: Largest change in params was 0.00171 in the m_probability of dob, level `Abs date difference Iteration 8: Largest change in params was 0.00115 in the m_probability of dob, level `Abs date difference Iteration 9: Largest change in params was 0.000759 in the m_probability of dob, level `Abs date difference Iteration 10: Largest change in params was 0.000498 in the m_probability of dob, level `Abs date difference Iteration 11: Largest change in params was 0.000326 in the m_probability of dob, level `Abs date difference Iteration 12: Largest change in params was 0.000213 in the m_probability of dob, level `Abs date difference Iteration 13: Largest change in params was 0.000139 in the m_probability of dob, level `Abs date difference Iteration 14: Largest change in params was 9.04e-05 in the m_probability of dob, level `Abs date difference <= 10 year` 同様にdobをブロッキングルールに設定して実行すると、dob以外の列についてパラメタを推定できる。 training_blocking_rule = block_on(\"dob\") training_session_dob = ( linker.training.estimate_parameters_using_expectation_maximisation( training_blocking_rule, estimate_without_term_frequencies=True ) ) > ----- Starting EM training session ----- > > Estimating the m probabilities of the model by blocking on: > l.\"dob\" = r.\"dob\" > > Parameter estimates will be made for the following comparison(s): > - first_name_surname > - postcode_fake > - birth_place > - occupation > > Parameter estimates cannot be made for the following comparison(s) since they are used in the blocking rules: > - dob > > Iteration 1: Largest change in params was -0.474 in the m_probability of first_name_surname, level `Exact match on first_name_surname_concat` > Iteration 2: Largest change in params was 0.052 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 3: Largest change in params was 0.0174 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 4: Largest change in params was 0.00532 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 5: Largest change in params was 0.00165 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 6: Largest change in params was 0.00052 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 7: Largest change in params was 0.000165 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 8: Largest change in params was 5.29e-05 in the m_probability of first_name_surname, level `All other comparisons` > > EM converged after 8 iterations > > Your model is not yet fully trained. Missing estimates for: > - first_name_surname (some u values are not trained). モデルパラメタの可視化 m確率、u確率の可視化。 マッチウェイトの可視化。マッチウェイトは (log_2 (m / u))で計算される。 linker.visualisations.match_weights_chart() モデルの保存と読み込み 以下でモデルを保存できる。 settings = linker.misc.save_model_to_json( \"./saved_model_from_demo.json\", overwrite=True ) 以下で保存したモデルを読み込める。 import json settings = json.load( open(\'./saved_model_from_demo.json\', \'r\') ) リンクするのに十分な情報が含まれていないレコード 「John Smith」のみを含み、他のすべてのフィールドがnullであるレコードは、 他のレコードにリンクされている可能性もあるが、潜在的なリンクを明確にするには十分な情報がない。 以下により可視化できる。 linker.evaluation.unlinkables_chart() 横軸は「マッチウェイトの閾値」。縦軸は「リンクするのに十分な情報が含まれないレコード」の割合。 マッチウェイト閾値=6.11ぐらいのところを見ると、入力データセットのレコードの約1.3%が リンクできないことが示唆される。 訓練済みモデルを使って未知データのマッチウェイトを予測 上で構築した推定モデルを使用し、どのペア比較が一致するかを予測する。 内部的には以下を行うとのこと。 blocking_rules_to_generate_predictionsの少なくとも1つと一致するペア比較を生成 Comparisonで指定されたルールを使用して、入力データの類似性を評価 推定された一致重みを使用し、要求に応じて用語頻度調整を適用して、最終的な一致重みと一致確率スコアを生成 df_predictions = linker.inference.predict(threshold_match_probability=0.2) df_predictions.as_pandas_dataframe(limit=1) > Blocking time: 0.88 seconds > Predict time: 1.91 seconds > > -- WARNING -- > You have called predict(), but there are some parameter estimates which have neither been estimated or > specified in your settings dictionary. To produce predictions the following untrained trained parameters will > use default values. > Comparison: \'first_name_surname\': > u values not fully trained records_to_plot = df_e.to_dict(orient=\"records\") linker.visualisations.waterfall_chart(records_to_plot, filter_nulls=False) predictしたマッチウェイトの可視化、数式との照合 predictしたマッチウェイトは、ウォーターフォール図で可視化できる。 マッチウェイトは、モデル内の各特徴量によって一致の証拠がどの程度提供されるかを示す中心的な指標。 (lambda)は無作為抽出した2つのレコードが一致する確率。(K=m/u)はベイズ因子。 begin{align} M &= log_2 ( frac{lambda}{1-lambda} ) + log_2 K \\ &= log_2 ( frac{lambda}{1-lambda} ) + log_2 m - log_2 u end{align} 異なる列の比較が互いに独立しているという仮定を置いていて、 2つのレコードのベイズ係数が各列比較のベイズ係数の積として扱う。 begin{eqnarray} K_{feature} = K_{first_name_surname} + K_{dob} + K_{postcode_fake} + K_{birth_place} + K_{occupation} + cdots end{eqnarray} マッチウェイトは以下の和。 begin{eqnarray} M_{observe} = M_{prior} + M_{feature} end{eqnarray} ここで begin{align} M_{prior} &= log_2 (frac{lambda}{1-lambda}) \\ M_{feature} &= M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots end{align} 以下のように書き換える。 begin{align} M_{observe} &= log_2 (frac{lambda}{1-lambda}) + sum_i^{feature} log_2 (frac{m_i}{u_i}) \\ &= log_2 (frac{lambda}{1-lambda}) + log_2 (prod_i^{feature} (frac{m_i}{u_i}) ) end{align} ウォーターフォール図の一番左、赤いバーは(M_{prior} = log_2 (frac{lambda}{1-lambda}))。 特徴に関する追加の知識が考慮されていない場合のマッチウェイト。 横に並んでいる薄い緑のバーは (M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots)。 各特徴量のマッチウェイト。 一番右の濃い緑のバーは2つのレコードの合計マッチウェイト。 begin{align} M_{feature} &= M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots \\ &= 8.50w end{align} まとめ 長くなったのでいったん終了。この記事では教師なし確率的名寄せパッケージSplinkを使用してモデルを作ってみた。 次の記事では、作ったモデルを使用して実際に名寄せをしてみる。 途中、DuckDBが楽しいことに気づいたので、DuckDBだけで何個か記事にしてみようと思う。

分散と標準偏差を計算しやすくする

[mathjax] 分散と標準偏差を計算しやすく変形できる。 いちいち偏差(x_i-bar{x})を計算しておかなくても、2乗和(x_i^2)と平均(bar{x})がわかっていればOK。 begin{eqnarray} s^2 &=& frac{1}{n} sum_{i=1}^n (x_i - bar{x})^2 \\ &=& frac{1}{n} sum_{i=1}^n ( x_i^2 -2 x_i bar{x} + bar{x}^2 ) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 -2 bar{x} sum_{i=1}^n x_i + bar{x}^2 sum_{i=1}^n 1) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 -2 n bar{x}^2 + nbar{x}^2 ) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 - nbar{x}^2 ) \\ &=& frac{1}{n} sum_{i=1}^n x_i^2 - bar{x}^2 \\ s &=& sqrt{frac{1}{n} sum_{i=1}^n x_i^2 - bar{x}^2 } end{eqnarray} 以下みたい使える。平均と標準偏差と2乗和の関係。 begin{eqnarray} sum_{i=1}^n (x_i - bar{x})^2 &=& sum_{i=1}^n x_i^2 - nbar{x}^2 \\ ns^2 &=& sum_{i=1}^n x_i^2 - nbar{x}^2 \\ sum_{i=1}^n x_i^2 &=& n(s^2 + bar{x} ) end{eqnarray}

標本調査に必要なサンプル数の下限を与える2次関数

[mathjax] 2項分布に従う母集団の母平均を推測するために有意水準を設定して95%信頼区間を求めてみた。 母平均のあたりがついていない状況だとやりにくい。 [clink url=\"https://ikuty.com/2019/01/11/sampling/\"] (hat{p})がどんな値であっても下限は(hat{p})の関数で抑えられると思ったので、 気になって(hat{p})を変数のまま残すとどうなるかやってみた。 begin{eqnarray} 1.96sqrt{frac{hat{p}(1-hat{p})}{n}} le 0.05 \\ frac{1.96}{0.05}sqrt{hat{p}(1-hat{p})} le sqrt{n} \\ 39.2^2 hat{p}(1-hat{p}) le n end{eqnarray} 左辺を(f(hat{p}))と置くと (f(hat{p}))は下に凸の2次関数であって、 (frac{d}{dhat{p}}f(hat{p})=0)の時に最大となる。というか(hat{p}=0.5)。 (hat{p}=0.5)であるとすると、これはアンケートを取るときのサンプル数を求める式と同じで、 非常に有名な以下の定数が出てくる。 begin{eqnarray} 1537 * 0.5 (1-0.5) le n \\ 384 le n end{eqnarray} (hat{p})がどんな値であっても、サンプル数を400とれば、 有意水準=5%の95%信頼区間を得られる。 だから、アンケートの(n)数はだいたい400で、となる。 さらに、有意水準を10%にとれば、(n)の下限は100で抑えられる。 なるはやのアンケートなら100、ちゃんとやるには400、というやつがこれ。

深層学習

勾配降下法

[mathjax] 各地点において関数の値を最大にするベクトル((frac{partial f}{partial x_0},frac{partial f}{partial x_1}))を全地点に対して計算したものを勾配とかいう。 ある地点において、このベクトルの方向に向かうことにより最も関数の値を大きくする。 で、今後のために正負を反転して関数の値を最小にするベクトルを考えることにした。 関数の値を小さくする操作を繰り返していけば、いずれ\"最小値\"が見つかるはず。 というモチベを続けるのが勾配降下法。学習率(eta)を使って以下みたいに書ける。。 begin{eqnarray} x_0 = x_0 - eta frac{partial f}{partial x_0} \\ x_1 = x_1 - eta frac{partial f}{partial x_1} end{eqnarray} ということで(f(x_0,x_1)=x_0^2+x_1^2)の最小値を初期値((3.0,4.0))、 学習率(eta=0.1)に設定して計算してみる。 import numpy as np def numerical_gradient(f, x): h = 1e-4 grad = np.zeros_like(x) for idx in range(x.size): tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) x[idx] = tmp_val - h fxh2 = f(x) grad[idx] = (fxh1 - fxh2) / (2*h) x[idx] = tmp_val return grad def gradient_descent(f, init_x, lr=0.01, step_num=100): x = init_x for i in range(step_num): grad = numerical_gradient(f,x) x -= lr * grad return x def function2(x): return x[0]**2 + x[1]**2 init_x = np.array([-3.0, 4.0]) v = gradient_descent(function2, init_x=init_x, lr=0.1, step_num=100) v # array([-6.11110793e-10, 8.14814391e-10]) ((0,0))に収束した。 ニューラルネットワークの勾配 損失関数を重みパラメータで微分する。以下みたいな感じ。 損失関数の大小を見るとして、例えば(w_{11})以外の重みを固定したとして(w_{11})をわずかに 増やしたときに損失関数の値がどれだけ大きくなるか。 損失関数の値はパラメータ(W)と入力(x)から決まるベクトルだけれども、それぞれ乱数と入力値が設定されている。 begin{eqnarray} W= begin{pmatrix} w_{11} & w_{12} & w_{13} \\ w_{21} & w_{22} & w_{23} end{pmatrix}, frac{partial L}{partial W}= begin{pmatrix} frac{partial L}{partial w_{11}} & frac{partial L}{partial w_{12}} & frac{partial L}{partial w_{13}} \\ frac{partial L}{partial w_{21}} & frac{partial L}{partial w_{22}} & frac{partial L}{partial w_{23}} end{pmatrix} end{eqnarray} 重み(W)が乱数で決まるネットワークがあるとする。このネットワークは入力と重みの積を出力 として返す。出力はSoftmaxを経由するとする。 ネットワークの出力と教師データのクロスエントロピー誤差を誤差として使う。 その前に、数値微分関数を多次元対応する。 普通、配列の次元が(n)個になると(n)重ループが必要になるけれども、 Numpy.nditer()を使うと(n)乗ループを1回のループにまとめることができる。 下のmulti_indexが((0,0),(0,1),(0,2),(1,0),(1,1),(1,2))みたいに イテレータが(n)次のタプルを返す。反復回数はタプルの要素数の直積。 Numpy配列にそのタプルでアクセスすることで晴れて全ての要素にアクセスできる。 def numerical_gradient_md(f, x): h = 1e-4 grad = np.zeros_like(x) it = np.nditer(x, flags=[\'multi_index\'], op_flags=[\'readwrite\']) while not it.finished: idx = it.multi_index tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) # f(x+h) x[idx] = tmp_val - h fxh2 = f(x) # f(x-h) grad[idx] = (fxh1 - fxh2) / (2*h) x[idx] = tmp_val # 値を元に戻す it.iternext() return grad 初期値(x=(0.6,0.9))、教師データ(t=(0,0,1))をネットワークに入力する。 predict()は(1 times 3)を返す。 それをSoftmax()を通して、(t)とのクロスエントロピー誤差を求めたものが以下。 import numpy as np def cross_entropy_error(y, t): if y.ndim == 1: t = t.reshape(1, t.size) y = y.reshape(1,y.size) batch_size = y.shape[0] delta = 1e-7 return -np.sum( t * np.log( y + delta)) / batch_size def softmax(x): c = np.max(x) return np.exp(x-c) / np.sum(np.exp(x-c)) import sys, os sys.path.append(os.pardir) import numpy as np class simpleNet: def __init__(self): self.W = np.random.randn(2,3) def predict(self, x): return np.dot(x, self.W) def loss(self, x, t): z = self.predict(x) y = softmax(z) loss = cross_entropy_error(y, t) return loss net = simpleNet() x = np.array([0.6, 0.9]) p = net.predict(x) t = np.array([0, 0, 1]) net.loss(x, t) # 0.9463818740797788 このlossを(W)で微分したのが以下。 あえてパラメータ(W)を引数にとり損失関数の値を計算する(f(W))を定義することで、 数値微分が何と何の演算なのかをわかりやすくしている。 実際は(f(W))は(W)とは関係なく(x)と(t)だけから結果を返すけれども、 損失関数(f(W))を(W)で微分するという操作が自明になるようにコードを合わせている。 def f(W): return net.loss(x, t) dW = numerical_gradient_md(f, net.W) dW # array([[ 0.07627371, 0.49923236, -0.57550607], # [ 0.11441057, 0.74884853, -0.8632591 ]]) 結果の解釈 上記の(w),(W),(t)から(frac{partial L}{partial W})が求まった。 損失関数が何か複雑な形をしているという状況で、 (frac{partial L}{partial w_{11}})は(w_{11})がわずかに動いたときに損失関数の値が変化する量を表している。 それが(w_{11})から(w_{32})まで6個分存在する。 begin{eqnarray} frac{partial L}{partial W} = begin{pmatrix} frac{partial L}{partial w_{11}} & frac{partial L}{partial w_{21}} & frac{partial L}{partial w_{31}} \\ frac{partial L}{partial w_{12}} & frac{partial L}{partial w_{22}} & frac{partial L}{partial w_{32}} end{pmatrix} = begin{pmatrix} 0.07627371 & 0.49923236 & -0.57550607 \\ 0.11441057 & 0.74884853 & -0.8632591 end{pmatrix} end{eqnarray}

勾配の可視化

[mathjax] 2変数関数(f(x_0,x_1))を各変数で偏微分する。 地点((i,j))におけるベクトル((frac{partial f(x_0,j)}{partial x_0},frac{partial f(i,x_1)}{partial x_1}))を全地点で記録していき、ベクトル場を得る。 このベクトル場が勾配(gradient)。 (f(x_0,x_1)=x_0^2+x_1^2)について、(-4.0 le x_0 le 4.0)、(-4.0 le x_1 le 4.0)の範囲で、 勾配を求めてみる。また、勾配を可視化してみる。 まず、2変数関数(f(x_0,x_1))の偏微分係数を求める関数の定義。 ((3.0,3.0))の偏微分係数は((6.00..,6.00..))。 def numerical_gradient(f, x): h = 10e-4 grad = np.zeros_like(x) for idx in range(x.size): tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) x[idx] = tmp_val - h fxh2 = f(x) grad[idx] = (fxh1 - fxh2) / 2*h x[idx] = tmp_val return grad def function2(x): return x[0]**2 + x[1]**2 p = np.array([3.0,3.0]) v = numerical_gradient(function2, p) v # array([6.e-06, 6.e-06]) (-4.0 le x_0 le 4.0)、(-4.0 le x_1 le 4.0)の範囲((0.5)刻み)で偏微分係数を求めて、 ベクトル場っぽく表示してみる。matplotlibのquiver()は便利。 各地点において関数の値を最も増やす方向が表示されている。 w_range = 4 dw = 0.5 w0 = np.arange(-w_range, w_range, dw) w1 = np.arange(-w_range, w_range, dw) wn = w0.shape[0] diff_w0 = np.zeros((len(w0), len(w1))) diff_w1 = np.zeros((len(w0), len(w1))) for i0 in range(wn): for i1 in range(wn): d = numerical_gradient(function2, np.array([ w0[i0], w1[i1] ])) diff_w0[i1, i0], diff_w1[i1, i0] = (d[0], d[1]) plt.xlabel(\'$x_0$\',fontsize=14) #x軸のラベル plt.ylabel(\'$x_1$\',fontsize=14) #y軸のラベル plt.xticks(range(-w_range,w_range+1,1)) #x軸に表示する値 plt.yticks(range(-w_range,w_range+1,1)) #y軸に表示する値 plt.quiver(w0, w1, diff_w0, diff_w1) plt.show() 値が大きい方向に矢印が向いている。例えば((-3.0,3.0))における偏微分係数は((-6.0,6.0))。 左上方向へのベクトル。 参考にしている本にはことわりが書いてあり、勾配にマイナスをつけたものを図にしている。 その場合、関数の値を最も減らす方向が表示されることになる。 各地点において、この勾配を参照することで、どちらに移動すれば関数の値を最も小さくできるかがわかる。

おっさんが数値微分を復習する

引き続き、ゼロDの写経を続ける。今回は、学生の頃が懐かしい懐ワード、数値微分。 全然Deepに入れないけれどおっさんの復習。解析的な微分と数値微分が一致するところを確認してみる。 昔と違うのは、PythonとJupyterNotebookで超絶簡単に実験できるし、 こうやってWordPressでLaTeXで記事を書いたりできる点。 [mathjax] まず、微分の基本的な考え方は以下の通り。高校数学の数3の範囲。 begin{eqnarray} frac{df(x)}{fx} = lim_{hrightarrow infty} frac{f(x+h)-f(x)}{h} end{eqnarray} 情報系学科に入って最初の方でEuler法とRunge-Kutta法を教わってコードを 書いたりレポート書いたりする。懐すぎる..。 または、基本情報の試験かなんかで、小さい値と小さい値どうしの計算で発生する問題が現れる。 ゼロDにはこの定義を少し改良した方法が載っている。へぇ。 begin{eqnarray} frac{df(x)}{fx} = lim_{hrightarrow infty} frac{f(x+h)-f(x-h)}{2h} end{eqnarray} 写経なので、がんばって数値微分を書いて動かしてみる。 簡単な2次関数(f(x))。 begin{eqnarray} f(x) &=& x^2 - 5x +3 \\ f\'(x) &=& 2x - 5 end{eqnarray} def numerical_diff(f, x): h = 10e-4 return (f(x+h) - f(x-h)) / (2*h) (f(x))と、(x=2.5)のところの(f\'(x))をmatplotlibで書いてみる。懐い... import matplotlib.pyplot as plt import numpy as np def f(x): return x**2 - 5*x + 3 x = np.arange(-10, 10, 0.1) y = f(x) dy = numerical_diff(f,x) plt.plot(x, y) x1 = -2.5 dy1 = numerical_diff(f, x1) y1 = f(x1) # y-y1 = dy1(x-x1) -> y = dy1(x-x1) + y1 j = lambda x: dy1 * (x-x1) + y1 plt.plot(x,j(x)) plt.xlabel(\'x\') plt.ylabel(\'y\') plt.grid() plt.show() 偏微分 2変数以上の関数の数値微分は以下の通り。片方を止める。 数値微分の方法は上記と同じものを使った。 begin{eqnarray} frac{partial f(x_0,x_1)}{partial x_0} &=& lim_{hrightarrow infty} frac{f(x_0 +h,x_1)-f(x_0-h,x_1)}{2h} \\ frac{partial f(x_0,x_1)}{partial x_1} &=& lim_{hrightarrow infty} frac{f(x_0,x_1+h)-f(x_0,x_1-h)}{2h} end{eqnarray} ((x_0,x_1)=(1,1))における(x_0)に対する偏微分(frac{partial f(x_0,x_1)}{x_0})、(x_1)に対する偏微分(frac{partial f(x_0,x_1)}{x_1})を求めてみる。 ちゃんと(frac{partial f(x_0,1.0)}{x_0}=2.00..)、(frac{partial f(1.0,x_1)}{x_1}=2.00..)になった。 import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.mplot3d import Axes3D def f(x): return x[0]**2 + x[1]**2 X = np.meshgrid(np.arange(-5., 5., 0.2),np.arange(-5., 5., 0.2)) Z = f(X) fig = plt.figure(figsize=(6, 6)) axes = fig.add_subplot(111, projection=\'3d\') axes.plot_surface(X[0],X[1], Z) f0 = lambda x: x**2 + 1.0**2 f1 = lambda x: 1.0**2 + x**2 df0 = numerical_diff(f0, 1.0) df1 = numerical_diff(f1, 1.0) print(df0) # 2.0000000000000018 print(df1) # 2.0000000000000018 plt.show()