diff --git a/run_inference_tpu.py b/run_inference_tpu.py new file mode 100644 index 0000000..960251c --- /dev/null +++ b/run_inference_tpu.py @@ -0,0 +1,67 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Runs FFN inference within a dense bounding box. + +Inference is performed within a single process. +""" + +import os +import time + +from google.protobuf import text_format +from absl import app +from absl import flags +from tensorflow import gfile + +from ffn.utils import bounding_box_pb2 +from ffn.inference import inference +from ffn.inference import inference_flags + +import tensorflow as tf +from tensorflow.contrib import tpu +from tensorflow.contrib.cluster_resolver import TPUClusterResolver + +FLAGS = flags.FLAGS + +flags.DEFINE_string('bounding_box', None, + 'BoundingBox proto in text format defining the area ' + 'to segmented.') + + +def main(unused_argv): + tpu_grpc_url = TPUClusterResolver(tpu=[os.environ['TPU_NAME']]).get_master() + + request = inference_flags.request_from_flags() + + if not gfile.Exists(request.segmentation_output_dir): + gfile.MakeDirs(request.segmentation_output_dir) + + bbox = bounding_box_pb2.BoundingBox() + text_format.Parse(FLAGS.bounding_box, bbox) + + runner = inference.Runner() + + with tf.Session(tpu_grpc_url, config=tf.ConfigProto(log_device_placement=True)) as sess: + runner.start(request, session=sess) + runner.run((bbox.start.z, bbox.start.y, bbox.start.x), + (bbox.size.z, bbox.size.y, bbox.size.x)) + + counter_path = os.path.join(request.segmentation_output_dir, 'counters.txt') + if not gfile.Exists(counter_path): + runner.counters.dump(counter_path) + + +if __name__ == '__main__': + app.run(main)