diff --git a/applications/KWS_Phoneme/README.md b/applications/KWS_Phoneme/README.md new file mode 100644 index 000000000..ab3ba986a --- /dev/null +++ b/applications/KWS_Phoneme/README.md @@ -0,0 +1,51 @@ +# Phoneme-based Keyword Spotting(KWS) + +# Project Description +There are two major issues in the existing KWS systems (a) They are not robust to heavy background noise and random utterances, and (b) They require collecting a lot of data, hampering the ease of adding a new keyword. Tackling these issues from a different perspective, we propose a new two staged scheme with a model for predicting phonemes which are in turn used for phoneme-based keyword classification. + +First we train a phoneme classification model which gives the phoneme transcription of the input speech snippet. For training this phoneme classifier, we use a large public speech dataset like LibriSpeech. The public dataset can be aligned (meaning we can get the phoneme labels for each speech snippet in the data) using Montreal Forced Aligner. We also add reverberations and additive noise to the speech samples from the public dataset to make the phoneme classifier training robust to various accents, background noise and varied environments. In this project, we predict phonemes at every 10ms which is the standard way. You can find the aligned LibriSpeech dataset we used for training here. + +In the second part, we use the predicted phoneme outputs from the phoneme classifier for predicting the input keyword. We train a 1 layer FastGRNN classifier to predict the keyword based on the phoneme transcription as input. Since the phoneme classifier training has been done to account for diverse accents, background noise and environments, the keyword classifier can be trained using a small number of Text-To-Speech(TTS) samples generated using any standard TTS API from cloud services like Azure, Google Cloud or AWS. + +This gives two advantages: (a) The phoneme model is trained to account for diverse accents and background noise settings, thus the flexible keyword classifier training requires only a small number of keyword samples, and (b) Empirically this method was able to detect keywords from as far as 9ft of distance. Further, the phoneme model has a small size of around 250k parameters and can fit on a Cortex M7 micro-controller. + +# Training the Phoneme Classifier +1) Train a phoneme classification model on some public speech dataset like LibriSpeech. +2) Training speech dataset can be labelled using Montreal Force Aligner. +3) Speech snippets are convolved with reverberation files, and additive noises from YouTube or other open source are added. +4) We also add white gaussian noise of various SNRs. + +# Training the KWS Model +1) Our method takes as input the speech snippet and passes it through the phoneme classifier. +2) Keywords are detected by training a keyword classifier over the detected phonemes. +3) For training the keyword classifier, we use Azure and Google Text-To-Speech API to get the training data (keyword snippets). +4) For example, if you want to train a keyword classifier for the keywords in the Google30 dataset, generate TTS samples from the Azure/Google-Cloud/AWS API for each of the 30 keywords. The TTS samples for each keyword must be stored in a separate folder named according to the keyword. More details about how the generated TTS data should be stored are mentioned below in sample use case for classifier model training. + +# Sample Use Cases + +## Phoneme Model Training +The following command can be used to instantiate and train the phoneme model. +``` +python train_phoneme.py --base_path=/path/to/librispeech_data/ --rir_base_path=/path/to/reverb_files/ --additive_base_path=/path/to/additive_noises/ --snr_samples="0,5,10,25,100,100" --rir_chance=0.5 +``` +Some important command line arguments: +1) base_path : Path of the speech data folder. The data in this folder should be in accordance to the data-loader code written here. +2) rir_base_path, additive_base_path : Path to the reverb and additive noise files. +3) snr_samples : List of various SNRs at which the additive noise is to be added. +4) rir_chance : Probability that would decide if the reverberation operation has to be performed for a given speech sample. + +## Classifier Model Training +The following command can be used to instantiate and train the classifier model. +``` +python train_classifier.py --base_path=/path/to/train_and_test_data_folders/ --train_data_folders=google30_azure_tts,google30_google_tts --test_data_folders=google30_test --phoneme_model_load_ckpt=/path/to/checkpoint/x.pt --rir_base_path=/mnt/reverb_noise_sampled/ --additive_base_path=/mnt/add_noises_sampled/ --synth +``` +Some important command line arguments: + +1) base_path : Path to train and test data folders. +2) train_data_folders, test_data_folders : These folders should have the .wav files for each keyword in a separate subfolder inside according to the data-loader here. +3) phoneme_model_load_ckpt : The full path of the checkpoint file that would be used to load the weights to the instantiated phoneme model. +4) rir_base_path, additive_base_path : Path to the reverb and additive noise files. +5) synth : Boolean flag for specifying if reverberations and noise addition has to be done. + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT license. \ No newline at end of file diff --git a/applications/KWS_Phoneme/auxiliary_files/README.md b/applications/KWS_Phoneme/auxiliary_files/README.md new file mode 100644 index 000000000..6521fed4a --- /dev/null +++ b/applications/KWS_Phoneme/auxiliary_files/README.md @@ -0,0 +1,22 @@ +# Python scripts to help download and down-sample the additive noise data from YouTube videos + +Run the following commands to download the CSV Files to download the YouTube Additive Noise Data : + +``` +wget http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv +``` +Following the download of the CSV file, run the extraction script to download the actual audio data : +``` +python download_youtube_data.py --csv_file=/path/to/csv_file.csv --target_folder=/path/to/target/folder/ +``` + +Please check [Google's Audioset data page](https://research.google.com/audioset/download.html) for further details. + +The downloaded files would need to be converted to 16KHz for our pipeline. Please run the following for the same : +``` +python convert_sampling_rate.py --source_folder=/path/to/csv_file.csv --target_folder=/path/to/target/16KHz_folder/ --fs=16000 --log_rate=100 +``` +The script can convert the sampling rate of any .wav file to the specified --fs. But for our applications, we use 16KHz only. Choose the log rate for how often the log should be printed for the sample rate conversion. This will print a string every log_rate iterations. + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT license. \ No newline at end of file diff --git a/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py b/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py new file mode 100644 index 000000000..1685ab168 --- /dev/null +++ b/applications/KWS_Phoneme/auxiliary_files/convert_sampling_rate.py @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import os +import librosa +import numpy as np +import soundfile as sf +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--source_folder', default=None, required=True) +parser.add_argument('--target_folder', default=None, required=True) +parser.add_argument('--fs', type=int, default=16000) +parser.add_argument('--log_rate', type=int, default=1000) +args = parser.parse_args() + +source_folder = args.source_folder +target_folder = args.target_folder +fs = args.fs +log_rate = args.log_rate +print(f'Source Folder :: {source_folder}\nTarget Folder :: {target_folder}\nSampling Frequency :: {fs}', flush=True) + +source_files = [] +target_files = [] +list_completed = [] + +# Get the list of list of wav files from source folder and create target file names (full paths) +for i, f in enumerate(os.listdir(source_folder)): + if f[-4:].lower() == '.wav': + source_files.append(os.path.join(source_folder, f)) + target_files.append(os.path.join(target_folder, f)) +print(f'Saved all the file paths, Number of files = {len(source_files)}', flush=True) + +# Convert the files to args.fs +# Read with librosa and write the mono channel audio using soundfile +print(f'Converting all files to {fs/1000} Khz', flush=True) +for i, file_path in enumerate(source_files): + y, sr = librosa.load(file_path, sr=fs, mono=True) + sf.write(target_files[i], y, sr) + list_completed.append(target_files[i]) + if i % log_rate == 0: + print(f'File Number {i+1}, Shape of Audio {y.shape}, Sampling Frequency {sr}', flush=True) + +print(f'Number of Files saved {len(list_completed)}') +print('Done', flush=True) diff --git a/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py b/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py new file mode 100644 index 000000000..b26efe8e7 --- /dev/null +++ b/applications/KWS_Phoneme/auxiliary_files/download_youtube_data.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import csv +import os +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--csv_file', default=None, required=True) +parser.add_argument('--target_folder', default=None, required=True) +args = parser.parse_args() + +with open(args.csv_file, 'r') as csv_f: + reader = csv.reader(csv_f, skipinitialspace=True) + # Skip 3 lines ; Header + next(reader) + next(reader) + next(reader) + for row in reader: + # Logging + print(row, flush=True) + # Link for the Youtube Video + YouTube_ID = row[0] # "-0RWZT-miFs" + start_time = int(float(row[1])) # 420 + end_time = int(float(row[2])) # 430 + # Construct downloadable link + YouTube_link = "https://youtu.be/" + YouTube_ID + # Output Filename + output_file = f"{args.target_folder}/ID_{YouTube_ID}.wav" + # Start time in hrs:min:sec format + start_sec = start_time % 60 + start_min = (start_time // 60) % 60 + start_hrs = start_time // 3600 + # End time in hrs:min:sec format + end_sec = end_time % 60 + end_min = (end_time // 60) % 60 + end_hrs = end_time // 3600 + # Start and End time args + time_args = f"-ss {start_hrs}:{start_min}:{start_sec} -to {end_hrs}:{end_min}:{end_sec}" + # Command Line Execution + os.system(f"youtube-dl -x -q --audio-format wav --postprocessor-args '{time_args}' {YouTube_link}" + " --exec 'mv {} " + f"{output_file}'") + print('', flush=True) diff --git a/applications/KWS_Phoneme/data_pipe.py b/applications/KWS_Phoneme/data_pipe.py new file mode 100644 index 000000000..af11cda12 --- /dev/null +++ b/applications/KWS_Phoneme/data_pipe.py @@ -0,0 +1,534 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import argparse +import torch.utils.data +import os, glob, random +from collections import Counter +import soundfile as sf +import scipy.signal +import scipy.io.wavfile +import numpy as np +import textgrid +import multiprocessing +from subprocess import call +import librosa +import math +from numpy import random + +def synthesize_wave(sigx, snr, wgn_snr, gain, do_rir, args): + """ + Synth Block - Used to process the input audio. + The input is convolved with room reverberation recording. + Adds noise in the form of white Gaussian noise and regular audio clips (eg:piano, people talking, car engine etc). + + Input: + sigx : input signal to the block. + snr : signal-to-noise ratio of the input and additive noise (regular audio). + wg_snr : signal-to-noise ratio of the input and additive noise (white Gaussian noise). + gain : gain of the output signal. + do_rir : boolean flag, if reverbration needs to be incorporated. + args : args object (contains info about model and training). + + Output: + clipped version of the audio post-processing. + """ + beta = np.random.choice([0.1, 0.25, 0.5, 0.75, 1]) + sigx = beta * sigx + x_power = np.sum(sigx * sigx) + + # Do RIR and normalize back to original power. + if do_rir: + rir_base_path = args.rir_base_path + rir_fname = random.choice(os.listdir(rir_base_path)) + rir_full_fname = rir_base_path + rir_fname + rir_sample, fs = sf.read(rir_full_fname) + if rir_sample.ndim > 1: + rir_sample = rir_sample[:,0] + # We cut the tail of the RIR signal at 99% energy. + cum_en = np.cumsum(np.power(rir_sample, 2)) + cum_en = cum_en / cum_en[-1] + rir_sample = rir_sample[cum_en <= 0.99] + + max_spike = np.argmax(np.abs(rir_sample)) + sigy = scipy.signal.fftconvolve(sigx, rir_sample)[max_spike:] + sigy = sigy[0:len(sigx)] + + y_power = np.sum(sigy * sigy) + sigy *= math.sqrt(x_power / y_power) # normalize so y has same total power. + else: + sigy = sigx + y_rmse = math.sqrt(x_power / len(sigy)) + + # Only bother with noise addition if the SNR is low enough. + if snr < 50: + add_sample = get_add_noise(args) + noise_rmse = math.sqrt(np.sum(add_sample * add_sample) / len(add_sample)) #+ 0.000000000000000001 + if len(add_sample) < len(sigy): + padded = np.zeros(len(sigy), dtype=np.float32) + padded[0:len(add_sample)] = add_sample + else: + padded = add_sample[0:len(sigy)] + add_sample = padded + noise_scale = y_rmse / noise_rmse * math.pow(10, -snr / 20) + sigy = sigy + add_sample * noise_scale + + # Only bother with white Gaussian noise addition if the WG_SNR is low enough. + if wgn_snr < 50: + wgn_samps = np.random.normal(size=(len(sigy))).astype(np.float32) + noise_scale = y_rmse * math.pow(10, -wgn_snr / 20) + sigy = sigy + wgn_samps * noise_scale + + # Apply gain & clipping. + return np.clip(sigy * gain, -1.0, 1.0) + +def get_add_noise(args): + """ + Extracts the additive noise file from the defined path. + + Input: + args: args object (contains info about model and training). + + Output: + add_sample: additive noise audio. + """ + additive_base_path = args.additive_base_path + add_fname = random.choice(os.listdir(additive_base_path)) + add_full_fname = additive_base_path + add_fname + add_sample, fs = sf.read(add_full_fname) + + return add_sample + +def get_ASR_datasets(args): + """ + Function for preparing the data samples for the phoneme pipeline. + + Input: + args: args object (contains info about model and training). + + Output: + train_dataset: dataset class used for loading the samples into the training pipeline. + """ + base_path = args.base_path + + # Load the speech data. This code snippet (till line 121) depends on the data format in base_path. + train_textgrid_paths = glob.glob(base_path + + "/text/train-clean*/*/*/*.TextGrid") + + train_wav_paths = [path.replace("text", "audio").replace(".TextGrid", ".wav") + for path in train_textgrid_paths] + + if args.pre_phone_list: + # If there is a list of phonemes in the dataset, use this flag. + Sy_phoneme = [] + with open(args.phoneme_text_file, "r") as f: + for line in f.readlines(): + if line.rstrip("\n") != "": Sy_phoneme.append(line.rstrip("\n")) + args.num_phonemes = len(Sy_phoneme) + print("**************", flush=True) + print("Phoneme List", flush=True) + print(Sy_phoneme, flush=True) + print("**************", flush=True) + print("**********************", flush=True) + print("Total Num of Phonemes", flush=True) + print(len(Sy_phoneme), flush=True) + print("**********************", flush=True) + else: + # No list of phonemes specified. Count from the input dataset. + phoneme_counter = Counter() + for path in train_textgrid_paths: + tg = textgrid.TextGrid() + tg.read(path) + phoneme_counter.update([phone.mark.rstrip("0123456789") + for phone in tg.getList("phones")[0] + if phone.mark not in ['', 'sp', 'spn']]) + + # Display and store the phonemes extracted. + Sy_phoneme = list(phoneme_counter) + args.num_phonemes = len(Sy_phoneme) + print("**************", flush=True) + print("Phoneme List", flush=True) + print(Sy_phoneme, flush=True) + print("**************", flush=True) + print("**********************", flush=True) + print("Total Num of Phonemes", flush=True) + print(len(Sy_phoneme), flush=True) + print("**********************", flush=True) + with open(args.phoneme_text_file, "w") as f: + for phoneme in Sy_phoneme: + f.write(phoneme + "\n") + + print("Data Path Prep Done.", flush=True) + + # Create dataset objects. + train_dataset = ASRDataset(train_wav_paths, train_textgrid_paths, Sy_phoneme, args) + + return train_dataset + +class ASRDataset(torch.utils.data.Dataset): + def __init__(self, wav_paths, textgrid_paths, Sy_phoneme, args): + """ + Dataset iterator for the phoneme detection model. + + Input: + wav_paths : list of strings (wav file paths). + textgrid_paths : list of strings (textgrid for each wav file). + Sy_phoneme : list of strings (all possible phonemes). + args : args object (contains info about model and training). + """ + self.wav_paths = wav_paths + self.textgrid_paths = textgrid_paths + self.length_mean = args.pretraining_length_mean + self.length_var = args.pretraining_length_var + self.Sy_phoneme = Sy_phoneme + self.args = args + # Dataset Loader for the iterator. + self.loader = torch.utils.data.DataLoader(self, batch_size=args.pretraining_batch_size, + num_workers=args.workers, shuffle=True, + collate_fn=CollateWavsASR()) + + def __len__(self): + """ + Number of audio samples available. + """ + return len(self.wav_paths) + + def __getitem__(self, idx): + """ + Gives one sample from the dataset. Data is read in this snippet. + (refer to the collate function for pre-processing). + + Input: + idx: index for the sample. + + Output: + x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping. + y_phoneme : the output phonemes sampled at 30ms. + """ + x, fs = sf.read(self.wav_paths[idx]) + + tg = textgrid.TextGrid() + tg.read(self.textgrid_paths[idx]) + + y_phoneme = [] + for phoneme in tg.getList("phones")[0]: + duration = phoneme.maxTime - phoneme.minTime + phoneme_index = self.Sy_phoneme.index(phoneme.mark.rstrip("0123456789")) if phoneme.mark.rstrip("0123456789") in self.Sy_phoneme else -1 + if phoneme.mark == '': phoneme_index = -1 + y_phoneme += [phoneme_index] * round(duration * fs) + + # Cut a snippet of length random_length from the audio. + random_length = round(fs * (self.length_mean + self.length_var * torch.randn(1).item())) + if len(x) <= random_length: + start = 0 + else: + start = torch.randint(low=0, high=len(x)-random_length, size=(1,)).item() + end = start + random_length + + x = x[start:end] + + if np.random.random() < self.args.synth_chance: + x = synthesize_wave(x, np.random.choice(self.args.snr_samples), + np.random.choice(self.args.wgn_snr_samples), np.random.choice(self.args.gain_samples), + np.random.random() < self.args.rir_chance, self.args) + + self.phone_downsample_factor = 160 + y_phoneme = y_phoneme[start:end:self.phone_downsample_factor] + + # feature = librosa.feature.mfcc(x,sr=16000,n_mfcc=80,win_length=25*16,hop_length=10*16) + + return (x, y_phoneme) + +class CollateWavsASR: + def __call__(self, batch): + """ + Pre-processing and padding, followed by batching the set of inputs. + + Input: + batch: list of tuples (input wav, phoneme labels). + + Output: + feature_tensor : the melspectogram features of the input audio. The features are padded for batching. + y_phoneme_tensor : the phonemes sequences in a tensor format. The phoneme sequences are padded for batching. + """ + x = []; y_phoneme = [] + batch_size = len(batch) + for index in range(batch_size): + x_,y_phoneme_, = batch[index] + + x.append(x_) + y_phoneme.append(y_phoneme_) + + # pad all sequences to have same length and get features. + features=[] + T = max([len(x_) for x_ in x]) + U_phoneme = max([len(y_phoneme_) for y_phoneme_ in y_phoneme]) + for index in range(batch_size): + # pad audio to same length for all the audio samples in the batch. + x_pad_length = (T - len(x[index])) + x[index] = np.pad(x[index], (x_pad_length,0), 'constant', constant_values=(0, 0)) + + # Extract Mel-Spectogram from padded audio. + feature = librosa.feature.melspectrogram(y=x[index],sr=16000,n_mels=80, + win_length=25*16,hop_length=10*16, n_fft=512) + + feature = librosa.core.power_to_db(feature) + # Normalize the features. + max_value = np.max(feature) + min_value = np.min(feature) + feature = (feature - min_value) / (max_value - min_value) + features.append(feature) + + # Pad the labels to same length for all samples in the batch. + y_pad_length = (U_phoneme - len(y_phoneme[index])) + y_phoneme[index] = np.pad(y_phoneme[index], (y_pad_length,0), 'constant', constant_values=(-1, -1)) + + features_tensor = []; y_phoneme_tensor = [] + batch_size = len(batch) + for index in range(batch_size): + # x_,y_phoneme_, = batch[index] + x_ = features[index] + y_phoneme_ = y_phoneme[index] + + features_tensor.append(torch.tensor(x_).float()) + y_phoneme_tensor.append(torch.tensor(y_phoneme_).long()) + + features_tensor = torch.stack(features_tensor) + y_phoneme_tensor = torch.stack(y_phoneme_tensor) + + return (features_tensor,y_phoneme_tensor) + +def get_classification_dataset(args): + """ + Function for preparing the data samples for the classification pipeline. + + Input: + args: args object (contains info about model and training). + + Output: + train_dataset : dataset class used for loading the samples into the training pipeline. + test_dataset : dataset class used for loading the samples into the testing pipeline. + """ + base_path = args.base_path + + # Train Data. + train_wav_paths = [] + train_labels = [] + + # We assign data_folder_list = ["google30_train"] or ["google30_azure_tts", "google30_google_tts"]. + data_folder_list = args.train_data_folders + for data_folder in data_folder_list: + # For each of the folder, iterate through the words and get the files and the labels. + for (label, word) in enumerate(args.words): + curr_word_files = glob.glob(base_path + f"/{data_folder}/" + word + "/*.wav") + train_wav_paths += curr_word_files + train_labels += [label]*len(curr_word_files) + print(f"Number of Train Files {len(train_wav_paths)}", flush=True) + temp = list(zip(train_wav_paths, train_labels)) + random.shuffle(temp) + train_wav_paths, train_labels = zip(*temp) + print(f"Train Data Folders Used {data_folder_list}", flush=True) + # Create dataset objects. + train_dataset = ClassificationDataset(wav_paths=train_wav_paths, labels=train_labels, args=args, is_train=True) + + # Test Data. + test_wav_paths = [] + test_labels = [] + + # We assign data_folder_list = ["google30_test"]. + data_folder_list = args.test_data_folders + for data_folder in data_folder_list: + # For each of the folder, iterate through the words and get the files and the labels. + for (label, word) in enumerate(args.words): + curr_word_files = glob.glob(base_path + f"/{data_folder}/" + word + "/*.wav") + test_wav_paths += curr_word_files + test_labels += [label]*len(curr_word_files) + print(f"Number of Test Files {len(test_wav_paths)}", flush=True) + temp = list(zip(test_wav_paths, test_labels)) + random.shuffle(temp) + test_wav_paths, test_labels = zip(*temp) + print(f"Test Data Folders Used {data_folder_list}", flush=True) + # Create dataset objects. + test_dataset = ClassificationDataset(wav_paths=test_wav_paths, labels=test_labels, args=args, is_train=False) + + return train_dataset, test_dataset + +class ClassificationDataset(torch.utils.data.Dataset): + def __init__(self, wav_paths, labels, args, is_train=True): + """ + Dataset iterator for the classifier model. + + Input: + wav_paths : list of strings (wav file paths). + labels : list of classification labels for the corresponding audio wav files. + is_train : boolean flag, if the dataset loader is for the train or test pipeline. + args : args object (contains info about model and training). + """ + self.wav_paths = wav_paths + self.labels = labels + self.args = args + self.is_train = is_train + self.loader = torch.utils.data.DataLoader(self, batch_size=args.pretraining_batch_size, + num_workers=args.workers, shuffle=is_train, + collate_fn=CollateWavsClassifier()) + + def __len__(self): + """ + Number of audio samples available. + """ + return len(self.wav_paths) + + def one_hot_encoder(self, lab): + """ + Label index to one-hot encoder. + + Input: + lab: label index. + + Output: + one_hot: label in the one-hot format. + """ + one_hot = np.zeros(len(self.args.words)) + one_hot[lab]=1 + return one_hot + + def __getitem__(self, idx): + """ + Gives one sample from the dataset. Data is read in this snippet. (refer to the collate function for pre-processing). + + Input: + idx: index for the sample. + + Output: + x : audio sample obtained from the synth block (if used, else input audio) after time-domain clipping. + one_hot_label : one-hot encoded label. + seqlen : length of the audio file. + This value will be dropped and seqlen after feature extraction will be used. Refer to the collate function. + """ + x, fs = sf.read(self.wav_paths[idx]) + + label = self.labels[idx] + one_hot_label = self.one_hot_encoder(label) + seqlen = len(x) + + if self.is_train: + # Use synth only for train files + if self.args.synth: + if np.random.random() < self.args.synth_chance: + x = synthesize_wave(x, np.random.choice(self.args.snr_samples), + np.random.choice(self.args.wgn_snr_samples), np.random.choice(self.args.gain_samples), + np.random.random() < self.args.rir_chance, self.args) + + return (x, one_hot_label, seqlen) + +class CollateWavsClassifier: + def __call__(self, batch): + """ + Pre-processing and padding, followed by batching the set of inputs. + + Input: + batch: list of tuples (input wav, one hot classification label, sequence length). + + Output: + feature_tensor : the melspectogram features of the input audio. The features are padded for batching. + one_hot_label_tensor : the on-hot label in a tensor format. + seqlen_tensor : the sequence length of the features in a minibatch. + """ + x = []; one_hot_label = []; seqlen = [] + batch_size = len(batch) + for index in range(batch_size): + x_,one_hot_label_,_ = batch[index] + x.append(x_) + one_hot_label.append(one_hot_label_) + + # pad all sequences to have same length and get features. + features=[] + T = max([len(x_) for x_ in x]) + T = max([T, 48000]) + for index in range(batch_size): + # pad audio to same length for all the audio samples in the batch. + x_pad_length = (T - len(x[index])) + x[index] = np.pad(x[index], (x_pad_length,0), 'constant', constant_values=(0, 0)) + + # Extract Mel-Spectogram from padded audio. + feature = librosa.feature.melspectrogram(y=x[index],sr=16000,n_mels=80,win_length=25*16,hop_length=10*16, n_fft=512) + feature = librosa.core.power_to_db(feature) + # Normalize the features. + max_value = np.max(feature) + min_value = np.min(feature) + if min_value == max_value: + feature = feature - min_value + else: + feature = (feature - min_value) / (max_value - min_value) + features.append(feature) + seqlen.append(feature.shape[1]) + + features_tensor = []; one_hot_label_tensor = []; seqlen_tensor = [] + batch_size = len(batch) + for index in range(batch_size): + x_ = features[index] + one_hot_label_ = one_hot_label[index] + seqlen_ = seqlen[index] + + features_tensor.append(torch.tensor(x_).float()) + one_hot_label_tensor.append(torch.tensor(one_hot_label_).long()) + seqlen_tensor.append(torch.tensor(seqlen_)) + + features_tensor = torch.stack(features_tensor) + one_hot_label_tensor = torch.stack(one_hot_label_tensor) + seqlen_tensor = torch.stack(seqlen_tensor) + + return (features_tensor,one_hot_label_tensor,seqlen_tensor) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--base_path', type=str, required=True, help="Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here.") + parser.add_argument('--train_data_folders', type=str, default="google30_train", help="List of training folders in base path. Each folder is a dataset in the prescribed format") + parser.add_argument('--test_data_folders', type=str, default="google30_test", help="List of testing folders in base path. Each folder is a dataset in the prescribed format") + parser.add_argument('--rir_base_path', type=str, required=True, help="Folder with the reverbration files") + parser.add_argument('--additive_base_path', type=str, required=True, help="Folder with additive noise files") + parser.add_argument('--phoneme_text_file', type=str, required=True, help="Text files with pre-fixed phons") + parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") + parser.add_argument("--word_model_name", default='google30', help="Name of the word list chosen. Will be used in conjunction with the data loader") + parser.add_argument('--words', type=str, default="all") + parser.add_argument("--synth", action='store_true', help="Use Synth block or not") + parser.add_argument('--pretraining_length_mean', type=int, default=9) + parser.add_argument('--pretraining_length_var', type=int, default=1) + parser.add_argument('--pretraining_batch_size', type=int, default=64) + parser.add_argument('--snr_samples', type=str, default="0,5,10,25,100,100") + parser.add_argument('--wgn_snr_samples', type=str, default="5,10,15,100,100") + parser.add_argument('--gain_samples', type=str, default="1.0,0.25,0.5,0.75") + parser.add_argument('--rir_chance', type=float, default=0.25) + parser.add_argument('--synth_chance', type=float, default=0.5) + parser.add_argument('--pre_phone_list', action='store_true') + args = parser.parse_args() + + # SNRs. + args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] + args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] + args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] + + # Workers. + if args.workers == -1: + args.workers = multiprocessing.cpu_count() + + # Words. + if args.word_model_name == 'google30': + args.words = ["bed", "bird", "cat", "dog", "down", "eight", "five", "four", "go", + "happy", "house", "left", "marvin", "nine", "no", "off", "on", "one", "right", + "seven", "sheila", "six", "stop", "three", "tree", "two", "up", "wow", "yes", "zero" + ] + elif args.word_model_name == 'google10': + args.words = ["yes", "no", "up", "down", "left", "right", "on", "off", + "stop", "go", "allsilence", "unknown"] + else: + raise ValueError('Incorrect Word Model Name') + + # Data Folders. + args.train_data_folders = [folder_idx for folder_idx in args.train_data_folders.split(',')] + args.test_data_folders = [folder_idx for folder_idx in args.test_data_folders.split(',')] + + print(args.pre_phone_list, flush=True) + # get_ASR_datasets(args) + dset = get_classification_dataset(args) diff --git a/applications/KWS_Phoneme/kwscnn.py b/applications/KWS_Phoneme/kwscnn.py new file mode 100644 index 000000000..a80532728 --- /dev/null +++ b/applications/KWS_Phoneme/kwscnn.py @@ -0,0 +1,673 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from utils import _single +import math +import numpy as np +from edgeml_pytorch.graph.rnn import * + +class _IndexSelect(torch.nn.Module): + def __init__(self, channels, direction, groups): + """ + Channel permutation module. The purpose of this is to allow mixing across the CNN groups. + """ + super(_IndexSelect, self).__init__() + + if channels % groups != 0: + raise ValueError('Channels should be a multiple of the groups') + + self._index = torch.zeros((channels), dtype=torch.int64) + count = 0 + + if direction > 0: + for gidx in range(groups): + for nidx in range(gidx, channels, groups): + self._index[count] = nidx + count += 1 + else: + for gidx in range(groups): + for nidx in range(gidx, channels, groups): + self._index[nidx] = count + count += 1 + + def forward(self, value): + if value.device != self._index.device: + self._index = self._index.to(value.device) + + return torch.index_select(value, 1, self._index) + + +class _TanhGate(torch.nn.Module): + def __init__(self): + super(_TanhGate, self).__init__() + + def forward(self, value): + """ + Applies a custom activation function. + The first half of the channels are passed through sigmoid layer and the next half through a tanh. + The outputs are multiplied and returned. + + Input: + value: A tensor of shape (batch, channels, *). + + Output: + activation output of shape (batch, channels/2, *). + """ + channels = value.shape[1] + piv = int(channels/2) + + sig_data = value[:, 0:piv, :] + tanh_data = value[:, piv:, :] + + sig_data = torch.sigmoid(sig_data) + tanh_data = torch.tanh(tanh_data) + return sig_data * tanh_data + + +class LR_conv(nn.Conv1d): + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, groups=1, + bias=True, padding_mode='zeros', rank=50): + super().__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode) + """ + A convolution layer with the weight matrix subjected to a low-rank decomposition. Currently for kernel size of 5. + + Input: + rank : The rank used for the low-rank decomposition on the weight/kernel tensor. + All other parameters are similar to that of a convolution layer. + Only change is the decomposition of the output channels into low-rank tensors. + """ + self.kernel_size = kernel_size + self.rank = rank + self.W1 = Parameter(torch.Tensor(self.out_channels, rank)) + # As per PyTorch Standard + nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) + self.W2 = Parameter(torch.Tensor(rank, self.in_channels * self.kernel_size)) + nn.init.kaiming_uniform_(self.W2, a=math.sqrt(5)) + self.weight = None + + def forward(self, input): + """ + The decomposed weights are multiplied to enforce the low-rank constraint. + The conv1d is performed as usual post multiplication. + + Input: + input: Input of shape similar to that of which is fed to a conv layer. + + Output: + convolution output. + """ + lr_weight = torch.matmul(self.W1, self.W2) + lr_weight = torch.reshape(lr_weight, (self.out_channels, self.in_channels, self.kernel_size)) + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + lr_weight, self.bias, self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, lr_weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + +class PreRNNConvBlock(torch.nn.Module): + def __init__( + self, in_channels, out_channels, kernel, + stride=1, groups=1, avg_pool=2, dropout=0.1, + batch_norm=0.1, shuffle=0, + activation='sigmoid', rank=50): + super(PreRNNConvBlock, self).__init__() + """ + A low-rank convolution layer combination with pooling and activation layers. Currently for kernel size of 5. + + Input: + in_channels : number of input channels for the conv layer. + out_channels : number of output channels for the conv layer. + kernel : conv kernel size. + stride : conv stride. + groups : number of groups for conv layer. + avg_pool : kernel size for average pooling layer. + dropout : dropout layer probability. + batch_norm : momentum for batch norm. + activation : activation layer. + rank : rank for low-rank decomposition for conv layer weights. + """ + activators = { + 'sigmoid': torch.nn.Sigmoid(), + 'relu': torch.nn.ReLU(), + 'leakyrelu': torch.nn.LeakyReLU(), + 'tanhgate': _TanhGate(), + 'none': None + } + + if activation not in activators: + raise ValueError('Available activations are: %s' % ', '.join(activators.keys())) + + if activation == 'tanhgate': + in_channels = int(in_channels/2) + + nonlin = activators[activation] + + if batch_norm > 0.0: + batch_block = torch.nn.BatchNorm1d(in_channels, affine=False, momentum=batch_norm) + else: + batch_block = None + + depth_cnn = None + point_cnn = LR_conv(in_channels, out_channels, kernel_size=kernel, stride=stride, groups=groups, rank=rank, padding=2) + + if shuffle != 0 and groups > 1: + shuffler = _IndexSelect(in_channels, shuffle, groups) + else: + shuffler = None + + if avg_pool > 0: + pool = torch.nn.AvgPool1d(kernel_size=avg_pool, stride=1) + else: + pool = None + + if dropout > 0: + dropout_block = torch.nn.Dropout(p=dropout) + else: + dropout_block = None + + seq1 = [nonlin, batch_block, depth_cnn, shuffler, point_cnn, dropout_block, pool] + seq_f1 = [item for item in seq1 if item is not None] + if len(seq_f1) == 1: + self._op1 = seq_f1[0] + else: + self._op1 = torch.nn.Sequential(*seq_f1) + + def forward(self, x): + """ + Apply the set of layers initialized in __init__. + + Input: + x: A tensor of shape (batch, channels, length). + + Output: + network block output of shape (batch, channels, length). + """ + x = self._op1(x) + return x + +class DSCNNBlockLR(torch.nn.Module): + def __init__( + self, in_channels, out_channels, kernel, + stride=1, groups=1, avg_pool=2, dropout=0.1, + batch_norm=0.1, shuffle=0, + activation='sigmoid', rank=50): + super(DSCNNBlockLR, self).__init__() + """ + A depthwise separable low-rank convolution layer combination with pooling and activation layers. + + Input: + in_channels : number of input channels for the pointwise conv layer. + out_channels : number of output channels for the pointwise conv layer. + kernel : conv kernel size for depthwise layer. + stride : conv stride. + groups : number of groups for conv layer. + avg_pool : kernel size for average pooling layer. + dropout : dropout layer probability. + batch_norm : momentum for batch norm. + activation : activation layer. + rank : rank for low-rank decomposition for conv layer weights. + """ + activators = { + 'sigmoid': torch.nn.Sigmoid(), + 'relu': torch.nn.ReLU(), + 'leakyrelu': torch.nn.LeakyReLU(), + 'tanhgate': _TanhGate(), + 'none': None + } + + if activation not in activators: + raise ValueError('Available activations are: %s' % ', '.join(activators.keys())) + + if activation == 'tanhgate': + in_channels = int(in_channels/2) + + nonlin = activators[activation] + + if batch_norm > 0.0: + batch_block = torch.nn.BatchNorm1d(in_channels, affine=False, momentum=batch_norm) + else: + batch_block = None + + depth_cnn = torch.nn.Conv1d(in_channels, in_channels, kernel_size=kernel, stride=1, groups=in_channels, padding=2) + point_cnn = LR_conv(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups, rank=rank) + + if shuffle != 0 and groups > 1: + shuffler = _IndexSelect(in_channels, shuffle, groups) + else: + shuffler = None + + if avg_pool > 0: + pool = torch.nn.AvgPool1d(kernel_size=avg_pool, stride=1) + else: + pool = None + + if dropout > 0: + dropout_block = torch.nn.Dropout(p=dropout) + else: + dropout_block = None + + seq = [nonlin, batch_block, depth_cnn, shuffler, point_cnn, dropout_block, pool] + seq_f = [item for item in seq if item is not None] + if len(seq_f) == 1: + self._op = seq_f[0] + else: + self._op = torch.nn.Sequential(*seq_f) + + def forward(self, x): + """ + Apply the set of layers initialized in __init__. + + Input: + x: A tensor of shape (batch, channels, length). + + Output: + network block output of shape (batch, channels, length). + """ + x = self._op(x) + return x + +class BiFastGRNN(nn.Module): + """ + Bi Directional FastGRNN. + + Parameters and arguments are similar to the torch RNN counterparts. + """ + def __init__(self, inputDims, hiddenDims, gate_nonlinearity, + update_nonlinearity, rank): + super(BiFastGRNN, self).__init__() + + self.cell_fwd = FastGRNNCUDA(inputDims, + hiddenDims, + gate_nonlinearity, + update_nonlinearity, + batch_first=True, + wRank=rank, + uRank=rank) + + self.cell_bwd = FastGRNNCUDA(inputDims, + hiddenDims, + gate_nonlinearity, + update_nonlinearity, + batch_first=True, + wRank=rank, + uRank=rank) + + def forward(self, input_f, input_b): + """ + Pass the inputs to forward and backward layers. + Please note the backward layer is similar to the forward layer and the input needs to be fed reversed accordingly. + Tensors are of the shape (batch, length, channels). + + Input: + input_f : input to the forward layer. + input_b : input to the backward layer. Input needs to be reversed before passing through this forward method. + + Output: + output1 : output of the forward layer. + output2 : output of the backward layer. + """ + # Bidirectional FastGRNN. + output1 = self.cell_fwd(input_f) + output2 = self.cell_bwd(input_b) + #Returning the flipped output only for the bwd pass. + #Will align it in the post processing. + return output1, output2 + + +def X_preRNN_process(X, fwd_context, bwd_context): + """ + A depthwise separable low-rank convolution layer combination with pooling and activation layers. + + Input: + in_channels : number of input channels for the pointwise conv layer. + out_channels : number of output channels for the pointwise conv layer. + kernel : conv kernel size for depthwise layer. + stride : conv stride. + groups : number of groups for conv layer. + avg_pool : kernel size for average pooling layer. + dropout : dropout layer probability. + batch_norm : momentum for batch norm. + activation : activation layer. + rank : rank for low-rank decomposition for conv layer weights. + """ + # Forward bricking. + brickLength = fwd_context + hopLength = 3 + X_bricked_f1 = X.unfold(1, brickLength, hopLength) + X_bricked_f2 = X_bricked_f1.permute(0, 1, 3, 2) + # X_bricked_f [batch, num_bricks, brickLen, inpDim]. + oldShape_f = X_bricked_f2.shape + X_bricked_f = torch.reshape( + X_bricked_f2, [oldShape_f[0] * oldShape_f[1], oldShape_f[2], -1]) + # X_bricked_f [batch*num_bricks, brickLen, inpDim]. + + # Backward bricking. + brickLength = bwd_context + hopLength = 3 + X_bricked_b = X.unfold(1, brickLength, hopLength) + X_bricked_b = X_bricked_b.permute(0, 1, 3, 2) + # X_bricked_b [batch, num_bricks, brickLen, inpDim]. + oldShape_b = X_bricked_b.shape + X_bricked_b = torch.reshape( + X_bricked_b, [oldShape_b[0] * oldShape_b[1], oldShape_b[2], -1]) + # X_bricked_b [batch*num_bricks, brickLen, inpDim]. + return X_bricked_f, oldShape_f, X_bricked_b, oldShape_b + + +def X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b): + """ + A depthwise separable low-rank convolution layer combination with pooling and activation layers. + + Input: + in_channels : number of input channels for the pointwise conv layer. + out_channels : number of output channels for the pointwise conv layer. + kernel : conv kernel size for depthwise layer. + stride : conv stride. + groups : number of groups for conv layer. + avg_pool : kernel size for average pooling layer. + dropout : dropout layer probability. + batch_norm : momentum for batch norm. + activation : activation layer. + rank : rank for low-rank decomposition for conv layer weights. + """ + # Forward bricks folding. + X_f = torch.reshape(X_f, [oldShape_f[0], oldShape_f[1], oldShape_f[2], -1]) + # batch, num_bricks, brickLen, hiddenDim. + X_new_f = X_f[:, 0, ::3, :] + # batch, brickLen, hiddenDim. + X_new_f_rest = X_f[:, :, -1, :].squeeze(2) + # batch, numBricks - 1, hiddenDim. + shape = X_new_f_rest.shape + X_new_f = torch.cat((X_new_f, X_new_f_rest), dim=1) + # batch, seqLen, hiddenDim. + + # Backward Bricks folding. + X_b = torch.reshape(X_b, [oldShape_b[0], oldShape_b[1], oldShape_b[2], -1]) + # batch, num_bricks, brickLen, hiddenDim. + X_b = torch.flip(X_b, [1]) + # Reverse the ordering of the bricks (bring last brick to start). + X_new_b = X_b[:, 0, ::3, :] + # batch, brickLen, inpDim. + X_new_b_rest = X_b[:, :, -1, :].squeeze(2) + # batch, seqlen - brickLen, hiddenDim. + X_new_b = torch.cat((X_new_b, X_new_b_rest), dim=1) + # batch, seqLen, hiddenDim. + X_new_b = torch.flip(X_new_b, [1]) + # inverting the flip operation. + X_new = torch.cat((X_new_f, X_new_b), dim=2) + # batch, seqLen, 2 * hiddenDim. + return X_new + + +class DSCNN_RNN_Block(torch.nn.Module): + def __init__(self, cnn_channels, rnn_hidden_size, rnn_num_layers, + device, gate_nonlinearity="sigmoid", update_nonlinearity="tanh", + isBi=True, num_labels=41, rank=None, fwd_context=15, bwd_context=9): + super(DSCNN_RNN_Block, self).__init__() + """ + A depthwise separable low-rank convolution layer combination with pooling and activation layers. + + Input: + cnn_channels : number of the output channels for the first CNN block. + rnn_hidden_size : hidden dimensions of the FastGRNN. + rnn_num_layers : number of FastGRNN layers. + device : device on which the tensors would placed. + gate_nonlinearity : activation function for the gating in the FastGRNN. + update_nonlinearity : activation function for the update function in the FastGRNN. + isBi : boolean flag to use bi-directional FastGRNN. + fwd_context : window for the forward pass. + bwd_context : window for the backward pass. + """ + self.cnn_channels = cnn_channels + self.rnn_hidden_size = rnn_hidden_size + self.rnn_num_layers = rnn_num_layers + self.gate_nonlinearity = gate_nonlinearity + self.update_nonlinearity = update_nonlinearity + self.num_labels = num_labels + self.device = device + self.fwd_context = fwd_context + self.bwd_context = bwd_context + self.isBi = isBi + if self.isBi: + self.direction_param = 2 + else: + self.direction_param = 1 + + self.declare_network(cnn_channels, rnn_hidden_size, rnn_num_layers, + num_labels, rank) + + self.__name__ = 'DSCNN_RNN_Block' + + def declare_network(self, cnn_channels, rnn_hidden_size, rnn_num_layers, + num_labels, rank): + """ + Declare the netwok layers. + Arguments can be inferred from the __init__. + """ + self.CNN1 = torch.nn.Sequential( + PreRNNConvBlock(80, cnn_channels, 5, 1, 1, + 0, 0, batch_norm=1e-2, + activation='none', rank=rank)) + + #torch tanh layer directly in the forward pass + self.bnorm_rnn = torch.nn.BatchNorm1d(cnn_channels, affine=False, momentum=1e-2) + self.RNN0 = BiFastGRNN(cnn_channels, rnn_hidden_size, + self.gate_nonlinearity, + self.update_nonlinearity, rank) + + self.CNN2 = DSCNNBlockLR(2 * rnn_hidden_size, + 2 * rnn_hidden_size, + batch_norm=1e-2, + dropout=0, + kernel=5, + activation='tanhgate', rank=rank) + + self.CNN3 = DSCNNBlockLR(2 * rnn_hidden_size, + 2 * rnn_hidden_size, + batch_norm=1e-2, + dropout=0, + kernel=5, + activation='tanhgate', rank=rank) + + self.CNN4 = DSCNNBlockLR(2 * rnn_hidden_size, + 2 * rnn_hidden_size, + batch_norm=1e-2, + dropout=0, + kernel=5, + activation='tanhgate', rank=rank) + + self.CNN5 = DSCNNBlockLR(2 * rnn_hidden_size, + num_labels, + batch_norm=1e-2, + dropout=0, + kernel=5, + activation='tanhgate', rank=rank) + + def forward(self, features): + """ + Apply the set of layers initialized in __init__. + + Input: + features: a tensor of shape (batch, channels, length). + + Output: + network block output in the form (batch, channels, length). + """ + batch, _, max_seq_len = features.shape + X = self.CNN1(features) # Down to 30ms inference / 250ms window + X = torch.tanh(X) + X = self.bnorm_rnn(X) + X = X.permute((0, 2, 1)) # NCL to NLC + + X = X.contiguous() + assert X.shape[1] % 3 == 0 + X_f, oldShape_f, X_b, oldShape_b = X_preRNN_process(X, self.fwd_context, self.bwd_context) + #X [batch * num_bricks, brickLen, inpDim] + X_b_f = torch.flip(X_b, [1]) + + X_f, X_b = self.RNN0(X_f, X_b_f) + + X = X_postRNN_process(X_f, oldShape_f, X_b, oldShape_b) + + # re-permute to get [batch, channels, max_seq_len/3 ] + X = X.permute((0, 2, 1)) # NLC to NCL + X = self.CNN2(X) + X = self.CNN3(X) + X = self.CNN4(X) + X = self.CNN5(X) + return X + + +class Binary_Classification_Block(torch.nn.Module): + def __init__(self, in_size, rnn_hidden_size, rnn_num_layers, + device, islstm=False, isBi=True, momentum=1e-2, + num_labels=2, dropout=0, batch_assertion=False): + super(Binary_Classification_Block, self).__init__() + """ + A depthwise separable low-rank convolution layer combination with pooling and activation layers. + + Input: + in_size : number of input channels to the layer. + rnn_hidden_size : hidden dimensions of the RNN layer. + rnn_num_layers : number of layers for the RNN layer. + device : device on which the tensors would placed. + islstm : boolean flag to use the LSTM. False would use GRU. + isBi : boolean flag to use the bi-directional variant of the RNN. + momentum : momentum for the batch-norm layer. + num_labels : number of output labels. + dropout : probability for the dropout layer. + """ + self.in_size = in_size + self.rnn_hidden_size = rnn_hidden_size + self.rnn_num_layers = rnn_num_layers + self.num_labels = num_labels + self.device = device + self.islstm = islstm + self.isBi = isBi + self.momentum = momentum + self.dropout = dropout + self.batch_assertion = batch_assertion + + if self.isBi: + self.direction_param = 2 + else: + self.direction_param = 1 + + self.declare_network(in_size, rnn_hidden_size, rnn_num_layers, num_labels) + + self.__name__ = 'Binary_Classification_Block_2lay' + + def declare_network(self, in_size, rnn_hidden_size, rnn_num_layers, num_labels): + """ + Declare the netwok layers. + Arguments can be inferred from the __init__. + """ + self.CNN1 = torch.nn.Sequential( + torch.nn.LeakyReLU(negative_slope=0.01), + torch.nn.BatchNorm1d(in_size, affine=False, + momentum=self.momentum), + torch.nn.Dropout(self.dropout)) + + if self.islstm: + self.RNN = nn.LSTM(input_size=in_size, + hidden_size=rnn_hidden_size, + num_layers=rnn_num_layers, + batch_first=True, + bidirectional=self.isBi) + else: + self.RNN = nn.GRU(input_size=in_size, + hidden_size=rnn_hidden_size, + num_layers=rnn_num_layers, + batch_first=True, + bidirectional=self.isBi) + + self.FCN = torch.nn.Sequential( + torch.nn.Dropout(self.dropout), + torch.nn.LeakyReLU(negative_slope=0.01), + torch.nn.Linear(self.direction_param * self.rnn_hidden_size, + num_labels)) + + def forward(self, features, seqlen): + """ + Apply the set of layers initialized in __init__. + + Input: + features: A tensor of shape (batch, channels, length). + + Output: + network block output in the form (batch, length, channels). length will be 1. + """ + batch, _, _ = features.shape + + if self.islstm: + hidden1 = self.init_hidden(batch, self.rnn_hidden_size, + self.rnn_num_layers) + hidden2 = self.init_hidden(batch, self.rnn_hidden_size, + self.rnn_num_layers) + else: + hidden = self.init_hidden(batch, self.rnn_hidden_size, + self.rnn_num_layers) + + X = self.CNN1(features) # Down to 30ms inference / 250ms window. + + X = X.permute((0, 2, 1)) # NCL to NLC. + + max_seq_len = X.shape[1] + + # modify seqlen. + max_seq_len = min(torch.max(seqlen).item(), max_seq_len) + seqlen = torch.clamp(seqlen, max=max_seq_len) + self.seqlen = seqlen + + # pad according to seqlen. + X = torch.nn.utils.rnn.pack_padded_sequence(X, + seqlen, + batch_first=True, + enforce_sorted=False) + + self.RNN.flatten_parameters() + if self.islstm: + X, (hh, _) = self.RNN(X, (hidden1, hidden2)) + else: + X, hh = self.RNN(X, hidden) + + X, _ = torch.nn.utils.rnn.pad_packed_sequence(X, batch_first=True) + + X = X.view(batch, max_seq_len, + self.direction_param * self.rnn_hidden_size) + + X = X[torch.arange(batch).long(), + seqlen.long() - 1, :].view( + batch, 1, self.direction_param * self.rnn_hidden_size) + + X = self.FCN(X) + + return X + + def init_hidden(self, batch, rnn_hidden_size, rnn_num_layers): + """ + Used to initialize the first hidden state of the RNN. It is currently zero. The user is free to edit this function for additional analysis. + """ + # the weights are of the form (batch, num_layers * num_directions , hidden_size). + if self.batch_assertion: + hidden = torch.zeros(rnn_num_layers * self.direction_param, batch, + rnn_hidden_size) + else: + hidden = torch.zeros(rnn_num_layers * self.direction_param, batch, + rnn_hidden_size) + + hidden = hidden.to(self.device) + + hidden = Variable(hidden) + + return hidden diff --git a/applications/KWS_Phoneme/train_classifier.py b/applications/KWS_Phoneme/train_classifier.py new file mode 100644 index 000000000..bd8a59b35 --- /dev/null +++ b/applications/KWS_Phoneme/train_classifier.py @@ -0,0 +1,296 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import argparse +import os +import re +import numpy as np +import torch +# Aux scripts. +import kwscnn +import multiprocessing +from data_pipe import get_ASR_datasets, get_classification_dataset + +def parseArgs(): + """ + Parse the command line arguments + Describes the architecture and the hyper-parameters + """ + parser = argparse.ArgumentParser() + # Args for Model Traning. + parser.add_argument('--phoneme_model_load_ckpt', type=str, required=True, help="Phoneme checkpoint file to be loaded") + parser.add_argument('--classifier_model_save_folder', type=str, default='./classifier_model', help="Folder to save the classifier checkpoint") + parser.add_argument('--classifier_model_load_ckpt', type=str, default=None, help="Classifier checkpoint to be loaded") + parser.add_argument('--optim', type=str, default='adam', help="Optimizer to be used") + parser.add_argument('--lr', type=float, default=0.001, help="Optimizer learning rate") + parser.add_argument('--epochs', type=int, default=200, help="Number of epochs for training") + parser.add_argument('--save_tick', type=int, default=1, help="Number of epochs to wait to save") + parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") + parser.add_argument("--gpu", type=str, default='0', help="GPU indices Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") + parser.add_argument("--word_model_name", default='google30', help="Name of the list of words used") + parser.add_argument('--words', type=str, help="List of words to be used. This will be assigned in the code. User input will not affect the result") + parser.add_argument("--is_training", action='store_true', help="True for training") + parser.add_argument("--synth", action='store_true', help="Use Synth block or not") + # Args for DataLoader. + parser.add_argument('--base_path', type=str, required=True, help="path to train and test data folders") + parser.add_argument('--train_data_folders', type=str, default="google30_train", help="List of training folders in base path. Each folder is a dataset in the prescribed format") + parser.add_argument('--test_data_folders', type=str, default="google30_test", help="List of testing folders in base path. Each folder is a dataset in the prescribed format") + parser.add_argument('--rir_base_path', type=str, required=True, help="Folder with the reverbration files") + parser.add_argument('--additive_base_path', type=str, required=True, help="Folder with additive noise files") + parser.add_argument('--phoneme_text_file', type=str, help="Text files with pre-fixed phons") + parser.add_argument('--pretraining_length_mean', type=int, default=6, help="Mean of the audio clips lengths") + parser.add_argument('--pretraining_length_var', type=int, default=1, help="variance of the audio clip lengths") + parser.add_argument('--pretraining_batch_size', type=int, default=256, help="Batch size for the pipeline") + parser.add_argument('--snr_samples', type=str, default="-5,0,0,5,10,15,40,100,100", help="SNR values for additive noise files") + parser.add_argument('--wgn_snr_samples', type=str, default="5,10,20,40,60", help="SNR values for white gaussian noise") + parser.add_argument('--gain_samples', type=str, default="1.0,0.25,0.5,0.75", help="Gain values for processed signal") + parser.add_argument('--rir_chance', type=float, default=0.9, help="Probability of performing reverbration") + parser.add_argument('--synth_chance', type=float, default=0.9, help="Probability of pre-processing the input with reverb and noise") + parser.add_argument('--pre_phone_list', action='store_true', help="use pre-fixed set of phonemes") + # Args for Phoneme. + parser.add_argument('--phoneme_cnn_channels', type=int, default=400, help="Number od channels for the CNN layers") + parser.add_argument('--phoneme_rnn_hidden_size', type=int, default=200, help="Number of RNN hidden states") + parser.add_argument('--phoneme_rnn_layers', type=int, default=1, help="Number of RNN layers") + parser.add_argument('--phoneme_rank', type=int, default=50, help="Rank of the CNN layers weights") + parser.add_argument('--phoneme_fwd_context', type=int, default=15, help="RNN forward window context") + parser.add_argument('--phoneme_bwd_context', type=int, default=9, help="RNN backward window context") + parser.add_argument('--phoneme_phoneme_isBi', action='store_true', help="Use Bi-Directional RNN") + parser.add_argument('--phoneme_num_labels', type=int, default=41, help="Number og phoneme labels") + # Args for Classifier. + parser.add_argument('--classifier_rnn_hidden_size', type=int, default=100, help="Classifier RNN hidden dimensions") + parser.add_argument('--classifier_rnn_num_layers', type=int, default=1, help="Classifier RNN number of layers") + parser.add_argument('--classifier_dropout', type=float, default=0.2, help="Classifier dropout layer probability") + parser.add_argument('--classifier_islstm', action='store_true', help="Use LSTM in the classifier") + parser.add_argument('--classifier_isBi', action='store_true', help="Use Bi-Directional RNN in classifier") + + args = parser.parse_args() + + # Parse the gain and SNR values to a float format. + args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] + args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] + args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] + + # Fix the number of workers for the data Loader. If == -1 then use all possible workers. + if args.workers == -1: + args.workers = multiprocessing.cpu_count() + + # Choose the word list to be used. For custom word lists, please add an elif condition. + if args.word_model_name == 'google30': + args.words = ["bed", "bird", "cat", "dog", "down", "eight", "five", "four", "go", + "happy", "house", "left", "marvin", "nine", "no", "off", "on", "one", "right", + "seven", "sheila", "six", "stop", "three", "tree", "two", "up", "wow", "yes", "zero"] + elif args.word_model_name == 'google10': + args.words = ["yes", "no", "up", "down", "left", "right", "on", "off", + "stop", "go", "allsilence", "unknown"] + else: + raise ValueError('Incorrect Word Model Name') + + # The data-folder in args.base_path that contain the data. + # Refer to data_pipe.py for loading format. + args.train_data_folders = [folder_idx for folder_idx in args.train_data_folders.split(',')] + args.test_data_folders = [folder_idx for folder_idx in args.test_data_folders.split(',')] + + print(f"Args : {args}", flush=True) + return args + +def train_classifier_model(args): + """ + Train the Classifier Model on the designated dataset. + The Dataset loader is defined in data_pipe.py. + Default dataset used is Google30. Change the paths and file reader to change datasets. + + args: args object (contains info about model and training). + """ + # GPU Settings. + gpu_str = str() + for gpu in args.gpu.split(','): + gpu_str = gpu_str + str(gpu) + "," + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str + use_cuda = torch.cuda.is_available() and (args.gpu != -1) + device = torch.device("cuda" if use_cuda else "cpu") + + # Instantiate Phoneme Model. + phoneme_model = kwscnn.DSCNN_RNN_Block(cnn_channels=args.phoneme_cnn_channels, + rnn_hidden_size=args.phoneme_rnn_hidden_size, + rnn_num_layers=args.phoneme_rnn_layers, + device=device, rank=args.phoneme_rank, + fwd_context=args.phoneme_fwd_context, + bwd_context=args.phoneme_bwd_context, + num_labels=args.phoneme_num_labels) + + # Freeze Phoneme Model and Deactivate BatchNorm and Dropout Layers. + for param in phoneme_model.parameters(): + param.requires_grad = False + phoneme_model.train(False) + + # Instantiate Classifier Model. + classifier_model = kwscnn.Binary_Classification_Block(in_size=args.phoneme_num_labels, + rnn_hidden_size=args.classifier_rnn_hidden_size, + rnn_num_layers=args.classifier_rnn_num_layers, + device=device, islstm=args.classifier_islstm, + isBi=args.classifier_isBi, dropout=args.classifier_dropout, + num_labels=len(args.words)) + + # Transfer to specified device. + phoneme_model.to(device) + phoneme_model = torch.nn.DataParallel(phoneme_model) + classifier_model.to(device) + classifier_model = torch.nn.DataParallel(classifier_model) + model = {'name': phoneme_model.module.__name__, 'phoneme': phoneme_model, + 'classifier_name': classifier_model.module.__name__, 'classifier': classifier_model} + + # Optimizer. + if args.optim == "adam": + model['opt'] = torch.optim.Adam(model['classifier'].parameters(), lr=args.lr) + if args.optim == "sgd": + model['opt'] = torch.optim.SGD(model['classifier'].parameters(), lr=args.lr) + + # Load the specified phoneme checkpoint. 'phoneme_model_load_ckpt' must point to a checkpoint and not folder. + if args.phoneme_model_load_ckpt is not None: + if os.path.exists(args.phoneme_model_load_ckpt): + # Load Checkpoint. + latest_phoneme_ckpt = torch.load(args.phoneme_model_load_ckpt, map_location=device) + # Load specific state_dicts() and print the latest stats. + print(f"Model Phoneme Location : {args.phoneme_model_load_ckpt}", flush=True) + model['phoneme'].load_state_dict(latest_phoneme_ckpt['phoneme_state_dict']) + print(f"Checkpoint Stats : {latest_phoneme_ckpt['train_stats']}", flush=True) + else: + raise ValueError("Invalid Phoneme Checkpoint Path") + else: + print("No Phoneme Checkpoint Given", flush=True) + + # Load the specified classifier checkpoint. 'classifier_model_load_ckpt' must point to a checkpoint and not folder. + if args.classifier_model_load_ckpt is not None: + if os.path.exists(args.classifier_model_load_ckpt): + # Get the number from the classifier checkpoint path. + start_epoch = args.classifier_model_load_ckpt # Temporarily store the full ckpt path. + start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux). + start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows). + start_epoch = int(start_epoch.split('.')[0]) # retain the integers. + # Load Checkpoint. + latest_classifier_ckpt = torch.load(args.classifier_model_load_ckpt, map_location=device) + # Load specific state_dicts() and print the latest stats. + model['classifier'].load_state_dict(latest_classifier_ckpt['classifier_state_dict']) + model['opt'].load_state_dict(latest_classifier_ckpt['opt_state_dict']) + print(f"Checkpoint Stats : {latest_classifier_ckpt['train_stats']}", flush=True) + else: + raise ValueError("Invalid Classifier Checkpoint Path") + else: + start_epoch = 0 + + # Instantiate all Essential Variables and utils. + train_dataset, test_dataset = get_classification_dataset(args) + train_loader = train_dataset.loader + test_loader = test_dataset.loader + total_batches = len(train_loader) + output_frame_rate = 3 + save_path = args.classifier_model_save_folder + os.makedirs(args.classifier_model_save_folder, exist_ok=True) + # Print for cross-checking. + print(f"Pre Phone List {args.pre_phone_list}", flush=True) + print(f"Start Epoch : {start_epoch}", flush=True) + print(f"Device : {device}", flush=True) + print(f"Output Frame Rate (multiple of 10ms): {output_frame_rate}", flush=True) + print(f"Number of Batches: {total_batches}", flush=True) + print(f"Synth: {args.synth}", flush=True) + print(f"Words: {args.words}", flush=True) + print(f"Optimizer : {model['opt']}", flush=True) + + # Train Loop + for epoch in range(start_epoch + 1, args.epochs): + model['train_stats'] = {'loss': 0, 'correct': 0, 'total': 0} + model['classifier'].train(True) + for train_features, train_label, train_seqlen in train_loader: + train_seqlen_classifier = train_seqlen.clone() / output_frame_rate + train_features = train_features.to(device) + train_label = train_label.to(device) + train_seqlen_classifier = train_seqlen_classifier.to(device) + model['opt'].zero_grad() + + # Data-padding for bricking. + train_features = train_features.permute((0, 2, 1)) # NCL to NLC. + mod_len = train_features.shape[1] + pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate + pad_len_feature = pad_len_mod + pad_data = torch.zeros(train_features.shape[0], pad_len_feature, + train_features.shape[2]).to(device) + train_features = torch.cat((train_features, pad_data), dim=1) + + assert (train_features.shape[1]) % output_frame_rate == 0 + + # Get the posterior predictions and trim the labels to the same length as the predictions. + train_features = train_features.permute((0, 2, 1)) # NLC to NCL. + train_posteriors = model['phoneme'](train_features) + train_posteriors = model['classifier'](train_posteriors, train_seqlen_classifier) + N, L, C = train_posteriors.shape + + # Permute and ready the final and pred labels values. + train_flat_posteriors = train_posteriors.reshape((-1, C)) # to [NL] x C. + + # Loss and backward step. + train_label = train_label.type(torch.float32) + loss_classifier_model = torch.nn.functional.binary_cross_entropy_with_logits(train_flat_posteriors, train_label) + loss_classifier_model.backward() + torch.nn.utils.clip_grad_norm_(model['classifier'].parameters(), 10.0) + model['opt'].step() + + # Stats. + model['train_stats']['loss'] += loss_classifier_model.detach() + _, train_idx_pred = torch.max(train_flat_posteriors, dim=1) + _, train_idx_label = torch.max(train_label, dim=1) + model['train_stats']['correct'] += float(np.sum((train_idx_pred == train_idx_label).detach().cpu().numpy())) + model['train_stats']['total'] += train_idx_label.shape[0] + + if epoch % args.save_tick == 0: + # Save the model. + torch.save({'classifier_state_dict': model['classifier'].state_dict(), + 'opt_state_dict': model['opt'].state_dict(), 'train_stats' : model['train_stats']}, + os.path.join(save_path, f'{epoch}.pt')) + + avg_ce = model['train_stats']['loss'].cpu() / total_batches + train_accuracy = 100.0 * model['train_stats']['correct'] / model['train_stats']['total'] + print(f"Summary for Epoch {epoch} for Model {model['classifier_name']}; Loss: {avg_ce}", flush=True) + print(f"TRAIN => Accuracy: {train_accuracy}; Correct: {model['train_stats']['correct']}; Total: {model['train_stats']['total']}", flush=True) + + model['test_stats'] = {'correct': 0,'total': 0} + model['classifier'].eval() + with torch.no_grad(): + for test_features, test_label, test_seqlen in test_loader: + test_seqlen_classifier = test_seqlen.clone() / output_frame_rate + test_features = test_features.to(device) + test_label = test_label.to(device) + test_seqlen_classifier = test_seqlen_classifier.to(device) + + # Data-padding for bricking. + test_features = test_features.permute((0, 2, 1)) # NCL to NLC. + mod_len = test_features.shape[1] + pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate + pad_len_feature = pad_len_mod + pad_data = torch.zeros(test_features.shape[0], pad_len_feature, + test_features.shape[2]).to(device) + test_features = torch.cat((test_features, pad_data), dim=1) + + assert (test_features.shape[1]) % output_frame_rate == 0 + + # Get the posterior predictions and trim the labels to the same length as the predictions. + test_features = test_features.permute((0, 2, 1)) # NLC to NCL. + test_posteriors = model['phoneme'](test_features) + test_posteriors = model['classifier'](test_posteriors, test_seqlen_classifier) + N, L, C = test_posteriors.shape + + # Permute and ready the final and pred labels values. + test_flat_posteriors = test_posteriors.reshape((-1, C)) # to [NL] x C. + + # Stats. + _, test_idx_pred = torch.max(test_flat_posteriors, dim=1) + _, test_idx_label = torch.max(test_label, dim=1) + model['test_stats']['correct'] += float(np.sum((test_idx_pred == test_idx_label).detach().cpu().numpy())) + model['test_stats']['total'] += test_idx_label.shape[0] + + test_accuracy = 100.0 * model['test_stats']['correct'] / model['test_stats']['total'] + print(f"TEST => Accuracy: {test_accuracy}; Correct: {model['test_stats']['correct']}; Total: {model['test_stats']['total']}", flush=True) + return + +if __name__ == '__main__': + args = parseArgs() + train_classifier_model(args) diff --git a/applications/KWS_Phoneme/train_phoneme.py b/applications/KWS_Phoneme/train_phoneme.py new file mode 100644 index 000000000..44dc36a55 --- /dev/null +++ b/applications/KWS_Phoneme/train_phoneme.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import argparse +import os +import re +import numpy as np +import torch +# Aux scripts. +import kwscnn +import multiprocessing +from data_pipe import get_ASR_datasets + +def parseArgs(): + """ + Parse the command line arguments. + Describes the architecture and the hyper-parameters. + """ + parser = argparse.ArgumentParser() + # Args for Model Traning. + parser.add_argument('--phoneme_model_save_folder', type=str, default='./phoneme_model', help="Folder to save the checkpoint") + parser.add_argument('--phoneme_model_load_ckpt', type=str, default=None, help="Checkpoint file to be loaded") + parser.add_argument('--optim', type=str, default='adam', help="Optimizer to be used") + parser.add_argument('--lr', type=float, default=0.001, help="Optimizer learning rate") + parser.add_argument('--epochs', type=int, default=200, help="Number of epochs for training") + parser.add_argument('--save_tick', type=int, default=1, help="Number of epochs to wait to save") + parser.add_argument('--workers', type=int, default=-1, help="Number of workers. Give -1 for all workers") + parser.add_argument("--gpu", type=str, default='0', help="GPU indices Eg: --gpu=0,1,2,3 for 4 gpus. -1 for CPU") + # Args for DataLoader. + parser.add_argument('--base_path', type=str, required=True, help="Path of the speech data folder. The data in this folder should be in accordance to the dataloader code written here.") + parser.add_argument('--rir_base_path', type=str, required=True, help="Folder with the reverbration files") + parser.add_argument('--additive_base_path', type=str, required=True, help="Folder with additive noise files") + parser.add_argument('--phoneme_text_file', type=str, required=True, help="Text files with pre-fixed phons") + parser.add_argument('--pretraining_length_mean', type=int, default=6, help="Mean of the audio clips lengths") + parser.add_argument('--pretraining_length_var', type=int, default=1, help="variance of the audio clip lengths") + parser.add_argument('--pretraining_batch_size', type=int, default=256, help="Batch size for the pipeline") + parser.add_argument('--snr_samples', type=str, default="0,5,10,25,100,100", help="SNR values for additive noise files") + parser.add_argument('--wgn_snr_samples', type=str, default="5,10,15,100,100", help="SNR values for white gaussian noise") + parser.add_argument('--gain_samples', type=str, default="1.0,0.25,0.5,0.75", help="Gain values for the processed signal") + parser.add_argument('--rir_chance', type=float, default=0.25, help="Probability of performing reverbration") + parser.add_argument('--synth_chance', type=float, default=0.5, help="Probability of pre-processing the signal with noise and reverb") + parser.add_argument('--pre_phone_list', action='store_true', help="Use pre-fixed list of phonemes") + # Args for Phoneme. + parser.add_argument('--phoneme_cnn_channels', type=int, default=400, help="Number od channels for the CNN layers") + parser.add_argument('--phoneme_rnn_hidden_size', type=int, default=200, help="Number of RNN hidden states") + parser.add_argument('--phoneme_rnn_layers', type=int, default=1, help="Number of RNN layers") + parser.add_argument('--phoneme_rank', type=int, default=50, help="Rank of the CNN layers weights") + parser.add_argument('--phoneme_fwd_context', type=int, default=15, help="RNN forward window context") + parser.add_argument('--phoneme_bwd_context', type=int, default=9, help="RNN backward window context") + parser.add_argument('--phoneme_phoneme_isBi', action='store_true', help="Use Bi-Directional RNN") + parser.add_argument('--phoneme_num_labels', type=int, default=41, help="Number og phoneme labels") + + args = parser.parse_args() + + # Parse the gain and SNR values to a float format. + args.snr_samples = [int(samp) for samp in args.snr_samples.split(',')] + args.wgn_snr_samples = [int(samp) for samp in args.wgn_snr_samples.split(',')] + args.gain_samples = [float(samp) for samp in args.gain_samples.split(',')] + + # Fix the number of workers for the data Loader. If == -1 then use all possible workers. + if args.workers == -1: + args.workers = multiprocessing.cpu_count() + + print(f"Args : {args}", flush=True) + return args + +def train_phoneme_model(args): + """ + Train the Phoneme Model on the designated dataset. + The Dataset loader is defined in data_pipe.py. + Default dataset used is LibriSpeeech. Change the paths and file reader to change datasets. + + args: args object (contains info about model and training). + """ + # GPU Settings. + gpu_str = str() + for gpu in args.gpu.split(','): + gpu_str = gpu_str + str(gpu) + "," + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str + use_cuda = torch.cuda.is_available() and (args.gpu != -1) + device = torch.device("cuda" if use_cuda else "cpu") + + # Instantiate model. + phoneme_model = kwscnn.DSCNN_RNN_Block(cnn_channels=args.phoneme_cnn_channels, + rnn_hidden_size=args.phoneme_rnn_hidden_size, + rnn_num_layers=args.phoneme_rnn_layers, + device=device, rank=args.phoneme_rank, + fwd_context=args.phoneme_fwd_context, + bwd_context=args.phoneme_bwd_context, + num_labels=args.phoneme_num_labels) + + # Transfer to specified device. + phoneme_model.to(device) + phoneme_model = torch.nn.DataParallel(phoneme_model) + model = {'name': phoneme_model.module.__name__, 'phoneme': phoneme_model} + + + # Optimizer. + if args.optim == "adam": + model['opt'] = torch.optim.Adam(model['phoneme'].parameters(), lr=args.lr) + if args.optim == "sgd": + model['opt'] = torch.optim.SGD(model['phoneme'].parameters(), lr=args.lr) + + # Load the specified checkpoint. 'phoneme_model_load_ckpt' must point to a checkpoint and not folder. + if args.phoneme_model_load_ckpt is not None: + if os.path.exists(args.phoneme_model_load_ckpt): + # Get the number from the phoneme checkpoint path. + start_epoch = args.phoneme_model_load_ckpt # Temporarily store the full ckpt path. + start_epoch = start_epoch.split('/')[-1] # retain only the *.pt from the path (Linux). + start_epoch = start_epoch.split('\\')[-1] # retain only the *.pt from the path (Windows). + start_epoch = int(start_epoch.split('.')[0]) # retain the integers. + # Load Checkpoint. + latest_ckpt = torch.load(args.phoneme_model_load_ckpt, map_location=device) + # Load specific state_dicts() and print the latest stats. + model['phoneme'].load_state_dict(latest_ckpt['phoneme_state_dict']) + model['opt'].load_state_dict(latest_ckpt['opt_state_dict']) + print(f"Checkpoint Stats : {latest_ckpt['train_stats']}", flush=True) + else: + raise ValueError("Invalid Checkpoint Path") + else: + start_epoch = 0 + + # Instantiate dataloaders, essential variables and save folders. + train_dataset = get_ASR_datasets(args) + train_loader = train_dataset.loader + total_batches = len(train_loader) + output_frame_rate = 3 + save_path = args.phoneme_model_save_folder + os.makedirs(args.phoneme_model_save_folder, exist_ok=True) + + print(f"Pre Phone List {args.pre_phone_list}", flush=True) + print(f"Start Epoch : {start_epoch}", flush=True) + print(f"Device : {device}", flush=True) + print(f"Output Frame Rate (multiple of 10ms): {output_frame_rate}", flush=True) + print(f"Number of Batches: {total_batches}", flush=True) + + # Train Loop. + for epoch in range(start_epoch + 1, args.epochs): + model['train_stats'] = {'loss': 0, 'predstd': 0, 'correct': 0, 'valid': 0} + for features, label in train_loader: + features = features.to(device) + label = label.to(device) + model['opt'].zero_grad() + + # Data-padding for bricking. + features = features.permute((0, 2, 1)) # NCL to NLC. + mod_len = features.shape[1] + pad_len_mod = (output_frame_rate - mod_len % output_frame_rate) % output_frame_rate + pad_len_feature = pad_len_mod + pad_data = torch.zeros(features.shape[0], pad_len_feature, + features.shape[2]).to(device) + features = torch.cat((features, pad_data), dim=1) + + assert (features.shape[1]) % output_frame_rate == 0 + # Augmenting the label accordingly. + pad_len_label = pad_len_feature + pad_data = torch.ones(label.shape[0], pad_len_label).to(device) * (-1) + pad_data = pad_data.type(torch.long) + label = torch.cat((label, pad_data), dim=1) + + # Get the posterior predictions and trim the labels to the same length as the predictions. + features = features.permute((0, 2, 1)) # NLC to NCL. + posteriors = model['phoneme'](features) + N, C, L = posteriors.shape + trim_label = label[:, ::output_frame_rate] # 30ms frame_rate. + trim_label = trim_label[:, :L] + + # Permute and ready the final and pred labels values. + flat_posteriors = posteriors.permute((0, 2, 1)) # TO NLC. + flat_posteriors = flat_posteriors.reshape((-1, C)) # to [NL] x C. + flat_labels = trim_label.reshape((-1)) + + _, idx = torch.max(flat_posteriors, dim=1) + correct_count = (idx == flat_labels).detach().sum() + valid_count = (flat_labels >= 0).detach().sum() + + # Loss and backward step. + loss_phoneme_model = torch.nn.functional.cross_entropy(flat_posteriors, + flat_labels, ignore_index=-1) + loss_phoneme_model.backward() + torch.nn.utils.clip_grad_norm_(model['phoneme'].parameters(), 10.0) + model['opt'].step() + + # Stats. + pred_std = idx.to(torch.float32).std() + + model['train_stats']['loss'] += loss_phoneme_model.detach() + model['train_stats']['correct'] += correct_count + model['train_stats']['predstd'] += pred_std + model['train_stats']['valid'] += valid_count + + if epoch % args.save_tick == 0: + # Save the model. + torch.save({'phoneme_state_dict': model['phoneme'].state_dict(), + 'opt_state_dict': model['opt'].state_dict(), + 'train_stats' : model['train_stats']}, os.path.join(save_path, f'{epoch}.pt')) + + valid_frames = model['train_stats']['valid'].cpu() + correct_frames = model['train_stats']['correct'].cpu().to(torch.float32) + epoch_prestd = model['train_stats']['predstd'].cpu() + + avg_ce = model['train_stats']['loss'].cpu() / total_batches + avg_err = 100 - 100.0 * (correct_frames / valid_frames) + + print(f"Summary for Epoch {epoch} for Model {model['name']}", flush=True) + print(f"CE: {avg_ce}, ERR: {avg_err}, FRAMES {correct_frames} / {valid_frames}, PREDSTD: {epoch_prestd / total_batches}", flush=True) + return + +if __name__ == '__main__': + args = parseArgs() + train_phoneme_model(args) diff --git a/applications/KWS_Phoneme/utils.py b/applications/KWS_Phoneme/utils.py new file mode 100644 index 000000000..a267a14b0 --- /dev/null +++ b/applications/KWS_Phoneme/utils.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +""" +Some Torch Functions required for overwriting +Convolutions to Low Rank Convolutions +""" + +from typing import List + +from torch._six import container_abcs +from itertools import repeat + + +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + return parse + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _reverse_repeat_tuple(t, n): + r"""Reverse the order of `t` and repeat each element for `n` times. + + This can be used to translate padding arg used by Conv and Pooling modules + to the ones used by `F.pad`. + """ + return tuple(x for x in reversed(t) for _ in range(n)) + + +def _list_with_default(out_size, defaults): + # type: (List[int], List[int]) -> List[int] + if isinstance(out_size, int): + return out_size + if len(defaults) <= len(out_size): + raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1)) + return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]