diff --git a/convert.py b/convert.py
index f3bf1798089cc..54dba5979cb38 100644
--- a/convert.py
+++ b/convert.py
@@ -142,6 +142,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
 @dataclass
 class Params:
     n_vocab:   int
+    n_vocab_base: int
     n_embd:    int
     n_mult:    int
     n_head:    int
@@ -169,6 +170,7 @@ def guessed(model: 'LazyModel') -> 'Params':
 
         return Params(
             n_vocab   = n_vocab,
+            n_vocab_base=n_vocab,
             n_embd    = n_embd,
             n_mult    = 256,
             n_head    = n_head,
@@ -191,6 +193,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
 
         return Params(
             n_vocab   = n_vocab,
+            n_vocab_base=n_vocab,
             n_embd    = n_embd,
             n_mult    = n_mult,
             n_head    = n_head,
@@ -215,6 +218,7 @@ def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
 
         return Params(
             n_vocab   = n_vocab,
+            n_vocab_base=n_vocab,
             n_embd    = n_embd,
             n_mult    = n_mult,
             n_head    = n_head,
@@ -239,7 +243,7 @@ def load(model_plus: 'ModelPlus') -> 'Params':
 
 
 class SentencePieceVocab:
-    def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None:
+    def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], fname_special_tokens: Optional[Path], fname_tokenizer_config: Optional[Path], vocabtype: Optional[str]) -> None:
         self.vocabtype = vocabtype
         if self.vocabtype == "bpe":
           self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read())
@@ -264,35 +268,72 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vo
         self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
         self.fname_tokenizer = fname_tokenizer
         self.fname_added_tokens = fname_added_tokens
+        self.special_tokens_map: Dict[int, str] = {}
+
+        TOKEN_NAME_TO_ID: Dict[str, int] = {
+            "unk_token": self.sentencepiece_tokenizer.unk_id(),
+            "bos_token": self.sentencepiece_tokenizer.bos_id(),
+            "eos_token": self.sentencepiece_tokenizer.eos_id(),
+            "pad_token": self.sentencepiece_tokenizer.pad_id()
+        }
+
+        tokenizer_config: Dict[str, Any]
+        if fname_tokenizer_config is not None:
+            tokenizer_config = json.load(open(fname_tokenizer_config))
+        else:
+            tokenizer_config = {}
+        for key, value in tokenizer_config.items():
+            if not isinstance(value, dict) and not isinstance(value, str):
+                continue
+            token_id = TOKEN_NAME_TO_ID.get(key, -1)
+            if token_id == -1:
+                continue
+            self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value
+
+        special_tokens: Dict[str, Any]
+        if fname_special_tokens is not None:
+            special_tokens = json.load(open(fname_special_tokens))
+        else:
+            special_tokens = {}
+        for key, value in special_tokens.items():
+            if not isinstance(value, dict) and not isinstance(value, str):
+                continue
+            token_id = TOKEN_NAME_TO_ID.get(key, -1)
+            if token_id == -1 or token_id in self.special_tokens_map:
+                continue
+            self.special_tokens_map[token_id] = value["content"] if isinstance(value, dict) else value
 
     def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]:
         tokenizer = self.sentencepiece_tokenizer
         if self.vocabtype == "bpe":
-          from transformers.models.gpt2 import tokenization_gpt2
-          byte_encoder = tokenization_gpt2.bytes_to_unicode()
-          byte_decoder = {v: k for k, v in byte_encoder.items()}
-          for i, item in enumerate(tokenizer):
-            text: bytes
-            text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
-            score: float = -i
-            yield text, score
+            from transformers.models.gpt2 import tokenization_gpt2
+            byte_encoder = tokenization_gpt2.bytes_to_unicode()
+            byte_decoder = {v: k for k, v in byte_encoder.items()}
+            for i, item in enumerate(tokenizer):
+                text: bytes
+                text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]])
+                score: float = -i
+                yield text, score
         else:
