|
4 | 4 | #include "ggml-impl.h"
|
5 | 5 | #include "ggml-quants.h"
|
6 | 6 | #include "ggml.h"
|
| 7 | +#include "sgemm.h" |
7 | 8 |
|
8 | 9 | #if defined(_MSC_VER) || defined(__MINGW32__)
|
9 | 10 | #include <malloc.h> // using malloc.h with MSC/MINGW
|
@@ -10817,6 +10818,27 @@ static void ggml_compute_forward_mul_mat(
|
10817 | 10818 | }
|
10818 | 10819 | #endif
|
10819 | 10820 |
|
| 10821 | + if (src1_cont) { |
| 10822 | + for (int64_t j = 0; j < ne13; j++) |
| 10823 | + for (int64_t i = 0; i < ne12; i++) |
| 10824 | + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), |
| 10825 | + (const char *)src0->data + i/r2*nb02 + j/r3*nb03, |
| 10826 | + nb01/ggml_type_size(src0->type), |
| 10827 | + (const char *)src1->data + i*nb12 + j*nb13, |
| 10828 | + nb11/ggml_type_size(src1->type), |
| 10829 | + (char *)dst->data + i*nb2 + j*nb3, |
| 10830 | + nb1/ggml_type_size(dst->type), |
| 10831 | + ith, nth, |
| 10832 | + params->type, |
| 10833 | + src0->type, |
| 10834 | + src1->type, |
| 10835 | + dst->type)) |
| 10836 | + goto UseGgmlGemm1; |
| 10837 | + return; |
| 10838 | + } |
| 10839 | +UseGgmlGemm1: |
| 10840 | + (void)0; |
| 10841 | + |
10820 | 10842 | if (params->type == GGML_TASK_TYPE_INIT) {
|
10821 | 10843 | if (ith != 0) {
|
10822 | 10844 | return;
|
@@ -10848,6 +10870,28 @@ static void ggml_compute_forward_mul_mat(
|
10848 | 10870 | const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
10849 | 10871 | const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
10850 | 10872 |
|
| 10873 | + if (src1_cont) { |
| 10874 | + for (int64_t j = 0; j < ne13; j++) |
| 10875 | + for (int64_t i = 0; i < ne12; i++) |
| 10876 | + if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), |
| 10877 | + (const char *)src0->data + i/r2*nb02 + j/r3*nb03, |
| 10878 | + nb01/ggml_type_size(src0->type), |
| 10879 | + (const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i + |
| 10880 | + nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*j), |
| 10881 | + row_size/ggml_type_size(vec_dot_type), |
| 10882 | + (char *)dst->data + i*nb2 + j*nb3, |
| 10883 | + nb1/ggml_type_size(dst->type), |
| 10884 | + ith, nth, |
| 10885 | + params->type, |
| 10886 | + src0->type, |
| 10887 | + vec_dot_type, |
| 10888 | + dst->type)) |
| 10889 | + goto UseGgmlGemm2; |
| 10890 | + return; |
| 10891 | + } |
| 10892 | +UseGgmlGemm2: |
| 10893 | + (void)0; |
| 10894 | + |
10851 | 10895 | const int64_t nr0 = ne01; // src0 rows
|
10852 | 10896 | const int64_t nr1 = ne1*ne12*ne13; // src1 rows
|
10853 | 10897 |
|
|
0 commit comments