From c43e777256804b720c89b40342481f499de13671 Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Mon, 5 Jul 2021 00:51:01 +0530 Subject: [PATCH 1/8] Phoneme detection and classifier model codes --- applications/KWS_Phoneme/README.md | 43 ++ .../KWS_Phoneme/auxiliary_files/README.md | 29 + .../auxiliary_files/convert_sampling_rate.py | 42 ++ .../auxiliary_files/download_youtube_data.py | 39 ++ applications/KWS_Phoneme/data_pipe.py | 535 +++++++++++++++++ applications/KWS_Phoneme/kwscnn.py | 559 ++++++++++++++++++ applications/KWS_Phoneme/train_classifier.py | 293 +++++++++ applications/KWS_Phoneme/train_phoneme.py | 209 +++++++ applications/KWS_Phoneme/utils.py | 40 ++ 9 files changed, 1789 insertions(+) create mode 100644 applications/KWS_Phoneme/README.md create mode 100644 applications/KWS_Phoneme/auxiliary_files/README.md create mode 100644 applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py create mode 100644 applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py create mode 100644 applications/KWS_Phoneme/data_pipe.py create mode 100644 applications/KWS_Phoneme/kwscnn.py create mode 100644 applications/KWS_Phoneme/train_classifier.py create mode 100644 applications/KWS_Phoneme/train_phoneme.py create mode 100644 applications/KWS_Phoneme/utils.py diff --git a/applications/KWS_Phoneme/README.md b/applications/KWS_Phoneme/README.md new file mode 100644 index 000000000..b610621ce --- /dev/null +++ b/applications/KWS_Phoneme/README.md @@ -0,0 +1,43 @@ +# Phoneme based Keyword Spotting(KWS) + +# Project Description +There are two major issues in the existing KWS systems (a) they are not robust to heavy background noise and random utterances, and (b) they require collecting a lot of data, hampering the ease of adding a new keyword. Tackling these issues from a different perspective, we propose a new two staged scheme with a model for predicting phonemes which are in turn used for phoneme based keyword classification. + +This gives two advantages: (a) The phoneme model is trained to account for diverse accents and background noise settings, thus the flexible keyword classifier training requires only a small number of keyword samples, and (b) Empirically this method was able to detect keywords from as far as 9ft of distance. Further, the phoneme model has a small size of around 250k parameters and can fit on a Cortex M7 microcontroller. + +# Training the Phoneme Classifier +1) Train a phoneme classification model on some public speech dataset like Librespeech +2) Training speech dataset can be labelled using Montreal Force Aligner +3) Speech snippets are convolved with reverberation files, and additive noises from youtube or other open source are added +4) We also add white gaussian noise of various SNRs + +# Training the KWS Model +1) Our method takes as input the speech snippet and passes it through the phoneme classifier +2) Keywords are detected by training a keyword classifier over the detected phonemes +3) For training the keyword classifier, we use Azure and Google Text-To-Speech API to get the training data (keyword snippets) + +# Sample Use Cases + +## Phoneme Model Training +The following command can be used to instantiate and train the phoneme model. +``` +python train_phoneme.py --base_path=/path/to/librespeech_data/ --rir_base_path=/path/to/reverb_files/ --additive_base_path=/path/to/additive_noises/ --snr_samples="0,5,10,25,100,100" --rir_chance=0.5 +``` +Some important command line arguments: +1) base_path : Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here. +2) rir_base_path, additive_base_path : Path to the reverb and additive noise files +3) snr_samples : List of various SNRs at which the additive noise is to be added. +4) rir_chance : Probability at which reverberation has to be done for each speech sample + +## Classifier Model Training +The following command can be used to instantiate and train the classifier model. +``` +python train_classifier.py --base_path=/path/to/train_and_test_data_folders/ --train_data_folders=google30_azure_tts,google30_google_tts --test_data_folders=google30_test --phoneme_model_load_ckpt=/path/to/checkpoint/x.pt --rir_base_path=/mnt/reverb_noise_sampled/ --additive_base_path=/mnt/add_noises_sampled/ --synth +``` +Some important command line arguments: + +1) base_path : path to train and test data folders +2) train_data_folders, test_data_folders : These folders should have the .wav files for each keyword in a separate subfolder inside according to the dataloader here +3) phoneme_model_load_ckpt : The full path of the checkpoint file that would be used to load the weights to the instantiated phoneme model +4) rir_base_path, additive_base_path : Path to the reverb and additive noise files +5) synth : Boolean flag for specifying if reverberations and noise addition has to be done diff --git a/applications/KWS_Phoneme/auxiliary_files/README.md b/applications/KWS_Phoneme/auxiliary_files/README.md new file mode 100644 index 000000000..a56a617ae --- /dev/null +++ b/applications/KWS_Phoneme/auxiliary_files/README.md @@ -0,0 +1,29 @@ +# Auxiliary Files to help Download and Prepare the Data + +## Note +When running commands it is recommended to use the following format to run the files uninterrupted (detached) and log the output. +``` +nohup python srcipt_execution args > log.txt & +``` +Please replace script_execution with the python commands below.
+Alternately tmux or other commands can be used in place of the above format. + +## YouTube Additive Noise +Run the following commands to download the CSV Files to download the YouTube Additive Noise Data (there is no need to use nohup for the wget file) : + +``` +wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv +``` +Followed by the extraction script to download the actual data : +``` +python download_youtube_data.py --csv_file=/path/to/csv_file.csv --target_folder=/path/to/target/folder/ +``` + +Please check [Google's Audioset data page](https://research.google.com/audioset/download.html) for further details. + +The downloaded files would need to be converted to 16KHz for our pipeline. Please run the following for the same : +``` +python convert_sampling_rate.py --source_folder=/path/to/csv_file.csv --target_folder=/path/to/target/16KHz_folder/ --fs=16000 --log_rate=100 +``` +The script can convert the sampling rate of any wav file to the specified --fs. But for our applications, we use 16KHz only.
+Choose the log rate for how often the log should be printed for the sample rate converion. This will print a strinng ever log_rate iterations. \ No newline at end of file diff --git a/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py b/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py new file mode 100644 index 000000000..effb35f5c --- /dev/null +++ b/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py @@ -0,0 +1,42 @@ +import os +import librosa +import numpy as np +import soundfile as sf +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--source_folder', default=None, required=True) +parser.add_argument('--target_folder', default=None, required=True) +parser.add_argument('--fs', type=int, default=16000) +parser.add_argument('--log_rate', type=int, default=1000) +args = parser.parse_args() + +source_folder = args.source_folder +target_folder = args.target_folder +fs = args.fs +log_rate = args.log_rate +print(f'Source Folder :: {source_folder}\nTarget Folder :: {target_folder}\nSampling Frequency :: {fs}', flush=True) + +source_files = [] +target_files = [] +list_completed = [] + +# Get the list of list of wav files from source folder and create target file names (full paths) +for i, f in enumerate(os.listdir(source_folder)): + if f[-4:].lower() == '.wav': + source_files.append(os.path.join(source_folder, f)) + target_files.append(os.path.join(target_folder, f)) +print(f'Saved all the file paths, Number of files = {len(source_files)}', flush=True) + +# Convert the files to args.fs +# Read with librosa and write the mono channel audio using soundfile +print(f'Converting all files to {fs/1000} Khz', flush=True) +for i, file_path in enumerate(source_files): + y, sr = librosa.load(file_path, sr=fs, mono=True) + sf.write(target_files[i], y, sr) + list_completed.append(target_files[i]) + if i % log_rate == 0: + print(f'File Number {i+1}, Shape of Audio {y.shape}, Sampling Frequency {sr}', flush=True) + +print(f'Number of Files saved {len(list_completed)}') +print('Done', flush=True) diff --git a/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py b/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py new file mode 100644 index 000000000..eede85d57 --- /dev/null +++ b/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py @@ -0,0 +1,39 @@ +import csv +import os +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--csv_file', default=None, required=True) +parser.add_argument('--target_folder', default=None, required=True) +args = parser.parse_args() + +with open(args.csv_file, 'r') as csv_f: + reader = csv.reader(csv_f, skipinitialspace=True) + # Skip 3 lines ; Header + next(reader) + next(reader) + next(reader) + for row in reader: + # Logging + print(row, flush=True) + # Link for the Youtube Video + YouTube_ID = row[0] # "-0RWZT-miFs" + start_time = int(float(row[1])) # 420 + end_time = int(float(row[2])) # 430 + # Construct downloadable link + YouTube_link = "https://youtu.be/" + YouTube_ID + # Output Filename + output_file = f"{args.target_folder}/ID_{YouTube_ID}.wav" + # Start time in hrs:min:sec format + start_sec = start_time % 60 + start_min = (start_time // 60) % 60 + start_hrs = start_time // 3600 + # End time in hrs:min:sec format + end_sec = end_time % 60 + end_min = (end_time // 60) % 60 + end_hrs = end_time // 3600 + # Start and End time args + time_args = f"-ss {start_hrs}:{start_min}:{start_sec} -to {end_hrs}:{end_min}:{end_sec}" + # Command Line Execution + os.system(f"youtube-dl -x -q --audio-format wav --postprocessor-args '{time_args}' {YouTube_link}" + " --exec 'mv {} " + f"{output_file}'") + print('', flush=True) diff --git a/applications/KWS_Phoneme/data_pipe.py b/applications/KWS_Phoneme/data_pipe.py new file mode 100644 index 000000000..c9337e2be --- /dev/null +++ b/applications/KWS_Phoneme/data_pipe.py @@ -0,0 +1,535 @@ +import torch +import argparse +import torch.utils.data +import os, glob, random +from collections import Counter +import soundfile as sf +import scipy.signal +import scipy.io.wavfile +import numpy as np +import textgrid +import multiprocessing +from subprocess import call +import librosa +import math +from numpy import random + +def synthesize_wave(sigx, snr, wgn_snr, gain, do_rir, args): + """ + Synth Block - Used to process the input audio. + The input is convolved with room reverbration recording + Adds noise in the form of white gaussian noise and regular audio clips (eg:piano, people talking, car engine etc) + + Input + sigx : input signal to the block + snr : signal-to-noise ratio of the input and additive noise (regular audio) + wg_snr : signal-to-noise ratio of the input and additive noise (white gaussian noise) + gain : gain of the output signal + do_rir : boolean flag, if reverbration needs to be incorporated + args : args object (contains info about model and training) + + Output + clipped version of the audio post-processing + """ + beta = np.random.choice([0.1, 0.25, 0.5, 0.75, 1]) + sigx = beta * sigx + x_power = np.sum(sigx * sigx) + + # Do RIR and normalize back to original power + if do_rir: + rir_base_path = args.rir_base_path + rir_fname = random.choice(os.listdir(rir_base_path)) + rir_full_fname = rir_base_path + rir_fname + rir_sample, fs = sf.read(rir_full_fname) + if rir_sample.ndim > 1: + rir_sample = rir_sample[:,0] + # We cut the tail of the RIR signal at 99% energy + cum_en = np.cumsum(np.power(rir_sample, 2)) + cum_en = cum_en / cum_en[-1] + rir_sample = rir_sample[cum_en <= 0.99] + + max_spike = np.argmax(np.abs(rir_sample)) + sigy = scipy.signal.fftconvolve(sigx, rir_sample)[max_spike:] + sigy = sigy[0:len(sigx)] + + y_power = np.sum(sigy * sigy) + sigy *= math.sqrt(x_power / y_power) # normalize so y has same total power + + else: + sigy = sigx + + y_rmse = math.sqrt(x_power / len(sigy)) + + # Only bother with noise addition if the SNR is low enough + if snr < 50: + add_sample = get_add_noise(args) + noise_rmse = math.sqrt(np.sum(add_sample * add_sample) / len(add_sample)) #+ 0.000000000000000001 + + if len(add_sample) < len(sigy): + padded = np.zeros(len(sigy), dtype=np.float32) + padded[0:len(add_sample)] = add_sample + else: + padded = add_sample[0:len(sigy)] + + add_sample = padded + + noise_scale = y_rmse / noise_rmse * math.pow(10, -snr / 20) + sigy = sigy + add_sample * noise_scale + + if wgn_snr < 50: + wgn_samps = np.random.normal(size=(len(sigy))).astype(np.float32) + noise_scale = y_rmse * math.pow(10, -wgn_snr / 20) + sigy = sigy + wgn_samps * noise_scale + + # Apply gain & clipping + return np.clip(sigy * gain, -1.0, 1.0) + +def get_add_noise(args): + """ + Extracts the additive noise file from the defined path + + Input + args: args object (contains info about model and training) + + Output + add_sample: addtive noise audio + """ + additive_base_path = args.additive_base_path + add_fname = random.choice(os.listdir(additive_base_path)) + add_full_fname = additive_base_path + add_fname + add_sample, fs = sf.read(add_full_fname) + + return add_sample + +def get_ASR_datasets(args): + """ + Function for preparing the data samples for the phoneme pipeline + + Input + args: args object (contains info about model and training) + + Output + train_dataset: dataset class used for loading the samples into the training pipeline + """ + base_path = args.base_path + + # Load the speech data. This code snippet (till line 121) depends on the data format in base_path + train_textgrid_paths = glob.glob(base_path + + "/text/train-clean*/*/*/*.TextGrid") + + train_wav_paths = [path.replace("text", "audio").replace(".TextGrid", ".wav") + for path in train_textgrid_paths] + + if args.pre_phone_list: + # If there is a list of phons in the dataset, use this flag + Sy_phoneme = [] + with open(args.phoneme_text_file, "r") as f: + for line in f.readlines(): + if line.rstrip("\n") != "": Sy_phoneme.append(line.rstrip("\n")) + args.num_phonemes = len(Sy_phoneme) + print("**************", flush=True) + print("Phoneme List", flush=True) + print(Sy_phoneme, flush=True) + print("**************", flush=True) + print("**********************", flush=True) + print("Total Num of Phonemes", flush=True) + print(len(Sy_phoneme), flush=True) + print("**********************", flush=True) + else: + # No list of phons specified. Count from the input dataset + phoneme_counter = Counter() + for path in train_textgrid_paths: + tg = textgrid.TextGrid() + tg.read(path) + phoneme_counter.update([phone.mark.rstrip("0123456789") + for phone in tg.getList("phones")[0] + if phone.mark not in ['', 'sp', 'spn']]) + + # Display and store the phons extracted + Sy_phoneme = list(phoneme_counter) + args.num_phonemes = len(Sy_phoneme) + print("**************", flush=True) + print("Phoneme List", flush=True) + print(Sy_phoneme, flush=True) + print("**************", flush=True) + print("**********************", flush=True) + print("Total Num of Phonemes", flush=True) + print(len(Sy_phoneme), flush=True) + print("**********************", flush=True) + with open(args.phoneme_text_file, "w") as f: + for phoneme in Sy_phoneme: + f.write(phoneme + "\n") + + print("Data Path Prep Done.", flush=True) + + # Create dataset objects + train_dataset = ASRDataset(train_wav_paths, train_textgrid_paths, Sy_phoneme, args) + + return train_dataset + +class ASRDataset(torch.utils.data.Dataset): + def __init__(self, wav_paths, textgrid_paths, Sy_phoneme, args): + """ + Dataset iterator for the phoneme detection model + + Input + wav_paths : list of strings (wav file paths) + textgrid_paths : list of strings (textgrid for each wav file) + Sy_phoneme : list of strings (all possible phonemes) + args : args object (contains info about model and training) + """ + self.wav_paths = wav_paths + self.textgrid_paths = textgrid_paths + self.length_mean = args.pretraining_length_mean + self.length_var = args.pretraining_length_var + self.Sy_phoneme = Sy_phoneme + self.args = args + # Dataset Loader for the iterator + self.loader = torch.utils.data.DataLoader(self, batch_size=args.pretraining_batch_size, + num_workers=args.workers, shuffle=True, + collate_fn=CollateWavsASR()) + + def __len__(self): + """ + Number of audio samples available + """ + return len(self.wav_paths) + + def __getitem__(self, idx): + """ + Gives one sample from the dataset. Data is read in this snippet. + (refer to the coalate function for pre-procesing) + + Input: + idx: index for the sample + + Output: + x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping + y_phoneme : the output phonemes sampled at 30ms + """ + x, fs = sf.read(self.wav_paths[idx]) + + tg = textgrid.TextGrid() + tg.read(self.textgrid_paths[idx]) + + y_phoneme = [] + for phoneme in tg.getList("phones")[0]: + duration = phoneme.maxTime - phoneme.minTime + phoneme_index = self.Sy_phoneme.index(phoneme.mark.rstrip("0123456789")) if phoneme.mark.rstrip("0123456789") in self.Sy_phoneme else -1 + if phoneme.mark == '': phoneme_index = -1 + y_phoneme += [phoneme_index] * round(duration * fs) + + # Cut a snippet of length random_length from the audio + random_length = round(fs * (self.length_mean + self.length_var * torch.randn(1).item())) + if len(x) <= random_length: + start = 0 + else: + start = torch.randint(low=0, high=len(x)-random_length, size=(1,)).item() + end = start + random_length + + x = x[start:end] + + if np.random.random() < self.args.synth_chance: + x = synthesize_wave(x, np.random.choice(self.args.snr_samples), + np.random.choice(self.args.wgn_snr_samples), np.random.choice(self.args.gain_samples), + np.random.random() < self.args.rir_chance, self.args) + + self.phone_downsample_factor = 160 + y_phoneme = y_phoneme[start:end:self.phone_downsample_factor] + + # feature = librosa.feature.mfcc(x,sr=16000,n_mfcc=80,win_length=25*16,hop_length=10*16) + + return (x, y_phoneme) + +class CollateWavsASR: + def __call__(self, batch): + """ + Pre-processing and padding, followed by batching the set of inputs + + Input: + batch: list of tuples (input wav, phoneme labels) + + Output: + feature_tensor : the melspectogram features of the input audio. The features are padded for batching. + y_phoneme_tensor : the phonemes sequences in a tensor format. The phoneme sequences are padded for batching. + """ + x = []; y_phoneme = [] + batch_size = len(batch) + for index in range(batch_size): + x_,y_phoneme_, = batch[index] + + x.append(x_) + y_phoneme.append(y_phoneme_) + + # pad all sequences to have same length and get features + features=[] + T = max([len(x_) for x_ in x]) + U_phoneme = max([len(y_phoneme_) for y_phoneme_ in y_phoneme]) + for index in range(batch_size): + # pad audio to same length for all the audio samples in the batch + x_pad_length = (T - len(x[index])) + x[index] = np.pad(x[index], (x_pad_length,0), 'constant', constant_values=(0, 0)) + + # Extract Mel-Spectogram from padded audio + feature = librosa.feature.melspectrogram(y=x[index],sr=16000,n_mels=80, + win_length=25*16,hop_length=10*16, n_fft=512) + + feature = librosa.core.power_to_db(feature) + # Normalize the features + max_value = np.max(feature) + min_value = np.min(feature) + feature = (feature - min_value) / (max_value - min_value) + features.append(feature) + + # Pad the labels to same length for all samples in the batch + y_pad_length = (U_phoneme - len(y_phoneme[index])) + y_phoneme[index] = np.pad(y_phoneme[index], (y_pad_length,0), 'constant', constant_values=(-1, -1)) + + features_tensor = []; y_phoneme_tensor = [] + batch_size = len(batch) + for index in range(batch_size): + # x_,y_phoneme_, = batch[index] + x_ = features[index] + y_phoneme_ = y_phoneme[index] + + features_tensor.append(torch.tensor(x_).float()) + y_phoneme_tensor.append(torch.tensor(y_phoneme_).long()) + + features_tensor = torch.stack(features_tensor) + y_phoneme_tensor = torch.stack(y_phoneme_tensor) + + return (features_tensor,y_phoneme_tensor) + +def get_classification_dataset(args): + """ + Function for preparing the data samples for the classification pipeline + + Input + args: args object (contains info about model and training) + + Output + train_dataset : dataset class used for loading the samples into the training pipeline + test_dataset : dataset class used for loading the samples into the testing pipeline + """ + base_path = args.base_path + + # Train Data + train_wav_paths = [] + train_labels = [] + + # data_folder_list = ["google30_train"] or ["google30_azure_tts", "google30_google_tts"] + data_folder_list = args.train_data_folders + for data_folder in data_folder_list: + # For each of the folder, iterate through the words and get the files and the labels + for (label, word) in enumerate(args.words): + curr_word_files = glob.glob(base_path + f"/{data_folder}/" + word + "/*.wav") + train_wav_paths += curr_word_files + train_labels += [label]*len(curr_word_files) + print(f"Number of Train Files {len(train_wav_paths)}", flush=True) + temp = list(zip(train_wav_paths, train_labels)) + random.shuffle(temp) + train_wav_paths, train_labels = zip(*temp) + print(f"Train Data Folders Used {data_folder_list}", flush=True) + # Create dataset objects + train_dataset = ClassificationDataset(wav_paths=train_wav_paths, labels=train_labels, args=args, is_train=True) + + # Test Data + test_wav_paths = [] + test_labels = [] + + # data_folder_list = ["google30_test"] + data_folder_list = args.test_data_folders + for data_folder in data_folder_list: + # For each of the folder, iterate through the words and get the files and the labels + for (label, word) in enumerate(args.words): + curr_word_files = glob.glob(base_path + f"/{data_folder}/" + word + "/*.wav") + test_wav_paths += curr_word_files + test_labels += [label]*len(curr_word_files) + print(f"Number of Test Files {len(test_wav_paths)}", flush=True) + temp = list(zip(test_wav_paths, test_labels)) + random.shuffle(temp) + test_wav_paths, test_labels = zip(*temp) + print(f"Test Data Folders Used {data_folder_list}", flush=True) + # Create dataset objects + test_dataset = ClassificationDataset(wav_paths=test_wav_paths, labels=test_labels, args=args, is_train=False) + + return train_dataset, test_dataset + +class ClassificationDataset(torch.utils.data.Dataset): + def __init__(self, wav_paths, labels, args, is_train=True): + """ + Dataset iterator for the classifier model + + Input + wav_paths : list of strings (wav file paths) + labels : list of classification labels for the corresponding audio wav filess + is_train : boolean flag, if the dataset loader is for the train or test pipeline + args : args object (contains info about model and training) + """ + self.wav_paths = wav_paths + self.labels = labels + self.args = args + self.is_train = is_train + self.loader = torch.utils.data.DataLoader(self, batch_size=args.pretraining_batch_size, + num_workers=args.workers, shuffle=is_train, + collate_fn=CollateWavsClassifier()) + + def __len__(self): + """ + Number of audio samples available + """ + return len(self.wav_paths) + + def one_hot_encoder(self, lab): + """ + Label index to one-hot encoder + + Input: + lab: label index + + Output: + one_hot: label in the one-hot format + """ + one_hot = np.zeros(len(self.args.words)) + one_hot[lab]=1 + return one_hot + + def __getitem__(self, idx): + """ + Gives one sample from the dataset. Data is read in this snippet. (refer to the coalate function for pre-procesing) + + Input: + idx: index for the sample + + Output: + x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping + one_hot_label : one hot encoded label + seqlen : length of the audio file. + This value will be dropped and seqlen after feature extraction will be used. Refer to the coallate func + """ + x, fs = sf.read(self.wav_paths[idx]) + + label = self.labels[idx] + one_hot_label = self.one_hot_encoder(label) + seqlen = len(x) + + if self.is_train: + # Use synth only for train files + if self.args.synth: + if np.random.random() < self.args.synth_chance: + x = synthesize_wave(x, np.random.choice(self.args.snr_samples), + np.random.choice(self.args.wgn_snr_samples), np.random.choice(self.args.gain_samples), + np.random.random() < self.args.rir_chance, self.args) + + return (x, one_hot_label, seqlen) + +class CollateWavsClassifier: + def __call__(self, batch): + """ + Pre-processing and padding, followed by batching the set of inputs + + Input: + batch: list of tuples (input wav, one hot classification label, sequence length) + + Output: + feature_tensor : the melspectogram features of the input audio. The features are padded for batching. + one_hot_label_tensor : the on-hot label in a tensor format + seqlen_tensor : the sequence length of the features in a minibatch + """ + x = []; one_hot_label = []; seqlen = [] + batch_size = len(batch) + for index in range(batch_size): + x_,one_hot_label_,_ = batch[index] + x.append(x_) + one_hot_label.append(one_hot_label_) + + # pad all sequences to have same length and get features + features=[] + T = max([len(x_) for x_ in x]) + T = max([T, 48000]) + for index in range(batch_size): + # pad audio to same length for all the audio samples in the batch + x_pad_length = (T - len(x[index])) + x[index] = np.pad(x[index], (x_pad_length,0), 'constant', constant_values=(0, 0)) + + # Extract MFCC from padded audio + feature = librosa.feature.melspectrogram(y=x[index],sr=16000,n_mels=80,win_length=25*16,hop_length=10*16, n_fft=512) + feature = librosa.core.power_to_db(feature) + # Normalize the features + max_value = np.max(feature) + min_value = np.min(feature) + if min_value == max_value: + feature = feature - min_value + else: + feature = (feature - min_value) / (max_value - min_value) + features.append(feature) + seqlen.append(feature.shape[1]) + + features_tensor = []; one_hot_label_tensor = []; seqlen_tensor = [] + batch_size = len(batch) + for index in range(batch_size): + x_ = features[index] + one_hot_label_ = one_hot_label[index] + seqlen_ = seqlen[index] + + features_tensor.append(torch.tensor(x_).float()) + one_hot_label_tensor.append(torch.tensor(one_hot_label_).long()) + seqlen_tensor.append(torch.tensor(seqlen_)) + + features_tensor = torch.stack(features_tensor) + one_hot_label_tensor = torch.stack(one_hot_label_tensor) + seqlen_tensor = torch.stack(seqlen_tensor) + + return (features_tensor,one_hot_label_tensor,seqlen_tensor) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--base_path', type=str, default="/mnt/kws_data/data/") + parser.add_argument('--train_data_folders', type=str, default="google30_train") + parser.add_argument('--test_data_folders', type=str, default="google30_test") + parser.add_argument('--rir_base_path', type=str, default="/mnt/kws_data/data/noises_sachin/iir/") + parser.add_argument('--additive_base_path', type=str, default="/mnt/kws_data/data/noises_sachin/additive/") + parser.add_argument('--phoneme_text_file', type=str, default="/mnt/kws_data/data/LibriSpeech/text/phonemes.txt") + parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") + parser.add_argument("--word_model_name", default='google30', help="Name of the word list chosen. Will be used in conjunction with the data loader") + parser.add_argument('--words', type=str, default="all") + parser.add_argument("--synth", action='store_true', help="Use Synth block or not") + parser.add_argument('--pretraining_length_mean', type=int, default=9) + parser.add_argument('--pretraining_length_var', type=int, default=1) + parser.add_argument('--pretraining_batch_size', type=int, default=64) + parser.add_argument('--snr_samples', type=str, default="0,5,10,25,100,100") + parser.add_argument('--wgn_snr_samples', type=str, default="5,10,15,100,100") + parser.add_argument('--gain_samples', type=str, default="1.0,0.25,0.5,0.75") + parser.add_argument('--rir_chance', type=float, default=0.25) + parser.add_argument('--synth_chance', type=float, default=0.5) + parser.add_argument('--pre_phone_list', action='store_true') + args = parser.parse_args() + + # SNRs + args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] + args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] + args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] + + # Workers + if args.workers == -1: + args.workers = multiprocessing.cpu_count() + + # Words + if args.word_model_name == 'google30': + args.words = ["bed", "bird", "cat", "dog", "down", "eight", "five", "four", "go", + "happy", "house", "left", "marvin", "nine", "no", "off", "on", "one", "right", + "seven", "sheila", "six", "stop", "three", "tree", "two", "up", "wow", "yes", "zero" + ] + elif args.word_model_name == 'google10': + args.words = ["yes", "no", "up", "down", "left", "right", "on", "off", + "stop", "go", "allsilence", "unknown"] + else: + raise ValueError('Incorrect Word Model Name') + + # Data Folders + args.train_data_folders = [folder_idx for folder_idx in args.train_data_folders.split(',')] + args.test_data_folders = [folder_idx for folder_idx in args.test_data_folders.split(',')] + + print(args.pre_phone_list, flush=True) + # get_ASR_datasets(args) + dset = get_classification_dataset(args) diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py new file mode 100644 index 000000000..3f3a2710f --- /dev/null +++ b/applications/KWS_Phoneme/kwscnn.py @@ -0,0 +1,559 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from utils import _single +import math +import numpy as np +from edgeml_pytorch.graph.rnn import * + +class _IndexSelect(torch.nn.Module): + """Channel permutation module. The purpose of this is to allow mixing across the CNN groups.""" + def __init__(self, channels, direction, groups): + super(_IndexSelect, self).__init__() + + if channels % groups != 0: + raise ValueError('Channels should be a multiple of the groups') + + self._index = torch.zeros((channels), dtype=torch.int64) + count = 0 + + if direction > 0: + for gidx in range(groups): + for nidx in range(gidx, channels, groups): + self._index[count] = nidx + count += 1 + else: + for gidx in range(groups): + for nidx in range(gidx, channels, groups): + self._index[nidx] = count + count += 1 + + def forward(self, value): + if value.device != self._index.device: + self._index = self._index.to(value.device) + + return torch.index_select(value, 1, self._index) + + +class _TanhGate(torch.nn.Module): + def __init__(self): + super(_TanhGate, self).__init__() + + def forward(self, value): + channels = value.shape[1] + piv = int(channels/2) + + sig_data = value[:, 0:piv, :] + tanh_data = value[:, piv:, :] + + sig_data = torch.sigmoid(sig_data) + tanh_data = torch.tanh(tanh_data) + return sig_data * tanh_data + + +class LR_conv(nn.Conv1d): + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, + bias=True, padding_mode='zeros', rank=50): + + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode) + + assert kernel_size == 5 + rank = rank + self.W1 = Parameter(torch.Tensor(self.out_channels, rank)) + nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) + self.W2 = Parameter(torch.Tensor(rank, self.in_channels * 5)) + nn.init.kaiming_uniform_(self.W2, a=math.sqrt(5)) + self.weight = None + + def forward(self, input): + lr_weight = torch.matmul(self.W1, self.W2) + lr_weight = torch.reshape(lr_weight, (self.out_channels, self.in_channels, 5)) + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + lr_weight, self.bias, self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, lr_weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + +class DSCNNBlockLR_k5(torch.nn.Module): + """A depth-separate CNN block""" + + def __init__( + self, in_channels, out_channels, kernel, + stride=1, groups=1, avg_pool=2, dropout=0.1, + batch_norm=0.1, do_depth=True, shuffle=0, + activation='sigmoid', rank=50): + + super(DSCNNBlockLR_k5, self).__init__() + + activators = { + 'sigmoid': torch.nn.Sigmoid(), + 'relu': torch.nn.ReLU(), + 'leakyrelu': torch.nn.LeakyReLU(), + 'tanhgate': _TanhGate(), + 'none': None + } + + if activation not in activators: + raise ValueError('Available activations are: %s' % ', '.join(activators.keys())) + + if activation == 'tanhgate': + in_channels = int(in_channels/2) + + nonlin = activators[activation] + + if batch_norm > 0.0: + batch_block = torch.nn.BatchNorm1d(in_channels, affine=False, momentum=batch_norm) + else: + batch_block = None + + #do_depth = False + # depth_cnn = torch.nn.Conv1d(in_channels, in_channels, kernel_size=kernel, stride=1, groups=in_channels) + depth_cnn = None + # point_cnn = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups) + point_cnn = LR_conv(in_channels, out_channels, kernel_size=kernel, stride=stride, groups=groups, rank=rank, padding=2) + + if shuffle != 0 and groups > 1: + shuffler = _IndexSelect(in_channels, shuffle, groups) + else: + shuffler = None + + if avg_pool > 0: + pool = torch.nn.AvgPool1d(kernel_size=avg_pool, stride=1) + else: + pool = None + + if dropout > 0: + dropout_block = torch.nn.Dropout(p=dropout) + else: + dropout_block = None + + seq1 = [nonlin, batch_block, depth_cnn, shuffler, point_cnn, dropout_block, pool] + # seq1 = [nonlin, batch_block, depth_cnn]#, shuffler, point_cnn, dropout_block, pool] + seq_f1 = [item for item in seq1 if item is not None] + if len(seq_f1) == 1: + self._op1 = seq_f1[0] + else: + self._op1 = torch.nn.Sequential(*seq_f1) + + seq2 = [shuffler, dropout_block, pool] + seq_f2 = [item for item in seq2 if item is not None] + if len(seq_f2) == 1: + print("Only 1 op in seq2") + self._op2 = seq_f2[0] + else: + self._op2 = torch.nn.Sequential(*seq_f2) + + def forward(self, x): + xb = self._op1[0](x) + x_cnn = self._op1[1](xb) + return x_cnn + + +class LR_pointcnn(nn.Conv1d): + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, + bias=True, padding_mode='zeros', rank=50): + super().__init__( + in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias, padding_mode) + + # assert stride == 3 + assert kernel_size == 1 + rank = rank + self.W1 = Parameter(torch.Tensor(self.out_channels, rank)) + nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) + self.W2 = Parameter(torch.Tensor(rank, self.in_channels)) + nn.init.kaiming_uniform_(self.W2, a=math.sqrt(5)) + self.weight = None + + def forward(self, input): + lr_weight = torch.matmul(self.W1, self.W2) + lr_weight = torch.reshape(lr_weight, (self.out_channels, self.in_channels, 1)) + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + lr_weight, self.bias, self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, lr_weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + +class DSCNNBlockLR_better(torch.nn.Module): + """A depth-separate CNN block""" + + def __init__( + self, in_channels, out_channels, kernel, + stride=1, groups=1, avg_pool=2, dropout=0.1, + batch_norm=0.1, do_depth=True, shuffle=0, + activation='sigmoid', rank=50): + + super(DSCNNBlockLR_better, self).__init__() + + activators = { + 'sigmoid': torch.nn.Sigmoid(), + 'relu': torch.nn.ReLU(), + 'leakyrelu': torch.nn.LeakyReLU(), + 'tanhgate': _TanhGate(), + 'none': None + } + + if activation not in activators: + raise ValueError('Available activations are: %s' % ', '.join(activators.keys())) + + if activation == 'tanhgate': + in_channels = int(in_channels/2) + + nonlin = activators[activation] + + if batch_norm > 0.0: + batch_block = torch.nn.BatchNorm1d(in_channels, affine=False, momentum=batch_norm) + else: + batch_block = None + + depth_cnn = torch.nn.Conv1d(in_channels, in_channels, kernel_size=kernel, stride=1, groups=in_channels, padding=2) + point_cnn = LR_pointcnn(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups, rank=rank) + + if shuffle != 0 and groups > 1: + shuffler = _IndexSelect(in_channels, shuffle, groups) + else: + shuffler = None + + if avg_pool > 0: + pool = torch.nn.AvgPool1d(kernel_size=avg_pool, stride=1) + else: + pool = None + + if dropout > 0: + dropout_block = torch.nn.Dropout(p=dropout) + else: + dropout_block = None + + seq = [nonlin, batch_block, depth_cnn, shuffler, point_cnn, dropout_block, pool] + # seq1 = [nonlin, batch_block, depth_cnn]#, shuffler, point_cnn, dropout_block, pool] + seq_f = [item for item in seq if item is not None] + if len(seq_f) == 1: + self._op = seq_f[0] + else: + self._op = torch.nn.Sequential(*seq_f) + + def forward(self, x): + x = self._op(x) + return x + +class BiFastGRNN(nn.Module): + "Bi Directional FastGRNN" + + def __init__(self, inputDims, hiddenDims, gate_nonlinearity, + update_nonlinearity, rank): + super(BiFastGRNN, self).__init__() + + self.cell_fwd = FastGRNNCUDA(inputDims, + hiddenDims, + gate_nonlinearity, + update_nonlinearity, + batch_first=True, + wRank=rank, + uRank=rank) + + self.cell_bwd = FastGRNNCUDA(inputDims, + hiddenDims, + gate_nonlinearity, + update_nonlinearity, + batch_first=True, + wRank=rank, + uRank=rank) + + def forward(self, input_f, input_b): + # Bidirectional FastGRNN + output1 = self.cell_fwd(input_f) + output2 = self.cell_bwd(input_b) + #Returning the flipped output only for the bwd pass + #Will align it in the post processing + return output1, output2 + + +def X_preRNN_process(X, fwd_context, bwd_context): + "Bricked RNN Architecture" + + #FWD bricking + brickLength = fwd_context + hopLength = 3 + X_bricked_f1 = X.unfold(1, brickLength, hopLength) + X_bricked_f2 = X_bricked_f1.permute(0, 1, 3, 2) + #X_bricked_f [batch, num_bricks, brickLen, inpDim] + oldShape_f = X_bricked_f2.shape + X_bricked_f = torch.reshape( + X_bricked_f2, [oldShape_f[0] * oldShape_f[1], oldShape_f[2], -1]) + # X_bricked_f [batch*num_bricks, brickLen, inpDim] + + #BWD bricking + brickLength = bwd_context + hopLength = 3 + X_bricked_b = X.unfold(1, brickLength, hopLength) + X_bricked_b = X_bricked_b.permute(0, 1, 3, 2) + #X_bricked_f [batch, num_bricks, brickLen, inpDim] + oldShape_b = X_bricked_b.shape + X_bricked_b = torch.reshape( + X_bricked_b, [oldShape_b[0] * oldShape_b[1], oldShape_b[2], -1]) + # X_bricked_f [batch*num_bricks, brickLen, inpDim] + return X_bricked_f, oldShape_f, X_bricked_b, oldShape_b + + +def X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b): + + #Forward bricks folding + X_f = torch.reshape(X_f, [oldShape_f[0], oldShape_f[1], oldShape_f[2], -1]) + #X_f [batch, num_bricks, brickLen, hiddenDim] + X_new_f = X_f[:, 0, ::3, :] + #batch,brickLen,hiddenDim + X_new_f_rest = X_f[:, :, -1, :].squeeze(2) + #batch, numBricks-1,hiddenDim + shape = X_new_f_rest.shape + X_new_f = torch.cat((X_new_f, X_new_f_rest), dim=1) + #batch,seqLen,hiddenDim + #X_new_f [batch, seqLen, hiddenDim] + + #Backward Bricks folding + X_b = torch.reshape(X_b, [oldShape_b[0], oldShape_b[1], oldShape_b[2], -1]) + #X_b [batch, num_bricks, brickLen, hiddenDim] + X_b = torch.flip(X_b, [1]) + #Reverse the ordering of the bricks (bring last brick to start) + + X_new_b = X_b[:, 0, ::3, :] + #batch,brickLen,inpDim + X_new_b_rest = X_b[:, :, -1, :].squeeze(2) + #batch,(seqlen-brickLen),hiddenDim + X_new_b = torch.cat((X_new_b, X_new_b_rest), dim=1) + #batch,seqLen,hiddenDim + X_new_b = torch.flip(X_new_b, [1]) + #inverting the flip operation + + X_new = torch.cat((X_new_f, X_new_b), dim=2) + #batch,seqLen,2*hiddenDim + + return X_new + + +class DSCNN_RNN_Block(torch.nn.Module): + def __init__(self, cnn_channels, rnn_hidden_size, rnn_num_layers, + device, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", + isBi=True, num_labels=41, rank=None, fwd_context=15, bwd_context=9): + + super(DSCNN_RNN_Block, self).__init__() + + self.cnn_channels = cnn_channels + self.rnn_hidden_size = rnn_hidden_size + self.rnn_num_layers = rnn_num_layers + self.gate_nonlinearity = gate_nonlinearity + self.update_nonlinearity = update_nonlinearity + self.num_labels = num_labels + self.device = device + self.fwd_context = fwd_context + self.bwd_context = bwd_context + self.isBi = isBi + if self.isBi: + self.direction_param = 2 + else: + self.direction_param = 1 + + self.declare_network(cnn_channels, rnn_hidden_size, rnn_num_layers, + num_labels, rank) + + self.__name__ = 'DSCNN_RNN_Block' + + def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, + num_labels, rank): + + self.CNN1 = torch.nn.Sequential( + DSCNNBlockLR_k5(80, cnn_channels, 5, 1, 1, + 0, 0, do_depth=False, batch_norm=1e-2, + activation='none', rank=rank)) + + #torch tanh layer directly in the forward pass + self.bnorm_rnn = torch.nn.BatchNorm1d(cnn_channels, affine=False, momentum=1e-2) + self.RNN0 = BiFastGRNN(cnn_channels, rnn_hidden_size, + self.gate_nonlinearity, + self.update_nonlinearity, rank) + + self.CNN2 = DSCNNBlockLR_better(2 * rnn_hidden_size, + 2 * rnn_hidden_size, + batch_norm=1e-2, + dropout=0, + kernel=5, + activation='tanhgate', rank=rank) + + self.CNN3 = DSCNNBlockLR_better(2 * rnn_hidden_size, + 2 * rnn_hidden_size, + batch_norm=1e-2, + dropout=0, + kernel=5, + activation='tanhgate', rank=rank) + + self.CNN4 = DSCNNBlockLR_better(2 * rnn_hidden_size, + 2 * rnn_hidden_size, + batch_norm=1e-2, + dropout=0, + kernel=5, + activation='tanhgate', rank=rank) + + self.CNN5 = DSCNNBlockLR_better(2 * rnn_hidden_size, + num_labels, + batch_norm=1e-2, + dropout=0, + kernel=5, + activation='tanhgate', rank=rank) + + def forward(self, features): + + batch, _, max_seq_len = features.shape + X = self.CNN1(features) # Down to 30ms inference / 250ms window + X = torch.tanh(X) + X = self.bnorm_rnn(X) + X = X.permute((0, 2, 1)) # NCL to NLC + + X = X.contiguous() + assert X.shape[1] % 3 == 0 + X_f, oldShape_f, X_b, oldShape_b = X_preRNN_process(X, self.fwd_context, self.bwd_context) + #X [batch * num_bricks, brickLen, inpDim] + X_b_f = torch.flip(X_b, [1]) + + X_f, X_b = self.RNN0(X_f, X_b_f) + + X = X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b) + # X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True) + + # X = X.view(batch, max_seq_len, 2* self.rnn_hidden_size) + + # re-permute to get [batch, channels, max_seq_len/3 ] + X = X.permute((0, 2, 1)) # NLC to NCL + # print("Pre CNN2 Shape : ", X.shape) + X = self.CNN2(X) + X = self.CNN3(X) + X = self.CNN4(X) + X = self.CNN5(X) + return X + + +class Binary_Classification_Block(torch.nn.Module): + def __init__(self, in_size, rnn_hidden_size, rnn_num_layers, + device, islstm=False, isBi=True, momentum=1e-2, + num_labels=2, dropout=0, batch_assertion=False): + + super(Binary_Classification_Block, self).__init__() + + self.in_size = in_size + self.rnn_hidden_size = rnn_hidden_size + self.rnn_num_layers = rnn_num_layers + self.num_labels = num_labels + self.device = device + self.islstm = islstm + self.isBi = isBi + self.momentum = momentum + self.dropout = dropout + self.batch_assertion = batch_assertion + + if self.isBi: + self.direction_param = 2 + else: + self.direction_param = 1 + + self.declare_network(in_size, rnn_hidden_size, rnn_num_layers, num_labels) + + self.__name__ = 'Binary_Classification_Block_2lay' + + def declare_network(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels): + + self.CNN1 = torch.nn.Sequential( + torch.nn.LeakyReLU(negative_slope=0.01), + torch.nn.BatchNorm1d(in_size, affine=False, + momentum=self.momentum), + torch.nn.Dropout(self.dropout)) + + if self.islstm: + self.RNN = nn.LSTM(input_size=in_size, + hidden_size=rnn_hidden_size, + num_layers=rnn_num_layers, + batch_first=True, + bidirectional=self.isBi) + else: + self.RNN = nn.GRU(input_size=in_size, + hidden_size=rnn_hidden_size, + num_layers=rnn_num_layers, + batch_first=True, + bidirectional=self.isBi) + + self.FCN = torch.nn.Sequential( + torch.nn.Dropout(self.dropout), + torch.nn.LeakyReLU(negative_slope=0.01), + torch.nn.Linear(self.direction_param * self.rnn_hidden_size, + num_labels)) + + def forward(self, features, seqlen): + batch, _, _ = features.shape + + if self.islstm: + hidden1 = self.init_hidden(batch, self.rnn_hidden_size, + self.rnn_num_layers) + hidden2 = self.init_hidden(batch, self.rnn_hidden_size, + self.rnn_num_layers) + else: + hidden = self.init_hidden(batch, self.rnn_hidden_size, + self.rnn_num_layers) + + X = self.CNN1(features) # Down to 30ms inference / 250ms window + + X = X.permute((0, 2, 1)) # NCL to NLC + + max_seq_len = X.shape[1] + + # modify seqlen + max_seq_len = min(torch.max(seqlen).item(), max_seq_len) + seqlen = torch.clamp(seqlen, max=max_seq_len) + self.seqlen = seqlen + + # pad according to seqlen + X = torch.nn.utils.rnn.pack_padded_sequence(X, + seqlen, + batch_first=True, + enforce_sorted=False) + + self.RNN.flatten_parameters() + if self.islstm: + X, (hh, _) = self.RNN(X, (hidden1, hidden2)) + else: + X, hh = self.RNN(X, hidden) + + X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True) + + X = X.view(batch, max_seq_len, + self.direction_param * self.rnn_hidden_size) + + X = X[torch.arange(batch).long(), + seqlen.long() - 1, :].view( + batch, 1, self.direction_param * self.rnn_hidden_size) + + X = self.FCN(X) + + return X + + def init_hidden(self, batch, rnn_hidden_size, rnn_num_layers): + + # the weights are of the form (batch, num_layers * num_directions , hidden_size) + if self.batch_assertion: + hidden = torch.zeros(rnn_num_layers * self.direction_param, batch, + rnn_hidden_size) + else: + hidden = torch.zeros(rnn_num_layers * self.direction_param, batch, + rnn_hidden_size) + + hidden = hidden.to(self.device) + + hidden = Variable(hidden) + + return hidden + diff --git a/applications/KWS_Phoneme/train_classifier.py b/applications/KWS_Phoneme/train_classifier.py new file mode 100644 index 000000000..dd0623475 --- /dev/null +++ b/applications/KWS_Phoneme/train_classifier.py @@ -0,0 +1,293 @@ +import argparse +import os +import re +import numpy as np +import torch +# Aux scripts +import kwscnn +import multiprocessing +from data_pipe import get_ASR_datasets, get_classification_dataset + +def parseArgs(): + """ + Parse the command line arguments + Describes the architecture and the hyper-parameters + """ + parser = argparse.ArgumentParser() + # Args for Model Traning + parser.add_argument('--phoneme_model_load_ckpt', type=str, required=True, help="Phoneme checkpoint file to be loaded") + parser.add_argument('--classifier_model_save_folder', type=str, default='./classifier_model', help="Folder to save the classifier checkpoint") + parser.add_argument('--classifier_model_load_ckpt', type=str, default=None, help="Classifier checkpoint to be loaded") + parser.add_argument('--optim', type=str, default='adam', help="Optimizer to be used") + parser.add_argument('--lr', type=float, default=0.001, help="Optimizer learning rate") + parser.add_argument('--epochs', type=int, default=200, help="Number of epochs for training") + parser.add_argument('--save_tick', type=int, default=1, help="Number of epochs to wait to save") + parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") + parser.add_argument("--gpu", type=str, default='0', help="GPU indicies Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") + parser.add_argument("--word_model_name", default='google30', help="Name of the list of words used") + parser.add_argument('--words', type=str, help="List of words to be used. This will be assigned in the code. User input will not affect the result") + parser.add_argument("--is_training", action='store_true', help="True for training") + parser.add_argument("--synth", action='store_true', help="Use Synth block or not") + # Args for DataLoader + parser.add_argument('--base_path', type=str, required=True, help="path to train and test data folders") + parser.add_argument('--train_data_folders', type=str, default="google30_train", help="List of training folders in base path. Each folder is a dataset in the precribed format") + parser.add_argument('--test_data_folders', type=str, default="google30_test", help="List of testing folders in base path. Each folder is a dataset in the precribed format") + parser.add_argument('--rir_base_path', type=str, required=True, help="Folder with the reverbration files") + parser.add_argument('--additive_base_path', type=str, required=True, help="Folder with additive noise files") + parser.add_argument('--phoneme_text_file', type=str, help="Text files with pre-fixed phons") + parser.add_argument('--pretraining_length_mean', type=int, default=6, help="Mean of the audio clips lengths") + parser.add_argument('--pretraining_length_var', type=int, default=1, help="variance of the audio clip lengths") + parser.add_argument('--pretraining_batch_size', type=int, default=256, help="Batch size for the pirpeline") + parser.add_argument('--snr_samples', type=str, default="-5,0,0,5,10,15,40,100,100", help="SNR values for additive noise files") + parser.add_argument('--wgn_snr_samples', type=str, default="5,10,20,40,60", help="SNR values for white gaussian noise") + parser.add_argument('--gain_samples', type=str, default="1.0,0.25,0.5,0.75", help="Gain values for processed signal") + parser.add_argument('--rir_chance', type=float, default=0.9, help="Probability of performing reverbration") + parser.add_argument('--synth_chance', type=float, default=0.9, help="Probability of pre-processing the input with reverb and noise") + parser.add_argument('--pre_phone_list', action='store_true', help="use pre-fixed set of phonemes") + # Args for Phoneme + parser.add_argument('--phoneme_cnn_channels', type=int, default=400, help="Number od channels for the CNN layers") + parser.add_argument('--phoneme_rnn_hidden_size', type=int, default=200, help="Number of RNN hidden states") + parser.add_argument('--phoneme_rnn_layers', type=int, default=1, help="Number of RNN layers") + parser.add_argument('--phoneme_rank', type=int, default=50, help="Rank of the CNN layers weights") + parser.add_argument('--phoneme_fwd_context', type=int, default=15, help="RNN forward window context") + parser.add_argument('--phoneme_bwd_context', type=int, default=9, help="RNN backward window context") + parser.add_argument('--phoneme_phoneme_isBi', action='store_true', help="Use Bi-Directional RNN") + parser.add_argument('--phoneme_num_labels', type=int, default=41, help="Number og phoneme labels") + # Args for Classifier + parser.add_argument('--classifier_rnn_hidden_size', type=int, default=100, help="Calssifier RNN hidden dimensions") + parser.add_argument('--classifier_rnn_num_layers', type=int, default=1, help="Classifier RNN number of layers") + parser.add_argument('--classifier_dropout', type=float, default=0.2, help="Classifier dropout layer probability") + parser.add_argument('--classifier_islstm', action='store_true', help="Use LSTM in the classifier") + parser.add_argument('--classifier_isBi', action='store_true', help="Use Bi-Directional RNN in classifier") + + args = parser.parse_args() + + # Parse the gain and SNR values to a float format + args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] + args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] + args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] + + # Fix the number of workers for the data Loader. If == -1 then use all possible workers + if args.workers == -1: + args.workers = multiprocessing.cpu_count() + + # Choose the word list to be used. For custom word lists, please add an elif condition + if args.word_model_name == 'google30': + args.words = ["bed", "bird", "cat", "dog", "down", "eight", "five", "four", "go", + "happy", "house", "left", "marvin", "nine", "no", "off", "on", "one", "right", + "seven", "sheila", "six", "stop", "three", "tree", "two", "up", "wow", "yes", "zero"] + elif args.word_model_name == 'google10': + args.words = ["yes", "no", "up", "down", "left", "right", "on", "off", + "stop", "go", "allsilence", "unknown"] + else: + raise ValueError('Incorrect Word Model Name') + + # The data-folder in args.base_path that contain the data + # Refer to data_pipe.py for loading format + args.train_data_folders = [folder_idx for folder_idx in args.train_data_folders.split(',')] + args.test_data_folders = [folder_idx for folder_idx in args.test_data_folders.split(',')] + + print(f"Args : {args}", flush=True) + return args + +def train_classifier_model(args): + """ + Train the Classifier Model on the designated dataset + The Dataset loader is defined in data_pipe.py + Default dataset used is Google30. Change the paths and file reader to change datasets + + args: args object (contains info about model and training) + """ + # GPU Settings + gpu_str = str() + for gpu in args.gpu.split(','): + gpu_str = gpu_str + str(gpu) + "," + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str + use_cuda = torch.cuda.is_available() and (args.gpu != -1) + device = torch.device("cuda" if use_cuda else "cpu") + + # Instantiate Phoneme Model + phoneme_model = kwscnn.DSCNN_RNN_Block(cnn_channels=args.phoneme_cnn_channels, + rnn_hidden_size=args.phoneme_rnn_hidden_size, + rnn_num_layers=args.phoneme_rnn_layers, + device=device, rank=args.phoneme_rank, + fwd_context=args.phoneme_fwd_context, + bwd_context=args.phoneme_bwd_context, + num_labels=args.phoneme_num_labels) + + # Freeze Phoneme Model and Deactivate BatchNorm and Dropout Layers + for param in phoneme_model.parameters(): + param.requires_grad = False + phoneme_model.train(False) + + # Instantiate Classifier Model + classifier_model = kwscnn.Binary_Classification_Block(in_size=args.phoneme_num_labels, + rnn_hidden_size=args.classifier_rnn_hidden_size, + rnn_num_layers=args.classifier_rnn_num_layers, + device=device, islstm=args.classifier_islstm, + isBi=args.classifier_isBi, dropout=args.classifier_dropout, + num_labels=len(args.words)) + + # Transfer to specified device + phoneme_model.to(device) + phoneme_model = torch.nn.DataParallel(phoneme_model) + classifier_model.to(device) + classifier_model = torch.nn.DataParallel(classifier_model) + model = {'name': phoneme_model.module.__name__, 'phoneme': phoneme_model, + 'classifier_name': classifier_model.module.__name__, 'classifier': classifier_model} + + # Optimizer + if args.optim == "adam": + model['opt'] = torch.optim.Adam(model['classifier'].parameters(), lr=args.lr) + if args.optim == "sgd": + model['opt'] = torch.optim.SGD(model['classifier'].parameters(), lr=args.lr) + + # Load the sepcified phoneme checkpoint. 'phoneme_model_load_ckpt' must point to a chcekpoint and not folder + if args.phoneme_model_load_ckpt is not None: + if os.path.exists(args.phoneme_model_load_ckpt): + # Load Checkpoint + latest_phoneme_ckpt = torch.load(args.phoneme_model_load_ckpt, map_location=device) + # Load specific state_dicts() and print the latest stats + print(f"Model Phoneme Location : {args.phoneme_model_load_ckpt}", flush=True) + model['phoneme'].load_state_dict(latest_phoneme_ckpt['phoneme_state_dict']) + print(f"Checkpoint Stats : {latest_phoneme_ckpt['train_stats']}", flush=True) + else: + raise ValueError("Invalid Phoneme Checkpoint Path") + else: + print("No Phoneme Checkpoint Given", flush=True) + + # Load the sepcified classifier checkpoint. 'classifier_model_load_ckpt' must point to a chcekpoint and not folder + if args.classifier_model_load_ckpt is not None: + if os.path.exists(args.classifier_model_load_ckpt): + # Get the number from the classifier checkpoint path + start_epoch = args.classifier_model_load_ckpt # Temporarlity store the full ckpt path + start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux) + start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows) + start_epoch = int(start_epoch.split('.')[0]) # reatin the integers + # Load Checkpoint + latest_classifier_ckpt = torch.load(args.classifier_model_load_ckpt, map_location=device) + # Load specific state_dicts() and print the latest stats + model['classifier'].load_state_dict(latest_classifier_ckpt['classifier_state_dict']) + model['opt'].load_state_dict(latest_classifier_ckpt['opt_state_dict']) + print(f"Checkpoint Stats : {latest_classifier_ckpt['train_stats']}", flush=True) + else: + raise ValueError("Invalid Classifier Checkpoint Path") + else: + start_epoch = 0 + + # Instantiate all Essential Variables and utils + train_dataset, test_dataset = get_classification_dataset(args) + train_loader = train_dataset.loader + test_loader = test_dataset.loader + total_batches = len(train_loader) + output_frame_rate = 3 + save_path = args.classifier_model_save_folder + os.makedirs(args.classifier_model_save_folder, exist_ok=True) + # Print for cros-checking + print(f"Pre Phone List {args.pre_phone_list}", flush=True) + print(f"Start Epoch : {start_epoch}", flush=True) + print(f"Device : {device}", flush=True) + print(f"Output Frame Rate (multiple of 10ms): {output_frame_rate}", flush=True) + print(f"Number of Batches: {total_batches}", flush=True) + print(f"Synth: {args.synth}", flush=True) + print(f"Words: {args.words}", flush=True) + print(f"Optimizer : {model['opt']}", flush=True) + + # Train Loop + for epoch in range(start_epoch + 1, args.epochs): + model['train_stats'] = {'loss': 0, 'correct': 0, 'total': 0} + model['classifier'].train(True) + for train_features, train_label, train_seqlen in train_loader: + train_seqlen_classifier = train_seqlen.clone() / output_frame_rate + train_features = train_features.to(device) + train_label = train_label.to(device) + train_seqlen_classifier = train_seqlen_classifier.to(device) + model['opt'].zero_grad() + + # Data-padding for bricking + train_features = train_features.permute((0, 2, 1)) # NCL to NLC + mod_len = train_features.shape[1] + pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate + pad_len_feature = pad_len_mod + pad_data = torch.zeros(train_features.shape[0], pad_len_feature, + train_features.shape[2]).to(device) + train_features = torch.cat((train_features, pad_data), dim=1) + + assert (train_features.shape[1]) % output_frame_rate == 0 + + # Get the posterior predictions and trim the labels to the same length as the predictions + train_features = train_features.permute((0, 2, 1)) # NLC to NCL + train_posteriors = model['phoneme'](train_features) + train_posteriors = model['classifier'](train_posteriors, train_seqlen_classifier) + N, L, C = train_posteriors.shape + + # Permute and ready the final and pred labels values + train_flat_posteriors = train_posteriors.reshape((-1, C)) # to [NL] x C + + # Loss and backward step + train_label = train_label.type(torch.float32) + loss_classifier_model = torch.nn.functional.binary_cross_entropy_with_logits(train_flat_posteriors, train_label) + loss_classifier_model.backward() + torch.nn.utils.clip_grad_norm_(model['classifier'].parameters(), 10.0) + model['opt'].step() + + # Stats + model['train_stats']['loss'] += loss_classifier_model.detach() + _, train_idx_pred = torch.max(train_flat_posteriors, dim=1) + _, train_idx_label = torch.max(train_label, dim=1) + model['train_stats']['correct'] += float(np.sum((train_idx_pred == train_idx_label).detach().cpu().numpy())) + model['train_stats']['total'] += train_idx_label.shape[0] + + if epoch % args.save_tick == 0: + # Save the model + torch.save({'classifier_state_dict': model['classifier'].state_dict(), + 'opt_state_dict': model['opt'].state_dict(), 'train_stats' : model['train_stats']}, + os.path.join(save_path, f'{epoch}.pt')) + + avg_ce = model['train_stats']['loss'].cpu() / total_batches + train_accuracy = 100.0 * model['train_stats']['correct'] / model['train_stats']['total'] + print(f"Summary for Epoch {epoch} for Model {model['classifier_name']}; Loss: {avg_ce}", flush=True) + print(f"TRAIN => Accuracy: {train_accuracy}; Correct: {model['train_stats']['correct']}; Total: {model['train_stats']['total']}", flush=True) + + model['test_stats'] = {'correct': 0,'total': 0} + model['classifier'].eval() + with torch.no_grad(): + for test_features, test_label, test_seqlen in test_loader: + test_seqlen_classifier = test_seqlen.clone() / output_frame_rate + test_features = test_features.to(device) + test_label = test_label.to(device) + test_seqlen_classifier = test_seqlen_classifier.to(device) + + # Data-padding for bricking + test_features = test_features.permute((0, 2, 1)) # NCL to NLC + mod_len = test_features.shape[1] + pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate + pad_len_feature = pad_len_mod + pad_data = torch.zeros(test_features.shape[0], pad_len_feature, + test_features.shape[2]).to(device) + test_features = torch.cat((test_features, pad_data), dim=1) + + assert (test_features.shape[1]) % output_frame_rate == 0 + + # Get the posterior predictions and trim the labels to the same length as the predictions + test_features = test_features.permute((0, 2, 1)) # NLC to NCL + test_posteriors = model['phoneme'](test_features) + test_posteriors = model['classifier'](test_posteriors, test_seqlen_classifier) + N, L, C = test_posteriors.shape + + # Permute and ready the final and pred labels values + test_flat_posteriors = test_posteriors.reshape((-1, C)) # to [NL] x C + + # Stats + _, test_idx_pred = torch.max(test_flat_posteriors, dim=1) + _, test_idx_label = torch.max(test_label, dim=1) + model['test_stats']['correct'] += float(np.sum((test_idx_pred == test_idx_label).detach().cpu().numpy())) + model['test_stats']['total'] += test_idx_label.shape[0] + + test_accuracy = 100.0 * model['test_stats']['correct'] / model['test_stats']['total'] + print(f"TEST => Accuracy: {test_accuracy}; Correct: {model['test_stats']['correct']}; Total: {model['test_stats']['total']}", flush=True) + return + +if __name__ == '__main__': + args = parseArgs() + train_classifier_model(args) \ No newline at end of file diff --git a/applications/KWS_Phoneme/train_phoneme.py b/applications/KWS_Phoneme/train_phoneme.py new file mode 100644 index 000000000..4a0e52162 --- /dev/null +++ b/applications/KWS_Phoneme/train_phoneme.py @@ -0,0 +1,209 @@ +import argparse +import os +import re +import numpy as np +import torch +# Aux scripts +import kwscnn +import multiprocessing +from data_pipe import get_ASR_datasets + +def parseArgs(): + """ + Parse the command line arguments + Describes the architecture and the hyper-parameters + """ + parser = argparse.ArgumentParser() + # Args for Model Traning + parser.add_argument('--phoneme_model_save_folder', type=str, default='./phoneme_model', help="Folder to save the checkpoint") + parser.add_argument('--phoneme_model_load_ckpt', type=str, default=None, help="Checkpoint file to be loaded") + parser.add_argument('--optim', type=str, default='adam', help="Optimizer to be used") + parser.add_argument('--lr', type=float, default=0.001, help="Optimizer learning rate") + parser.add_argument('--epochs', type=int, default=200, help="Number of epochs for training") + parser.add_argument('--save_tick', type=int, default=1, help="Number of epochs to wait to save") + parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") + parser.add_argument("--gpu", type=str, default='0', help="GPU indicies Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") + + # Args for DataLoader + parser.add_argument('--base_path', type=str, required=True, help="Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here.") + parser.add_argument('--rir_base_path', type=str, required=True, help="Folder with the reverbration files") + parser.add_argument('--additive_base_path', type=str, required=True, help="Folder with additive noise files") + parser.add_argument('--phoneme_text_file', type=str, required=True, help="Text files with pre-fixed phons") + parser.add_argument('--pretraining_length_mean', type=int, default=6, help="Mean of the audio clips lengths") + parser.add_argument('--pretraining_length_var', type=int, default=1, help="variance of the audio clip lengths") + parser.add_argument('--pretraining_batch_size', type=int, default=256, help="Batch size for the pirpeline") + parser.add_argument('--snr_samples', type=str, default="0,5,10,25,100,100", help="SNR values for additive noise files") + parser.add_argument('--wgn_snr_samples', type=str, default="5,10,15,100,100", help="SNR values for white gaussian noise") + parser.add_argument('--gain_samples', type=str, default="1.0,0.25,0.5,0.75", help="Gain values for the processed signal") + parser.add_argument('--rir_chance', type=float, default=0.25, help="Probability of performing reverbration") + parser.add_argument('--synth_chance', type=float, default=0.5, help="Probablity of pre-processing the signal with noise and reverb") + parser.add_argument('--pre_phone_list', action='store_true', help="Use pre-fixed list of phonemes") + # Args for Phoneme + parser.add_argument('--phoneme_cnn_channels', type=int, default=400, help="Number od channels for the CNN layers") + parser.add_argument('--phoneme_rnn_hidden_size', type=int, default=200, help="Number of RNN hidden states") + parser.add_argument('--phoneme_rnn_layers', type=int, default=1, help="Number of RNN layers") + parser.add_argument('--phoneme_rank', type=int, default=50, help="Rank of the CNN layers weights") + parser.add_argument('--phoneme_fwd_context', type=int, default=15, help="RNN forward window context") + parser.add_argument('--phoneme_bwd_context', type=int, default=9, help="RNN backward window context") + parser.add_argument('--phoneme_phoneme_isBi', action='store_true', help="Use Bi-Directional RNN") + parser.add_argument('--phoneme_num_labels', type=int, default=41, help="Number og phoneme labels") + + args = parser.parse_args() + + # Parse the gain and SNR values to a float format + args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] + args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] + args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] + + # Fix the number of workers for the data Loader. If == -1 then use all possible workers + if args.workers == -1: + args.workers = multiprocessing.cpu_count() + + print(f"Args : {args}", flush=True) + return args + +def train_phoneme_model(args): + """ + Train the Phoneme Model on the designated dataset + The Dataset loader is defined in data_pipe.py + Default dataset used is LibriSpeeech. Change the paths and file reader to change datasets + + args: args object (contains info about model and training) + """ + # GPU Settings + gpu_str = str() + for gpu in args.gpu.split(','): + gpu_str = gpu_str + str(gpu) + "," + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str + use_cuda = torch.cuda.is_available() and (args.gpu != -1) + device = torch.device("cuda" if use_cuda else "cpu") + + # Instantiate model + phoneme_model = kwscnn.DSCNN_RNN_Block(cnn_channels=args.phoneme_cnn_channels, + rnn_hidden_size=args.phoneme_rnn_hidden_size, + rnn_num_layers=args.phoneme_rnn_layers, + device=device, rank=args.phoneme_rank, + fwd_context=args.phoneme_fwd_context, + bwd_context=args.phoneme_bwd_context, + num_labels=args.phoneme_num_labels) + + # Transfer to specified device + phoneme_model.to(device) + phoneme_model = torch.nn.DataParallel(phoneme_model) + model = {'name': phoneme_model.module.__name__, 'phoneme': phoneme_model} + + + # Optimizer + if args.optim == "adam": + model['opt'] = torch.optim.Adam(model['phoneme'].parameters(), lr=args.lr) + if args.optim == "sgd": + model['opt'] = torch.optim.SGD(model['phoneme'].parameters(), lr=args.lr) + + # Load the sepcified checkpoint. 'phoneme_model_load_ckpt' must point to a chcekpoint and not folder + if args.phoneme_model_load_ckpt is not None: + if os.path.exists(args.phoneme_model_load_ckpt): + # Get the number from the phoneme checkpoint path + start_epoch = args.phoneme_model_load_ckpt # Temporarlity store the full ckpt path + start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux) + start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows) + start_epoch = int(start_epoch.split('.')[0]) # reatin the integers + # Load Checkpoint + latest_ckpt = torch.load(args.phoneme_model_load_ckpt, map_location=device) + # Load specific state_dicts() and print the latest stats + model['phoneme'].load_state_dict(latest_ckpt['phoneme_state_dict']) + model['opt'].load_state_dict(latest_ckpt['opt_state_dict']) + print(f"Checkpoint Stats : {latest_ckpt['train_stats']}", flush=True) + else: + raise ValueError("Invalid Checkpoint Path") + else: + start_epoch = 0 + + # Instantiate dataloaders, essential variables and save folders + train_dataset = get_ASR_datasets(args) + train_loader = train_dataset.loader + total_batches = len(train_loader) + output_frame_rate = 3 + save_path = args.phoneme_model_save_folder + os.makedirs(args.phoneme_model_save_folder, exist_ok=True) + + print(f"Pre Phone List {args.pre_phone_list}", flush=True) + print(f"Start Epoch : {start_epoch}", flush=True) + print(f"Device : {device}", flush=True) + print(f"Output Frame Rate (multiple of 10ms): {output_frame_rate}", flush=True) + print(f"Number of Batches: {total_batches}", flush=True) + + # Train Loop + for epoch in range(start_epoch + 1, args.epochs): + model['train_stats'] = {'loss': 0, 'predstd': 0, 'correct': 0, 'valid': 0} + for features, label in train_loader: + features = features.to(device) + label = label.to(device) + model['opt'].zero_grad() + + # Data-padding for bricking + features = features.permute((0, 2, 1)) # NCL to NLC + mod_len = features.shape[1] + pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate + pad_len_feature = pad_len_mod + pad_data = torch.zeros(features.shape[0], pad_len_feature, + features.shape[2]).to(device) + features = torch.cat((features, pad_data), dim=1) + + assert (features.shape[1]) % output_frame_rate == 0 + # Augmenting the label accordingly + pad_len_label = pad_len_feature + pad_data = torch.ones(label.shape[0], pad_len_label).to(device) * (-1) + pad_data = pad_data.type(torch.long) + label = torch.cat((label, pad_data), dim=1) + + # Get the posterior predictions and trim the labels to the same length as the predictions + features = features.permute((0, 2, 1)) # NLC to NCL + posteriors = model['phoneme'](features) + N, C, L = posteriors.shape + trim_label = label[:, ::output_frame_rate] # 30ms frame_rate + trim_label = trim_label[:, :L] + + # Permute and ready the final and pred labels values + flat_posteriors = posteriors.permute((0, 2, 1)) # TO NLC + flat_posteriors = flat_posteriors.reshape((-1, C)) # to [NL] x C + flat_labels = trim_label.reshape((-1)) + + _, idx = torch.max(flat_posteriors, dim=1) + correct_count = (idx == flat_labels).detach().sum() + valid_count = (flat_labels >= 0).detach().sum() + + # Loss and backward step + loss_phoneme_model = torch.nn.functional.cross_entropy(flat_posteriors, + flat_labels, ignore_index=-1) + loss_phoneme_model.backward() + torch.nn.utils.clip_grad_norm_(model['phoneme'].parameters(), 10.0) + model['opt'].step() + + # Stats + pred_std = idx.to(torch.float32).std() + + model['train_stats']['loss'] += loss_phoneme_model.detach() + model['train_stats']['correct'] += correct_count + model['train_stats']['predstd'] += pred_std + model['train_stats']['valid'] += valid_count + + if epoch % args.save_tick == 0: + # Save the model + torch.save({'phoneme_state_dict': model['phoneme'].state_dict(), + 'opt_state_dict': model['opt'].state_dict(), + 'train_stats' : model['train_stats']}, os.path.join(save_path, f'{epoch}.pt')) + + valid_frames = model['train_stats']['valid'].cpu() + correct_frames = model['train_stats']['correct'].cpu().to(torch.float32) + epoch_prestd = model['train_stats']['predstd'].cpu() + + avg_ce = model['train_stats']['loss'].cpu() / total_batches + avg_err = 100 - 100.0 * (correct_frames / valid_frames) + + print(f"Summary for Epoch {epoch} for Model {model['name']}", flush=True) + print(f"CE: {avg_ce}, ERR: {avg_err}, FRAMES {correct_frames} / {valid_frames}, PREDSTD: {epoch_prestd / total_batches}", flush=True) + return + +if __name__ == '__main__': + args = parseArgs() + train_phoneme_model(args) \ No newline at end of file diff --git a/applications/KWS_Phoneme/utils.py b/applications/KWS_Phoneme/utils.py new file mode 100644 index 000000000..a5f5723cb --- /dev/null +++ b/applications/KWS_Phoneme/utils.py @@ -0,0 +1,40 @@ +""" +Some Torch Functions required for overwriting +Convolutions to Low Rank Convolutions +""" + +from typing import List + +from torch._six import container_abcs +from itertools import repeat + + +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + return parse + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _reverse_repeat_tuple(t, n): + r"""Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return tuple(x for x in reversed(t) for _ in range(n)) + + +def _list_with_default(out_size, defaults): + # type: (List[int], List[int]) -> List[int] + if isinstance(out_size, int): + return out_size + if len(defaults) <= len(out_size): + raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1)) + return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])] From 7203332f0caaacdbb4c42ecd4cb849c5ec623c1e Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Wed, 7 Jul 2021 19:12:15 +0530 Subject: [PATCH 2/8] Add license --- applications/KWS_Phoneme/README.md | 5 +++++ .../KWS_Phoneme/auxiliary_files/convert_sampling_rate.py | 3 +++ .../KWS_Phoneme/auxiliary_files/download_youtube_data.py | 3 +++ applications/KWS_Phoneme/data_pipe.py | 3 +++ applications/KWS_Phoneme/kwscnn.py | 3 +++ applications/KWS_Phoneme/train_classifier.py | 3 +++ applications/KWS_Phoneme/train_phoneme.py | 3 +++ applications/KWS_Phoneme/utils.py | 3 +++ 8 files changed, 26 insertions(+) diff --git a/applications/KWS_Phoneme/README.md b/applications/KWS_Phoneme/README.md index b610621ce..12d0a1643 100644 --- a/applications/KWS_Phoneme/README.md +++ b/applications/KWS_Phoneme/README.md @@ -3,6 +3,10 @@ # Project Description There are two major issues in the existing KWS systems (a) they are not robust to heavy background noise and random utterances, and (b) they require collecting a lot of data, hampering the ease of adding a new keyword. Tackling these issues from a different perspective, we propose a new two staged scheme with a model for predicting phonemes which are in turn used for phoneme based keyword classification. +First we train a phoneme classification model which gives the phoneme transcription of the input speech snippet. For training this phoneme classifier, we use a large public speech dataset like LibreSpeech. The public dataset can be aligned (meaning get the phoneme labels for each speech snippet in the data) using Montreal Forced Aligner. We also add reverberations and additive noise to the speech samples from the public dataset to make the phoneme classifier training robust to various accents, background noise and varied environment. In this project, we predict phonemes at every 10ms which is the standard way. You can find the aligned LibreSpeech dataset we used for training here. + +In the second part, we use the predicted phoneme outputs from the phoneme classifier for predicting the input keyword. We train a 1 layer FastGRNN classifier to predict the keyword based on the phoneme transcription as input. Since the phoneme classifier training has been done to account for diverse accent, background noise and environments, the keyword classifier can be trained using a small number of Text-To-Speech(TTS) samples generated using any standard TTS api from cloud services like Azure, Google Cloud or AWS. + This gives two advantages: (a) The phoneme model is trained to account for diverse accents and background noise settings, thus the flexible keyword classifier training requires only a small number of keyword samples, and (b) Empirically this method was able to detect keywords from as far as 9ft of distance. Further, the phoneme model has a small size of around 250k parameters and can fit on a Cortex M7 microcontroller. # Training the Phoneme Classifier @@ -15,6 +19,7 @@ This gives two advantages: (a) The phoneme model is trained to account for diver 1) Our method takes as input the speech snippet and passes it through the phoneme classifier 2) Keywords are detected by training a keyword classifier over the detected phonemes 3) For training the keyword classifier, we use Azure and Google Text-To-Speech API to get the training data (keyword snippets) +4) For example, if you want to train a Keyword classifier for the keywords in the Google30 dataset, generate TTS samples from the Azure/Google-Cloud/AWS API for each of the 30 keywords. The TTS samples for each keyword must be stored in a separate folder named according to the keyword. More details about how the generated TTS data should be stored are mentioned below in sample use case for classifier model training. # Sample Use Cases diff --git a/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py b/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py index effb35f5c..1685ab168 100644 --- a/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py +++ b/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + import os import librosa import numpy as np diff --git a/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py b/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py index eede85d57..b26efe8e7 100644 --- a/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py +++ b/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + import csv import os import argparse diff --git a/applications/KWS_Phoneme/data_pipe.py b/applications/KWS_Phoneme/data_pipe.py index c9337e2be..eec669934 100644 --- a/applications/KWS_Phoneme/data_pipe.py +++ b/applications/KWS_Phoneme/data_pipe.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + import torch import argparse import torch.utils.data diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py index 3f3a2710f..eafd07043 100644 --- a/applications/KWS_Phoneme/kwscnn.py +++ b/applications/KWS_Phoneme/kwscnn.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + import torch import torch.nn as nn from torch.autograd import Variable diff --git a/applications/KWS_Phoneme/train_classifier.py b/applications/KWS_Phoneme/train_classifier.py index dd0623475..1a9dbf2fe 100644 --- a/applications/KWS_Phoneme/train_classifier.py +++ b/applications/KWS_Phoneme/train_classifier.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + import argparse import os import re diff --git a/applications/KWS_Phoneme/train_phoneme.py b/applications/KWS_Phoneme/train_phoneme.py index 4a0e52162..7c2ee2649 100644 --- a/applications/KWS_Phoneme/train_phoneme.py +++ b/applications/KWS_Phoneme/train_phoneme.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + import argparse import os import re diff --git a/applications/KWS_Phoneme/utils.py b/applications/KWS_Phoneme/utils.py index a5f5723cb..a267a14b0 100644 --- a/applications/KWS_Phoneme/utils.py +++ b/applications/KWS_Phoneme/utils.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + """ Some Torch Functions required for overwriting Convolutions to Low Rank Convolutions From 43a9b377e21a911313321168e8e5830e86cffca9 Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Thu, 8 Jul 2021 17:39:45 +0530 Subject: [PATCH 3/8] Remove redundant functions --- applications/KWS_Phoneme/kwscnn.py | 147 +++++++++++++++-------------- 1 file changed, 74 insertions(+), 73 deletions(-) diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py index eafd07043..a754c3b41 100644 --- a/applications/KWS_Phoneme/kwscnn.py +++ b/applications/KWS_Phoneme/kwscnn.py @@ -12,8 +12,10 @@ from edgeml_pytorch.graph.rnn import * class _IndexSelect(torch.nn.Module): - """Channel permutation module. The purpose of this is to allow mixing across the CNN groups.""" def __init__(self, channels, direction, groups): + """ + Channel permutation module. The purpose of this is to allow mixing across the CNN groups + """ super(_IndexSelect, self).__init__() if channels % groups != 0: @@ -45,6 +47,17 @@ def __init__(self): super(_TanhGate, self).__init__() def forward(self, value): + """ + Applies a custom activation function + The first half of the channels are passed through sigmoid layer and the next hlaf though a tanh + The outputs are multiplied and returned + + Input + value: A tensor of shape (batch, channels, *) + + Output + activation output + """ channels = value.shape[1] piv = int(channels/2) @@ -61,22 +74,29 @@ class LR_conv(nn.Conv1d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', rank=50): - super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) + """ + A convolution layer with the weight matrix subjected to a low-rank decomposition. Currently for kernel size of 5 - assert kernel_size == 5 - rank = rank + Input + rank : The rank used for the low-rank decomposition on the weight/kernl tensor + All other parameters are similar to that of a convolution layer + Only change is the decomposition of the output channels into low-rank tensors + """ + self.kernel_size = kernel_size + self.rank = rank self.W1 = Parameter(torch.Tensor(self.out_channels, rank)) + # As per PyTorch Standard nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) - self.W2 = Parameter(torch.Tensor(rank, self.in_channels * 5)) + self.W2 = Parameter(torch.Tensor(rank, self.in_channels * self.kernel_size)) nn.init.kaiming_uniform_(self.W2, a=math.sqrt(5)) self.weight = None def forward(self, input): lr_weight = torch.matmul(self.W1, self.W2) - lr_weight = torch.reshape(lr_weight, (self.out_channels, self.in_channels, 5)) + lr_weight = torch.reshape(lr_weight, (self.out_channels, self.in_channels, self.kernel_size)) if self.padding_mode != 'zeros': return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), lr_weight, self.bias, self.stride, @@ -84,17 +104,28 @@ def forward(self, input): return F.conv1d(input, lr_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) -class DSCNNBlockLR_k5(torch.nn.Module): - """A depth-separate CNN block""" - +class PreRNNConvBlock(torch.nn.Module): def __init__( self, in_channels, out_channels, kernel, stride=1, groups=1, avg_pool=2, dropout=0.1, - batch_norm=0.1, do_depth=True, shuffle=0, + batch_norm=0.1, shuffle=0, activation='sigmoid', rank=50): - - super(DSCNNBlockLR_k5, self).__init__() - + super(PreRNNConvBlock, self).__init__() + """ + A low-rank convolution layer combination with pooling and activation layers. Currently for kernel size of 5 + + Input + in_channels : number of input channels for the conv layer + out_channels : number of output channels for the conv layer + kernel : conv kernel size + stride : conv stride + groups : number of groups for conv layer + avg_pool : kernel size for avgerage pooling layer + dropout : dropout layer probability + batch_norm : momemtum for batch norm + activation : activation layer + rank : rank for low-rank decomposition for conv layer weights + """ activators = { 'sigmoid': torch.nn.Sigmoid(), 'relu': torch.nn.ReLU(), @@ -116,10 +147,7 @@ def __init__( else: batch_block = None - #do_depth = False - # depth_cnn = torch.nn.Conv1d(in_channels, in_channels, kernel_size=kernel, stride=1, groups=in_channels) depth_cnn = None - # point_cnn = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups) point_cnn = LR_conv(in_channels, out_channels, kernel_size=kernel, stride=stride, groups=groups, rank=rank, padding=2) if shuffle != 0 and groups > 1: @@ -138,65 +166,38 @@ def __init__( dropout_block = None seq1 = [nonlin, batch_block, depth_cnn, shuffler, point_cnn, dropout_block, pool] - # seq1 = [nonlin, batch_block, depth_cnn]#, shuffler, point_cnn, dropout_block, pool] seq_f1 = [item for item in seq1 if item is not None] if len(seq_f1) == 1: self._op1 = seq_f1[0] else: self._op1 = torch.nn.Sequential(*seq_f1) - seq2 = [shuffler, dropout_block, pool] - seq_f2 = [item for item in seq2 if item is not None] - if len(seq_f2) == 1: - print("Only 1 op in seq2") - self._op2 = seq_f2[0] - else: - self._op2 = torch.nn.Sequential(*seq_f2) - def forward(self, x): - xb = self._op1[0](x) - x_cnn = self._op1[1](xb) - return x_cnn - - -class LR_pointcnn(nn.Conv1d): - def __init__(self, in_channels, out_channels, kernel_size, - stride=1, padding=0, dilation=1, groups=1, - bias=True, padding_mode='zeros', rank=50): - super().__init__( - in_channels, out_channels, kernel_size, stride, - padding, dilation, groups, bias, padding_mode) - - # assert stride == 3 - assert kernel_size == 1 - rank = rank - self.W1 = Parameter(torch.Tensor(self.out_channels, rank)) - nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) - self.W2 = Parameter(torch.Tensor(rank, self.in_channels)) - nn.init.kaiming_uniform_(self.W2, a=math.sqrt(5)) - self.weight = None - - def forward(self, input): - lr_weight = torch.matmul(self.W1, self.W2) - lr_weight = torch.reshape(lr_weight, (self.out_channels, self.in_channels, 1)) - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - lr_weight, self.bias, self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, lr_weight, self.bias, self.stride, - self.padding, self.dilation, self.groups) - -class DSCNNBlockLR_better(torch.nn.Module): - """A depth-separate CNN block""" + x = self._op1(x) + return x +class DSCNNBlockLR(torch.nn.Module): def __init__( self, in_channels, out_channels, kernel, stride=1, groups=1, avg_pool=2, dropout=0.1, - batch_norm=0.1, do_depth=True, shuffle=0, + batch_norm=0.1, shuffle=0, activation='sigmoid', rank=50): - - super(DSCNNBlockLR_better, self).__init__() - + super(DSCNNBlockLR, self).__init__() + """ + A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + + Input + in_channels : number of input channels for the pointwise conv layer + out_channels : number of output channels for the pointwise conv layer + kernel : conv kernel size for depthwise layer + stride : conv stride + groups : number of groups for conv layer + avg_pool : kernel size for avgerage pooling layer + dropout : dropout layer probability + batch_norm : momemtum for batch norm + activation : activation layer + rank : rank for low-rank decomposition for conv layer weights + """ activators = { 'sigmoid': torch.nn.Sigmoid(), 'relu': torch.nn.ReLU(), @@ -219,7 +220,7 @@ def __init__( batch_block = None depth_cnn = torch.nn.Conv1d(in_channels, in_channels, kernel_size=kernel, stride=1, groups=in_channels, padding=2) - point_cnn = LR_pointcnn(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups, rank=rank) + point_cnn = LR_conv(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups, rank=rank) if shuffle != 0 and groups > 1: shuffler = _IndexSelect(in_channels, shuffle, groups) @@ -237,7 +238,6 @@ def __init__( dropout_block = None seq = [nonlin, batch_block, depth_cnn, shuffler, point_cnn, dropout_block, pool] - # seq1 = [nonlin, batch_block, depth_cnn]#, shuffler, point_cnn, dropout_block, pool] seq_f = [item for item in seq if item is not None] if len(seq_f) == 1: self._op = seq_f[0] @@ -249,8 +249,9 @@ def forward(self, x): return x class BiFastGRNN(nn.Module): - "Bi Directional FastGRNN" - + """ + Bi Directional FastGRNN + """ def __init__(self, inputDims, hiddenDims, gate_nonlinearity, update_nonlinearity, rank): super(BiFastGRNN, self).__init__() @@ -373,8 +374,8 @@ def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, num_labels, rank): self.CNN1 = torch.nn.Sequential( - DSCNNBlockLR_k5(80, cnn_channels, 5, 1, 1, - 0, 0, do_depth=False, batch_norm=1e-2, + PreRNNConvBlock(80, cnn_channels, 5, 1, 1, + 0, 0, batch_norm=1e-2, activation='none', rank=rank)) #torch tanh layer directly in the forward pass @@ -383,28 +384,28 @@ def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, self.gate_nonlinearity, self.update_nonlinearity, rank) - self.CNN2 = DSCNNBlockLR_better(2 * rnn_hidden_size, + self.CNN2 = DSCNNBlockLR(2 * rnn_hidden_size, 2 * rnn_hidden_size, batch_norm=1e-2, dropout=0, kernel=5, activation='tanhgate', rank=rank) - self.CNN3 = DSCNNBlockLR_better(2 * rnn_hidden_size, + self.CNN3 = DSCNNBlockLR(2 * rnn_hidden_size, 2 * rnn_hidden_size, batch_norm=1e-2, dropout=0, kernel=5, activation='tanhgate', rank=rank) - self.CNN4 = DSCNNBlockLR_better(2 * rnn_hidden_size, + self.CNN4 = DSCNNBlockLR(2 * rnn_hidden_size, 2 * rnn_hidden_size, batch_norm=1e-2, dropout=0, kernel=5, activation='tanhgate', rank=rank) - self.CNN5 = DSCNNBlockLR_better(2 * rnn_hidden_size, + self.CNN5 = DSCNNBlockLR(2 * rnn_hidden_size, num_labels, batch_norm=1e-2, dropout=0, From 113ab2311a13a7c2a435442270376e049dcaf175 Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Fri, 9 Jul 2021 00:27:46 +0530 Subject: [PATCH 4/8] finish documenting kwscnn --- applications/KWS_Phoneme/kwscnn.py | 141 ++++++++++++++++++++++++++--- 1 file changed, 128 insertions(+), 13 deletions(-) diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py index a754c3b41..d966cf6d0 100644 --- a/applications/KWS_Phoneme/kwscnn.py +++ b/applications/KWS_Phoneme/kwscnn.py @@ -56,7 +56,7 @@ def forward(self, value): value: A tensor of shape (batch, channels, *) Output - activation output + activation output of shape (batch, channels/2, *) """ channels = value.shape[1] piv = int(channels/2) @@ -95,6 +95,16 @@ def __init__(self, in_channels, out_channels, kernel_size, self.weight = None def forward(self, input): + """ + The decomposed weights are multiplied to enforce the low-rank constraint + The conv1d is performed as usual post multiplication + + Input + input: Input of shape similar to that of which is fed to a conv layer + + Output + convolution output + """ lr_weight = torch.matmul(self.W1, self.W2) lr_weight = torch.reshape(lr_weight, (self.out_channels, self.in_channels, self.kernel_size)) if self.padding_mode != 'zeros': @@ -173,6 +183,15 @@ def __init__( self._op1 = torch.nn.Sequential(*seq_f1) def forward(self, x): + """ + Apply the set of layers initilized in __init__ + + Input + x: A tensor of shape (batch, channels, length) + + Output + network block output of shape (batch, channels, length) + """ x = self._op1(x) return x @@ -245,12 +264,23 @@ def __init__( self._op = torch.nn.Sequential(*seq_f) def forward(self, x): + """ + Apply the set of layers initilized in __init__ + + Input + x: A tensor of shape (batch, channels, length) + + Output + network block output of shape (batch, channels, length) + """ x = self._op(x) return x class BiFastGRNN(nn.Module): """ Bi Directional FastGRNN + + Parameters and arguments are similar to the torch RNN counterparts """ def __init__(self, inputDims, hiddenDims, gate_nonlinearity, update_nonlinearity, rank): @@ -273,6 +303,19 @@ def __init__(self, inputDims, hiddenDims, gate_nonlinearity, uRank=rank) def forward(self, input_f, input_b): + """ + Pass the inputs to forward and backward layers + Please note the backward layer is similar to the forward layer and the input needs to be fed reversed accordingly + Tensors are of the shape (batch, length, channels) + + Input + input_f : input to the forward layer + input_b : input to the backward layer. Input needs to be reversed before passing through this forward method + + Output + output1 : output of the forward layer + output2 : output of the backward layer + """ # Bidirectional FastGRNN output1 = self.cell_fwd(input_f) output2 = self.cell_bwd(input_b) @@ -282,8 +325,21 @@ def forward(self, input_f, input_b): def X_preRNN_process(X, fwd_context, bwd_context): - "Bricked RNN Architecture" - + """ + A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + + Input + in_channels : number of input channels for the pointwise conv layer + out_channels : number of output channels for the pointwise conv layer + kernel : conv kernel size for depthwise layer + stride : conv stride + groups : number of groups for conv layer + avg_pool : kernel size for avgerage pooling layer + dropout : dropout layer probability + batch_norm : momemtum for batch norm + activation : activation layer + rank : rank for low-rank decomposition for conv layer weights + """ #FWD bricking brickLength = fwd_context hopLength = 3 @@ -309,7 +365,21 @@ def X_preRNN_process(X, fwd_context, bwd_context): def X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b): - + """ + A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + + Input + in_channels : number of input channels for the pointwise conv layer + out_channels : number of output channels for the pointwise conv layer + kernel : conv kernel size for depthwise layer + stride : conv stride + groups : number of groups for conv layer + avg_pool : kernel size for avgerage pooling layer + dropout : dropout layer probability + batch_norm : momemtum for batch norm + activation : activation layer + rank : rank for low-rank decomposition for conv layer weights + """ #Forward bricks folding X_f = torch.reshape(X_f, [oldShape_f[0], oldShape_f[1], oldShape_f[2], -1]) #X_f [batch, num_bricks, brickLen, hiddenDim] @@ -347,9 +417,21 @@ class DSCNN_RNN_Block(torch.nn.Module): def __init__(self, cnn_channels, rnn_hidden_size, rnn_num_layers, device, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", isBi=True, num_labels=41, rank=None, fwd_context=15, bwd_context=9): - super(DSCNN_RNN_Block, self).__init__() + """ + A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + Input + cnn_channels : number of the output channels for the first CNN block + rnn_hidden_size : hidden dimensions of the FastGRNN + rnn_num_layers : number of FastGRNN layers + device : device on which the tensors would placed + gate_nonlinearity : activation function for the gating in the FastGRNN + update_nonlinearity : activation function for the update fucntion in the FastGRNN + isBi : boolean flag to use bi-directional FastGRNN + fwd_context : window for the forward pass + bwd_context : window for the backward pass + """ self.cnn_channels = cnn_channels self.rnn_hidden_size = rnn_hidden_size self.rnn_num_layers = rnn_num_layers @@ -372,7 +454,10 @@ def __init__(self, cnn_channels, rnn_hidden_size, rnn_num_layers, def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, num_labels, rank): - + """ + Declare the netwok layers + Arguments can be infered from the __init__ + """ self.CNN1 = torch.nn.Sequential( PreRNNConvBlock(80, cnn_channels, 5, 1, 1, 0, 0, batch_norm=1e-2, @@ -413,7 +498,15 @@ def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, activation='tanhgate', rank=rank) def forward(self, features): + """ + Apply the set of layers initilized in __init__ + + Input + features: A tensor of shape (batch, channels, length) + Output + network block output in the form (batch, channels, length) + """ batch, _, max_seq_len = features.shape X = self.CNN1(features) # Down to 30ms inference / 250ms window X = torch.tanh(X) @@ -429,13 +522,9 @@ def forward(self, features): X_f, X_b = self.RNN0(X_f, X_b_f) X = X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b) - # X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True) - - # X = X.view(batch, max_seq_len, 2* self.rnn_hidden_size) # re-permute to get [batch, channels, max_seq_len/3 ] X = X.permute((0, 2, 1)) # NLC to NCL - # print("Pre CNN2 Shape : ", X.shape) X = self.CNN2(X) X = self.CNN3(X) X = self.CNN4(X) @@ -447,9 +536,21 @@ class Binary_Classification_Block(torch.nn.Module): def __init__(self, in_size, rnn_hidden_size, rnn_num_layers, device, islstm=False, isBi=True, momentum=1e-2, num_labels=2, dropout=0, batch_assertion=False): - super(Binary_Classification_Block, self).__init__() + """ + A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + Input + in_size : number of input channels to the layer + rnn_hidden_size : hidden dimensions of the RNN layer + rnn_num_layers : number of layers for the RNN layer + device : device on which the tensors would placed + islstm : boolean flag to use the LSTM. Flase would use GRU + isBi : boolean flag to use the bi-directional variat of the RNN + momentum : momentum for the batch-norm layer + num_labels : number of output labels + dropout : probability for the dropout layer + """ self.in_size = in_size self.rnn_hidden_size = rnn_hidden_size self.rnn_num_layers = rnn_num_layers @@ -471,7 +572,10 @@ def __init__(self, in_size, rnn_hidden_size, rnn_num_layers, self.__name__ = 'Binary_Classification_Block_2lay' def declare_network(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels): - + """ + Declare the netwok layers + Arguments can be infered from the __init__ + """ self.CNN1 = torch.nn.Sequential( torch.nn.LeakyReLU(negative_slope=0.01), torch.nn.BatchNorm1d(in_size, affine=False, @@ -498,6 +602,15 @@ def declare_network(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels): num_labels)) def forward(self, features, seqlen): + """ + Apply the set of layers initilized in __init__ + + Input + features: A tensor of shape (batch, channels, length) + + Output + network block output in the form (batch, length, channels). length will be 1 + """ batch, _, _ = features.shape if self.islstm: @@ -546,7 +659,9 @@ def forward(self, features, seqlen): return X def init_hidden(self, batch, rnn_hidden_size, rnn_num_layers): - + """ + Used to initialize the first hidden state of the RNN. It is currrently zero. THe user is free to edit this function for addtionaly analysis + """ # the weights are of the form (batch, num_layers * num_directions , hidden_size) if self.batch_assertion: hidden = torch.zeros(rnn_num_layers * self.direction_param, batch, From 35e0159cb3e2c011f667c3e1981e3395bae09aa2 Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Sat, 10 Jul 2021 20:41:23 +0530 Subject: [PATCH 5/8] Fix typos --- applications/KWS_Phoneme/README.md | 13 +++-- .../KWS_Phoneme/auxiliary_files/README.md | 15 ++---- applications/KWS_Phoneme/data_pipe.py | 34 ++++++------- applications/KWS_Phoneme/kwscnn.py | 50 +++++++++---------- applications/KWS_Phoneme/train_classifier.py | 20 ++++---- applications/KWS_Phoneme/train_phoneme.py | 12 ++--- 6 files changed, 71 insertions(+), 73 deletions(-) diff --git a/applications/KWS_Phoneme/README.md b/applications/KWS_Phoneme/README.md index 12d0a1643..15f41603f 100644 --- a/applications/KWS_Phoneme/README.md +++ b/applications/KWS_Phoneme/README.md @@ -3,16 +3,16 @@ # Project Description There are two major issues in the existing KWS systems (a) they are not robust to heavy background noise and random utterances, and (b) they require collecting a lot of data, hampering the ease of adding a new keyword. Tackling these issues from a different perspective, we propose a new two staged scheme with a model for predicting phonemes which are in turn used for phoneme based keyword classification. -First we train a phoneme classification model which gives the phoneme transcription of the input speech snippet. For training this phoneme classifier, we use a large public speech dataset like LibreSpeech. The public dataset can be aligned (meaning get the phoneme labels for each speech snippet in the data) using Montreal Forced Aligner. We also add reverberations and additive noise to the speech samples from the public dataset to make the phoneme classifier training robust to various accents, background noise and varied environment. In this project, we predict phonemes at every 10ms which is the standard way. You can find the aligned LibreSpeech dataset we used for training here. +First we train a phoneme classification model which gives the phoneme transcription of the input speech snippet. For training this phoneme classifier, we use a large public speech dataset like LibriSpeech. The public dataset can be aligned (meaning get the phoneme labels for each speech snippet in the data) using Montreal Forced Aligner. We also add reverberations and additive noise to the speech samples from the public dataset to make the phoneme classifier training robust to various accents, background noise and varied environment. In this project, we predict phonemes at every 10ms which is the standard way. You can find the aligned LibriSpeech dataset we used for training here. In the second part, we use the predicted phoneme outputs from the phoneme classifier for predicting the input keyword. We train a 1 layer FastGRNN classifier to predict the keyword based on the phoneme transcription as input. Since the phoneme classifier training has been done to account for diverse accent, background noise and environments, the keyword classifier can be trained using a small number of Text-To-Speech(TTS) samples generated using any standard TTS api from cloud services like Azure, Google Cloud or AWS. -This gives two advantages: (a) The phoneme model is trained to account for diverse accents and background noise settings, thus the flexible keyword classifier training requires only a small number of keyword samples, and (b) Empirically this method was able to detect keywords from as far as 9ft of distance. Further, the phoneme model has a small size of around 250k parameters and can fit on a Cortex M7 microcontroller. +This gives two advantages: (a) The phoneme model is trained to account for diverse accents and background noise settings, thus the flexible keyword classifier training requires only a small number of keyword samples, and (b) Empirically this method was able to detect keywords from as far as 9ft of distance. Further, the phoneme model has a small size of around 250k parameters and can fit on a Cortex M7 micro-controller. # Training the Phoneme Classifier -1) Train a phoneme classification model on some public speech dataset like Librespeech +1) Train a phoneme classification model on some public speech dataset like LibriSpeech 2) Training speech dataset can be labelled using Montreal Force Aligner -3) Speech snippets are convolved with reverberation files, and additive noises from youtube or other open source are added +3) Speech snippets are convolved with reverberation files, and additive noises from YouTube or other open source are added 4) We also add white gaussian noise of various SNRs # Training the KWS Model @@ -26,7 +26,7 @@ This gives two advantages: (a) The phoneme model is trained to account for diver ## Phoneme Model Training The following command can be used to instantiate and train the phoneme model. ``` -python train_phoneme.py --base_path=/path/to/librespeech_data/ --rir_base_path=/path/to/reverb_files/ --additive_base_path=/path/to/additive_noises/ --snr_samples="0,5,10,25,100,100" --rir_chance=0.5 +python train_phoneme.py --base_path=/path/to/librispeech_data/ --rir_base_path=/path/to/reverb_files/ --additive_base_path=/path/to/additive_noises/ --snr_samples="0,5,10,25,100,100" --rir_chance=0.5 ``` Some important command line arguments: 1) base_path : Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here. @@ -46,3 +46,6 @@ Some important command line arguments: 3) phoneme_model_load_ckpt : The full path of the checkpoint file that would be used to load the weights to the instantiated phoneme model 4) rir_base_path, additive_base_path : Path to the reverb and additive noise files 5) synth : Boolean flag for specifying if reverberations and noise addition has to be done + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT license. \ No newline at end of file diff --git a/applications/KWS_Phoneme/auxiliary_files/README.md b/applications/KWS_Phoneme/auxiliary_files/README.md index a56a617ae..64d5ba400 100644 --- a/applications/KWS_Phoneme/auxiliary_files/README.md +++ b/applications/KWS_Phoneme/auxiliary_files/README.md @@ -1,15 +1,7 @@ # Auxiliary Files to help Download and Prepare the Data -## Note -When running commands it is recommended to use the following format to run the files uninterrupted (detached) and log the output. -``` -nohup python srcipt_execution args > log.txt & -``` -Please replace script_execution with the python commands below.
-Alternately tmux or other commands can be used in place of the above format. - ## YouTube Additive Noise -Run the following commands to download the CSV Files to download the YouTube Additive Noise Data (there is no need to use nohup for the wget file) : +Run the following commands to download the CSV Files to download the YouTube Additive Noise Data : ``` wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv @@ -26,4 +18,7 @@ The downloaded files would need to be converted to 16KHz for our pipeline. Pleas python convert_sampling_rate.py --source_folder=/path/to/csv_file.csv --target_folder=/path/to/target/16KHz_folder/ --fs=16000 --log_rate=100 ``` The script can convert the sampling rate of any wav file to the specified --fs. But for our applications, we use 16KHz only.
-Choose the log rate for how often the log should be printed for the sample rate converion. This will print a strinng ever log_rate iterations. \ No newline at end of file +Choose the log rate for how often the log should be printed for the sample rate conversion. This will print a string ever log_rate iterations. + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT license. \ No newline at end of file diff --git a/applications/KWS_Phoneme/data_pipe.py b/applications/KWS_Phoneme/data_pipe.py index eec669934..5614538ab 100644 --- a/applications/KWS_Phoneme/data_pipe.py +++ b/applications/KWS_Phoneme/data_pipe.py @@ -20,7 +20,7 @@ def synthesize_wave(sigx, snr, wgn_snr, gain, do_rir, args): """ Synth Block - Used to process the input audio. - The input is convolved with room reverbration recording + The input is convolved with room reverberation recording Adds noise in the form of white gaussian noise and regular audio clips (eg:piano, people talking, car engine etc) Input @@ -95,7 +95,7 @@ def get_add_noise(args): args: args object (contains info about model and training) Output - add_sample: addtive noise audio + add_sample: additive noise audio """ additive_base_path = args.additive_base_path add_fname = random.choice(os.listdir(additive_base_path)) @@ -124,7 +124,7 @@ def get_ASR_datasets(args): for path in train_textgrid_paths] if args.pre_phone_list: - # If there is a list of phons in the dataset, use this flag + # If there is a list of phonemes in the dataset, use this flag Sy_phoneme = [] with open(args.phoneme_text_file, "r") as f: for line in f.readlines(): @@ -139,7 +139,7 @@ def get_ASR_datasets(args): print(len(Sy_phoneme), flush=True) print("**********************", flush=True) else: - # No list of phons specified. Count from the input dataset + # No list of phonemes specified. Count from the input dataset phoneme_counter = Counter() for path in train_textgrid_paths: tg = textgrid.TextGrid() @@ -148,7 +148,7 @@ def get_ASR_datasets(args): for phone in tg.getList("phones")[0] if phone.mark not in ['', 'sp', 'spn']]) - # Display and store the phons extracted + # Display and store the phonemes extracted Sy_phoneme = list(phoneme_counter) args.num_phonemes = len(Sy_phoneme) print("**************", flush=True) @@ -201,7 +201,7 @@ def __len__(self): def __getitem__(self, idx): """ Gives one sample from the dataset. Data is read in this snippet. - (refer to the coalate function for pre-procesing) + (refer to the collate function for pre-processing) Input: idx: index for the sample @@ -365,7 +365,7 @@ def __init__(self, wav_paths, labels, args, is_train=True): Input wav_paths : list of strings (wav file paths) - labels : list of classification labels for the corresponding audio wav filess + labels : list of classification labels for the corresponding audio wav files is_train : boolean flag, if the dataset loader is for the train or test pipeline args : args object (contains info about model and training) """ @@ -399,16 +399,16 @@ def one_hot_encoder(self, lab): def __getitem__(self, idx): """ - Gives one sample from the dataset. Data is read in this snippet. (refer to the coalate function for pre-procesing) + Gives one sample from the dataset. Data is read in this snippet. (refer to the collate function for pre-processing) Input: idx: index for the sample Output: x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping - one_hot_label : one hot encoded label + one_hot_label : one-hot encoded label seqlen : length of the audio file. - This value will be dropped and seqlen after feature extraction will be used. Refer to the coallate func + This value will be dropped and seqlen after feature extraction will be used. Refer to the collate func """ x, fs = sf.read(self.wav_paths[idx]) @@ -455,7 +455,7 @@ def __call__(self, batch): x_pad_length = (T - len(x[index])) x[index] = np.pad(x[index], (x_pad_length,0), 'constant', constant_values=(0, 0)) - # Extract MFCC from padded audio + # Extract Mel-Spectogram from padded audio feature = librosa.feature.melspectrogram(y=x[index],sr=16000,n_mels=80,win_length=25*16,hop_length=10*16, n_fft=512) feature = librosa.core.power_to_db(feature) # Normalize the features @@ -487,12 +487,12 @@ def __call__(self, batch): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--base_path', type=str, default="/mnt/kws_data/data/") - parser.add_argument('--train_data_folders', type=str, default="google30_train") - parser.add_argument('--test_data_folders', type=str, default="google30_test") - parser.add_argument('--rir_base_path', type=str, default="/mnt/kws_data/data/noises_sachin/iir/") - parser.add_argument('--additive_base_path', type=str, default="/mnt/kws_data/data/noises_sachin/additive/") - parser.add_argument('--phoneme_text_file', type=str, default="/mnt/kws_data/data/LibriSpeech/text/phonemes.txt") + parser.add_argument('--base_path', type=str, required=True, help="Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here.") + parser.add_argument('--train_data_folders', type=str, default="google30_train", help="List of training folders in base path. Each folder is a dataset in the prescribed format") + parser.add_argument('--test_data_folders', type=str, default="google30_test", help="List of testing folders in base path. Each folder is a dataset in the prescribed format") + parser.add_argument('--rir_base_path', type=str, required=True, help="Folder with the reverbration files") + parser.add_argument('--additive_base_path', type=str, required=True, help="Folder with additive noise files") + parser.add_argument('--phoneme_text_file', type=str, required=True, help="Text files with pre-fixed phons") parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") parser.add_argument("--word_model_name", default='google30', help="Name of the word list chosen. Will be used in conjunction with the data loader") parser.add_argument('--words', type=str, default="all") diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py index d966cf6d0..bb79f5aac 100644 --- a/applications/KWS_Phoneme/kwscnn.py +++ b/applications/KWS_Phoneme/kwscnn.py @@ -49,7 +49,7 @@ def __init__(self): def forward(self, value): """ Applies a custom activation function - The first half of the channels are passed through sigmoid layer and the next hlaf though a tanh + The first half of the channels are passed through sigmoid layer and the next half though a tanh The outputs are multiplied and returned Input @@ -81,7 +81,7 @@ def __init__(self, in_channels, out_channels, kernel_size, A convolution layer with the weight matrix subjected to a low-rank decomposition. Currently for kernel size of 5 Input - rank : The rank used for the low-rank decomposition on the weight/kernl tensor + rank : The rank used for the low-rank decomposition on the weight/kernel tensor All other parameters are similar to that of a convolution layer Only change is the decomposition of the output channels into low-rank tensors """ @@ -130,9 +130,9 @@ def __init__( kernel : conv kernel size stride : conv stride groups : number of groups for conv layer - avg_pool : kernel size for avgerage pooling layer + avg_pool : kernel size for average pooling layer dropout : dropout layer probability - batch_norm : momemtum for batch norm + batch_norm : momentum for batch norm activation : activation layer rank : rank for low-rank decomposition for conv layer weights """ @@ -184,7 +184,7 @@ def __init__( def forward(self, x): """ - Apply the set of layers initilized in __init__ + Apply the set of layers initialized in __init__ Input x: A tensor of shape (batch, channels, length) @@ -203,7 +203,7 @@ def __init__( activation='sigmoid', rank=50): super(DSCNNBlockLR, self).__init__() """ - A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + A depthwise separable low-rank convolution layer combination with pooling and activation layers Input in_channels : number of input channels for the pointwise conv layer @@ -211,9 +211,9 @@ def __init__( kernel : conv kernel size for depthwise layer stride : conv stride groups : number of groups for conv layer - avg_pool : kernel size for avgerage pooling layer + avg_pool : kernel size for average pooling layer dropout : dropout layer probability - batch_norm : momemtum for batch norm + batch_norm : momentum for batch norm activation : activation layer rank : rank for low-rank decomposition for conv layer weights """ @@ -265,7 +265,7 @@ def __init__( def forward(self, x): """ - Apply the set of layers initilized in __init__ + Apply the set of layers initialized in __init__ Input x: A tensor of shape (batch, channels, length) @@ -326,7 +326,7 @@ def forward(self, input_f, input_b): def X_preRNN_process(X, fwd_context, bwd_context): """ - A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + A depthwise separable low-rank convolution layer combination with pooling and activation layers Input in_channels : number of input channels for the pointwise conv layer @@ -334,9 +334,9 @@ def X_preRNN_process(X, fwd_context, bwd_context): kernel : conv kernel size for depthwise layer stride : conv stride groups : number of groups for conv layer - avg_pool : kernel size for avgerage pooling layer + avg_pool : kernel size for average pooling layer dropout : dropout layer probability - batch_norm : momemtum for batch norm + batch_norm : momentum for batch norm activation : activation layer rank : rank for low-rank decomposition for conv layer weights """ @@ -366,7 +366,7 @@ def X_preRNN_process(X, fwd_context, bwd_context): def X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b): """ - A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + A depthwise separable low-rank convolution layer combination with pooling and activation layers Input in_channels : number of input channels for the pointwise conv layer @@ -374,9 +374,9 @@ def X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b): kernel : conv kernel size for depthwise layer stride : conv stride groups : number of groups for conv layer - avg_pool : kernel size for avgerage pooling layer + avg_pool : kernel size for average pooling layer dropout : dropout layer probability - batch_norm : momemtum for batch norm + batch_norm : momentum for batch norm activation : activation layer rank : rank for low-rank decomposition for conv layer weights """ @@ -419,7 +419,7 @@ def __init__(self, cnn_channels, rnn_hidden_size, rnn_num_layers, isBi=True, num_labels=41, rank=None, fwd_context=15, bwd_context=9): super(DSCNN_RNN_Block, self).__init__() """ - A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + A depthwise separable low-rank convolution layer combination with pooling and activation layers Input cnn_channels : number of the output channels for the first CNN block @@ -427,7 +427,7 @@ def __init__(self, cnn_channels, rnn_hidden_size, rnn_num_layers, rnn_num_layers : number of FastGRNN layers device : device on which the tensors would placed gate_nonlinearity : activation function for the gating in the FastGRNN - update_nonlinearity : activation function for the update fucntion in the FastGRNN + update_nonlinearity : activation function for the update function in the FastGRNN isBi : boolean flag to use bi-directional FastGRNN fwd_context : window for the forward pass bwd_context : window for the backward pass @@ -456,7 +456,7 @@ def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, num_labels, rank): """ Declare the netwok layers - Arguments can be infered from the __init__ + Arguments can be inferred from the __init__ """ self.CNN1 = torch.nn.Sequential( PreRNNConvBlock(80, cnn_channels, 5, 1, 1, @@ -499,7 +499,7 @@ def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, def forward(self, features): """ - Apply the set of layers initilized in __init__ + Apply the set of layers initialized in __init__ Input features: A tensor of shape (batch, channels, length) @@ -538,15 +538,15 @@ def __init__(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels=2, dropout=0, batch_assertion=False): super(Binary_Classification_Block, self).__init__() """ - A depthwise seeprable low-rank convolution layer combination with pooling and activation layers + A depthwise separable low-rank convolution layer combination with pooling and activation layers Input in_size : number of input channels to the layer rnn_hidden_size : hidden dimensions of the RNN layer rnn_num_layers : number of layers for the RNN layer device : device on which the tensors would placed - islstm : boolean flag to use the LSTM. Flase would use GRU - isBi : boolean flag to use the bi-directional variat of the RNN + islstm : boolean flag to use the LSTM. False would use GRU + isBi : boolean flag to use the bi-directional variant of the RNN momentum : momentum for the batch-norm layer num_labels : number of output labels dropout : probability for the dropout layer @@ -574,7 +574,7 @@ def __init__(self, in_size, rnn_hidden_size, rnn_num_layers, def declare_network(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels): """ Declare the netwok layers - Arguments can be infered from the __init__ + Arguments can be inferred from the __init__ """ self.CNN1 = torch.nn.Sequential( torch.nn.LeakyReLU(negative_slope=0.01), @@ -603,7 +603,7 @@ def declare_network(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels): def forward(self, features, seqlen): """ - Apply the set of layers initilized in __init__ + Apply the set of layers initialized in __init__ Input features: A tensor of shape (batch, channels, length) @@ -660,7 +660,7 @@ def forward(self, features, seqlen): def init_hidden(self, batch, rnn_hidden_size, rnn_num_layers): """ - Used to initialize the first hidden state of the RNN. It is currrently zero. THe user is free to edit this function for addtionaly analysis + Used to initialize the first hidden state of the RNN. It is currently zero. THe user is free to edit this function for additional analysis """ # the weights are of the form (batch, num_layers * num_directions , hidden_size) if self.batch_assertion: diff --git a/applications/KWS_Phoneme/train_classifier.py b/applications/KWS_Phoneme/train_classifier.py index 1a9dbf2fe..bc6901bad 100644 --- a/applications/KWS_Phoneme/train_classifier.py +++ b/applications/KWS_Phoneme/train_classifier.py @@ -26,21 +26,21 @@ def parseArgs(): parser.add_argument('--epochs', type=int, default=200, help="Number of epochs for training") parser.add_argument('--save_tick', type=int, default=1, help="Number of epochs to wait to save") parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") - parser.add_argument("--gpu", type=str, default='0', help="GPU indicies Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") + parser.add_argument("--gpu", type=str, default='0', help="GPU indices Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") parser.add_argument("--word_model_name", default='google30', help="Name of the list of words used") parser.add_argument('--words', type=str, help="List of words to be used. This will be assigned in the code. User input will not affect the result") parser.add_argument("--is_training", action='store_true', help="True for training") parser.add_argument("--synth", action='store_true', help="Use Synth block or not") # Args for DataLoader parser.add_argument('--base_path', type=str, required=True, help="path to train and test data folders") - parser.add_argument('--train_data_folders', type=str, default="google30_train", help="List of training folders in base path. Each folder is a dataset in the precribed format") - parser.add_argument('--test_data_folders', type=str, default="google30_test", help="List of testing folders in base path. Each folder is a dataset in the precribed format") + parser.add_argument('--train_data_folders', type=str, default="google30_train", help="List of training folders in base path. Each folder is a dataset in the prescribed format") + parser.add_argument('--test_data_folders', type=str, default="google30_test", help="List of testing folders in base path. Each folder is a dataset in the prescribed format") parser.add_argument('--rir_base_path', type=str, required=True, help="Folder with the reverbration files") parser.add_argument('--additive_base_path', type=str, required=True, help="Folder with additive noise files") parser.add_argument('--phoneme_text_file', type=str, help="Text files with pre-fixed phons") parser.add_argument('--pretraining_length_mean', type=int, default=6, help="Mean of the audio clips lengths") parser.add_argument('--pretraining_length_var', type=int, default=1, help="variance of the audio clip lengths") - parser.add_argument('--pretraining_batch_size', type=int, default=256, help="Batch size for the pirpeline") + parser.add_argument('--pretraining_batch_size', type=int, default=256, help="Batch size for the pipeline") parser.add_argument('--snr_samples', type=str, default="-5,0,0,5,10,15,40,100,100", help="SNR values for additive noise files") parser.add_argument('--wgn_snr_samples', type=str, default="5,10,20,40,60", help="SNR values for white gaussian noise") parser.add_argument('--gain_samples', type=str, default="1.0,0.25,0.5,0.75", help="Gain values for processed signal") @@ -57,7 +57,7 @@ def parseArgs(): parser.add_argument('--phoneme_phoneme_isBi', action='store_true', help="Use Bi-Directional RNN") parser.add_argument('--phoneme_num_labels', type=int, default=41, help="Number og phoneme labels") # Args for Classifier - parser.add_argument('--classifier_rnn_hidden_size', type=int, default=100, help="Calssifier RNN hidden dimensions") + parser.add_argument('--classifier_rnn_hidden_size', type=int, default=100, help="Classifier RNN hidden dimensions") parser.add_argument('--classifier_rnn_num_layers', type=int, default=1, help="Classifier RNN number of layers") parser.add_argument('--classifier_dropout', type=float, default=0.2, help="Classifier dropout layer probability") parser.add_argument('--classifier_islstm', action='store_true', help="Use LSTM in the classifier") @@ -145,7 +145,7 @@ def train_classifier_model(args): if args.optim == "sgd": model['opt'] = torch.optim.SGD(model['classifier'].parameters(), lr=args.lr) - # Load the sepcified phoneme checkpoint. 'phoneme_model_load_ckpt' must point to a chcekpoint and not folder + # Load the specified phoneme checkpoint. 'phoneme_model_load_ckpt' must point to a checkpoint and not folder if args.phoneme_model_load_ckpt is not None: if os.path.exists(args.phoneme_model_load_ckpt): # Load Checkpoint @@ -159,14 +159,14 @@ def train_classifier_model(args): else: print("No Phoneme Checkpoint Given", flush=True) - # Load the sepcified classifier checkpoint. 'classifier_model_load_ckpt' must point to a chcekpoint and not folder + # Load the specified classifier checkpoint. 'classifier_model_load_ckpt' must point to a checkpoint and not folder if args.classifier_model_load_ckpt is not None: if os.path.exists(args.classifier_model_load_ckpt): # Get the number from the classifier checkpoint path - start_epoch = args.classifier_model_load_ckpt # Temporarlity store the full ckpt path + start_epoch = args.classifier_model_load_ckpt # Temporarily store the full ckpt path start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux) start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows) - start_epoch = int(start_epoch.split('.')[0]) # reatin the integers + start_epoch = int(start_epoch.split('.')[0]) # retain the integers # Load Checkpoint latest_classifier_ckpt = torch.load(args.classifier_model_load_ckpt, map_location=device) # Load specific state_dicts() and print the latest stats @@ -186,7 +186,7 @@ def train_classifier_model(args): output_frame_rate = 3 save_path = args.classifier_model_save_folder os.makedirs(args.classifier_model_save_folder, exist_ok=True) - # Print for cros-checking + # Print for cross-checking print(f"Pre Phone List {args.pre_phone_list}", flush=True) print(f"Start Epoch : {start_epoch}", flush=True) print(f"Device : {device}", flush=True) diff --git a/applications/KWS_Phoneme/train_phoneme.py b/applications/KWS_Phoneme/train_phoneme.py index 7c2ee2649..390c9e07e 100644 --- a/applications/KWS_Phoneme/train_phoneme.py +++ b/applications/KWS_Phoneme/train_phoneme.py @@ -25,7 +25,7 @@ def parseArgs(): parser.add_argument('--epochs', type=int, default=200, help="Number of epochs for training") parser.add_argument('--save_tick', type=int, default=1, help="Number of epochs to wait to save") parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") - parser.add_argument("--gpu", type=str, default='0', help="GPU indicies Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") + parser.add_argument("--gpu", type=str, default='0', help="GPU indices Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") # Args for DataLoader parser.add_argument('--base_path', type=str, required=True, help="Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here.") @@ -34,12 +34,12 @@ def parseArgs(): parser.add_argument('--phoneme_text_file', type=str, required=True, help="Text files with pre-fixed phons") parser.add_argument('--pretraining_length_mean', type=int, default=6, help="Mean of the audio clips lengths") parser.add_argument('--pretraining_length_var', type=int, default=1, help="variance of the audio clip lengths") - parser.add_argument('--pretraining_batch_size', type=int, default=256, help="Batch size for the pirpeline") + parser.add_argument('--pretraining_batch_size', type=int, default=256, help="Batch size for the pipeline") parser.add_argument('--snr_samples', type=str, default="0,5,10,25,100,100", help="SNR values for additive noise files") parser.add_argument('--wgn_snr_samples', type=str, default="5,10,15,100,100", help="SNR values for white gaussian noise") parser.add_argument('--gain_samples', type=str, default="1.0,0.25,0.5,0.75", help="Gain values for the processed signal") parser.add_argument('--rir_chance', type=float, default=0.25, help="Probability of performing reverbration") - parser.add_argument('--synth_chance', type=float, default=0.5, help="Probablity of pre-processing the signal with noise and reverb") + parser.add_argument('--synth_chance', type=float, default=0.5, help="Probability of pre-processing the signal with noise and reverb") parser.add_argument('--pre_phone_list', action='store_true', help="Use pre-fixed list of phonemes") # Args for Phoneme parser.add_argument('--phoneme_cnn_channels', type=int, default=400, help="Number od channels for the CNN layers") @@ -102,14 +102,14 @@ def train_phoneme_model(args): if args.optim == "sgd": model['opt'] = torch.optim.SGD(model['phoneme'].parameters(), lr=args.lr) - # Load the sepcified checkpoint. 'phoneme_model_load_ckpt' must point to a chcekpoint and not folder + # Load the specified checkpoint. 'phoneme_model_load_ckpt' must point to a checkpoint and not folder if args.phoneme_model_load_ckpt is not None: if os.path.exists(args.phoneme_model_load_ckpt): # Get the number from the phoneme checkpoint path - start_epoch = args.phoneme_model_load_ckpt # Temporarlity store the full ckpt path + start_epoch = args.phoneme_model_load_ckpt # Temporarily store the full ckpt path start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux) start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows) - start_epoch = int(start_epoch.split('.')[0]) # reatin the integers + start_epoch = int(start_epoch.split('.')[0]) # retain the integers # Load Checkpoint latest_ckpt = torch.load(args.phoneme_model_load_ckpt, map_location=device) # Load specific state_dicts() and print the latest stats From 11b718eaadd0f39ee9926da892ccea61d539cc25 Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Tue, 20 Jul 2021 13:45:32 +0530 Subject: [PATCH 6/8] Fix typos and punctuation --- applications/KWS_Phoneme/README.md | 40 +++++++++---------- .../KWS_Phoneme/auxiliary_files/README.md | 4 +- applications/KWS_Phoneme/kwscnn.py | 1 - applications/KWS_Phoneme/train_classifier.py | 2 +- applications/KWS_Phoneme/train_phoneme.py | 2 +- 5 files changed, 24 insertions(+), 25 deletions(-) diff --git a/applications/KWS_Phoneme/README.md b/applications/KWS_Phoneme/README.md index 15f41603f..ab3ba986a 100644 --- a/applications/KWS_Phoneme/README.md +++ b/applications/KWS_Phoneme/README.md @@ -1,25 +1,25 @@ -# Phoneme based Keyword Spotting(KWS) +# Phoneme-based Keyword Spotting(KWS) # Project Description -There are two major issues in the existing KWS systems (a) they are not robust to heavy background noise and random utterances, and (b) they require collecting a lot of data, hampering the ease of adding a new keyword. Tackling these issues from a different perspective, we propose a new two staged scheme with a model for predicting phonemes which are in turn used for phoneme based keyword classification. +There are two major issues in the existing KWS systems (a) They are not robust to heavy background noise and random utterances, and (b) They require collecting a lot of data, hampering the ease of adding a new keyword. Tackling these issues from a different perspective, we propose a new two staged scheme with a model for predicting phonemes which are in turn used for phoneme-based keyword classification. -First we train a phoneme classification model which gives the phoneme transcription of the input speech snippet. For training this phoneme classifier, we use a large public speech dataset like LibriSpeech. The public dataset can be aligned (meaning get the phoneme labels for each speech snippet in the data) using Montreal Forced Aligner. We also add reverberations and additive noise to the speech samples from the public dataset to make the phoneme classifier training robust to various accents, background noise and varied environment. In this project, we predict phonemes at every 10ms which is the standard way. You can find the aligned LibriSpeech dataset we used for training here. +First we train a phoneme classification model which gives the phoneme transcription of the input speech snippet. For training this phoneme classifier, we use a large public speech dataset like LibriSpeech. The public dataset can be aligned (meaning we can get the phoneme labels for each speech snippet in the data) using Montreal Forced Aligner. We also add reverberations and additive noise to the speech samples from the public dataset to make the phoneme classifier training robust to various accents, background noise and varied environments. In this project, we predict phonemes at every 10ms which is the standard way. You can find the aligned LibriSpeech dataset we used for training here. -In the second part, we use the predicted phoneme outputs from the phoneme classifier for predicting the input keyword. We train a 1 layer FastGRNN classifier to predict the keyword based on the phoneme transcription as input. Since the phoneme classifier training has been done to account for diverse accent, background noise and environments, the keyword classifier can be trained using a small number of Text-To-Speech(TTS) samples generated using any standard TTS api from cloud services like Azure, Google Cloud or AWS. +In the second part, we use the predicted phoneme outputs from the phoneme classifier for predicting the input keyword. We train a 1 layer FastGRNN classifier to predict the keyword based on the phoneme transcription as input. Since the phoneme classifier training has been done to account for diverse accents, background noise and environments, the keyword classifier can be trained using a small number of Text-To-Speech(TTS) samples generated using any standard TTS API from cloud services like Azure, Google Cloud or AWS. This gives two advantages: (a) The phoneme model is trained to account for diverse accents and background noise settings, thus the flexible keyword classifier training requires only a small number of keyword samples, and (b) Empirically this method was able to detect keywords from as far as 9ft of distance. Further, the phoneme model has a small size of around 250k parameters and can fit on a Cortex M7 micro-controller. # Training the Phoneme Classifier -1) Train a phoneme classification model on some public speech dataset like LibriSpeech -2) Training speech dataset can be labelled using Montreal Force Aligner -3) Speech snippets are convolved with reverberation files, and additive noises from YouTube or other open source are added -4) We also add white gaussian noise of various SNRs +1) Train a phoneme classification model on some public speech dataset like LibriSpeech. +2) Training speech dataset can be labelled using Montreal Force Aligner. +3) Speech snippets are convolved with reverberation files, and additive noises from YouTube or other open source are added. +4) We also add white gaussian noise of various SNRs. # Training the KWS Model -1) Our method takes as input the speech snippet and passes it through the phoneme classifier -2) Keywords are detected by training a keyword classifier over the detected phonemes -3) For training the keyword classifier, we use Azure and Google Text-To-Speech API to get the training data (keyword snippets) -4) For example, if you want to train a Keyword classifier for the keywords in the Google30 dataset, generate TTS samples from the Azure/Google-Cloud/AWS API for each of the 30 keywords. The TTS samples for each keyword must be stored in a separate folder named according to the keyword. More details about how the generated TTS data should be stored are mentioned below in sample use case for classifier model training. +1) Our method takes as input the speech snippet and passes it through the phoneme classifier. +2) Keywords are detected by training a keyword classifier over the detected phonemes. +3) For training the keyword classifier, we use Azure and Google Text-To-Speech API to get the training data (keyword snippets). +4) For example, if you want to train a keyword classifier for the keywords in the Google30 dataset, generate TTS samples from the Azure/Google-Cloud/AWS API for each of the 30 keywords. The TTS samples for each keyword must be stored in a separate folder named according to the keyword. More details about how the generated TTS data should be stored are mentioned below in sample use case for classifier model training. # Sample Use Cases @@ -29,10 +29,10 @@ The following command can be used to instantiate and train the phoneme model. python train_phoneme.py --base_path=/path/to/librispeech_data/ --rir_base_path=/path/to/reverb_files/ --additive_base_path=/path/to/additive_noises/ --snr_samples="0,5,10,25,100,100" --rir_chance=0.5 ``` Some important command line arguments: -1) base_path : Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here. -2) rir_base_path, additive_base_path : Path to the reverb and additive noise files +1) base_path : Path of the speech data folder. The data in this folder should be in accordance to the data-loader code written here. +2) rir_base_path, additive_base_path : Path to the reverb and additive noise files. 3) snr_samples : List of various SNRs at which the additive noise is to be added. -4) rir_chance : Probability at which reverberation has to be done for each speech sample +4) rir_chance : Probability that would decide if the reverberation operation has to be performed for a given speech sample. ## Classifier Model Training The following command can be used to instantiate and train the classifier model. @@ -41,11 +41,11 @@ python train_classifier.py --base_path=/path/to/train_and_test_data_folders/ --t ``` Some important command line arguments: -1) base_path : path to train and test data folders -2) train_data_folders, test_data_folders : These folders should have the .wav files for each keyword in a separate subfolder inside according to the dataloader here -3) phoneme_model_load_ckpt : The full path of the checkpoint file that would be used to load the weights to the instantiated phoneme model -4) rir_base_path, additive_base_path : Path to the reverb and additive noise files -5) synth : Boolean flag for specifying if reverberations and noise addition has to be done +1) base_path : Path to train and test data folders. +2) train_data_folders, test_data_folders : These folders should have the .wav files for each keyword in a separate subfolder inside according to the data-loader here. +3) phoneme_model_load_ckpt : The full path of the checkpoint file that would be used to load the weights to the instantiated phoneme model. +4) rir_base_path, additive_base_path : Path to the reverb and additive noise files. +5) synth : Boolean flag for specifying if reverberations and noise addition has to be done. Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license. \ No newline at end of file diff --git a/applications/KWS_Phoneme/auxiliary_files/README.md b/applications/KWS_Phoneme/auxiliary_files/README.md index 64d5ba400..ed6ea9265 100644 --- a/applications/KWS_Phoneme/auxiliary_files/README.md +++ b/applications/KWS_Phoneme/auxiliary_files/README.md @@ -17,8 +17,8 @@ The downloaded files would need to be converted to 16KHz for our pipeline. Pleas ``` python convert_sampling_rate.py --source_folder=/path/to/csv_file.csv --target_folder=/path/to/target/16KHz_folder/ --fs=16000 --log_rate=100 ``` -The script can convert the sampling rate of any wav file to the specified --fs. But for our applications, we use 16KHz only.
-Choose the log rate for how often the log should be printed for the sample rate conversion. This will print a string ever log_rate iterations. +The script can convert the sampling rate of any .wav file to the specified --fs. But for our applications, we use 16KHz only.
+Choose the log rate for how often the log should be printed for the sample rate conversion. This will print a string every log_rate iterations. Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license. \ No newline at end of file diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py index bb79f5aac..ab7e4c558 100644 --- a/applications/KWS_Phoneme/kwscnn.py +++ b/applications/KWS_Phoneme/kwscnn.py @@ -675,4 +675,3 @@ def init_hidden(self, batch, rnn_hidden_size, rnn_num_layers): hidden = Variable(hidden) return hidden - diff --git a/applications/KWS_Phoneme/train_classifier.py b/applications/KWS_Phoneme/train_classifier.py index bc6901bad..1bfaa37f9 100644 --- a/applications/KWS_Phoneme/train_classifier.py +++ b/applications/KWS_Phoneme/train_classifier.py @@ -293,4 +293,4 @@ def train_classifier_model(args): if __name__ == '__main__': args = parseArgs() - train_classifier_model(args) \ No newline at end of file + train_classifier_model(args) diff --git a/applications/KWS_Phoneme/train_phoneme.py b/applications/KWS_Phoneme/train_phoneme.py index 390c9e07e..69857f012 100644 --- a/applications/KWS_Phoneme/train_phoneme.py +++ b/applications/KWS_Phoneme/train_phoneme.py @@ -209,4 +209,4 @@ def train_phoneme_model(args): if __name__ == '__main__': args = parseArgs() - train_phoneme_model(args) \ No newline at end of file + train_phoneme_model(args) From ecd1d0964fd5a8f838e3672df02087290ee6b3b3 Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Tue, 20 Jul 2021 21:34:54 +0530 Subject: [PATCH 7/8] Minor modifications to comments and punctuation --- .../KWS_Phoneme/auxiliary_files/README.md | 8 +- applications/KWS_Phoneme/data_pipe.py | 190 ++++++----- applications/KWS_Phoneme/kwscnn.py | 312 +++++++++--------- applications/KWS_Phoneme/train_classifier.py | 98 +++--- applications/KWS_Phoneme/train_phoneme.py | 77 +++-- 5 files changed, 337 insertions(+), 348 deletions(-) diff --git a/applications/KWS_Phoneme/auxiliary_files/README.md b/applications/KWS_Phoneme/auxiliary_files/README.md index ed6ea9265..6521fed4a 100644 --- a/applications/KWS_Phoneme/auxiliary_files/README.md +++ b/applications/KWS_Phoneme/auxiliary_files/README.md @@ -1,12 +1,11 @@ -# Auxiliary Files to help Download and Prepare the Data +# Python scripts to help download and down-sample the additive noise data from YouTube videos -## YouTube Additive Noise Run the following commands to download the CSV Files to download the YouTube Additive Noise Data : ``` wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv ``` -Followed by the extraction script to download the actual data : +Following the download of the CSV file, run the extraction script to download the actual audio data : ``` python download_youtube_data.py --csv_file=/path/to/csv_file.csv --target_folder=/path/to/target/folder/ ``` @@ -17,8 +16,7 @@ The downloaded files would need to be converted to 16KHz for our pipeline. Pleas ``` python convert_sampling_rate.py --source_folder=/path/to/csv_file.csv --target_folder=/path/to/target/16KHz_folder/ --fs=16000 --log_rate=100 ``` -The script can convert the sampling rate of any .wav file to the specified --fs. But for our applications, we use 16KHz only.
-Choose the log rate for how often the log should be printed for the sample rate conversion. This will print a string every log_rate iterations. +The script can convert the sampling rate of any .wav file to the specified --fs. But for our applications, we use 16KHz only. Choose the log rate for how often the log should be printed for the sample rate conversion. This will print a string every log_rate iterations. Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license. \ No newline at end of file diff --git a/applications/KWS_Phoneme/data_pipe.py b/applications/KWS_Phoneme/data_pipe.py index 5614538ab..9439c35de 100644 --- a/applications/KWS_Phoneme/data_pipe.py +++ b/applications/KWS_Phoneme/data_pipe.py @@ -20,25 +20,25 @@ def synthesize_wave(sigx, snr, wgn_snr, gain, do_rir, args): """ Synth Block - Used to process the input audio. - The input is convolved with room reverberation recording - Adds noise in the form of white gaussian noise and regular audio clips (eg:piano, people talking, car engine etc) + The input is convolved with room reverberation recording. + Adds noise in the form of white gaussian noise and regular audio clips (eg:piano, people talking, car engine etc). - Input - sigx : input signal to the block - snr : signal-to-noise ratio of the input and additive noise (regular audio) - wg_snr : signal-to-noise ratio of the input and additive noise (white gaussian noise) - gain : gain of the output signal - do_rir : boolean flag, if reverbration needs to be incorporated - args : args object (contains info about model and training) + Input: + sigx : input signal to the block. + snr : signal-to-noise ratio of the input and additive noise (regular audio). + wg_snr : signal-to-noise ratio of the input and additive noise (white gaussian noise). + gain : gain of the output signal. + do_rir : boolean flag, if reverbration needs to be incorporated. + args : args object (contains info about model and training). - Output - clipped version of the audio post-processing + Output: + clipped version of the audio post-processing. """ beta = np.random.choice([0.1, 0.25, 0.5, 0.75, 1]) sigx = beta * sigx x_power = np.sum(sigx * sigx) - # Do RIR and normalize back to original power + # Do RIR and normalize back to original power. if do_rir: rir_base_path = args.rir_base_path rir_fname = random.choice(os.listdir(rir_base_path)) @@ -46,7 +46,7 @@ def synthesize_wave(sigx, snr, wgn_snr, gain, do_rir, args): rir_sample, fs = sf.read(rir_full_fname) if rir_sample.ndim > 1: rir_sample = rir_sample[:,0] - # We cut the tail of the RIR signal at 99% energy + # We cut the tail of the RIR signal at 99% energy. cum_en = np.cumsum(np.power(rir_sample, 2)) cum_en = cum_en / cum_en[-1] rir_sample = rir_sample[cum_en <= 0.99] @@ -56,46 +56,42 @@ def synthesize_wave(sigx, snr, wgn_snr, gain, do_rir, args): sigy = sigy[0:len(sigx)] y_power = np.sum(sigy * sigy) - sigy *= math.sqrt(x_power / y_power) # normalize so y has same total power - + sigy *= math.sqrt(x_power / y_power) # normalize so y has same total power. else: sigy = sigx - y_rmse = math.sqrt(x_power / len(sigy)) - # Only bother with noise addition if the SNR is low enough + # Only bother with noise addition if the SNR is low enough. if snr < 50: add_sample = get_add_noise(args) noise_rmse = math.sqrt(np.sum(add_sample * add_sample) / len(add_sample)) #+ 0.000000000000000001 - if len(add_sample) < len(sigy): padded = np.zeros(len(sigy), dtype=np.float32) padded[0:len(add_sample)] = add_sample else: padded = add_sample[0:len(sigy)] - add_sample = padded - noise_scale = y_rmse / noise_rmse * math.pow(10, -snr / 20) sigy = sigy + add_sample * noise_scale + # Only bother with white gasussian noise addition if the WG_SNR is low enough. if wgn_snr < 50: wgn_samps = np.random.normal(size=(len(sigy))).astype(np.float32) noise_scale = y_rmse * math.pow(10, -wgn_snr / 20) sigy = sigy + wgn_samps * noise_scale - # Apply gain & clipping + # Apply gain & clipping. return np.clip(sigy * gain, -1.0, 1.0) def get_add_noise(args): """ - Extracts the additive noise file from the defined path + Extracts the additive noise file from the defined path. - Input - args: args object (contains info about model and training) + Input: + args: args object (contains info about model and training). - Output - add_sample: additive noise audio + Output: + add_sample: additive noise audio. """ additive_base_path = args.additive_base_path add_fname = random.choice(os.listdir(additive_base_path)) @@ -106,17 +102,17 @@ def get_add_noise(args): def get_ASR_datasets(args): """ - Function for preparing the data samples for the phoneme pipeline + Function for preparing the data samples for the phoneme pipeline. - Input - args: args object (contains info about model and training) + Input: + args: args object (contains info about model and training). - Output - train_dataset: dataset class used for loading the samples into the training pipeline + Output: + train_dataset: dataset class used for loading the samples into the training pipeline. """ base_path = args.base_path - # Load the speech data. This code snippet (till line 121) depends on the data format in base_path + # Load the speech data. This code snippet (till line 121) depends on the data format in base_path. train_textgrid_paths = glob.glob(base_path + "/text/train-clean*/*/*/*.TextGrid") @@ -124,7 +120,7 @@ def get_ASR_datasets(args): for path in train_textgrid_paths] if args.pre_phone_list: - # If there is a list of phonemes in the dataset, use this flag + # If there is a list of phonemes in the dataset, use this flag. Sy_phoneme = [] with open(args.phoneme_text_file, "r") as f: for line in f.readlines(): @@ -139,7 +135,7 @@ def get_ASR_datasets(args): print(len(Sy_phoneme), flush=True) print("**********************", flush=True) else: - # No list of phonemes specified. Count from the input dataset + # No list of phonemes specified. Count from the input dataset. phoneme_counter = Counter() for path in train_textgrid_paths: tg = textgrid.TextGrid() @@ -148,7 +144,7 @@ def get_ASR_datasets(args): for phone in tg.getList("phones")[0] if phone.mark not in ['', 'sp', 'spn']]) - # Display and store the phonemes extracted + # Display and store the phonemes extracted. Sy_phoneme = list(phoneme_counter) args.num_phonemes = len(Sy_phoneme) print("**************", flush=True) @@ -165,7 +161,7 @@ def get_ASR_datasets(args): print("Data Path Prep Done.", flush=True) - # Create dataset objects + # Create dataset objects. train_dataset = ASRDataset(train_wav_paths, train_textgrid_paths, Sy_phoneme, args) return train_dataset @@ -173,13 +169,13 @@ def get_ASR_datasets(args): class ASRDataset(torch.utils.data.Dataset): def __init__(self, wav_paths, textgrid_paths, Sy_phoneme, args): """ - Dataset iterator for the phoneme detection model + Dataset iterator for the phoneme detection model. - Input - wav_paths : list of strings (wav file paths) - textgrid_paths : list of strings (textgrid for each wav file) - Sy_phoneme : list of strings (all possible phonemes) - args : args object (contains info about model and training) + Input: + wav_paths : list of strings (wav file paths). + textgrid_paths : list of strings (textgrid for each wav file). + Sy_phoneme : list of strings (all possible phonemes). + args : args object (contains info about model and training). """ self.wav_paths = wav_paths self.textgrid_paths = textgrid_paths @@ -187,28 +183,28 @@ def __init__(self, wav_paths, textgrid_paths, Sy_phoneme, args): self.length_var = args.pretraining_length_var self.Sy_phoneme = Sy_phoneme self.args = args - # Dataset Loader for the iterator + # Dataset Loader for the iterator. self.loader = torch.utils.data.DataLoader(self, batch_size=args.pretraining_batch_size, num_workers=args.workers, shuffle=True, collate_fn=CollateWavsASR()) def __len__(self): """ - Number of audio samples available + Number of audio samples available. """ return len(self.wav_paths) def __getitem__(self, idx): """ Gives one sample from the dataset. Data is read in this snippet. - (refer to the collate function for pre-processing) + (refer to the collate function for pre-processing). Input: - idx: index for the sample + idx: index for the sample. Output: - x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping - y_phoneme : the output phonemes sampled at 30ms + x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping. + y_phoneme : the output phonemes sampled at 30ms. """ x, fs = sf.read(self.wav_paths[idx]) @@ -222,7 +218,7 @@ def __getitem__(self, idx): if phoneme.mark == '': phoneme_index = -1 y_phoneme += [phoneme_index] * round(duration * fs) - # Cut a snippet of length random_length from the audio + # Cut a snippet of length random_length from the audio. random_length = round(fs * (self.length_mean + self.length_var * torch.randn(1).item())) if len(x) <= random_length: start = 0 @@ -247,10 +243,10 @@ def __getitem__(self, idx): class CollateWavsASR: def __call__(self, batch): """ - Pre-processing and padding, followed by batching the set of inputs + Pre-processing and padding, followed by batching the set of inputs. Input: - batch: list of tuples (input wav, phoneme labels) + batch: list of tuples (input wav, phoneme labels). Output: feature_tensor : the melspectogram features of the input audio. The features are padded for batching. @@ -264,27 +260,27 @@ def __call__(self, batch): x.append(x_) y_phoneme.append(y_phoneme_) - # pad all sequences to have same length and get features + # pad all sequences to have same length and get features. features=[] T = max([len(x_) for x_ in x]) U_phoneme = max([len(y_phoneme_) for y_phoneme_ in y_phoneme]) for index in range(batch_size): - # pad audio to same length for all the audio samples in the batch + # pad audio to same length for all the audio samples in the batch. x_pad_length = (T - len(x[index])) x[index] = np.pad(x[index], (x_pad_length,0), 'constant', constant_values=(0, 0)) - # Extract Mel-Spectogram from padded audio + # Extract Mel-Spectogram from padded audio. feature = librosa.feature.melspectrogram(y=x[index],sr=16000,n_mels=80, win_length=25*16,hop_length=10*16, n_fft=512) feature = librosa.core.power_to_db(feature) - # Normalize the features + # Normalize the features. max_value = np.max(feature) min_value = np.min(feature) feature = (feature - min_value) / (max_value - min_value) features.append(feature) - # Pad the labels to same length for all samples in the batch + # Pad the labels to same length for all samples in the batch. y_pad_length = (U_phoneme - len(y_phoneme[index])) y_phoneme[index] = np.pad(y_phoneme[index], (y_pad_length,0), 'constant', constant_values=(-1, -1)) @@ -305,25 +301,25 @@ def __call__(self, batch): def get_classification_dataset(args): """ - Function for preparing the data samples for the classification pipeline + Function for preparing the data samples for the classification pipeline. - Input - args: args object (contains info about model and training) + Input: + args: args object (contains info about model and training). - Output - train_dataset : dataset class used for loading the samples into the training pipeline - test_dataset : dataset class used for loading the samples into the testing pipeline + Output: + train_dataset : dataset class used for loading the samples into the training pipeline. + test_dataset : dataset class used for loading the samples into the testing pipeline. """ base_path = args.base_path - # Train Data + # Train Data. train_wav_paths = [] train_labels = [] - # data_folder_list = ["google30_train"] or ["google30_azure_tts", "google30_google_tts"] + # We assign data_folder_list = ["google30_train"] or ["google30_azure_tts", "google30_google_tts"]. data_folder_list = args.train_data_folders for data_folder in data_folder_list: - # For each of the folder, iterate through the words and get the files and the labels + # For each of the folder, iterate through the words and get the files and the labels. for (label, word) in enumerate(args.words): curr_word_files = glob.glob(base_path + f"/{data_folder}/" + word + "/*.wav") train_wav_paths += curr_word_files @@ -333,17 +329,17 @@ def get_classification_dataset(args): random.shuffle(temp) train_wav_paths, train_labels = zip(*temp) print(f"Train Data Folders Used {data_folder_list}", flush=True) - # Create dataset objects + # Create dataset objects. train_dataset = ClassificationDataset(wav_paths=train_wav_paths, labels=train_labels, args=args, is_train=True) - # Test Data + # Test Data. test_wav_paths = [] test_labels = [] - # data_folder_list = ["google30_test"] + # We assign data_folder_list = ["google30_test"]. data_folder_list = args.test_data_folders for data_folder in data_folder_list: - # For each of the folder, iterate through the words and get the files and the labels + # For each of the folder, iterate through the words and get the files and the labels. for (label, word) in enumerate(args.words): curr_word_files = glob.glob(base_path + f"/{data_folder}/" + word + "/*.wav") test_wav_paths += curr_word_files @@ -353,7 +349,7 @@ def get_classification_dataset(args): random.shuffle(temp) test_wav_paths, test_labels = zip(*temp) print(f"Test Data Folders Used {data_folder_list}", flush=True) - # Create dataset objects + # Create dataset objects. test_dataset = ClassificationDataset(wav_paths=test_wav_paths, labels=test_labels, args=args, is_train=False) return train_dataset, test_dataset @@ -361,13 +357,13 @@ def get_classification_dataset(args): class ClassificationDataset(torch.utils.data.Dataset): def __init__(self, wav_paths, labels, args, is_train=True): """ - Dataset iterator for the classifier model + Dataset iterator for the classifier model. - Input - wav_paths : list of strings (wav file paths) - labels : list of classification labels for the corresponding audio wav files - is_train : boolean flag, if the dataset loader is for the train or test pipeline - args : args object (contains info about model and training) + Input: + wav_paths : list of strings (wav file paths). + labels : list of classification labels for the corresponding audio wav files. + is_train : boolean flag, if the dataset loader is for the train or test pipeline. + args : args object (contains info about model and training). """ self.wav_paths = wav_paths self.labels = labels @@ -379,19 +375,19 @@ def __init__(self, wav_paths, labels, args, is_train=True): def __len__(self): """ - Number of audio samples available + Number of audio samples available. """ return len(self.wav_paths) def one_hot_encoder(self, lab): """ - Label index to one-hot encoder + Label index to one-hot encoder. Input: - lab: label index + lab: label index. Output: - one_hot: label in the one-hot format + one_hot: label in the one-hot format. """ one_hot = np.zeros(len(self.args.words)) one_hot[lab]=1 @@ -399,16 +395,16 @@ def one_hot_encoder(self, lab): def __getitem__(self, idx): """ - Gives one sample from the dataset. Data is read in this snippet. (refer to the collate function for pre-processing) + Gives one sample from the dataset. Data is read in this snippet. (refer to the collate function for pre-processing). Input: - idx: index for the sample + idx: index for the sample. Output: - x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping - one_hot_label : one-hot encoded label + x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping. + one_hot_label : one-hot encoded label. seqlen : length of the audio file. - This value will be dropped and seqlen after feature extraction will be used. Refer to the collate func + This value will be dropped and seqlen after feature extraction will be used. Refer to the collate function. """ x, fs = sf.read(self.wav_paths[idx]) @@ -429,15 +425,15 @@ def __getitem__(self, idx): class CollateWavsClassifier: def __call__(self, batch): """ - Pre-processing and padding, followed by batching the set of inputs + Pre-processing and padding, followed by batching the set of inputs. Input: - batch: list of tuples (input wav, one hot classification label, sequence length) + batch: list of tuples (input wav, one hot classification label, sequence length). Output: feature_tensor : the melspectogram features of the input audio. The features are padded for batching. - one_hot_label_tensor : the on-hot label in a tensor format - seqlen_tensor : the sequence length of the features in a minibatch + one_hot_label_tensor : the on-hot label in a tensor format. + seqlen_tensor : the sequence length of the features in a minibatch. """ x = []; one_hot_label = []; seqlen = [] batch_size = len(batch) @@ -446,19 +442,19 @@ def __call__(self, batch): x.append(x_) one_hot_label.append(one_hot_label_) - # pad all sequences to have same length and get features + # pad all sequences to have same length and get features. features=[] T = max([len(x_) for x_ in x]) T = max([T, 48000]) for index in range(batch_size): - # pad audio to same length for all the audio samples in the batch + # pad audio to same length for all the audio samples in the batch. x_pad_length = (T - len(x[index])) x[index] = np.pad(x[index], (x_pad_length,0), 'constant', constant_values=(0, 0)) - # Extract Mel-Spectogram from padded audio + # Extract Mel-Spectogram from padded audio. feature = librosa.feature.melspectrogram(y=x[index],sr=16000,n_mels=80,win_length=25*16,hop_length=10*16, n_fft=512) feature = librosa.core.power_to_db(feature) - # Normalize the features + # Normalize the features. max_value = np.max(feature) min_value = np.min(feature) if min_value == max_value: @@ -508,16 +504,16 @@ def __call__(self, batch): parser.add_argument('--pre_phone_list', action='store_true') args = parser.parse_args() - # SNRs + # SNRs. args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] - # Workers + # Workers. if args.workers == -1: args.workers = multiprocessing.cpu_count() - # Words + # Words. if args.word_model_name == 'google30': args.words = ["bed", "bird", "cat", "dog", "down", "eight", "five", "four", "go", "happy", "house", "left", "marvin", "nine", "no", "off", "on", "one", "right", @@ -529,7 +525,7 @@ def __call__(self, batch): else: raise ValueError('Incorrect Word Model Name') - # Data Folders + # Data Folders. args.train_data_folders = [folder_idx for folder_idx in args.train_data_folders.split(',')] args.test_data_folders = [folder_idx for folder_idx in args.test_data_folders.split(',')] diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py index ab7e4c558..204eee5df 100644 --- a/applications/KWS_Phoneme/kwscnn.py +++ b/applications/KWS_Phoneme/kwscnn.py @@ -14,7 +14,7 @@ class _IndexSelect(torch.nn.Module): def __init__(self, channels, direction, groups): """ - Channel permutation module. The purpose of this is to allow mixing across the CNN groups + Channel permutation module. The purpose of this is to allow mixing across the CNN groups. """ super(_IndexSelect, self).__init__() @@ -48,15 +48,15 @@ def __init__(self): def forward(self, value): """ - Applies a custom activation function - The first half of the channels are passed through sigmoid layer and the next half though a tanh - The outputs are multiplied and returned + Applies a custom activation function. + The first half of the channels are passed through sigmoid layer and the next half though a tanh. + The outputs are multiplied and returned. - Input - value: A tensor of shape (batch, channels, *) + Input: + value: A tensor of shape (batch, channels, *). - Output - activation output of shape (batch, channels/2, *) + Output: + activation output of shape (batch, channels/2, *). """ channels = value.shape[1] piv = int(channels/2) @@ -78,12 +78,12 @@ def __init__(self, in_channels, out_channels, kernel_size, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) """ - A convolution layer with the weight matrix subjected to a low-rank decomposition. Currently for kernel size of 5 + A convolution layer with the weight matrix subjected to a low-rank decomposition. Currently for kernel size of 5. - Input - rank : The rank used for the low-rank decomposition on the weight/kernel tensor - All other parameters are similar to that of a convolution layer - Only change is the decomposition of the output channels into low-rank tensors + Input: + rank : The rank used for the low-rank decomposition on the weight/kernel tensor. + All other parameters are similar to that of a convolution layer. + Only change is the decomposition of the output channels into low-rank tensors. """ self.kernel_size = kernel_size self.rank = rank @@ -96,14 +96,14 @@ def __init__(self, in_channels, out_channels, kernel_size, def forward(self, input): """ - The decomposed weights are multiplied to enforce the low-rank constraint - The conv1d is performed as usual post multiplication + The decomposed weights are multiplied to enforce the low-rank constraint. + The conv1d is performed as usual post multiplication. - Input - input: Input of shape similar to that of which is fed to a conv layer + Input: + input: Input of shape similar to that of which is fed to a conv layer. - Output - convolution output + Output: + convolution output. """ lr_weight = torch.matmul(self.W1, self.W2) lr_weight = torch.reshape(lr_weight, (self.out_channels, self.in_channels, self.kernel_size)) @@ -122,19 +122,19 @@ def __init__( activation='sigmoid', rank=50): super(PreRNNConvBlock, self).__init__() """ - A low-rank convolution layer combination with pooling and activation layers. Currently for kernel size of 5 + A low-rank convolution layer combination with pooling and activation layers. Currently for kernel size of 5. - Input - in_channels : number of input channels for the conv layer - out_channels : number of output channels for the conv layer - kernel : conv kernel size - stride : conv stride - groups : number of groups for conv layer - avg_pool : kernel size for average pooling layer - dropout : dropout layer probability - batch_norm : momentum for batch norm - activation : activation layer - rank : rank for low-rank decomposition for conv layer weights + Input: + in_channels : number of input channels for the conv layer. + out_channels : number of output channels for the conv layer. + kernel : conv kernel size. + stride : conv stride. + groups : number of groups for conv layer. + avg_pool : kernel size for average pooling layer. + dropout : dropout layer probability. + batch_norm : momentum for batch norm. + activation : activation layer. + rank : rank for low-rank decomposition for conv layer weights. """ activators = { 'sigmoid': torch.nn.Sigmoid(), @@ -184,13 +184,13 @@ def __init__( def forward(self, x): """ - Apply the set of layers initialized in __init__ + Apply the set of layers initialized in __init__. - Input - x: A tensor of shape (batch, channels, length) + Input: + x: A tensor of shape (batch, channels, length). - Output - network block output of shape (batch, channels, length) + Output: + network block output of shape (batch, channels, length). """ x = self._op1(x) return x @@ -203,19 +203,19 @@ def __init__( activation='sigmoid', rank=50): super(DSCNNBlockLR, self).__init__() """ - A depthwise separable low-rank convolution layer combination with pooling and activation layers + A depthwise separable low-rank convolution layer combination with pooling and activation layers. - Input - in_channels : number of input channels for the pointwise conv layer - out_channels : number of output channels for the pointwise conv layer - kernel : conv kernel size for depthwise layer - stride : conv stride - groups : number of groups for conv layer - avg_pool : kernel size for average pooling layer - dropout : dropout layer probability - batch_norm : momentum for batch norm - activation : activation layer - rank : rank for low-rank decomposition for conv layer weights + Input: + in_channels : number of input channels for the pointwise conv layer. + out_channels : number of output channels for the pointwise conv layer. + kernel : conv kernel size for depthwise layer. + stride : conv stride. + groups : number of groups for conv layer. + avg_pool : kernel size for average pooling layer. + dropout : dropout layer probability. + batch_norm : momentum for batch norm. + activation : activation layer. + rank : rank for low-rank decomposition for conv layer weights. """ activators = { 'sigmoid': torch.nn.Sigmoid(), @@ -265,22 +265,22 @@ def __init__( def forward(self, x): """ - Apply the set of layers initialized in __init__ + Apply the set of layers initialized in __init__. - Input - x: A tensor of shape (batch, channels, length) + Input: + x: A tensor of shape (batch, channels, length). - Output - network block output of shape (batch, channels, length) + Output: + network block output of shape (batch, channels, length). """ x = self._op(x) return x class BiFastGRNN(nn.Module): """ - Bi Directional FastGRNN + Bi Directional FastGRNN. - Parameters and arguments are similar to the torch RNN counterparts + Parameters and arguments are similar to the torch RNN counterparts. """ def __init__(self, inputDims, hiddenDims, gate_nonlinearity, update_nonlinearity, rank): @@ -304,112 +304,108 @@ def __init__(self, inputDims, hiddenDims, gate_nonlinearity, def forward(self, input_f, input_b): """ - Pass the inputs to forward and backward layers - Please note the backward layer is similar to the forward layer and the input needs to be fed reversed accordingly - Tensors are of the shape (batch, length, channels) + Pass the inputs to forward and backward layers. + Please note the backward layer is similar to the forward layer and the input needs to be fed reversed accordingly. + Tensors are of the shape (batch, length, channels). - Input - input_f : input to the forward layer - input_b : input to the backward layer. Input needs to be reversed before passing through this forward method + Input: + input_f : input to the forward layer. + input_b : input to the backward layer. Input needs to be reversed before passing through this forward method. - Output - output1 : output of the forward layer - output2 : output of the backward layer + Output: + output1 : output of the forward layer. + output2 : output of the backward layer. """ - # Bidirectional FastGRNN + # Bidirectional FastGRNN. output1 = self.cell_fwd(input_f) output2 = self.cell_bwd(input_b) - #Returning the flipped output only for the bwd pass - #Will align it in the post processing + #Returning the flipped output only for the bwd pass. + #Will align it in the post processing. return output1, output2 def X_preRNN_process(X, fwd_context, bwd_context): """ - A depthwise separable low-rank convolution layer combination with pooling and activation layers - - Input - in_channels : number of input channels for the pointwise conv layer - out_channels : number of output channels for the pointwise conv layer - kernel : conv kernel size for depthwise layer - stride : conv stride - groups : number of groups for conv layer - avg_pool : kernel size for average pooling layer - dropout : dropout layer probability - batch_norm : momentum for batch norm - activation : activation layer - rank : rank for low-rank decomposition for conv layer weights + A depthwise separable low-rank convolution layer combination with pooling and activation layers. + + Input: + in_channels : number of input channels for the pointwise conv layer. + out_channels : number of output channels for the pointwise conv layer. + kernel : conv kernel size for depthwise layer. + stride : conv stride. + groups : number of groups for conv layer. + avg_pool : kernel size for average pooling layer. + dropout : dropout layer probability. + batch_norm : momentum for batch norm. + activation : activation layer. + rank : rank for low-rank decomposition for conv layer weights. """ - #FWD bricking + # Forward bricking. brickLength = fwd_context hopLength = 3 X_bricked_f1 = X.unfold(1, brickLength, hopLength) X_bricked_f2 = X_bricked_f1.permute(0, 1, 3, 2) - #X_bricked_f [batch, num_bricks, brickLen, inpDim] + # X_bricked_f [batch, num_bricks, brickLen, inpDim]. oldShape_f = X_bricked_f2.shape X_bricked_f = torch.reshape( X_bricked_f2, [oldShape_f[0] * oldShape_f[1], oldShape_f[2], -1]) - # X_bricked_f [batch*num_bricks, brickLen, inpDim] + # X_bricked_f [batch*num_bricks, brickLen, inpDim]. - #BWD bricking + # Backward bricking. brickLength = bwd_context hopLength = 3 X_bricked_b = X.unfold(1, brickLength, hopLength) X_bricked_b = X_bricked_b.permute(0, 1, 3, 2) - #X_bricked_f [batch, num_bricks, brickLen, inpDim] + # X_bricked_b [batch, num_bricks, brickLen, inpDim]. oldShape_b = X_bricked_b.shape X_bricked_b = torch.reshape( X_bricked_b, [oldShape_b[0] * oldShape_b[1], oldShape_b[2], -1]) - # X_bricked_f [batch*num_bricks, brickLen, inpDim] + # X_bricked_b [batch*num_bricks, brickLen, inpDim]. return X_bricked_f, oldShape_f, X_bricked_b, oldShape_b def X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b): """ - A depthwise separable low-rank convolution layer combination with pooling and activation layers - - Input - in_channels : number of input channels for the pointwise conv layer - out_channels : number of output channels for the pointwise conv layer - kernel : conv kernel size for depthwise layer - stride : conv stride - groups : number of groups for conv layer - avg_pool : kernel size for average pooling layer - dropout : dropout layer probability - batch_norm : momentum for batch norm - activation : activation layer - rank : rank for low-rank decomposition for conv layer weights + A depthwise separable low-rank convolution layer combination with pooling and activation layers. + + Input: + in_channels : number of input channels for the pointwise conv layer. + out_channels : number of output channels for the pointwise conv layer. + kernel : conv kernel size for depthwise layer. + stride : conv stride. + groups : number of groups for conv layer. + avg_pool : kernel size for average pooling layer. + dropout : dropout layer probability. + batch_norm : momentum for batch norm. + activation : activation layer. + rank : rank for low-rank decomposition for conv layer weights. """ - #Forward bricks folding + # Forward bricks folding. X_f = torch.reshape(X_f, [oldShape_f[0], oldShape_f[1], oldShape_f[2], -1]) - #X_f [batch, num_bricks, brickLen, hiddenDim] + # batch, num_bricks, brickLen, hiddenDim. X_new_f = X_f[:, 0, ::3, :] - #batch,brickLen,hiddenDim + # batch, brickLen, hiddenDim. X_new_f_rest = X_f[:, :, -1, :].squeeze(2) - #batch, numBricks-1,hiddenDim + # batch, numBricks - 1, hiddenDim. shape = X_new_f_rest.shape X_new_f = torch.cat((X_new_f, X_new_f_rest), dim=1) - #batch,seqLen,hiddenDim - #X_new_f [batch, seqLen, hiddenDim] + # batch, seqLen, hiddenDim. - #Backward Bricks folding + # Backward Bricks folding. X_b = torch.reshape(X_b, [oldShape_b[0], oldShape_b[1], oldShape_b[2], -1]) - #X_b [batch, num_bricks, brickLen, hiddenDim] + # batch, num_bricks, brickLen, hiddenDim. X_b = torch.flip(X_b, [1]) - #Reverse the ordering of the bricks (bring last brick to start) - + # Reverse the ordering of the bricks (bring last brick to start). X_new_b = X_b[:, 0, ::3, :] - #batch,brickLen,inpDim + # batch, brickLen, inpDim. X_new_b_rest = X_b[:, :, -1, :].squeeze(2) - #batch,(seqlen-brickLen),hiddenDim + # batch, seqlen - brickLen, hiddenDim. X_new_b = torch.cat((X_new_b, X_new_b_rest), dim=1) - #batch,seqLen,hiddenDim + # batch, seqLen, hiddenDim. X_new_b = torch.flip(X_new_b, [1]) - #inverting the flip operation - + # inverting the flip operation. X_new = torch.cat((X_new_f, X_new_b), dim=2) - #batch,seqLen,2*hiddenDim - + # batch, seqLen, 2 * hiddenDim. return X_new @@ -419,18 +415,18 @@ def __init__(self, cnn_channels, rnn_hidden_size, rnn_num_layers, isBi=True, num_labels=41, rank=None, fwd_context=15, bwd_context=9): super(DSCNN_RNN_Block, self).__init__() """ - A depthwise separable low-rank convolution layer combination with pooling and activation layers + A depthwise separable low-rank convolution layer combination with pooling and activation layers. - Input - cnn_channels : number of the output channels for the first CNN block - rnn_hidden_size : hidden dimensions of the FastGRNN - rnn_num_layers : number of FastGRNN layers - device : device on which the tensors would placed - gate_nonlinearity : activation function for the gating in the FastGRNN - update_nonlinearity : activation function for the update function in the FastGRNN - isBi : boolean flag to use bi-directional FastGRNN - fwd_context : window for the forward pass - bwd_context : window for the backward pass + Input: + cnn_channels : number of the output channels for the first CNN block. + rnn_hidden_size : hidden dimensions of the FastGRNN. + rnn_num_layers : number of FastGRNN layers. + device : device on which the tensors would placed. + gate_nonlinearity : activation function for the gating in the FastGRNN. + update_nonlinearity : activation function for the update function in the FastGRNN. + isBi : boolean flag to use bi-directional FastGRNN. + fwd_context : window for the forward pass. + bwd_context : window for the backward pass. """ self.cnn_channels = cnn_channels self.rnn_hidden_size = rnn_hidden_size @@ -455,8 +451,8 @@ def __init__(self, cnn_channels, rnn_hidden_size, rnn_num_layers, def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, num_labels, rank): """ - Declare the netwok layers - Arguments can be inferred from the __init__ + Declare the netwok layers. + Arguments can be inferred from the __init__. """ self.CNN1 = torch.nn.Sequential( PreRNNConvBlock(80, cnn_channels, 5, 1, 1, @@ -499,13 +495,13 @@ def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, def forward(self, features): """ - Apply the set of layers initialized in __init__ + Apply the set of layers initialized in __init__. - Input - features: A tensor of shape (batch, channels, length) + Input: + features: a tensor of shape (batch, channels, length). - Output - network block output in the form (batch, channels, length) + Output: + network block output in the form (batch, channels, length). """ batch, _, max_seq_len = features.shape X = self.CNN1(features) # Down to 30ms inference / 250ms window @@ -538,18 +534,18 @@ def __init__(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels=2, dropout=0, batch_assertion=False): super(Binary_Classification_Block, self).__init__() """ - A depthwise separable low-rank convolution layer combination with pooling and activation layers + A depthwise separable low-rank convolution layer combination with pooling and activation layers. - Input - in_size : number of input channels to the layer - rnn_hidden_size : hidden dimensions of the RNN layer - rnn_num_layers : number of layers for the RNN layer - device : device on which the tensors would placed - islstm : boolean flag to use the LSTM. False would use GRU - isBi : boolean flag to use the bi-directional variant of the RNN - momentum : momentum for the batch-norm layer - num_labels : number of output labels - dropout : probability for the dropout layer + Input: + in_size : number of input channels to the layer. + rnn_hidden_size : hidden dimensions of the RNN layer. + rnn_num_layers : number of layers for the RNN layer. + device : device on which the tensors would placed. + islstm : boolean flag to use the LSTM. False would use GRU. + isBi : boolean flag to use the bi-directional variant of the RNN. + momentum : momentum for the batch-norm layer. + num_labels : number of output labels. + dropout : probability for the dropout layer. """ self.in_size = in_size self.rnn_hidden_size = rnn_hidden_size @@ -573,8 +569,8 @@ def __init__(self, in_size, rnn_hidden_size, rnn_num_layers, def declare_network(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels): """ - Declare the netwok layers - Arguments can be inferred from the __init__ + Declare the netwok layers. + Arguments can be inferred from the __init__. """ self.CNN1 = torch.nn.Sequential( torch.nn.LeakyReLU(negative_slope=0.01), @@ -603,13 +599,13 @@ def declare_network(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels): def forward(self, features, seqlen): """ - Apply the set of layers initialized in __init__ + Apply the set of layers initialized in __init__. - Input - features: A tensor of shape (batch, channels, length) + Input: + features: A tensor of shape (batch, channels, length). - Output - network block output in the form (batch, length, channels). length will be 1 + Output: + network block output in the form (batch, length, channels). length will be 1. """ batch, _, _ = features.shape @@ -622,18 +618,18 @@ def forward(self, features, seqlen): hidden = self.init_hidden(batch, self.rnn_hidden_size, self.rnn_num_layers) - X = self.CNN1(features) # Down to 30ms inference / 250ms window + X = self.CNN1(features) # Down to 30ms inference / 250ms window. - X = X.permute((0, 2, 1)) # NCL to NLC + X = X.permute((0, 2, 1)) # NCL to NLC. max_seq_len = X.shape[1] - # modify seqlen + # modify seqlen. max_seq_len = min(torch.max(seqlen).item(), max_seq_len) seqlen = torch.clamp(seqlen, max=max_seq_len) self.seqlen = seqlen - # pad according to seqlen + # pad according to seqlen. X = torch.nn.utils.rnn.pack_padded_sequence(X, seqlen, batch_first=True, @@ -660,9 +656,9 @@ def forward(self, features, seqlen): def init_hidden(self, batch, rnn_hidden_size, rnn_num_layers): """ - Used to initialize the first hidden state of the RNN. It is currently zero. THe user is free to edit this function for additional analysis + Used to initialize the first hidden state of the RNN. It is currently zero. The user is free to edit this function for additional analysis. """ - # the weights are of the form (batch, num_layers * num_directions , hidden_size) + # the weights are of the form (batch, num_layers * num_directions , hidden_size). if self.batch_assertion: hidden = torch.zeros(rnn_num_layers * self.direction_param, batch, rnn_hidden_size) diff --git a/applications/KWS_Phoneme/train_classifier.py b/applications/KWS_Phoneme/train_classifier.py index 1bfaa37f9..bd8a59b35 100644 --- a/applications/KWS_Phoneme/train_classifier.py +++ b/applications/KWS_Phoneme/train_classifier.py @@ -6,7 +6,7 @@ import re import numpy as np import torch -# Aux scripts +# Aux scripts. import kwscnn import multiprocessing from data_pipe import get_ASR_datasets, get_classification_dataset @@ -17,7 +17,7 @@ def parseArgs(): Describes the architecture and the hyper-parameters """ parser = argparse.ArgumentParser() - # Args for Model Traning + # Args for Model Traning. parser.add_argument('--phoneme_model_load_ckpt', type=str, required=True, help="Phoneme checkpoint file to be loaded") parser.add_argument('--classifier_model_save_folder', type=str, default='./classifier_model', help="Folder to save the classifier checkpoint") parser.add_argument('--classifier_model_load_ckpt', type=str, default=None, help="Classifier checkpoint to be loaded") @@ -31,7 +31,7 @@ def parseArgs(): parser.add_argument('--words', type=str, help="List of words to be used. This will be assigned in the code. User input will not affect the result") parser.add_argument("--is_training", action='store_true', help="True for training") parser.add_argument("--synth", action='store_true', help="Use Synth block or not") - # Args for DataLoader + # Args for DataLoader. parser.add_argument('--base_path', type=str, required=True, help="path to train and test data folders") parser.add_argument('--train_data_folders', type=str, default="google30_train", help="List of training folders in base path. Each folder is a dataset in the prescribed format") parser.add_argument('--test_data_folders', type=str, default="google30_test", help="List of testing folders in base path. Each folder is a dataset in the prescribed format") @@ -47,7 +47,7 @@ def parseArgs(): parser.add_argument('--rir_chance', type=float, default=0.9, help="Probability of performing reverbration") parser.add_argument('--synth_chance', type=float, default=0.9, help="Probability of pre-processing the input with reverb and noise") parser.add_argument('--pre_phone_list', action='store_true', help="use pre-fixed set of phonemes") - # Args for Phoneme + # Args for Phoneme. parser.add_argument('--phoneme_cnn_channels', type=int, default=400, help="Number od channels for the CNN layers") parser.add_argument('--phoneme_rnn_hidden_size', type=int, default=200, help="Number of RNN hidden states") parser.add_argument('--phoneme_rnn_layers', type=int, default=1, help="Number of RNN layers") @@ -56,7 +56,7 @@ def parseArgs(): parser.add_argument('--phoneme_bwd_context', type=int, default=9, help="RNN backward window context") parser.add_argument('--phoneme_phoneme_isBi', action='store_true', help="Use Bi-Directional RNN") parser.add_argument('--phoneme_num_labels', type=int, default=41, help="Number og phoneme labels") - # Args for Classifier + # Args for Classifier. parser.add_argument('--classifier_rnn_hidden_size', type=int, default=100, help="Classifier RNN hidden dimensions") parser.add_argument('--classifier_rnn_num_layers', type=int, default=1, help="Classifier RNN number of layers") parser.add_argument('--classifier_dropout', type=float, default=0.2, help="Classifier dropout layer probability") @@ -65,16 +65,16 @@ def parseArgs(): args = parser.parse_args() - # Parse the gain and SNR values to a float format + # Parse the gain and SNR values to a float format. args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] - # Fix the number of workers for the data Loader. If == -1 then use all possible workers + # Fix the number of workers for the data Loader. If == -1 then use all possible workers. if args.workers == -1: args.workers = multiprocessing.cpu_count() - # Choose the word list to be used. For custom word lists, please add an elif condition + # Choose the word list to be used. For custom word lists, please add an elif condition. if args.word_model_name == 'google30': args.words = ["bed", "bird", "cat", "dog", "down", "eight", "five", "four", "go", "happy", "house", "left", "marvin", "nine", "no", "off", "on", "one", "right", @@ -85,8 +85,8 @@ def parseArgs(): else: raise ValueError('Incorrect Word Model Name') - # The data-folder in args.base_path that contain the data - # Refer to data_pipe.py for loading format + # The data-folder in args.base_path that contain the data. + # Refer to data_pipe.py for loading format. args.train_data_folders = [folder_idx for folder_idx in args.train_data_folders.split(',')] args.test_data_folders = [folder_idx for folder_idx in args.test_data_folders.split(',')] @@ -95,13 +95,13 @@ def parseArgs(): def train_classifier_model(args): """ - Train the Classifier Model on the designated dataset - The Dataset loader is defined in data_pipe.py - Default dataset used is Google30. Change the paths and file reader to change datasets + Train the Classifier Model on the designated dataset. + The Dataset loader is defined in data_pipe.py. + Default dataset used is Google30. Change the paths and file reader to change datasets. - args: args object (contains info about model and training) + args: args object (contains info about model and training). """ - # GPU Settings + # GPU Settings. gpu_str = str() for gpu in args.gpu.split(','): gpu_str = gpu_str + str(gpu) + "," @@ -109,7 +109,7 @@ def train_classifier_model(args): use_cuda = torch.cuda.is_available() and (args.gpu != -1) device = torch.device("cuda" if use_cuda else "cpu") - # Instantiate Phoneme Model + # Instantiate Phoneme Model. phoneme_model = kwscnn.DSCNN_RNN_Block(cnn_channels=args.phoneme_cnn_channels, rnn_hidden_size=args.phoneme_rnn_hidden_size, rnn_num_layers=args.phoneme_rnn_layers, @@ -118,12 +118,12 @@ def train_classifier_model(args): bwd_context=args.phoneme_bwd_context, num_labels=args.phoneme_num_labels) - # Freeze Phoneme Model and Deactivate BatchNorm and Dropout Layers + # Freeze Phoneme Model and Deactivate BatchNorm and Dropout Layers. for param in phoneme_model.parameters(): param.requires_grad = False phoneme_model.train(False) - # Instantiate Classifier Model + # Instantiate Classifier Model. classifier_model = kwscnn.Binary_Classification_Block(in_size=args.phoneme_num_labels, rnn_hidden_size=args.classifier_rnn_hidden_size, rnn_num_layers=args.classifier_rnn_num_layers, @@ -131,7 +131,7 @@ def train_classifier_model(args): isBi=args.classifier_isBi, dropout=args.classifier_dropout, num_labels=len(args.words)) - # Transfer to specified device + # Transfer to specified device. phoneme_model.to(device) phoneme_model = torch.nn.DataParallel(phoneme_model) classifier_model.to(device) @@ -139,18 +139,18 @@ def train_classifier_model(args): model = {'name': phoneme_model.module.__name__, 'phoneme': phoneme_model, 'classifier_name': classifier_model.module.__name__, 'classifier': classifier_model} - # Optimizer + # Optimizer. if args.optim == "adam": model['opt'] = torch.optim.Adam(model['classifier'].parameters(), lr=args.lr) if args.optim == "sgd": model['opt'] = torch.optim.SGD(model['classifier'].parameters(), lr=args.lr) - # Load the specified phoneme checkpoint. 'phoneme_model_load_ckpt' must point to a checkpoint and not folder + # Load the specified phoneme checkpoint. 'phoneme_model_load_ckpt' must point to a checkpoint and not folder. if args.phoneme_model_load_ckpt is not None: if os.path.exists(args.phoneme_model_load_ckpt): - # Load Checkpoint + # Load Checkpoint. latest_phoneme_ckpt = torch.load(args.phoneme_model_load_ckpt, map_location=device) - # Load specific state_dicts() and print the latest stats + # Load specific state_dicts() and print the latest stats. print(f"Model Phoneme Location : {args.phoneme_model_load_ckpt}", flush=True) model['phoneme'].load_state_dict(latest_phoneme_ckpt['phoneme_state_dict']) print(f"Checkpoint Stats : {latest_phoneme_ckpt['train_stats']}", flush=True) @@ -159,17 +159,17 @@ def train_classifier_model(args): else: print("No Phoneme Checkpoint Given", flush=True) - # Load the specified classifier checkpoint. 'classifier_model_load_ckpt' must point to a checkpoint and not folder + # Load the specified classifier checkpoint. 'classifier_model_load_ckpt' must point to a checkpoint and not folder. if args.classifier_model_load_ckpt is not None: if os.path.exists(args.classifier_model_load_ckpt): - # Get the number from the classifier checkpoint path - start_epoch = args.classifier_model_load_ckpt # Temporarily store the full ckpt path - start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux) - start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows) - start_epoch = int(start_epoch.split('.')[0]) # retain the integers - # Load Checkpoint + # Get the number from the classifier checkpoint path. + start_epoch = args.classifier_model_load_ckpt # Temporarily store the full ckpt path. + start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux). + start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows). + start_epoch = int(start_epoch.split('.')[0]) # retain the integers. + # Load Checkpoint. latest_classifier_ckpt = torch.load(args.classifier_model_load_ckpt, map_location=device) - # Load specific state_dicts() and print the latest stats + # Load specific state_dicts() and print the latest stats. model['classifier'].load_state_dict(latest_classifier_ckpt['classifier_state_dict']) model['opt'].load_state_dict(latest_classifier_ckpt['opt_state_dict']) print(f"Checkpoint Stats : {latest_classifier_ckpt['train_stats']}", flush=True) @@ -178,7 +178,7 @@ def train_classifier_model(args): else: start_epoch = 0 - # Instantiate all Essential Variables and utils + # Instantiate all Essential Variables and utils. train_dataset, test_dataset = get_classification_dataset(args) train_loader = train_dataset.loader test_loader = test_dataset.loader @@ -186,7 +186,7 @@ def train_classifier_model(args): output_frame_rate = 3 save_path = args.classifier_model_save_folder os.makedirs(args.classifier_model_save_folder, exist_ok=True) - # Print for cross-checking + # Print for cross-checking. print(f"Pre Phone List {args.pre_phone_list}", flush=True) print(f"Start Epoch : {start_epoch}", flush=True) print(f"Device : {device}", flush=True) @@ -207,8 +207,8 @@ def train_classifier_model(args): train_seqlen_classifier = train_seqlen_classifier.to(device) model['opt'].zero_grad() - # Data-padding for bricking - train_features = train_features.permute((0, 2, 1)) # NCL to NLC + # Data-padding for bricking. + train_features = train_features.permute((0, 2, 1)) # NCL to NLC. mod_len = train_features.shape[1] pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate pad_len_feature = pad_len_mod @@ -218,23 +218,23 @@ def train_classifier_model(args): assert (train_features.shape[1]) % output_frame_rate == 0 - # Get the posterior predictions and trim the labels to the same length as the predictions - train_features = train_features.permute((0, 2, 1)) # NLC to NCL + # Get the posterior predictions and trim the labels to the same length as the predictions. + train_features = train_features.permute((0, 2, 1)) # NLC to NCL. train_posteriors = model['phoneme'](train_features) train_posteriors = model['classifier'](train_posteriors, train_seqlen_classifier) N, L, C = train_posteriors.shape - # Permute and ready the final and pred labels values - train_flat_posteriors = train_posteriors.reshape((-1, C)) # to [NL] x C + # Permute and ready the final and pred labels values. + train_flat_posteriors = train_posteriors.reshape((-1, C)) # to [NL] x C. - # Loss and backward step + # Loss and backward step. train_label = train_label.type(torch.float32) loss_classifier_model = torch.nn.functional.binary_cross_entropy_with_logits(train_flat_posteriors, train_label) loss_classifier_model.backward() torch.nn.utils.clip_grad_norm_(model['classifier'].parameters(), 10.0) model['opt'].step() - # Stats + # Stats. model['train_stats']['loss'] += loss_classifier_model.detach() _, train_idx_pred = torch.max(train_flat_posteriors, dim=1) _, train_idx_label = torch.max(train_label, dim=1) @@ -242,7 +242,7 @@ def train_classifier_model(args): model['train_stats']['total'] += train_idx_label.shape[0] if epoch % args.save_tick == 0: - # Save the model + # Save the model. torch.save({'classifier_state_dict': model['classifier'].state_dict(), 'opt_state_dict': model['opt'].state_dict(), 'train_stats' : model['train_stats']}, os.path.join(save_path, f'{epoch}.pt')) @@ -261,8 +261,8 @@ def train_classifier_model(args): test_label = test_label.to(device) test_seqlen_classifier = test_seqlen_classifier.to(device) - # Data-padding for bricking - test_features = test_features.permute((0, 2, 1)) # NCL to NLC + # Data-padding for bricking. + test_features = test_features.permute((0, 2, 1)) # NCL to NLC. mod_len = test_features.shape[1] pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate pad_len_feature = pad_len_mod @@ -272,16 +272,16 @@ def train_classifier_model(args): assert (test_features.shape[1]) % output_frame_rate == 0 - # Get the posterior predictions and trim the labels to the same length as the predictions - test_features = test_features.permute((0, 2, 1)) # NLC to NCL + # Get the posterior predictions and trim the labels to the same length as the predictions. + test_features = test_features.permute((0, 2, 1)) # NLC to NCL. test_posteriors = model['phoneme'](test_features) test_posteriors = model['classifier'](test_posteriors, test_seqlen_classifier) N, L, C = test_posteriors.shape - # Permute and ready the final and pred labels values - test_flat_posteriors = test_posteriors.reshape((-1, C)) # to [NL] x C + # Permute and ready the final and pred labels values. + test_flat_posteriors = test_posteriors.reshape((-1, C)) # to [NL] x C. - # Stats + # Stats. _, test_idx_pred = torch.max(test_flat_posteriors, dim=1) _, test_idx_label = torch.max(test_label, dim=1) model['test_stats']['correct'] += float(np.sum((test_idx_pred == test_idx_label).detach().cpu().numpy())) diff --git a/applications/KWS_Phoneme/train_phoneme.py b/applications/KWS_Phoneme/train_phoneme.py index 69857f012..44dc36a55 100644 --- a/applications/KWS_Phoneme/train_phoneme.py +++ b/applications/KWS_Phoneme/train_phoneme.py @@ -6,18 +6,18 @@ import re import numpy as np import torch -# Aux scripts +# Aux scripts. import kwscnn import multiprocessing from data_pipe import get_ASR_datasets def parseArgs(): """ - Parse the command line arguments - Describes the architecture and the hyper-parameters + Parse the command line arguments. + Describes the architecture and the hyper-parameters. """ parser = argparse.ArgumentParser() - # Args for Model Traning + # Args for Model Traning. parser.add_argument('--phoneme_model_save_folder', type=str, default='./phoneme_model', help="Folder to save the checkpoint") parser.add_argument('--phoneme_model_load_ckpt', type=str, default=None, help="Checkpoint file to be loaded") parser.add_argument('--optim', type=str, default='adam', help="Optimizer to be used") @@ -26,8 +26,7 @@ def parseArgs(): parser.add_argument('--save_tick', type=int, default=1, help="Number of epochs to wait to save") parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") parser.add_argument("--gpu", type=str, default='0', help="GPU indices Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") - - # Args for DataLoader + # Args for DataLoader. parser.add_argument('--base_path', type=str, required=True, help="Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here.") parser.add_argument('--rir_base_path', type=str, required=True, help="Folder with the reverbration files") parser.add_argument('--additive_base_path', type=str, required=True, help="Folder with additive noise files") @@ -41,7 +40,7 @@ def parseArgs(): parser.add_argument('--rir_chance', type=float, default=0.25, help="Probability of performing reverbration") parser.add_argument('--synth_chance', type=float, default=0.5, help="Probability of pre-processing the signal with noise and reverb") parser.add_argument('--pre_phone_list', action='store_true', help="Use pre-fixed list of phonemes") - # Args for Phoneme + # Args for Phoneme. parser.add_argument('--phoneme_cnn_channels', type=int, default=400, help="Number od channels for the CNN layers") parser.add_argument('--phoneme_rnn_hidden_size', type=int, default=200, help="Number of RNN hidden states") parser.add_argument('--phoneme_rnn_layers', type=int, default=1, help="Number of RNN layers") @@ -53,12 +52,12 @@ def parseArgs(): args = parser.parse_args() - # Parse the gain and SNR values to a float format + # Parse the gain and SNR values to a float format. args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] - # Fix the number of workers for the data Loader. If == -1 then use all possible workers + # Fix the number of workers for the data Loader. If == -1 then use all possible workers. if args.workers == -1: args.workers = multiprocessing.cpu_count() @@ -67,13 +66,13 @@ def parseArgs(): def train_phoneme_model(args): """ - Train the Phoneme Model on the designated dataset - The Dataset loader is defined in data_pipe.py - Default dataset used is LibriSpeeech. Change the paths and file reader to change datasets + Train the Phoneme Model on the designated dataset. + The Dataset loader is defined in data_pipe.py. + Default dataset used is LibriSpeeech. Change the paths and file reader to change datasets. - args: args object (contains info about model and training) + args: args object (contains info about model and training). """ - # GPU Settings + # GPU Settings. gpu_str = str() for gpu in args.gpu.split(','): gpu_str = gpu_str + str(gpu) + "," @@ -81,7 +80,7 @@ def train_phoneme_model(args): use_cuda = torch.cuda.is_available() and (args.gpu != -1) device = torch.device("cuda" if use_cuda else "cpu") - # Instantiate model + # Instantiate model. phoneme_model = kwscnn.DSCNN_RNN_Block(cnn_channels=args.phoneme_cnn_channels, rnn_hidden_size=args.phoneme_rnn_hidden_size, rnn_num_layers=args.phoneme_rnn_layers, @@ -90,29 +89,29 @@ def train_phoneme_model(args): bwd_context=args.phoneme_bwd_context, num_labels=args.phoneme_num_labels) - # Transfer to specified device + # Transfer to specified device. phoneme_model.to(device) phoneme_model = torch.nn.DataParallel(phoneme_model) model = {'name': phoneme_model.module.__name__, 'phoneme': phoneme_model} - # Optimizer + # Optimizer. if args.optim == "adam": model['opt'] = torch.optim.Adam(model['phoneme'].parameters(), lr=args.lr) if args.optim == "sgd": model['opt'] = torch.optim.SGD(model['phoneme'].parameters(), lr=args.lr) - # Load the specified checkpoint. 'phoneme_model_load_ckpt' must point to a checkpoint and not folder + # Load the specified checkpoint. 'phoneme_model_load_ckpt' must point to a checkpoint and not folder. if args.phoneme_model_load_ckpt is not None: if os.path.exists(args.phoneme_model_load_ckpt): - # Get the number from the phoneme checkpoint path - start_epoch = args.phoneme_model_load_ckpt # Temporarily store the full ckpt path - start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux) - start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows) - start_epoch = int(start_epoch.split('.')[0]) # retain the integers - # Load Checkpoint + # Get the number from the phoneme checkpoint path. + start_epoch = args.phoneme_model_load_ckpt # Temporarily store the full ckpt path. + start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux). + start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows). + start_epoch = int(start_epoch.split('.')[0]) # retain the integers. + # Load Checkpoint. latest_ckpt = torch.load(args.phoneme_model_load_ckpt, map_location=device) - # Load specific state_dicts() and print the latest stats + # Load specific state_dicts() and print the latest stats. model['phoneme'].load_state_dict(latest_ckpt['phoneme_state_dict']) model['opt'].load_state_dict(latest_ckpt['opt_state_dict']) print(f"Checkpoint Stats : {latest_ckpt['train_stats']}", flush=True) @@ -121,7 +120,7 @@ def train_phoneme_model(args): else: start_epoch = 0 - # Instantiate dataloaders, essential variables and save folders + # Instantiate dataloaders, essential variables and save folders. train_dataset = get_ASR_datasets(args) train_loader = train_dataset.loader total_batches = len(train_loader) @@ -135,7 +134,7 @@ def train_phoneme_model(args): print(f"Output Frame Rate (multiple of 10ms): {output_frame_rate}", flush=True) print(f"Number of Batches: {total_batches}", flush=True) - # Train Loop + # Train Loop. for epoch in range(start_epoch + 1, args.epochs): model['train_stats'] = {'loss': 0, 'predstd': 0, 'correct': 0, 'valid': 0} for features, label in train_loader: @@ -143,8 +142,8 @@ def train_phoneme_model(args): label = label.to(device) model['opt'].zero_grad() - # Data-padding for bricking - features = features.permute((0, 2, 1)) # NCL to NLC + # Data-padding for bricking. + features = features.permute((0, 2, 1)) # NCL to NLC. mod_len = features.shape[1] pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate pad_len_feature = pad_len_mod @@ -153,36 +152,36 @@ def train_phoneme_model(args): features = torch.cat((features, pad_data), dim=1) assert (features.shape[1]) % output_frame_rate == 0 - # Augmenting the label accordingly + # Augmenting the label accordingly. pad_len_label = pad_len_feature pad_data = torch.ones(label.shape[0], pad_len_label).to(device) * (-1) pad_data = pad_data.type(torch.long) label = torch.cat((label, pad_data), dim=1) - # Get the posterior predictions and trim the labels to the same length as the predictions - features = features.permute((0, 2, 1)) # NLC to NCL + # Get the posterior predictions and trim the labels to the same length as the predictions. + features = features.permute((0, 2, 1)) # NLC to NCL. posteriors = model['phoneme'](features) N, C, L = posteriors.shape - trim_label = label[:, ::output_frame_rate] # 30ms frame_rate + trim_label = label[:, ::output_frame_rate] # 30ms frame_rate. trim_label = trim_label[:, :L] - # Permute and ready the final and pred labels values - flat_posteriors = posteriors.permute((0, 2, 1)) # TO NLC - flat_posteriors = flat_posteriors.reshape((-1, C)) # to [NL] x C + # Permute and ready the final and pred labels values. + flat_posteriors = posteriors.permute((0, 2, 1)) # TO NLC. + flat_posteriors = flat_posteriors.reshape((-1, C)) # to [NL] x C. flat_labels = trim_label.reshape((-1)) _, idx = torch.max(flat_posteriors, dim=1) correct_count = (idx == flat_labels).detach().sum() valid_count = (flat_labels >= 0).detach().sum() - # Loss and backward step + # Loss and backward step. loss_phoneme_model = torch.nn.functional.cross_entropy(flat_posteriors, flat_labels, ignore_index=-1) loss_phoneme_model.backward() torch.nn.utils.clip_grad_norm_(model['phoneme'].parameters(), 10.0) model['opt'].step() - # Stats + # Stats. pred_std = idx.to(torch.float32).std() model['train_stats']['loss'] += loss_phoneme_model.detach() @@ -191,7 +190,7 @@ def train_phoneme_model(args): model['train_stats']['valid'] += valid_count if epoch % args.save_tick == 0: - # Save the model + # Save the model. torch.save({'phoneme_state_dict': model['phoneme'].state_dict(), 'opt_state_dict': model['opt'].state_dict(), 'train_stats' : model['train_stats']}, os.path.join(save_path, f'{epoch}.pt')) From 8c78dad597eefd3189fef3a424bbedf62259c98e Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Fri, 22 Oct 2021 04:51:22 -0700 Subject: [PATCH 8/8] Incorporate reviewer comments --- applications/KWS_Phoneme/data_pipe.py | 6 +++--- applications/KWS_Phoneme/kwscnn.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/applications/KWS_Phoneme/data_pipe.py b/applications/KWS_Phoneme/data_pipe.py index 9439c35de..af11cda12 100644 --- a/applications/KWS_Phoneme/data_pipe.py +++ b/applications/KWS_Phoneme/data_pipe.py @@ -21,12 +21,12 @@ def synthesize_wave(sigx, snr, wgn_snr, gain, do_rir, args): """ Synth Block - Used to process the input audio. The input is convolved with room reverberation recording. - Adds noise in the form of white gaussian noise and regular audio clips (eg:piano, people talking, car engine etc). + Adds noise in the form of white Gaussian noise and regular audio clips (eg:piano, people talking, car engine etc). Input: sigx : input signal to the block. snr : signal-to-noise ratio of the input and additive noise (regular audio). - wg_snr : signal-to-noise ratio of the input and additive noise (white gaussian noise). + wg_snr : signal-to-noise ratio of the input and additive noise (white Gaussian noise). gain : gain of the output signal. do_rir : boolean flag, if reverbration needs to be incorporated. args : args object (contains info about model and training). @@ -74,7 +74,7 @@ def synthesize_wave(sigx, snr, wgn_snr, gain, do_rir, args): noise_scale = y_rmse / noise_rmse * math.pow(10, -snr / 20) sigy = sigy + add_sample * noise_scale - # Only bother with white gasussian noise addition if the WG_SNR is low enough. + # Only bother with white Gaussian noise addition if the WG_SNR is low enough. if wgn_snr < 50: wgn_samps = np.random.normal(size=(len(sigy))).astype(np.float32) noise_scale = y_rmse * math.pow(10, -wgn_snr / 20) diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py index 204eee5df..a80532728 100644 --- a/applications/KWS_Phoneme/kwscnn.py +++ b/applications/KWS_Phoneme/kwscnn.py @@ -49,7 +49,7 @@ def __init__(self): def forward(self, value): """ Applies a custom activation function. - The first half of the channels are passed through sigmoid layer and the next half though a tanh. + The first half of the channels are passed through sigmoid layer and the next half through a tanh. The outputs are multiplied and returned. Input: