API - 分布式

分布式训练的帮助sessions和方法,请参考 mnist例子

TaskSpecDef

Specification for a distributed task.

TaskSpec()

Returns the a TaskSpecDef based on the environment variables for distributed training.

DistributedSession([task_spec, ...])

Creates a distributed session.

StopAtTimeHook

Hook that requests stop after a specified time.

LoadCheckpoint

Hook that loads a checkpoint after the session is created.

分布式训练

TaskSpecDef

tensorlayer.distributed.TaskSpecDef(task_type='master', index=0, trial=None, ps_hosts=None, worker_hosts=None, master=None)[源代码]

Specification for a distributed task.

警告

THIS FUNCTION IS DEPRECATED: It will be removed after after 2018-10-30. Instructions for updating: Using the TensorLayer distributed trainer..

It contains the job name, index of the task, the parameter servers and the worker servers. If you want to use the last worker for continuous evaluation you can call the method use_last_worker_as_evaluator which returns a new TaskSpecDef object without the last worker in the cluster specification.

参数
  • task_type (str) -- Task type. One of master, worker or ps.

  • index (int) -- The zero-based index of the task. Distributed training jobs will have a single master task, one or more parameter servers, and one or more workers.

  • trial (int) -- The identifier of the trial being run.

  • ps_hosts (str OR list of str) -- A string with a coma separate list of hosts for the parameter servers or a list of hosts.

  • worker_hosts (str OR list of str) -- A string with a coma separate list of hosts for the worker servers or a list of hosts.

  • master (str) -- A string with the master hosts

提示

master might not be included in TF_CONFIG and can be None. The shard_index is adjusted in any case to assign 0 to master and >= 1 to workers. This implementation doesn't support sparse arrays in the TF_CONFIG variable as the official TensorFlow documentation shows, as it is not a supported by the json definition.

引用

Create TaskSpecDef from environment variables

tensorlayer.distributed.TaskSpec()

Returns the a TaskSpecDef based on the environment variables for distributed training.

警告

THIS FUNCTION IS DEPRECATED: It will be removed after after 2018-10-30. Instructions for updating: Using the TensorLayer distributed trainer..

引用

Distributed Session object

tensorlayer.distributed.DistributedSession(task_spec=None, checkpoint_dir=None, scaffold=None, hooks=None, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=<object object>, save_summaries_secs=<object object>, config=None, stop_grace_period_secs=120, log_step_count_steps=100)

Creates a distributed session.

警告

THIS FUNCTION IS DEPRECATED: It will be removed after after 2018-10-30. Instructions for updating: Using the TensorLayer distributed trainer..

It calls MonitoredTrainingSession to create a MonitoredSession for distributed training.

参数
  • task_spec (TaskSpecDef.) -- The task spec definition from create_task_spec_def()

  • checkpoint_dir (str.) -- Optional path to a directory where to restore variables.

  • scaffold (Scaffold) -- A Scaffold used for gathering or building supportive ops. If not specified, a default one is created. It's used to finalize the graph.

  • hooks (list of SessionRunHook objects.) -- Optional

  • chief_only_hooks (list of SessionRunHook objects.) -- Activate these hooks if is_chief==True, ignore otherwise.

  • save_checkpoint_secs (int) -- The frequency, in seconds, that a checkpoint is saved using a default checkpoint saver. If save_checkpoint_secs is set to None, then the default checkpoint saver isn't used.

  • save_summaries_steps (int) -- The frequency, in number of global steps, that the summaries are written to disk using a default summary saver. If both save_summaries_steps and save_summaries_secs are set to None, then the default summary saver isn't used. Default 100.

  • save_summaries_secs (int) -- The frequency, in secs, that the summaries are written to disk using a default summary saver. If both save_summaries_steps and save_summaries_secs are set to None, then the default summary saver isn't used. Default not enabled.

  • config (tf.ConfigProto) -- an instance of tf.ConfigProto proto used to configure the session. It's the config argument of constructor of tf.Session.

  • stop_grace_period_secs (int) -- Number of seconds given to threads to stop after close() has been called.

  • log_step_count_steps (int) -- The frequency, in number of global steps, that the global step/sec is logged.

实际案例

A simple example for distributed training where all the workers use the same dataset:

>>> task_spec = TaskSpec()
>>> with tf.device(task_spec.device_fn()):
>>>      tensors = create_graph()
>>> with tl.DistributedSession(task_spec=task_spec,
...                            checkpoint_dir='/tmp/ckpt') as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)

