Skip to content

Mla 1145 cherrypick #4206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DevProject/Packages/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"com.unity.ide.vscode": "1.1.4",
"com.unity.ml-agents": "file:../../com.unity.ml-agents",
"com.unity.ml-agents.extensions": "file:../../com.unity.ml-agents.extensions",
"com.unity.barracuda": "file:../../../UnityInferenceEngine/UnityProject/Assets/Barracuda",
"com.unity.multiplayer-hlapi": "1.0.6",
"com.unity.package-manager-doctools": "1.1.1-preview.3",
"com.unity.package-validation-suite": "0.11.0-preview",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using UnityEngine;
using Unity.Barracuda;
using System.IO;
using Unity.Barracuda.ONNX;
using Unity.MLAgents;
using Unity.MLAgents.Policies;
#if UNITY_EDITOR
Expand Down Expand Up @@ -113,9 +114,7 @@ void GetAssetPathFromCommandLine()
{
m_OverrideExtension = args[i + 1].Trim().ToLower();
var isKnownExtension = k_SupportedExtensions.Contains(m_OverrideExtension);
// Not supported yet - need to update the model loading code to support
var isOnnx = m_OverrideExtension.Equals("onnx");
if (!isKnownExtension || isOnnx)
if (!isKnownExtension)
{
Debug.LogError($"loading unsupported format: {m_OverrideExtension}");
Application.Quit(1);
Expand Down Expand Up @@ -209,10 +208,10 @@ public NNModel GetModelForBehaviorName(string behaviorName)
return null;
}

byte[] model = null;
byte[] rawModel = null;
try
{
model = File.ReadAllBytes(assetPath);
rawModel = File.ReadAllBytes(assetPath);
}
catch(IOException)
{
Expand All @@ -222,11 +221,34 @@ public NNModel GetModelForBehaviorName(string behaviorName)
return null;
}

// Note - this approach doesn't work for onnx files. Need to replace with
// the equivalent of ONNXModelImporter.OnImportAsset()
var asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = ScriptableObject.CreateInstance<NNModelData>();
asset.modelData.Value = model;
NNModel asset;
var isOnnx = m_OverrideExtension.Equals("onnx");
if (isOnnx)
{
var importer = new ONNXModelImporterRuntime(true);
var onnxModel = importer.Import(rawModel);

NNModelData assetData = ScriptableObject.CreateInstance<NNModelData>();
using (var memoryStream = new MemoryStream())
using (var writer = new BinaryWriter(memoryStream))
{
ModelWriter.Save(writer, onnxModel);
assetData.Value = memoryStream.ToArray();
}
assetData.name = "Data";
assetData.hideFlags = HideFlags.HideInHierarchy;

asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = assetData;
}
else
{
// Note - this approach doesn't work for onnx files. Need to replace with
// the equivalent of ONNXModelImporter.OnImportAsset()
asset = ScriptableObject.CreateInstance<NNModel>();
asset.modelData = ScriptableObject.CreateInstance<NNModelData>();
asset.modelData.Value = rawModel;
}

asset.name = "Override - " + Path.GetFileName(assetPath);
m_CachedModels[behaviorName] = asset;
Expand Down
1 change: 1 addition & 0 deletions Project/Packages/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"com.unity.collab-proxy": "1.2.15",
"com.unity.ml-agents": "file:../../com.unity.ml-agents",
"com.unity.ml-agents.extensions": "file:../../com.unity.ml-agents.extensions",
"com.unity.barracuda": "file:../../../UnityInferenceEngine/UnityProject/Assets/Barracuda",
"com.unity.package-manager-ui": "2.0.8",
"com.unity.purchasing": "2.0.3",
"com.unity.textmeshpro": "1.4.1",
Expand Down
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ empty string). (#4155)
- Fixed issue with FoodCollector, Soccer, and WallJump when playing with keyboard. (#4147, #4174)
- Fixed a crash in StatsReporter when using threaded trainers with very frequent summary writes
(#4201)
- `mlagents-learn` will now raise an error immediately if `--num-envs` is greater than 1 without setting the `--env`
argument. (#4203)

## [1.1.0-preview] - 2020-06-10
### Major Changes
Expand Down
2 changes: 1 addition & 1 deletion com.unity.ml-agents/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"unity": "2018.4",
"description": "Use state-of-the-art machine learning to create intelligent character behaviors in any Unity environment (games, robotics, film, etc.).",
"dependencies": {
"com.unity.barracuda": "1.0.1"
"com.unity.barracuda": "1.1.0-preview"
}
}
7 changes: 6 additions & 1 deletion ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,14 @@ class EnvironmentSettings:
env_path: Optional[str] = parser.get_default("env_path")
env_args: Optional[List[str]] = parser.get_default("env_args")
base_port: int = parser.get_default("base_port")
num_envs: int = parser.get_default("num_envs")
num_envs: int = attr.ib(default=parser.get_default("num_envs"))
seed: int = parser.get_default("seed")

@num_envs.validator
def validate_num_envs(self, attribute, value):
if value > 1 and self.env_path is None:
raise ValueError("num_envs must be 1 if env_path is not set.")


@attr.s(auto_attribs=True)
class EngineSettings:
Expand Down
16 changes: 16 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
RewardSignalType,
RewardSignalSettings,
CuriositySettings,
EnvironmentSettings,
EnvironmentParameterSettings,
ConstantSettings,
UniformSettings,
Expand Down Expand Up @@ -452,3 +453,18 @@ def test_exportable_settings(use_defaults):
check_dict_is_at_least(second_export, dict_export)
# Check that the two exports are the same
assert dict_export == second_export


def test_environment_settings():
# default args
EnvironmentSettings()

# 1 env is OK if no env_path
EnvironmentSettings(num_envs=1)

# multiple envs is OK if env_path is set
EnvironmentSettings(num_envs=42, env_path="/foo/bar.exe")

# Multiple environments with no env_path is an error
with pytest.raises(ValueError):
EnvironmentSettings(num_envs=2)