MNISTの手書き文字認識用データ取得クラスの作成
MNISTから手書き文字認識用のデータセットをロードするクラスを作ってみた。 ロードしたデータセットをpickleでシリアライズ、デシリアライズする機能付き。 後々改造する予定でここに貼ったのはメンテしない予定。 たったこれだけ書くのに40分も要してしまった...。 Pythonの冗長感が半端ないけども慣れるしかない。 マジックコードだらけだけども、MNISTの手書き文字認識用データ取得専用だから仕方ない。 get_image()関数により訓練データ、訓練ラベル、テストデータ、テストラベルを取れる。 データは1次元になって入る。つまり1ファイルごとに28*28のサイズがある。 import os import urllib.request import numpy as np import gzip import pickle class MnistLoader: def __init__(self, mnistdir): self.url_base = \'http://yann.lecun.com/exdb/mnist/\' self.dataset_dir = mnistdir self.save_path = self.dataset_dir + \'mnist.pkl\' self.key_file = { \'train_img\':\'train-images-idx3-ubyte.gz\', \'train_label\':\'train-labels-idx1-ubyte.gz\', \'test_img\':\'t10k-images-idx3-ubyte.gz\', \'test_label\':\'t10k-labels-idx1-ubyte.gz\' } self.dataset = {} if os.path.exists(self.save_path): with open(self.save_path,\'rb\') as f: self.dataset = pickle.load(f) else: self.__load_mnist() self.dataset[\'train_img\'] = self.__load_img(self.key_file[\'train_img\']) self.dataset[\'train_label\'] = self.__load_label(self.key_file[\'train_label\']) self.dataset[\'test_img\'] = self.__load_img(self.key_file[\'test_img\']) self.dataset[\'test_label\'] = self.__load_label(self.key_file[\'test_label\']) with open(self.save_path, \'wb\') as f: pickle.dump(dataset, f, -1) def __load_mnist(self): \'\'\' load mnist data and store to file \'\'\' for v in self.key_file.values(): file_path = self.dataset_dir + v urllib.request.urlretrieve(url_base + v, self.dataset_dir + v) def __load_img(self, file_name): file_path = self.dataset_dir + file_name with gzip.open(file_path, \'rb\') as f: data = np.frombuffer(f.read(), np.uint8, offset=16) data = data.reshape(-1, 784) return data def __load_label(self,file_name): file_path = dataset_dir + \'/\' + file_name with gzip.open(file_path, \'rb\') as f: labels = np.frombuffer(f.read(), np.uint8, offset=8) return labels def get_image(self): return (self.dataset[\'train_img\'] , self.dataset[\'train_label\']), (self.dataset[\'test_img\'], self.dataset[\'test_label\']) mnist_loader = MnistLoader(mnistdir=\'/Users/ikuty/Documents/mnist/\') PILを使って訓練データの1枚目を表示するテスト。 import sys, os sys.path.append(os.pardir) import numpy as np from PIL import Image def img_show(img): pil_img = Image.fromarray(np.uint8(img)) pil_img.show() (x_train, t_train), (x_test, t_test) = mnist_loader.get_image() img = x_train[0] label = t_train[0] bimg = img.reshape(28, 28) img_show(bimg) こんなのが出る。(28x28しかなくて小さすぎなので256x256に引き伸ばして表示。)