Skip to content

Pytorch port of SAC #4219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 22, 2020
43 changes: 35 additions & 8 deletions experiment_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def run_experiment(
name: str,
steps: int,
use_torch: bool,
algo: str,
num_torch_threads: int,
use_gpu: bool,
num_envs: int = 1,
Expand All @@ -32,6 +33,7 @@ def run_experiment(
name,
str(steps),
str(use_torch),
algo,
str(num_torch_threads),
str(num_envs),
str(use_gpu),
Expand All @@ -46,7 +48,7 @@ def run_experiment(
if config_name is None:
config_name = name
run_options = parse_command_line(
[f"config/ppo/{config_name}.yaml", "--num-envs", f"{num_envs}"]
[f"config/{algo}/{config_name}.yaml", "--num-envs", f"{num_envs}"]
)
run_options.checkpoint_settings.run_id = (
f"{name}_test_" + str(steps) + "_" + ("torch" if use_torch else "tf")
Expand Down Expand Up @@ -87,20 +89,29 @@ def run_experiment(
tc_advance_total = tc_advance["total"]
tc_advance_count = tc_advance["count"]
if use_torch:
update_total = update["TorchPPOOptimizer.update"]["total"]
if algo == "ppo":
update_total = update["TorchPPOOptimizer.update"]["total"]
update_count = update["TorchPPOOptimizer.update"]["count"]
else:
update_total = update["SACTrainer._update_policy"]["total"]
update_count = update["SACTrainer._update_policy"]["count"]
evaluate_total = evaluate["TorchPolicy.evaluate"]["total"]
update_count = update["TorchPPOOptimizer.update"]["count"]
evaluate_count = evaluate["TorchPolicy.evaluate"]["count"]
else:
update_total = update["TFPPOOptimizer.update"]["total"]
if algo == "ppo":
update_total = update["TFPPOOptimizer.update"]["total"]
update_count = update["TFPPOOptimizer.update"]["count"]
else:
update_total = update["SACTrainer._update_policy"]["total"]
update_count = update["SACTrainer._update_policy"]["count"]
evaluate_total = evaluate["NNPolicy.evaluate"]["total"]
update_count = update["TFPPOOptimizer.update"]["count"]
evaluate_count = evaluate["NNPolicy.evaluate"]["count"]
# todo: do total / count
return (
name,
str(steps),
str(use_torch),
algo,
str(num_torch_threads),
str(num_envs),
str(use_gpu),
Expand Down Expand Up @@ -133,28 +144,41 @@ def main():
action="store_true",
help="If true, will only do 3dball",
)
parser.add_argument(
"--sac",
default=False,
action="store_true",
help="If true, will run sac instead of ppo",
)
args = parser.parse_args()

if args.gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

algo = "ppo"
if args.sac:
algo = "sac"

envs_config_tuples = [
("3DBall", "3DBall"),
("GridWorld", "GridWorld"),
("PushBlock", "PushBlock"),
("Hallway", "Hallway"),
("CrawlerStaticTarget", "CrawlerStatic"),
("VisualHallway", "VisualHallway"),
]
if algo == "ppo":
envs_config_tuples += [("Hallway", "Hallway"),
("VisualHallway", "VisualHallway")]
if args.ball:
envs_config_tuples = [("3DBall", "3DBall")]


labels = (
"name",
"steps",
"use_torch",
"algorithm",
"num_torch_threads",
"num_envs",
"use_gpu",
Expand All @@ -170,7 +194,7 @@ def main():
results = []
results.append(labels)
f = open(
f"result_data_steps_{args.steps}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt",
f"result_data_steps_{args.steps}_algo_{algo}_envs_{args.num_envs}_gpu_{args.gpu}_thread_{args.threads}.txt",
"w",
)
f.write(" ".join(labels) + "\n")
Expand All @@ -180,6 +204,7 @@ def main():
name=env_config[0],
steps=args.steps,
use_torch=True,
algo=algo,
num_torch_threads=1,
use_gpu=args.gpu,
num_envs=args.num_envs,
Expand All @@ -193,6 +218,7 @@ def main():
name=env_config[0],
steps=args.steps,
use_torch=True,
algo=algo,
num_torch_threads=8,
use_gpu=args.gpu,
num_envs=args.num_envs,
Expand All @@ -205,6 +231,7 @@ def main():
name=env_config[0],
steps=args.steps,
use_torch=False,
algo=algo,
num_torch_threads=1,
use_gpu=args.gpu,
num_envs=args.num_envs,
Expand Down
50 changes: 43 additions & 7 deletions ml-agents/mlagents/trainers/distributions_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ def __init__(self, mean, std):
self.std = std

def sample(self):
return self.mean + torch.randn_like(self.mean) * self.std
sample = self.mean + torch.randn_like(self.mean) * self.std
return sample

def log_prob(self, value):
var = self.std ** 2
log_scale = self.std.log()
log_scale = torch.log(self.std + EPSILON)
return (
-((value - self.mean) ** 2) / (2 * var)
-((value - self.mean) ** 2) / (2 * var + EPSILON)
- log_scale
- math.log(math.sqrt(2 * math.pi))
)
Expand All @@ -29,7 +30,28 @@ def pdf(self, value):
return torch.exp(log_prob)

def entropy(self):
return torch.log(2 * math.pi * math.e * self.std)
return torch.log(2 * math.pi * math.e * self.std + EPSILON)


class TanhGaussianDistInstance(GaussianDistInstance):
def __init__(self, mean, std):
super().__init__(mean, std)
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)

def sample(self):
unsquashed_sample = super().sample()
squashed = self.transform(unsquashed_sample)
return squashed

def _inverse_tanh(self, value):
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)

def log_prob(self, value):
unsquashed = self.transform.inv(value)
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
unsquashed, value
)


class CategoricalDistInstance(nn.Module):
Expand All @@ -47,15 +69,26 @@ def pdf(self, value):
def log_prob(self, value):
return torch.log(self.pdf(value))

def all_log_prob(self):
return torch.log(self.probs)

def entropy(self):
return torch.sum(self.probs * torch.log(self.probs), dim=-1)


class GaussianDistribution(nn.Module):
def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs):
def __init__(
self,
hidden_size,
num_outputs,
conditional_sigma=False,
tanh_squash=False,
**kwargs
):
super(GaussianDistribution, self).__init__(**kwargs)
self.conditional_sigma = conditional_sigma
self.mu = nn.Linear(hidden_size, num_outputs)
self.tanh_squash = tanh_squash
nn.init.xavier_uniform_(self.mu.weight, gain=0.01)
if conditional_sigma:
self.log_sigma = nn.Linear(hidden_size, num_outputs)
Expand All @@ -68,10 +101,13 @@ def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs):
def forward(self, inputs):
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = self.log_sigma(inputs)
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be consts?

else:
log_sigma = self.log_sigma
return [GaussianDistInstance(mu, torch.exp(log_sigma))]
if self.tanh_squash:
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
else:
return [GaussianDistInstance(mu, torch.exp(log_sigma))]


class MultiCategoricalDistribution(nn.Module):
Expand Down
Loading