API - 文件

下载基准(benchmark)数据集,保存加载模型和数据。 TensorFlow提供 .ckpt 文件格式来保存和加载模型,但为了更好地实现跨平台, 我们建议使用python标准文件格式 .npz 来保存和加载模型。

## 保存模型为 .ckpt
saver = tf.train.Saver()
save_path = saver.save(sess, "model.ckpt")
# 从 .ckpt 加载模型
saver = tf.train.Saver()
saver.restore(sess, "model.ckpt")

## 保存模型为 .npz
tl.files.save_npz(network.all_params , name='model.npz')
# 从 .npz 加载模型 (方法1)
load_params = tl.files.load_npz(name='model.npz')
tl.files.assign_params(sess, load_params, network)
# 从 .npz 加载模型 (方法2)
tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network)

## 此外,你可以这样加载预训练的参数
# 加载第一个参数
tl.files.assign_params(sess, [load_params[0]], network)
# 加载前三个参数
tl.files.assign_params(sess, load_params[:3], network)

TensorLayer provides rich layer implementations trailed for various benchmarks and domain-specific problems. In addition, we also support transparent access to native TensorFlow parameters. For example, we provide not only layers for local response normalization, but also layers that allow user to apply tf.nn.lrn on network.outputs. More functions can be found in TensorFlow API.

load_mnist_dataset([shape, path]) Load the original mnist.
load_fashion_mnist_dataset([shape, path]) Load the fashion mnist.
load_cifar10_dataset([shape, path, plotable]) Load CIFAR-10 dataset.
load_cropped_svhn([path, include_extra]) Load Cropped SVHN.
load_ptb_dataset([path]) Load Penn TreeBank (PTB) dataset.
load_matt_mahoney_text8_dataset([path]) Load Matt Mahoney's dataset.
load_imdb_dataset([path, nb_words, ...]) Load IMDB dataset.
load_nietzsche_dataset([path]) Load Nietzsche dataset.
load_wmt_en_fr_dataset([path]) Load WMT'15 English-to-French translation dataset.
load_flickr25k_dataset([tag, path, ...]) Load Flickr25K dataset.
load_flickr1M_dataset([tag, size, path, ...]) Load Flick1M dataset.
load_cyclegan_dataset([filename, path]) Load images from CycleGAN's database, see this link.
load_celebA_dataset([path]) Load CelebA dataset
load_voc_dataset([path, dataset, ...]) Pascal VOC 2007/2012 Dataset.
load_mpii_pose_dataset([path, is_16_pos_only]) Load MPII Human Pose Dataset.
download_file_from_google_drive(ID, destination) Download file from Google Drive.
save_npz([save_list, name, sess]) Input parameters and the file name, save parameters into .npz file.
load_npz([path, name]) Load the parameters of a Model saved by tl.files.save_npz().
assign_params(sess, params, network) Assign the given parameters to the TensorLayer network.
load_and_assign_npz([sess, name, network]) Load model from npz and assign to a network.
save_npz_dict([save_list, name, sess]) Input parameters and the file name, save parameters as a dictionary into .npz file.
load_and_assign_npz_dict([name, sess]) Restore the parameters saved by tl.files.save_npz_dict().
save_graph([network, name]) Save the architecture of TL model into a pickle file.
load_graph([name]) Restore TL model archtecture from a a pickle file.
save_graph_and_params([network, name, sess]) Save TL model architecture and parameters (i.e.
load_graph_and_params([name, sess]) Load TL model architecture and parameters from graph file and npz file, respectively.
save_ckpt([sess, mode_name, save_dir, ...]) Save parameters into ckpt file.
load_ckpt([sess, mode_name, save_dir, ...]) Load parameters from ckpt file.
save_any_to_npy([save_dict, name]) Save variables to .npy file.
load_npy_to_any([path, name]) Load .npy file.
file_exists(filepath) Check whether a file exists by given file path.
folder_exists(folderpath) Check whether a folder exists by given folder path.
del_file(filepath) Delete a file by given file path.
del_folder(folderpath) Delete a folder by given folder path.
read_file(filepath) Read a file and return a string.
load_file_list([path, regx, printable, ...]) Return a file list in a folder by given a path and regular expression.
load_folder_list([path]) Return a folder list in a folder by given a folder path.
exists_or_mkdir(path[, verbose]) Check a folder by given name, if not exist, create the folder and return False, if directory exists, return True.
maybe_download_and_extract(filename, ...[, ...]) Checks if file exists in working_directory otherwise tries to dowload the file, and optionally also tries to extract the file if format is ".zip" or ".tar"
natural_keys(text) Sort list of string with number in human order.
npz_to_W_pdf([path, regx]) Convert the first weight matrix of .npz file to .pdf by using tl.visualize.W().

下载数据集

MNIST

tensorlayer.files.load_mnist_dataset(shape=(-1, 784), path='data')[源代码]

Load the original mnist.

Automatically download MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 digit images respectively.

参数:
  • shape (tuple) -- The shape of digit images (the default is (-1, 784), alternatively (-1, 28, 28, 1)).
  • path (str) -- The path that the data is downloaded to.
返回:

X_train, y_train, X_val, y_val, X_test, y_test -- Return splitted training/validation/test set respectively.

返回类型:

tuple

Examples

>>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1,784), path='datasets')
>>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))

