Skip to content

uploader: request ServerInfo from frontend #2879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 15, 2019
Merged
1 change: 1 addition & 0 deletions tensorboard/uploader/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ py_library(
":auth",
":dev_creds",
":exporter_lib",
":server_info",
":uploader_lib",
"//tensorboard:expect_absl_app_installed",
"//tensorboard:expect_absl_flags_argparse_flags_installed",
Expand Down
14 changes: 14 additions & 0 deletions tensorboard/uploader/server_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def create_server_info(frontend_origin, api_endpoint):
return result


def experiment_url(server_info, experiment_id):
"""Formats a URL that will resolve to the provided experiment.

Args:
server_info: A `server_info_pb2.ServerInfoResponse` message.
experiment_id: A string; the ID of the experiment to link to.

Returns:
A URL resolving to the given experiment, as a string.
"""
url_format = server_info.url_format
return url_format.template.replace(url_format.id_placeholder, experiment_id)


class CommunicationError(RuntimeError):
"""Raised upon failure to communicate with the server."""

Expand Down
11 changes: 11 additions & 0 deletions tensorboard/uploader/server_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ def test(self):
self.assertEqual(actual_url, expected_url)


class ExperimentUrlTest(tb_test.TestCase):
"""Tests for `experiment_url`."""

def test(self):
info = server_info_pb2.ServerInfoResponse()
info.url_format.template = "https://unittest.tensorboard.dev/x/???"
info.url_format.id_placeholder = "???"
actual = server_info.experiment_url(info, "123")
self.assertEqual(actual, "https://unittest.tensorboard.dev/x/123")


def _localhost():
"""Gets family and nodename for a loopback address."""
s = socket
Expand Down
4 changes: 2 additions & 2 deletions tensorboard/uploader/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def __init__(self, writer_client, logdir, rate_limiter=None):
self._logdir, directory_loader_factory)

def create_experiment(self):
"""Creates an Experiment for this upload session and returns the URL."""
"""Creates an Experiment for this upload session and returns the ID."""
logger.info("Creating experiment")
request = write_service_pb2.CreateExperimentRequest()
response = grpc_util.call_with_retries(self._api.CreateExperiment, request)
self._request_builder = _RequestBuilder(response.experiment_id)
return response.url
return response.experiment_id

def start_uploading(self):
"""Blocks forever to continuously upload data from the logdir.
Expand Down
84 changes: 70 additions & 14 deletions tensorboard/uploader/uploader_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from tensorboard.uploader.proto import write_service_pb2_grpc
from tensorboard.uploader import auth
from tensorboard.uploader import exporter as exporter_lib
from tensorboard.uploader import server_info as server_info_lib
from tensorboard.uploader import uploader as uploader_lib
from tensorboard.uploader.proto import server_info_pb2
from tensorboard import program
from tensorboard.plugins import base_plugin

Expand Down Expand Up @@ -65,6 +67,11 @@
_AUTH_SUBCOMMAND_FLAG = '_uploader__subcommand_auth'
_AUTH_SUBCOMMAND_KEY_REVOKE = 'REVOKE'

_DEFAULT_ORIGIN = "https://tensorboard.dev"
# Compatibility measure until server-side /api/uploader support is
# rolled out and stable.
_HARDCODED_API_ENDPOINT = "api.tensorboard.dev:443"


def _prompt_for_user_ack(intent):
"""Prompts for user consent, exiting the program if they decline."""
Expand All @@ -91,10 +98,19 @@ def _define_flags(parser):
subparsers = parser.add_subparsers()

parser.add_argument(
'--endpoint',
'--origin',
type=str,
default='api.tensorboard.dev:443',
help='URL for the API server accepting write requests.')
default='',
help='Experimental. Origin for TensorBoard.dev service to which '
'to connect. If not set, defaults to %r.' % _DEFAULT_ORIGIN)

parser.add_argument(
'--api_endpoint',
type=str,
default='',
help='Experimental. Direct URL for the API server accepting '
'write requests. If set, will skip initial server handshake '
'unless `--origin` is also set.')

parser.add_argument(
'--grpc_creds_type',
Expand Down Expand Up @@ -222,15 +238,26 @@ def _run(flags):
msg = 'Invalid --grpc_creds_type %s' % flags.grpc_creds_type
raise base_plugin.FlagsError(msg)

try:
server_info = _get_server_info(flags)
except server_info_lib.CommunicationError as e:
_die(str(e))
_handle_server_info(server_info)

if not server_info.api_server.endpoint:
logging.error('Server info response: %s', server_info)
_die('Internal error: frontend did not specify an API server')
composite_channel_creds = grpc.composite_channel_credentials(
channel_creds, auth.id_token_call_credentials(credentials))

# TODO(@nfelt): In the `_UploadIntent` case, consider waiting until
# logdir exists to open channel.
channel = grpc.secure_channel(
flags.endpoint, composite_channel_creds, options=channel_options)
server_info.api_server.endpoint,
composite_channel_creds,
options=channel_options)
with channel:
intent.execute(channel)
intent.execute(server_info, channel)


@six.add_metaclass(abc.ABCMeta)
Expand All @@ -254,10 +281,11 @@ def get_ack_message_body(self):
pass

@abc.abstractmethod
def execute(self, channel):
def execute(self, server_info, channel):
"""Carries out this intent with the specified gRPC channel.

