Skip to content

Commit 82ba63d

Browse files
committed
Disable half2 by default when using HIP
1 parent b992719 commit 82ba63d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

model_init.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from model import ExLlama, ExLlamaCache, ExLlamaConfig
22
from tokenizer import ExLlamaTokenizer
33
import argparse, sys, os, glob
4+
from torch import version as torch_version
45

56
def add_args(parser):
67

@@ -23,11 +24,12 @@ def add_args(parser):
2324
parser.add_argument("-mmnh2", "--matmul_no_half2", action = "store_true", help = "Don't use half2 in Q4 matmul kernel")
2425
parser.add_argument("-snh2", "--silu_no_half2", action = "store_true", help = "Don't use half2 in SiLU kernel")
2526
parser.add_argument("-nh2", "--no_half2", action = "store_true", help = "(All of the above) disable half2 in all kernela")
27+
parser.add_argument("-fh2", "--force_half2", action = "store_true", help = "Force enable half2 even if unsupported")
2628

2729

2830
def post_parse(args):
2931

30-
if args.no_half2:
32+
if args.no_half2 or torch_version.hip and not args.force_half2:
3133
args.rmsnorm_no_half2 = True
3234
args.rope_no_half2 = True
3335
args.matmul_no_half2 = True

0 commit comments

Comments
 (0)