Fashion-MNIST

tensorlayer.files.load_fashion_mnist_dataset(shape=(-1, 784), path='data')[源代码]

Load the fashion mnist.

Automatically download fashion-MNIST dataset and return the training, validation and test set with 50000, 10000 and 10000 fashion images respectively, examples.

参数:
  • shape (tuple) -- The shape of digit images (the default is (-1, 784), alternatively (-1, 28, 28, 1)).
  • path (str) -- The path that the data is downloaded to.
返回:

X_train, y_train, X_val, y_val, X_test, y_test -- Return splitted training/validation/test set respectively.

返回类型:

tuple

Examples

>>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_fashion_mnist_dataset(shape=(-1,784), path='datasets')
>>> X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_fashion_mnist_dataset(shape=(-1, 28, 28, 1))

CIFAR-10

tensorlayer.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), path='data', plotable=False)[源代码]

Load CIFAR-10 dataset.

It consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class.

参数:
  • shape (tupe) -- The shape of digit images e.g. (-1, 3, 32, 32) and (-1, 32, 32, 3).
  • path (str) -- The path that the data is downloaded to, defaults is data/cifar10/.
  • plotable (boolean) -- Whether to plot some image examples, False as default.

Examples

>>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3))

References

SVHN

tensorlayer.files.load_cropped_svhn(path='data', include_extra=True)[源代码]

Load Cropped SVHN.

The Cropped Street View House Numbers (SVHN) Dataset contains 32x32x3 RGB images. Digit '1' has label 1, '9' has label 9 and '0' has label 0 (the original dataset uses 10 to represent '0'), see ufldl website.

参数:
  • path (str) -- The path that the data is downloaded to.
  • include_extra (boolean) -- If True (default), add extra images to the training set.
返回:

X_train, y_train, X_test, y_test -- Return splitted training/test set respectively.

返回类型:

tuple

Examples

>>> X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False)
>>> tl.vis.save_images(X_train[0:100], [10, 10], 'svhn.png')

Penn TreeBank (PTB)

tensorlayer.files.load_ptb_dataset(path='data')[源代码]

Load Penn TreeBank (PTB) dataset.

It is used in many LANGUAGE MODELING papers, including "Empirical Evaluation and Combination of Advanced Language Modeling Techniques", "Recurrent Neural Network Regularization". It consists of 929k training words, 73k validation words, and 82k test words. It has 10k words in its vocabulary.

参数:path (str) -- The path that the data is downloaded to, defaults is data/ptb/.
返回:
  • train_data, valid_data, test_data (list of int) -- The training, validating and testing data in integer format.
  • vocab_size (int) -- The vocabulary size.

Examples

>>> train_data, valid_data, test_data, vocab_size = tl.files.load_ptb_dataset()

References

Notes

  • If you want to get the raw data, see the source code.

Matt Mahoney's text8

tensorlayer.files.load_matt_mahoney_text8_dataset(path='data')[源代码]

Load Matt Mahoney's dataset.

Download a text file from Matt Mahoney's website if not present, and make sure it's the right size. Extract the first file enclosed in a zip file as a list of words. This dataset can be used for Word Embedding.