Args:
server_info: A `server_info_pb2.ServerInfoResponse` value.
channel: A connected gRPC channel whose server provides the TensorBoard
reader and writer services.
"""
Expand All @@ -271,7 +299,7 @@ def get_ack_message_body(self):
"""Must not be called."""
raise AssertionError('No user ack needed to revoke credentials')

def execute(self, channel):
def execute(self, server_info, channel):
"""Execute handled specially by `main`. Must not be called."""
raise AssertionError('_AuthRevokeIntent should not be directly executed')

Expand All @@ -296,7 +324,7 @@ def __init__(self, experiment_id):
def get_ack_message_body(self):
return self._MESSAGE_TEMPLATE.format(experiment_id=self.experiment_id)

def execute(self, channel):
def execute(self, server_info, channel):
api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel)
experiment_id = self.experiment_id
if not experiment_id:
Expand Down Expand Up @@ -329,14 +357,13 @@ class _ListIntent(_Intent):
def get_ack_message_body(self):
return self._MESSAGE

def execute(self, channel):
def execute(self, server_info, channel):
api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel)
gen = exporter_lib.list_experiments(api_client)
count = 0
for experiment_id in gen:
count += 1
# TODO(@wchargin): Once #2879 is in, remove this hard-coded URL pattern.
url = 'https://tensorboard.dev/experiment/%s/' % experiment_id
url = server_info_lib.experiment_url(server_info, experiment_id)
print(url)
sys.stdout.flush()
if not count:
Expand Down Expand Up @@ -366,10 +393,11 @@ def __init__(self, logdir):
def get_ack_message_body(self):
return self._MESSAGE_TEMPLATE.format(logdir=self.logdir)

def execute(self, channel):
def execute(self, server_info, channel):
api_client = write_service_pb2_grpc.TensorBoardWriterServiceStub(channel)
uploader = uploader_lib.TensorBoardUploader(api_client, self.logdir)
url = uploader.create_experiment()
experiment_id = uploader.create_experiment()
url = server_info_lib.experiment_url(server_info, experiment_id)
print("Upload started and will continue reading any new data as it's added")
print("to the logdir. To stop uploading, press Ctrl-C.")
print("View your TensorBoard live at: %s" % url)
Expand Down Expand Up @@ -407,7 +435,7 @@ def __init__(self, output_dir):
def get_ack_message_body(self):
return self._MESSAGE_TEMPLATE.format(output_dir=self.output_dir)

def execute(self, channel):
def execute(self, server_info, channel):
api_client = export_service_pb2_grpc.TensorBoardExporterServiceStub(channel)
outdir = self.output_dir
try:
Expand Down Expand Up @@ -476,6 +504,34 @@ def _get_intent(flags):
raise AssertionError('Unknown subcommand %r' % (cmd,))


def _get_server_info(flags):
origin = flags.origin or _DEFAULT_ORIGIN
if not flags.origin:
# Temporary fallback to hardcoded API endpoint when not specified.
api_endpoint = flags.api_endpoint or _HARDCODED_API_ENDPOINT
return server_info_lib.create_server_info(origin, api_endpoint)
server_info = server_info_lib.fetch_server_info(origin)
# Override with any API server explicitly specified on the command
# line, but only if the server accepted our initial handshake.
if flags.api_endpoint and server_info.api_server.endpoint:
server_info.api_server.endpoint = flags.api_endpoint
return server_info


def _handle_server_info(info):
compat = info.compatibility
if compat.verdict == server_info_pb2.VERDICT_WARN:
sys.stderr.write('Warning [from server]: %s\n' % compat.details)
sys.stderr.flush()
elif compat.verdict == server_info_pb2.VERDICT_ERROR:
_die('Error [from server]: %s' % compat.details)
else:
# OK or unknown; assume OK.
if compat.details:
sys.stderr.write('%s\n' % compat.details)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sys.stderr.flush() too?

Though now that I actually go look at the docs, apparently both it and stdout are supposed to be line-buffered, so maybe the fact that I usually flush() after doing a raw write() is just cargo culting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fascinating. I had always assumed that stderr was unbuffered, like it is
in every other environment, because that’s one of the main points of
stderr. And so it was in Python 2, but indeed this appears to have
changed in Python 3. I’ll keep this in mind from now on; thanks.

(It’s not clear to me whether we should flush here: superfluous in the
interactive case, but not when outputting to a file. I guess the perf
isn’t really a problem since this isn’t in a loop, so I’ll go ahead and
add it.)

sys.stderr.flush()


def _die(message):
sys.stderr.write('%s\n' % (message,))
sys.stderr.flush()
Expand Down
6 changes: 3 additions & 3 deletions tensorboard/uploader/uploader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def _create_mock_client(self):
stub = write_service_pb2_grpc.TensorBoardWriterServiceStub(test_channel)
mock_client = mock.create_autospec(stub)
fake_exp_response = write_service_pb2.CreateExperimentResponse(
experiment_id="123", url="https://example.com/123")
experiment_id="123", url="should not be used!")
mock_client.CreateExperiment.return_value = fake_exp_response
return mock_client

def test_create_experiment(self):
logdir = "/logs/foo"
mock_client = self._create_mock_client()
uploader = uploader_lib.TensorBoardUploader(mock_client, logdir)
url = uploader.create_experiment()
self.assertEqual(url, "https://example.com/123")
eid = uploader.create_experiment()
self.assertEqual(eid, "123")

def test_start_uploading_without_create_experiment_fails(self):
mock_client = self._create_mock_client()
Expand Down