Skip to content

Commit cd29e95

Browse files
author
Ervin T
authored
Prevent tf.Session() from eating up all the GPU memory (#3219)
* Use soft placement and allow_growth for Session * Move config generation to tf utils * Re-add self.graph
1 parent f9d1faf commit cd29e95

File tree

4 files changed

+23
-10
lines changed

4 files changed

+23
-10
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from mlagents.tf_utils.tf import tf as tf # noqa
22
from mlagents.tf_utils.tf import set_warnings_enabled # noqa
3+
from mlagents.tf_utils.tf import generate_session_config # noqa

ml-agents/mlagents/tf_utils/tf.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,23 @@
2323

2424
def set_warnings_enabled(is_enabled: bool) -> None:
2525
"""
26-
Enable or disable tensorflow warnings (notabley, this disables deprecation warnings.
26+
Enable or disable tensorflow warnings (notably, this disables deprecation warnings.
2727
:param is_enabled:
2828
"""
2929
level = tf_logging.WARN if is_enabled else tf_logging.ERROR
3030
tf_logging.set_verbosity(level)
31+
32+
33+
def generate_session_config() -> tf.ConfigProto:
34+
"""
35+
Generate a ConfigProto to use for ML-Agents that doesn't consume all of the GPU memory
36+
and allows for soft placement in the case of multi-GPU.
37+
"""
38+
config = tf.ConfigProto()
39+
config.gpu_options.allow_growth = True
40+
# For multi-GPU training, set allow_soft_placement to True to allow
41+
# placing the operation into an alternative device automatically
42+
# to prevent from exceptions if the device doesn't suppport the operation
43+
# or the device does not exist
44+
config.allow_soft_placement = True
45+
return config

ml-agents/mlagents/trainers/tf_policy.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from mlagents.tf_utils import tf
6+
from mlagents import tf_utils
67

78
from mlagents_envs.exception import UnityException
89
from mlagents.trainers.policy import Policy
@@ -69,14 +70,9 @@ def __init__(self, seed, brain, trainer_parameters):
6970
self.model_path = trainer_parameters["model_path"]
7071
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
7172
self.graph = tf.Graph()
72-
config = tf.ConfigProto()
73-
config.gpu_options.allow_growth = True
74-
# For multi-GPU training, set allow_soft_placement to True to allow
75-
# placing the operation into an alternative device automatically
76-
# to prevent from exceptions if the device doesn't suppport the operation
77-
# or the device does not exist
78-
config.allow_soft_placement = True
79-
self.sess = tf.Session(config=config, graph=self.graph)
73+
self.sess = tf.Session(
74+
config=tf_utils.generate_session_config(), graph=self.graph
75+
)
8076
self.saver = None
8177
if self.use_recurrent:
8278
self.m_size = trainer_parameters["memory_size"]

ml-agents/mlagents/trainers/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict, List, Deque, Any
44

55
from mlagents.tf_utils import tf
6+
from mlagents import tf_utils
67

78
from collections import deque
89

@@ -70,7 +71,7 @@ def write_tensorboard_text(self, key: str, input_dict: Dict[str, Any]) -> None:
7071
:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
7172
"""
7273
try:
73-
with tf.Session() as sess:
74+
with tf.Session(config=tf_utils.generate_session_config()) as sess:
7475
s_op = tf.summary.text(
7576
key,
7677
tf.convert_to_tensor(

0 commit comments

Comments
 (0)