diff --git a/train_script.py b/train_script.py new file mode 100644 index 000000000..529daea48 --- /dev/null +++ b/train_script.py @@ -0,0 +1,179 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""Training executable for detection models. +This executable is used to train DetectionModels. There are two ways of +configuring the training job: +1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file +can be specified by --pipeline_config_path. +Example usage: + ./train \ + --logtostderr \ + --train_dir=path/to/train_dir \ + --pipeline_config_path=pipeline_config.pbtxt +2) Three configuration files can be provided: a model_pb2.DetectionModel +configuration file to define what type of DetectionModel is being trained, an +input_reader_pb2.InputReader file to specify what training data will be used and +a train_pb2.TrainConfig file to configure training parameters. +Example usage: + ./train \ + --logtostderr \ + --train_dir=path/to/train_dir \ + --model_config_path=model_config.pbtxt \ + --train_config_path=train_config.pbtxt \ + --input_config_path=train_input_config.pbtxt +""" + +import functools +import json +import os +import tensorflow as tf + +from object_detection.builders import dataset_builder +from object_detection.builders import graph_rewriter_builder +from object_detection.builders import model_builder +from object_detection.legacy import trainer +from object_detection.utils import config_util + +tf.logging.set_verbosity(tf.logging.INFO) + +flags = tf.app.flags +flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.') +flags.DEFINE_integer('task', 0, 'task id') +flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.') +flags.DEFINE_boolean('clone_on_cpu', False, + 'Force clones to be deployed on CPU. Note that even if ' + 'set to False (allowing ops to run on gpu), some ops may ' + 'still be run on the CPU if they have no GPU kernel.') +flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer ' + 'replicas.') +flags.DEFINE_integer('ps_tasks', 0, + 'Number of parameter server tasks. If None, does not use ' + 'a parameter server.') +flags.DEFINE_string('train_dir', '', + 'Directory to save the checkpoints and training summaries.') + +flags.DEFINE_string('pipeline_config_path', '', + 'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' + 'file. If provided, other configs are ignored') + +flags.DEFINE_string('train_config_path', '', + 'Path to a train_pb2.TrainConfig config file.') +flags.DEFINE_string('input_config_path', '', + 'Path to an input_reader_pb2.InputReader config file.') +flags.DEFINE_string('model_config_path', '', + 'Path to a model_pb2.DetectionModel config file.') + +FLAGS = flags.FLAGS + + +@tf.contrib.framework.deprecated(None, 'Use object_detection/model_main.py.') +def main(_): + assert FLAGS.train_dir, '`train_dir` is missing.' + if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir) + if FLAGS.pipeline_config_path: + configs = config_util.get_configs_from_pipeline_file( + FLAGS.pipeline_config_path) + if FLAGS.task == 0: + tf.gfile.Copy(FLAGS.pipeline_config_path, + os.path.join(FLAGS.train_dir, 'pipeline.config'), + overwrite=True) + else: + configs = config_util.get_configs_from_multiple_files( + model_config_path=FLAGS.model_config_path, + train_config_path=FLAGS.train_config_path, + train_input_config_path=FLAGS.input_config_path) + if FLAGS.task == 0: + for name, config in [('model.config', FLAGS.model_config_path), + ('train.config', FLAGS.train_config_path), + ('input.config', FLAGS.input_config_path)]: + tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name), + overwrite=True) + + model_config = configs['model'] + train_config = configs['train_config'] + input_config = configs['train_input_config'] + + model_fn = functools.partial( + model_builder.build, + model_config=model_config, + is_training=True) + + def get_next(config): + return dataset_builder.make_initializable_iterator( + dataset_builder.build(config)).get_next() + + create_input_dict_fn = functools.partial(get_next, input_config) + + env = json.loads(os.environ.get('TF_CONFIG', '{}')) + cluster_data = env.get('cluster', None) + cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None + task_data = env.get('task', None) or {'type': 'master', 'index': 0} + task_info = type('TaskSpec', (object,), task_data) + + # Parameters for a single worker. + ps_tasks = 0 + worker_replicas = 1 + worker_job_name = 'lonely_worker' + task = 0 + is_chief = True + master = '' + + if cluster_data and 'worker' in cluster_data: + # Number of total worker replicas include "worker"s and the "master". + worker_replicas = len(cluster_data['worker']) + 1 + if cluster_data and 'ps' in cluster_data: + ps_tasks = len(cluster_data['ps']) + + if worker_replicas > 1 and ps_tasks < 1: + raise ValueError('At least 1 ps task is needed for distributed training.') + + if worker_replicas >= 1 and ps_tasks > 0: + # Set up distributed training. + server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc', + job_name=task_info.type, + task_index=task_info.index) + if task_info.type == 'ps': + server.join() + return + + worker_job_name = '%s/task:%d' % (task_info.type, task_info.index) + task = task_info.index + is_chief = (task_info.type == 'master') + master = server.target + + graph_rewriter_fn = None + if 'graph_rewriter_config' in configs: + graph_rewriter_fn = graph_rewriter_builder.build( + configs['graph_rewriter_config'], is_training=True) + + trainer.train( + create_input_dict_fn, + model_fn, + train_config, + master, + task, + FLAGS.num_clones, + worker_replicas, + FLAGS.clone_on_cpu, + ps_tasks, + worker_job_name, + is_chief, + FLAGS.train_dir, + graph_hook_fn=graph_rewriter_fn) + + +if __name__ == '__main__': + tf.app.run() \ No newline at end of file