Skip to content

Feature Request: Nemotron-4-340B-Instruct Support #7966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
4 tasks done
rankaiyx opened this issue Jun 17, 2024 · 13 comments
Closed
4 tasks done

Feature Request: Nemotron-4-340B-Instruct Support #7966

rankaiyx opened this issue Jun 17, 2024 · 13 comments
Labels
enhancement New feature or request stale

Comments

@rankaiyx
Copy link
Contributor

rankaiyx commented Jun 17, 2024

Prerequisites

  • I am running the latest code. Mention the version if possible as well.
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new and useful enhancement to share.

Feature Description

A super-huge new model from Nvidia
https://huggingface.co/nvidia/Nemotron-4-340B-Instruct

Nemotron-4-340B-Instruct is a large language model (LLM) that can be used as part of a synthetic data generation pipeline to create training data that helps researchers and developers build their own LLMs. It is a fine-tuned version of the Nemotron-4-340B-Base model, optimized for English-based single and multi-turn chat use-cases. It supports a context length of 4,096 tokens.

Motivation

Because the mountain was there.
But it may have no practical value because of the ratio of performance to price.

Possible Implementation

No response

@rankaiyx rankaiyx added the enhancement New feature or request label Jun 17, 2024
@Yorizuka
Copy link

Yorizuka commented Jun 17, 2024

Well even if its not something that most can run at home, it would still be really useful for people who can deploy it. Big GPUs can be rented in the cloud. This model feels to me like its going to be a game changer!

llama.cpp is simply the least headache inducing a way of running any LLM, Renting for this model is going to be expensive and not having to fiddle with jank is nice. I also wonder how well the AMD MI300x would be.

@rankaiyx
Copy link
Contributor Author

According to the fact that the Q4 quantized 34B model requires 20g RAM,
the Q4 quantized 340B model should be able to run on a computer with 256G RAM.

@fairydreaming
Copy link
Collaborator

I started working on this a few days ago and so far it's going well. Will post the code on github branch after cleaning it up a bit.
https://www.youtube.com/watch?v=TX0eppc88TU

@fairydreaming
Copy link
Collaborator

My code for brave souls:

  1. Model conversion script based on @FailSpy earlier work: https://github.com/fairydreaming/export-nemo-to-safetensors
  2. llama.cpp Nemotron 4 branch: https://github.com/fairydreaming/llama.cpp/tree/nemotron

There is a new tokenizer in the code (it's SentencePiece BPE modified to handle user tokens). It's done this way because when I ran the original model in the NeMo framework it passed the whole prompt to SentencePiece tokenizer as a single string without any special token preprocessing, so I did the same (parse_special is currently hardcoded to false). I wonder if it's possible to do it in a simpler way without adding a new tokenizer, I need to research this a bit more.

./llama-cli --temp 0.01 --numa distribute -t 32 -m /mnt/md0/models/nemotron-4-340b-Q8_0.gguf -f ../prompts/prompt-nemotron-5.txt
Log start
main: build = 3375 (a815e06c)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed  = 1720905020
llama_model_loader: loaded meta data with 27 key-value pairs and 772 tensors from /mnt/md0/models/nemotron-4-340b-Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = nemotron4
llama_model_loader: - kv   1:                               general.name str              = output_dir
llama_model_loader: - kv   2:                      nemotron4.block_count u32              = 96
llama_model_loader: - kv   3:                   nemotron4.context_length u32              = 4096
llama_model_loader: - kv   4:                 nemotron4.embedding_length u32              = 18432
llama_model_loader: - kv   5:              nemotron4.feed_forward_length u32              = 73728
llama_model_loader: - kv   6:             nemotron4.rope.dimension_count u32              = 96
llama_model_loader: - kv   7:             nemotron4.attention.head_count u32              = 96
llama_model_loader: - kv   8:          nemotron4.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:                   nemotron4.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  10:     nemotron4.attention.layer_norm_epsilon f32              = 0.000010
llama_model_loader: - kv  11:                          general.file_type u32              = 7
llama_model_loader: - kv  12:                       tokenizer.ggml.model str              = nemotron
llama_model_loader: - kv  13:                         tokenizer.ggml.pre str              = default
llama_model_loader: - kv  14:                      tokenizer.ggml.tokens arr[str,256000]  = ["<pad>", "<unk>", "<s>", "</s>", "<e...
llama_model_loader: - kv  15:                      tokenizer.ggml.scores arr[f32,256000]  = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  16:                  tokenizer.ggml.token_type arr[i32,256000]  = [3, 2, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, ...
llama_model_loader: - kv  17:            tokenizer.ggml.add_space_prefix bool             = true
llama_model_loader: - kv  18:    tokenizer.ggml.remove_extra_whitespaces bool             = false
llama_model_loader: - kv  19:                tokenizer.ggml.bos_token_id u32              = 2
llama_model_loader: - kv  20:                tokenizer.ggml.eos_token_id u32              = 3
llama_model_loader: - kv  21:            tokenizer.ggml.unknown_token_id u32              = 1
llama_model_loader: - kv  22:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  23:                tokenizer.ggml.eot_token_id u32              = 5
llama_model_loader: - kv  24:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  25:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  26:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  386 tensors
llama_model_loader: - type q8_0:  386 tensors
llm_load_vocab: special tokens cache size = 1260
llm_load_vocab: token to piece cache size = 1.7854 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = nemotron4
llm_load_print_meta: vocab type       = NTN
llm_load_print_meta: n_vocab          = 256000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 4096
llm_load_print_meta: n_embd           = 18432
llm_load_print_meta: n_layer          = 96
llm_load_print_meta: n_head           = 96
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 96
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 192
llm_load_print_meta: n_embd_head_v    = 192
llm_load_print_meta: n_gqa            = 12
llm_load_print_meta: n_embd_k_gqa     = 1536
llm_load_print_meta: n_embd_v_gqa     = 1536
llm_load_print_meta: f_norm_eps       = 1.0e-05
llm_load_print_meta: f_norm_rms_eps   = 0.0e+00
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 73728
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 2
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 4096
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = Q8_0
llm_load_print_meta: model params     = 341.03 B
llm_load_print_meta: model size       = 337.48 GiB (8.50 BPW) 
llm_load_print_meta: general.name     = output_dir
llm_load_print_meta: BOS token        = 2 '<s>'
llm_load_print_meta: EOS token        = 3 '</s>'
llm_load_print_meta: UNK token        = 1 '<unk>'
llm_load_print_meta: PAD token        = 0 '<pad>'
llm_load_print_meta: LF token         = 1014 '<0x0A>'
llm_load_print_meta: EOT token        = 5 '<extra_id_1>'
llm_load_print_meta: max token length = 48
llm_load_tensors: ggml ctx size =    0.37 MiB
llm_load_tensors:        CPU buffer size = 345577.64 MiB
....................................................................................................
llama_new_context_with_model: n_ctx      = 4096
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =  2304.00 MiB
llama_new_context_with_model: KV self size  = 2304.00 MiB, K (f16): 1152.00 MiB, V (f16): 1152.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.98 MiB
ggml_gallocr_reserve_n: reallocating CPU buffer from size 0.00 MiB to 926.01 MiB
llama_new_context_with_model:        CPU compute buffer size =   926.01 MiB
llama_new_context_with_model: graph nodes  = 3559
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 32 / 64 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.010
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 4096, n_batch = 2048, n_predict = -1, n_keep = 1


 <extra_id_0>System

<extra_id_1>User
What is 2 + 2?
<extra_id_1>Assistant
The answer is 4.
<extra_id_1> [end of text]

llama_print_timings:        load time =    1971.72 ms
llama_print_timings:      sample time =       0.72 ms /     8 runs   (    0.09 ms per token, 11142.06 tokens per second)
llama_print_timings: prompt eval time =    5752.64 ms /    21 tokens (  273.94 ms per token,     3.65 tokens per second)
llama_print_timings:        eval time =    7444.00 ms /     7 runs   ( 1063.43 ms per token,     0.94 tokens per second)
llama_print_timings:       total time =   13205.42 ms /    28 tokens
Log end

@leafspark
Copy link

  1. Model conversion script based on @FailSpy earlier work: https://github.com/fairydreaming/export-nemo-to-safetensors

Getting an error when trying to run the NeMo > safetensors conversion script:

F:\LocalLLMs\nvidia>python convert-nemo.py Nemotron-4-340B-Instruct Nemotron-4-safetensors
Reading model_config.yaml
Reading tokenizer model from Nemotron-4-340B-Instruct\eb5528fdec5c4083affa2c97958eeef7_megatron_2.model
Writing tokenizer model
Writing added_tokens.json
Creating config.json
Writing config.json
Exporting 96 layers
converting embedding.word_embeddings.weight of shape (1, 256000, 18432) to embedding.word_embeddings.weight
Traceback (most recent call last):
  File "F:\LocalLLMs\nvidia\convert-nemo.py", line 205, in <module>
    convert_nemo_model(args.model_dir, args.output_dir)
  File "F:\LocalLLMs\nvidia\convert-nemo.py", line 178, in convert_nemo_model
    save_file(sharded_state_dict,output_dir/fname)
  File "C:\Users\chenx\miniconda3\Lib\site-packages\safetensors\torch.py", line 284, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
safetensors_rust.SafetensorError: Error preparing tensor view: InvalidTensorView(BF16, [1, 256000, 18432], 847249408)

@fairydreaming
Copy link
Collaborator

@leafspark I have no idea what's wrong, maybe try installing the exact versions of packages that I used:
convert-nemo-conda-pkgs.txt
I installed latest versions of packages from conda-forge.

@leafspark
Copy link

Ended up fixing it by bypassing the model.embedding.word_embeddings.weight tensor and separately handling it; thanks!

@fairydreaming
Copy link
Collaborator

fairydreaming commented Jul 14, 2024

@leafspark But this 847249408 number looks worrying (it's the length of the tensor data buffer), make sure that your model is fully downloaded. This tensor shall have buffer size of 9437184000 (there are 8 files in model.embedding.word_embeddings.weight directory, each file with 1179648000 bytes).

@leafspark
Copy link

leafspark commented Jul 15, 2024

I verified the sha256 of all the files; they matched, but unfortunately I was unable to find the issue. (I assume the Windows build of safetensors has a problem) For anyone else having the same issue I used failspy's original safetensors and wrote a script to rename the tensors.

@HanClinto
Copy link
Collaborator

How much overlap is there between Nemotron and Mistral NeMo?

https://mistral.ai/news/mistral-nemo/

The Mistral blog post says that the model was developed in conjunction with NVidia, so it looks like it might be related...? I'm rather unfamiliar, so having a hard time telling how much overlap there is between the two models.

#8577 is set up to track NeMo support in llama.cpp.

@fairydreaming
Copy link
Collaborator

How much overlap is there between Nemotron and Mistral NeMo?

https://mistral.ai/news/mistral-nemo/

The Mistral blog post says that the model was developed in conjunction with NVidia, so it looks like it might be related...? I'm rather unfamiliar, so having a hard time telling how much overlap there is between the two models.

I think there's basically no overlap between the two. The only thing in common is that it's possible to run them in NVIDIA NeMo framework, but it doesn't imply anything specific.

@github-actions github-actions bot added the stale label Aug 19, 2024
Copy link
Contributor

github-actions bot commented Sep 2, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Sep 2, 2024
@pandruszkow
Copy link

Any plans to add support at some point in the future? Or should this be considered a WONTFIX?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request stale
Projects
None yet
Development

No branches or pull requests

6 participants