diff --git a/image_segmentation/pytorch/brats_evaluation_cases.txt b/image_segmentation/pytorch/brats_evaluation_cases.txt new file mode 100644 index 000000000..547d570dc --- /dev/null +++ b/image_segmentation/pytorch/brats_evaluation_cases.txt @@ -0,0 +1,251 @@ +00000 +00009 +00016 +00024 +00028 +00031 +00035 +00045 +00046 +00051 +00070 +00078 +00085 +00087 +00088 +00089 +00099 +00102 +00104 +00106 +00107 +00110 +00117 +00123 +00126 +00127 +00130 +00136 +00137 +00140 +00144 +00152 +00160 +00185 +00192 +00194 +00211 +00212 +00214 +00217 +00218 +00231 +00243 +00251 +00266 +00270 +00274 +00283 +00286 +00292 +00294 +00304 +00311 +00313 +00334 +00336 +00350 +00353 +00366 +00375 +00376 +00379 +00391 +00399 +00402 +00419 +00423 +00426 +00430 +00433 +00436 +00442 +00464 +00480 +00485 +00495 +00500 +00514 +00516 +00549 +00558 +00583 +00586 +00587 +00597 +00602 +00605 +00607 +00630 +00639 +00641 +00654 +00693 +00709 +00715 +00716 +00727 +00733 +00735 +00753 +00757 +00759 +00772 +00775 +00777 +00782 +00791 +00793 +00810 +00818 +00834 +00836 +01001 +01004 +01005 +01018 +01037 +01043 +01050 +01051 +01058 +01063 +01066 +01072 +01079 +01081 +01088 +01089 +01096 +01101 +01110 +01118 +01123 +01124 +01137 +01141 +01142 +01146 +01147 +01152 +01162 +01169 +01170 +01173 +01181 +01183 +01191 +01198 +01206 +01208 +01209 +01210 +01211 +01212 +01225 +01231 +01237 +01238 +01239 +01240 +01262 +01265 +01273 +01274 +01282 +01283 +01284 +01289 +01294 +01296 +01305 +01306 +01312 +01316 +01319 +01320 +01326 +01328 +01331 +01333 +01335 +01339 +01340 +01342 +01345 +01349 +01350 +01356 +01357 +01359 +01361 +01363 +01367 +01375 +01387 +01388 +01389 +01398 +01399 +01402 +01423 +01425 +01440 +01443 +01450 +01461 +01465 +01467 +01476 +01480 +01481 +01488 +01489 +01490 +01492 +01494 +01496 +01500 +01501 +01507 +01513 +01518 +01521 +01522 +01524 +01530 +01538 +01540 +01547 +01560 +01562 +01566 +01567 +01582 +01584 +01593 +01594 +01599 +01601 +01608 +01609 +01613 +01614 +01626 +01634 +01651 +01652 +01653 +01660 +01661 +01662 diff --git a/image_segmentation/pytorch/brats_evaluation_cases_0.txt b/image_segmentation/pytorch/brats_evaluation_cases_0.txt new file mode 100644 index 000000000..a88e31c0e --- /dev/null +++ b/image_segmentation/pytorch/brats_evaluation_cases_0.txt @@ -0,0 +1,251 @@ +00000 +00003 +00005 +00031 +00044 +00046 +00063 +00066 +00070 +00084 +00087 +00099 +00102 +00103 +00107 +00116 +00118 +00126 +00130 +00140 +00143 +00148 +00150 +00154 +00155 +00176 +00177 +00183 +00186 +00196 +00199 +00201 +00203 +00222 +00227 +00235 +00242 +00260 +00262 +00266 +00267 +00273 +00281 +00285 +00286 +00288 +00303 +00305 +00314 +00316 +00320 +00322 +00324 +00327 +00338 +00340 +00343 +00351 +00369 +00370 +00373 +00382 +00388 +00389 +00391 +00395 +00401 +00409 +00419 +00421 +00432 +00444 +00456 +00457 +00479 +00480 +00481 +00488 +00491 +00504 +00518 +00524 +00538 +00544 +00547 +00556 +00557 +00565 +00571 +00581 +00586 +00587 +00598 +00608 +00611 +00631 +00639 +00645 +00649 +00652 +00654 +00656 +00663 +00668 +00682 +00683 +00691 +00693 +00703 +00704 +00706 +00707 +00708 +00714 +00715 +00725 +00728 +00734 +00765 +00772 +00780 +00792 +00797 +00800 +00802 +00804 +00806 +00807 +00820 +00823 +00837 +00999 +01007 +01010 +01017 +01019 +01023 +01033 +01034 +01040 +01049 +01053 +01062 +01067 +01070 +01073 +01079 +01087 +01091 +01092 +01097 +01101 +01111 +01113 +01115 +01117 +01126 +01140 +01149 +01153 +01165 +01168 +01182 +01183 +01185 +01186 +01194 +01210 +01212 +01213 +01221 +01222 +01229 +01230 +01233 +01243 +01245 +01248 +01253 +01254 +01255 +01258 +01261 +01263 +01274 +01283 +01291 +01307 +01310 +01313 +01317 +01328 +01333 +01337 +01345 +01347 +01358 +01360 +01365 +01366 +01370 +01380 +01381 +01384 +01397 +01399 +01401 +01402 +01404 +01406 +01416 +01420 +01428 +01434 +01443 +01456 +01464 +01465 +01469 +01488 +01494 +01495 +01498 +01501 +01502 +01504 +01509 +01515 +01518 +01523 +01524 +01526 +01530 +01537 +01540 +01542 +01547 +01550 +01557 +01588 +01594 +01598 +01614 +01619 +01629 +01633 +01635 +01637 +01653 +01658 +01662 diff --git a/image_segmentation/pytorch/brats_evaluation_cases_1.txt b/image_segmentation/pytorch/brats_evaluation_cases_1.txt new file mode 100644 index 000000000..a14e5abad --- /dev/null +++ b/image_segmentation/pytorch/brats_evaluation_cases_1.txt @@ -0,0 +1,250 @@ +00009 +00016 +00019 +00021 +00024 +00026 +00028 +00033 +00045 +00049 +00051 +00053 +00072 +00081 +00095 +00100 +00101 +00105 +00106 +00108 +00115 +00121 +00123 +00124 +00133 +00138 +00149 +00156 +00159 +00160 +00170 +00171 +00178 +00191 +00195 +00207 +00210 +00211 +00217 +00218 +00220 +00238 +00243 +00247 +00250 +00263 +00274 +00282 +00284 +00290 +00296 +00298 +00299 +00300 +00312 +00318 +00321 +00325 +00339 +00350 +00352 +00353 +00356 +00366 +00375 +00376 +00377 +00380 +00383 +00392 +00397 +00399 +00413 +00414 +00425 +00436 +00441 +00443 +00453 +00468 +00478 +00495 +00499 +00501 +00511 +00533 +00537 +00539 +00548 +00552 +00554 +00555 +00574 +00580 +00583 +00591 +00594 +00604 +00606 +00610 +00612 +00624 +00650 +00655 +00687 +00690 +00705 +00716 +00723 +00746 +00747 +00756 +00764 +00768 +00773 +00784 +00795 +00805 +00809 +00810 +00814 +01011 +01012 +01014 +01016 +01018 +01027 +01028 +01031 +01036 +01041 +01045 +01054 +01058 +01063 +01064 +01072 +01083 +01084 +01089 +01095 +01096 +01100 +01104 +01107 +01114 +01116 +01119 +01122 +01124 +01136 +01137 +01141 +01143 +01147 +01148 +01152 +01157 +01162 +01166 +01170 +01173 +01178 +01189 +01190 +01192 +01193 +01196 +01198 +01199 +01200 +01203 +01204 +01207 +01215 +01220 +01227 +01236 +01239 +01246 +01247 +01267 +01281 +01288 +01290 +01299 +01303 +01309 +01311 +01312 +01324 +01327 +01332 +01334 +01335 +01349 +01363 +01372 +01373 +01375 +01396 +01415 +01419 +01425 +01440 +01446 +01449 +01480 +01486 +01487 +01491 +01492 +01497 +01500 +01503 +01506 +01507 +01511 +01512 +01513 +01517 +01528 +01531 +01532 +01536 +01538 +01544 +01552 +01556 +01562 +01566 +01577 +01578 +01582 +01584 +01585 +01593 +01595 +01596 +01597 +01607 +01613 +01624 +01640 +01642 +01645 +01659 +01664 +01665 +01666 diff --git a/image_segmentation/pytorch/brats_evaluation_cases_2.txt b/image_segmentation/pytorch/brats_evaluation_cases_2.txt new file mode 100644 index 000000000..41bf905a2 --- /dev/null +++ b/image_segmentation/pytorch/brats_evaluation_cases_2.txt @@ -0,0 +1,250 @@ +00025 +00048 +00056 +00060 +00071 +00096 +00097 +00112 +00127 +00128 +00132 +00139 +00146 +00151 +00152 +00158 +00162 +00172 +00185 +00192 +00206 +00214 +00231 +00234 +00237 +00241 +00246 +00249 +00259 +00269 +00283 +00289 +00294 +00297 +00301 +00304 +00334 +00341 +00348 +00364 +00386 +00400 +00404 +00410 +00412 +00417 +00429 +00430 +00431 +00440 +00442 +00466 +00469 +00485 +00498 +00500 +00507 +00512 +00528 +00530 +00540 +00551 +00561 +00568 +00569 +00570 +00572 +00575 +00576 +00579 +00582 +00590 +00601 +00607 +00613 +00616 +00620 +00622 +00623 +00630 +00636 +00659 +00667 +00675 +00679 +00680 +00689 +00709 +00718 +00724 +00727 +00733 +00740 +00744 +00757 +00758 +00759 +00774 +00775 +00778 +00788 +00799 +00808 +00818 +00824 +00831 +01001 +01002 +01005 +01013 +01043 +01044 +01046 +01055 +01057 +01059 +01060 +01061 +01071 +01082 +01102 +01112 +01120 +01123 +01129 +01132 +01135 +01138 +01146 +01151 +01154 +01159 +01160 +01161 +01164 +01169 +01172 +01177 +01179 +01180 +01191 +01206 +01209 +01214 +01216 +01223 +01224 +01232 +01237 +01240 +01242 +01244 +01250 +01256 +01262 +01265 +01268 +01270 +01271 +01272 +01275 +01278 +01280 +01284 +01286 +01289 +01298 +01308 +01316 +01320 +01322 +01323 +01326 +01336 +01338 +01340 +01342 +01343 +01351 +01352 +01353 +01354 +01356 +01361 +01364 +01368 +01369 +01378 +01379 +01387 +01392 +01394 +01405 +01407 +01412 +01423 +01424 +01436 +01441 +01447 +01448 +01450 +01451 +01452 +01459 +01460 +01467 +01468 +01470 +01473 +01475 +01479 +01482 +01485 +01489 +01490 +01493 +01514 +01516 +01520 +01525 +01539 +01549 +01553 +01554 +01560 +01561 +01564 +01567 +01571 +01574 +01576 +01579 +01583 +01591 +01592 +01599 +01601 +01602 +01604 +01605 +01608 +01610 +01611 +01626 +01628 +01631 +01643 +01654 +01663 diff --git a/image_segmentation/pytorch/brats_evaluation_cases_3.txt b/image_segmentation/pytorch/brats_evaluation_cases_3.txt new file mode 100644 index 000000000..f1200c57d --- /dev/null +++ b/image_segmentation/pytorch/brats_evaluation_cases_3.txt @@ -0,0 +1,250 @@ +00012 +00014 +00017 +00018 +00020 +00022 +00035 +00036 +00043 +00052 +00062 +00068 +00074 +00077 +00090 +00109 +00111 +00117 +00120 +00122 +00131 +00134 +00136 +00137 +00142 +00147 +00166 +00184 +00188 +00193 +00194 +00216 +00219 +00221 +00228 +00230 +00239 +00253 +00261 +00270 +00275 +00280 +00292 +00293 +00306 +00309 +00310 +00313 +00317 +00329 +00331 +00332 +00336 +00346 +00347 +00360 +00378 +00379 +00402 +00405 +00416 +00418 +00445 +00446 +00449 +00477 +00483 +00493 +00494 +00506 +00510 +00517 +00519 +00520 +00523 +00525 +00529 +00542 +00549 +00550 +00559 +00563 +00567 +00588 +00589 +00593 +00596 +00599 +00602 +00619 +00621 +00638 +00640 +00641 +00642 +00646 +00657 +00658 +00661 +00674 +00676 +00685 +00692 +00697 +00731 +00735 +00737 +00751 +00753 +00767 +00781 +00791 +00793 +00796 +00801 +00819 +00834 +00836 +01009 +01015 +01022 +01024 +01029 +01030 +01035 +01038 +01039 +01050 +01051 +01052 +01056 +01066 +01075 +01076 +01077 +01081 +01088 +01093 +01098 +01103 +01108 +01110 +01118 +01128 +01131 +01134 +01139 +01145 +01163 +01167 +01174 +01181 +01184 +01197 +01201 +01208 +01211 +01217 +01218 +01219 +01225 +01228 +01251 +01266 +01269 +01276 +01277 +01279 +01285 +01293 +01294 +01296 +01302 +01304 +01305 +01306 +01315 +01331 +01339 +01341 +01344 +01348 +01350 +01355 +01362 +01371 +01377 +01383 +01385 +01386 +01389 +01393 +01395 +01398 +01400 +01403 +01410 +01411 +01417 +01418 +01421 +01422 +01426 +01429 +01430 +01432 +01433 +01435 +01437 +01444 +01453 +01458 +01463 +01466 +01471 +01481 +01483 +01499 +01533 +01534 +01543 +01545 +01546 +01548 +01551 +01555 +01558 +01568 +01569 +01570 +01572 +01573 +01575 +01581 +01600 +01606 +01609 +01615 +01618 +01621 +01623 +01625 +01632 +01638 +01641 +01644 +01646 +01647 +01651 +01652 diff --git a/image_segmentation/pytorch/brats_evaluation_cases_4.txt b/image_segmentation/pytorch/brats_evaluation_cases_4.txt new file mode 100644 index 000000000..ea176a70b --- /dev/null +++ b/image_segmentation/pytorch/brats_evaluation_cases_4.txt @@ -0,0 +1,250 @@ +00002 +00006 +00008 +00011 +00030 +00032 +00054 +00058 +00059 +00061 +00064 +00078 +00085 +00088 +00089 +00094 +00098 +00104 +00110 +00113 +00144 +00157 +00165 +00167 +00187 +00204 +00209 +00212 +00233 +00236 +00240 +00251 +00254 +00258 +00271 +00291 +00311 +00328 +00344 +00349 +00359 +00367 +00371 +00387 +00390 +00403 +00406 +00407 +00423 +00426 +00433 +00448 +00451 +00452 +00454 +00455 +00459 +00464 +00470 +00472 +00496 +00502 +00505 +00513 +00514 +00516 +00526 +00532 +00543 +00545 +00558 +00577 +00578 +00584 +00597 +00605 +00615 +00618 +00625 +00626 +00628 +00651 +00677 +00684 +00686 +00688 +00694 +00698 +00729 +00730 +00732 +00736 +00739 +00742 +00750 +00760 +00777 +00782 +00787 +00789 +00803 +00811 +00816 +00828 +00830 +00838 +00839 +00840 +01000 +01003 +01004 +01008 +01020 +01021 +01025 +01026 +01032 +01037 +01042 +01047 +01048 +01065 +01068 +01069 +01074 +01078 +01080 +01085 +01086 +01090 +01094 +01099 +01105 +01106 +01109 +01121 +01125 +01127 +01130 +01133 +01142 +01144 +01150 +01155 +01156 +01158 +01171 +01175 +01176 +01187 +01188 +01195 +01202 +01205 +01226 +01231 +01234 +01235 +01238 +01241 +01249 +01252 +01257 +01259 +01260 +01264 +01273 +01282 +01287 +01292 +01295 +01297 +01300 +01301 +01314 +01318 +01319 +01321 +01325 +01329 +01330 +01346 +01357 +01359 +01367 +01374 +01376 +01382 +01388 +01390 +01391 +01408 +01409 +01413 +01414 +01427 +01431 +01438 +01439 +01442 +01445 +01454 +01455 +01457 +01461 +01462 +01472 +01474 +01476 +01477 +01478 +01484 +01496 +01505 +01508 +01510 +01519 +01521 +01522 +01527 +01529 +01535 +01541 +01559 +01563 +01565 +01580 +01586 +01587 +01589 +01590 +01603 +01612 +01616 +01617 +01620 +01622 +01627 +01630 +01634 +01636 +01639 +01648 +01649 +01650 +01655 +01656 +01657 +01660 +01661 diff --git a/image_segmentation/pytorch/data_loading/data_loader/split_dataset.py b/image_segmentation/pytorch/data_loading/data_loader/split_dataset.py new file mode 100644 index 000000000..e6410237e --- /dev/null +++ b/image_segmentation/pytorch/data_loading/data_loader/split_dataset.py @@ -0,0 +1,24 @@ +from sklearn.model_selection import KFold +import tensorflow.io as io +# importing random module +import random +import os +import re +def load_data(path, files_pattern): + data = sorted(io.gfile.glob((os.path.join(path, files_pattern)))) + assert len(data) > 0, f"Found no data at {path}" + return data + +def fold_split(path: str, fold_num : int): + imgs = load_data(path, "*_x.npy") + split = KFold(n_splits=fold_num, random_state=random.seed(3), shuffle=True) + + for idx, item in enumerate(split.split(imgs)): + with open('brats_evaluation_cases_{}.txt'.format(idx), 'w') as f: + for i in item[1]: + filename = imgs[i] + f.write(re.split('/', filename)[-1].split('_')[1]) + f.write('\n') + +if __name__ == "__main__": + fold_split("gs://mlperf-dataset/data/2021_Brats_np/11_3d", 5) \ No newline at end of file diff --git a/image_segmentation/pytorch/data_loading/data_loader/unet3d_data_loader.py b/image_segmentation/pytorch/data_loading/data_loader/unet3d_data_loader.py index 3060d288c..fcca56355 100644 --- a/image_segmentation/pytorch/data_loading/data_loader/unet3d_data_loader.py +++ b/image_segmentation/pytorch/data_loading/data_loader/unet3d_data_loader.py @@ -9,16 +9,11 @@ from data_loading.pytorch_loader import PytTrain, PytVal from runtime.logging import mllog_event from torch.utils.data import Dataset - - -def list_files_with_pattern(path, files_pattern): - data = sorted(glob.glob(os.path.join(path, files_pattern))) - assert len(data) > 0, f"Found no data at {path}" - return data - +import glob +import tensorflow.io as io def load_data(path, files_pattern): - data = sorted(glob.glob(os.path.join(path, files_pattern))) + data = sorted(io.gfile.glob((os.path.join(path, files_pattern)))) assert len(data) > 0, f"Found no data at {path}" return data @@ -34,13 +29,19 @@ def split_eval_data(x_val, y_val, num_shards, shard_id): y = [a.tolist() for a in np.array_split(y_val, num_shards)] return x[shard_id], y[shard_id] +def get_data_split(path: str, num_shards: int, shard_id: int, use_brats: bool, foldidx: int): + if use_brats: + listfile = "brats_evaluation_cases_{}.txt".format(foldidx) + else: + listfile = "evaluation_cases.txt" -def get_data_split(path: str, num_shards: int, shard_id: int): - with open("evaluation_cases.txt", "r") as f: + with open(listfile, "r") as f: val_cases_list = f.readlines() val_cases_list = [case.rstrip("\n") for case in val_cases_list] imgs = load_data(path, "*_x.npy") lbls = load_data(path, "*_y.npy") + imgs = [name.split('/')[-1] for name in imgs] + lbls = [name.split('/')[-1] for name in lbls] assert len(imgs) == len(lbls), f"Found {len(imgs)} volumes but {len(lbls)} corresponding masks" imgs_train, lbls_train, imgs_val, lbls_val = [], [], [], [] for (case_img, case_lbl) in zip(imgs, lbls): @@ -50,12 +51,11 @@ def get_data_split(path: str, num_shards: int, shard_id: int): else: imgs_train.append(case_img) lbls_train.append(case_lbl) + mllog_event(key="train_samples", value=len(imgs_train), sync=False) mllog_event(key="eval_samples", value=len(imgs_val), sync=False) imgs_val, lbls_val = split_eval_data(imgs_val, lbls_val, num_shards, shard_id) return imgs_train, imgs_val, lbls_train, lbls_val - - class SyntheticDataset(Dataset): def __init__( self, @@ -105,22 +105,22 @@ def get_data_loaders(flags: Namespace, num_shards: int, global_rank: int, device :rtype: Union[Tuple[pl.MpDeviceLoader, pl.MpDeviceLoader], Tuple[DataLoader, DataLoader]] """ if flags.loader == "synthetic": - train_dataset = SyntheticDataset(scalar=True, shape=flags.input_shape, layout=flags.layout) - val_dataset = SyntheticDataset( + train_dataset = SyntheticDataset(channels_in=4, channels_out=4, scalar=True, shape=flags.input_shape, layout=flags.layout) + val_dataset = SyntheticDataset(channels_in=4, channels_out=4, scalar=True, shape=flags.val_input_shape, layout=flags.layout ) elif flags.loader == "pytorch": x_train, x_val, y_train, y_val = get_data_split( - flags.data_dir, num_shards, shard_id=global_rank + flags.data_dir, num_shards, shard_id=global_rank, use_brats=flags.use_brats, foldidx = flags.fold_idx ) train_data_kwargs = { "patch_size": flags.input_shape, "oversampling": flags.oversampling, "seed": flags.seed, } - train_dataset = PytTrain(x_train, y_train, **train_data_kwargs) - val_dataset = PytVal(x_val, y_val) + train_dataset = PytTrain(x_train, y_train, flags.data_dir, **train_data_kwargs) + val_dataset = PytVal(x_val, y_val, flags.data_dir) else: raise ValueError(f"Loader {flags.loader} unknown. Valid loaders are: synthetic, pytorch") diff --git a/image_segmentation/pytorch/data_loading/data_loader/xla_data_loader.py b/image_segmentation/pytorch/data_loading/data_loader/xla_data_loader.py index f8baee486..2de29e484 100644 --- a/image_segmentation/pytorch/data_loading/data_loader/xla_data_loader.py +++ b/image_segmentation/pytorch/data_loading/data_loader/xla_data_loader.py @@ -52,7 +52,7 @@ def get_data_loaders( val_loader = DataLoader( val_dataset, batch_size=1, - shuffle=not flags.benchmark and val_sampler is None, + shuffle=flags.use_brats and not flags.benchmark and val_sampler is None, sampler=val_sampler, num_workers=flags.num_workers, pin_memory=False, diff --git a/image_segmentation/pytorch/data_loading/pytorch_loader.py b/image_segmentation/pytorch/data_loading/pytorch_loader.py index bb71d32f7..841c92a67 100644 --- a/image_segmentation/pytorch/data_loading/pytorch_loader.py +++ b/image_segmentation/pytorch/data_loading/pytorch_loader.py @@ -1,9 +1,11 @@ import random import numpy as np +import os import scipy.ndimage from torch.utils.data import Dataset from torchvision import transforms - +import gcsfs +fs = gcsfs.GCSFileSystem() def get_train_transforms(): rand_flip = RandFlip() @@ -135,7 +137,8 @@ def __call__(self, data): class PytTrain(Dataset): - def __init__(self, images, labels, **kwargs): + def __init__(self, images, labels, dataset, **kwargs): + self.dataset = dataset self.images, self.labels = images, labels self.train_transforms = get_train_transforms() patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"] @@ -146,24 +149,20 @@ def __len__(self): return len(self.images) def __getitem__(self, idx): - data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])} + with fs.open(os.path.join(self.dataset, self.images[idx]), 'rb') as f, fs.open(os.path.join(self.dataset, self.labels[idx]), 'rb') as g: + data = {"image": np.load(f), "label": np.load(g)} data = self.rand_crop(data) - data = self.train_transforms(data) + data = self.train_transforms(data) return data["image"], data["label"] - -class PytVal(Dataset): - def __init__(self, images, labels): +class PytVal(Dataset): + def __init__(self, images, labels, dataset): self.images, self.labels = images, labels + self.dataset = dataset - def __len__(self): - return len(self.images) - - def __getitem__(self, idx): - return np.load(self.images[idx]), np.load(self.labels[idx]) - - - - - + def __len__(self): + return len(self.images) + def __getitem__(self, idx): + with fs.open(os.path.join(self.dataset, self.images[idx]), 'rb') as f, fs.open(os.path.join(self.dataset, self.labels[idx]), 'rb') as g: + return np.load(f), np.load(g) \ No newline at end of file diff --git a/image_segmentation/pytorch/main.py b/image_segmentation/pytorch/main.py index 8f94aaadb..fbf14e591 100644 --- a/image_segmentation/pytorch/main.py +++ b/image_segmentation/pytorch/main.py @@ -8,7 +8,8 @@ from mlperf_logging.mllog import constants from data_loading.data_loader.unet3d_data_loader import get_data_loaders -from model.losses import DiceCELoss, DiceScore +from model.losses import DiceCELoss, DiceScore, LossBraTS +from model.metrics import Dice as DiceMetric from model.unet3d import Unet3D from runtime.arguments import PARSER from runtime.callbacks import get_callbacks @@ -53,7 +54,7 @@ def main(local_rank, flags): world_size = get_world_size() local_rank = get_rank() worker_seeds, shuffling_seeds = setup_seeds( - master_seed=flags.seed, epochs=flags.epochs, device=device + master_seed= flags.seed if flags.seed != -1 else None, epochs=flags.epochs, device=device ) worker_seed = worker_seeds[local_rank] seed_everything(worker_seed) @@ -68,8 +69,12 @@ def main(local_rank, flags): mlperf_run_param_log(flags) callbacks = get_callbacks(flags, dllogger, local_rank, world_size) + + if flags.use_brats: + model = Unet3D(4, 4, normalization=flags.normalization, activation=flags.activation) + else: + model = Unet3D(1, 3, normalization=flags.normalization, activation=flags.activation) flags.seed = worker_seed - model = Unet3D(1, 3, normalization=flags.normalization, activation=flags.activation) if flags.use_fsdp: import torch_xla.core.xla_model as xm @@ -123,10 +128,6 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): model = model.to(device) - param_nums = sum(p.numel() for p in model.parameters().values()) - - mllog_event(key="per-TPU (sharded) parameter num", value=param_nums) - mllog_end(key=constants.INIT_STOP, sync=True) mllog_start(key=constants.RUN_START, sync=True) mllog_event(key="training_params", value=str(flags), sync=True) @@ -152,12 +153,14 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): use_softmax=True, layout=flags.layout, include_background=flags.include_background, + num_classes=4 if flags.use_brats else 3 ) score_fn = DiceScore( to_onehot_y=True, use_argmax=True, layout=flags.layout, include_background=flags.include_background, + num_classes=4 if flags.use_brats else 3 ) if flags.exec_mode == "train": @@ -183,13 +186,11 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): print("Invalid exec_mode.") pass - if __name__ == "__main__": flags = PARSER.parse_args() # record the program start time, which is later used for # calculating the training start-up time flags.program_start_time = time.time() - if flags.device == "xla": xmp.spawn(xla_main, args=(flags,)) elif flags.device == "cuda": diff --git a/image_segmentation/pytorch/model/losses.py b/image_segmentation/pytorch/model/losses.py index dec82a55f..238a36a89 100644 --- a/image_segmentation/pytorch/model/losses.py +++ b/image_segmentation/pytorch/model/losses.py @@ -2,7 +2,6 @@ import torch.nn as nn import torch.nn.functional as F - class Dice: def __init__(self, to_onehot_y: bool = True, @@ -10,7 +9,8 @@ def __init__(self, use_softmax: bool = True, use_argmax: bool = False, include_background: bool = False, - layout: str = "NCDHW"): + layout: str = "NCDHW", + num_classes: int = 3): self.include_background = include_background self.to_onehot_y = to_onehot_y self.to_onehot_x = to_onehot_x @@ -24,9 +24,11 @@ def __call__(self, prediction, target): if self.layout == "NCDHW": channel_axis = 1 reduce_axis = list(range(2, len(prediction.shape))) + num_classes = prediction.shape[1] else: channel_axis = -1 reduce_axis = list(range(1, len(prediction.shape) - 1)) + num_classes = prediction.shape[-1] num_pred_ch = prediction.shape[channel_axis] if self.use_softmax: @@ -35,10 +37,10 @@ def __call__(self, prediction, target): prediction = torch.argmax(prediction, dim=channel_axis) if self.to_onehot_y: - target = to_one_hot(target, self.layout, channel_axis) + target = to_one_hot(target, self.layout, channel_axis, num_classes) if self.to_onehot_x: - prediction = to_one_hot(prediction, self.layout, channel_axis) + prediction = to_one_hot(prediction, self.layout, channel_axis, num_classes) if not self.include_background: assert num_pred_ch > 1, \ @@ -60,20 +62,20 @@ def __call__(self, prediction, target): return (2.0 * intersection + self.smooth_nr) / (target_sum + prediction_sum + self.smooth_dr) -def to_one_hot(array, layout, channel_axis): +def to_one_hot(array, layout, channel_axis, num_classes): if len(array.shape) >= 5: array = torch.squeeze(array, dim=channel_axis) - array = F.one_hot(array.long(), num_classes=3) + array = F.one_hot(array.long(), num_classes=num_classes) if layout == "NCDHW": - array = array.permute(0, 4, 1, 2, 3).float() + array = array.permute(0, 4, 1, 2, 3) return array class DiceCELoss(nn.Module): - def __init__(self, to_onehot_y, use_softmax, layout, include_background): + def __init__(self, to_onehot_y, use_softmax, layout, include_background, num_classes): super(DiceCELoss, self).__init__() self.dice = Dice(to_onehot_y=to_onehot_y, use_softmax=use_softmax, layout=layout, - include_background=include_background) + include_background=include_background, num_classes=num_classes) self.cross_entropy = nn.CrossEntropyLoss() def forward(self, y_pred, y_true): @@ -84,9 +86,25 @@ def forward(self, y_pred, y_true): class DiceScore: def __init__(self, to_onehot_y: bool = True, use_argmax: bool = True, layout: str = "NCDHW", - include_background: bool = False): + include_background: bool = False, num_classes = 3): self.dice = Dice(to_onehot_y=to_onehot_y, to_onehot_x=True, use_softmax=False, - use_argmax=use_argmax, layout=layout, include_background=include_background) + use_argmax=use_argmax, layout=layout, include_background=include_background, num_classes=num_classes) def __call__(self, y_pred, y_true): return torch.mean(self.dice(y_pred, y_true), dim=0) + +from monai.losses import DiceLoss +class LossBraTS(nn.Module): + def __init__(self): + super(LossBraTS, self).__init__() + self.dice = DiceLoss(sigmoid=True, batch=True) + self.ce = nn.BCEWithLogitsLoss() + + def _loss(self, p, y): + return self.dice(p, y) + self.ce(p, y.float()) + + def forward(self, p, y): + y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3 + p_wt, p_tc, p_et = p[:, 0].unsqueeze(1), p[:, 1].unsqueeze(1), p[:, 2].unsqueeze(1) + l_wt, l_tc, l_et = self._loss(p_wt, y_wt), self._loss(p_tc, y_tc), self._loss(p_et, y_et) + return l_wt + l_tc + l_et \ No newline at end of file diff --git a/image_segmentation/pytorch/model/metrics.py b/image_segmentation/pytorch/model/metrics.py new file mode 100644 index 000000000..4c60c9648 --- /dev/null +++ b/image_segmentation/pytorch/model/metrics.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics import Metric + + +class Dice(Metric): + full_state_update = False + + def __init__(self, n_class, brats): + super().__init__(dist_sync_on_step=False) + self.n_class = n_class + self.brats = brats + self.add_state("steps", default=torch.zeros(1), dist_reduce_fx="sum") + self.add_state("dice", default=torch.zeros((n_class,)), dist_reduce_fx="sum") + self.add_state("loss", default=torch.zeros(1), dist_reduce_fx="sum") + + def __call__(self, p, y): + return self.compute_stats_brats(p, y) + + def update(self, p, y, l): + self.steps += 1 + self.dice += self.compute_stats_brats(p, y) if self.brats else self.compute_stats(p, y) + self.loss += l + + def compute(self): + return 100 * self.dice / self.steps, self.loss / self.steps + + def compute_stats_brats(self, p, y): + scores = torch.zeros(self.n_class, device=p.device, dtype=torch.float32) + p = (torch.sigmoid(p) > 0.5).int() + y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3 + y = torch.stack([y_wt, y_tc, y_et], dim=1) + + for i in range(self.n_class): + p_i, y_i = p[:, i], y[:, i] + if y_i.dim() == 5 and y_i.shape[0] == 1: + y_i = torch.squeeze(y_i, 0) + else: + assert("unexpected label shape") + if (y_i != 1).all(): + # no foreground class + scores[i - 1] += 1 if (p_i != 1).all() else 0 + continue + tp, fn, fp = self.get_stats(p_i, y_i, 1) + denom = (2 * tp + fp + fn).to(torch.float) + score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0 + scores[i - 1] += score_cls + return scores + + def compute_stats(self, p, y): + scores = torch.zeros(self.n_class, device=p.device, dtype=torch.float32) + p = torch.argmax(p, dim=1) + for i in range(1, self.n_class + 1): + if (y != i).all(): + # no foreground class + scores[i - 1] += 1 if (p != i).all() else 0 + continue + tp, fn, fp = self.get_stats(p, y, i) + denom = (2 * tp + fp + fn).to(torch.float) + score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0 + scores[i - 1] += score_cls + return scores + + @staticmethod + def get_stats(p, y, c): + tp = torch.logical_and(p == c, y == c).sum() + fn = torch.logical_and(p != c, y == c).sum() + fp = torch.logical_and(p == c, y != c).sum() + return tp, fn, fp \ No newline at end of file diff --git a/image_segmentation/pytorch/requirements.txt b/image_segmentation/pytorch/requirements.txt index 6ac350dcf..09137f6df 100644 --- a/image_segmentation/pytorch/requirements.txt +++ b/image_segmentation/pytorch/requirements.txt @@ -1,4 +1,7 @@ git+https://github.com/NVIDIA/dllogger https://github.com/mlcommons/logging/archive/refs/tags/1.1.0-rc4.zip nibabel==3.2.1 -scipy==1.5.2 \ No newline at end of file +scipy==1.5.2 +tensorflow-cpu==2.9.1 +monai +torchmetrics \ No newline at end of file diff --git a/image_segmentation/pytorch/run_brats.sh b/image_segmentation/pytorch/run_brats.sh new file mode 100755 index 000000000..7160ebe47 --- /dev/null +++ b/image_segmentation/pytorch/run_brats.sh @@ -0,0 +1,71 @@ +#!/bin/bash +set -e + +# runs benchmark and reports time to convergence +# to use the script: +# run_and_time.sh + +SEED=${1:--1} + +MAX_EPOCHS=${10:-1000} +QUALITY_THRESHOLD=${9:-"0.84"} +START_EVAL_AT=${6:-200} +EVALUATE_EVERY=${11:-20} +LEARNING_RATE=${2:-"3e-4"} +OPTMIZER=${3:-"adam"} +INIT_LEARNING_RATE=${4:-"3e-4"} +LR_WARMUP_EPOCHS=${5:-5} +DATASET_DIR="gs://mlperf-dataset/data/2021_Brats_np/11_3d" +BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=1 +ACTIVATION=${7:-"leaky_relu"} +FOLD=${8:-0} + +if [ true ] +then + # start timing + start=$(date +%s) + start_fmt=$(date +%Y-%m-%d\ %r) + echo "STARTING TIMING RUN AT $start_fmt" + +# # CLEAR YOUR CACHE HERE +# python3 -c " +# from mlperf_logging.mllog import constants +# from runtime.logging import mllog_event +# mllog_event(key=constants.CACHE_CLEAR, value=True)" + + PJRT_DEVICE=TPU python3 main.py --data_dir ${DATASET_DIR} \ + --tb_dir "" \ + --epochs ${MAX_EPOCHS} \ + --evaluate_every ${EVALUATE_EVERY} \ + --start_eval_at ${START_EVAL_AT} \ + --quality_threshold ${QUALITY_THRESHOLD} \ + --batch_size ${BATCH_SIZE} \ + --optimizer ${OPTMIZER} \ + --activation ${ACTIVATION} \ + --ga_steps ${GRADIENT_ACCUMULATION_STEPS} \ + --init_learning_rate ${INIT_LEARNING_RATE} \ + --learning_rate ${LEARNING_RATE} \ + --seed ${SEED} \ + --lr_warmup_epochs ${LR_WARMUP_EPOCHS} \ + --use_brats \ + --fold_idx ${FOLD} \ + --input_shape 128 128 128 \ + --profile_port 9229 \ + --device xla 2>&1 | tee -a ~/result.txt + + # end timing + end=$(date +%s) + end_fmt=$(date +%Y-%m-%d\ %r) + echo "ENDING TIMING RUN AT $end_fmt" + + + # report result + result=$(( $end - $start )) + result_name="image_segmentation" + + + echo "RESULT,$result_name,$SEED,$result,$USER,$start_fmt" +else + echo "Directory ${DATASET_DIR} does not exist" +fi \ No newline at end of file diff --git a/image_segmentation/pytorch/run_kitts.sh b/image_segmentation/pytorch/run_kitts.sh new file mode 100755 index 000000000..bbdbd56a2 --- /dev/null +++ b/image_segmentation/pytorch/run_kitts.sh @@ -0,0 +1,69 @@ +#!/bin/bash +set -e + +# runs benchmark and reports time to convergence +# to use the script: +# run_and_time.sh + +SEED=${1:--1} + +MAX_EPOCHS=3000 +QUALITY_THRESHOLD="0.908" +START_EVAL_AT=${6:-1000} +EVALUATE_EVERY=20 +LEARNING_RATE=${2:-"0.8"} +OPTMIZER=${3:-"sgd"} +INIT_LEARNING_RATE=${4:-"1e-4"} +LR_WARMUP_EPOCHS=${5:-1000} +DATASET_DIR="gs://mlperf-dataset/data/kits19" +BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=1 +ACTIVATION=${7:-"relu"} +FOLD=${8:-2} + +if [ true ] +then + # start timing + start=$(date +%s) + start_fmt=$(date +%Y-%m-%d\ %r) + echo "STARTING TIMING RUN AT $start_fmt" + +# # CLEAR YOUR CACHE HERE +# python3 -c " +# from mlperf_logging.mllog import constants +# from runtime.logging import mllog_event +# mllog_event(key=constants.CACHE_CLEAR, value=True)" + + PJRT_DEVICE=TPU python3 main.py --data_dir ${DATASET_DIR} \ + --tb_dir "" \ + --epochs ${MAX_EPOCHS} \ + --evaluate_every ${EVALUATE_EVERY} \ + --start_eval_at ${START_EVAL_AT} \ + --quality_threshold ${QUALITY_THRESHOLD} \ + --batch_size ${BATCH_SIZE} \ + --optimizer ${OPTMIZER} \ + --activation ${ACTIVATION} \ + --ga_steps ${GRADIENT_ACCUMULATION_STEPS} \ + --init_learning_rate ${INIT_LEARNING_RATE} \ + --learning_rate ${LEARNING_RATE} \ + --seed ${SEED} \ + --lr_warmup_epochs ${LR_WARMUP_EPOCHS} \ + --num_workers 8 \ + --input_shape 128 128 128 \ + --device xla 2>&1 | tee -a ~/result.txt + + # end timing + end=$(date +%s) + end_fmt=$(date +%Y-%m-%d\ %r) + echo "ENDING TIMING RUN AT $end_fmt" + + + # report result + result=$(( $end - $start )) + result_name="image_segmentation" + + + echo "RESULT,$result_name,$SEED,$result,$USER,$start_fmt" +else + echo "Directory ${DATASET_DIR} does not exist" +fi \ No newline at end of file diff --git a/image_segmentation/pytorch/runtime/arguments.py b/image_segmentation/pytorch/runtime/arguments.py index 58414e035..f04f3131a 100644 --- a/image_segmentation/pytorch/runtime/arguments.py +++ b/image_segmentation/pytorch/runtime/arguments.py @@ -5,6 +5,7 @@ PARSER.add_argument('--data_dir', dest='data_dir', required=True) PARSER.add_argument('--log_dir', dest='log_dir', type=str, default="/tmp") +PARSER.add_argument('--tb_dir', dest='tb_dir', type=str, default="/tmp/tb") PARSER.add_argument('--save_ckpt_path', dest='save_ckpt_path', type=str, default="") PARSER.add_argument('--load_ckpt_path', dest='load_ckpt_path', type=str, default="") PARSER.add_argument('--loader', dest='loader', default="pytorch", type=str) @@ -55,3 +56,7 @@ PARSER.add_argument('--use_nested_fsdp', action='store_true', default=False) PARSER.add_argument('--use_grad_ckpt', action='store_true', default=False) PARSER.add_argument('--use_bf16', action='store_true', default=True) + +PARSER.add_argument('--use_brats', action='store_true', default=False) +PARSER.add_argument('--fold_idx', dest='fold_idx', type=int, default=0) +PARSER.add_argument('--deep_supervision', action='store_true', default=False) diff --git a/image_segmentation/pytorch/runtime/inference.py b/image_segmentation/pytorch/runtime/inference.py index b80fe0719..8434e1261 100644 --- a/image_segmentation/pytorch/runtime/inference.py +++ b/image_segmentation/pytorch/runtime/inference.py @@ -11,6 +11,9 @@ reduce_tensor, ) +from runtime.logging import ( + mllog_event +) def evaluate(flags, model, loader, loss_fn, score_fn, device, epoch=0, is_distributed=False): rank = get_rank() @@ -44,26 +47,43 @@ def evaluate(flags, model, loader, loss_fn, score_fn, device, epoch=0, is_distri model=model, overlap=flags.overlap, mode="gaussian", - padding_val=-2.2 + padding_val=0 if flags.use_brats else -2.2, + out_dim=4 if flags.use_brats else 3 ) eval_loss_value = loss_fn(output, label) scores.append(score_fn(output, label)) - eval_loss.append(eval_loss_value) - del output - del label - - scores = reduce_tensor(torch.mean(torch.stack(scores, dim=0), dim=0), world_size) - eval_loss = reduce_tensor(torch.mean(torch.stack(eval_loss, dim=0), dim=0), world_size) + eval_loss.append(eval_loss_value) + del output + del label + + mllog_event( + key="eval_step", + value=eval_loss_value, + metadata={ + "iteration_num": i, + }, + sync=False, + ) + + scores = reduce_tensor(torch.mean(torch.stack(scores, dim=0), dim=0)) + eval_loss = reduce_tensor(torch.mean(torch.stack(eval_loss, dim=0), dim=0)) # scores = torch.mean(torch.stack(scores, dim=0), dim=0) # eval_loss = torch.mean(torch.stack(eval_loss, dim=0), dim=0) scores, eval_loss = scores.cpu().numpy(), float(eval_loss.cpu().numpy()) - eval_metrics = {"epoch": epoch, - "L1 dice": scores[-2], - "L2 dice": scores[-1], - "mean_dice": (scores[-1] + scores[-2]) / 2, - "eval_loss": eval_loss} - + if flags.use_brats: + eval_metrics = {"epoch": epoch, + "L1 dice": scores[-3], + "L2 dice": scores[-2], + "L3 dice": scores[-1], + "mean_dice": (scores[-1] + scores[-2] + scores[-3]) / 3, + "eval_loss": eval_loss} + else: + eval_metrics = {"epoch": epoch, + "L1 dice": scores[-2], + "L2 dice": scores[-1], + "mean_dice": (scores[-1] + scores[-2]) / 2, + "eval_loss": eval_loss} return eval_metrics @@ -71,8 +91,8 @@ def pad_input(volume, roi_shape, strides, padding_mode, padding_val, dim=3): """ mode: constant, reflect, replicate, circular """ - bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)] - bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] + bounds = [(strides[i] - volume.shape[-dim:][i] % strides[i]) % strides[i] for i in range(dim)] + bounds = [bounds[i] if (volume.shape[-dim:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)] paddings = [bounds[2] // 2, bounds[2] - bounds[2] // 2, bounds[1] // 2, bounds[1] - bounds[1] // 2, @@ -94,7 +114,9 @@ def gaussian_kernel(n, std): def sliding_window_inference(inputs, labels, roi_shape, model, overlap=0.5, mode="gaussian", - padding_mode="constant", padding_val=0.0, **kwargs): + padding_mode="constant", padding_val=0.0, out_dim=3): + labels = labels.type(torch.int8) + inputs = inputs.type(torch.bfloat16) image_shape = list(inputs.shape[2:]) dim = len(image_shape) strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)] @@ -114,7 +136,7 @@ def sliding_window_inference(inputs, labels, roi_shape, model, overlap=0.5, mode padded_shape = inputs.shape[2:] size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)] - result = torch.zeros(size=(1, 3, *padded_shape), dtype=inputs.dtype, device=inputs.device) + result = torch.zeros(size=(1, out_dim, *padded_shape), dtype=torch.bfloat16, device=inputs.device) norm_map = torch.zeros_like(result) if mode == "constant": norm_patch = torch.ones(size=roi_shape, dtype=norm_map.dtype, device=norm_map.device) diff --git a/image_segmentation/pytorch/runtime/trainer/unet3d_trainer.py b/image_segmentation/pytorch/runtime/trainer/unet3d_trainer.py index b78b3d547..863c5208a 100644 --- a/image_segmentation/pytorch/runtime/trainer/unet3d_trainer.py +++ b/image_segmentation/pytorch/runtime/trainer/unet3d_trainer.py @@ -11,7 +11,7 @@ from torch.nn import Parameter from torch.optim import SGD, Adam - +import torch_xla.test.test_utils as test_utils class UNet3DTrainer(ABC): """Base class for training UNet3D in PyTorch""" @@ -49,6 +49,8 @@ def __init__( gamma=flags.lr_decay_factor, ) + self.summary_writer = None + def train(self): """Trains the UNet3D model""" is_successful = False @@ -103,8 +105,15 @@ def train(self): loss_value = self.train_step(iteration=iteration, images=images, labels=labels) - #loss_value = reduce_tensor(loss_value).detach().cpu().numpy() + loss_value = reduce_tensor(loss_value) cumulative_loss.append(loss_value) + if self.summary_writer: + test_utils.write_to_summary( + self.summary_writer, + global_step=epoch * len(train_loader._loader) // self.flags.batch_size + iteration, + dict_to_write={ + 'loss': loss_value.detach().cpu().numpy() + }) # in debug mode, log the train loss on each iteration if self.flags.debug: mllog_event( @@ -116,7 +125,11 @@ def train(self): }, sync=False, ) - + if self.summary_writer: + test_utils.write_to_summary( + self.summary_writer, + global_step = epoch, + dict_to_write={'learning rate':self.optimizer.param_groups[0]["lr"]}) mllog_end( key=CONSTANTS.EPOCH_STOP, metadata={ @@ -126,6 +139,17 @@ def train(self): }, sync=False, ) + try: + mllog_end( + key=CONSTANTS.EPOCH_STOP, + metadata={ + CONSTANTS.EPOCH_NUM: epoch, + "loss": sum(cumulative_loss) / len(cumulative_loss) + }, + sync=False, + ) + except: + pass # startup time is defined as the time between the program # starting and the 1st epoch ending @@ -166,6 +190,15 @@ def train(self): metadata={CONSTANTS.EPOCH_NUM: epoch}, sync=False, ) + if self.summary_writer: + test_utils.write_to_summary( + self.summary_writer, + global_step = epoch, + dict_to_write={ + 'eval_loss':eval_metrics["eval_loss"], + 'mean_dice':eval_metrics["mean_dice"], + }) + mllog_end( key=CONSTANTS.EVAL_STOP, metadata={CONSTANTS.EPOCH_NUM: epoch}, diff --git a/image_segmentation/pytorch/runtime/trainer/xla_trainer.py b/image_segmentation/pytorch/runtime/trainer/xla_trainer.py index 66747f8af..357ece0c0 100644 --- a/image_segmentation/pytorch/runtime/trainer/xla_trainer.py +++ b/image_segmentation/pytorch/runtime/trainer/xla_trainer.py @@ -9,7 +9,8 @@ import torch_xla.debug.profiler as xp import torch_xla.distributed.parallel_loader as pl from runtime.trainer.unet3d_trainer import UNet3DTrainer - +from datetime import datetime +import torch_xla.test.test_utils as test_utils class XLATrainer(UNet3DTrainer): """Trains UNet3D in PyTorch/XLA""" @@ -49,6 +50,20 @@ def __init__( # Start and persist the profiler server if self.flags.profile_port: self.profile_server = xp.start_server(self.flags.profile_port) + + if flags.tb_dir !="": + dataset_name = "kitts/" if not flags.use_brats else "" + self.summary_dir = flags.tb_dir + dataset_name + datetime.now().strftime("%Y%m%d-%H%M%S") + self.summary_interval = 100 + if xm.is_master_ordinal(): + self.summary_writer = test_utils.get_summary_writer( + self.summary_dir) if self.summary_interval else None + + xm.master_print('tb_summery_dir is {}'.format(self.summary_dir)) + + def __del__(self): + if self.summary_writer: + test_utils.close_summary_writer(self.summary_writer) def train_step( self, iteration: int, images: torch.Tensor, labels: torch.Tensor