diff --git a/ggml-metal.m b/ggml-metal.m index 2810fa2a841c5..436d73081b3f6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -358,643 +358,616 @@ void ggml_metal_graph_compute( struct ggml_cgraph * gf) { metal_printf("%s: evaluating graph\n", __func__); - // create multiple command buffers and enqueue them - // then, we encode the graph into the command buffers in parallel - - const int n_cb = ctx->n_cb; - - NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb]; - - for (int i = 0; i < n_cb; ++i) { - command_buffers[i] = [ctx->queue commandBuffer]; - - // enqueue the command buffers in order to specify their execution order - [command_buffers[i] enqueue]; - } - - // TODO: is this the best way to start threads? - dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); - - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb; - - dispatch_async(queue, ^{ - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_dst = 0; - - id command_buffer = command_buffers[cb_idx]; - - id encoder = nil; - - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb; - - for (int i = node_start; i < node_end; ++i) { - metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); - - struct ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct ggml_tensor * dst = gf->nodes[i]; - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - id id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; - id id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; - - //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop - } break; - case GGML_OP_ADD: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - [encoder setComputePipelineState:ctx->pipeline_add]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_MUL: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - if (ggml_nelements(src1) == ne10) { - // src1 is a row - [encoder setComputePipelineState:ctx->pipeline_mul_row]; - } else { - [encoder setComputePipelineState:ctx->pipeline_mul]; - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - const float scale = *(const float *) src1->data; - - [encoder setComputePipelineState:ctx->pipeline_scale]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SILU: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - [encoder setComputePipelineState:ctx->pipeline_silu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_RELU: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - [encoder setComputePipelineState:ctx->pipeline_relu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_GELU: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - [encoder setComputePipelineState:ctx->pipeline_gelu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - const int nth = 32; - - [encoder setComputePipelineState:ctx->pipeline_soft_max]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - const int n_past = ((int32_t *)(src1->data))[0]; - - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_MUL_MAT: - { - // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 - - GGML_ASSERT(ne00 == ne10); - GGML_ASSERT(ne02 == ne12); - - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { - - if (encoder != nil) { - [encoder endEncoding]; - encoder = nil; - } - - MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; - MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; - - // for F32 x F32 we use MPS - MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt]; - - MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt]; - - MPSMatrixDescriptor * desc = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32]; - - MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] - initWithDevice:ctx->device transposeLeft:false transposeRight:true - resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; - - // we need to do ne02 multiplications - // TODO: is there a way to do this in parallel - currently very slow .. - // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS - for (int64_t i02 = 0; i02 < ne02; ++i02) { - size_t offs_src0_cur = offs_src0 + i02*nb02; - size_t offs_src1_cur = offs_src1 + i02*nb12; - size_t offs_dst_cur = offs_dst + i02*nb2; - - MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0]; - MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1]; - MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ]; - - [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; - } - } else { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; + id command_buffer = [ctx->queue commandBuffer]; + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_dst = 0; + + id encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + + for (int i = 0; i < gf->n_nodes; ++i) { + metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + + struct ggml_tensor * src0 = gf->nodes[i]->src[0]; + struct ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct ggml_tensor * dst = gf->nodes[i]; + + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; + + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; + + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + id id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; + id id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; + + //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //if (src0) { + // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop + } break; + case GGML_OP_BARRIER: + { + [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers | MTLBarrierScopeRenderTargets | MTLBarrierScopeTextures]; + } break; + case GGML_OP_ADD: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + [encoder setComputePipelineState:ctx->pipeline_add]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + if (ggml_nelements(src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_mul_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + const float scale = *(const float *) src1->data; + + [encoder setComputePipelineState:ctx->pipeline_scale]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SILU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_RELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + [encoder setComputePipelineState:ctx->pipeline_relu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_GELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + [encoder setComputePipelineState:ctx->pipeline_gelu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + const int nth = 32; + + [encoder setComputePipelineState:ctx->pipeline_soft_max]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + const int n_past = ((int32_t *)(src1->data))[0]; + + [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL_MAT: + { + // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 + + GGML_ASSERT(ne00 == ne10); + GGML_ASSERT(ne02 == ne12); + + if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { + + if (encoder != nil) { + [encoder endEncoding]; + encoder = nil; + } + + MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; + MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; + + // for F32 x F32 we use MPS + MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor + matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt]; + + MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor + matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt]; + + MPSMatrixDescriptor * desc = [MPSMatrixDescriptor + matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32]; + + MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] + initWithDevice:ctx->device transposeLeft:false transposeRight:true + resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; + + // we need to do ne02 multiplications + // TODO: is there a way to do this in parallel - currently very slow .. + // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS + for (int64_t i02 = 0; i02 < ne02; ++i02) { + size_t offs_src0_cur = offs_src0 + i02*nb02; + size_t offs_src1_cur = offs_src1 + i02*nb12; + size_t offs_dst_cur = offs_dst + i02*nb2; + + MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0]; + MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1]; + MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ]; + + [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; + } + } else { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + int nth0 = 32; + int nth1 = 1; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F16: + { + GGML_ASSERT(ne02 == ne12); + + nth0 = 64; + nth1 = 1; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } break; + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; + } break; + case GGML_TYPE_Q2_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; + } break; + case GGML_TYPE_Q3_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; + } break; + case GGML_TYPE_Q4_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; + } break; + case GGML_TYPE_Q5_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; + } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; + } break; + default: + { + fprintf(stderr, "Asserting on type %d\n",(int)src0t); + GGML_ASSERT(false && "not implemented"); } - - int nth0 = 32; - int nth1 = 1; - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F16: - { - GGML_ASSERT(ne02 == ne12); - - nth0 = 64; - nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; - } break; - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; - } break; - case GGML_TYPE_Q2_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; - } break; - case GGML_TYPE_Q3_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; - } break; - case GGML_TYPE_Q4_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; - } break; - case GGML_TYPE_Q5_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; - } break; - case GGML_TYPE_Q6_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; - } break; - default: - { - fprintf(stderr, "Asserting on type %d\n",(int)src0t); - GGML_ASSERT(false && "not implemented"); - } - }; - - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || - src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { + }; + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case GGML_OP_GET_ROWS: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - switch (src0->type) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; - default: GGML_ASSERT(false && "not implemented"); - } - - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; - - const int64_t n = ggml_nelements(src1); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - const float eps = 1e-6f; - - const int nth = 512; - - [encoder setComputePipelineState:ctx->pipeline_rms_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - const float eps = 1e-5f; - - const int nth = 256; - - [encoder setComputePipelineState:ctx->pipeline_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ALIBI: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - GGML_ASSERT((src0t == GGML_TYPE_F32)); - - const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past); - const int n_head = ((int32_t *) src1->data)[1]; - const float max_bias = ((float *) src1->data)[2]; - - if (__builtin_popcount(n_head) != 1) { - GGML_ASSERT(false && "only power-of-two n_head implemented"); - } - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - - [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - const int nth = 32; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - - const int n_past = ((int32_t *)(src1->data))[0]; - - float freq_base; - float freq_scale; - memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); - - [encoder setComputePipelineState:ctx->pipeline_rope]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; - [encoder setBytes:&mode length:sizeof( int) atIndex:20]; - [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; - [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_CPY: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - const int nth = 32; - - switch (src0t) { - case GGML_TYPE_F32: - { - switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; - default: GGML_ASSERT(false && "not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; - case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; - default: GGML_ASSERT(false && "not implemented"); - }; - } break; - default: GGML_ASSERT(false && "not implemented"); - } - - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ASSERT(false); - } - } - - if (encoder != nil) { - [encoder endEncoding]; - encoder = nil; - } + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_GET_ROWS: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; + default: GGML_ASSERT(false && "not implemented"); + } + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; + + const int64_t n = ggml_nelements(src1); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + const float eps = 1e-6f; + + const int nth = 512; + + [encoder setComputePipelineState:ctx->pipeline_rms_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_NORM: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + const float eps = 1e-5f; + + const int nth = 256; + + [encoder setComputePipelineState:ctx->pipeline_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ALIBI: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + GGML_ASSERT((src0t == GGML_TYPE_F32)); + + const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past); + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; + + if (__builtin_popcount(n_head) != 1) { + GGML_ASSERT(false && "only power-of-two n_head implemented"); + } + + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + + [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + const int nth = 32; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; + + const int n_past = ((int32_t *)(src1->data))[0]; + + float freq_base; + float freq_scale; + memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + + [encoder setComputePipelineState:ctx->pipeline_rope]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; + [encoder setBytes:&mode length:sizeof( int) atIndex:20]; + [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; + [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_CPY: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; + } + + const int nth = 32; + + switch (src0t) { + case GGML_TYPE_F32: + { + switch (dstt) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; + case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_ASSERT(false); + } + } - [command_buffer commit]; - }); + if (encoder != nil) { + [encoder endEncoding]; + encoder = nil; } - // wait for all threads to finish - dispatch_barrier_sync(queue, ^{}); + [command_buffer commit]; - [command_buffers[n_cb - 1] waitUntilCompleted]; + [command_buffer waitUntilCompleted]; // check status of command buffers // needed to detect if the device ran out-of-memory for example (#1881) - for (int i = 0; i < n_cb; i++) { - MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status]; - if (status != MTLCommandBufferStatusCompleted) { - fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status); - GGML_ASSERT(false); - } + MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + fprintf(stderr, "%s: command buffer failed with status %lu\n", __func__, status); + GGML_ASSERT(false); } } diff --git a/ggml.c b/ggml.c index 6055da867cb27..c0c847074bed2 100644 --- a/ggml.c +++ b/ggml.c @@ -3807,9 +3807,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", + "BARRIER", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3887,9 +3888,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", + "memory barrier", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -15164,6 +15166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { // nop } break; + case GGML_OP_BARRIER: + { + // nop + } break; case GGML_OP_COUNT: { GGML_ASSERT(false); @@ -15999,6 +16005,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // nop } break; + case GGML_OP_BARRIER: + { + // nop + } break; case GGML_OP_COUNT: { GGML_ASSERT(false); @@ -16077,6 +16087,66 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten } } +void ggml_graph_find_concurrency(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { + int search_depth = 40; //we only find concurrency in this range to avoiding waste to much time + struct ggml_tensor * nodes_bak[GGML_MAX_NODES]={NULL}; + struct ggml_tensor * barrier_node; + barrier_node = ggml_new_tensor_1d(ctx,GGML_TYPE_F32,0); + barrier_node->op=GGML_OP_BARRIER; + + for (int i=0; i < cgraph->n_nodes; i++) { + nodes_bak[i] = cgraph->nodes[i]; + cgraph->nodes[i] = NULL; + } + + int n_left = cgraph->n_nodes; + int n_start = 0; // all nodes before n_start at nodes_bak array have been sorted and store back to cgraph->nodes + int level_pos = 0; // at cgraph->nodes, the last layer (level) ends at level_pos + while (n_left > 0) { + // number of nodes at a layer (that can be issued concurrently) + int concurrency = 0; + for (int i = n_start; i < n_start + search_depth; i++) { + if (nodes_bak[i]) { + + // if the requirements for nodes_bak[i] are satisfied + int exe_flag=1; + // scan all srcs + for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { + struct ggml_tensor * src_cur = nodes_bak[i]->src[src_ind]; + if (src_cur) { + // if is leaf nodes it's satisfied. + if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;} + // otherwise if this src is the output from previous nodes. + + int is_found = 0; + // scan 2*search_depth back because we insert barrier nodes. + for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { + if (cgraph->nodes[j] == src_cur) {is_found = 1; break;} + } + if (is_found == 0) {exe_flag = 0; break;} + } + } + if (exe_flag) { + cgraph->nodes[level_pos + concurrency] = nodes_bak[i]; + nodes_bak[i] = NULL; + concurrency++; + } + } + } + n_left -= concurrency; + // adding a barrier between different layer + cgraph->nodes[level_pos + concurrency] = barrier_node; + cgraph->n_nodes++; + // jump all sorted nodes at nodes_bak + while (!nodes_bak[n_start]) {n_start++;} + level_pos += concurrency + 1; + } + //remove the last barrier after result_output + cgraph->nodes[cgraph->n_nodes-1] = NULL; + cgraph->n_nodes--; + +} + void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { ggml_build_forward_impl(cgraph, tensor, true); } @@ -16721,6 +16791,10 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { { n_tasks = 1; } break; + case GGML_OP_BARRIER: + { + // nop + } break; case GGML_OP_COUNT: { GGML_ASSERT(false); diff --git a/ggml.h b/ggml.h index 5023b16528788..feea9ab0a81ba 100644 --- a/ggml.h +++ b/ggml.h @@ -194,7 +194,7 @@ #define GGML_QNT_VERSION_FACTOR 1000 // do not change this #define GGML_MAX_DIMS 4 -#define GGML_MAX_NODES 4096 +#define GGML_MAX_NODES 8192 #define GGML_MAX_PARAMS 256 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 6 @@ -387,6 +387,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_BARRIER, // Any operation between two barriers can be issued concurrently. + GGML_OP_COUNT, }; @@ -1363,6 +1365,9 @@ extern "C" { GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); + //sort all nodes in a graph to find operations that can be issued concurrently, insert memory barrier if necessary + GGML_API void ggml_graph_find_concurrency(struct ggml_context * ctx, struct ggml_cgraph * cgraph); + // print info and performance information for the graph GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); diff --git a/llama.cpp b/llama.cpp index 0a381afd5b726..cb04d4905bcc3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1662,6 +1662,7 @@ static bool llama_eval_internal( #ifdef GGML_USE_METAL if (lctx.ctx_metal && N == 1) { + ggml_graph_find_concurrency(ctx0,&gf); ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_graph_compute(lctx.ctx_metal, &gf); ggml_metal_get_tensor (lctx.ctx_metal, cur);