From afa246ed92feaa47b54d3899dfdf92c337418bc7 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 22 Oct 2023 14:55:20 +0800 Subject: [PATCH 1/3] Add poiseuille_flow hydra --- examples/pipe/conf/poiseuille_flow.yaml | 73 +++++++++++++++++++++ examples/pipe/poiseuille_flow.py | 86 ++++++++++++++----------- 2 files changed, 122 insertions(+), 37 deletions(-) create mode 100644 examples/pipe/conf/poiseuille_flow.yaml diff --git a/examples/pipe/conf/poiseuille_flow.yaml b/examples/pipe/conf/poiseuille_flow.yaml new file mode 100644 index 0000000000..186ef53221 --- /dev/null +++ b/examples/pipe/conf/poiseuille_flow.yaml @@ -0,0 +1,73 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_poiseuille_flow/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 42 +output_dir: ${hydra:run.dir} + +# set working condition +NU_MEAN: 0.001 +NU_STD: 0.9 +L: 1.0 # length of pipe +R: 0.05 # radius of pipe +RHO: 1 # density +P_OUT: 0 # pressure at the outlet of pipe +P_IN: 0.1 # pressure at the inlet of pipe +N_x: 10 +N_y: 50 +N_p: 50 +X_IN: 0 + +# model settings +MODEL: + u_net: + input_keys: ["sin(x)", "cos(x)", "y", "nu"] + output_keys: ["u"] + num_layers: 3 + hidden_size: 50 + activation: "swish" + v_net: + input_keys: ["sin(x)", "cos(x)", "y", "nu"] + output_keys: ["v"] + num_layers: 3 + hidden_size: 50 + activation: "swish" + p_net: + input_keys: ["sin(x)", "cos(x)", "y", "nu"] + output_keys: ["p"] + num_layers: 3 + hidden_size: 50 + activation: "swish" + +# training settings +TRAIN: + epochs: 3000 + batch_size: + pde_constraint: 128 + eval_during_train: false + save_freq: 10 + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null diff --git a/examples/pipe/poiseuille_flow.py b/examples/pipe/poiseuille_flow.py index 8b270bf542..115a62d7b4 100644 --- a/examples/pipe/poiseuille_flow.py +++ b/examples/pipe/poiseuille_flow.py @@ -18,47 +18,43 @@ import copy import os -import os.path as osp +from os import path as osp +import hydra import matplotlib.pyplot as plt import numpy as np import paddle +from omegaconf import DictConfig +import ppsci from ppsci.utils import checker +from ppsci.utils import logger if not checker.dynamic_import_to_globals("seaborn"): raise ModuleNotFoundError("Please install seaborn through pip first.") import seaborn as sns -import ppsci -from ppsci.utils import config -from ppsci.utils import logger -if __name__ == "__main__": - args = config.parse_args() +def train(cfg: DictConfig): # set random seed for reproducibility - ppsci.utils.misc.set_random_seed(42) - - # set output directory - OUTPUT_DIR = "./output_poiseuille_flow" - + ppsci.utils.misc.set_random_seed(cfg.seed) # initialize logger - logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") + logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") - NU_MEAN = 0.001 - NU_STD = 0.9 - L = 1.0 # length of pipe - R = 0.05 # radius of pipe - RHO = 1 # density - P_OUT = 0 # pressure at the outlet of pipe - P_IN = 0.1 # pressure at the inlet of pipe + NU_MEAN = cfg.NU_MEAN + NU_STD = cfg.NU_STD + L = cfg.L # length of pipe + R = cfg.R # radius of pipe + RHO = cfg.RHO # density + P_OUT = cfg.P_OUT # pressure at the outlet of pipe + P_IN = cfg.P_IN # pressure at the inlet of pipe - N_x = 10 - N_y = 50 - N_p = 50 + N_x = cfg.N_x + N_y = cfg.N_y + N_p = cfg.N_p - X_IN = 0 + X_IN = cfg.X_IN X_OUT = X_IN + L Y_START = -R Y_END = Y_START + 2 * R @@ -86,16 +82,16 @@ input_y = data_2d_xy_shuffle[:, 1].reshape(data_2d_xy_shuffle.shape[0], 1) input_nu = data_2d_xy_shuffle[:, 2].reshape(data_2d_xy_shuffle.shape[0], 1) - interior_data = {"x": input_x, "y": input_y, "nu": input_nu} + interior_data = {"x": input_x, "y": input_y, "nu": input_nu} # noqa: F841 interior_geom = ppsci.geometry.PointCloud( interior={"x": input_x, "y": input_y, "nu": input_nu}, coord_keys=("x", "y", "nu"), ) # set model - model_u = ppsci.arch.MLP(("sin(x)", "cos(x)", "y", "nu"), ("u",), 3, 50, "swish") - model_v = ppsci.arch.MLP(("sin(x)", "cos(x)", "y", "nu"), ("v",), 3, 50, "swish") - model_p = ppsci.arch.MLP(("sin(x)", "cos(x)", "y", "nu"), ("p",), 3, 50, "swish") + model_u = ppsci.arch.MLP(**cfg.MODEL.u_net) + model_v = ppsci.arch.MLP(**cfg.MODEL.v_net) + model_p = ppsci.arch.MLP(**cfg.MODEL.p_net) def input_trans(input): x, y = input["x"], input["y"] @@ -137,7 +133,7 @@ def output_trans_p(input, out): } # set constraint - BATCH_SIZE = 128 + BATCH_SIZE = cfg.TRAIN.batch_size.pde_constraint ITERS_PER_EPOCH = int((N_x * N_y * N_p) / BATCH_SIZE) pde_constraint = ppsci.constraint.InteriorConstraint( @@ -163,18 +159,16 @@ def output_trans_p(input, out): # wrap constraints together constraint = {pde_constraint.name: pde_constraint} - EPOCHS = 3000 if not args.epochs else args.epochs - # initialize solver solver = ppsci.solver.Solver( model, constraint, - OUTPUT_DIR, + cfg.output_dir, optimizer, - epochs=EPOCHS, + epochs=cfg.TRAIN.epochs, iters_per_epoch=ITERS_PER_EPOCH, - eval_during_train=False, - save_freq=10, + eval_during_train=cfg.TRAIN.eval_during_train, + save_freq=cfg.TRAIN.save_freq, equation=equation, ) @@ -189,8 +183,8 @@ def output_trans_p(input, out): } output_dict = solver.predict(input_dict, return_numpy=True) u_pred = output_dict["u"].reshape(N_y, N_x, N_p) - v_pred = output_dict["v"].reshape(N_y, N_x, N_p) - p_pred = output_dict["p"].reshape(N_y, N_x, N_p) + v_pred = output_dict["v"].reshape(N_y, N_x, N_p) # noqa: F841 + p_pred = output_dict["p"].reshape(N_y, N_x, N_p) # noqa: F841 # Analytical result, y = data_1d_y u_analytical = np.zeros([N_y, N_x, N_p]) @@ -206,7 +200,7 @@ def output_trans_p(input, out): ytext = [0.45, 0.28, 0.1, 0.01] # Plot - PLOT_DIR = osp.join(OUTPUT_DIR, "visu") + PLOT_DIR = osp.join(cfg.output_dir, "visu") os.makedirs(PLOT_DIR, exist_ok=True) plt.figure(1) plt.clf() @@ -291,3 +285,21 @@ def output_trans_p(input, out): ax1.tick_params(axis="x", labelsize=fontsize) ax1.tick_params(axis="y", labelsize=fontsize) plt.savefig(osp.join(PLOT_DIR, "pipe_unformUQ.png"), bbox_inches="tight") + + +def evaluate(cfg: DictConfig): + print("Not supported.") + + +@hydra.main(version_base=None, config_path="./conf", config_name="poiseuille_flow.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() From f9edb2f821c93e45e9a2813514d1efc55214a438 Mon Sep 17 00:00:00 2001 From: co63oc Date: Sun, 22 Oct 2023 15:12:04 +0800 Subject: [PATCH 2/3] Fix --- docs/zh/examples/labelfree_DNN_surrogate.md | 44 +++++++++++++-------- examples/pipe/conf/poiseuille_flow.yaml | 1 + examples/pipe/poiseuille_flow.py | 5 +-- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/docs/zh/examples/labelfree_DNN_surrogate.md b/docs/zh/examples/labelfree_DNN_surrogate.md index ba1eace278..bcf09987d3 100644 --- a/docs/zh/examples/labelfree_DNN_surrogate.md +++ b/docs/zh/examples/labelfree_DNN_surrogate.md @@ -1,5 +1,17 @@ # LabelFree-DNN-Surrogate (Aneurysm flow & Pipe flow) +=== "模型训练命令" + + ``` sh + python poiseuille_flow.py + ``` + +=== "模型评估命令" + + ``` sh + python poiseuille_flow.py mode=eval EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/poiseuille_flow/poiseuille_flow_pretrained.pdparams + ``` + ## 1. 背景简介 流体动力学问题的数值模拟主要依赖于使用多项式将控制方程在空间或/和时间上离散化为有限维代数系统。由于物理的多尺度特性和对复杂几何体进行网格划分的敏感性,这样的过程对于大多数实时应用程序(例如,临床诊断和手术计划)和多查询分析(例如,优化设计和不确定性量化)。在本文中,我们提供了一种物理约束的 DL 方法,用于在不依赖任何模拟数据的情况下对流体流动进行代理建模。 具体来说,设计了一种结构化深度神经网络 (DNN) 架构来强制执行初始条件和边界条件,并将控制偏微分方程(即 Navier-Stokes 方程)纳入 DNN的损失中以驱动训练。 对与血液动力学应用相关的许多内部流动进行了数值实验,并研究了流体特性和域几何中不确定性的前向传播。结果表明,DL 代理近似与第一原理数值模拟之间的流场和前向传播不确定性非常吻合。 @@ -79,15 +91,15 @@ $$ 上式中 $f_1, f_2, f_3$ 即为 MLP 模型本身,$transform_{input}, transform_{output}$, 表示施加额外的结构化自定义层,用于施加约束和丰富输入,用 PaddleScience 代码表示如下: -``` py linenums="95" +``` py linenums="91" --8<-- -examples/pipe/poiseuille_flow.py:95:98 +examples/pipe/poiseuille_flow.py:91:93 --8<-- ``` -``` py linenums="123" +``` py linenums="118" --8<-- -examples/pipe/poiseuille_flow.py:123:129 +examples/pipe/poiseuille_flow.py:118:124 --8<-- ``` @@ -99,9 +111,9 @@ examples/pipe/poiseuille_flow.py:123:129 由于本案例使用的是 Navier-Stokes 方程的2维稳态形式,因此可以直接使用 PaddleScience 内置的 `NavierStokes`。 -``` py linenums="134" +``` py linenums="130" --8<-- -examples/pipe/poiseuille_flow.py:134:137 +examples/pipe/poiseuille_flow.py:130:132 --8<-- ``` @@ -111,9 +123,9 @@ examples/pipe/poiseuille_flow.py:134:137 本文中本案例的计算域和参数自变量 $\nu$ 由`numpy`随机数生成的点云构成,因此可以直接使用 PaddleScience 内置的点云几何 `PointCloud` 组合成空间的 `Geometry` 计算域。 -``` py linenums="67" +``` py linenums="64" --8<-- -examples/pipe/poiseuille_flow.py:67:94 +examples/pipe/poiseuille_flow.py:64:88 --8<-- ``` @@ -175,9 +187,9 @@ examples/pipe/poiseuille_flow.py:67:94 以作用在流体域内部点上的 `InteriorConstraint` 为例,代码如下: - ``` py linenums="143" + ``` py linenums="138" --8<-- - examples/pipe/poiseuille_flow.py:143:164 + examples/pipe/poiseuille_flow.py:138:159 --8<-- ``` @@ -201,9 +213,9 @@ examples/pipe/poiseuille_flow.py:67:94 训练过程会调用优化器来更新模型参数,此处选择较为常用的 `Adam` 优化器。 -``` py linenums="131" +``` py linenums="127" --8<-- -examples/pipe/poiseuille_flow.py:131:132 +examples/pipe/poiseuille_flow.py:127:127 --8<-- ``` @@ -211,9 +223,9 @@ examples/pipe/poiseuille_flow.py:131:132 完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练。 -``` py linenums="167" +``` py linenums="162" --8<-- -examples/pipe/poiseuille_flow.py:167:181 +examples/pipe/poiseuille_flow.py:162:174 --8<-- ``` @@ -223,9 +235,9 @@ examples/pipe/poiseuille_flow.py:167:181 2. 当我们选取截断高斯分布的动力粘性系数 ${\nu}$ 采样(均值为 $\hat{\nu} = 10^{−3}$, 方差 $\sigma_{\nu}​=2.67×10^{−4}$),中心处速度的概率密度函数和解析解对比 -``` py linenums="185" +``` py linenums="176" --8<-- -examples/pipe/poiseuille_flow.py:185:293 +examples/pipe/poiseuille_flow.py:176:284 --8<-- ``` diff --git a/examples/pipe/conf/poiseuille_flow.yaml b/examples/pipe/conf/poiseuille_flow.yaml index 186ef53221..09a26529f6 100644 --- a/examples/pipe/conf/poiseuille_flow.yaml +++ b/examples/pipe/conf/poiseuille_flow.yaml @@ -63,6 +63,7 @@ TRAIN: epochs: 3000 batch_size: pde_constraint: 128 + learning_rate: 5.0e-3 eval_during_train: false save_freq: 10 pretrained_model_path: null diff --git a/examples/pipe/poiseuille_flow.py b/examples/pipe/poiseuille_flow.py index 115a62d7b4..b2503ae826 100644 --- a/examples/pipe/poiseuille_flow.py +++ b/examples/pipe/poiseuille_flow.py @@ -82,7 +82,6 @@ def train(cfg: DictConfig): input_y = data_2d_xy_shuffle[:, 1].reshape(data_2d_xy_shuffle.shape[0], 1) input_nu = data_2d_xy_shuffle[:, 2].reshape(data_2d_xy_shuffle.shape[0], 1) - interior_data = {"x": input_x, "y": input_y, "nu": input_nu} # noqa: F841 interior_geom = ppsci.geometry.PointCloud( interior={"x": input_x, "y": input_y, "nu": input_nu}, coord_keys=("x", "y", "nu"), @@ -125,7 +124,7 @@ def output_trans_p(input, out): model = ppsci.arch.ModelList((model_u, model_v, model_p)) # set optimizer - optimizer = ppsci.optimizer.Adam(5e-3)(model) + optimizer = ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model) # set euqation equation = { @@ -183,8 +182,6 @@ def output_trans_p(input, out): } output_dict = solver.predict(input_dict, return_numpy=True) u_pred = output_dict["u"].reshape(N_y, N_x, N_p) - v_pred = output_dict["v"].reshape(N_y, N_x, N_p) # noqa: F841 - p_pred = output_dict["p"].reshape(N_y, N_x, N_p) # noqa: F841 # Analytical result, y = data_1d_y u_analytical = np.zeros([N_y, N_x, N_p]) From 6e1206ad19d0196a3b4fcc557dfd6d842477012e Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 25 Oct 2023 09:43:42 +0800 Subject: [PATCH 3/3] Fix --- docs/zh/examples/labelfree_DNN_surrogate.md | 30 ++++---- examples/pipe/poiseuille_flow.py | 76 +++++++++------------ 2 files changed, 48 insertions(+), 58 deletions(-) diff --git a/docs/zh/examples/labelfree_DNN_surrogate.md b/docs/zh/examples/labelfree_DNN_surrogate.md index bcf09987d3..cd81bd64c8 100644 --- a/docs/zh/examples/labelfree_DNN_surrogate.md +++ b/docs/zh/examples/labelfree_DNN_surrogate.md @@ -91,15 +91,15 @@ $$ 上式中 $f_1, f_2, f_3$ 即为 MLP 模型本身,$transform_{input}, transform_{output}$, 表示施加额外的结构化自定义层,用于施加约束和丰富输入,用 PaddleScience 代码表示如下: -``` py linenums="91" +``` py linenums="78" --8<-- -examples/pipe/poiseuille_flow.py:91:93 +examples/pipe/poiseuille_flow.py:78:80 --8<-- ``` -``` py linenums="118" +``` py linenums="105" --8<-- -examples/pipe/poiseuille_flow.py:118:124 +examples/pipe/poiseuille_flow.py:105:111 --8<-- ``` @@ -111,9 +111,9 @@ examples/pipe/poiseuille_flow.py:118:124 由于本案例使用的是 Navier-Stokes 方程的2维稳态形式,因此可以直接使用 PaddleScience 内置的 `NavierStokes`。 -``` py linenums="130" +``` py linenums="117" --8<-- -examples/pipe/poiseuille_flow.py:130:132 +examples/pipe/poiseuille_flow.py:117:121 --8<-- ``` @@ -123,9 +123,9 @@ examples/pipe/poiseuille_flow.py:130:132 本文中本案例的计算域和参数自变量 $\nu$ 由`numpy`随机数生成的点云构成,因此可以直接使用 PaddleScience 内置的点云几何 `PointCloud` 组合成空间的 `Geometry` 计算域。 -``` py linenums="64" +``` py linenums="52" --8<-- -examples/pipe/poiseuille_flow.py:64:88 +examples/pipe/poiseuille_flow.py:52:75 --8<-- ``` @@ -187,9 +187,9 @@ examples/pipe/poiseuille_flow.py:64:88 以作用在流体域内部点上的 `InteriorConstraint` 为例,代码如下: - ``` py linenums="138" + ``` py linenums="128" --8<-- - examples/pipe/poiseuille_flow.py:138:159 + examples/pipe/poiseuille_flow.py:128:146 --8<-- ``` @@ -215,7 +215,7 @@ examples/pipe/poiseuille_flow.py:64:88 ``` py linenums="127" --8<-- -examples/pipe/poiseuille_flow.py:127:127 +examples/pipe/poiseuille_flow.py:114:114 --8<-- ``` @@ -223,9 +223,9 @@ examples/pipe/poiseuille_flow.py:127:127 完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练。 -``` py linenums="162" +``` py linenums="152" --8<-- -examples/pipe/poiseuille_flow.py:162:174 +examples/pipe/poiseuille_flow.py:152:164 --8<-- ``` @@ -235,9 +235,9 @@ examples/pipe/poiseuille_flow.py:162:174 2. 当我们选取截断高斯分布的动力粘性系数 ${\nu}$ 采样(均值为 $\hat{\nu} = 10^{−3}$, 方差 $\sigma_{\nu}​=2.67×10^{−4}$),中心处速度的概率密度函数和解析解对比 -``` py linenums="176" +``` py linenums="166" --8<-- -examples/pipe/poiseuille_flow.py:176:284 +examples/pipe/poiseuille_flow.py:166:274 --8<-- ``` diff --git a/examples/pipe/poiseuille_flow.py b/examples/pipe/poiseuille_flow.py index b2503ae826..fdea0eaba8 100644 --- a/examples/pipe/poiseuille_flow.py +++ b/examples/pipe/poiseuille_flow.py @@ -42,34 +42,21 @@ def train(cfg: DictConfig): # initialize logger logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") - NU_MEAN = cfg.NU_MEAN - NU_STD = cfg.NU_STD - L = cfg.L # length of pipe - R = cfg.R # radius of pipe - RHO = cfg.RHO # density - P_OUT = cfg.P_OUT # pressure at the outlet of pipe - P_IN = cfg.P_IN # pressure at the inlet of pipe - - N_x = cfg.N_x - N_y = cfg.N_y - N_p = cfg.N_p - - X_IN = cfg.X_IN - X_OUT = X_IN + L - Y_START = -R - Y_END = Y_START + 2 * R - NU_START = NU_MEAN - NU_MEAN * NU_STD # 0.0001 - NU_END = NU_MEAN + NU_MEAN * NU_STD # 0.1 + X_OUT = cfg.X_IN + cfg.L + Y_START = -cfg.R + Y_END = Y_START + 2 * cfg.R + NU_START = cfg.NU_MEAN - cfg.NU_MEAN * cfg.NU_STD # 0.0001 + NU_END = cfg.NU_MEAN + cfg.NU_MEAN * cfg.NU_STD # 0.1 ## prepare data with (?, 2) data_1d_x = np.linspace( - X_IN, X_OUT, N_x, endpoint=True, dtype=paddle.get_default_dtype() + cfg.X_IN, X_OUT, cfg.N_x, endpoint=True, dtype=paddle.get_default_dtype() ) data_1d_y = np.linspace( - Y_START, Y_END, N_y, endpoint=True, dtype=paddle.get_default_dtype() + Y_START, Y_END, cfg.N_y, endpoint=True, dtype=paddle.get_default_dtype() ) data_1d_nu = np.linspace( - NU_START, NU_END, N_p, endpoint=True, dtype=paddle.get_default_dtype() + NU_START, NU_END, cfg.N_p, endpoint=True, dtype=paddle.get_default_dtype() ) data_2d_xy = ( @@ -95,23 +82,23 @@ def train(cfg: DictConfig): def input_trans(input): x, y = input["x"], input["y"] nu = input["nu"] - b = 2 * np.pi / (X_OUT - X_IN) - c = np.pi * (X_IN + X_OUT) / (X_IN - X_OUT) - sin_x = X_IN * paddle.sin(b * x + c) - cos_x = X_IN * paddle.cos(b * x + c) + b = 2 * np.pi / (X_OUT - cfg.X_IN) + c = np.pi * (cfg.X_IN + X_OUT) / (cfg.X_IN - X_OUT) + sin_x = cfg.X_IN * paddle.sin(b * x + c) + cos_x = cfg.X_IN * paddle.cos(b * x + c) return {"sin(x)": sin_x, "cos(x)": cos_x, "x": x, "y": y, "nu": nu} def output_trans_u(input, out): - return {"u": out["u"] * (R**2 - input["y"] ** 2)} + return {"u": out["u"] * (cfg.R**2 - input["y"] ** 2)} def output_trans_v(input, out): - return {"v": (R**2 - input["y"] ** 2) * out["v"]} + return {"v": (cfg.R**2 - input["y"] ** 2) * out["v"]} def output_trans_p(input, out): return { "p": ( - (P_IN - P_OUT) * (X_OUT - input["x"]) / L - + (X_IN - input["x"]) * (X_OUT - input["x"]) * out["p"] + (cfg.P_IN - cfg.P_OUT) * (X_OUT - input["x"]) / cfg.L + + (cfg.X_IN - input["x"]) * (X_OUT - input["x"]) * out["p"] ) } @@ -128,12 +115,15 @@ def output_trans_p(input, out): # set euqation equation = { - "NavierStokes": ppsci.equation.NavierStokes(nu="nu", rho=RHO, dim=2, time=False) + "NavierStokes": ppsci.equation.NavierStokes( + nu="nu", rho=cfg.RHO, dim=2, time=False + ) } # set constraint - BATCH_SIZE = cfg.TRAIN.batch_size.pde_constraint - ITERS_PER_EPOCH = int((N_x * N_y * N_p) / BATCH_SIZE) + ITERS_PER_EPOCH = int( + (cfg.N_x * cfg.N_y * cfg.N_p) / cfg.TRAIN.batch_size.pde_constraint + ) pde_constraint = ppsci.constraint.InteriorConstraint( equation["NavierStokes"].equations, @@ -142,7 +132,7 @@ def output_trans_p(input, out): dataloader_cfg={ "dataset": "NamedArrayDataset", "num_workers": 1, - "batch_size": BATCH_SIZE, + "batch_size": cfg.TRAIN.batch_size.pde_constraint, "iters_per_epoch": ITERS_PER_EPOCH, "sampler": { "name": "BatchSampler", @@ -181,18 +171,18 @@ def output_trans_p(input, out): "nu": data_2d_xy[:, 2:3], } output_dict = solver.predict(input_dict, return_numpy=True) - u_pred = output_dict["u"].reshape(N_y, N_x, N_p) + u_pred = output_dict["u"].reshape(cfg.N_y, cfg.N_x, cfg.N_p) # Analytical result, y = data_1d_y - u_analytical = np.zeros([N_y, N_x, N_p]) - dP = P_IN - P_OUT + u_analytical = np.zeros([cfg.N_y, cfg.N_x, cfg.N_p]) + dP = cfg.P_IN - cfg.P_OUT - for i in range(N_p): - uy = (R**2 - data_1d_y**2) * dP / (2 * L * data_1d_nu[i] * RHO) - u_analytical[:, :, i] = np.tile(uy.reshape([N_y, 1]), N_x) + for i in range(cfg.N_p): + uy = (cfg.R**2 - data_1d_y**2) * dP / (2 * cfg.L * data_1d_nu[i] * cfg.RHO) + u_analytical[:, :, i] = np.tile(uy.reshape([cfg.N_y, 1]), cfg.N_x) fontsize = 16 - idx_X = int(round(N_x / 2)) # pipe velocity section at L/2 + idx_X = int(round(cfg.N_x / 2)) # pipe velocity section at L/2 nu_index = [3, 6, 14, 49] # pick 4 nu samples ytext = [0.45, 0.28, 0.1, 0.01] @@ -238,9 +228,9 @@ def output_trans_p(input, out): # Distribution of center velocity # Predicted result num_test = 500 - data_1d_nu_distribution = np.random.normal(NU_MEAN, 0.2 * NU_MEAN, num_test) + data_1d_nu_distribution = np.random.normal(cfg.NU_MEAN, 0.2 * cfg.NU_MEAN, num_test) data_2d_xy_test = ( - np.array(np.meshgrid((X_IN - X_OUT) / 2.0, 0, data_1d_nu_distribution)) + np.array(np.meshgrid((cfg.X_IN - X_OUT) / 2.0, 0, data_1d_nu_distribution)) .reshape(3, -1) .T ) @@ -254,7 +244,7 @@ def output_trans_p(input, out): u_max_pred = output_dict_test["u"] # Analytical result, y = 0 - u_max_a = (R**2) * dP / (2 * L * data_1d_nu_distribution * RHO) + u_max_a = (cfg.R**2) * dP / (2 * cfg.L * data_1d_nu_distribution * cfg.RHO) # Plot plt.figure(2)