@@ -80,7 +80,8 @@ def set_default_configs(self):
80
80
'time-stride' :1 ,
81
81
'l2-regularize' :0.0 ,
82
82
'max-change' : 0.75 ,
83
- 'self-repair-scale' : 1.0e-05 }
83
+ 'self-repair-scale' : 1.0e-05 ,
84
+ 'context' : 'default' }
84
85
85
86
def set_derived_configs (self ):
86
87
pass
@@ -104,6 +105,10 @@ def check_configs(self):
104
105
raise RuntimeError ('bypass-scale is nonzero but output-dim != input-dim: {0} != {1}'
105
106
'' .format (output_dim , input_dim ))
106
107
108
+ if not self .config ['context' ] in ['default' , 'left-only' , 'shift-left' , 'none' ]:
109
+ raise RuntimeError ('context must be default, left-only shift-left or none, got {}' .format (
110
+ self .config ['context' ]))
111
+
107
112
108
113
def output_name (self , auxiliary_output = None ):
109
114
assert auxiliary_output is None
@@ -142,9 +147,16 @@ def _generate_config(self):
142
147
bypass_scale = self .config ['bypass-scale' ]
143
148
dropout_proportion = self .config ['dropout-proportion' ]
144
149
time_stride = self .config ['time-stride' ]
145
- if time_stride != 0 :
150
+ context = self .config ['context' ]
151
+ if time_stride != 0 and context != 'none' :
146
152
time_offsets1 = '{0},0' .format (- time_stride )
147
- time_offsets2 = '0,{0}' .format (time_stride )
153
+ if context == 'default' :
154
+ time_offsets2 = '0,{0}' .format (time_stride )
155
+ elif context == 'shift-left' :
156
+ time_offsets2 = '{0},0' .format (- time_stride )
157
+ else :
158
+ assert context == 'left-only'
159
+ time_offsets2 = '0'
148
160
else :
149
161
time_offsets1 = '0'
150
162
time_offsets2 = '0'
0 commit comments