2
2
from typing import List , Dict , NamedTuple
3
3
import numpy as np
4
4
import abc
5
+ import csv
5
6
import os
6
7
7
8
from mlagents .tf_utils import tf
8
9
9
10
11
+ class StatsSummary (NamedTuple ):
12
+ mean : float
13
+ std : float
14
+ num : int
15
+
16
+
10
17
class StatsWriter (abc .ABC ):
11
18
"""
12
19
A StatsWriter abstract class. A StatsWriter takes in a category, key, scalar value, and step
13
20
and writes it out by some method.
14
21
"""
15
22
16
23
@abc .abstractmethod
17
- def write_stats (self , category : str , key : str , value : float , step : int ) -> None :
24
+ def write_stats (
25
+ self , category : str , values : Dict [str , StatsSummary ], step : int
26
+ ) -> None :
18
27
pass
19
28
20
29
@abc .abstractmethod
@@ -24,15 +33,23 @@ def write_text(self, category: str, text: str, step: int) -> None:
24
33
25
34
class TensorboardWriter (StatsWriter ):
26
35
def __init__ (self , base_dir : str ):
36
+ """
37
+ A StatsWriter that writes to a Tensorboard summary.
38
+ :param base_dir: The directory within which to place all the summaries. Tensorboard files will be written to a
39
+ {base_dir}/{category} directory.
40
+ """
27
41
self .summary_writers : Dict [str , tf .summary .FileWriter ] = {}
28
42
self .base_dir : str = base_dir
29
43
30
- def write_stats (self , category : str , key : str , value : float , step : int ) -> None :
44
+ def write_stats (
45
+ self , category : str , values : Dict [str , StatsSummary ], step : int
46
+ ) -> None :
31
47
self ._maybe_create_summary_writer (category )
32
- summary = tf .Summary ()
33
- summary .value .add (tag = "{}" .format (key ), simple_value = value )
34
- self .summary_writers [category ].add_summary (summary , step )
35
- self .summary_writers [category ].flush ()
48
+ for key , value in values .items ():
49
+ summary = tf .Summary ()
50
+ summary .value .add (tag = "{}" .format (key ), simple_value = value .mean )
51
+ self .summary_writers [category ].add_summary (summary , step )
52
+ self .summary_writers [category ].flush ()
36
53
37
54
def _maybe_create_summary_writer (self , category : str ) -> None :
38
55
if category not in self .summary_writers :
@@ -47,10 +64,59 @@ def write_text(self, category: str, text: str, step: int) -> None:
47
64
self .summary_writers [category ].add_summary (text , step )
48
65
49
66
50
- class StatsSummary (NamedTuple ):
51
- mean : float
52
- std : float
53
- num : int
67
+ class CSVWriter (StatsWriter ):
68
+ def __init__ (self , base_dir : str , required_fields : List [str ] = None ):
69
+ """
70
+ A StatsWriter that writes to a Tensorboard summary.
71
+ :param base_dir: The directory within which to place the CSV file, which will be {base_dir}/{category}.csv.
72
+ :param required_fields: If provided, the CSV writer won't write until these fields have statistics to write for
73
+ them.
74
+ """
75
+ # We need to keep track of the fields in the CSV, as all rows need the same fields.
76
+ self .csv_fields : Dict [str , List [str ]] = {}
77
+ self .required_fields = required_fields if required_fields else []
78
+ self .base_dir : str = base_dir
79
+
80
+ def write_stats (
81
+ self , category : str , values : Dict [str , StatsSummary ], step : int
82
+ ) -> None :
83
+ if self ._maybe_create_csv_file (category , list (values .keys ())):
84
+ row = [str (step )]
85
+ # Only record the stats that showed up in the first valid row
86
+ for key in self .csv_fields [category ]:
87
+ _val = values .get (key , None )
88
+ row .append (str (_val .mean ) if _val else "None" )
89
+ with open (self ._get_filepath (category ), "a" ) as file :
90
+ writer = csv .writer (file )
91
+ writer .writerow (row )
92
+
93
+ def _maybe_create_csv_file (self , category : str , keys : List [str ]) -> bool :
94
+ """
95
+ If no CSV file exists and the keys have the required values,
96
+ make the CSV file and write hte title row.
97
+ Returns True if there is now (or already is) a valid CSV file.
98
+ """
99
+ if category not in self .csv_fields :
100
+ summary_dir = self .base_dir
101
+ os .makedirs (summary_dir , exist_ok = True )
102
+ # Only store if the row contains the required fields
103
+ if all (item in keys for item in self .required_fields ):
104
+ self .csv_fields [category ] = keys
105
+ with open (self ._get_filepath (category ), "w" ) as file :
106
+ title_row = ["Steps" ]
107
+ title_row .extend (keys )
108
+ writer = csv .writer (file )
109
+ writer .writerow (title_row )
110
+ return True
111
+ return False
112
+ return True
113
+
114
+ def _get_filepath (self , category : str ) -> str :
115
+ file_dir = os .path .join (self .base_dir , category + ".csv" )
116
+ return file_dir
117
+
118
+ def write_text (self , category : str , text : str , step : int ) -> None :
119
+ pass
54
120
55
121
56
122
class StatsReporter :
@@ -87,11 +153,13 @@ def write_stats(self, step: int) -> None:
87
153
:param category: The category which to write out the stats.
88
154
:param step: Training step which to write these stats as.
89
155
"""
156
+ values : Dict [str , StatsSummary ] = {}
90
157
for key in StatsReporter .stats_dict [self .category ]:
91
158
if len (StatsReporter .stats_dict [self .category ][key ]) > 0 :
92
- stat_mean = float (np .mean (StatsReporter .stats_dict [self .category ][key ]))
93
- for writer in StatsReporter .writers :
94
- writer .write_stats (self .category , key , stat_mean , step )
159
+ stat_summary = self .get_stats_summaries (key )
160
+ values [key ] = stat_summary
161
+ for writer in StatsReporter .writers :
162
+ writer .write_stats (self .category , values , step )
95
163
del StatsReporter .stats_dict [self .category ]
96
164
97
165
def write_text (self , text : str , step : int ) -> None :
0 commit comments