@@ -139,6 +139,20 @@ if (GGML_METAL)
139
139
)
140
140
endif ()
141
141
142
+ if (GGML_MUSA)
143
+ set (CMAKE_C_COMPILER clang)
144
+ set (CMAKE_C_EXTENSIONS OFF )
145
+ set (CMAKE_CXX_COMPILER clang++)
146
+ set (CMAKE_CXX_EXTENSIONS OFF )
147
+
148
+ set (GGML_CUDA ON )
149
+ set (GGML_OPENMP OFF )
150
+
151
+ list (APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA)
152
+
153
+ add_compile_definitions (GGML_USE_MUSA)
154
+ endif ()
155
+
142
156
if (GGML_OPENMP)
143
157
find_package (OpenMP)
144
158
if (OpenMP_FOUND)
@@ -249,7 +263,13 @@ endif()
249
263
if (GGML_CUDA)
250
264
cmake_minimum_required (VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
251
265
252
- find_package (CUDAToolkit)
266
+ if (GGML_MUSA)
267
+ list (APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/" )
268
+ find_package (MUSAToolkit)
269
+ set (CUDAToolkit_FOUND ${MUSAToolkit_FOUND} )
270
+ else ()
271
+ find_package (CUDAToolkit)
272
+ endif ()
253
273
254
274
if (CUDAToolkit_FOUND)
255
275
message (STATUS "CUDA found" )
@@ -268,7 +288,11 @@ if (GGML_CUDA)
268
288
endif ()
269
289
message (STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES} " )
270
290
271
- enable_language (CUDA)
291
+ if (GGML_MUSA)
292
+ set (CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE} )
293
+ else ()
294
+ enable_language (CUDA)
295
+ endif ()
272
296
273
297
file (GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh" )
274
298
list (APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h" )
@@ -332,21 +356,40 @@ if (GGML_CUDA)
332
356
add_compile_definitions (GGML_CUDA_NO_PEER_COPY)
333
357
endif ()
334
358
359
+ if (GGML_MUSA)
360
+ set_source_files_properties (${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
361
+ foreach (SOURCE ${GGML_SOURCES_CUDA} )
362
+ set_property (SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22" )
363
+ endforeach ()
364
+ endif ()
365
+
335
366
if (GGML_STATIC)
336
367
if (WIN32 )
337
368
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
338
369
set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
339
370
else ()
340
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
371
+ if (GGML_MUSA)
372
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static)
373
+ else ()
374
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
375
+ endif ()
341
376
endif ()
342
377
else ()
343
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
378
+ if (GGML_MUSA)
379
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas)
380
+ else ()
381
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
382
+ endif ()
344
383
endif ()
345
384
346
385
if (GGML_CUDA_NO_VMM)
347
386
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
348
387
else ()
349
- set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
388
+ if (GGML_MUSA)
389
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
390
+ else ()
391
+ set (GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
392
+ endif ()
350
393
endif ()
351
394
else ()
352
395
message (WARNING "CUDA not found" )
@@ -757,8 +800,10 @@ function(get_flags CCID CCVER)
757
800
set (C_FLAGS -Wdouble-promotion)
758
801
set (CXX_FLAGS -Wno-array-bounds)
759
802
760
- if (CCVER VERSION_GREATER_EQUAL 7.1.0)
761
- list (APPEND CXX_FLAGS -Wno-format-truncation)
803
+ if (NOT GGML_MUSA)
804
+ if (CCVER VERSION_GREATER_EQUAL 7.1.0)
805
+ list (APPEND CXX_FLAGS -Wno-format-truncation)
806
+ endif ()
762
807
endif ()
763
808
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
764
809
list (APPEND CXX_FLAGS -Wextra-semi)
@@ -1059,7 +1104,9 @@ if (GGML_CUDA)
1059
1104
list (JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
1060
1105
1061
1106
if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "" )
1062
- list (APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED} )
1107
+ # list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
1108
+ # XXX: Removed flags: -Xcompiler
1109
+ list (APPEND CUDA_FLAGS ${CUDA_CXX_FLAGS_JOINED} )
1063
1110
endif ()
1064
1111
1065
1112
add_compile_options ("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS} >" )
0 commit comments