参数:path (str) -- The path that the data is downloaded to, defaults is data/mm_test8/.
返回:The raw text data e.g. [.... 'their', 'families', 'who', 'were', 'expelled', 'from', 'jerusalem', ...]
返回类型:list of str

Examples

>>> words = tl.files.load_matt_mahoney_text8_dataset()
>>> print('Data size', len(words))

IMBD

tensorlayer.files.load_imdb_dataset(path='data', nb_words=None, skip_top=0, maxlen=None, test_split=0.2, seed=113, start_char=1, oov_char=2, index_from=3)[源代码]

Load IMDB dataset.

参数:
  • path (str) -- The path that the data is downloaded to, defaults is data/imdb/.
  • nb_words (int) -- Number of words to get.
  • skip_top (int) -- Top most frequent words to ignore (they will appear as oov_char value in the sequence data).
  • maxlen (int) -- Maximum sequence length. Any longer sequence will be truncated.
  • seed (int) -- Seed for reproducible data shuffling.
  • start_char (int) -- The start of a sequence will be marked with this character. Set to 1 because 0 is usually the padding character.
  • oov_char (int) -- Words that were cut out because of the num_words or skip_top limit will be replaced with this character.
  • index_from (int) -- Index actual words with this index and higher.

Examples

>>> X_train, y_train, X_test, y_test = tl.files.load_imdb_dataset(
...                                 nb_words=20000, test_split=0.2)
>>> print('X_train.shape', X_train.shape)
(20000,)  [[1, 62, 74, ... 1033, 507, 27],[1, 60, 33, ... 13, 1053, 7]..]
>>> print('y_train.shape', y_train.shape)
(20000,)  [1 0 0 ..., 1 0 1]

References

Nietzsche

tensorlayer.files.load_nietzsche_dataset(path='data')[源代码]

Load Nietzsche dataset.

参数:path (str) -- The path that the data is downloaded to, defaults is data/nietzsche/.
返回:The content.
返回类型:str

Examples

>>> see tutorial_generate_text.py
>>> words = tl.files.load_nietzsche_dataset()
>>> words = basic_clean_str(words)
>>> words = words.split()

WMT'15 Website 的英文译法文数据

tensorlayer.files.load_wmt_en_fr_dataset(path='data')[源代码]

Load WMT'15 English-to-French translation dataset.

It will download the data from the WMT'15 Website (10^9-French-English corpus), and the 2013 news test from the same site as development set. Returns the directories of training data and test data.

参数:path (str) -- The path that the data is downloaded to, defaults is data/wmt_en_fr/.

References

  • Code modified from /tensorflow/models/rnn/translation/data_utils.py

Notes

Usually, it will take a long time to download this dataset.

Flickr25k

tensorlayer.files.load_flickr25k_dataset(tag='sky', path='data', n_threads=50, printable=False)[源代码]

Load Flickr25K dataset.

Returns a list of images by a given tag from Flick25k dataset, it will download Flickr25k from the official website at the first time you use it.

参数:
  • tag (str or None) --
    What images to return.
    • If you want to get images with tag, use string like 'dog', 'red', see Flickr Search.
    • If you want to get all images, set to None.
  • path (str) -- The path that the data is downloaded to, defaults is data/flickr25k/.
  • n_threads (int) -- The number of thread to read image.
  • printable (boolean) -- Whether to print infomation when reading images, default is False.

Examples

Get images with tag of sky

>>> images = tl.files.load_flickr25k_dataset(tag='sky')

Get all images

>>> images = tl.files.load_flickr25k_dataset(tag=None, n_threads=100, printable=True)

Flickr1M

tensorlayer.files.load_flickr1M_dataset(tag='sky', size=10, path='data', n_threads=50, printable=False)[源代码]

Load Flick1M dataset.

Returns a list of images by a given tag from Flickr1M dataset, it will download Flickr1M from the official website at the first time you use it.

参数:
  • tag (str or None) --
    What images to return.
    • If you want to get images with tag, use string like 'dog', 'red', see Flickr Search.
    • If you want to get all images, set to None.
  • size (int) -- integer between 1 to 10. 1 means 100k images ... 5 means 500k images, 10 means all 1 million images. Default is 10.
  • path (str) -- The path that the data is downloaded to, defaults is data/flickr25k/.
  • n_threads (int) -- The number of thread to read image.
  • printable (boolean) -- Whether to print infomation when reading images, default is False.

