Skip to content

Commit da98df0

Browse files
authored
use the fp16 revision of SD (#72)
1 parent e31b3e2 commit da98df0

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

mii/grpc_related/modelresponse_server.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import sys
1111
import time
1212

13-
from torch import autocast
1413
from transformers import Conversation
1514
from mii.constants import GRPC_MAX_MSG_SIZE
1615

@@ -75,8 +74,7 @@ def Txt2ImgReply(self, request, context):
7574
request = [r for r in request.request]
7675

7776
start = time.time()
78-
with autocast("cuda"):
79-
response = self.inference_pipeline(request, **query_kwargs)
77+
response = self.inference_pipeline(request, **query_kwargs)
8078
end = time.time()
8179

8280
images_bytes = []

mii/models/providers/diffusers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import os
2+
import torch
23

34

45
def diffusers_provider(model_path, model_name, task_name, mii_config):
56
from diffusers import DiffusionPipeline
67
local_rank = int(os.getenv('LOCAL_RANK', '0'))
8+
9+
kwargs = {}
10+
if mii_config.torch_dtype() == torch.half:
11+
kwargs["torch_dtype"] = torch.float16
12+
kwargs["revision"] = "fp16"
13+
714
pipeline = DiffusionPipeline.from_pretrained(model_name,
8-
use_auth_token=mii_config.hf_auth_token)
15+
use_auth_token=mii_config.hf_auth_token,
16+
**kwargs)
917
pipeline = pipeline.to(f"cuda:{local_rank}")
1018
return pipeline

0 commit comments

Comments
 (0)