Skip to content

Commit d5a159f

Browse files
author
Chris Elion
authored
inference - fill 0s for done Agents (#3232)
* fill 0s for done agents * docstrings
1 parent 959b738 commit d5a159f

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,30 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentIn
101101
var agentIndex = 0;
102102
foreach (var info in infos)
103103
{
104-
var tensorOffset = 0;
105-
// Write each sensor consecutively to the tensor
106-
foreach (var sensorIndex in m_SensorIndices)
104+
if (info.agentInfo.done)
107105
{
108-
var sensor = info.sensors[sensorIndex];
109-
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, tensorOffset);
110-
var numWritten = sensor.Write(m_WriteAdapter);
111-
tensorOffset += numWritten;
106+
// If the agent is done, we might have a stale reference to the sensors
107+
// e.g. a dependent object might have been disposed.
108+
// To avoid this, just fill observation with zeroes instead of calling sensor.Write.
109+
TensorUtils.FillTensorBatch(tensorProxy, agentIndex, 0.0f);
110+
}
111+
else
112+
{
113+
var tensorOffset = 0;
114+
// Write each sensor consecutively to the tensor
115+
foreach (var sensorIndex in m_SensorIndices)
116+
{
117+
var sensor = info.sensors[sensorIndex];
118+
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, tensorOffset);
119+
var numWritten = sensor.Write(m_WriteAdapter);
120+
tensorOffset += numWritten;
121+
}
122+
Debug.AssertFormat(
123+
tensorOffset == vecObsSizeT,
124+
"mismatch between vector observation size ({0}) and number of observations written ({1})",
125+
vecObsSizeT, tensorOffset
126+
);
112127
}
113-
Debug.AssertFormat(
114-
tensorOffset == vecObsSizeT,
115-
"mismatch between vector observation size ({0}) and number of observations written ({1})",
116-
vecObsSizeT, tensorOffset
117-
);
118128

119129
agentIndex++;
120130
}
@@ -356,8 +366,19 @@ public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentIn
356366
foreach (var infoSensorPair in infos)
357367
{
358368
var sensor = infoSensorPair.sensors[m_SensorIndex];
359-
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, 0);
360-
sensor.Write(m_WriteAdapter);
369+
if (infoSensorPair.agentInfo.done)
370+
{
371+
// If the agent is done, we might have a stale reference to the sensors
372+
// e.g. a dependent object might have been disposed.
373+
// To avoid this, just fill observation with zeroes instead of calling sensor.Write.
374+
TensorUtils.FillTensorBatch(tensorProxy, agentIndex, 0.0f);
375+
}
376+
else
377+
{
378+
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, 0);
379+
sensor.Write(m_WriteAdapter);
380+
381+
}
361382
agentIndex++;
362383
}
363384
}

UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorProxy.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,29 @@ public static TensorProxy TensorProxyFromBarracuda(Tensor src, string nameOverri
9090
};
9191
}
9292

93+
/// <summary>
94+
/// Fill a specific batch of a TensorProxy with a given value
95+
/// </summary>
96+
/// <param name="tensorProxy"></param>
97+
/// <param name="batch">The batch index to fill.</param>
98+
/// <param name="fillValue"></param>
99+
public static void FillTensorBatch(TensorProxy tensorProxy, int batch, float fillValue)
100+
{
101+
var height = tensorProxy.data.height;
102+
var width = tensorProxy.data.width;
103+
var channels = tensorProxy.data.channels;
104+
for (var h = 0; h < height; h++)
105+
{
106+
for (var w = 0; w < width; w++)
107+
{
108+
for (var c = 0; c < channels; c++)
109+
{
110+
tensorProxy.data[batch, h, w, c] = fillValue;
111+
}
112+
}
113+
}
114+
}
115+
93116
/// <summary>
94117
/// Fill a pre-allocated Tensor with random numbers
95118
/// </summary>

0 commit comments

Comments
 (0)