Examples

Use 200k images

>>> images = tl.files.load_flickr1M_dataset(tag='zebra', size=2)

Use 1 Million images

>>> images = tl.files.load_flickr1M_dataset(tag='zebra')

CycleGAN

tensorlayer.files.load_cyclegan_dataset(filename='summer2winter_yosemite', path='data')[源代码]

Load images from CycleGAN's database, see this link.

参数:
  • filename (str) -- The dataset you want, see this link.
  • path (str) -- The path that the data is downloaded to, defaults is data/cyclegan

Examples

>>> im_train_A, im_train_B, im_test_A, im_test_B = load_cyclegan_dataset(filename='summer2winter_yosemite')

CelebA

tensorlayer.files.load_celebA_dataset(path='data')[源代码]

Load CelebA dataset

Return a list of image path.

参数:path (str) -- The path that the data is downloaded to, defaults is data/celebA/.

VOC 2007/2012

tensorlayer.files.load_voc_dataset(path='data', dataset='2012', contain_classes_in_person=False)[源代码]

Pascal VOC 2007/2012 Dataset.

It has 20 objects: aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor and additional 3 classes : head, hand, foot for person.

参数:
  • path (str) -- The path that the data is downloaded to, defaults is data/VOC.
  • dataset (str) -- The VOC dataset version, 2012, 2007, 2007test or 2012test. We usually train model on 2007+2012 and test it on 2007test.
  • contain_classes_in_person (boolean) -- Whether include head, hand and foot annotation, default is False.
返回:

  • imgs_file_list (list of str) -- Full paths of all images.
  • imgs_semseg_file_list (list of str) -- Full paths of all maps for semantic segmentation. Note that not all images have this map!
  • imgs_insseg_file_list (list of str) -- Full paths of all maps for instance segmentation. Note that not all images have this map!
  • imgs_ann_file_list (list of str) -- Full paths of all annotations for bounding box and object class, all images have this annotations.
  • classes (list of str) -- Classes in order.
  • classes_in_person (list of str) -- Classes in person.
  • classes_dict (dictionary) -- Class label to integer.
  • n_objs_list (list of int) -- Number of objects in all images in imgs_file_list in order.
  • objs_info_list (list of str) -- Darknet format for the annotation of all images in imgs_file_list in order. [class_id x_centre y_centre width height] in ratio format.
  • objs_info_dicts (dictionary) -- The annotation of all images in imgs_file_list, {imgs_file_list : dictionary for annotation}, format from TensorFlow/Models/object-detection.

Examples

>>> imgs_file_list, imgs_semseg_file_list, imgs_insseg_file_list, imgs_ann_file_list,
>>>     classes, classes_in_person, classes_dict,
>>>     n_objs_list, objs_info_list, objs_info_dicts = tl.files.load_voc_dataset(dataset="2012", contain_classes_in_person=False)
>>> idx = 26
>>> print(classes)
['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
>>> print(classes_dict)
{'sheep': 16, 'horse': 12, 'bicycle': 1, 'bottle': 4, 'cow': 9, 'sofa': 17, 'car': 6, 'dog': 11, 'cat': 7, 'person': 14, 'train': 18, 'diningtable': 10, 'aeroplane': 0, 'bus': 5, 'pottedplant': 15, 'tvmonitor': 19, 'chair': 8, 'bird': 2, 'boat': 3, 'motorbike': 13}
>>> print(imgs_file_list[idx])
data/VOC/VOC2012/JPEGImages/2007_000423.jpg
>>> print(n_objs_list[idx])
2
>>> print(imgs_ann_file_list[idx])
data/VOC/VOC2012/Annotations/2007_000423.xml
>>> print(objs_info_list[idx])
14 0.173 0.461333333333 0.142 0.496
14 0.828 0.542666666667 0.188 0.594666666667
>>> ann = tl.prepro.parse_darknet_ann_str_to_list(objs_info_list[idx])
>>> print(ann)
[[14, 0.173, 0.461333333333, 0.142, 0.496], [14, 0.828, 0.542666666667, 0.188, 0.594666666667]]
>>> c, b = tl.prepro.parse_darknet_ann_list_to_cls_box(ann)
>>> print(c, b)
[14, 14] [[0.173, 0.461333333333, 0.142, 0.496], [0.828, 0.542666666667, 0.188, 0.594666666667]]

References

MPII

tensorlayer.files.load_mpii_pose_dataset(path='data', is_16_pos_only=False)[源代码]

Load MPII Human Pose Dataset.

参数:
  • path (str) -- The path that the data is downloaded to.
  • is_16_pos_only (boolean) -- If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation)
