tensorlayer.files.dataset_loaders.cifar10_dataset 源代码

#! /usr/bin/python
# -*- coding: utf-8 -*-

import os
import pickle
import sys

import numpy as np

from tensorlayer import logging
from tensorlayer.files.utils import maybe_download_and_extract

__all__ = ['load_cifar10_dataset']


[文档]def load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data', plotable=False): """Load CIFAR-10 dataset. It consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class. Parameters ---------- shape : tupe The shape of digit images e.g. (-1, 3, 32, 32) and (-1, 32, 32, 3). path : str The path that the data is downloaded to, defaults is ``data/cifar10/``. plotable : boolean Whether to plot some image examples, False as default. Examples -------- >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3)) References ---------- - `CIFAR website <https://www.cs.toronto.edu/~kriz/cifar.html>`__ - `Data download link <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`__ - `<https://teratail.com/questions/28932>`__ """ path = os.path.join(path, 'cifar10') logging.info("Load or Download cifar10 > {}".format(path)) #Helper function to unpickle the data def unpickle(file): fp = open(file, 'rb') if sys.version_info.major == 2: data = pickle.load(fp) elif sys.version_info.major == 3: data = pickle.load(fp, encoding='latin-1') else: raise RuntimeError("Sys Version Unsupported") fp.close() return data filename = 'cifar-10-python.tar.gz' url = 'https://www.cs.toronto.edu/~kriz/' #Download and uncompress file maybe_download_and_extract(filename, path, url, extract=True) #Unpickle file and fill in data X_train = None y_train = [] for i in range(1, 6): data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "data_batch_{}".format(i))) if i == 1: X_train = data_dic['data'] else: X_train = np.vstack((X_train, data_dic['data'])) y_train += data_dic['labels'] test_data_dic = unpickle(os.path.join(path, 'cifar-10-batches-py/', "test_batch")) X_test = test_data_dic['data'] y_test = np.array(test_data_dic['labels']) if shape == (-1, 3, 32, 32): X_test = X_test.reshape(shape) X_train = X_train.reshape(shape) elif shape == (-1, 32, 32, 3): X_test = X_test.reshape(shape, order='F') X_train = X_train.reshape(shape, order='F') X_test = np.transpose(X_test, (0, 2, 1, 3)) X_train = np.transpose(X_train, (0, 2, 1, 3)) else: X_test = X_test.reshape(shape) X_train = X_train.reshape(shape) y_train = np.array(y_train) if plotable: logging.info('\nCIFAR-10') import matplotlib.pyplot as plt fig = plt.figure(1) logging.info('Shape of a training image: X_train[0] %s' % X_train[0].shape) plt.ion() # interactive mode count = 1 for _ in range(10): # each row for _ in range(10): # each column _ = fig.add_subplot(10, 10, count) if shape == (-1, 3, 32, 32): # plt.imshow(X_train[count-1], interpolation='nearest') plt.imshow(np.transpose(X_train[count - 1], (1, 2, 0)), interpolation='nearest') # plt.imshow(np.transpose(X_train[count-1], (2, 1, 0)), interpolation='nearest') elif shape == (-1, 32, 32, 3): plt.imshow(X_train[count - 1], interpolation='nearest') # plt.imshow(np.transpose(X_train[count-1], (1, 0, 2)), interpolation='nearest') else: raise Exception("Do not support the given 'shape' to plot the image examples") plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) count = count + 1 plt.draw() # interactive mode plt.pause(3) # interactive mode logging.info("X_train: %s" % X_train.shape) logging.info("y_train: %s" % y_train.shape) logging.info("X_test: %s" % X_test.shape) logging.info("y_test: %s" % y_test.shape) X_train = np.asarray(X_train, dtype=np.float32) X_test = np.asarray(X_test, dtype=np.float32) y_train = np.asarray(y_train, dtype=np.int32) y_test = np.asarray(y_test, dtype=np.int32) return X_train, y_train, X_test, y_test