diff --git a/docs/zh/examples/ldc2d_unsteady.md b/docs/zh/examples/ldc2d_unsteady.md index 4e219bb1ed..1ec57b39fd 100644 --- a/docs/zh/examples/ldc2d_unsteady.md +++ b/docs/zh/examples/ldc2d_unsteady.md @@ -2,6 +2,11 @@ AI Studio快速体验 +=== "模型训练命令" + ``` sh + python ldc2d_unsteady_Re10.py + ``` + ## 1. 背景简介 顶盖方腔驱动流LDC问题在许多领域中都有应用。例如,这个问题可以用于计算流体力学(CFD)领域中验证计算方法的有效性。虽然这个问题的边界条件相对简单,但是其流动特性却非常复杂。在顶盖驱动流LDC中,顶壁朝x方向以U=1的速度移动,而其他三个壁则被定义为无滑移边界条件,即速度为零。 @@ -201,7 +206,7 @@ examples/ldc/ldc2d_unsteady_Re10.py:36:44 ``` py linenums="46" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:46:60 +examples/ldc/ldc2d_unsteady_Re10.py:46:58 --8<-- ``` @@ -209,20 +214,16 @@ examples/ldc/ldc2d_unsteady_Re10.py:46:60 以作用在矩形内部点上的 `InteriorConstraint` 为例,代码如下: -``` py linenums="62" +``` py linenums="60" # set constraint -pde_constraint = ppsci.constraint.InteriorConstraint( +pde = ppsci.constraint.InteriorConstraint( equation["NavierStokes"].equations, {"continuity": 0, "momentum_x": 0, "momentum_y": 0}, geom["time_rect"], {**train_dataloader_cfg, "batch_size": NPOINT_PDE * NTIME_PDE}, ppsci.loss.MSELoss("sum"), evenly=True, - weight_dict={ - "continuity": 0.0001, # (1) - "momentum_x": 0.0001, - "momentum_y": 0.0001, - }, + weight_dict=cfg.TRAIN.weight.pde, # (1) name="EQ", ) ``` @@ -257,9 +258,9 @@ pde_constraint = ppsci.constraint.InteriorConstraint( 由于 `BoundaryConstraint` 默认会在所有边界上进行采样,而我们需要对四个边界分别施加约束,因此需通过设置 `criteria` 参数,进一步细化出四个边界,如上边界就是符合 $y = 0.05$ 的边界点集 -``` py linenums="77" +``` py linenums="71" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:77:112 +examples/ldc/ldc2d_unsteady_Re10.py:71:106 --8<-- ``` @@ -267,27 +268,27 @@ examples/ldc/ldc2d_unsteady_Re10.py:77:112 最后我们还需要对 $t=t_0$ 时刻的矩形内部点施加 N-S 方程约束,代码如下: -``` py linenums="113" +``` py linenums="107" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:113:121 +examples/ldc/ldc2d_unsteady_Re10.py:107:115 --8<-- ``` 在微分方程约束、边界约束、初值约束构建完毕之后,以我们刚才的命名为关键字,封装到一个字典中,方便后续访问。 -``` py linenums="122" +``` py linenums="116" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:122:130 +examples/ldc/ldc2d_unsteady_Re10.py:116:124 --8<-- ``` ### 3.5 超参数设定 -接下来我们需要指定训练轮数和学习率,此处我们按实验经验,使用两万轮训练轮数和带有 warmup 的 Cosine 余弦衰减学习率。 +接下来需要在配置文件中指定训练轮数,此处我们按实验经验,使用两万轮训练轮数和带有 warmup 的 Cosine 余弦衰减学习率。 -``` py linenums="132" +``` py linenums="40" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:132:139 +examples/ldc/conf/ldc2d_unsteady_Re10.yaml:40:43 --8<-- ``` @@ -295,9 +296,9 @@ examples/ldc/ldc2d_unsteady_Re10.py:132:139 训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器。 -``` py linenums="141" +``` py linenums="132" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:141:142 +examples/ldc/ldc2d_unsteady_Re10.py:132:133 --8<-- ``` @@ -305,9 +306,9 @@ examples/ldc/ldc2d_unsteady_Re10.py:141:142 在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.GeometryValidator` 构建评估器。 -``` py linenums="144" +``` py linenums="135" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:144:162 +examples/ldc/ldc2d_unsteady_Re10.py:135:153 --8<-- ``` @@ -329,9 +330,9 @@ examples/ldc/ldc2d_unsteady_Re10.py:144:162 本文中的输出数据是一个区域内的二维点集,每个时刻 $t$ 的坐标是 $(x^t_i,y^t_i)$,对应值是 $(u^t_i, v^t_i, p^t_i)$,因此我们只需要将评估的输出数据按时刻保存成 16 个 **vtu格式** 文件,最后用可视化软件打开查看即可。代码如下: -``` py linenums="164" +``` py linenums="155" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:164:195 +examples/ldc/ldc2d_unsteady_Re10.py:155:186 --8<-- ``` @@ -339,9 +340,9 @@ examples/ldc/ldc2d_unsteady_Re10.py:164:195 完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练、评估、可视化。 -``` py linenums="197" +``` py linenums="188" --8<-- -examples/ldc/ldc2d_unsteady_Re10.py:197: +examples/ldc/ldc2d_unsteady_Re10.py:188:209 --8<-- ``` diff --git a/examples/ldc/conf/ldc2d_unsteady_Re10.yaml b/examples/ldc/conf/ldc2d_unsteady_Re10.yaml new file mode 100644 index 0000000000..20ff73f807 --- /dev/null +++ b/examples/ldc/conf/ldc2d_unsteady_Re10.yaml @@ -0,0 +1,59 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: output_ldc2d_unsteady_Re10/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchanged + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} + +# set working condition +NU: 0.01 +RHO: 1.0 +NTIME_ALL: 16 + +# model settings +MODEL: + model: + input_keys: ["t", "x", "y"] + output_keys: ["u", "v", "p"] + num_layers: 9 + hidden_size: 50 + activation: "tanh" + +# training settings +TRAIN: + epochs: 20000 + iters_per_epoch: 1 + eval_during_train: true + eval_freq: 200 + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + learning_rate: 0.001 + weight: + pde: {"continuity": 0.0001,"momentum_x": 0.0001,"momentum_y": 0.0001} + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + batch_size: + residual_validator: 8192 diff --git a/examples/ldc/ldc2d_unsteady_Re10.py b/examples/ldc/ldc2d_unsteady_Re10.py index ac22f23de5..d7cc0e8f6f 100644 --- a/examples/ldc/ldc2d_unsteady_Re10.py +++ b/examples/ldc/ldc2d_unsteady_Re10.py @@ -11,30 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from os import path as osp +import hydra import numpy as np +from omegaconf import DictConfig import ppsci -from ppsci.utils import config from ppsci.utils import logger -if __name__ == "__main__": - args = config.parse_args() + +def train(cfg: DictConfig): # set random seed for reproducibility - ppsci.utils.misc.set_random_seed(42) - # set output directory - OUTPUT_DIR = "./ldc2d_unsteady_Re10" if not args.output_dir else args.output_dir + ppsci.utils.misc.set_random_seed(cfg.seed) # initialize logger - logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info") # set model - model = ppsci.arch.MLP(("t", "x", "y"), ("u", "v", "p"), 9, 50, "tanh") + model = ppsci.arch.MLP(**cfg.MODEL.model) # set equation - equation = {"NavierStokes": ppsci.equation.NavierStokes(0.01, 1.0, 2, True)} + equation = {"NavierStokes": ppsci.equation.NavierStokes(cfg.NU, cfg.RHO, 2, True)} # set timestamps(including initial t0) - timestamps = np.linspace(0.0, 1.5, 16, endpoint=True) + timestamps = np.linspace(0.0, 1.5, cfg.NTIME_ALL, endpoint=True) # set time-geometry geom = { "time_rect": ppsci.geometry.TimeXGeometry( @@ -44,34 +44,28 @@ } # set dataloader config - ITERS_PER_EPOCH = 1 train_dataloader_cfg = { "dataset": "IterableNamedArrayDataset", - "iters_per_epoch": ITERS_PER_EPOCH, + "iters_per_epoch": cfg.TRAIN.iters_per_epoch, } # pde/bc constraint use t1~tn, initial constraint use t0 - NTIME_ALL = len(timestamps) - NPOINT_PDE, NTIME_PDE = 99**2, NTIME_ALL - 1 - NPOINT_TOP, NTIME_TOP = 101, NTIME_ALL - 1 - NPOINT_DOWN, NTIME_DOWN = 101, NTIME_ALL - 1 - NPOINT_LEFT, NTIME_LEFT = 99, NTIME_ALL - 1 - NPOINT_RIGHT, NTIME_RIGHT = 99, NTIME_ALL - 1 + NPOINT_PDE, NTIME_PDE = 99**2, cfg.NTIME_ALL - 1 + NPOINT_TOP, NTIME_TOP = 101, cfg.NTIME_ALL - 1 + NPOINT_DOWN, NTIME_DOWN = 101, cfg.NTIME_ALL - 1 + NPOINT_LEFT, NTIME_LEFT = 99, cfg.NTIME_ALL - 1 + NPOINT_RIGHT, NTIME_RIGHT = 99, cfg.NTIME_ALL - 1 NPOINT_IC, NTIME_IC = 99**2, 1 # set constraint - pde_constraint = ppsci.constraint.InteriorConstraint( + pde = ppsci.constraint.InteriorConstraint( equation["NavierStokes"].equations, {"continuity": 0, "momentum_x": 0, "momentum_y": 0}, geom["time_rect"], {**train_dataloader_cfg, "batch_size": NPOINT_PDE * NTIME_PDE}, ppsci.loss.MSELoss("sum"), evenly=True, - weight_dict={ - "continuity": 0.0001, - "momentum_x": 0.0001, - "momentum_y": 0.0001, - }, + weight_dict=cfg.TRAIN.weight.pde, # (1) name="EQ", ) bc_top = ppsci.constraint.BoundaryConstraint( @@ -121,7 +115,7 @@ ) # wrap constraints together constraint = { - pde_constraint.name: pde_constraint, + pde.name: pde, bc_top.name: bc_top, bc_down.name: bc_down, bc_left.name: bc_left, @@ -130,19 +124,16 @@ } # set training hyper-parameters - EPOCHS = 20000 if not args.epochs else args.epochs lr_scheduler = ppsci.optimizer.lr_scheduler.Cosine( - EPOCHS, - ITERS_PER_EPOCH, - 0.001, - warmup_epoch=int(0.05 * EPOCHS), + **cfg.TRAIN.lr_scheduler, + warmup_epoch=int(0.05 * cfg.TRAIN.epochs), )() # set optimizer optimizer = ppsci.optimizer.Adam(lr_scheduler)(model) # set validator - NPOINT_EVAL = NPOINT_PDE * NTIME_ALL + NPOINT_EVAL = NPOINT_PDE * cfg.NTIME_ALL residual_validator = ppsci.validate.GeometryValidator( equation["NavierStokes"].equations, {"momentum_x": 0, "continuity": 0, "momentum_y": 0}, @@ -150,7 +141,7 @@ { "dataset": "NamedArrayDataset", "total_size": NPOINT_EVAL, - "batch_size": 8192, + "batch_size": cfg.EVAL.batch_size.residual_validator, "sampler": {"name": "BatchSampler"}, }, ppsci.loss.MSELoss("sum"), @@ -189,7 +180,7 @@ "visulzie_u_v": ppsci.visualize.VisualizerVtu( vis_points, {"u": lambda d: d["u"], "v": lambda d: d["v"], "p": lambda d: d["p"]}, - num_timestamps=NTIME_ALL, + num_timestamps=cfg.NTIME_ALL, prefix="result_u_v", ) } @@ -198,13 +189,13 @@ solver = ppsci.solver.Solver( model, constraint, - OUTPUT_DIR, + cfg.output_dir, optimizer, lr_scheduler, - EPOCHS, - ITERS_PER_EPOCH, - eval_during_train=True, - eval_freq=200, + cfg.TRAIN.epochs, + cfg.TRAIN.iters_per_epoch, + eval_during_train=cfg.EVAL.pretrained_model_path, + eval_freq=cfg.TRAIN.eval_freq, equation=equation, geom=geom, validator=validator, @@ -217,17 +208,117 @@ # visualize prediction after finished training solver.visualize() + +def evaluate(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + # initialize logger + logger.init_logger("ppsci", osp.join(cfg.output_dir, "eval.log"), "info") + + # set model + model = ppsci.arch.MLP(**cfg.MODEL.model) + + # set equation + equation = {"NavierStokes": ppsci.equation.NavierStokes(cfg.NU, cfg.RHO, 2, True)} + + # set timestamps(including initial t0) + timestamps = np.linspace(0.0, 1.5, cfg.NTIME_ALL, endpoint=True) + # set time-geometry + geom = { + "time_rect": ppsci.geometry.TimeXGeometry( + ppsci.geometry.TimeDomain(0.0, 1.5, timestamps=timestamps), + ppsci.geometry.Rectangle((-0.05, -0.05), (0.05, 0.05)), + ) + } + + # pde/bc constraint use t1~tn, initial constraint use t0 + NPOINT_PDE = 99**2 + NPOINT_TOP = 101 + NPOINT_DOWN = 101 + NPOINT_LEFT = 99 + NPOINT_RIGHT = 99 + NPOINT_IC = 99**2 + NTIME_PDE = cfg.NTIME_ALL - 1 + + # set validator + NPOINT_EVAL = NPOINT_PDE * cfg.NTIME_ALL + residual_validator = ppsci.validate.GeometryValidator( + equation["NavierStokes"].equations, + {"momentum_x": 0, "continuity": 0, "momentum_y": 0}, + geom["time_rect"], + { + "dataset": "NamedArrayDataset", + "total_size": NPOINT_EVAL, + "batch_size": cfg.EVAL.batch_size.residual_validator, + "sampler": {"name": "BatchSampler"}, + }, + ppsci.loss.MSELoss("sum"), + evenly=True, + metric={"MSE": ppsci.metric.MSE()}, + with_initial=True, + name="Residual", + ) + validator = {residual_validator.name: residual_validator} + + # set visualizer(optional) + NPOINT_BC = NPOINT_TOP + NPOINT_DOWN + NPOINT_LEFT + NPOINT_RIGHT + vis_initial_points = geom["time_rect"].sample_initial_interior( + (NPOINT_IC + NPOINT_BC), evenly=True + ) + vis_pde_points = geom["time_rect"].sample_interior( + (NPOINT_PDE + NPOINT_BC) * NTIME_PDE, evenly=True + ) + vis_points = vis_initial_points + # manually collate input data for visualization, + # (interior+boundary) x all timestamps + for t in range(NTIME_PDE): + for key in geom["time_rect"].dim_keys: + vis_points[key] = np.concatenate( + ( + vis_points[key], + vis_pde_points[key][ + t + * (NPOINT_PDE + NPOINT_BC) : (t + 1) + * (NPOINT_PDE + NPOINT_BC) + ], + ) + ) + + visualizer = { + "visulzie_u_v": ppsci.visualize.VisualizerVtu( + vis_points, + {"u": lambda d: d["u"], "v": lambda d: d["v"], "p": lambda d: d["p"]}, + num_timestamps=cfg.NTIME_ALL, + prefix="result_u_v", + ) + } + # directly evaluate pretrained model(optional) solver = ppsci.solver.Solver( model, - constraint, - OUTPUT_DIR, + output_dir=cfg.output_dir, equation=equation, geom=geom, validator=validator, visualizer=visualizer, - pretrained_model_path=f"{OUTPUT_DIR}/checkpoints/latest", + pretrained_model_path=cfg.EVAL.pretrained_model_path, ) solver.eval() # visualize prediction for pretrained model(optional) solver.visualize() + + +@hydra.main( + version_base=None, config_path="./conf", config_name="ldc2d_unsteady_Re10.yaml" +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main()