Skip to content

Commit 396047f

Browse files
Updated distributed Demos (#215)
* added simple cluster template * added simple cluster template * added simple cluster template * added simple cluster template * added simple cluster template * added simple cluster template * added simple cluster template * added simple cluster template * added simple cluster template * added simple cluster template * sets correct backend for possible combinations of gpu inputs * sets correct backend for possible combinations of gpu inputs * simple slurm example * simple slurm example * simple slurm example
1 parent 83b756f commit 396047f

File tree

9 files changed

+228
-18
lines changed

9 files changed

+228
-18
lines changed

MANIFEST.in

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@ include LICENSE
1313
exclude *.sh
1414
exclude *.toml
1515
exclude *.svg
16-
recursive-include examples *.py
1716
recursive-include pytorch_lightning *.py
1817

18+
# include examples
19+
recursive-include examples *.py
20+
recursive-include examples *.md
21+
recursive-include examples *.sh
22+
1923
# exclude tests from package
2024
recursive-exclude tests *
25+
recursive-exclude site *
2126
exclude tests
2227

2328
# Exclude the documentation files

examples/new_project_templates/lightning_module_template.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,15 @@ def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
240240
parser.add_argument('--out_features', default=10, type=int)
241241
# use 500 for CPU, 50000 for GPU to see speed difference
242242
parser.add_argument('--hidden_dim', default=50000, type=int)
243-
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
243+
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=True)
244+
parser.opt_list('--learning_rate', default=0.001 * 8, type=float,
245+
options=[0.0001, 0.0005, 0.001],
246+
tunable=True)
244247

245248
# data
246249
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
247250

248251
# training params (opt)
249-
parser.opt_list('--learning_rate', default=0.001 * 8, type=float,
250-
options=[0.0001, 0.0005, 0.001, 0.005],
251-
tunable=False)
252252
parser.opt_list('--optimizer_name', default='adam', type=str,
253253
options=['adam'], tunable=False)
254254

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Multi-node examples
2+
Use these templates for multi-node training
3+
4+
## Simplest example.
5+
1. Modify this script with your CoolModel file.
6+
2. Update and submit [this bash script]()
7+
```bash
8+
squeue minimal_multi_node_demo_script.sh
9+
```
10+
11+
## Grid search on a cluster
12+
13+
#### Option 1: Run on cluster using your own SLURM script
14+
The trainer and model will work on a cluster if you configure your SLURM script correctly.
15+
16+
1. Update [this demo slurm script]().
17+
2. Submit the script
18+
```bash
19+
$ squeue demo_script.sh
20+
```
21+
22+
Most people have some way they automatically generate their own scripts.
23+
To run a grid search this way, you'd need a way to automatically generate scripts using all the combinations of
24+
hyperparameters to search over.
25+
26+
#### Option 2: Use test-tube for SLURM script
27+
With test tube we can automatically generate slurm scripts for different hyperparameter options.
28+
29+
To run this demo:
30+
```bash
31+
source activate YourCondaEnv
32+
33+
python multi_node_cluster_auto_slurm.py --email [email protected] --gpu_partition your_partition --conda_env YourCondaEnv
34+
```
35+
36+
That will submit 6 jobs. Each job will have a specific combination of hyperparams. Each job will also run on 2 nodes
37+
where each node has 8 gpus.

