-
Notifications
You must be signed in to change notification settings - Fork 2
Add gpt2 implementations in python and c++ #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
General
To be continued... |
import numpy as np | ||
|
||
# download the tiny shakespeare dataset | ||
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use if name main pattern to make this module reusable also to prevent wonky variable scoping
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') | ||
if not os.path.exists(input_file_path): | ||
data_url = 'https://github.com/raw/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' | ||
with open(input_file_path, 'w', encoding='utf-8') as f: | ||
f.write(requests.get(data_url).text) | ||
|
||
with open(input_file_path, 'r', encoding='utf-8') as f: | ||
data = f.read() | ||
n = len(data) | ||
train_data = data[:int(n*0.9)] | ||
val_data = data[int(n*0.9):] | ||
|
||
# encode with tiktoken gpt2 bpe | ||
enc = tiktoken.get_encoding("gpt2") | ||
train_ids = enc.encode_ordinary(train_data) | ||
val_ids = enc.encode_ordinary(val_data) | ||
print(f"train has {len(train_ids):,} tokens") | ||
print(f"val has {len(val_ids):,} tokens") | ||
|
||
# export to bin files | ||
train_ids = np.array(train_ids, dtype=np.uint16) | ||
val_ids = np.array(val_ids, dtype=np.uint16) | ||
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) | ||
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') | |
if not os.path.exists(input_file_path): | |
data_url = 'https://github.com/raw/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' | |
with open(input_file_path, 'w', encoding='utf-8') as f: | |
f.write(requests.get(data_url).text) | |
with open(input_file_path, 'r', encoding='utf-8') as f: | |
data = f.read() | |
n = len(data) | |
train_data = data[:int(n*0.9)] | |
val_data = data[int(n*0.9):] | |
# encode with tiktoken gpt2 bpe | |
enc = tiktoken.get_encoding("gpt2") | |
train_ids = enc.encode_ordinary(train_data) | |
val_ids = enc.encode_ordinary(val_data) | |
print(f"train has {len(train_ids):,} tokens") | |
print(f"val has {len(val_ids):,} tokens") | |
# export to bin files | |
train_ids = np.array(train_ids, dtype=np.uint16) | |
val_ids = np.array(val_ids, dtype=np.uint16) | |
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) | |
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) | |
def main(): | |
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') | |
if not os.path.exists(input_file_path): | |
data_url = 'https://github.com/raw/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' | |
with open(input_file_path, 'w', encoding='utf-8') as f: | |
f.write(requests.get(data_url).text) | |
with open(input_file_path, 'r', encoding='utf-8') as f: | |
data = f.read() | |
n = len(data) | |
train_data = data[:int(n*0.9)] | |
val_data = data[int(n*0.9):] | |
# encode with tiktoken gpt2 bpe | |
enc = tiktoken.get_encoding("gpt2") | |
train_ids = enc.encode_ordinary(train_data) | |
val_ids = enc.encode_ordinary(val_data) | |
print(f"train has {len(train_ids):,} tokens") | |
print(f"val has {len(val_ids):,} tokens") | |
# export to bin files | |
train_ids = np.array(train_ids, dtype=np.uint16) | |
val_ids = np.array(val_ids, dtype=np.uint16) | |
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) | |
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) | |
if __name__ == '__main__': | |
main() |
from torch.nn import functional as F | ||
import tiktoken | ||
|
||
batch_size = 64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dont use global variables
|
||
return logits,loss | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use if __name__ == "__main__":
|
||
def __init__weights(self, module): | ||
if isinstance(module, nn.Linear): | ||
std = 0.02 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this to else block. This is causing confusion.
}; | ||
|
||
struct CausalSelfAttentionImpl : torch::nn::Module { | ||
CausalSelfAttentionImpl(const Config& cfg) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer initizliser list. This way your compiler can warn you if there are uninitialized variables
mask = m; | ||
register_buffer("mask", mask); | ||
} | ||
torch::Tensor forward(const torch::Tensor& x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like east const. Checkout https://hackingcpp.com/cpp/design/east_vs_west_const.html
auto out = y.permute({0, 2, 1, 3}).contiguous().view({B, T, n_embed}); | ||
return proj->forward(out); | ||
} | ||
int64_t n_embed, n_head, head_dim; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typically in C++ member variables have a m_
prefix. This prevents variable shadowing
}; | ||
TORCH_MODULE(CausalSelfAttention); | ||
|
||
struct MLPImpl : torch::nn::Module { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mark this as final
so no one accidently inherets.
TORCH_MODULE(GPT); | ||
|
||
int main() { | ||
Config cfg; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer auto initialization. This prevents uninitialized garbage values
Adds two files:
Adds dataloaders as well as requirements file as well.