# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import gzip import numpy as np import struct # load compressed MNIST gz files and return numpy arrays def load_data(filename, label=False): with gzip.open(filename) as gz: struct.unpack('I', gz.read(4)) n_items = struct.unpack('>I', gz.read(4)) if not label: n_rows = struct.unpack('>I', gz.read(4))[0] n_cols = struct.unpack('>I', gz.read(4))[0] res = np.frombuffer(gz.read(n_items[0] * n_rows * n_cols), dtype=np.uint8) res = res.reshape(n_items[0], n_rows * n_cols) else: res = np.frombuffer(gz.read(n_items[0]), dtype=np.uint8) res = res.reshape(n_items[0], 1) return res # one-hot encode a 1-D array def one_hot_encode(array, num_of_classes): return np.eye(num_of_classes)[array.reshape(-1)]