返回:

  • img_train_list (list of str) -- The image directories of training data.
  • ann_train_list (list of dict) -- The annotations of training data.
  • img_test_list (list of str) -- The image directories of testing data.
  • ann_test_list (list of dict) -- The annotations of testing data.

Examples

>>> import pprint
>>> import tensorlayer as tl
>>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset()
>>> image = tl.vis.read_image(img_train_list[0])
>>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png')
>>> pprint.pprint(ann_train_list[0])

References

Google Drive

tensorlayer.files.download_file_from_google_drive(ID, destination)[源代码]

Download file from Google Drive.

See tl.files.load_celebA_dataset for example.

参数:
  • ID (str) -- The driver ID.
  • destination (str) -- The destination for save file.

保存与加载模型

以列表保存模型到 .npz

tensorlayer.files.save_npz(save_list=None, name='model.npz', sess=None)[源代码]

Input parameters and the file name, save parameters into .npz file. Use tl.utils.load_npz() to restore.

参数:
  • save_list (list of tensor) -- A list of parameters (tensor) to be saved.
  • name (str) -- The name of the .npz file.
  • sess (None or Session) -- Session may be required in some case.

Examples

Save model to npz

>>> tl.files.save_npz(network.all_params, name='model.npz', sess=sess)

Load model from npz (Method 1)

>>> load_params = tl.files.load_npz(name='model.npz')
>>> tl.files.assign_params(sess, load_params, network)

Load model from npz (Method 2)

>>> tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network)

Notes

If you got session issues, you can change the value.eval() to value.eval(session=sess)

References

Saving dictionary using numpy

从save_npz加载模型参数列表

tensorlayer.files.load_npz(path='', name='model.npz')[源代码]

Load the parameters of a Model saved by tl.files.save_npz().

参数:
  • path (str) -- Folder path to .npz file.
  • name (str) -- The name of the .npz file.
返回:

A list of parameters in order.

返回类型:

list of array

Examples

  • See tl.files.save_npz

References

把模型参数载入模型

tensorlayer.files.assign_params(sess, params, network)[源代码]

Assign the given parameters to the TensorLayer network.

参数:
  • sess (Session) -- TensorFlow Session.
  • params (list of array) -- A list of parameters (array) in order.
  • network (Layer) -- The network to be assigned.
返回:

A list of tf ops in order that assign params. Support sess.run(ops) manually.

返回类型:

list of operations

Examples

  • See tl.files.save_npz

References

从.npz中加载参数并导入模型

tensorlayer.files.load_and_assign_npz(sess=None, name=None, network=None)[源代码]

Load model from npz and assign to a network.

参数:
  • sess (Session) -- TensorFlow Session.
  • name (str) -- The name of the .npz file.
  • network (Layer) -- The network to be assigned.
返回:

Returns False, if the model is not exist.

返回类型:

False or network

Examples

  • See tl.files.save_npz

以字典保存模型到 .npz

tensorlayer.files.save_npz_dict(save_list=None, name='model.npz', sess=None)[源代码]

Input parameters and the file name, save parameters as a dictionary into .npz file.

Use tl.files.load_and_assign_npz_dict() to restore.

参数:
  • save_list (list of parameters) -- A list of parameters (tensor) to be saved.
  • name (str) -- The name of the .npz file.
  • sess (Session) -- TensorFlow Session.

从save_npz_dict加载模型参数列表

tensorlayer.files.load_and_assign_npz_dict(name='model.npz', sess=None)[源代码]

Restore the parameters saved by tl.files.save_npz_dict().

参数:
  • name (str) -- The name of the .npz file.
  • sess (Session) -- TensorFlow Session.

保存模型结构

tensorlayer.files.save_graph(network=None, name='graph.pkl')[源代码]

Save the architecture of TL model into a pickle file. No parameters be saved.

参数:
  • network (TensorLayer layer) -- The network to save.
  • name (str) -- The name of graph file.

Examples