examples/new_project_templates/multi_node_examples/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#!/bin/bash
2+
#
3+
# Auto-generated by test-tube (https://github.com/williamFalcon/test-tube)
4+
#################
5+
6+
# set a job name
7+
#SBATCH --job-name=lightning_test
8+
#################
9+
10+
# a file for job output, you can check job progress
11+
#SBATCH --output=/slurm_output_%j.out
12+
#################
13+
14+
# a file for errors
15+
#SBATCH --error=/slurm_output_%j.err
16+
#################
17+
18+
# time needed for job
19+
#SBATCH --time=01:00:00
20+
#################
21+
22+
# gpus per node
23+
#SBATCH --gres=gpu:8
24+
#################
25+
26+
# cpus per job
27+
#SBATCH --cpus-per-task=10
28+
#################
29+
30+
# number of requested nodes
31+
#SBATCH --nodes=2
32+
#################
33+
34+
# memory per node (0 means all)
35+
#SBATCH --mem=0
36+
#################
37+
38+
# slurm will send a signal this far out before it kills the job
39+
#SBATCH --signal=USR1@300
40+
#################
41+
42+
# comment
43+
#SBATCH --comment=lightning_demo
44+
#################
45+
46+
# 1 task per gpu
47+
#SBATCH --ntasks-per-node=8
48+
#################
49+
50+
source activate YourEnv
51+
52+
# debugging flags (optional)
53+
export NCCL_DEBUG=INFO
54+
export PYTHONFAULTHANDLER=1
55+
56+
# random port between 12k and 20k
57+
export MASTER_PORT=$((12000 + RANDOM % 20000))$
58+
59+
srun python multi_node_own_slurm_script.py
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from pytorch_lightning import Trainer
2+
from test_tube import Experiment
3+
import os
4+
5+
6+
def main():
7+
# use the cool model from the main README.md
8+
model = CoolModel() # noqa: F821
9+
exp = Experiment(save_dir=os.getcwd())
10+
11+
# train on 4 GPUs across 4 nodes
12+
trainer = Trainer(
13+
experiment=exp,
14+
distributed_backend='ddp',
15+
max_nb_epochs=10,
16+
gpus=4,
17+
nb_gpu_nodes=4
18+
)
19+
20+
trainer.fit(model)
21+
22+
23+
if __name__ == '__main__':
24+
main()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#!/bin/bash -l
2+
3+
# SLURM SUBMIT SCRIPT
4+
#SBATCH --nodes=4
5+
#SBATCH --gres=gpu:4
6+
#SBATCH --ntasks-per-node=4
7+
#SBATCH --mem=0
8+
#SBATCH --time=0-02:00:00
9+
10+
# activate conda env
11+
conda activate my_env
12+
13+
# run script from above
14+
python minimal_multi_node_demo.py

examples/new_project_templates/multi_node_cluster_template.py renamed to examples/new_project_templates/multi_node_examples/multi_node_cluster_auto_slurm.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ def main(hparams, cluster):
7575
# ------------------------
7676
# 4 INIT TRAINER
7777
# ------------------------
78+
gpus = list(range(0, hparams.per_experiment_nb_gpus))
7879
trainer = Trainer(
7980
experiment=exp,
80-
cluster=cluster,
8181
checkpoint_callback=checkpoint,
8282
early_stop_callback=early_stop,
83-
gpus=hparams.gpus,
83+
gpus=gpus,
8484
nb_gpu_nodes=hyperparams.nb_gpu_nodes
8585
)
8686

@@ -99,7 +99,7 @@ def optimize_on_cluster(hyperparams):
9999
)
100100

101101
# email for cluster coms
102-
cluster.notify_job_status(email='add_email_here', on_done=True, on_fail=True)
102+
cluster.notify_job_status(email=hyperparams.email, on_done=True, on_fail=True)
103103

104104
# configure cluster
105105
cluster.per_experiment_nb_gpus = hyperparams.per_experiment_nb_gpus
@@ -109,7 +109,7 @@ def optimize_on_cluster(hyperparams):
109109
cluster.memory_mb_per_node = 0
110110

111111
# any modules for code to run in env
112-
cluster.add_command('source activate lightning')
112+
cluster.add_command(f'source activate {hyperparams.conda_env}')
113113