-          for i in range(tokenizer.vocab_size()):
-              text: bytes
-              if tokenizer.is_unknown(i):
-                  text = " \u2047 ".encode("utf-8")
-              elif tokenizer.is_control(i):
-                  text = b""
-              elif tokenizer.is_byte(i):
-                  piece = tokenizer.id_to_piece(i)
-                  if len(piece) != 6:
-                      raise Exception(f"Invalid token: {piece}")
-                  byte_value = int(piece[3:-1], 16)
-                  text = struct.pack("B", byte_value)
-              else:
-                  text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
-              score: float = tokenizer.get_score(i)
-              yield text, score
+            special_tokens = [tokenizer.bos_id(), tokenizer.eos_id(), tokenizer.pad_id()]
+            for i in range(tokenizer.vocab_size()):
+                text: bytes
+                if tokenizer.is_unknown(i):
+                    text = self.special_tokens_map.get(i, " \u2047 ").encode("utf-8")
+                elif i in special_tokens:
+                    text = self.special_tokens_map.get(i, "").encode("utf-8")
+                elif tokenizer.is_control(i):
+                    text = b""
+                elif tokenizer.is_byte(i):
+                    piece = tokenizer.id_to_piece(i)
+                    if len(piece) != 6:
+                        raise Exception(f"Invalid token: {piece}")
+                    byte_value = int(piece[3:-1], 16)
+                    text = struct.pack("B", byte_value)
+                else:
+                    text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
+                score: float = tokenizer.get_score(i)
+                yield text, score
 
     def added_tokens(self) -> Iterable[Tuple[bytes, float]]:
         for text in self.added_tokens_list:
@@ -303,6 +344,12 @@ def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
         yield from self.sentencepiece_tokens()
         yield from self.added_tokens()
 
+    def all_special_tokens(self) -> Iterable[int]:
+        for token_id in self.special_tokens_map.keys():
+            yield token_id
+        for i in range(len(self.added_tokens_list)):
+            yield self.vocab_size_base + i
+
     def __repr__(self) -> str:
         return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
 
@@ -310,11 +357,16 @@ def __repr__(self) -> str:
 class GGMLVocab:
     def __init__(self, tokens: List[Tuple[bytes, float]]):
         self.tokens = tokens
+        self.special_tokens = []
         self.vocab_size = len(tokens)
+        self.vocab_size_base = 0
 
     def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
         return self.tokens
 
+    def all_special_tokens(self) -> Iterable[int]:
+        return self.special_tokens
+
     def __repr__(self) -> str:
         return f"<GGMLVocab with {self.vocab_size} tokens>"
 
@@ -1072,10 +1124,10 @@ def write_file_header(self, params: Params, file_type: GGMLFileType) -> None:
             params.n_mult,
             params.n_head,
             params.n_layer,
-            params.n_embd // params.n_head,  # rot (obsolete)
+            params.n_vocab_base | 0xF0000000, # reuse obsolete rot value to store vocab_base
             file_type.value,
         ]
-        self.fout.write(struct.pack("i" * len(values), *values))
+        self.fout.write(struct.pack("I" * len(values), *values))
 
     def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
         sname = name.encode('utf-8')
@@ -1093,7 +1145,8 @@ def write_vocab(self, vocab: Vocab) -> None:
     @staticmethod
     def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
         of = OutputFile(fname_out)
-        params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
+        params = Params(n_vocab=vocab.vocab_size, n_vocab_base=vocab.vocab_size_base, n_embd=0, n_mult=0,
+                        n_head=1, n_layer=0)
         of = OutputFile(fname_out)
         of.write_file_header(params, file_type=GGMLFileType.AllF32)
         of.write_vocab(vocab)
@@ -1249,8 +1302,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab:
                 f"Could not find tokenizer.model in {path} or its parent; "
                 "if it's in another directory, pass the directory as --vocab-dir")
     added_tokens_path = path.parent / "added_tokens.json"
+    special_tokens_path = path.parent / "special_tokens_map.json"
+    tokenizer_config_path = path.parent / "tokenizer_config.json"
     print(f"Loading vocab file {path}")
-    return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None,
+    return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, special_tokens_path if special_tokens_path.exists() else None, tokenizer_config_path if tokenizer_config_path.exists() else None,
                               vocabtype)
 
 
@@ -1313,6 +1368,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
             vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
             vocab = load_vocab(vocab_dir, args.vocabtype)
         params = Params.load(model_plus)
+        params.n_vocab_base = vocab.vocab_size_base
         model = model_plus.model
         model = do_necessary_conversions(model, params)
         output_type = pick_output_type(model, args.outtype)
