日々の学び

gorm互換の型安全なORMであるgenでCRUD APIを試作

GolangとGinでAPIを書くために必要な要素技術を学習中、 ORMは何が良いか考えてみたが、gormは無いなぁと思うに至った。 ORマッパーがどの程度の抽象化を担うべきか、については答えはないと思うが、 Webアプリケーションのシナリオで出てくるテーブル構造と関係程度は完全にSQLを排除して欲しい。 SQLを排除することで可読性が向上するし、静的型付けによる恩恵を得られる。 Genには以下のような特徴がある。 型安全: コンパイル時にエラー検出 自動補完: IDEでメソッドとフィールドが補完される クエリビルダー: Where(q.Product.Name.Like(\"%...\"))のような直感的なAPI GORM互換: 既存のGORMモデルをそのまま使用可能 なぜ\"Gen\"なのかは、ビルド時にGolangコードから静的に(ビルド前に)オブジェクトにアクセスする ために必要なGoオブジェクトを生成する、という仕組みから来ているのではないかと思う。 [clink implicit=\"false\" url=\"https://gorm.io/gen/index.html\" imgurl=\"https://gorm.io/gorm.svg\" title=\"Gen Guides\" excerpt=\"GEN: Friendly & Safer GORM powered by Code Generation.Idiomatic & Reusable API from Dynamic Raw SQL.100% Type-safe DAO API without interface{}.Database To Struct follows GORM conventions.GORM under the hood, supports all features, plugins, DBMS that GORM supports.\"] [arst_toc tag=\"h4\"] 環境構築 サクッとClaudeで環境を作った。実際に商用環境を作るとしたら必要な理解の度合いは上がるだろうが、 試してみるまでの時間が無駄にかかって勿体無いのと、Claudeに入口を教わるのは悪くない。 以下の構成で、Golang+GinにCRUDルートを設定しgenを介してDBアクセスできる。 models以下にテーブルと対応する型定義された構造体が格納される。 また、query以下にGormレベルの操作をGen(Golang)レベルに抽象化する自動生成コードが格納される。 query以下を読むと、GenがGormのラッパーであることが良くわかる。 $ tree . -n 2 . ├── cmd │   └── generate │   └── main.go # マイグレーション ├── config │   └── database.go # DB接続設定 ├── database │   └── database.go # Conenct(), Close(), GetDB()など ├── docker-compose.yml # Golangアプリケーション(8080), PostgreSQL(5432) ├── Dockerfile ├── go.mod ├── go.sum ├── handlers │   └── product.go ├── main.go # CRUD APIのルーティング ├── models │   └── product.go # テーブル->モデル ├── query │   ├── gen.go # モデルを操作するラッパー │   └── products.gen.go # SQLレベルのモデル操作をGolangレベルに抽象化するためのIF └── README.md CRUDルート 早速、CRUD APIのルートを作っていく。Claudeにお任せしたところ商品(Product)のCRUD APIが出来た。 その位置にMigrate置くの本当に良いの? という感があるが、本題はそこではないので省略。 package main import ( \"log\" \"github.com/gin-gonic/gin\" \"github.com/gin-gonic/gin/binding\" \"github.com/ikuty/golang-gin/database\" \"github.com/ikuty/golang-gin/handlers\" \"github.com/ikuty/golang-gin/models\" \"github.com/ikuty/golang-gin/query\" ) func main() { // データベース接続 if err := database.Connect(); err != nil { log.Fatalf(\"Failed to connect to database: %v\", err) } defer database.Close() // マイグレーション実行 db := database.GetDB() if err := db.AutoMigrate(&models.Product{}); err != nil { log.Fatalf(\"Failed to migrate database: %v\", err) } // Gen初期化 query.SetDefault(db) // Ginエンジンの初期化 r := gin.Default() // 8. GORM + PostgreSQL - CRUD操作 r.GET(\"/api/products\", handlers.GetProductsHandler) // 全商品取得 r.GET(\"/api/products/:id\", handlers.GetProductHandler) // 商品詳細取得 r.POST(\"/api/products\", handlers.CreateProductHandler) // 商品作成 r.PUT(\"/api/products/:id\", handlers.UpdateProductHandler) // 商品更新 r.DELETE(\"/api/products/:id\", handlers.DeleteProductHandler) // 商品削除 r.GET(\"/api/products/search\", handlers.SearchProductsHandler) // 商品検索 // サーバー起動 r.Run(\":8080\") } モデル さて、モデル定義(=テーブル構造)はどうなっているかというと、以下の通り。 フィールドの物理型をGenを介してGolangで厳密で管理できるのは動的型付け言語にはない利点。 package models import ( \"time\" \"gorm.io/gorm\" ) // Product は商品モデル type Product struct { ID uint `gorm:\"primarykey\" json:\"id\"` Name string `gorm:\"size:100;not null\" json:\"name\" binding:\"required\"` Description string `gorm:\"size:500\" json:\"description\"` Price float64 `gorm:\"not null\" json:\"price\" binding:\"required,gt=0\"` Stock int `gorm:\"default:0\" json:\"stock\"` Category string `gorm:\"size:50\" json:\"category\"` CreatedAt time.Time `json:\"created_at\"` UpdatedAt time.Time `json:\"updated_at\"` DeletedAt gorm.DeletedAt `gorm:\"index\" json:\"-\"` } // TableName はテーブル名を指定 func (Product) TableName() string { return \"products\" } ハンドラ(商品詳細取得) 素晴らしい。説明が不要なくらいDBアクセスが抽象化されている。 ただ、依存性注入があるEloquentと比べるとロジックと関係ない冗長な処理が残っている。 db,q,Contextは裏側に隠して欲しいという思いはあるものの、これでも良いかとも思う。 Find()はGenにより自動生成される。interfaceが用意されビルド時に全て解決される。 なお、VSCodeなどで補完が効く、というのは、例えば JetBrains環境であれば、 動的型付け言語であってもほぼ実現されているので、それほど実利があるメリットではない。 package handlers import ( \"net/http\" \"strconv\" \"github.com/gin-gonic/gin\" \"github.com/ikuty/golang-gin/database\" \"github.com/ikuty/golang-gin/models\" \"github.com/ikuty/golang-gin/query\" ) // GetProductsHandler は全商品を取得 func GetProductsHandler(c *gin.Context) { db := database.GetDB() q := query.Use(db) products, err := q.Product.WithContext(c.Request.Context()).Find() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to fetch products\", }) return } c.JSON(http.StatusOK, gin.H{ \"data\": products, \"count\": len(products), }) } ハンドラ(指定の商品を取得) バリデータを介さず自力でバリデーション(IDがUintか)を行っている。 Productに対してWhereで条件指定し(Order By Ascした後に)先頭のオブジェクトを取得している。 もはや他に説明が必要ないくらい抽象化されていて良い。 // GetProductHandler は指定IDの商品を取得 func GetProductHandler(c *gin.Context) { id := c.Param(\"id\") idUint, err := strconv.ParseUint(id, 10, 32) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid ID\", }) return } db := database.GetDB() q := query.Use(db) product, err := q.Product.WithContext(c.Request.Context()).Where(q.Product.ID.Eq(uint(idUint))).First() if err != nil { c.JSON(http.StatusNotFound, gin.H{ \"error\": \"Product not found\", }) return } c.JSON(http.StatusOK, gin.H{ \"data\": product, }) } ハンドラ(商品作成) 次はCreate。モデルオブジェクトを空から生成し入力値をバインドして整形した後に、 Create()に渡している。Create()の内部はGormレベルの(低レイヤの)コードが動く。 // CreateProductHandler は商品を作成 func CreateProductHandler(c *gin.Context) { var product models.Product if err := c.ShouldBindJSON(&product); err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid request\", \"details\": err.Error(), }) return } db := database.GetDB() q := query.Use(db) if err := q.Product.WithContext(c.Request.Context()).Create(&product); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to create product\", }) return } c.JSON(http.StatusCreated, gin.H{ \"message\": \"Product created successfully\", \"data\": product, }) } ハンドラ(商品更新) 基本的にはCreate()と同じ。空モデルに入力値をバインドしUpdate()に渡している。 実行後に更新対象のオブジェクトを取得しているがEloquentは確か更新の戻りがオブジェクトだった。 // UpdateProductHandler は商品を更新 func UpdateProductHandler(c *gin.Context) { id := c.Param(\"id\") idUint, err := strconv.ParseUint(id, 10, 32) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid ID\", }) return } db := database.GetDB() q := query.Use(db) ctx := c.Request.Context() // 既存の商品を取得 product, err := q.Product.WithContext(ctx).Where(q.Product.ID.Eq(uint(idUint))).First() if err != nil { c.JSON(http.StatusNotFound, gin.H{ \"error\": \"Product not found\", }) return } // 更新データをバインド var updateData models.Product if err := c.ShouldBindJSON(&updateData); err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid request\", \"details\": err.Error(), }) return } // 更新実行 _, err = q.Product.WithContext(ctx).Where(q.Product.ID.Eq(uint(idUint))).Updates(&updateData) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to update product\", }) return } // 更新後のデータを取得 product, _ = q.Product.WithContext(ctx).Where(q.Product.ID.Eq(uint(idUint))).First() c.JSON(http.StatusOK, gin.H{ \"message\": \"Product updated successfully\", \"data\": product, }) } ハンドラ(論理削除) DeletedAtフィールドがNULLの場合、そのレコードはアクティブ。非Nullなら論理削除済み。 Unscoped()を介さずDelete()した場合(つまりデフォルトでは)論理削除となる。 DeletedAtは他のAPIから透過的に扱われる。論理削除状態かどうかは把握しなくて良い。 DeletedAtはデフォルトでは*time.Time型だが、のデータ形式の対応も可能。 // DeleteProductHandler は商品を削除(ソフトデリート) func DeleteProductHandler(c *gin.Context) { id := c.Param(\"id\") idUint, err := strconv.ParseUint(id, 10, 32) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid ID\", }) return } db := database.GetDB() q := query.Use(db) // ソフトデリート実行 _, err = q.Product.WithContext(c.Request.Context()).Where(q.Product.ID.Eq(uint(idUint))).Delete() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to delete product\", }) return } c.JSON(http.StatusOK, gin.H{ \"message\": \"Product deleted successfully\", }) } ハンドラ(商品検索) Where句を複数記述する場合など、手続き的に条件用のオブジェクトを足していける。 一見、productQueryを上から上書きしているように見えるが、Genのクエリビルダーはimmutableパターン として振る舞い、都度実行によりWhereの戻りとなるオブジェクトが累積していく動作となる。 // SearchProductsHandler は商品を検索 func SearchProductsHandler(c *gin.Context) { db := database.GetDB() q := query.Use(db) ctx := c.Request.Context() // クエリパラメータを取得 name := c.Query(\"name\") category := c.Query(\"category\") minPrice := c.Query(\"min_price\") maxPrice := c.Query(\"max_price\") // クエリビルダー productQuery := q.Product.WithContext(ctx) if name != \"\" { productQuery = productQuery.Where(q.Product.Name.Like(\"%\" + name + \"%\")) } if category != \"\" { productQuery = productQuery.Where(q.Product.Category.Eq(category)) } if minPrice != \"\" { if price, err := strconv.ParseFloat(minPrice, 64); err == nil { productQuery = productQuery.Where(q.Product.Price.Gte(price)) } } if maxPrice != \"\" { if price, err := strconv.ParseFloat(maxPrice, 64); err == nil { productQuery = productQuery.Where(q.Product.Price.Lte(price)) } } // 検索実行 products, err := productQuery.Find() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ \"error\": \"Failed to search products\", }) return } c.JSON(http.StatusOK, gin.H{ \"data\": products, \"count\": len(products), }) } 変換後のクエリを見てみる。 $ http://localhost:8080/api/products/search?name=Test&category=Electronics&min_price=1400&max_price=1600 SELECT * FROM \"products\" WHERE \"products\".\"name\" LIKE \'%Test%\' AND \"products\".\"category\" = \'Electronics\' AND \"products\".\"price\" >= 1400 AND \"products\".\"price\" <= 1600 AND "products"."deleted_at" IS NULL まとめ GolangのORMであるGormをラップしたGenを使って、CRUDを行うAPIをGinで書いて動かしてみた。 確かにGormレベル(SQLレベル)の記述が不要であることを確認した。 (まだ見ていないが)テーブルをJOINしていった先にGormを素で触らないといけない場面は発生するだろうが、 多くのシナリオでGenだけで行けるのであれば、Genを導入するメリットとなるのではないだろうか。

Golang + Gin カスタムバリデーション

Golang+GinによるAPI構築で使いそうなフィーチャーを試してみるシリーズ。 今回はカスタムバリデーションを試してみる。 [clink implicit=\"false\" url=\"https://gin-gonic.com/ja/docs/examples/custom-validators/\" imgurl=\"https://gin-gonic.com/_astro/gin.D6H2T_2v_ZD2G7l.webp\" title=\"カスタムバリデーション\" excerpt=\"カスタムしたバリデーションを使用することもできます。サンプルコード も見てみてください。\"] [arst_toc tag=\"h4\"] ルーティング バリデーションを外部に移譲することで、ハンドラからロジック以外の冗長な処理を除くことができる。 Ginはカスタムバリデータを用意している。以下の例では、ユーザ登録を行うPOSTリクエストの例。 組み込みのバリデーション・バインディングと合わせて、パスワードバリデーションロジックの 追加を行っている。 package main import ( \"github.com/gin-gonic/gin\" \"github.com/gin-gonic/gin/binding\" \"github.com/go-playground/validator/v10\" \"github.com/ikuty/golang-gin/handlers\" ) func main() { // Ginエンジンの初期化 r := gin.Default() // カスタムバリデーターを登録 if v, ok := binding.Validator.Engine().(*validator.Validate); ok { handlers.InitCustomValidators(v) } // 7. カスタムバリデーション r.POST(\"/api/register\", handlers.RegisterValidatorHandler) // サーバー起動 r.Run(\":8080\") } ハンドラ リクエストで受けたJSONをRegisterRequest構造体にバインディングする際に、組み込みの バリデーションルールを定義するのとは別に、strongpassword というカスタムルールを定義している。 strongpasswordルールの実体は strongPassword() 。 例に出現するオブジェクトの使い方は、まぁこう使うのかぐらいで、ありがちな感じ。 カスタムバリデータ関数がチェック結果をTrue/Falseで返せばよさそう。 組み込みバリデータ、または、カスタムバリデータのバリデーション結果と文字列の対応を定義し、 その文字列をレスポンスに付与して返す、というのは良くあるパターンで、 Ginで実装する場合は、 また、カスタムバリデータのバリデーション結果と文字列の対応を定義しレスポンスに含める、 というパターンは良くありそうで、構造体へのバインディングで発生したエラー(err)を取得し、 errに対する型アサーションを行った上で、errを validator.ValidationErrors型として扱う。 動的型付けだと、発生したerrが本当に期待したオブジェクトなのか実行するまで分からなが、 全ての処理が静的型付けを通して、実行前に実行可能であることが確認される。 package handlers import ( \"net/http\" \"regexp\" \"github.com/gin-gonic/gin\" \"github.com/go-playground/validator/v10\" ) // RegisterRequest はユーザー登録リクエストの構造体(高度なバリデーション付き) type RegisterRequest struct { Username string `json:\"username\" binding:\"required,min=3,max=20,alphanum\"` Email string `json:\"email\" binding:\"required,email\"` Password string `json:\"password\" binding:\"required,min=8,max=50,strongpassword\"` Age int `json:\"age\" binding:\"required,gte=18,lte=100\"` Website string `json:\"website\" binding:\"omitempty,url\"` Phone string `json:\"phone\" binding:\"omitempty,e164\"` // E.164 形式の電話番号 } // カスタムバリデーター: 強力なパスワードチェック func strongPassword(fl validator.FieldLevel) bool { password := fl.Field().String() // 最低1つの大文字、1つの小文字、1つの数字を含む hasUpper := regexp.MustCompile(`[A-Z]`).MatchString(password) hasLower := regexp.MustCompile(`[a-z]`).MatchString(password) hasNumber := regexp.MustCompile(`[0-9]`).MatchString(password) return hasUpper && hasLower && hasNumber } // RegisterValidatorHandler はカスタムバリデーターを使用するハンドラー func RegisterValidatorHandler(c *gin.Context) { var req RegisterRequest // JSON をバインド if err := c.ShouldBindJSON(&req); err != nil { // バリデーションエラーを詳細に返す c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Validation failed\", \"details\": formatValidationError(err), }) return } c.JSON(http.StatusCreated, gin.H{ \"message\": \"Registration successful\", \"username\": req.Username, \"email\": req.Email, }) } // formatValidationError はバリデーションエラーをわかりやすく整形 func formatValidationError(err error) []string { var errors []string if validationErrors, ok := err.(validator.ValidationErrors); ok { for _, e := range validationErrors { var message string switch e.Tag() { case \"required\": message = e.Field() + \" is required\" case \"email\": message = e.Field() + \" must be a valid email address\" case \"min\": message = e.Field() + \" must be at least \" + e.Param() + \" characters\" case \"max\": message = e.Field() + \" must be at most \" + e.Param() + \" characters\" case \"alphanum\": message = e.Field() + \" must contain only letters and numbers\" case \"gte\": message = e.Field() + \" must be greater than or equal to \" + e.Param() case \"lte\": message = e.Field() + \" must be less than or equal to \" + e.Param() case \"url\": message = e.Field() + \" must be a valid URL\" case \"e164\": message = e.Field() + \" must be a valid phone number (E.164 format)\" case \"strongpassword\": message = e.Field() + \" must contain at least one uppercase letter, one lowercase letter, and one number\" default: message = e.Field() + \" is invalid\" } errors = append(errors, message) } } else { errors = append(errors, err.Error()) } return errors } // InitCustomValidators はカスタムバリデーターを登録する func InitCustomValidators(v *validator.Validate) { v.RegisterValidation(\"strongpassword\", strongPassword) } 実行結果 リクエストに対してバリデーションが行われ、期待通りバリデーションエラーがアサートされていて、 アサートと対応するカスタム文字列がレスポンスに含まれていることが確認できる。 $ curl -X POST http://localhost:8080/api/register -H \"Content-Type: application/json\" -d \'{\"username\":\"john123\",\"email\":\"john@example.com\",\"password\":\"SecurePass123\",\"age\":25,\"website\":\"https://example.com\"}\' {\"email\":\"john@example.com\",\"message\":\"Registration successful\",\"username\":\"john123\"} 2. ユーザー名が短すぎる {\"details\":[\"Username must be at least 3 characters\"],\"error\":\"Validation failed\"} 3. 弱いパスワード(カスタムバリデーター) {\"details\":[\"Password must contain at least one uppercase letter, one lowercase letter, and one number\"],\"error\":\"Validation failed\"} 4. 年齢が18歳未満 {\"details\":[\"Age must be greater than or equal to 18\"],\"error\":\"Validation failed\"} まとめ 組み込みバリデーションの他に、カスタムバリデーションを追加できることを確認した。 静的型付けにより実行時エラーに頼ることのないある種の堅牢さがあることも見てとれた。

Golang + Gin Framework で Hello World してみた話 〜基本的なルーティング、バスパラメタ・クエリパラメタ・JSON Req/Res、フォームデータ

Golang+GinでAPIを大量に書くことになりそうなので予習することにする。 コード自体はAI Agentで書こうと思うが、まずはGinのフィーチャーを把握する必要がある。 AI Agentを使用してAPI毎にフィーチャーを試せる学習用プロジェクトを構築する。 著者のスペックは、昔仕事でLaravelでWebアプリを書いたことがある。 [arst_toc tag=\"h4\"] Ginについて 🚀 高速なパフォーマンス martini に似たAPIを持ちながら、httprouter のおかげでそれより40倍以上も速いパフォーマンスがあります。 **基数木(Radix Tree)**ベースのルーティングを採用しており、メモリ効率が良く、高速なルーティングを実現しています。 他のGo製Webフレームワークと比較して、ベンチマークで優れた速度を示すことが多く、特に高スループットな REST API や マイクロサービス の構築に適しています。 Laravelは遅くて有名だったが、速いのは良いこと。 Golang自体ネイティブ実行だし、Golang用フレームワークの中でも速度にフィーチャーした構造。 たいした同時実行数を捌かないなら別に遅くても良いし、速いなら良いよね、ぐらい。 🧩 ミドルウェアのサポート 受信したHTTPリクエストを、ミドルウェアのチェーンと最終的なアクション(ハンドラー)で処理する仕組みを提供します。 ロガー、認証、GZIP圧縮など、様々な機能を簡単に組み込むことができます。 ミドルウェアくらい使えないと困るよね。認証を書きたい。 🛡️ クラッシュフリー HTTPリクエスト処理中に発生したpanicをキャッチし、**リカバリー(回復)**する機能が組み込まれています。 これにより、サーバーがクラッシュするのを防ぎ、サービスを常に利用可能な状態に保ちます。 🔗 ルートのグループ化 認証が必要なルートやAPIのバージョンごとなど、関連するルートをグループ化して整理する機能があり、共通のミドルウェアを適用しやすいです。 フルスタックフレームワークではないので、これだけしか書かれていない。 シンプルであることは良いこと。 学習用プロジェクトの構成 いったん、こんな感じで構成。 golang-gin/ ├── docker-compose.yml ├── Dockerfile ├── go.mod ├── go.sum ├── main.go ├── README.md └── handlers/ ├── hello.go # Hello World API ├── params.go # パラメータ処理 ├── json.go # JSON処理 ├── middleware.go # ミドルウェア ├── validation.go # バリデーション ├── file.go # ファイルアップロード └── grouping.go # ルートグループ化 学習計画とAPI API毎にフィーチャーを実装していくスタイルとする。 Claude Codeにその一覧を出力すると以下の通り。 | No. | 機能 | エンドポイント | メソッド | 説明 | |-----|--------------|----------------------|------|----------------------| | 1 | 基本的なルーティング | /hello | GET | Hello World を返す基本API | | 2 | パスパラメータ | /users/:id | GET | URL パスからパラメータを取得 | | 3 | クエリパラメータ | /search | GET | クエリ文字列からパラメータを取得 | | 4 | JSON レスポンス | /api/user | GET | 構造体を JSON で返す | | 5 | JSON リクエスト | /api/user | POST | JSON をバインドして処理 | | 6 | フォームデータ | /form | POST | フォームデータの受け取り | | 7 | バリデーション | /api/register | POST | 入力データのバリデーション | | 8 | ファイルアップロード | /upload | POST | 単一ファイルのアップロード | | 9 | 複数ファイルアップロード | /upload/multiple | POST | 複数ファイルのアップロード | | 10 | ミドルウェア (ログ) | /api/protected | GET | カスタムミドルウェアの実装 | | 11 | ルートグループ化 | /v1/users, /v2/users | GET | API バージョニング | | 12 | エラーハンドリング | /error | GET | エラーレスポンスの処理 | | 13 | カスタムバリデーター | /api/validate | POST | カスタムバリデーションルール | | 14 | リダイレクト | /redirect | GET | リダイレクト処理 | | 15 | 静的ファイル配信 | /static/* | GET | 静的ファイルの提供 | Hello World まずは Hello World を返すAPIを作る。 main.goは以下の通り。./handlers 以下に実態を書いていく。 package main import ( \"github.com/gin-gonic/gin\" \"github.com/ikuty/golang-gin/handlers\" ) func main() { // Ginエンジンの初期化 r := gin.Default() // Hello World API r.GET(\"/hello\", handlers.HelloHandler) // サーバー起動 r.Run(\":8080\") } ./handlers/hello.go は以下の通り。 package handlers import ( \"net/http\" \"github.com/gin-gonic/gin\" ) // HelloHandler は Hello World を返すハンドラー func HelloHandler(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ \"message\": \"Hello World\", }) } 試す。入門した。 $ curl http://localhost:8080/hello {\"message\":\"Hello World\"} パスパラメータ URL内にプレースホルダを設定し、URLのプレースホルダと対応する値を変数で受けられる機能。 package main import ( \"github.com/gin-gonic/gin\" \"github.com/ikuty/golang-gin/handlers\" ) func main() { // Ginエンジンの初期化 r := gin.Default() // 1. 基本的なルーティング r.GET(\"/hello\", handlers.HelloHandler) // 2. パスパラメータ r.GET(\"/users/:id\", handlers.GetUserByIDHandler) // サーバー起動 r.Run(\":8080\") } ./handlers/params.goは以下。 Laravelと同じところに違和感。型はどこいった..? Ginでは、パスパラメータは常に文字列(string)として取得される。 URLから取得したパラメータを別の型(intやuintなど)として扱いたい場合は、 取得した文字列を明示的に型変換する必要がある。 package handlers import ( \"net/http\" \"github.com/gin-gonic/gin\" ) // GetUserByIDHandler は URL パスパラメータからユーザーIDを取得するハンドラー func GetUserByIDHandler(c *gin.Context) { // パスパラメータ :id を取得 id := c.Param(\"id\") c.JSON(http.StatusOK, gin.H{ \"user_id\": id, \"message\": \"User ID retrieved from path parameter\", }) } 実行。 # 数値IDのテスト $ curl http://localhost:8080/users/123 {\"message\":\"User ID retrieved from path parameter\",\"user_id\":\"123\"} # 文字列IDのテスト $ curl http://localhost:8080/users/alice {\"message\":\"User ID retrieved from path parameter\",\"user_id\":\"alice\"} クエリパラメータ クエリパラメータを受け取る方法は以下。 まぁシンプル。 package handlers import ( \"net/http\" \"github.com/gin-gonic/gin\" ) // SearchHandler は クエリパラメータから検索条件を取得するハンドラー func SearchHandler(c *gin.Context) { // クエリパラメータを取得 query := c.Query(\"q\") // ?q=keyword page := c.DefaultQuery(\"page\", \"1\") // ?page=2 (デフォルト値: \"1\") limit := c.DefaultQuery(\"limit\", \"10\") // ?limit=20 (デフォルト値: \"10\") // オプショナルなパラメータ sort := c.Query(\"sort\") // 値がない場合は空文字列 c.JSON(http.StatusOK, gin.H{ \"query\": query, \"page\": page, \"limit\": limit, \"sort\": sort, \"message\": \"Query parameters retrieved successfully\", }) } 実行結果は以下。 # パスパラメータ $ curl http://localhost:8080/users/123 {\"message\":\"User ID retrieved from path parameter\",\"user_id\":\"123\"} # クエリパラメータ $ curl \"http://localhost:8080/search?q=test&page=2\" {\"limit\":\"10\",\"message\":\"Query parameters retrieved successfully\",\"page\":\"2\",\"query\":\"test\",\"sort\":\"\"} JSONリクエスト/JSONレスポンス Content-Type: application/json で半構造化データ(JSON)を送り、構造体で受けることができる。 また、構造体を Content-Type: application/json でJSON文字列を返すことができる。 構造体のメンバに型を定義しておくことで、文字列がメンバ型に変換(バインド)できる。 まずルーティングは以下の通り。 package main import ( \"github.com/gin-gonic/gin\" \"github.com/ikuty/golang-gin/handlers\" ) func main() { // Ginエンジンの初期化 r := gin.Default() // 4. JSON レスポンス r.GET(\"/api/user\", handlers.GetUserHandler) // 5. JSON リクエスト r.POST(\"/api/user\", handlers.CreateUserHandler) // サーバー起動 r.Run(\":8080\") } ハンドラは以下の通り。 バインドの記述が興味深い。バインド時にバリデーションを実行している。 package handlers import ( \"net/http\" \"github.com/gin-gonic/gin\" ) // User 構造体 type User struct { ID int `json:\"id\"` Name string `json:\"name\"` Email string `json:\"email\"` Age int `json:\"age\"` IsActive bool `json:\"is_active\"` } // GetUserHandler は 構造体を JSON で返すハンドラー func GetUserHandler(c *gin.Context) { // サンプルユーザーデータ user := User{ ID: 1, Name: \"John Doe\", Email: \"john@example.com\", Age: 30, IsActive: true, } c.JSON(http.StatusOK, user) } // CreateUserRequest はユーザー作成リクエストの構造体 type CreateUserRequest struct { Name string `json:\"name\" binding:\"required\"` Email string `json:\"email\" binding:\"required,email\"` Age int `json:\"age\" binding:\"required,gte=0,lte=150\"` } // CreateUserHandler は JSON リクエストをバインドして処理するハンドラー func CreateUserHandler(c *gin.Context) { var req CreateUserRequest // JSON をバインド(バリデーションも実行される) if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid request\", \"details\": err.Error(), }) return } // 作成されたユーザーを返す(実際はDBに保存する) user := User{ ID: 100, // 仮のID Name: req.Name, Email: req.Email, Age: req.Age, IsActive: true, } c.JSON(http.StatusCreated, gin.H{ \"message\": \"User created successfully\", \"user\": user, }) } 実行結果は以下。 1. GET - JSON レスポンス $ curl http://localhost:8080/api/user {\"id\":1,\"name\":\"John Doe\",\"email\":\"john@example.com\",\"age\":30,\"is_active\":true} 2. POST - 正常なリクエスト $ curl -X POST http://localhost:8080/api/user -H \"Content-Type: application/json\" -d \'{\"name\":\"Alice\",\"email\":\"alice@example.com\",\"age\":25}\' {\"message\":\"User created successfully\",\"user\":{\"id\":100,\"name\":\"Alice\",\"email\":\"alice@example.com\",\"age\":25,\"is_active\":true}} 3. POST - バリデーションエラー(メール形式) $ curl -X POST http://localhost:8080/api/user -H \"Content-Type: application/json\" -d \'{\"name\":\"Bob\",\"email\":\"invalid-email\",\"age\":30}\' {\"details\":\"Key: \'CreateUserRequest.Email\' Error:Field validation for \'Email\' failed on the \'email\' tag\",\"error\":\"Invalid request\"} 4. POST - バリデーションエラー(年齢範囲) $ curl -X POST http://localhost:8080/api/user -H \"Content-Type: application/json\" -d \'{\"name\":\"Charlie\",\"email\":\"charlie@example.com\",\"age\":200}\' {\"details\":\"Key: \'CreateUserRequest.Age\' Error:Field validation for \'Age\' failed on the \'lte\' tag\",\"error\":\"Invalid request\"} フォームデータ フォームデータの送信例。ルーティングは以下。 POSTで送ったフィールドを丸っと構造体にする例と、 それぞれのフィールドを個別に取得する例の2つ。 package main import ( \"github.com/gin-gonic/gin\" \"github.com/ikuty/golang-gin/handlers\" ) func main() { // Ginエンジンの初期化 r := gin.Default() // 6. フォームデータ r.POST(\"/form/login\", handlers.LoginHandler) r.POST(\"/form/post\", handlers.PostFormHandler) // サーバー起動 r.Run(\":8080\") } ハンドラは以下。丸っとフォームデータを構造体にバインドできるし、 個別にアクセスすることもできる。 シンプルというか、少ない道具でなんとかするタイプ。 package handlers import ( \"net/http\" \"github.com/gin-gonic/gin\" ) // LoginForm はログインフォームの構造体 type LoginForm struct { Username string `form:\"username\" binding:\"required\"` Password string `form:\"password\" binding:\"required,min=6\"` Remember bool `form:\"remember\"` } // LoginHandler はフォームデータを受け取るハンドラー func LoginHandler(c *gin.Context) { var form LoginForm // フォームデータをバインド if err := c.ShouldBind(&form); err != nil { c.JSON(http.StatusBadRequest, gin.H{ \"error\": \"Invalid form data\", \"details\": err.Error(), }) return } // 実際はここで認証処理を行う c.JSON(http.StatusOK, gin.H{ \"message\": \"Login successful\", \"username\": form.Username, \"remember\": form.Remember, }) } // PostFormHandler は個別にフォームフィールドを取得するハンドラー func PostFormHandler(c *gin.Context) { // 個別のフォームフィールドを取得 title := c.PostForm(\"title\") content := c.DefaultPostForm(\"content\", \"No content provided\") tags := c.PostFormArray(\"tags\") // 配列として取得 c.JSON(http.StatusOK, gin.H{ \"message\": \"Form data received\", \"title\": title, \"content\": content, \"tags\": tags, }) } 実行例は以下。 1. ログインフォーム - 正常 $ curl -X POST http://localhost:8080/form/login -d \"username=john&password=secret123\" {\"message\":\"Login successful\",\"remember\":false,\"username\":\"john\"} 2. ログインフォーム - remember 付き $ curl -X POST http://localhost:8080/form/login -d \"username=alice&password=pass123&remember=true\" {\"message\":\"Login successful\",\"remember\":true,\"username\":\"alice\"} 3. ログインフォーム - バリデーションエラー $ curl -X POST http://localhost:8080/form/login -d \"username=bob&password=123\" {\"details\":\"Key: \'LoginForm.Password\' Error:Field validation for \'Password\' failed on the \'min\' tag\",\"error\":\"Invalid form data\"} 4. 投稿フォーム - 配列データ $ curl -X POST http://localhost:8080/form/post -d \"title=Hello&content=World&tags=go&tags=gin&tags=api\" {\"content\":\"World\",\"message\":\"Form data received\",\"tags\":[\"go\",\"gin\",\"api\"],\"title\":\"Hello\"} まとめ いったん、以下を試した。 基本的なルーティング バスパラメタ・クエリパラメタ JSON Request/Response フォームデータ シンプルすぎてClaude Codeが機能を絞っているのか疑ったが、 公式を読む限り、若干バリエーションが増える程度の様子。 これならわざわざClaudeに入門コースを作ってもらわなくても上から読めば良いかな。

やりなおし統計

Fellegi-Sunterモデルに基づく確率的名寄せパッケージ Splinkを試してみる

[mathjax] Record Linkage、Entity Recognitionなど、いわゆる「名寄せ」問題には、割とさまざまな解決策が 提案されている。その1つに確率論的な数学的背景を持つFellegi-Sunterモデルを実装したSplinkがある。 数学的背景について咀嚼して説明してくれているサイトはあまりなく、Splinkの公式ドキュメントが唯一 に近いソースだろうと思う。 今回は5万人の人物データを使って重複を除去するSplinkの主要な機能のサンプルを追ってみた。 数式を読んだところで実りは薄い、ということもあり、なるべく数式を読まずに読み進めたが、 モデルのパラメタ、予測結果を解釈する際にどうしても数式を読まないとわからない箇所があった。 特にベイズ統計について理解がかなり浅くて、正直なところ理解しきれてはいない。 Splinkの公式ドキュメントを読み進めていった際に気付いたメモを書いてみることにした。 [arst_toc tag=\"h4\"] 環境構築・準備 環境構築の実装部分は主な学習要素ではないのでClaude Codeでサクッと作ってみた。 まずはPlan modeで計画を立ててもらう。これぐらいだとほぼ言い値で使える。 DuckDB、Splinkを使う最小構成の環境を作ります。環境はDocker composeで構築します。計画を立ててください。 ⏺ DuckDB、Splinkを使うDocker Compose環境の計画を完了しました。 計画概要 1. 設計文書: .tmp/design.mdで要件、アーキテクチャ、技術仕様を定義 2. タスク管理: .tmp/task.mdで詳細な実装手順を3フェーズに分けて整理 主要コンポーネント - Jupyter Lab環境: DuckDB、Splink、必要なライブラリがプリインストール - データボリューム: サンプルデータ、notebook、出力用のディレクトリ - 最小構成: 開発・学習用途に特化した軽量環境 次のステップ 実装を開始する場合は、Phase 1から順番に進めます: 1. Docker設定ファイル作成 2. ディレクトリ構造構築 3. サンプル実装 4. テストと検証 計画の詳細は.tmp/design.mdと.tmp/task.mdをご確認ください。実装を開始しますか? 実際にファイルを作ってもらうと以下のような構成となる。 得られる中間成果物である docker-compose.yml、Dockerfile は本記事では省略する。 . ├── data ├── docker-compose.yml ├── Dockerfile ├── notebooks ├── output └── requirements.txt 普通に http://localhost:8888 で JupyterLab が開く。 使用するサンプルデータ 5万人の人物データを使って名寄せを行うサンプル。おそらくSplinkの用途として最初に思いつくやつ。 Splinkにデータをロードする前に必要なデータクリーニング手順について説明がある。 公式によると、まずは行に一意のIDを割り当てる必要がある。 データセット内で一意となるIDであって、重複除去した後のエンティティを識別するIDのことではない。 [clink implicit=\"false\" url=\"https://moj-analytical-services.github.io/splink/demos/tutorials/01_Prerequisites.html\" imgurl=\"https://user-images.githubusercontent.com/7570107/85285114-3969ac00-b488-11ea-88ff-5fca1b34af1f.png\" title=\"Data Prerequisites\" excerpt=\"Splink では、リンクする前にデータをクリーンアップし、行に一意の ID を割り当てる必要があります。このセクションでは、Splink にデータをロードする前に必要な追加のデータクリーニング手順について説明します。\"] 使用するサンプルデータは以下の通り。 from splink import splink_datasets df = splink_datasets.historical_50k df.head() データの分布を可視化 splink.exploratoryのprofile_columnsを使って分布を可視化してみる。 from splink import DuckDBAPI from splink.exploratory import profile_columns db_api = DuckDBAPI() profile_columns(df, db_api, column_expressions=[\"first_name\", \"substr(surname,1,2)\"]) 同じ姓・名の人が大量にいることがわかる。 ブロッキングとブロッキングルールの評価 テーブル内のレコードが他のレコードと「同一かどうか」を調べるためには、 基本的には、他のすべてのレコードとの何らかの比較操作を行うこととなる。 全てのレコードについて全てのカラム同士を比較したいのなら、 対象のテーブルをCROSS JOINした結果、各カラム同士を比較することとなる。 SELECT ... FROM input_tables as l CROSS JOIN input_tables as r あるカラムが条件に合わなければ、もうその先は見ても意味がない、 というケースは多い。例えば、まず first_name 、surname が同じでなければ、 その先の比較を行わない、というのはあり得る。 SELECT ... FROM input_tables as l INNER JOIN input_tables as r ON l.first_name = r.first_name AND l.surname = r.surname このような考え方をブロッキング、ON句の条件をブロッキングルールと言う。 ただ、これだと性と名が完全一致していないレコードが残らない。 そこで、ブロッキングルールを複数定義し、いずれかが真であれば残すことができる。 ここでポイントなのが、ブロッキングルールを複数定義したとき、 それぞれのブロッキングルールで重複して選ばれるレコードが発生した場合、 Splinkが自動的に排除してくれる。 このため、ブロッキングルールを重ねがけすると、最終的に残るレコード数は一致する。 ただ、順番により、同じルールで残るレコード数は変化する。 逆に言うと、ブロッキングルールを足すことで、重複除去後のOR条件が増えていく。 積算グラフにして、ブロッキングルールとその順番の効果を見ることができる。 from splink import DuckDBAPI, block_on from splink.blocking_analysis import ( cumulative_comparisons_to_be_scored_from_blocking_rules_chart, ) blocking_rules = [ block_on(\"substr(first_name,1,3)\", \"substr(surname,1,4)\"), block_on(\"surname\", \"dob\"), block_on(\"first_name\", \"dob\"), block_on(\"postcode_fake\", \"first_name\"), block_on(\"postcode_fake\", \"surname\"), block_on(\"dob\", \"birth_place\"), block_on(\"substr(postcode_fake,1,3)\", \"dob\"), block_on(\"substr(postcode_fake,1,3)\", \"first_name\"), block_on(\"substr(postcode_fake,1,3)\", \"surname\"), block_on(\"substr(first_name,1,2)\", \"substr(surname,1,2)\", \"substr(dob,1,4)\"), ] db_api = DuckDBAPI() cumulative_comparisons_to_be_scored_from_blocking_rules_chart( table_or_tables=df, blocking_rules=blocking_rules, db_api=db_api, link_type=\"dedupe_only\", ) 積算グラフは以下の通り。積み上がっている数値は「比較の数」。 要は、論理和で条件を足していって、次第に緩和されている様子がわかる。 DuckDBでは比較の数を2,000万件以内、Athena,Sparkでは1億件以内を目安にせよとのこと。 比較の定義 Splinkは Fellegi-Sunter model モデル (というかフレームワーク) に基づいている。 https://moj-analytical-services.github.io/splink/topic_guides/theory/fellegi_sunter.html 各カラムの同士をカラムの特性に応じた距離を使って比較し、重みを計算していく。 各カラムの比較に使うためのメソッドが予め用意されているので、特性に応じて選んでいく。 以下では、first_name, sur_name に ForenameSurnameComparison が使われている。 dobにDateOfBirthComparison、birth_place、ocupationにExactMatchが使われている。 import splink.comparison_library as cl from splink import Linker, SettingsCreator settings = SettingsCreator( link_type=\"dedupe_only\", blocking_rules_to_generate_predictions=blocking_rules, comparisons=[ cl.ForenameSurnameComparison( \"first_name\", \"surname\", forename_surname_concat_col_name=\"first_name_surname_concat\", ), cl.DateOfBirthComparison( \"dob\", input_is_string=True ), cl.PostcodeComparison(\"postcode_fake\"), cl.ExactMatch(\"birth_place\").configure(term_frequency_adjustments=True), cl.ExactMatch(\"occupation\").configure(term_frequency_adjustments=True), ], retain_intermediate_calculation_columns=True, ) # Needed to apply term frequencies to first+surname comparison df[\"first_name_surname_concat\"] = df[\"first_name\"] + \" \" + df[\"surname\"] linker = Linker(df, settings, db_api=db_api) ComparisonとComparison Level ここでSplinkツール内の比較の概念の説明。以下の通り概念に名前がついている。 Data Linking Model ├─-- Comparison: Date of birth │ ├─-- ComparisonLevel: Exact match │ ├─-- ComparisonLevel: One character difference │ ├─-- ComparisonLevel: All other ├─-- Comparison: First name │ ├─-- ComparisonLevel: Exact match on first_name │ ├─-- ComparisonLevel: first_names have JaroWinklerSimilarity > 0.95 │ ├─-- ComparisonLevel: first_names have JaroWinklerSimilarity > 0.8 │ ├─-- ComparisonLevel: All other モデルのパラメタ推定 モデルの実行に必要なパラメタは以下の3つ。Splinkを用いてパラメタを得る。 ちなみに u は \"\'U\'nmatch\"、m は \"\'M\'atch\"。背後の数式の説明で現れる。 No パラメタ 説明 1 無作為に選んだレコードが一致する確率 入力データからランダムに取得した2つのレコードが一致する確率 (通常は非常に小さい数値) 2 u値(u確率) 実際には一致しないレコードの中で各 ComparisonLevel に該当するレコードの割合。具体的には、レコード同士が同じエンティティを表すにも関わらず値が異なる確率。例えば、同じ人なのにレコードによって生年月日が違う確率。これは端的には「データ品質」を表す。名前であればタイプミス、別名、ニックネーム、ミドルネーム、結婚後の姓など。 3 m値(m確率) 実際に一致するレコードの中で各 ComparisonLevel に該当するレコードの割合。具体的には、レコード同士が異なるエンティティを表すにも関わらず値が同じである確率。例えば別人なのにレコードによって性・名が同じ確率 (同姓同名)。性別は男か女かしかないので別人でも50%の確率で一致してしまう。 無作為に選んだレコードが一致する確率 入力データからランダムに抽出した2つのレコードが一致する確率を求める。 値は0.000136。すべての可能なレコードのペア比較のうち7,362.31組に1組が一致すると予想される。 合計1,279,041,753組の比較が可能なため、一致するペアは合計で約173,728.33組になると予想される、 とのこと。 linker.training.estimate_probability_two_random_records_match( [ block_on(\"first_name\", \"surname\", \"dob\"), block_on(\"substr(first_name,1,2)\", \"surname\", \"substr(postcode_fake,1,2)\"), block_on(\"dob\", \"postcode_fake\"), ], recall=0.6, ) > Probability two random records match is estimated to be 0.000136. > This means that amongst all possible pairwise record comparisons, > one in 7,362.31 are expected to match. > With 1,279,041,753 total possible comparisons, > we expect a total of around 173,728.33 matching pairs u確率の推定 実際には一致しないレコードの中でComparisonの評価結果がPositiveである確率。 基本、無作為に抽出したレコードは一致しないため、「無作為に抽出したレコード」を 「実際には一致しないレコード」として扱える、という点がミソ。 probability_two_random_records_match によって得られた値を使ってu確率を求める。 estimate_u_using_random_sampling によって、ラベルなし、つまり教師なしでu確率を得られる。 レコードのペアをランダムでサンプルして上で定義したComparisonを評価する。 ランダムサンプルなので大量の不一致が発生するが、各Comparisonにおける不一致の分布を得ている。 これは、例えば性別について、50%が一致、50%が不一致である、という分布を得ている。 一方、例えば生年月日について、一致する確率は 1%、1 文字の違いがある確率は 3%、 その他はすべて 96% の確率で発生する、という分布を得ている。 linker.training.estimate_u_using_random_sampling(max_pairs=5e6) > ----- Estimating u probabilities using random sampling ----- > > Estimated u probabilities using random sampling > > Your model is not yet fully trained. Missing estimates for: > - first_name_surname (no m values are trained). > - dob (no m values are trained). > - postcode_fake (no m values are trained). > - birth_place (no m values are trained). > - occupation (no m values are trained). m確率の推定 「実際に一致するレコード」の中で、Comparisonの評価がNegativeになる確率。 そもそも、このモデルを使って名寄せ、つまり「一致するレコード」を見つけたいのだから、 モデルを作るために「実際に一致するレコード」を計算しなければならないのは矛盾では..となる。 無作為抽出結果から求められるu確率とは異なり、m確率を求めるのは難しい。 もしラベル付けされた「一致するレコード」、つまり教師データセットがあるのであれば、 そのデータセットを使ってm確率を求められる。 例えば、日本人全員にマイナンバーが振られて、全てのレコードにマイナンバーが振られている、 というアナザーワールドがあるのであれば、マイナンバーを使ってm確率を推定する。(どういう状況??) ラベル付けされたデータがないのであれば、EMアルゴリズムでm確率を求めることになっている。 EMアルゴリズムは反復的な手法で、メモリや収束速度の点でペア数を減らす必要があり、 例ではブロッキングルールを設定している。 以下のケースでは、first_nameとsurnameをブロッキングルールとしている。 つまり、first_name, surnameが完全に一致するレコードについてペア比較を行う。 この仮定を設定したため、first_name, surname (first_name_surname) のパラメタを推定できない。 training_blocking_rule = block_on(\"first_name\", \"surname\") training_session_names = ( linker.training.estimate_parameters_using_expectation_maximisation( training_blocking_rule, estimate_without_term_frequencies=True ) ) > ----- Starting EM training session ----- > > Estimating the m probabilities of the model by blocking on: > (l.\"first_name\" = r.\"first_name\") AND (l.\"surname\" = r.\"surname\") > > Parameter estimates will be made for the following comparison(s): > - dob > - postcode_fake > - birth_place > - occupation > > Parameter estimates cannot be made for the following comparison(s) since they are used in the blocking rules: > - first_name_surname > > Iteration 1: Largest change in params was 0.248 in probability_two_random_records_match > Iteration 2: Largest change in params was 0.0929 in probability_two_random_records_match > Iteration 3: Largest change in params was -0.0237 in the m_probability of birth_place, level `Exact match on > birth_place` > Iteration 4: Largest change in params was 0.00961 in the m_probability of birth_place, level `All other >comparisons` > Iteration 5: Largest change in params was -0.00457 in the m_probability of birth_place, level `Exact match on birth_place` > Iteration 6: Largest change in params was -0.00256 in the m_probability of birth_place, level `Exact match on birth_place` > Iteration 7: Largest change in params was 0.00171 in the m_probability of dob, level `Abs date difference Iteration 8: Largest change in params was 0.00115 in the m_probability of dob, level `Abs date difference Iteration 9: Largest change in params was 0.000759 in the m_probability of dob, level `Abs date difference Iteration 10: Largest change in params was 0.000498 in the m_probability of dob, level `Abs date difference Iteration 11: Largest change in params was 0.000326 in the m_probability of dob, level `Abs date difference Iteration 12: Largest change in params was 0.000213 in the m_probability of dob, level `Abs date difference Iteration 13: Largest change in params was 0.000139 in the m_probability of dob, level `Abs date difference Iteration 14: Largest change in params was 9.04e-05 in the m_probability of dob, level `Abs date difference <= 10 year` 同様にdobをブロッキングルールに設定して実行すると、dob以外の列についてパラメタを推定できる。 training_blocking_rule = block_on(\"dob\") training_session_dob = ( linker.training.estimate_parameters_using_expectation_maximisation( training_blocking_rule, estimate_without_term_frequencies=True ) ) > ----- Starting EM training session ----- > > Estimating the m probabilities of the model by blocking on: > l.\"dob\" = r.\"dob\" > > Parameter estimates will be made for the following comparison(s): > - first_name_surname > - postcode_fake > - birth_place > - occupation > > Parameter estimates cannot be made for the following comparison(s) since they are used in the blocking rules: > - dob > > Iteration 1: Largest change in params was -0.474 in the m_probability of first_name_surname, level `Exact match on first_name_surname_concat` > Iteration 2: Largest change in params was 0.052 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 3: Largest change in params was 0.0174 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 4: Largest change in params was 0.00532 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 5: Largest change in params was 0.00165 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 6: Largest change in params was 0.00052 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 7: Largest change in params was 0.000165 in the m_probability of first_name_surname, level `All other comparisons` > Iteration 8: Largest change in params was 5.29e-05 in the m_probability of first_name_surname, level `All other comparisons` > > EM converged after 8 iterations > > Your model is not yet fully trained. Missing estimates for: > - first_name_surname (some u values are not trained). モデルパラメタの可視化 m確率、u確率の可視化。 マッチウェイトの可視化。マッチウェイトは (log_2 (m / u))で計算される。 linker.visualisations.match_weights_chart() モデルの保存と読み込み 以下でモデルを保存できる。 settings = linker.misc.save_model_to_json( \"./saved_model_from_demo.json\", overwrite=True ) 以下で保存したモデルを読み込める。 import json settings = json.load( open(\'./saved_model_from_demo.json\', \'r\') ) リンクするのに十分な情報が含まれていないレコード 「John Smith」のみを含み、他のすべてのフィールドがnullであるレコードは、 他のレコードにリンクされている可能性もあるが、潜在的なリンクを明確にするには十分な情報がない。 以下により可視化できる。 linker.evaluation.unlinkables_chart() 横軸は「マッチウェイトの閾値」。縦軸は「リンクするのに十分な情報が含まれないレコード」の割合。 マッチウェイト閾値=6.11ぐらいのところを見ると、入力データセットのレコードの約1.3%が リンクできないことが示唆される。 訓練済みモデルを使って未知データのマッチウェイトを予測 上で構築した推定モデルを使用し、どのペア比較が一致するかを予測する。 内部的には以下を行うとのこと。 blocking_rules_to_generate_predictionsの少なくとも1つと一致するペア比較を生成 Comparisonで指定されたルールを使用して、入力データの類似性を評価 推定された一致重みを使用し、要求に応じて用語頻度調整を適用して、最終的な一致重みと一致確率スコアを生成 df_predictions = linker.inference.predict(threshold_match_probability=0.2) df_predictions.as_pandas_dataframe(limit=1) > Blocking time: 0.88 seconds > Predict time: 1.91 seconds > > -- WARNING -- > You have called predict(), but there are some parameter estimates which have neither been estimated or > specified in your settings dictionary. To produce predictions the following untrained trained parameters will > use default values. > Comparison: \'first_name_surname\': > u values not fully trained records_to_plot = df_e.to_dict(orient=\"records\") linker.visualisations.waterfall_chart(records_to_plot, filter_nulls=False) predictしたマッチウェイトの可視化、数式との照合 predictしたマッチウェイトは、ウォーターフォール図で可視化できる。 マッチウェイトは、モデル内の各特徴量によって一致の証拠がどの程度提供されるかを示す中心的な指標。 (lambda)は無作為抽出した2つのレコードが一致する確率。(K=m/u)はベイズ因子。 begin{align} M &= log_2 ( frac{lambda}{1-lambda} ) + log_2 K \\ &= log_2 ( frac{lambda}{1-lambda} ) + log_2 m - log_2 u end{align} 異なる列の比較が互いに独立しているという仮定を置いていて、 2つのレコードのベイズ係数が各列比較のベイズ係数の積として扱う。 begin{eqnarray} K_{feature} = K_{first_name_surname} + K_{dob} + K_{postcode_fake} + K_{birth_place} + K_{occupation} + cdots end{eqnarray} マッチウェイトは以下の和。 begin{eqnarray} M_{observe} = M_{prior} + M_{feature} end{eqnarray} ここで begin{align} M_{prior} &= log_2 (frac{lambda}{1-lambda}) \\ M_{feature} &= M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots end{align} 以下のように書き換える。 begin{align} M_{observe} &= log_2 (frac{lambda}{1-lambda}) + sum_i^{feature} log_2 (frac{m_i}{u_i}) \\ &= log_2 (frac{lambda}{1-lambda}) + log_2 (prod_i^{feature} (frac{m_i}{u_i}) ) end{align} ウォーターフォール図の一番左、赤いバーは(M_{prior} = log_2 (frac{lambda}{1-lambda}))。 特徴に関する追加の知識が考慮されていない場合のマッチウェイト。 横に並んでいる薄い緑のバーは (M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots)。 各特徴量のマッチウェイト。 一番右の濃い緑のバーは2つのレコードの合計マッチウェイト。 begin{align} M_{feature} &= M_{first_name_surname} + M_{dob} + M_{postcode_fake} + M_{birth_place} + M_{occupation} + cdots \\ &= 8.50w end{align} まとめ 長くなったのでいったん終了。この記事では教師なし確率的名寄せパッケージSplinkを使用してモデルを作ってみた。 次の記事では、作ったモデルを使用して実際に名寄せをしてみる。 途中、DuckDBが楽しいことに気づいたので、DuckDBだけで何個か記事にしてみようと思う。

分散と標準偏差を計算しやすくする

[mathjax] 分散と標準偏差を計算しやすく変形できる。 いちいち偏差(x_i-bar{x})を計算しておかなくても、2乗和(x_i^2)と平均(bar{x})がわかっていればOK。 begin{eqnarray} s^2 &=& frac{1}{n} sum_{i=1}^n (x_i - bar{x})^2 \\ &=& frac{1}{n} sum_{i=1}^n ( x_i^2 -2 x_i bar{x} + bar{x}^2 ) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 -2 bar{x} sum_{i=1}^n x_i + bar{x}^2 sum_{i=1}^n 1) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 -2 n bar{x}^2 + nbar{x}^2 ) \\ &=& frac{1}{n} ( sum_{i=1}^n x_i^2 - nbar{x}^2 ) \\ &=& frac{1}{n} sum_{i=1}^n x_i^2 - bar{x}^2 \\ s &=& sqrt{frac{1}{n} sum_{i=1}^n x_i^2 - bar{x}^2 } end{eqnarray} 以下みたい使える。平均と標準偏差と2乗和の関係。 begin{eqnarray} sum_{i=1}^n (x_i - bar{x})^2 &=& sum_{i=1}^n x_i^2 - nbar{x}^2 \\ ns^2 &=& sum_{i=1}^n x_i^2 - nbar{x}^2 \\ sum_{i=1}^n x_i^2 &=& n(s^2 + bar{x} ) end{eqnarray}

標本調査に必要なサンプル数の下限を与える2次関数

[mathjax] 2項分布に従う母集団の母平均を推測するために有意水準を設定して95%信頼区間を求めてみた。 母平均のあたりがついていない状況だとやりにくい。 [clink url=\"https://ikuty.com/2019/01/11/sampling/\"] (hat{p})がどんな値であっても下限は(hat{p})の関数で抑えられると思ったので、 気になって(hat{p})を変数のまま残すとどうなるかやってみた。 begin{eqnarray} 1.96sqrt{frac{hat{p}(1-hat{p})}{n}} le 0.05 \\ frac{1.96}{0.05}sqrt{hat{p}(1-hat{p})} le sqrt{n} \\ 39.2^2 hat{p}(1-hat{p}) le n end{eqnarray} 左辺を(f(hat{p}))と置くと (f(hat{p}))は下に凸の2次関数であって、 (frac{d}{dhat{p}}f(hat{p})=0)の時に最大となる。というか(hat{p}=0.5)。 (hat{p}=0.5)であるとすると、これはアンケートを取るときのサンプル数を求める式と同じで、 非常に有名な以下の定数が出てくる。 begin{eqnarray} 1537 * 0.5 (1-0.5) le n \\ 384 le n end{eqnarray} (hat{p})がどんな値であっても、サンプル数を400とれば、 有意水準=5%の95%信頼区間を得られる。 だから、アンケートの(n)数はだいたい400で、となる。 さらに、有意水準を10%にとれば、(n)の下限は100で抑えられる。 なるはやのアンケートなら100、ちゃんとやるには400、というやつがこれ。

深層学習

勾配降下法

[mathjax] 各地点において関数の値を最大にするベクトル((frac{partial f}{partial x_0},frac{partial f}{partial x_1}))を全地点に対して計算したものを勾配とかいう。 ある地点において、このベクトルの方向に向かうことにより最も関数の値を大きくする。 で、今後のために正負を反転して関数の値を最小にするベクトルを考えることにした。 関数の値を小さくする操作を繰り返していけば、いずれ\"最小値\"が見つかるはず。 というモチベを続けるのが勾配降下法。学習率(eta)を使って以下みたいに書ける。。 begin{eqnarray} x_0 = x_0 - eta frac{partial f}{partial x_0} \\ x_1 = x_1 - eta frac{partial f}{partial x_1} end{eqnarray} ということで(f(x_0,x_1)=x_0^2+x_1^2)の最小値を初期値((3.0,4.0))、 学習率(eta=0.1)に設定して計算してみる。 import numpy as np def numerical_gradient(f, x): h = 1e-4 grad = np.zeros_like(x) for idx in range(x.size): tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) x[idx] = tmp_val - h fxh2 = f(x) grad[idx] = (fxh1 - fxh2) / (2*h) x[idx] = tmp_val return grad def gradient_descent(f, init_x, lr=0.01, step_num=100): x = init_x for i in range(step_num): grad = numerical_gradient(f,x) x -= lr * grad return x def function2(x): return x[0]**2 + x[1]**2 init_x = np.array([-3.0, 4.0]) v = gradient_descent(function2, init_x=init_x, lr=0.1, step_num=100) v # array([-6.11110793e-10, 8.14814391e-10]) ((0,0))に収束した。 ニューラルネットワークの勾配 損失関数を重みパラメータで微分する。以下みたいな感じ。 損失関数の大小を見るとして、例えば(w_{11})以外の重みを固定したとして(w_{11})をわずかに 増やしたときに損失関数の値がどれだけ大きくなるか。 損失関数の値はパラメータ(W)と入力(x)から決まるベクトルだけれども、それぞれ乱数と入力値が設定されている。 begin{eqnarray} W= begin{pmatrix} w_{11} & w_{12} & w_{13} \\ w_{21} & w_{22} & w_{23} end{pmatrix}, frac{partial L}{partial W}= begin{pmatrix} frac{partial L}{partial w_{11}} & frac{partial L}{partial w_{12}} & frac{partial L}{partial w_{13}} \\ frac{partial L}{partial w_{21}} & frac{partial L}{partial w_{22}} & frac{partial L}{partial w_{23}} end{pmatrix} end{eqnarray} 重み(W)が乱数で決まるネットワークがあるとする。このネットワークは入力と重みの積を出力 として返す。出力はSoftmaxを経由するとする。 ネットワークの出力と教師データのクロスエントロピー誤差を誤差として使う。 その前に、数値微分関数を多次元対応する。 普通、配列の次元が(n)個になると(n)重ループが必要になるけれども、 Numpy.nditer()を使うと(n)乗ループを1回のループにまとめることができる。 下のmulti_indexが((0,0),(0,1),(0,2),(1,0),(1,1),(1,2))みたいに イテレータが(n)次のタプルを返す。反復回数はタプルの要素数の直積。 Numpy配列にそのタプルでアクセスすることで晴れて全ての要素にアクセスできる。 def numerical_gradient_md(f, x): h = 1e-4 grad = np.zeros_like(x) it = np.nditer(x, flags=[\'multi_index\'], op_flags=[\'readwrite\']) while not it.finished: idx = it.multi_index tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) # f(x+h) x[idx] = tmp_val - h fxh2 = f(x) # f(x-h) grad[idx] = (fxh1 - fxh2) / (2*h) x[idx] = tmp_val # 値を元に戻す it.iternext() return grad 初期値(x=(0.6,0.9))、教師データ(t=(0,0,1))をネットワークに入力する。 predict()は(1 times 3)を返す。 それをSoftmax()を通して、(t)とのクロスエントロピー誤差を求めたものが以下。 import numpy as np def cross_entropy_error(y, t): if y.ndim == 1: t = t.reshape(1, t.size) y = y.reshape(1,y.size) batch_size = y.shape[0] delta = 1e-7 return -np.sum( t * np.log( y + delta)) / batch_size def softmax(x): c = np.max(x) return np.exp(x-c) / np.sum(np.exp(x-c)) import sys, os sys.path.append(os.pardir) import numpy as np class simpleNet: def __init__(self): self.W = np.random.randn(2,3) def predict(self, x): return np.dot(x, self.W) def loss(self, x, t): z = self.predict(x) y = softmax(z) loss = cross_entropy_error(y, t) return loss net = simpleNet() x = np.array([0.6, 0.9]) p = net.predict(x) t = np.array([0, 0, 1]) net.loss(x, t) # 0.9463818740797788 このlossを(W)で微分したのが以下。 あえてパラメータ(W)を引数にとり損失関数の値を計算する(f(W))を定義することで、 数値微分が何と何の演算なのかをわかりやすくしている。 実際は(f(W))は(W)とは関係なく(x)と(t)だけから結果を返すけれども、 損失関数(f(W))を(W)で微分するという操作が自明になるようにコードを合わせている。 def f(W): return net.loss(x, t) dW = numerical_gradient_md(f, net.W) dW # array([[ 0.07627371, 0.49923236, -0.57550607], # [ 0.11441057, 0.74884853, -0.8632591 ]]) 結果の解釈 上記の(w),(W),(t)から(frac{partial L}{partial W})が求まった。 損失関数が何か複雑な形をしているという状況で、 (frac{partial L}{partial w_{11}})は(w_{11})がわずかに動いたときに損失関数の値が変化する量を表している。 それが(w_{11})から(w_{32})まで6個分存在する。 begin{eqnarray} frac{partial L}{partial W} = begin{pmatrix} frac{partial L}{partial w_{11}} & frac{partial L}{partial w_{21}} & frac{partial L}{partial w_{31}} \\ frac{partial L}{partial w_{12}} & frac{partial L}{partial w_{22}} & frac{partial L}{partial w_{32}} end{pmatrix} = begin{pmatrix} 0.07627371 & 0.49923236 & -0.57550607 \\ 0.11441057 & 0.74884853 & -0.8632591 end{pmatrix} end{eqnarray}

勾配の可視化

[mathjax] 2変数関数(f(x_0,x_1))を各変数で偏微分する。 地点((i,j))におけるベクトル((frac{partial f(x_0,j)}{partial x_0},frac{partial f(i,x_1)}{partial x_1}))を全地点で記録していき、ベクトル場を得る。 このベクトル場が勾配(gradient)。 (f(x_0,x_1)=x_0^2+x_1^2)について、(-4.0 le x_0 le 4.0)、(-4.0 le x_1 le 4.0)の範囲で、 勾配を求めてみる。また、勾配を可視化してみる。 まず、2変数関数(f(x_0,x_1))の偏微分係数を求める関数の定義。 ((3.0,3.0))の偏微分係数は((6.00..,6.00..))。 def numerical_gradient(f, x): h = 10e-4 grad = np.zeros_like(x) for idx in range(x.size): tmp_val = x[idx] x[idx] = tmp_val + h fxh1 = f(x) x[idx] = tmp_val - h fxh2 = f(x) grad[idx] = (fxh1 - fxh2) / 2*h x[idx] = tmp_val return grad def function2(x): return x[0]**2 + x[1]**2 p = np.array([3.0,3.0]) v = numerical_gradient(function2, p) v # array([6.e-06, 6.e-06]) (-4.0 le x_0 le 4.0)、(-4.0 le x_1 le 4.0)の範囲((0.5)刻み)で偏微分係数を求めて、 ベクトル場っぽく表示してみる。matplotlibのquiver()は便利。 各地点において関数の値を最も増やす方向が表示されている。 w_range = 4 dw = 0.5 w0 = np.arange(-w_range, w_range, dw) w1 = np.arange(-w_range, w_range, dw) wn = w0.shape[0] diff_w0 = np.zeros((len(w0), len(w1))) diff_w1 = np.zeros((len(w0), len(w1))) for i0 in range(wn): for i1 in range(wn): d = numerical_gradient(function2, np.array([ w0[i0], w1[i1] ])) diff_w0[i1, i0], diff_w1[i1, i0] = (d[0], d[1]) plt.xlabel(\'$x_0$\',fontsize=14) #x軸のラベル plt.ylabel(\'$x_1$\',fontsize=14) #y軸のラベル plt.xticks(range(-w_range,w_range+1,1)) #x軸に表示する値 plt.yticks(range(-w_range,w_range+1,1)) #y軸に表示する値 plt.quiver(w0, w1, diff_w0, diff_w1) plt.show() 値が大きい方向に矢印が向いている。例えば((-3.0,3.0))における偏微分係数は((-6.0,6.0))。 左上方向へのベクトル。 参考にしている本にはことわりが書いてあり、勾配にマイナスをつけたものを図にしている。 その場合、関数の値を最も減らす方向が表示されることになる。 各地点において、この勾配を参照することで、どちらに移動すれば関数の値を最も小さくできるかがわかる。

おっさんが数値微分を復習する

引き続き、ゼロDの写経を続ける。今回は、学生の頃が懐かしい懐ワード、数値微分。 全然Deepに入れないけれどおっさんの復習。解析的な微分と数値微分が一致するところを確認してみる。 昔と違うのは、PythonとJupyterNotebookで超絶簡単に実験できるし、 こうやってWordPressでLaTeXで記事を書いたりできる点。 [mathjax] まず、微分の基本的な考え方は以下の通り。高校数学の数3の範囲。 begin{eqnarray} frac{df(x)}{fx} = lim_{hrightarrow infty} frac{f(x+h)-f(x)}{h} end{eqnarray} 情報系学科に入って最初の方でEuler法とRunge-Kutta法を教わってコードを 書いたりレポート書いたりする。懐すぎる..。 または、基本情報の試験かなんかで、小さい値と小さい値どうしの計算で発生する問題が現れる。 ゼロDにはこの定義を少し改良した方法が載っている。へぇ。 begin{eqnarray} frac{df(x)}{fx} = lim_{hrightarrow infty} frac{f(x+h)-f(x-h)}{2h} end{eqnarray} 写経なので、がんばって数値微分を書いて動かしてみる。 簡単な2次関数(f(x))。 begin{eqnarray} f(x) &=& x^2 - 5x +3 \\ f\'(x) &=& 2x - 5 end{eqnarray} def numerical_diff(f, x): h = 10e-4 return (f(x+h) - f(x-h)) / (2*h) (f(x))と、(x=2.5)のところの(f\'(x))をmatplotlibで書いてみる。懐い... import matplotlib.pyplot as plt import numpy as np def f(x): return x**2 - 5*x + 3 x = np.arange(-10, 10, 0.1) y = f(x) dy = numerical_diff(f,x) plt.plot(x, y) x1 = -2.5 dy1 = numerical_diff(f, x1) y1 = f(x1) # y-y1 = dy1(x-x1) -> y = dy1(x-x1) + y1 j = lambda x: dy1 * (x-x1) + y1 plt.plot(x,j(x)) plt.xlabel(\'x\') plt.ylabel(\'y\') plt.grid() plt.show() 偏微分 2変数以上の関数の数値微分は以下の通り。片方を止める。 数値微分の方法は上記と同じものを使った。 begin{eqnarray} frac{partial f(x_0,x_1)}{partial x_0} &=& lim_{hrightarrow infty} frac{f(x_0 +h,x_1)-f(x_0-h,x_1)}{2h} \\ frac{partial f(x_0,x_1)}{partial x_1} &=& lim_{hrightarrow infty} frac{f(x_0,x_1+h)-f(x_0,x_1-h)}{2h} end{eqnarray} ((x_0,x_1)=(1,1))における(x_0)に対する偏微分(frac{partial f(x_0,x_1)}{x_0})、(x_1)に対する偏微分(frac{partial f(x_0,x_1)}{x_1})を求めてみる。 ちゃんと(frac{partial f(x_0,1.0)}{x_0}=2.00..)、(frac{partial f(1.0,x_1)}{x_1}=2.00..)になった。 import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.mplot3d import Axes3D def f(x): return x[0]**2 + x[1]**2 X = np.meshgrid(np.arange(-5., 5., 0.2),np.arange(-5., 5., 0.2)) Z = f(X) fig = plt.figure(figsize=(6, 6)) axes = fig.add_subplot(111, projection=\'3d\') axes.plot_surface(X[0],X[1], Z) f0 = lambda x: x**2 + 1.0**2 f1 = lambda x: 1.0**2 + x**2 df0 = numerical_diff(f0, 1.0) df1 = numerical_diff(f1, 1.0) print(df0) # 2.0000000000000018 print(df1) # 2.0000000000000018 plt.show()