@@ -27,6 +27,7 @@ def __init__(
27
27
init_path : str = None ,
28
28
multi_gpu : bool = False ,
29
29
force_torch : bool = False ,
30
+ force_tensorflow : bool = False ,
30
31
):
31
32
"""
32
33
The TrainerFactory generates the Trainers based on the configuration passed as
@@ -45,7 +46,9 @@ def __init__(
45
46
:param init_path: Path from which to load model.
46
47
:param multi_gpu: If True, multi-gpu will be used. (currently not available)
47
48
:param force_torch: If True, the Trainers will all use the PyTorch framework
48
- instead of the TensorFlow framework.
49
+ instead of what is specified in the config YAML.
50
+ :param force_tensorflow: If True, thee Trainers will all use the TensorFlow
51
+ framework.
49
52
"""
50
53
self .trainer_config = trainer_config
51
54
self .output_path = output_path
@@ -57,6 +60,7 @@ def __init__(
57
60
self .multi_gpu = multi_gpu
58
61
self .ghost_controller = GhostController ()
59
62
self ._force_torch = force_torch
63
+ self ._force_tf = force_tensorflow
60
64
61
65
def generate (self , behavior_name : str ) -> Trainer :
62
66
if behavior_name not in self .trainer_config .keys ():
@@ -67,6 +71,18 @@ def generate(self, behavior_name: str) -> Trainer:
67
71
trainer_settings = self .trainer_config [behavior_name ]
68
72
if self ._force_torch :
69
73
trainer_settings .framework = FrameworkType .PYTORCH
74
+ logger .warning (
75
+ "Note that specifying --torch is not required anymore as PyTorch is the default framework."
76
+ )
77
+ if self ._force_tf :
78
+ trainer_settings .framework = FrameworkType .TENSORFLOW
79
+ logger .warning (
80
+ "Setting the framework to TensorFlow. TensorFlow trainers will be deprecated in the future."
81
+ )
82
+ if self ._force_torch :
83
+ logger .warning (
84
+ "Both --torch and --tensorflow CLI options were specified. Using TensorFlow."
85
+ )
70
86
return TrainerFactory ._initialize_trainer (
71
87
trainer_settings ,
72
88
behavior_name ,
0 commit comments