Save the architecture >>> tl.files.save_graph(net_test, 'graph.pkl')

Load the architecture in another script (no parameters restore) >>> net = tl.files.load_graph('graph.pkl')

加载模型结构

tensorlayer.files.load_graph(name='model.pkl')[源代码]

Restore TL model archtecture from a a pickle file. No parameters be restored.

参数:name (str) -- The name of graph file.
返回:network -- The input placeholder will become the attributes of the returned TL layer object.
返回类型:TensorLayer layer

Examples

  • see tl.files.save_graph

保存模型结构和参数

tensorlayer.files.save_graph_and_params(network=None, name='model', sess=None)[源代码]

Save TL model architecture and parameters (i.e. whole model) into graph file and npz file, respectively.

参数:
  • network (TensorLayer layer) -- The network to save.
  • name (str) -- The folder name to save the graph and parameters.
  • sess (Session) -- TensorFlow Session.

Examples

Save architecture and parameters

>>> tl.files.save_graph_and_params(net, 'model', sess)

Load archtecture and parameters

>>> net = tl.files.load_graph_and_params('model', sess)

加载模型结构和参数

tensorlayer.files.load_graph_and_params(name='model', sess=None)[源代码]

Load TL model architecture and parameters from graph file and npz file, respectively.

参数:
  • name (str) -- The folder name to load the graph and parameters.
  • sess (Session) -- TensorFlow Session.

以列表保存模型到 .ckpt

tensorlayer.files.save_ckpt(sess=None, mode_name='model.ckpt', save_dir='checkpoint', var_list=None, global_step=None, printable=False)[源代码]

Save parameters into ckpt file.

参数:
  • sess (Session) -- TensorFlow Session.
  • mode_name (str) -- The name of the model, default is model.ckpt.
  • save_dir (str) -- The path / file directory to the ckpt, default is checkpoint.
  • var_list (list of tensor) -- The parameters / variables (tensor) to be saved. If empty, save all global variables (default).
  • global_step (int or None) -- Step number.
  • printable (boolean) -- Whether to print all parameters information.

参见

load_ckpt()

从.ckpt中加载参数并导入模型

tensorlayer.files.load_ckpt(sess=None, mode_name='model.ckpt', save_dir='checkpoint', var_list=None, is_latest=True, printable=False)[源代码]

Load parameters from ckpt file.

