From bfcdaea47b6a442ae25dc1af3fc9116404484d93 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 17:51:56 +0000 Subject: [PATCH 01/19] while_loop --- .../inference_tpu_single_device.py | 51 ++++++++++++++----- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 9e8021d08..2d51ad3e2 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -42,9 +42,9 @@ def main(args): device = xm.xla_device() pipe.to(device) - bs = args.batch_size - inference_steps = args.inf_steps - height = width = args.width + bs = args.batch_size # 1 + inference_steps = args.inf_steps # 2 + height = width = args.width # 512 prompts = ["a photo of an astronaut riding a horse on mars"] * bs print(f'batch size = {bs}, inference steps = {inference_steps}', @@ -52,16 +52,41 @@ def main(args): flush=True ) - iters = 15 - print('starting inference', flush=True) - for i in range(iters): - start = time() - image = pipe(prompts, - num_inference_steps=inference_steps, - height=height, - width=width, - ).images[0] - print(f'Step {i} inference time {time()-start} sec', flush=True) + import torch + import torch_xla.experimental.fori_loop + from torch._higher_order_ops.while_loop import while_loop + def cond_fn(init, limit_value): + return limit_value[0] <= init[0] + + def body_fn(init, limit_value): + one_value = torch.ones(1, dtype=torch.int32, device=device) + two_value = limit_value.clone() + # start = time() + image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, + num_inference_steps=2, # inference_steps, + height=512, # height, + width=512, # width, + ).images[0] + # print(f'Step {i} inference time {time()-start} sec', flush=True) + return (torch.sub(init, one_value), two_value) + + init = torch.tensor([3], dtype=torch.int32, device=device) + # iters = 3 + limit_value = torch.tensor([0], dtype=torch.int32, device=device) + res = while_loop(cond_fn, body_fn, (init, limit_value)) + # expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) + # self.assertEqual(expected, res) + + # iters = 1 # 15 + # print('starting inference', flush=True) + # for i in range(iters): + # start = time() + # image = pipe(prompts, + # num_inference_steps=inference_steps, + # height=height, + # width=width, + # ).images[0] + # print(f'Step {i} inference time {time()-start} sec', flush=True) if __name__ == '__main__': From 0fe20597964a1959fd347be04af4a3ce62aa7fe9 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 17:59:58 +0000 Subject: [PATCH 02/19] while_loop --- .../text_to_image/inference_tpu_single_device.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 2d51ad3e2..4702d7244 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -12,7 +12,7 @@ def parser(args): parser.add_argument( '--batch-size', type=int, - default=8, + default=2, # 8, help='Number of images to generate' ) @@ -26,7 +26,7 @@ def parser(args): parser.add_argument( '--inf-steps', type=int, - default=30, + default=2, # 30, help='Number of itterations to run the benchmark.' ) @@ -62,11 +62,11 @@ def body_fn(init, limit_value): one_value = torch.ones(1, dtype=torch.int32, device=device) two_value = limit_value.clone() # start = time() - image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, - num_inference_steps=2, # inference_steps, - height=512, # height, - width=512, # width, - ).images[0] + # image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, + # num_inference_steps=2, # inference_steps, + # height=512, # height, + # width=512, # width, + # ).images[0] # print(f'Step {i} inference time {time()-start} sec', flush=True) return (torch.sub(init, one_value), two_value) From 68d7f7ddc74303c73f3bf6c2d09817a559946469 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 19:49:44 +0000 Subject: [PATCH 03/19] while_loop --- .../text_to_image/inference_tpu_single_device.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 4702d7244..6f0cdfdd9 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -35,12 +35,12 @@ def parser(args): def main(args): server = xp.start_server(9012) - pipe = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-0.9", - use_safetensors=True, - ) + # pipe = DiffusionPipeline.from_pretrained( + # "stabilityai/stable-diffusion-xl-base-0.9", + # use_safetensors=True, + # ) device = xm.xla_device() - pipe.to(device) + # pipe.to(device) bs = args.batch_size # 1 inference_steps = args.inf_steps # 2 @@ -62,6 +62,12 @@ def body_fn(init, limit_value): one_value = torch.ones(1, dtype=torch.int32, device=device) two_value = limit_value.clone() # start = time() + pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-0.9", + use_safetensors=True, + ) + # device = xm.xla_device() + pipe.to(device) # image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, # num_inference_steps=2, # inference_steps, # height=512, # height, From bcdb9ad1cfd48903917a7b6875e01865a2890548 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:00:11 +0000 Subject: [PATCH 04/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 6f0cdfdd9..09ade9602 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -67,7 +67,7 @@ def body_fn(init, limit_value): use_safetensors=True, ) # device = xm.xla_device() - pipe.to(device) + # pipe.to(device) # image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, # num_inference_steps=2, # inference_steps, # height=512, # height, From eef8eab0451d0804506925f92650a89029c04bba Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:02:23 +0000 Subject: [PATCH 05/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 09ade9602..4efa21ff9 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -79,7 +79,9 @@ def body_fn(init, limit_value): init = torch.tensor([3], dtype=torch.int32, device=device) # iters = 3 limit_value = torch.tensor([0], dtype=torch.int32, device=device) - res = while_loop(cond_fn, body_fn, (init, limit_value)) + # res = while_loop(cond_fn, body_fn, (init, limit_value)) + from torch_xla.experimental.fori_loop import _xla_while_loop + res = _xla_while_loop(cond_fn, body_fn, (init, limit_value)) # expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) # self.assertEqual(expected, res) From 2a8066bb41915a268714290363bbf5435fcbd1f5 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:03:44 +0000 Subject: [PATCH 06/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 4efa21ff9..0f274bf86 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -82,6 +82,7 @@ def body_fn(init, limit_value): # res = while_loop(cond_fn, body_fn, (init, limit_value)) from torch_xla.experimental.fori_loop import _xla_while_loop res = _xla_while_loop(cond_fn, body_fn, (init, limit_value)) + print("result of while_loop: ", res) # expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) # self.assertEqual(expected, res) From 08be70364138fb6328ae67fd51f1e2da7d471764 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:08:08 +0000 Subject: [PATCH 07/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 0f274bf86..6fde68ad9 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -75,17 +75,26 @@ def body_fn(init, limit_value): # ).images[0] # print(f'Step {i} inference time {time()-start} sec', flush=True) return (torch.sub(init, one_value), two_value) - + + start = time() init = torch.tensor([3], dtype=torch.int32, device=device) # iters = 3 limit_value = torch.tensor([0], dtype=torch.int32, device=device) # res = while_loop(cond_fn, body_fn, (init, limit_value)) from torch_xla.experimental.fori_loop import _xla_while_loop res = _xla_while_loop(cond_fn, body_fn, (init, limit_value)) + print(f'Call pipeline with _xla_while_loop used {time()-start} sec', flush=True) print("result of while_loop: ", res) # expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) # self.assertEqual(expected, res) + start2 = time() + pipe2 = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-0.9", + use_safetensors=True, + ) + print(f'Call pipeline without _xla_while_loop used {time()-start2} sec', flush=True) + # iters = 1 # 15 # print('starting inference', flush=True) # for i in range(iters): From 6d9aab48c7ffe78cff3b219e617c951fe6b93a59 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:10:52 +0000 Subject: [PATCH 08/19] while_loop --- .../text_to_image/inference_tpu_single_device.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 6fde68ad9..995a0823c 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -77,23 +77,24 @@ def body_fn(init, limit_value): return (torch.sub(init, one_value), two_value) start = time() - init = torch.tensor([3], dtype=torch.int32, device=device) # iters = 3 + init = torch.tensor([3], dtype=torch.int32, device=device) limit_value = torch.tensor([0], dtype=torch.int32, device=device) # res = while_loop(cond_fn, body_fn, (init, limit_value)) from torch_xla.experimental.fori_loop import _xla_while_loop res = _xla_while_loop(cond_fn, body_fn, (init, limit_value)) - print(f'Call pipeline with _xla_while_loop used {time()-start} sec', flush=True) + print(f'Call pipeline with _xla_while_loop for three times used {time()-start} sec', flush=True) print("result of while_loop: ", res) # expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) # self.assertEqual(expected, res) start2 = time() - pipe2 = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-0.9", - use_safetensors=True, - ) - print(f'Call pipeline without _xla_while_loop used {time()-start2} sec', flush=True) + for i in range(iters): + pipe2 = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-0.9", + use_safetensors=True, + ) + print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) # iters = 1 # 15 # print('starting inference', flush=True) From 058c4a346f863d431c528e85adca7880faf44807 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:11:55 +0000 Subject: [PATCH 09/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 995a0823c..0ed13ed40 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -89,6 +89,7 @@ def body_fn(init, limit_value): # self.assertEqual(expected, res) start2 = time() + iters = 3 for i in range(iters): pipe2 = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-0.9", From b43f93771a2de33bf3aa507811a243c33753726d Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:16:10 +0000 Subject: [PATCH 10/19] while_loop --- .../inference_tpu_single_device.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 0ed13ed40..7b4df2632 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -67,13 +67,13 @@ def body_fn(init, limit_value): use_safetensors=True, ) # device = xm.xla_device() - # pipe.to(device) - # image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, - # num_inference_steps=2, # inference_steps, - # height=512, # height, - # width=512, # width, - # ).images[0] - # print(f'Step {i} inference time {time()-start} sec', flush=True) + pipe.to(device) + image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, + num_inference_steps=2, # inference_steps, + height=512, # height, + width=512, # width, + ).images[0] + print(f'Step {i} inference time {time()-start} sec', flush=True) return (torch.sub(init, one_value), two_value) start = time() @@ -88,14 +88,14 @@ def body_fn(init, limit_value): # expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) # self.assertEqual(expected, res) - start2 = time() - iters = 3 - for i in range(iters): - pipe2 = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-0.9", - use_safetensors=True, - ) - print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) + # start2 = time() + # iters = 3 + # for i in range(iters): + # pipe2 = DiffusionPipeline.from_pretrained( + # "stabilityai/stable-diffusion-xl-base-0.9", + # use_safetensors=True, + # ) + # print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) # iters = 1 # 15 # print('starting inference', flush=True) From 38e8602bffaf1ee48f242111f778e2baa44cfed4 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:29:44 +0000 Subject: [PATCH 11/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 7b4df2632..ec61128b2 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -73,7 +73,7 @@ def body_fn(init, limit_value): height=512, # height, width=512, # width, ).images[0] - print(f'Step {i} inference time {time()-start} sec', flush=True) + # print(f'Step {i} inference time {time()-start} sec', flush=True) return (torch.sub(init, one_value), two_value) start = time() From 41788d17dae72828627bcb3aa78025fc788d6dda Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 20:37:49 +0000 Subject: [PATCH 12/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index ec61128b2..6790b0c82 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -73,6 +73,7 @@ def body_fn(init, limit_value): height=512, # height, width=512, # width, ).images[0] + print("type of image: ", type(image)) # print(f'Step {i} inference time {time()-start} sec', flush=True) return (torch.sub(init, one_value), two_value) From 34b58c012257473f2ce8708ed5f798c664ebc07f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 21:02:25 +0000 Subject: [PATCH 13/19] while_loop --- .../text_to_image/inference_tpu_single_device.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 6790b0c82..29ca00057 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -68,12 +68,12 @@ def body_fn(init, limit_value): ) # device = xm.xla_device() pipe.to(device) - image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, - num_inference_steps=2, # inference_steps, - height=512, # height, - width=512, # width, - ).images[0] - print("type of image: ", type(image)) + # image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, + # num_inference_steps=2, # inference_steps, + # height=512, # height, + # width=512, # width, + # ).images[0] + # print("type of image: ", type(image)) # print(f'Step {i} inference time {time()-start} sec', flush=True) return (torch.sub(init, one_value), two_value) From e5230602e5527f6ea5e5f68951efaae9f6fab56b Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 21:09:15 +0000 Subject: [PATCH 14/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 29ca00057..4420bb4e9 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -68,6 +68,11 @@ def body_fn(init, limit_value): ) # device = xm.xla_device() pipe.to(device) + image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, + num_inference_steps=2, # inference_steps, + height=512, # height, + width=512, # width, + ) # image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, # num_inference_steps=2, # inference_steps, # height=512, # height, From 18364086957fbb8088d26688d99bc653360114cb Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 21:30:23 +0000 Subject: [PATCH 15/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 4420bb4e9..f874b28e0 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -59,8 +59,8 @@ def cond_fn(init, limit_value): return limit_value[0] <= init[0] def body_fn(init, limit_value): - one_value = torch.ones(1, dtype=torch.int32, device=device) - two_value = limit_value.clone() + # one_value = torch.ones(1, dtype=torch.int32, device=device) + # two_value = limit_value.clone() # start = time() pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-0.9", @@ -80,6 +80,8 @@ def body_fn(init, limit_value): # ).images[0] # print("type of image: ", type(image)) # print(f'Step {i} inference time {time()-start} sec', flush=True) + one_value = torch.ones(1, dtype=torch.int32, device=device) + two_value = limit_value.clone() return (torch.sub(init, one_value), two_value) start = time() From b9fd2d7d8feb2eb538739ee96b0126062197ca93 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Thu, 14 Mar 2024 22:05:38 +0000 Subject: [PATCH 16/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index f874b28e0..3be250e4f 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -72,7 +72,7 @@ def body_fn(init, limit_value): num_inference_steps=2, # inference_steps, height=512, # height, width=512, # width, - ) + ).images[0] # image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, # num_inference_steps=2, # inference_steps, # height=512, # height, From ba56ce6713a10a064c9301f2d8eafd15246a1a0f Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 15 Mar 2024 05:02:56 +0000 Subject: [PATCH 17/19] while_loop --- .../text_to_image/inference_tpu_single_device.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 3be250e4f..e5cb38b3f 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -96,14 +96,14 @@ def body_fn(init, limit_value): # expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) # self.assertEqual(expected, res) - # start2 = time() - # iters = 3 - # for i in range(iters): - # pipe2 = DiffusionPipeline.from_pretrained( - # "stabilityai/stable-diffusion-xl-base-0.9", - # use_safetensors=True, - # ) - # print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) + start2 = time() + iters = 3 + for i in range(iters): + pipe2 = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-0.9", + use_safetensors=True, + ) + print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) # iters = 1 # 15 # print('starting inference', flush=True) From 66fecae3513adc4401d5908c027090b9ad15a5cb Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 15 Mar 2024 05:10:13 +0000 Subject: [PATCH 18/19] while_loop --- examples/text_to_image/inference_tpu_single_device.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index e5cb38b3f..7a7490c1b 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -103,6 +103,12 @@ def body_fn(init, limit_value): "stabilityai/stable-diffusion-xl-base-0.9", use_safetensors=True, ) + pipe2.to(device) + image2 = pipe2(["a photo of an astronaut riding a horse on mars"], # prompts, + num_inference_steps=2, # inference_steps, + height=512, # height, + width=512, # width, + ).images[0] print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) # iters = 1 # 15 From c41298e6142a5415bcc4a22e50d015edbefa9718 Mon Sep 17 00:00:00 2001 From: manfeibaigithub Date: Fri, 15 Mar 2024 05:38:57 +0000 Subject: [PATCH 19/19] while_loop --- .../inference_tpu_single_device.py | 62 +++++++++++++------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/examples/text_to_image/inference_tpu_single_device.py b/examples/text_to_image/inference_tpu_single_device.py index 7a7490c1b..fee2170a8 100644 --- a/examples/text_to_image/inference_tpu_single_device.py +++ b/examples/text_to_image/inference_tpu_single_device.py @@ -51,6 +51,28 @@ def main(args): f'height = width = {width}', flush=True ) + + pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-0.9", + use_safetensors=True, + ) + pipe.to(device) + + start2 = time() + iters = 3 + for i in range(iters): + # pipe2 = DiffusionPipeline.from_pretrained( + # "stabilityai/stable-diffusion-xl-base-0.9", + # use_safetensors=True, + # ) + # pipe2.to(device) + image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, + num_inference_steps=2, # inference_steps, + height=512, # height, + width=512, # width, + ).images[0] + print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) + import torch import torch_xla.experimental.fori_loop @@ -62,12 +84,12 @@ def body_fn(init, limit_value): # one_value = torch.ones(1, dtype=torch.int32, device=device) # two_value = limit_value.clone() # start = time() - pipe = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-0.9", - use_safetensors=True, - ) - # device = xm.xla_device() - pipe.to(device) + # pipe = DiffusionPipeline.from_pretrained( + # "stabilityai/stable-diffusion-xl-base-0.9", + # use_safetensors=True, + # ) + # # device = xm.xla_device() + # pipe.to(device) image = pipe(["a photo of an astronaut riding a horse on mars"], # prompts, num_inference_steps=2, # inference_steps, height=512, # height, @@ -96,20 +118,20 @@ def body_fn(init, limit_value): # expected = _fake_while_loop(cond_fn, body_fn, (init, limit_value)) # self.assertEqual(expected, res) - start2 = time() - iters = 3 - for i in range(iters): - pipe2 = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-0.9", - use_safetensors=True, - ) - pipe2.to(device) - image2 = pipe2(["a photo of an astronaut riding a horse on mars"], # prompts, - num_inference_steps=2, # inference_steps, - height=512, # height, - width=512, # width, - ).images[0] - print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) + # start2 = time() + # iters = 3 + # for i in range(iters): + # # pipe2 = DiffusionPipeline.from_pretrained( + # # "stabilityai/stable-diffusion-xl-base-0.9", + # # use_safetensors=True, + # # ) + # pipe2.to(device) + # image2 = pipe2(["a photo of an astronaut riding a horse on mars"], # prompts, + # num_inference_steps=2, # inference_steps, + # height=512, # height, + # width=512, # width, + # ).images[0] + # print(f'Call pipeline without _xla_while_loop for three times used {time()-start2} sec', flush=True) # iters = 1 # 15 # print('starting inference', flush=True)