114114
# run only on 32GB voltas
115115
cluster.add_slurm_cmd(cmd='constraint', value='volta32gb',
@@ -121,7 +121,7 @@ def optimize_on_cluster(hyperparams):
121121
# creates and submits jobs to slurm
122122
cluster.optimize_parallel_cluster_gpu(
123123
main,
124-
nb_trials=hyperparams.nb_hopt_trials,
124+
nb_trials=hyperparams.num_hyperparam_trials,
125125
job_name=hyperparams.experiment_name
126126
)
127127

@@ -139,15 +139,10 @@ def optimize_on_cluster(hyperparams):
139139
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
140140

141141
# cluster args not defined inside the model
142-
parent_parser.add_argument('--gpu_partition', type=str, help='consult your cluster manual')
143142

144-
# TODO: make 1 param
145143
parent_parser.add_argument('--per_experiment_nb_gpus', type=int,
146-
default=2, help='how many gpus to use in a node')
147-
parent_parser.add_argument('--gpus', type=str, default='-1',
148-
help='how many gpus to use in the node')
149-
150-
parent_parser.add_argument('--nb_gpu_nodes', type=int, default=1,
144+
default=8, help='how many gpus to use in a node')
145+
parent_parser.add_argument('--nb_gpu_nodes', type=int, default=2,
151146
help='how many nodes to use in a cluster')
152147
parent_parser.add_argument('--test_tube_save_path', type=str, default=test_tube_dir,
153148
help='where to save logs')
@@ -157,9 +152,15 @@ def optimize_on_cluster(hyperparams):
157152
help='where to save model')
158153
parent_parser.add_argument('--experiment_name', type=str, default='pt_lightning_exp_a',
159154
help='test tube exp name')
160-
parent_parser.add_argument('--nb_hopt_trials', type=int, default=1,
155+
parent_parser.add_argument('--num_hyperparam_trials', type=int, default=6,
161156
help='how many grid search trials to run')
162157

158+
parent_parser.add_argument('--email', type=str, default='[email protected]',
159+
help='email for jobs')
160+
parent_parser.add_argument('--conda_env', type=str, default='base',
161+
help='email for jobs')
162+
parent_parser.add_argument('--gpu_partition', type=str, help='consult your cluster manual')
163+
163164
# allow model to overwrite or extend args
164165
parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir)
165166
hyperparams = parser.parse_args()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Multi-node example (GPU)
3+
"""
4+
import os
5+
import numpy as np
6+
import torch
7+
8+
from test_tube import HyperOptArgumentParser, Experiment
9+
from pytorch_lightning import Trainer
10+
from examples.new_project_templates.lightning_module_template import LightningTemplateModel
11+
12+
SEED = 2334
13+
torch.manual_seed(SEED)
14+
np.random.seed(SEED)
15+
16+
17+
def main(hparams):
18+
"""
19+
Main training routine specific for this project
20+
:param hparams:
21+
:return:
22+
"""
23+
# ------------------------
24+
# 1 INIT LIGHTNING MODEL
25+
# ------------------------
26+
model = LightningTemplateModel(hparams)
27+
28+
# ------------------------
29+
# 2 INIT TEST TUBE EXP
30+
# ------------------------
31+
# init experiment
32+
exp = Experiment(
33+
name='test_exp',
34+
save_dir=hyperparams.log_dir,
35+
autosave=False,
36+
description='test demo'
37+
)
38+
39+
# ------------------------
40+
# 2 INIT TRAINER
41+
# ------------------------
42+
trainer = Trainer(
43+
experiment=exp,
44+
gpus=[0, 1, 2, 3, 4, 5, 6, 7],
45+
nb_gpu_nodes=2
46+
)
47+
48+
# ------------------------
49+
# 5 START TRAINING
50+
# ------------------------
51+
trainer.fit(model)
52+
53+
54+
if __name__ == '__main__':
55+
# use current dir for logging
56+
root_dir = os.path.dirname(os.path.realpath(__file__))
57+
log_dir = os.path.join(root_dir, 'pt_lightning_demo_logs')
58+
59+
parent_parser = HyperOptArgumentParser(strategy='grid_search', add_help=False)
60+
parent_parser.add_argument('--log_dir', type=str, default=log_dir,
61+
help='where to save logs')
62+
63+
# allow model to overwrite or extend args
64+
parser = LightningTemplateModel.add_model_specific_args(parent_parser, root_dir)
65+
hyperparams = parser.parse_args()
66+
67+
# ---------------------
68+
# RUN TRAINING
69+
# ---------------------
70+
main(hyperparams)

0 commit comments

Comments
 (0)