diff --git a/tensorboard/data/server/BUILD b/tensorboard/data/server/BUILD index a72c6da28a..0eb5e6e8b1 100644 --- a/tensorboard/data/server/BUILD +++ b/tensorboard/data/server/BUILD @@ -27,6 +27,7 @@ rust_library( name = "rustboard_core", srcs = [ "lib.rs", + "commit.rs", "data_compat.rs", "event_file.rs", "masked_crc.rs", diff --git a/tensorboard/data/server/commit.rs b/tensorboard/data/server/commit.rs new file mode 100644 index 0000000000..9b0104d21b --- /dev/null +++ b/tensorboard/data/server/commit.rs @@ -0,0 +1,138 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//! Shared state for sampled data available to readers. + +use std::collections::HashMap; +use std::sync::RwLock; + +use crate::proto::tensorboard as pb; +use crate::reservoir::Basin; +use crate::types::{Run, Step, Tag, WallTime}; + +/// Current state of in-memory sampled data. +/// +/// A commit is an internally mutable structure. All readers and writers should keep a shared +/// reference to a single commit. When writers need to update it, they grab an exclusive lock to +/// the contents. +/// +/// Deadlock safety: any thread should obtain the outer lock (around the hash map) before an inner +/// lock (around the run data), and should obtain at most one `RunData` lock at once. +#[derive(Debug, Default)] +pub struct Commit { + pub runs: RwLock>>, +} + +impl Commit { + /// Creates a new, empty commit. + pub fn new() -> Self { + Commit::default() + } +} + +/// Data for a single run. +/// +/// This contains all data and metadata for a run. For now, that data includes only scalars; +/// tensors and blob sequences will come soon. +#[derive(Debug, Default)] +pub struct RunData { + /// The time of the first event recorded for this run. + /// + /// Used to define an ordering on runs that is stable as new runs are added, so that existing + /// runs aren't constantly changing color. + pub start_time: Option, + + /// Scalar time series for this run. + pub scalars: TagStore, +} + +pub type TagStore = HashMap>; + +#[derive(Debug)] +pub struct TimeSeries { + /// Summary metadata for this time series. + pub metadata: Box, + + /// Reservoir basin for data points in this time series. + /// + /// See [`TimeSeries::valid_values`] for a client-friendly view that omits `DataLoss` points + /// and transposes `Step`s into the tuple. + pub basin: Basin<(WallTime, Result)>, +} + +impl TimeSeries { + /// Creates a new time series from the given summary metadata. + pub fn new(metadata: Box) -> Self { + TimeSeries { + metadata, + basin: Basin::new(), + } + } + + /// Gets an iterator over `self.values` that omits `DataLoss` points. + pub fn valid_values(&self) -> impl Iterator { + self.basin + .as_slice() + .iter() + .filter_map(|(step, (wall_time, v))| Some((*step, *wall_time, v.as_ref().ok()?))) + } +} + +/// A value in a time series is corrupt and should be ignored. +/// +/// This is used when a point looks superficially reasonable when it's offered to the reservoir, +/// but at commit time we realize that it can't be enriched into a valid point. This might happen +/// if, for instance, a point in a scalar time series has a tensor value containing a string. We +/// don't care too much about what happens to these invalid values. Keeping them in the commit as +/// `DataLoss` tombstones is convenient, and [`TimeSeries::valid_values`] offers a view that +/// abstracts over this detail by only showing valid data. +#[derive(Debug)] +pub struct DataLoss; + +/// The value of a scalar time series at a single point. +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct ScalarValue(pub f64); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_values() { + let mut ts = TimeSeries::<&str>::new(Box::new(pb::SummaryMetadata::default())); + + let mut rsv = crate::reservoir::StageReservoir::new(10); + let wall_time = WallTime::new(0.0).unwrap(); // don't really care + rsv.offer(Step(0), "zero"); + rsv.offer(Step(1), "one"); + rsv.offer(Step(2), "two"); + rsv.offer(Step(3), "three"); + rsv.offer(Step(5), "five"); + rsv.commit_map(&mut ts.basin, |s| { + (wall_time, if s == "three" { Err(DataLoss) } else { Ok(s) }) + }); + + assert_eq!( + ts.valid_values().collect::>(), + vec![ + (Step(0), wall_time, &"zero"), + (Step(1), wall_time, &"one"), + (Step(2), wall_time, &"two"), + // missing: Step(3) + (Step(5), wall_time, &"five") + ] + ); + } +} diff --git a/tensorboard/data/server/lib.rs b/tensorboard/data/server/lib.rs index 5511250820..8d3f144423 100644 --- a/tensorboard/data/server/lib.rs +++ b/tensorboard/data/server/lib.rs @@ -15,6 +15,7 @@ limitations under the License. //! Core functionality for TensorBoard data loading. +pub mod commit; pub mod data_compat; pub mod event_file; pub mod masked_crc; diff --git a/tensorboard/data/server/types.rs b/tensorboard/data/server/types.rs index 59a764504a..bb50972f3b 100644 --- a/tensorboard/data/server/types.rs +++ b/tensorboard/data/server/types.rs @@ -70,6 +70,19 @@ impl Borrow for Tag { } } +/// The name of a TensorBoard run. +/// +/// Run names are derived from directory names relative to the logdir, but are lossily converted to +/// valid Unicode strings. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] +pub struct Run(pub String); + +impl Borrow for Run { + fn borrow(&self) -> &str { + &self.0 + } +} + #[cfg(test)] mod tests { use super::*; @@ -85,6 +98,17 @@ mod tests { assert_eq!(m.get("xent"), None); } + #[test] + fn test_run_hash_map_str_access() { + use std::collections::HashMap; + let mut m: HashMap = HashMap::new(); + m.insert(Run("train".to_string()), 1); + m.insert(Run("test".to_string()), 2); + // We can call `get` given only a `&str`, not an owned `Run`. + assert_eq!(m.get("train"), Some(&1)); + assert_eq!(m.get("val"), None); + } + #[test] fn test_wall_time() { assert_eq!(WallTime::new(f64::INFINITY), None);