@@ -52,8 +52,9 @@ pub fn build_alibi_tensor(
52
52
device : & Device ,
53
53
dtype : DType ,
54
54
) -> Result < Tensor > {
55
- let context_positions = Tensor :: arange ( 0.0 , num_positions as f64 , device) ?. unsqueeze ( 1 ) ?;
56
- let memory_positions = Tensor :: arange ( 0.0 , num_positions as f64 , device) ?. unsqueeze ( 0 ) ?;
55
+ let context_positions =
56
+ Tensor :: arange ( 0.0 , num_positions as f64 , & Device :: Cpu ) ?. unsqueeze ( 1 ) ?;
57
+ let memory_positions = Tensor :: arange ( 0.0 , num_positions as f64 , & Device :: Cpu ) ?. unsqueeze ( 0 ) ?;
57
58
58
59
let relative_positions = memory_positions. broadcast_sub ( & context_positions) ?. abs ( ) ?;
59
60
// [num_heads, num_positions, num_positions]
@@ -63,13 +64,17 @@ pub fn build_alibi_tensor(
63
64
. expand ( ( num_heads, num_positions, num_positions) ) ?;
64
65
65
66
// [num_heads, 1, 1]
66
- let slopes =
67
- ( Tensor :: from_vec ( alibi_head_slopes ( num_heads) , ( num_heads, 1 , 1 ) , device) ? * -1_f64 ) ?;
67
+ let slopes = ( Tensor :: from_vec (
68
+ alibi_head_slopes ( num_heads) ,
69
+ ( num_heads, 1 , 1 ) ,
70
+ & Device :: Cpu ,
71
+ ) ? * -1_f64 ) ?;
68
72
69
73
// [num_heads, num_positions, num_positions]
70
74
let alibi = relative_positions. broadcast_mul ( & slopes) ?;
71
75
72
76
alibi
73
77
. reshape ( ( 1 , num_heads, num_positions, num_positions) ) ?
74
- . to_dtype ( dtype)
78
+ . to_dtype ( dtype) ?
79
+ . to_device ( device)
75
80
}
0 commit comments