|
9 | 9 | #include <map>
|
10 | 10 | #include <string>
|
11 | 11 | #include <vector>
|
| 12 | +#include <regex> |
| 13 | + |
| 14 | +// TODO: move somewhere else |
| 15 | +#define QK 32 |
| 16 | + |
12 | 17 |
|
13 | 18 | // determine number of model parts based on the dimension
|
14 | 19 | static const std::map<int, int> LLAMA_N_PARTS = {
|
@@ -681,3 +686,258 @@ bool llama_eval(
|
681 | 686 |
|
682 | 687 | return true;
|
683 | 688 | }
|
| 689 | +bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype) { |
| 690 | + ggml_type type = GGML_TYPE_Q4_1; |
| 691 | + |
| 692 | + switch (itype) { |
| 693 | + case 2: type = GGML_TYPE_Q4_0; break; |
| 694 | + case 3: type = GGML_TYPE_Q4_1; break; |
| 695 | + default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1; |
| 696 | + }; |
| 697 | + |
| 698 | + if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) { |
| 699 | + fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type); |
| 700 | + return false; |
| 701 | + } |
| 702 | + |
| 703 | + gpt_vocab vocab; |
| 704 | + |
| 705 | + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); |
| 706 | + |
| 707 | + auto finp = std::ifstream(fname_inp, std::ios::binary); |
| 708 | + if (!finp) { |
| 709 | + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); |
| 710 | + return false; |
| 711 | + } |
| 712 | + |
| 713 | + auto fout = std::ofstream(fname_out, std::ios::binary); |
| 714 | + if (!fout) { |
| 715 | + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); |
| 716 | + return false; |
| 717 | + } |
| 718 | + |
| 719 | + // verify magic |
| 720 | + { |
| 721 | + uint32_t magic; |
| 722 | + finp.read((char *) &magic, sizeof(magic)); |
| 723 | + if (magic != 0x67676d6c) { |
| 724 | + fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str()); |
| 725 | + return false; |
| 726 | + } |
| 727 | + |
| 728 | + fout.write((char *) &magic, sizeof(magic)); |
| 729 | + } |
| 730 | + |
| 731 | + llama_hparams hparams; |
| 732 | + |
| 733 | + // load hparams |
| 734 | + { |
| 735 | + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); |
| 736 | + //finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); |
| 737 | + finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); |
| 738 | + finp.read((char *) &hparams.n_mult, sizeof(hparams.n_mult)); |
| 739 | + finp.read((char *) &hparams.n_head, sizeof(hparams.n_head)); |
| 740 | + finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); |
| 741 | + finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot)); |
| 742 | + finp.read((char *) &hparams.f16, sizeof(hparams.f16)); |
| 743 | + |
| 744 | + printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); |
| 745 | + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); |
| 746 | + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); |
| 747 | + printf("%s: n_mult = %d\n", __func__, hparams.n_mult); |
| 748 | + printf("%s: n_head = %d\n", __func__, hparams.n_head); |
| 749 | + printf("%s: n_layer = %d\n", __func__, hparams.n_layer); |
| 750 | + printf("%s: f16 = %d\n", __func__, hparams.f16); |
| 751 | + |
| 752 | + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); |
| 753 | + //fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); |
| 754 | + fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd)); |
| 755 | + fout.write((char *) &hparams.n_mult, sizeof(hparams.n_mult)); |
| 756 | + fout.write((char *) &hparams.n_head, sizeof(hparams.n_head)); |
| 757 | + fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer)); |
| 758 | + fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot)); |
| 759 | + fout.write((char *) &itype, sizeof(hparams.f16)); |
| 760 | + } |
| 761 | + |
| 762 | + // load vocab |
| 763 | + { |
| 764 | + const int32_t n_vocab = hparams.n_vocab; |
| 765 | + |
| 766 | + if (n_vocab != hparams.n_vocab) { |
| 767 | + fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", |
| 768 | + __func__, fname_inp.c_str(), n_vocab, hparams.n_vocab); |
| 769 | + return false; |
| 770 | + } |
| 771 | + |
| 772 | + std::string word; |
| 773 | + for (int i = 0; i < n_vocab; i++) { |
| 774 | + uint32_t len; |
| 775 | + finp.read ((char *) &len, sizeof(len)); |
| 776 | + fout.write((char *) &len, sizeof(len)); |
| 777 | + |
| 778 | + word.resize(len); |
| 779 | + finp.read ((char *) word.data(), len); |
| 780 | + fout.write((char *) word.data(), len); |
| 781 | + |
| 782 | + vocab.token_to_id[word] = i; |
| 783 | + vocab.id_to_token[i] = word; |
| 784 | + } |
| 785 | + } |
| 786 | + |
| 787 | + // load weights |
| 788 | + { |
| 789 | + size_t total_size_org = 0; |
| 790 | + size_t total_size_new = 0; |
| 791 | + |
| 792 | + std::vector<float> work; |
| 793 | + |
| 794 | + std::vector<uint8_t> data_u8; |
| 795 | + std::vector<ggml_fp16_t> data_f16; |
| 796 | + std::vector<float> data_f32; |
| 797 | + |
| 798 | + std::vector<int64_t> hist_all(1 << 4, 0); |
| 799 | + |
| 800 | + while (true) { |
| 801 | + int32_t n_dims; |
| 802 | + int32_t length; |
| 803 | + int32_t ftype; |
| 804 | + |
| 805 | + finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims)); |
| 806 | + finp.read(reinterpret_cast<char *>(&length), sizeof(length)); |
| 807 | + finp.read(reinterpret_cast<char *>(&ftype), sizeof(ftype)); |
| 808 | + |
| 809 | + if (finp.eof()) { |
| 810 | + break; |
| 811 | + } |
| 812 | + |
| 813 | + int32_t nelements = 1; |
| 814 | + int32_t ne[2] = { 1, 1 }; |
| 815 | + for (int i = 0; i < n_dims; ++i) { |
| 816 | + finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i])); |
| 817 | + nelements *= ne[i]; |
| 818 | + } |
| 819 | + |
| 820 | + std::string name(length, 0); |
| 821 | + finp.read (&name[0], length); |
| 822 | + |
| 823 | + { |
| 824 | + static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", }; |
| 825 | + printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]); |
| 826 | + } |
| 827 | + |
| 828 | + // regexes of tensor names to be quantized |
| 829 | + const std::vector<std::string> k_names = { |
| 830 | + ".*weight", |
| 831 | + }; |
| 832 | + |
| 833 | + bool quantize = false; |
| 834 | + for (const auto & s : k_names) { |
| 835 | + if (std::regex_match(name, std::regex(s))) { |
| 836 | + quantize = true; |
| 837 | + break; |
| 838 | + } |
| 839 | + } |
| 840 | + |
| 841 | + // quantize only 2D tensors |
| 842 | + quantize &= (n_dims == 2); |
| 843 | + |
| 844 | + if (quantize) { |
| 845 | + if (ftype != 0 && ftype != 1) { |
| 846 | + fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype); |
| 847 | + return false; |
| 848 | + } |
| 849 | + |
| 850 | + if (ftype == 1) { |
| 851 | + data_f16.resize(nelements); |
| 852 | + finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t)); |
| 853 | + data_f32.resize(nelements); |
| 854 | + for (int i = 0; i < nelements; ++i) { |
| 855 | + data_f32[i] = ggml_fp16_to_fp32(data_f16[i]); |
| 856 | + } |
| 857 | + } else { |
| 858 | + data_f32.resize(nelements); |
| 859 | + finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float)); |
| 860 | + } |
| 861 | + |
| 862 | + ftype = itype; |
| 863 | + } else { |
| 864 | + const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t); |
| 865 | + |
| 866 | + data_u8.resize(nelements*bpe); |
| 867 | + finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe); |
| 868 | + } |
| 869 | + |
| 870 | + fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims)); |
| 871 | + fout.write(reinterpret_cast<char *>(&length), sizeof(length)); |
| 872 | + fout.write(reinterpret_cast<char *>(&ftype), sizeof(ftype)); |
| 873 | + for (int i = 0; i < n_dims; ++i) { |
| 874 | + fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i])); |
| 875 | + } |
| 876 | + fout.write(&name[0], length); |
| 877 | + |
| 878 | + if (quantize) { |
| 879 | + printf("quantizing .. "); |
| 880 | + work.resize(nelements); // for quantization |
| 881 | + |
| 882 | + size_t cur_size = 0; |
| 883 | + std::vector<int64_t> hist_cur(1 << 4, 0); |
| 884 | + |
| 885 | + switch (type) { |
| 886 | + case GGML_TYPE_Q4_0: |
| 887 | + { |
| 888 | + cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], QK, hist_cur.data()); |
| 889 | + } break; |
| 890 | + case GGML_TYPE_Q4_1: |
| 891 | + { |
| 892 | + cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], QK, hist_cur.data()); |
| 893 | + } break; |
| 894 | + default: |
| 895 | + { |
| 896 | + fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type); |
| 897 | + return false; |
| 898 | + } |
| 899 | + } |
| 900 | + |
| 901 | + fout.write(reinterpret_cast<char *>(work.data()), cur_size); |
| 902 | + total_size_new += cur_size; |
| 903 | + |
| 904 | + printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0); |
| 905 | + for (int i = 0; i < hist_cur.size(); ++i) { |
| 906 | + hist_all[i] += hist_cur[i]; |
| 907 | + } |
| 908 | + |
| 909 | + for (int i = 0; i < hist_cur.size(); ++i) { |
| 910 | + printf("%5.3f ", hist_cur[i] / (float)nelements); |
| 911 | + } |
| 912 | + printf("\n"); |
| 913 | + } else { |
| 914 | + printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0); |
| 915 | + fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size()); |
| 916 | + total_size_new += data_u8.size(); |
| 917 | + } |
| 918 | + |
| 919 | + total_size_org += nelements * sizeof(float); |
| 920 | + } |
| 921 | + |
| 922 | + printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); |
| 923 | + printf("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); |
| 924 | + |
| 925 | + { |
| 926 | + int64_t sum_all = 0; |
| 927 | + for (int i = 0; i < hist_all.size(); ++i) { |
| 928 | + sum_all += hist_all[i]; |
| 929 | + } |
| 930 | + |
| 931 | + printf("%s: hist: ", __func__); |
| 932 | + for (int i = 0; i < hist_all.size(); ++i) { |
| 933 | + printf("%5.3f ", hist_all[i] / (float)sum_all); |
| 934 | + } |
| 935 | + printf("\n"); |
| 936 | + } |
| 937 | + } |
| 938 | + |
| 939 | + finp.close(); |
| 940 | + fout.close(); |
| 941 | + |
| 942 | + return true; |
| 943 | +} |
0 commit comments