diff --git a/llama.cpp b/llama.cpp
index 39aefd499dd0c..44104be66d710 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -181,13 +181,13 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
 // default hparams (LLaMA 7B)
 struct llama_hparams {
     uint32_t n_vocab   = 32000;
+    uint32_t n_vocab_base = 32000;
     uint32_t n_ctx     = 512;   // this is provided as user input?
     uint32_t n_embd    = 4096;
     uint32_t n_mult    = 256;
     uint32_t n_head    = 32;
     uint32_t n_head_kv = 32;
     uint32_t n_layer   = 32;
-    uint32_t n_rot     = 64;
 
     // LLaMAv2
     // TODO: load from model data hparams
@@ -277,6 +277,12 @@ struct llama_vocab {
 
     std::unordered_map<token, id> token_to_id;
     std::vector<token_score> id_to_token;
+
+    std::unordered_map<token, id> special_token_to_id;
+
+    void add_special_token(const token & word, id token_id) {
+        special_token_to_id[word] = token_id;
+    }
 };
 
 struct llama_model {
@@ -509,6 +515,7 @@ struct llama_file_loader {
         read_hparams();
         read_vocab();
         read_tensor_metadata(tensors_map);
+        set_vocab_sp();
     }
     void read_magic() {
         uint32_t magic = file.read_u32();
@@ -543,7 +550,8 @@ struct llama_file_loader {
         hparams.n_mult  = file.read_u32();
         hparams.n_head  = file.read_u32();
         hparams.n_layer = file.read_u32();
-        hparams.n_rot   = file.read_u32();
+        hparams.n_vocab_base = file.read_u32();
+        hparams.n_vocab_base = (hparams.n_vocab_base & 0xF0000000) == 0 ? hparams.n_vocab : (hparams.n_vocab_base & ~0xF0000000); // this bitwise operation is necessary for compatibility with older models
         hparams.ftype   = (enum llama_ftype) file.read_u32();
 
         // LLaMAv2
@@ -612,6 +620,17 @@ struct llama_file_loader {
             tensors_map.name_to_idx[name] = tensors_map.tensors.size() - 1;
         }
     }
+    void set_vocab_sp() {
+        uint32_t vocab_sp = 3 + hparams.n_vocab - hparams.n_vocab_base;
+        vocab.special_token_to_id.reserve(vocab_sp);
+        for (uint32_t i = 0; i < vocab_sp; i++) {
+            llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
+            const auto & word = vocab.id_to_token[token_id].tok;
+            if (!word.empty()) {
+                vocab.add_special_token(word, token_id);
+            }
+        }
+    }
 };
 
 struct llama_file_saver {
@@ -635,7 +654,7 @@ struct llama_file_saver {
         file.write_u32(hparams.n_mult);
         file.write_u32(hparams.n_head);
         file.write_u32(hparams.n_layer);
-        file.write_u32(hparams.n_rot);
+        file.write_u32(hparams.n_vocab_base | 0xF0000000); // this bitwise operation is necessary for compatibility with older models
         file.write_u32(new_ftype);
     }
     void write_vocab() {
@@ -1100,7 +1119,7 @@ static void llama_model_load_internal(
         fprintf(stderr, "%s: n_head     = %u\n",   __func__, hparams.n_head);
         fprintf(stderr, "%s: n_head_kv  = %u\n",   __func__, hparams.n_head_kv);
         fprintf(stderr, "%s: n_layer    = %u\n",   __func__, hparams.n_layer);
-        fprintf(stderr, "%s: n_rot      = %u\n",   __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
+        fprintf(stderr, "%s: n_rot      = %u\n",   __func__, hparams.n_embd/hparams.n_head); // a.k.a. n_embd_head, n_head_dim
         fprintf(stderr, "%s: n_gqa      = %u\n",   __func__, hparams.n_gqa());
         fprintf(stderr, "%s: rnorm_eps  = %.1e\n", __func__, hparams.f_rms_norm_eps);
         fprintf(stderr, "%s: n_ff       = %u\n",   __func__, n_ff);
@@ -1418,7 +1437,7 @@ static struct ggml_cgraph * llama_build_graph(
     const int64_t n_embd_head = hparams.n_embd_head();
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();
 
-    LLAMA_ASSERT(n_embd_head == hparams.n_rot);
+    LLAMA_ASSERT(n_embd_head == hparams.n_embd/hparams.n_head);
 
     const float freq_base  = hparams.rope_freq_base;
     const float freq_scale = hparams.rope_freq_scale;
@@ -1960,18 +1979,20 @@ struct llama_sp_bigram {
 struct llama_tokenizer {
     llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
 
-    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+    void tokenize(const char * text, size_t len, std::vector<llama_vocab::id> & output) {
+        symbols_.clear();
+
         // split string into utf8 chars
         int index = 0;
         size_t offs = 0;
-        while (offs < text.size()) {
+        while (offs < len) {
             llama_sp_symbol sym;
-            size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
-            sym.text = text.c_str() + offs;
+            size_t char_len = std::min(len - offs, utf8_len(text[offs]));
+            sym.text = text + offs;
             sym.n = char_len;
             offs += char_len;
             sym.prev = index - 1;
-            sym.next = offs == text.size() ? -1 : index + 1;
+            sym.next = offs == len ? -1 : index + 1;
             index++;
             symbols_.emplace_back(sym);
         }
@@ -2074,7 +2095,45 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
         output.push_back(llama_token_bos());
     }
 
-    tokenizer.tokenize(text, output);
+    if (vocab.special_token_to_id.empty()) {
+        tokenizer.tokenize(text.c_str(), text.size(), output);
+        return output;
+    }
+
+    size_t delim_start = 0;
+    size_t last_delim_end = 0;
+
+    while (delim_start < text.size()) {
+        size_t delim_end = 0;
+        llama_vocab::id token_id = -1;
+
+        for (const auto & mit : vocab.special_token_to_id) {
+            const std::string & delimiter = mit.first;
+            size_t end = delim_start + delimiter.size();
+            if (end <= text.size() && text.compare(delim_start, delimiter.size(), delimiter) == 0) {
+                if (token_id == -1 || end > delim_end) {
+                    token_id = mit.second;
+                    delim_end = end;
+                }
+            }
+        }
+
+        if (token_id != -1) {
+            if (last_delim_end < delim_start) {
+                tokenizer.tokenize(text.c_str() + last_delim_end, delim_start - last_delim_end, output);
+            }
+            output.push_back(token_id);
+            delim_start = delim_end;
+            last_delim_end = delim_end;
+        } else {
+            delim_start++;
+        }
+    }
+
+    if (last_delim_end < text.size()) {
+        tokenizer.tokenize(text.c_str() + last_delim_end, text.size() - last_delim_end, output);
+    }
+
     return output;
 }
 
@@ -4212,6 +4271,10 @@ llama_token llama_token_nl() {
     return 13;
 }
 
+void llama_add_special_token(struct llama_model * model, const char * token, llama_token token_id) {
+    model->vocab.add_special_token(token, token_id);
+}
+
 struct llama_timings llama_get_timings(struct llama_context * ctx) {
     struct llama_timings result = {
         /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
diff --git a/llama.h b/llama.h
index fa1977f2d9492..519ee716d0e63 100644
--- a/llama.h
+++ b/llama.h
@@ -373,6 +373,11 @@ extern "C" {
     LLAMA_API llama_token llama_token_eos();  // end-of-sentence
     LLAMA_API llama_token llama_token_nl();   // next-line
 
+    LLAMA_API void llama_add_special_token(
+              struct llama_model * model,
+                      const char * token,
+                      llama_token token_id);
+
     // Grammar
     //
     LLAMA_API struct llama_grammar * llama_grammar_init(
diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp
index 87fde16453d25..3472180343c24 100644
--- a/tests/test-tokenizer-0.cpp
+++ b/tests/test-tokenizer-0.cpp
@@ -14,6 +14,9 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests()
         { " this is 🦙.cpp",    { 1,    445,    338,  29871,    243,    162,    169,    156,  29889,   8223, }, },
         { "w048 7tuijk dsdfhu", { 1,  29893,  29900,  29946,  29947,  29871,  29955,   9161,  13535,  18031,   2176,   6905, }, },
         { "нещо на Български",  { 1,    821,   4851,    665,   1386,  29713,   1305, }, },
+        { "<🦙>test extra_id_1   test",  { 1, 32004,  1688,  29871,  32001,    259,   1243, }, },
+        { "<🦙>test extra_id_100 test",  { 1, 32004,  1688,  29871,  32002,   1243, }, },
+        { "<🦙>test extra_id_200 test",  { 1, 32004,  1688,  321,    32003,   1243, }, },
     };
     return _k_tests;
 };
@@ -46,6 +49,11 @@ int main(int argc, char **argv) {
             return 1;
         }
 
+        llama_add_special_token(model, "extra_id_1", 32001);
+        llama_add_special_token(model, "extra_id_100", 32002);
+        llama_add_special_token(model, "xtra_id_200", 32003);
+        llama_add_special_token(model, "<🦙>", 32004);
+
         ctx = llama_new_context_with_model(model, lparams);
 
         if (ctx == NULL) {