Skip to content

Conversation

ChinmayK0607
Copy link
Member

Adds two files:

  1. train.cpp -> gpt2 implementation in c++
  2. train.py -> gpt2 implementation in python

Adds dataloaders as well as requirements file as well.

@kamatajinkya2
Copy link

kamatajinkya2 commented Jul 8, 2025

General

  1. Prefer pyproject.toml over requirements.txt. Refer to Why Should I Choose pyproject.toml over requirements.txt for managing dependencies?
  2. Use a nested folder structure. Aka train.py, dataloader.py go inside src or llm101 or scripts folder. This will help in adding a test folder
  3. Implement unit tests (Not applicable in this instance, but just typing out)
  4. Have appropriate white spaces after a class of a function ends. You can use autoformatters like Black to achieve this.

To be continued...

import numpy as np

# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use if name main pattern to make this module reusable also to prevent wonky variable scoping

Comment on lines +7 to +30
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
if not os.path.exists(input_file_path):
data_url = 'https://github.com/raw/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(input_file_path, 'w', encoding='utf-8') as f:
f.write(requests.get(data_url).text)

with open(input_file_path, 'r', encoding='utf-8') as f:
data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
if not os.path.exists(input_file_path):
data_url = 'https://github.com/raw/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(input_file_path, 'w', encoding='utf-8') as f:
f.write(requests.get(data_url).text)
with open(input_file_path, 'r', encoding='utf-8') as f:
data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]
# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")
# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
def main():
input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
if not os.path.exists(input_file_path):
data_url = 'https://github.com/raw/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
with open(input_file_path, 'w', encoding='utf-8') as f:
f.write(requests.get(data_url).text)
with open(input_file_path, 'r', encoding='utf-8') as f:
data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]
# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")
# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
if __name__ == '__main__':
main()

from torch.nn import functional as F
import tiktoken

batch_size = 64

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont use global variables


return logits,loss


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use if __name__ == "__main__":


def __init__weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this to else block. This is causing confusion.

};

struct CausalSelfAttentionImpl : torch::nn::Module {
CausalSelfAttentionImpl(const Config& cfg) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer initizliser list. This way your compiler can warn you if there are uninitialized variables

mask = m;
register_buffer("mask", mask);
}
torch::Tensor forward(const torch::Tensor& x) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto out = y.permute({0, 2, 1, 3}).contiguous().view({B, T, n_embed});
return proj->forward(out);
}
int64_t n_embed, n_head, head_dim;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically in C++ member variables have a m_ prefix. This prevents variable shadowing

};
TORCH_MODULE(CausalSelfAttention);

struct MLPImpl : torch::nn::Module {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark this as final so no one accidently inherets.

TORCH_MODULE(GPT);

int main() {
Config cfg;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer auto initialization. This prevents uninitialized garbage values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants