-
Notifications
You must be signed in to change notification settings - Fork 2
Add gpt2 implementations in python and c++ #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
Initial commit | ||
# Experiments with GPT2 architecture | ||
|
||
1. for `train.cpp` tokenize the data using tiktoken or use bins of tokenized data | ||
2. for `train.py` run data.py in the same directory as train.py and then run it. | ||
3. Use `uv` package manager for dependency handling |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,33 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import requests | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import tiktoken | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# download the tiny shakespeare dataset | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use if name main pattern to make this module reusable also to prevent wonky variable scoping |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if not os.path.exists(input_file_path): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data_url = 'https://github.com/raw/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with open(input_file_path, 'w', encoding='utf-8') as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f.write(requests.get(data_url).text) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with open(input_file_path, 'r', encoding='utf-8') as f: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
data = f.read() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
n = len(data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_data = data[:int(n*0.9)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
val_data = data[int(n*0.9):] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# encode with tiktoken gpt2 bpe | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
enc = tiktoken.get_encoding("gpt2") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_ids = enc.encode_ordinary(train_data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
val_ids = enc.encode_ordinary(val_data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"train has {len(train_ids):,} tokens") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"val has {len(val_ids):,} tokens") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# export to bin files | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_ids = np.array(train_ids, dtype=np.uint16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
val_ids = np.array(val_ids, dtype=np.uint16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+7
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# train.bin has 301,966 tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# val.bin has 36,059 tokens |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
uv | ||
torch | ||
pandas | ||
tiktoken | ||
mlflow | ||
wandb | ||
tqdm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
#include <torch/torch.h> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
#include <fstream> | ||
#include <vector> | ||
#include <iostream> | ||
#include <cassert> | ||
|
||
struct Config { | ||
int64_t block_size = 256; | ||
int64_t vocab_size = 50257; | ||
int64_t n_layer = 6; | ||
int64_t n_head = 6; | ||
int64_t n_embed = 384; | ||
}; | ||
|
||
struct TextLoader { | ||
TextLoader(int64_t B, int64_t T) : B(B), T(T), cur_pos(0) { | ||
std::ifstream in("/teamspace/studios/this_studio/input.txt"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be a |
||
int64_t tok; | ||
while (in >> tok) tokens.push_back(tok); | ||
size_t N = tokens.size(); | ||
std::cout << "loaded " << N << " tokens\n"; | ||
std::cout << "processing " << (N / (B * T)) << " batches per epoch\n"; | ||
} | ||
std::pair<torch::Tensor, torch::Tensor> next_batch() { | ||
int64_t span = B * T + 1; | ||
if (cur_pos + span > tokens.size()) cur_pos = 0; | ||
auto start = tokens.begin() + cur_pos; | ||
std::vector<int64_t> buf(start, start + span); | ||
cur_pos += B * T; | ||
auto data = torch::from_blob(buf.data(), {span}, torch::kInt64).clone(); | ||
auto x = data.narrow(0, 0, B * T).view({B, T}); | ||
auto y = data.narrow(0, 1, B * T).view({B, T}); | ||
return {x, y}; | ||
} | ||
int64_t B, T, cur_pos; | ||
std::vector<int64_t> tokens; | ||
}; | ||
|
||
struct CausalSelfAttentionImpl : torch::nn::Module { | ||
CausalSelfAttentionImpl(const Config& cfg) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer initizliser list. This way your compiler can warn you if there are uninitialized variables |
||
n_embed = cfg.n_embed; | ||
n_head = cfg.n_head; | ||
head_dim = n_embed / n_head; | ||
qkv = register_module("qkv", torch::nn::Linear(n_embed, 3 * n_embed)); | ||
proj = register_module("proj", torch::nn::Linear(n_embed, n_embed)); | ||
auto m = torch::ones({cfg.block_size, cfg.block_size}, torch::kBool).tril().view({1, 1, cfg.block_size, cfg.block_size}); | ||
mask = m; | ||
register_buffer("mask", mask); | ||
} | ||
torch::Tensor forward(const torch::Tensor& x) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like east const. Checkout https://hackingcpp.com/cpp/design/east_vs_west_const.html |
||
auto B = x.size(0), T = x.size(1); | ||
auto qkv_out = qkv->forward(x).view({B, T, 3, n_head, head_dim}); | ||
auto q = qkv_out.select(2, 0).permute({0, 2, 1, 3}); | ||
auto k = qkv_out.select(2, 1).permute({0, 2, 1, 3}); | ||
auto v = qkv_out.select(2, 2).permute({0, 2, 1, 3}); | ||
auto y = at::scaled_dot_product_attention(q, k, v, mask, 0.0, false); | ||
auto out = y.permute({0, 2, 1, 3}).contiguous().view({B, T, n_embed}); | ||
return proj->forward(out); | ||
} | ||
int64_t n_embed, n_head, head_dim; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typically in C++ member variables have a |
||
torch::nn::Linear qkv{nullptr}, proj{nullptr}; | ||
torch::Tensor mask; | ||
}; | ||
TORCH_MODULE(CausalSelfAttention); | ||
|
||
struct MLPImpl : torch::nn::Module { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark this as |
||
MLPImpl(const Config& cfg) { | ||
fc = register_module("fc", torch::nn::Linear(cfg.n_embed, 4 * cfg.n_embed)); | ||
act = torch::nn::GELU(); | ||
proj = register_module("proj", torch::nn::Linear(4 * cfg.n_embed, cfg.n_embed)); | ||
} | ||
torch::Tensor forward(torch::Tensor x) { | ||
x = fc->forward(x); | ||
x = act(x); | ||
return proj->forward(x); | ||
} | ||
torch::nn::Linear fc{nullptr}, proj{nullptr}; | ||
torch::nn::GELU act; | ||
}; | ||
TORCH_MODULE(MLP); | ||
|
||
struct BlockImpl : torch::nn::Module { | ||
BlockImpl(const Config& cfg) { | ||
ln1 = register_module("ln1", torch::nn::LayerNorm(torch::nn::LayerNormOptions({cfg.n_embed}))); | ||
attn = register_module("attn", CausalSelfAttention(cfg)); | ||
ln2 = register_module("ln2", torch::nn::LayerNorm(torch::nn::LayerNormOptions({cfg.n_embed}))); | ||
mlp = register_module("mlp", MLP(cfg)); | ||
} | ||
torch::Tensor forward(torch::Tensor x) { | ||
x = x + attn->forward(ln1->forward(x)); | ||
x = x + mlp->forward(ln2->forward(x)); | ||
return x; | ||
} | ||
torch::nn::LayerNorm ln1{nullptr}, ln2{nullptr}; | ||
CausalSelfAttention attn{nullptr}; | ||
MLP mlp{nullptr}; | ||
}; | ||
TORCH_MODULE(Block); | ||
|
||
struct GPTImpl : torch::nn::Module { | ||
GPTImpl(const Config& cfg) { | ||
wte = register_module("wte", torch::nn::Embedding(cfg.vocab_size, cfg.n_embed)); | ||
wpe = register_module("wpe", torch::nn::Embedding(cfg.block_size, cfg.n_embed)); | ||
for (int i = 0; i < cfg.n_layer; ++i) blocks->push_back(Block(cfg)); | ||
register_module("blocks", blocks); | ||
ln_f = register_module("ln_f", torch::nn::LayerNorm(torch::nn::LayerNormOptions({cfg.n_embed}))); | ||
lm_head = register_module("lm_head", torch::nn::Linear(torch::nn::LinearOptions(cfg.n_embed, cfg.vocab_size).bias(false))); | ||
lm_head->weight = wte->weight; | ||
apply([](torch::nn::Module& m) { | ||
if (auto* L = m.as<torch::nn::Linear>()) | ||
torch::nn::init::normal_(L->weight, 0.0, 0.02); | ||
}); | ||
} | ||
std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor idx, torch::Tensor targets) { | ||
auto B = idx.size(0), T = idx.size(1); | ||
auto pos = torch::arange(0, T, torch::kLong).to(idx.device()); | ||
auto x = wte->forward(idx) + wpe->forward(pos); | ||
for (auto& m : *blocks) x = m->as<Block>()->forward(x); | ||
x = ln_f->forward(x); | ||
auto logits = lm_head->forward(x); | ||
torch::Tensor loss; | ||
if (targets.defined()) { | ||
loss = torch::nn::functional::cross_entropy(logits.view({-1, logits.size(-1)}), targets.view(-1)); | ||
} | ||
return {logits, loss}; | ||
} | ||
torch::nn::Embedding wte{nullptr}, wpe{nullptr}; | ||
torch::nn::ModuleList blocks; | ||
torch::nn::LayerNorm ln_f{nullptr}; | ||
torch::nn::Linear lm_head{nullptr}; | ||
}; | ||
TORCH_MODULE(GPT); | ||
|
||
int main() { | ||
Config cfg; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer auto initialization. This prevents uninitialized garbage values |
||
int64_t B = 64, T = cfg.block_size; | ||
TextLoader loader(B, T); | ||
auto device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU; | ||
GPT model(cfg); | ||
model->to(device); | ||
torch::optim::AdamW opt(model->parameters(), 3e-4); | ||
for (int i = 0; i < 50; ++i) { | ||
auto [x, y] = loader.next_batch(); | ||
x = x.to(device); | ||
y = y.to(device); | ||
opt.zero_grad(); | ||
auto [logits, loss] = model->forward(x, y); | ||
loss.backward(); | ||
opt.step(); | ||
std::cout << "step " << i << " loss " << loss.item<double>() << '\n'; | ||
} | ||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this script load data or does it prepare the data? Wont something like
prepare.py
be a better name?