An example where the dataset is shared among the workers (see https://www.tensorflow.org/programmers_guide/datasets):

>>> task_spec = TaskSpec()
>>> # dataset is a :class:`tf.data.Dataset` with the raw data
>>> dataset = create_dataset()
>>> if task_spec is not None:
>>>     dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
>>> # shuffle or apply a map function to the new sharded dataset, for example:
>>> dataset = dataset.shuffle(buffer_size=10000)
>>> dataset = dataset.batch(batch_size)
>>> dataset = dataset.repeat(num_epochs)
>>> # create the iterator for the dataset and the input tensor
>>> iterator = dataset.make_one_shot_iterator()
>>> next_element = iterator.get_next()
>>> with tf.device(task_spec.device_fn()):
>>>      # next_element is the input for the graph
>>>      tensors = create_graph(next_element)
>>> with tl.DistributedSession(task_spec=task_spec,
...                            checkpoint_dir='/tmp/ckpt') as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)

引用

Data sharding

我们希望把数据分开很多块,放到每一个训练服务器上,而不是把整个数据放到所有的服务器上。 TensorFlow >= 1.4 提供了一些帮助类(helper classes)来支持数据分区功能(data sharding): Datasets

值得注意的是,在数据切分时,数据打乱非常重要,这些操作在建立shards的时候自动完成:

from tensorflow.contrib.data import TextLineDataset
from tensorflow.contrib.data import Dataset

task_spec = TaskSpec()
files_dataset = Dataset.list_files(files_pattern)
dataset = TextLineDataset(files_dataset)
dataset = dataset.map(your_python_map_function, num_threads=4)
if task_spec is not None:
      dataset = dataset.shard(task_spec.num_workers, task_spec.shard_index)
dataset = dataset.shuffle(buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.device(task_spec.device_fn()):
      tensors = create_graph(next_element)
with tl.DistributedSession(task_spec=task_spec,
                           checkpoint_dir='/tmp/ckpt') as session:
      while not session.should_stop():
          session.run(tensors)

Logging

我们可以使用task_spec来对主服务器(master server)做日志记录:

while not session.should_stop():
      should_log = task_spec.is_master() and your_conditions
      if should_log:
          results = session.run(tensors_with_log_info)
          logging.info(...)
      else:
          results = session.run(tensors)

Continuous evaluation

我们可以使用其中一台子服务器(worker)来一直对保存下来对checkpoint做评估:

import tensorflow as tf
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.monitored_session import SingularMonitoredSession

class Evaluator(session_run_hook.SessionRunHook):
      def __init__(self, checkpoints_path, output_path):
          self.checkpoints_path = checkpoints_path
          self.summary_writer = tf.summary.FileWriter(output_path)
          self.lastest_checkpoint = ''

      def after_create_session(self, session, coord):
          checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
          # wait until a new check point is available
          while self.lastest_checkpoint == checkpoint:
              time.sleep(30)
              checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
          self.saver.restore(session, checkpoint)
          self.lastest_checkpoint = checkpoint

      def end(self, session):
          super(Evaluator, self).end(session)
          # save summaries
          step = int(self.lastest_checkpoint.split('-')[-1])
          self.summary_writer.add_summary(self.summary, step)

      def _create_graph():
          # your code to create the graph with the dataset

      def run_evaluation():
          with tf.Graph().as_default():
              summary_tensors = create_graph()
              self.saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
              hooks = self.create_hooks()
              hooks.append(self)
              if self.max_time_secs and self.max_time_secs > 0:
                  hooks.append(StopAtTimeHook(self.max_time_secs))
              # this evaluation runs indefinitely, until the process is killed
              while True:
                  with SingularMonitoredSession(hooks=[self]) as session:
                      try:
                          while not sess.should_stop():
                              self.summary = session.run(summary_tensors)
                      except OutOfRangeError:
                          pass
                      # end of evaluation

task_spec = TaskSpec().user_last_worker_as_evaluator()
if task_spec.is_evaluator():
      Evaluator().run_evaluation()
else:
      # run normal training

Session Hooks

TensorFlow提供了一些 Session Hooks 来对sessions做操作,我们在这里加更多的helper来实现更多的常规操作。

Stop after maximum time

tensorlayer.distributed.StopAtTimeHook(*args, **kwargs)[源代码]

Hook that requests stop after a specified time.

警告

THIS FUNCTION IS DEPRECATED: It will be removed after after 2018-10-30. Instructions for updating: Using the TensorLayer distributed trainer..

参数

time_running (int) -- Maximum time running in seconds

Initialize network with checkpoint

tensorlayer.distributed.LoadCheckpoint(*args, **kwargs)[源代码]

Hook that loads a checkpoint after the session is created.

警告

THIS FUNCTION IS DEPRECATED: It will be removed after after 2018-10-30. Instructions for updating: Using the TensorLayer distributed trainer..

>>> from tensorflow.python.ops import variables as tf_variables
>>> from tensorflow.python.training.monitored_session import SingularMonitoredSession
>>>
>>> tensors = create_graph()
>>> saver = tf.train.Saver(var_list=tf_variables.trainable_variables())
>>> checkpoint_hook = LoadCheckpoint(saver, my_checkpoint_file)
>>> with tf.SingularMonitoredSession(hooks=[checkpoint_hook]) as session:
>>>      while not session.should_stop():
>>>           session.run(tensors)