参数:
  • sess (Session) -- TensorFlow Session.
  • mode_name (str) -- The name of the model, default is model.ckpt.
  • save_dir (str) -- The path / file directory to the ckpt, default is checkpoint.
  • var_list (list of tensor) -- The parameters / variables (tensor) to be saved. If empty, save all global variables (default).
  • is_latest (boolean) -- Whether to load the latest ckpt, if False, load the ckpt with the name of `mode_name.
  • printable (boolean) -- Whether to print all parameters information.

Examples

  • Save all global parameters.
>>> tl.files.save_ckpt(sess=sess, mode_name='model.ckpt', save_dir='model', printable=True)
  • Save specific parameters.
>>> tl.files.save_ckpt(sess=sess, mode_name='model.ckpt', var_list=net.all_params, save_dir='model', printable=True)
  • Load latest ckpt.
>>> tl.files.load_ckpt(sess=sess, var_list=net.all_params, save_dir='model', printable=True)
  • Load specific ckpt.
>>> tl.files.load_ckpt(sess=sess, mode_name='model.ckpt', var_list=net.all_params, save_dir='model', is_latest=False, printable=True)

保存与加载数据

保持数据到.npy文件

tensorlayer.files.save_any_to_npy(save_dict=None, name='file.npy')[源代码]

Save variables to .npy file.

参数:
  • save_dict (directory) -- The variables to be saved.
  • name (str) -- File name.

Examples

>>> tl.files.save_any_to_npy(save_dict={'data': ['a','b']}, name='test.npy')
>>> data = tl.files.load_npy_to_any(name='test.npy')
>>> print(data)
{'data': ['a','b']}

从.npy文件加载数据

tensorlayer.files.load_npy_to_any(path='', name='file.npy')[源代码]

Load .npy file.

参数:
  • path (str) -- Path to the file (optional).
  • name (str) -- File name.

Examples

  • see tl.files.save_any_to_npy()

文件夹/文件相关函数

判断文件存在

tensorlayer.files.file_exists(filepath)[源代码]

Check whether a file exists by given file path.

判断文件夹存在

tensorlayer.files.folder_exists(folderpath)[源代码]

Check whether a folder exists by given folder path.

删除文件

tensorlayer.files.del_file(filepath)[源代码]

Delete a file by given file path.

删除文件夹

tensorlayer.files.del_folder(folderpath)[源代码]

Delete a folder by given folder path.

读取文件

tensorlayer.files.read_file(filepath)[源代码]

Read a file and return a string.

Examples

>>> data = tl.files.read_file('data.txt')

从文件夹中读取文件名列表

tensorlayer.files.load_file_list(path=None, regx='\\.jpg', printable=True, keep_prefix=False)[源代码]

Return a file list in a folder by given a path and regular expression.

参数:
  • path (str or None) -- A folder path, if None, use the current directory.
  • regx (str) -- The regx of file name.
  • printable (boolean) -- Whether to print the files infomation.
  • keep_prefix (boolean) -- Whether to keep path in the file name.

Examples

>>> file_list = tl.files.load_file_list(path=None, regx='w1pre_[0-9]+\.(npz)')

从文件夹中读取文件夹列表

tensorlayer.files.load_folder_list(path='')[源代码]

Return a folder list in a folder by given a folder path.

参数:path (str) -- A folder path.

查看或建立文件夹

tensorlayer.files.exists_or_mkdir(path, verbose=True)[源代码]

Check a folder by given name, if not exist, create the folder and return False, if directory exists, return True.

参数:
  • path (str) -- A folder path.
  • verbose (boolean) -- If True (default), prints results.
返回:

True if folder already exist, otherwise, returns False and create the folder.

返回类型:

boolean

Examples

>>> tl.files.exists_or_mkdir("checkpoints/train")

下载或解压

tensorlayer.files.maybe_download_and_extract(filename, working_directory, url_source, extract=False, expected_bytes=None)[源代码]

Checks if file exists in working_directory otherwise tries to dowload the file, and optionally also tries to extract the file if format is ".zip" or ".tar"

参数:
  • filename (str) -- The name of the (to be) dowloaded file.
  • working_directory (str) -- A folder path to search for the file in and dowload the file to
  • url (str) -- The URL to download the file from
  • extract (boolean) -- If True, tries to uncompress the dowloaded file is ".tar.gz/.tar.bz2" or ".zip" file, default is False.
  • expected_bytes (int or None) -- If set tries to verify that the downloaded file is of the specified size, otherwise raises an Exception, defaults is None which corresponds to no check being performed.
返回:

File path of the dowloaded (uncompressed) file.

返回类型:

str

Examples

>>> down_file = tl.files.maybe_download_and_extract(filename='train-images-idx3-ubyte.gz',
...                                            working_directory='data/',
...                                            url_source='http://yann.lecun.com/exdb/mnist/')
>>> tl.files.maybe_download_and_extract(filename='ADEChallengeData2016.zip',
...                                             working_directory='data/',
...                                             url_source='http://sceneparsing.csail.mit.edu/data/',
...                                             extract=True)

排序

字符串按数字排序

tensorlayer.files.natural_keys(text)[源代码]

Sort list of string with number in human order.

Examples

>>> l = ['im1.jpg', 'im31.jpg', 'im11.jpg', 'im21.jpg', 'im03.jpg', 'im05.jpg']
>>> l.sort(key=tl.files.natural_keys)
['im1.jpg', 'im03.jpg', 'im05', 'im11.jpg', 'im21.jpg', 'im31.jpg']
>>> l.sort() # that is what we dont want
['im03.jpg', 'im05', 'im1.jpg', 'im11.jpg', 'im21.jpg', 'im31.jpg']

References

可视化 npz 文件

tensorlayer.files.npz_to_W_pdf(path=None, regx='w1pre_[0-9]+\\.(npz)')[源代码]

Convert the first weight matrix of .npz file to .pdf by using tl.visualize.W().

参数:
  • path (str) -- A folder path to npz files.
  • regx (str) -- Regx for the file name.

Examples

Convert the first weight matrix of w1_pre...npz file to w1_pre...pdf.

>>> tl.files.npz_to_W_pdf(path='/Users/.../npz_file/', regx='w1pre_[0-9]+\.(npz)')