Skip to content

metal: matrix-matrix multiplication kernel #2615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -298,7 +298,6 @@ if (LLAMA_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)

set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)

@@ -315,7 +314,6 @@ if (LLAMA_METAL)
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST
ifdef LLAMA_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
CXXFLAGS += -DGGML_USE_METAL
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
OBJS += ggml-metal.o
endif # LLAMA_METAL

2 changes: 0 additions & 2 deletions flake.nix
Original file line number Diff line number Diff line change
@@ -14,8 +14,6 @@
with pkgs.darwin.apple_sdk_11_0.frameworks; [
Accelerate
MetalKit
MetalPerformanceShaders
MetalPerformanceShadersGraph
]
else if isAarch32 && isDarwin then
with pkgs.darwin.apple_sdk.frameworks; [
171 changes: 55 additions & 116 deletions ggml-metal.m

Large diffs are not rendered by default.

969 changes: 471 additions & 498 deletions ggml-metal.metal

Large diffs are not rendered by default.

18 changes: 1 addition & 17 deletions llama.cpp
Original file line number Diff line number Diff line change
@@ -1830,7 +1830,7 @@ static bool llama_eval_internal(
#endif

#ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) {
if (lctx.ctx_metal) {
// TODO: disabled until #2413 is resolved
//if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
// ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
@@ -1842,22 +1842,6 @@ static bool llama_eval_internal(
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
}
} else {
// IMPORTANT:
// Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
// ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX
// coprocessor.
//
// When we implement Matrix x Matrix Metal multiplication, we can avoid this branch.
// But for now, we have focused only on Matrix x Vector Metal multiplication.
//
// TODO: avoid these syncs via shared memory (ref #1696)
//
if (lctx.ctx_metal) {
// We need to sync the GPU KV cache with the CPU KV cache
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k);
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
}

ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
}
#else