Skip to content

Port code to tensor flow 2.0 #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tf2.0/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
## Edge Machine Learning: Tensorflow Library

This directory includes, Tensorflow implementations of various techniques and
algorithms developed as part of EdgeML. Currently, the following algorithms are
available in Tensorflow:

1. [Bonsai](../docs/publications/Bonsai.pdf)
2. [EMI-RNN](../docs/publications/emi-rnn-nips18.pdf)
3. [FastRNN & FastGRNN](../docs/publications/FastGRNN.pdf)
4. [ProtoNN](../docs/publications/ProtoNN.pdf)

The TensorFlow compute graphs for these algoriths are packaged as
`edgeml.graph`. Trainers for these algorithms are in `edgeml.trainer`. Usage
directions and examples for these algorithms are provided in `examples`
directory. To get started with any of the provided algorithms, please follow
the notebooks in the the `examples` directory.

## Installation

Use pip and the provided requirements file to first install required
dependencies before installing the `edgeml` library. Details for cpu based
installation and gpu based installation provided below.

It is highly recommended that EdgeML be installed in a virtual environment. Please create
a new virtual environment using your environment manager ([virtualenv](https://virtualenv.pypa.io/en/stable/userguide/#usage) or [Anaconda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands)).
Make sure the new environment is active before running the below mentioned commands.

### CPU

```
pip install -r requirements-cpu.txt
pip install -e .
```

Tested on Python3.5 and python 2.7 with >= Tensorflow 1.6.0.

### GPU

Install appropriate CUDA and cuDNN [Tested with >= CUDA 8.1 and cuDNN >= 6.1]

```
pip install -r requirements-gpu.txt
pip install -e .
```

Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT license.
57 changes: 57 additions & 0 deletions tf2.0/docs/FastCells.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# FastRNN and FastGRNN - FastCells

This document aims to explain and elaborate on specific details of FastCells
present as part of `tf/edgeml/graph/rnn.py`. The endpoint use case scripts with
3 phase training along with an example notebook are present in `tf/examples/FastCells/`.
One can use the endpoint script to test out the RNN architectures on any dataset
while specifying budget constraints as part of hyper-parameters in terms of sparsity and rank
of weight matrices.

# FastRNN
![FastRNN](img/FastRNN.png)
![FastRNN Equation](img/FastRNN_eq.png)

# FastGRNN
![FastGRNN Base Architecture](img/FastGRNN.png)
![FastGRNN Base Equation](img/FastGRNN_eq.png)

# Plug and Play Cells

`FastRNNCell` and `FastGRNNCell` present in `edgeml.graph.rnn` are very similar to
Tensorflow's inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell`, and `UGRNNCell` allowing us to
replace any of the standard RNN Cell in our architecture with FastCells.
One can see the plug and play nature at the endpoint script for FastCells, where the graph
building is very similar to LSTM/GRU in Tensorflow.

Script: [Endpoint Script](../examples/FastCells/fastcell_example.py)

Example Notebook: [iPython Notebook](../examples/FastCells/fastcell_example.ipynb)

Cells: [FastRNNCell](../edgeml/graph/rnn.py#L206) and [FastGRNNCell](../edgeml/graph/rnn.py#L31).

# 3 phase Fast Training

`FastCells`, similar to `Bonsai` use a 3 phase training routine, to induce the right
support and sparsity for the weight matrices. With the low-rank parameterization of weights
followed by the 3 phase training, we obtain FastRNN and FastGRNN models which are compact
and they can be further compressed by using byte quantization without significant loss in accuracy.

# Compression

1) Low-Rank Parameterization of Weight Matrices (L)
2) Sparsity (S)
3) Quantization (Q)

Low-rank is directly induced into the FastCells during initialization and the training happens with
the targetted low-rank versions of the weight matrices. One can use `wRank` and `uRank` parameters
of FastCells to achieve this.

Sparsity is taken in as hyper-parameter during the 3 phase training into `fastTrainer.py` which at the
end spits out a sparse, low-rank model.

Further compression is achieved by byte Quantization and can be performed using `quantizeFastModels.py`
script which is part of `tf/exampled/FastCells/`. This will give model size reduction of up to 4x if 8-bit
integers are used. Lastly, to facilitate all integer arithmetic, including the non-linearities, one could
use `quantTanh` instead of `tanh` and `quantSigm` instead of `sigmoid` as the non-linearities in the RNN
Cells followed by byte quantization. These non-linearities can be set using the appropriate parameters in
the `FastRNNCell` and `FastGRNNCell`
Binary file added tf2.0/docs/img/3PartsGraph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tf2.0/docs/img/FastGRNN.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tf2.0/docs/img/FastGRNN_eq.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tf2.0/docs/img/FastRNN.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tf2.0/docs/img/FastRNN_eq.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tf2.0/docs/img/MIML_illustration.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 13 additions & 0 deletions tf2.0/edgeml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

'''
package edgeml
Provides: Bonsai, ProtoNN and BasicTrainer routines
for both
'''

# TODO Override the __all__ variable for the package
# and limit the functions that are exposed.
# Do not expose functions in utils - can be dangerous
2 changes: 2 additions & 0 deletions tf2.0/edgeml/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
181 changes: 181 additions & 0 deletions tf2.0/edgeml/graph/bonsai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import warnings


class Bonsai:
def __init__(self, numClasses, dataDimension, projectionDimension,
treeDepth, sigma,
isRegression=False, W=None, T=None, V=None, Z=None):
'''
Expected Dimensions:
Bonsai Params // Optional
W [numClasses*totalNodes, projectionDimension]
V [numClasses*totalNodes, projectionDimension]
Z [projectionDimension, dataDimension + 1]
T [internalNodes, projectionDimension]
internalNodes = 2**treeDepth - 1
totalNodes = 2*internalNodes + 1
sigma - tanh non-linearity
sigmaI - Indicator function for node probabilities
sigmaI - has to be set to infinity(1e9 for practicality)
while doing testing/inference
numClasses will be reset to 1 in binary case
'''
self.dataDimension = dataDimension
self.projectionDimension = projectionDimension
self.isRegression = isRegression

if ((self.isRegression == True) & (numClasses != 1)):
warnings.warn("Number of classes cannot be greater than 1 for regression")
self.numClasses = 1

if numClasses == 2:
self.numClasses = 1
else:
self.numClasses = numClasses

self.treeDepth = treeDepth
self.sigma = sigma

self.internalNodes = 2**self.treeDepth - 1
self.totalNodes = 2 * self.internalNodes + 1

self.W = self.initW(W)
self.V = self.initV(V)
self.T = self.initT(T)
self.Z = self.initZ(Z)

self.assertInit()

self.score = None
self.X_ = None
self.prediction = None

def initZ(self, Z):
if Z is None:
Z = tf.random_normal(
[self.projectionDimension, self.dataDimension])
Z = tf.Variable(Z, name='Z', dtype=tf.float32)
return Z

def initW(self, W):
if W is None:
W = tf.random_normal(
[self.numClasses * self.totalNodes, self.projectionDimension])
W = tf.Variable(W, name='W', dtype=tf.float32)
return W

def initV(self, V):
if V is None:
V = tf.random_normal(
[self.numClasses * self.totalNodes, self.projectionDimension])
V = tf.Variable(V, name='V', dtype=tf.float32)
return V

def initT(self, T):
if T is None:
T = tf.random_normal(
[self.internalNodes, self.projectionDimension])
T = tf.Variable(T, name='T', dtype=tf.float32)
return T

def __call__(self, X, sigmaI):
'''
Function to build the Bonsai Tree graph
Expected Dimensions
X is [_, self.dataDimension]
'''
errmsg = "Dimension Mismatch, X is [_, self.dataDimension]"
assert (len(X.shape) == 2 and int(
X.shape[1]) == self.dataDimension), errmsg
if self.score is not None:
return self.score, self.X_

X_ = tf.divide(tf.matmul(self.Z, X, transpose_b=True),
self.projectionDimension)

W_ = self.W[0:(self.numClasses)]
V_ = self.V[0:(self.numClasses)]

self.__nodeProb = []
self.__nodeProb.append(1)

score_ = self.__nodeProb[0] * tf.multiply(
tf.matmul(W_, X_), tf.tanh(self.sigma * tf.matmul(V_, X_)))
for i in range(1, self.totalNodes):
W_ = self.W[i * self.numClasses:((i + 1) * self.numClasses)]
V_ = self.V[i * self.numClasses:((i + 1) * self.numClasses)]

T_ = tf.reshape(self.T[int(np.ceil(i / 2.0) - 1.0)],
[-1, self.projectionDimension])
prob = (1 + ((-1)**(i + 1)) *
tf.tanh(tf.multiply(sigmaI, tf.matmul(T_, X_))))

prob = tf.divide(prob, 2.0)
prob = self.__nodeProb[int(np.ceil(i / 2.0) - 1.0)] * prob
self.__nodeProb.append(prob)
score_ += self.__nodeProb[i] * tf.multiply(
tf.matmul(W_, X_), tf.tanh(self.sigma * tf.matmul(V_, X_)))

self.score = score_
self.X_ = X_
return self.score, self.X_

def getPrediction(self):
'''
Takes in a score tensor and outputs a integer class for each data point
'''

# Classification.
if (self.isRegression == False):
if self.prediction is not None:
return self.prediction

if self.numClasses > 2:
self.prediction = tf.argmax(tf.transpose(self.score), 1)
else:
self.prediction = tf.argmax(
tf.concat([tf.transpose(self.score),
0 * tf.transpose(self.score)], 1), 1)
# Regression.
elif (self.isRegression == True):
# For regression , scores are the actual predictions, just return them.
self.prediction = self.score

return self.prediction

def assertInit(self):
errmsg = "Number of Classes for regression can only be 1."
if (self.isRegression == True):
assert (self.numClasses == 1), errmsg
errRank = "All Parameters must has only two dimensions shape = [a, b]"
assert len(self.W.shape) == len(self.Z.shape), errRank
assert len(self.W.shape) == len(self.T.shape), errRank
assert len(self.W.shape) == 2, errRank
msg = "W and V should be of same Dimensions"
assert self.W.shape == self.V.shape, msg
errW = "W and V are [numClasses*totalNodes, projectionDimension]"
assert self.W.shape[0] == self.numClasses * self.totalNodes, errW
assert self.W.shape[1] == self.projectionDimension, errW
errZ = "Z is [projectionDimension, dataDimension]"
assert self.Z.shape[0] == self.projectionDimension, errZ
assert self.Z.shape[1] == self.dataDimension, errZ
errT = "T is [internalNodes, projectionDimension]"
assert self.T.shape[0] == self.internalNodes, errT
assert self.T.shape[1] == self.projectionDimension, errT
assert int(self.numClasses) > 0, "numClasses should be > 1"
msg = "# of features in data should be > 0"
assert int(self.dataDimension) > 0, msg
msg = "Projection should be > 0 dims"
assert int(self.projectionDimension) > 0, msg
msg = "treeDepth should be >= 0"
assert int(self.treeDepth) >= 0, msg
191 changes: 191 additions & 0 deletions tf2.0/edgeml/graph/protoNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

class ProtoNN:
def __init__(self, inputDimension, projectionDimension, numPrototypes,
numOutputLabels, gamma,
W = None, B = None, Z = None):
'''
Forward computation graph for ProtoNN.
inputDimension: Input data dimension or feature dimension.
projectionDimension: hyperparameter
numPrototypes: hyperparameter
numOutputLabels: The number of output labels or classes
W, B, Z: Numpy matrices that can be used to initialize
projection matrix(W), prototype matrix (B) and prototype labels
matrix (B).
Expected Dimensions:
W inputDimension (d) x projectionDimension (d_cap)
B projectionDimension (d_cap) x numPrototypes (m)
Z numOutputLabels (L) x numPrototypes (m)
'''
with tf.name_scope('protoNN') as ns:
self.__nscope = ns
self.__d = inputDimension
self.__d_cap = projectionDimension
self.__m = numPrototypes
self.__L = numOutputLabels

self.__inW = W
self.__inB = B
self.__inZ = Z
self.__inGamma = gamma
self.W, self.B, self.Z = None, None, None
self.gamma = None

self.__validInit = False
self.__initWBZ()
self.__initGamma()
self.__validateInit()
self.protoNNOut = None
self.predictions = None
self.accuracy = None

def __validateInit(self):
self.__validInit = False
errmsg = "Dimensions mismatch! Should be W[d, d_cap]"
errmsg += ", B[d_cap, m] and Z[L, m]"
d, d_cap, m, L, _ = self.getHyperParams()
assert self.W.shape[0] == d, errmsg
assert self.W.shape[1] == d_cap, errmsg
assert self.B.shape[0] == d_cap, errmsg
assert self.B.shape[1] == m, errmsg
assert self.Z.shape[0] == L, errmsg
assert self.Z.shape[1] == m, errmsg
self.__validInit = True

def __initWBZ(self):
with tf.name_scope(self.__nscope):
W = self.__inW
if W is None:
W = tf.random_normal_initializer()
W = W([self.__d, self.__d_cap])
self.W = tf.Variable(W, name='W', dtype=tf.float32)

B = self.__inB
if B is None:
B = tf.random_uniform_initializer()
B = B([self.__d_cap, self.__m])
self.B = tf.Variable(B, name='B', dtype=tf.float32)

Z = self.__inZ
if Z is None:
Z = tf.random_normal_initializer()
Z = Z([self.__L, self.__m])
Z = tf.Variable(Z, name='Z', dtype=tf.float32)
self.Z = Z
return self.W, self.B, self.Z

def __initGamma(self):
with tf.name_scope(self.__nscope):
gamma = self.__inGamma
self.gamma = tf.constant(gamma, name='gamma')

def getHyperParams(self):
'''
Returns the model hyperparameters:
[inputDimension, projectionDimension,
numPrototypes, numOutputLabels, gamma]
'''
d = self.__d
dcap = self.__d_cap
m = self.__m
L = self.__L
return d, dcap, m, L, self.gamma

def getModelMatrices(self):
'''
Returns Tensorflow tensors of the model matrices, which
can then be evaluated to obtain corresponding numpy arrays.
These can then be exported as part of other implementations of
ProtonNN, for instance a C++ implementation or pure python
implementation.
Returns
[ProjectionMatrix (W), prototypeMatrix (B),
prototypeLabelsMatrix (Z), gamma]
'''
return self.W, self.B, self.Z, self.gamma

def __call__(self, X, Y=None):
'''
This method is responsible for construction of the forward computation
graph. The end point of the computation graph, or in other words the
output operator for the forward computation is returned. Additionally,
if the argument Y is provided, a classification accuracy operator with
Y as target will also be created. For this, Y is assumed to in one-hot
encoded format and the class with the maximum prediction score is
compared to the encoded class in Y. This accuracy operator is returned
by getAccuracyOp() method. If a different accuracyOp is required, it
can be defined by overriding the createAccOp(protoNNScoresOut, Y)
method.
X: Input tensor or placeholder of shape [-1, inputDimension]
Y: Optional tensor or placeholder for targets (labels or classes).
Expected shape is [-1, numOutputLabels].
returns: The forward computation outputs, self.protoNNOut
'''
# This should never execute
assert self.__validInit is True, "Initialization failed!"
if self.protoNNOut is not None:
return self.protoNNOut

W, B, Z, gamma = self.W, self.B, self.Z, self.gamma
with tf.name_scope(self.__nscope):
WX = tf.matmul(X, W)
# Convert WX to tensor so that broadcasting can work
dim = [-1, WX.shape.as_list()[1], 1]
WX = tf.reshape(WX, dim)
dim = [1, B.shape.as_list()[0], -1]
B_ = tf.reshape(B, dim)
l2sim = B_ - WX
l2sim = tf.pow(l2sim, 2)
l2sim = tf.reduce_sum(l2sim, 1, keepdims=True)
self.l2sim = l2sim
gammal2sim = (-1 * gamma * gamma) * l2sim
M = tf.exp(gammal2sim)
dim = [1] + Z.shape.as_list()
Z_ = tf.reshape(Z, dim)
y = tf.multiply(Z_, M)
y = tf.reduce_sum(y, 2, name='protoNNScoreOut')
self.protoNNOut = y
self.predictions = tf.argmax(y, 1, name='protoNNPredictions')
if Y is not None:
self.createAccOp(self.protoNNOut, Y)
return y

def createAccOp(self, outputs, target):
'''
Define an accuracy operation on ProtoNN's output scores and targets.
Here a simple classification accuracy operator is defined. More
complicated operators (for multiple label problems and so forth) can be
defined by overriding this method
'''
assert self.predictions is not None
target = tf.argmax(target, 1)
correctPrediction = tf.equal(self.predictions, target)
acc = tf.reduce_mean(tf.cast(correctPrediction, tf.float32),
name='protoNNAccuracy')
self.accuracy = acc

def getPredictionsOp(self):
'''
The predictions operator is defined as argmax(protoNNScores) for each
prediction.
'''
return self.predictions

def getAccuracyOp(self):
'''
returns accuracyOp as defined by createAccOp. It defaults to
multi-class classification accuracy.
'''
msg = "Accuracy operator not defined in graph. Did you provide Y as an"
msg += " argument to _call_?"
assert self.accuracy is not None, msg
return self.accuracy
2 changes: 2 additions & 0 deletions tf2.0/edgeml/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
561 changes: 561 additions & 0 deletions tf2.0/edgeml/trainer/bonsaiTrainer.py

Large diffs are not rendered by default.

528 changes: 528 additions & 0 deletions tf2.0/edgeml/trainer/fastTrainer.py

Large diffs are not rendered by default.

220 changes: 220 additions & 0 deletions tf2.0/edgeml/trainer/protoNNTrainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

from __future__ import print_function
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import sys
import edgeml.utils as utils


class ProtoNNTrainer:
def __init__(self, protoNNObj, regW, regB, regZ,
sparcityW, sparcityB, sparcityZ,
learningRate, X, Y, lossType='l2'):
'''
A wrapper for the various techniques used for training ProtoNN. This
subsumes both the responsibility of loss graph construction and
performing training. The original training routine that is part of the
C++ implementation of EdgeML used iterative hard thresholding (IHT),
gamma estimation through median heuristic and other tricks for
training ProtoNN. This module implements the same in Tensorflow
and python.
protoNNObj: An instance of ProtoNN class defining the forward
computation graph. The loss functions and training routines will be
attached to this instance.
regW, regB, regZ: Regularization constants for W, B, and
Z matrices of protoNN.
sparcityW, sparcityB, sparcityZ: Sparsity constraints
for W, B and Z matrices. A value between 0 (exclusive) and 1
(inclusive) is expected. A value of 1 indicates dense training.
learningRate: Initial learning rate for ADAM optimizer.
X, Y : Placeholders for data and labels.
X [-1, featureDimension]
Y [-1, num Labels]
lossType: ['l2', 'xentropy']
'''
self.protoNNObj = protoNNObj
self.__regW = regW
self.__regB = regB
self.__regZ = regZ
self.__sW = sparcityW
self.__sB = sparcityB
self.__sZ = sparcityZ
self.__lR = learningRate
self.X = X
self.Y = Y
self.sparseTraining = True
if (sparcityW == 1.0) and (sparcityB == 1.0) and (sparcityZ == 1.0):
self.sparseTraining = False
print("Sparse training disabled.", file=sys.stderr)
# Define placeholders for sparse training
self.W_th = None
self.B_th = None
self.Z_th = None
self.__lossType = lossType
self.__validInit = False
self.__validInit = self.__validateInit()
self.__protoNNOut = protoNNObj(X, Y)
self.loss = self.__lossGraph()
self.trainStep = self.__trainGraph()
self.__hthOp = self.__getHardThresholdOp()
self.accuracy = protoNNObj.getAccuracyOp()

def __validateInit(self):
self.__validInit = False
msg = "Sparsity value should be between"
msg += " 0 and 1 (both inclusive)."
assert self.__sW >= 0. and self.__sW <= 1., 'W:' + msg
assert self.__sB >= 0. and self.__sB <= 1., 'B:' + msg
assert self.__sZ >= 0. and self.__sZ <= 1., 'Z:' + msg
d, dcap, m, L, _ = self.protoNNObj.getHyperParams()
msg = 'Y should be of dimension [-1, num labels/classes]'
msg += ' specified as part of ProtoNN object.'
assert (len(self.Y.shape)) == 2, msg
assert (self.Y.shape[1] == L), msg
msg = 'X should be of dimension [-1, featureDimension]'
msg += ' specified as part of ProtoNN object.'
assert (len(self.X.shape) == 2), msg
assert (self.X.shape[1] == d), msg
self.__validInit = True
msg = 'Values can be \'l2\', or \'xentropy\''
if self.__lossType not in ['l2', 'xentropy']:
raise ValueError(msg)
return True

def __lossGraph(self):
pnnOut = self.__protoNNOut
l1, l2, l3 = self.__regW, self.__regB, self.__regZ
W, B, Z, _ = self.protoNNObj.getModelMatrices()
if self.__lossType == 'l2':
with tf.name_scope('protonn-l2-loss'):
loss_0 = tf.nn.l2_loss(self.Y - pnnOut)
reg = l1 * tf.nn.l2_loss(W) + l2 * tf.nn.l2_loss(B)
reg += l3 * tf.nn.l2_loss(Z)
loss = loss_0 + reg
elif self.__lossType == 'xentropy':
with tf.name_scope('protonn-xentropy-loss'):
loss_0 = tf.nn.softmax_cross_entropy_with_logits_v2(logits=pnnOut,
labels=tf.stop_gradient(self.Y))
loss_0 = tf.reduce_mean(loss_0)
reg = l1 * tf.nn.l2_loss(W) + l2 * tf.nn.l2_loss(B)
reg += l3 * tf.nn.l2_loss(Z)
loss = loss_0 + reg
return loss

def __trainGraph(self):
with tf.name_scope('protonn-gradient-adam'):
trainStep = tf.train.AdamOptimizer(self.__lR)
trainStep = trainStep.minimize(self.loss)
return trainStep

def __getHardThresholdOp(self):
W, B, Z, _ = self.protoNNObj.getModelMatrices()
self.W_th = tf.placeholder(tf.float32, name='W_th')
self.B_th = tf.placeholder(tf.float32, name='B_th')
self.Z_th = tf.placeholder(tf.float32, name='Z_th')
with tf.name_scope('hard-threshold-assignments'):
# hard_thrsd_W = W.assign(self.W_th)
# hard_thrsd_B = B.assign(self.B_th)
# hard_thrsd_Z = Z.assign(self.Z_th)
# Code changes for tf 1.11
hard_thrsd_W = tf.assign(W, self.W_th)
hard_thrsd_B = tf.assign(B, self.B_th)
hard_thrsd_Z = tf.assign(Z, self.Z_th)
hard_thrsd_op = tf.group(hard_thrsd_W, hard_thrsd_B, hard_thrsd_Z)
return hard_thrsd_op

def train(self, batchSize, totalEpochs, sess,
x_train, x_val, y_train, y_val, noInit=False,
redirFile=None, printStep=10, valStep=3):
'''
Performs dense training of ProtoNN followed by iterative hard
thresholding to enforce sparsity constraints.
batchSize: Batch size per update
totalEpochs: The number of epochs to run training for. One epoch is
defined as one pass over the entire training data.
sess: The Tensorflow session to use for running various graph
operators.
x_train, x_val, y_train, y_val: The numpy array containing train and
validation data. x data is assumed to in of shape [-1,
featureDimension] while y should have shape [-1, numberLabels].
noInit: By default, all the tensors of the computation graph are
initialized at the start of the training session. Set noInit=False to
disable this behaviour.
printStep: Number of batches between echoing of loss and train accuracy.
valStep: Number of epochs between evolutions on validation set.
'''
d, d_cap, m, L, gamma = self.protoNNObj.getHyperParams()
assert batchSize >= 1, 'Batch size should be positive integer'
assert totalEpochs >= 1, 'Total epochs should be positive integer'
assert x_train.ndim == 2, 'Expected training data to be of rank 2'
assert x_train.shape[1] == d, 'Expected x_train to be [-1, %d]' % d
assert x_val.ndim == 2, 'Expected validation data to be of rank 2'
assert x_val.shape[1] == d, 'Expected x_val to be [-1, %d]' % d
assert y_train.ndim == 2, 'Expected training labels to be of rank 2'
assert y_train.shape[1] == L, 'Expected y_train to be [-1, %d]' % L
assert y_val.ndim == 2, 'Expected validation labels to be of rank 2'
assert y_val.shape[1] == L, 'Expected y_val to be [-1, %d]' % L

# Numpy will throw asserts for arrays
if sess is None:
raise ValueError('sess must be valid Tensorflow session.')

trainNumBatches = int(np.ceil(len(x_train) / batchSize))
valNumBatches = int(np.ceil(len(x_val) / batchSize))
x_train_batches = np.array_split(x_train, trainNumBatches)
y_train_batches = np.array_split(y_train, trainNumBatches)
x_val_batches = np.array_split(x_val, valNumBatches)
y_val_batches = np.array_split(y_val, valNumBatches)
if not noInit:
sess.run(tf.global_variables_initializer())
X, Y = self.X, self.Y
W, B, Z, _ = self.protoNNObj.getModelMatrices()
for epoch in range(totalEpochs):
for i in range(len(x_train_batches)):
batch_x = x_train_batches[i]
batch_y = y_train_batches[i]
feed_dict = {
X: batch_x,
Y: batch_y
}
sess.run(self.trainStep, feed_dict=feed_dict)
if i % printStep == 0:
loss, acc = sess.run([self.loss, self.accuracy],
feed_dict=feed_dict)
msg = "Epoch: %3d Batch: %3d" % (epoch, i)
msg += " Loss: %3.5f Accuracy: %2.5f" % (loss, acc)
print(msg, file=redirFile)

# Perform Hard thresholding
if self.sparseTraining:
W_, B_, Z_ = sess.run([W, B, Z])
fd_thrsd = {
self.W_th: utils.hardThreshold(W_, self.__sW),
self.B_th: utils.hardThreshold(B_, self.__sB),
self.Z_th: utils.hardThreshold(Z_, self.__sZ)
}
sess.run(self.__hthOp, feed_dict=fd_thrsd)

if (epoch + 1) % valStep == 0:
acc = 0.0
loss = 0.0
for j in range(len(x_val_batches)):
batch_x = x_val_batches[j]
batch_y = y_val_batches[j]
feed_dict = {
X: batch_x,
Y: batch_y
}
acc_, loss_ = sess.run([self.accuracy, self.loss],
feed_dict=feed_dict)
acc += acc_
loss += loss_
acc /= len(y_val_batches)
loss /= len(y_val_batches)
print("Test Loss: %2.5f Accuracy: %2.5f" % (loss, acc))

340 changes: 340 additions & 0 deletions tf2.0/edgeml/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

from __future__ import print_function
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import scipy.cluster
import scipy.spatial
import os


def medianHeuristic(data, projectionDimension, numPrototypes, W_init=None):
'''
This method can be used to estimate gamma for ProtoNN. An approximation to
median heuristic is used here.
1. First the data is collapsed into the projectionDimension by W_init. If
W_init is not provided, it is initialized from a random normal(0, 1). Hence
data normalization is essential.
2. Prototype are computed by running a k-means clustering on the projected
data.
3. The median distance is then estimated by calculating median distance
between prototypes and projected data points.
data needs to be [-1, numFeats]
If using this method to initialize gamma, please use the W and B as well.
TODO: Return estimate of Z (prototype labels) based on cluster centroids
andand labels
TODO: Clustering fails due to singularity error if projecting upwards
W [dxd_cap]
B [d_cap, m]
returns gamma, W, B
'''
assert data.ndim == 2
X = data
featDim = data.shape[1]
if projectionDimension > featDim:
print("Warning: Projection dimension > feature dimension. Gamma")
print("\t estimation due to median heuristic could fail.")
print("\tTo retain the projection dataDimension, provide")
print("\ta value for gamma.")

if W_init is None:
W_init = np.random.normal(size=[featDim, projectionDimension])
W = W_init
XW = np.matmul(X, W)
assert XW.shape[1] == projectionDimension
assert XW.shape[0] == len(X)
# Requires [N x d_cap] data matrix of N observations of d_cap-dimension and
# the number of centroids m. Returns, [n x d_cap] centroids and
# elementwise center information.
B, centers = scipy.cluster.vq.kmeans2(XW, numPrototypes)
# Requires two matrices. Number of observations x dimension of observation
# space. Distances[i,j] is the distance between XW[i] and B[j]
distances = scipy.spatial.distance.cdist(XW, B, metric='euclidean')
distances = np.reshape(distances, [-1])
gamma = np.median(distances)
gamma = 1 / (2.5 * gamma)
return gamma.astype('float32'), W.astype('float32'), B.T.astype('float32')


def multiClassHingeLoss(logits, label, batch_th):
'''
MultiClassHingeLoss to match C++ Version - No TF internal version
'''
flatLogits = tf.reshape(logits, [-1, ])
label_ = tf.argmax(label, 1)

correctId = tf.range(0, batch_th) * label.shape[1] + label_
correctLogit = tf.gather(flatLogits, correctId)

maxLabel = tf.argmax(logits, 1)
top2, _ = tf.nn.top_k(logits, k=2, sorted=True)

wrongMaxLogit = tf.where(
tf.equal(maxLabel, label_), top2[:, 1], top2[:, 0])

return tf.reduce_mean(tf.nn.relu(1. + wrongMaxLogit - correctLogit))


def crossEntropyLoss(logits, label):
'''
Cross Entropy loss for MultiClass case in joint training for
faster convergence
'''
return tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
labels=tf.stop_gradient(label)))


def mean_absolute_error(logits, label):
'''
Function to compute the mean absolute error.
'''
return tf.reduce_mean(tf.abs(tf.subtract(logits, label)))


def hardThreshold(A, s):
'''
Hard thresholding function on Tensor A with sparsity s
'''
A_ = np.copy(A)
A_ = A_.ravel()
if len(A_) > 0:
th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher')
A_[np.abs(A_) < th] = 0.0
A_ = A_.reshape(A.shape)
return A_


def copySupport(src, dest):
'''
copy support of src tensor to dest tensor
'''
support = np.nonzero(src)
dest_ = dest
dest = np.zeros(dest_.shape)
dest[support] = dest_[support]
return dest


def countnnZ(A, s, bytesPerVar=4):
'''
Returns # of non-zeros and representative size of the tensor
Uses dense for s >= 0.5 - 4 byte
Else uses sparse - 8 byte
'''
params = 1
hasSparse = False
for i in range(0, len(A.shape)):
params *= int(A.shape[i])
if s < 0.5:
nnZ = np.ceil(params * s)
hasSparse = True
return nnZ, nnZ * 2 * bytesPerVar, hasSparse
else:
nnZ = params
return nnZ, nnZ * bytesPerVar, hasSparse


def getConfusionMatrix(predicted, target, numClasses):
'''
Returns a confusion matrix for a multiclass classification
problem. `predicted` is a 1-D array of integers representing
the predicted classes and `target` is the target classes.
confusion[i][j]: Number of elements of class j
predicted as class i
Labels are assumed to be in range(0, numClasses)
Use`printFormattedConfusionMatrix` to echo the confusion matrix
in a user friendly form.
'''
assert(predicted.ndim == 1)
assert(target.ndim == 1)
arr = np.zeros([numClasses, numClasses])

for i in range(len(predicted)):
arr[predicted[i]][target[i]] += 1
return arr


def printFormattedConfusionMatrix(matrix):
'''
Given a 2D confusion matrix, prints it in a human readable way.
The confusion matrix is expected to be a 2D numpy array with
square dimensions
'''
assert(matrix.ndim == 2)
assert(matrix.shape[0] == matrix.shape[1])
RECALL = 'Recall'
PRECISION = 'PRECISION'
print("|%s|" % ('True->'), end='')
for i in range(matrix.shape[0]):
print("%7d|" % i, end='')
print("%s|" % 'Precision')

print("|%s|" % ('-' * len(RECALL)), end='')
for i in range(matrix.shape[0]):
print("%s|" % ('-' * 7), end='')
print("%s|" % ('-' * len(PRECISION)))

precisionlist = np.sum(matrix, axis=1)
recalllist = np.sum(matrix, axis=0)
precisionlist = [matrix[i][i] / x if x !=
0 else -1 for i, x in enumerate(precisionlist)]
recalllist = [matrix[i][i] / x if x !=
0 else -1 for i, x in enumerate(recalllist)]
for i in range(matrix.shape[0]):
# len recall = 6
print("|%6d|" % (i), end='')
for j in range(matrix.shape[0]):
print("%7d|" % (matrix[i][j]), end='')
print("%s" % (" " * (len(PRECISION) - 7)), end='')
if precisionlist[i] != -1:
print("%1.5f|" % precisionlist[i])
else:
print("%7s|" % "nan")

print("|%s|" % ('-' * len(RECALL)), end='')
for i in range(matrix.shape[0]):
print("%s|" % ('-' * 7), end='')
print("%s|" % ('-' * len(PRECISION)))
print("|%s|" % ('Recall'), end='')

for i in range(matrix.shape[0]):
if recalllist[i] != -1:
print("%1.5f|" % (recalllist[i]), end='')
else:
print("%7s|" % "nan", end='')

print('%s|' % (' ' * len(PRECISION)))


def getPrecisionRecall(cmatrix, label=1):
trueP = cmatrix[label][label]
denom = np.sum(cmatrix, axis=0)[label]
if denom == 0:
denom = 1
recall = trueP / denom
denom = np.sum(cmatrix, axis=1)[label]
if denom == 0:
denom = 1
precision = trueP / denom
return precision, recall


def getMacroPrecisionRecall(cmatrix):
# TP + FP
precisionlist = np.sum(cmatrix, axis=1)
# TP + FN
recalllist = np.sum(cmatrix, axis=0)
precisionlist__ = [cmatrix[i][i] / x if x !=
0 else 0 for i, x in enumerate(precisionlist)]
recalllist__ = [cmatrix[i][i] / x if x !=
0 else 0 for i, x in enumerate(recalllist)]
precision = np.sum(precisionlist__)
precision /= len(precisionlist__)
recall = np.sum(recalllist__)
recall /= len(recalllist__)
return precision, recall


def getMicroPrecisionRecall(cmatrix):
# TP + FP
precisionlist = np.sum(cmatrix, axis=1)
# TP + FN
recalllist = np.sum(cmatrix, axis=0)
num = 0.0
for i in range(len(cmatrix)):
num += cmatrix[i][i]

precision = num / np.sum(precisionlist)
recall = num / np.sum(recalllist)
return precision, recall


def getMacroMicroFScore(cmatrix):
'''
Returns macro and micro f-scores.
Refer: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf
'''
precisionlist = np.sum(cmatrix, axis=1)
recalllist = np.sum(cmatrix, axis=0)
precisionlist__ = [cmatrix[i][i] / x if x !=
0 else 0 for i, x in enumerate(precisionlist)]
recalllist__ = [cmatrix[i][i] / x if x !=
0 else 0 for i, x in enumerate(recalllist)]
macro = 0.0
for i in range(len(precisionlist)):
denom = precisionlist__[i] + recalllist__[i]
numer = precisionlist__[i] * recalllist__[i] * 2
if denom == 0:
denom = 1
macro += numer / denom
macro /= len(precisionlist)

num = 0.0
for i in range(len(precisionlist)):
num += cmatrix[i][i]

denom1 = np.sum(precisionlist)
denom2 = np.sum(recalllist)
pi = num / denom1
rho = num / denom2
denom = pi + rho
if denom == 0:
denom = 1
micro = 2 * pi * rho / denom
return macro, micro


def restructreMatrixBonsaiSeeDot(A, nClasses, nNodes):
'''
Restructures a matrix from [nNodes*nClasses, Proj] to
[nClasses*nNodes, Proj] for SeeDot
'''
tempMatrix = np.zeros(A.shape)
rowIndex = 0

for i in range(0, nClasses):
for j in range(0, nNodes):
tempMatrix[rowIndex] = A[j * nClasses + i]
rowIndex += 1

return tempMatrix


class GraphManager:
'''
Manages saving and restoring graphs. Designed to be used with EMI-RNN
though is general enough to be useful otherwise as well.
'''

def __init__(self):
pass

def checkpointModel(self, saver, sess, modelPrefix,
globalStep=1000, redirFile=None):
saver.save(sess, modelPrefix, global_step=globalStep)
print('Model saved to %s, global_step %d' % (modelPrefix, globalStep),
file=redirFile)

def loadCheckpoint(self, sess, modelPrefix, globalStep,
redirFile=None):
metaname = modelPrefix + '-%d.meta' % globalStep
basename = os.path.basename(metaname)
fileList = os.listdir(os.path.dirname(modelPrefix))
fileList = [x for x in fileList if x.startswith(basename)]
assert len(fileList) > 0, 'Checkpoint file not found'
msg = 'Too many or too few checkpoint files for globalStep: %d' % globalStep
assert len(fileList) is 1, msg
chkpt = basename + '/' + fileList[0]
saver = tf.train.import_meta_graph(metaname)
metaname = metaname[:-5]
saver.restore(sess, metaname)
graph = tf.get_default_graph()
return graph
67 changes: 67 additions & 0 deletions tf2.0/examples/Bonsai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# EdgeML Bonsai on a sample public dataset

This directory includes, example notebook and general execution script of
Bonsai developed as part of EdgeML. Also, we include a sample cleanup and
use-case on the USPS10 public dataset.

`edgeml.graph.bonsai` implements the Bonsai prediction graph in tensorflow.
The three-phase training routine for Bonsai is decoupled from the forward graph
to facilitate a plug and play behaviour wherein Bonsai can be combined with or
used as a final layer classifier for other architectures (RNNs, CNNs).

Note that `bonsai_example.py` assumes that data is in a specific format. It is
assumed that train and test data is contained in two files, `train.npy` and
`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples,
numberOfFeatures + 1]`. The first column of each matrix is assumed to contain
label information. For an N-Class problem, we assume the labels are integers
from 0 through N-1. `bonsai_example.py` also supports univariate regression
and can be accessed using the help options of the script. Multivariate regression
requires restructuring of the input data format and can further help in extending
bonsai to multi-label classification and multi-variate regression. Lastly,
the training data, `train.npy`, is assumed to well shuffled
as the training routine doesn't shuffle internally.

**Tested With:** Tensorflow >1.6 with Python 2 and Python 3

## Download and clean up sample dataset

We will be testing out the validation of the code by using the USPS dataset.
The download and cleanup of the dataset to match the above-mentioned format is
done by the script [fetch_usps.py](fetch_usps.py) and
[process_usps.py](process_usps.py)

```
python fetch_usps.py
python process_usps.py
```

## Sample command for Bonsai on USPS10
The following sample run on usps10 should validate your library:

```bash
python bonsai_example.py -dir usps10/ -d 3 -p 28 -rW 0.001 -rZ 0.0001 -rV 0.001 -rT 0.001 -sZ 0.2 -sW 0.3 -sV 0.3 -sT 0.62 -e 100 -s 1
```
This command should give you a final output screen which reads roughly similar to (might not be exact numbers due to various version mismatches):
```
Maximum Test accuracy at compressed model size(including early stopping): 0.94369704 at Epoch: 66
Final Test Accuracy: 0.93024415
Non-Zeros: 4156.0 Model Size: 31.703125 KB hasSparse: True
```

usps10 directory will now have a consolidated results file called `TFBonsaiResults.txt` and a directory `TFBonsaiResults` with the corresponding models with each run of the code on the usps10 dataset

## Byte Quantization (Q) for model compression
If you wish to quantize the generated model to use byte quantized integers use `quantizeBonsaiModels.py`. Usage Instructions:

```
python quantizeBonsaiModels.py -h
```

This will generate quantized models with a suffix of `q` before every param stored in a new directory `QuantizedTFBonsaiModel` inside the model directory.
One can use this model further on edge devices.


Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT license.
1,135 changes: 1,135 additions & 0 deletions tf2.0/examples/Bonsai/bonsai_example.ipynb

Large diffs are not rendered by default.

115 changes: 115 additions & 0 deletions tf2.0/examples/Bonsai/bonsai_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import helpermethods
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import sys
from edgeml.trainer.bonsaiTrainer import BonsaiTrainer
from edgeml.graph.bonsai import Bonsai


def main():
# Fixing seeds for reproducibility
tf.set_random_seed(42)
np.random.seed(42)

# Hyper Param pre-processing
args = helpermethods.getArgs()

# Set 'isRegression' to be True, for regression. Default is 'False'.
isRegression = args.regression

sigma = args.sigma
depth = args.depth

projectionDimension = args.proj_dim
regZ = args.rZ
regT = args.rT
regW = args.rW
regV = args.rV

totalEpochs = args.epochs

learningRate = args.learning_rate

dataDir = args.data_dir

outFile = args.output_file

(dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest,
mean, std) = helpermethods.preProcessData(dataDir, isRegression)

sparZ = args.sZ

if numClasses > 2:
sparW = 0.2
sparV = 0.2
sparT = 0.2
else:
sparW = 1
sparV = 1
sparT = 1

if args.sW is not None:
sparW = args.sW
if args.sV is not None:
sparV = args.sV
if args.sT is not None:
sparT = args.sT

if args.batch_size is None:
batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0]))))
else:
batchSize = args.batch_size

useMCHLoss = True

if numClasses == 2:
numClasses = 1

X = tf.placeholder("float32", [None, dataDimension])
Y = tf.placeholder("float32", [None, numClasses])

currDir = helpermethods.createTimeStampDir(dataDir)

helpermethods.dumpCommand(sys.argv, currDir)
helpermethods.saveMeanStd(mean, std, currDir)

# numClasses = 1 for binary case
bonsaiObj = Bonsai(numClasses, dataDimension,
projectionDimension, depth, sigma, isRegression)

bonsaiTrainer = BonsaiTrainer(bonsaiObj,
regW, regT, regV, regZ,
sparW, sparT, sparV, sparZ,
learningRate, X, Y, useMCHLoss, outFile)

sess = tf.InteractiveSession()

sess.run(tf.global_variables_initializer())

bonsaiTrainer.train(batchSize, totalEpochs, sess,
Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir)

sess.close()
sys.stdout.close()


if __name__ == '__main__':
main()

# For the following command:
# Data - Curet
# python2 bonsai_example.py -dir ./curet/ -d 2 -p 22 -rW 0.00001 -rZ 0.0000001 -rV 0.00001 -rT 0.000001 -sZ 0.4 -sW 0.5 -sV 0.5 -sT 1 -e 300 -s 0.1 -b 20
# Final Output - useMCHLoss = True
# Maximum Test accuracy at compressed model size(including early stopping): 0.93727726 at Epoch: 297
# Final Test Accuracy: 0.9337135
# Non-Zeros: 24231.0 Model Size: 115.65625 KB hasSparse: True

# Data - usps2
# python2 bonsai_example.py -dir /mnt/c/Users/t-vekusu/Downloads/datasets/usps-binary/ -d 2 -p 22 -rW 0.00001 -rZ 0.0000001 -rV 0.00001 -rT 0.000001 -sZ 0.4 -sW 0.5 -sV 0.5 -sT 1 -e 300 -s 0.1 -b 20
# Maximum Test accuracy at compressed model size(including early stopping): 0.9521674 at Epoch: 246
# Final Test Accuracy: 0.94170403
# Non-Zeros: 2636.0 Model Size: 19.1328125 KB hasSparse: True
64 changes: 64 additions & 0 deletions tf2.0/examples/Bonsai/fetch_usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
#
# Setting up the USPS Data.

import subprocess
import os
import numpy as np
from sklearn.datasets import load_svmlight_file
import sys

def downloadData(workingDir, downloadDir, linkTrain, linkTest):
def runcommand(command):
p = subprocess.Popen(command.split(), stdout=subprocess.PIPE)
output, error = p.communicate()
assert(p.returncode == 0), 'Command failed: %s' % command

path = workingDir + '/' + downloadDir
path = os.path.abspath(path)
try:
os.mkdir(path)
except OSError:
print("Could not create %s. Make sure the path does" % path)
print("not already exist and you have permisions to create it.")
return False
cwd = os.getcwd()
os.chdir(path)
print("Downloading data")
command = 'wget %s' % linkTrain
runcommand(command)
command = 'wget %s' % linkTest
runcommand(command)
print("Extracting data")
command = 'bzip2 -d usps.bz2'
runcommand(command)
command = 'bzip2 -d usps.t.bz2'
runcommand(command)
command = 'mv usps train.txt'
runcommand(command)
command = 'mv usps.t test.txt'
runcommand(command)
os.chdir(cwd)
return True

if __name__ == '__main__':
workingDir = './'
downloadDir = 'usps10'
linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2'
linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2'
failureMsg = '''
Download Failed!
To manually perform the download
\t1. Create a new empty directory named `usps10`.
\t2. Download the data from the following links into the usps10 directory.
\t\tTest: %s
\t\tTrain: %s
\t3. Extract the downloaded files.
\t4. Rename `usps` to `train.txt` and,
\t5. Rename `usps.t` to `test.txt
''' % (linkTrain, linkTest)

if not downloadData(workingDir, downloadDir, linkTrain, linkTest):
exit(failureMsg)
print("Done")
270 changes: 270 additions & 0 deletions tf2.0/examples/Bonsai/helpermethods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

'''
Functions to check sanity of input arguments
for the example script.
'''
import argparse
import datetime
import os
import numpy as np


def checkIntPos(value):
ivalue = int(value)
if ivalue <= 0:
raise argparse.ArgumentTypeError(
"%s is an invalid positive int value" % value)
return ivalue


def checkIntNneg(value):
ivalue = int(value)
if ivalue < 0:
raise argparse.ArgumentTypeError(
"%s is an invalid non-neg int value" % value)
return ivalue


def checkFloatNneg(value):
fvalue = float(value)
if fvalue < 0:
raise argparse.ArgumentTypeError(
"%s is an invalid non-neg float value" % value)
return fvalue


def checkFloatPos(value):
fvalue = float(value)
if fvalue <= 0:
raise argparse.ArgumentTypeError(
"%s is an invalid positive float value" % value)
return fvalue


def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


def getArgs():
'''
Function to parse arguments for Bonsai Algorithm
'''
parser = argparse.ArgumentParser(
description='HyperParams for Bonsai Algorithm')
parser.add_argument('-dir', '--data-dir', required=True,
help='Data directory containing' +
'train.npy and test.npy')

parser.add_argument('-d', '--depth', type=checkIntNneg, default=2,
help='Depth of Bonsai Tree ' +
'(default: 2 try: [0, 1, 3])')
parser.add_argument('-p', '--proj-dim', type=checkIntPos, default=10,
help='Projection Dimension ' +
'(default: 20 try: [5, 10, 30])')
parser.add_argument('-s', '--sigma', type=float, default=1.0,
help='Parameter for sigmoid sharpness ' +
'(default: 1.0 try: [3.0, 0.05, 0.1]')
parser.add_argument('-e', '--epochs', type=checkIntPos, default=42,
help='Total Epochs (default: 42 try:[100, 150, 60])')
parser.add_argument('-b', '--batch-size', type=checkIntPos,
help='Batch Size to be used ' +
'(default: max(100, sqrt(train_samples)))')
parser.add_argument('-lr', '--learning-rate', type=checkFloatPos,
default=0.01, help='Initial Learning rate for ' +
'Adam Optimizer (default: 0.01)')

parser.add_argument('-rW', type=float, default=0.0001,
help='Regularizer for predictor parameter W ' +
'(default: 0.0001 try: [0.01, 0.001, 0.00001])')
parser.add_argument('-rV', type=float, default=0.0001,
help='Regularizer for predictor parameter V ' +
'(default: 0.0001 try: [0.01, 0.001, 0.00001])')
parser.add_argument('-rT', type=float, default=0.0001,
help='Regularizer for branching parameter Theta ' +
'(default: 0.0001 try: [0.01, 0.001, 0.00001])')
parser.add_argument('-rZ', type=float, default=0.00001,
help='Regularizer for projection parameter Z ' +
'(default: 0.00001 try: [0.001, 0.0001, 0.000001])')

parser.add_argument('-sW', type=checkFloatPos,
help='Sparsity for predictor parameter W ' +
'(default: For Binary classification 1.0 else 0.2 ' +
'try: [0.1, 0.3, 0.5])')
parser.add_argument('-sV', type=checkFloatPos,
help='Sparsity for predictor parameter V ' +
'(default: For Binary classification 1.0 else 0.2 ' +
'try: [0.1, 0.3, 0.5])')
parser.add_argument('-sT', type=checkFloatPos,
help='Sparsity for branching parameter Theta ' +
'(default: For Binary classification 1.0 else 0.2 ' +
'try: [0.1, 0.3, 0.5])')
parser.add_argument('-sZ', type=checkFloatPos, default=0.2,
help='Sparsity for projection parameter Z ' +
'(default: 0.2 try: [0.1, 0.3, 0.5])')
parser.add_argument('-oF', '--output-file', default=None,
help='Output file for dumping the program output, ' +
'(default: stdout)')

parser.add_argument('-regression', type=str2bool, default=False,
help='boolean argument which controls whether to perform ' +
'regression or classification.' +
'default : False (Classification) values: [True, False]')

return parser.parse_args()


def getQuantArgs():
'''
Function to parse arguments for Model Quantisation
'''
parser = argparse.ArgumentParser(
description='Arguments for quantizing Fast models. ' +
'Works only for piece-wise linear non-linearities, ' +
'like relu, quantTanh, quantSigm (check rnn.py for the definitions)')
parser.add_argument('-dir', '--model-dir', required=True,
help='model directory containing' +
'*.npy weight files dumped from the trained model')
parser.add_argument('-m', '--max-val', type=checkIntNneg, default=127,
help='this represents the maximum possible value ' +
'in model, essentially the byte complexity, ' +
'127=> 1 byte is default')

return parser.parse_args()


def createTimeStampDir(dataDir):
'''
Creates a Directory with timestamp as it's name
'''
if os.path.isdir(dataDir + '/TFBonsaiResults') is False:
try:
os.mkdir(dataDir + '/TFBonsaiResults')
except OSError:
print("Creation of the directory %s failed" %
dataDir + '/TFBonsaiResults')

currDir = 'TFBonsaiResults/' + datetime.datetime.now().strftime("%H_%M_%S_%d_%m_%y")
if os.path.isdir(dataDir + '/' + currDir) is False:
try:
os.mkdir(dataDir + '/' + currDir)
except OSError:
print("Creation of the directory %s failed" %
dataDir + '/' + currDir)
else:
return (dataDir + '/' + currDir)
return None


def preProcessData(dataDir, isRegression=False):
'''
Function to pre-process input data
Expects a .npy file of form [lbl feats] for each datapoint
Outputs a train and test set datapoints appended with 1 for Bias induction
dataDimension, numClasses are inferred directly
'''
train = np.load(dataDir + '/train.npy')
test = np.load(dataDir + '/test.npy')

dataDimension = int(train.shape[1]) - 1

Xtrain = train[:, 1:dataDimension + 1]
Ytrain_ = train[:, 0]

Xtest = test[:, 1:dataDimension + 1]
Ytest_ = test[:, 0]

# Mean Var Normalisation
mean = np.mean(Xtrain, 0)
std = np.std(Xtrain, 0)
std[std[:] < 0.000001] = 1
Xtrain = (Xtrain - mean) / std
Xtest = (Xtest - mean) / std
# End Mean Var normalisation

# Classification.
if (isRegression == False):
numClasses = max(Ytrain_) - min(Ytrain_) + 1
numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1))

lab = Ytrain_.astype('uint8')
lab = np.array(lab) - min(lab)

lab_ = np.zeros((Xtrain.shape[0], numClasses))
lab_[np.arange(Xtrain.shape[0]), lab] = 1
if (numClasses == 2):
Ytrain = np.reshape(lab, [-1, 1])
else:
Ytrain = lab_

lab = Ytest_.astype('uint8')
lab = np.array(lab) - min(lab)

lab_ = np.zeros((Xtest.shape[0], numClasses))
lab_[np.arange(Xtest.shape[0]), lab] = 1
if (numClasses == 2):
Ytest = np.reshape(lab, [-1, 1])
else:
Ytest = lab_

elif (isRegression == True):
# The number of classes is always 1, for regression.
numClasses = 1
Ytrain = Ytrain_
Ytest = Ytest_

trainBias = np.ones([Xtrain.shape[0], 1])
Xtrain = np.append(Xtrain, trainBias, axis=1)
testBias = np.ones([Xtest.shape[0], 1])
Xtest = np.append(Xtest, testBias, axis=1)

mean = np.append(mean, np.array([0]))
std = np.append(std, np.array([1]))

if (isRegression == False):
return dataDimension + 1, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std
elif (isRegression == True):
return dataDimension + 1, numClasses, Xtrain, Ytrain.reshape((-1, 1)), Xtest, Ytest.reshape((-1, 1)), mean, std


def dumpCommand(list, currDir):
'''
Dumps the current command to a file for further use
'''
commandFile = open(currDir + '/command.txt', 'w')
command = "python"

command = command + " " + ' '.join(list)
commandFile.write(command)

commandFile.flush()
commandFile.close()


def saveMeanStd(mean, std, currDir):
'''
Function to save Mean and Std vectors
'''
np.save(currDir + '/mean.npy', mean)
np.save(currDir + '/std.npy', std)
saveMeanStdSeeDot(mean, std, currDir + "/SeeDot")


def saveMeanStdSeeDot(mean, std, seeDotDir):
'''
Function to save Mean and Std vectors
'''
if os.path.isdir(seeDotDir) is False:
try:
os.mkdir(seeDotDir)
except OSError:
print("Creation of the directory %s failed" %
seeDotDir)
np.savetxt(seeDotDir + '/Mean', mean, delimiter="\t")
np.savetxt(seeDotDir + '/Std', std, delimiter="\t")
54 changes: 54 additions & 0 deletions tf2.0/examples/Bonsai/process_usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
#
# Processing the USPS Data. It is assumed that the data is already
# downloaded.

import subprocess
import os
import numpy as np
from sklearn.datasets import load_svmlight_file
import sys

def processData(workingDir, downloadDir):
def loadLibSVMFile(file):
data = load_svmlight_file(file)
features = data[0]
labels = data[1]
retMat = np.zeros([features.shape[0], features.shape[1] + 1])
retMat[:, 0] = labels
retMat[:, 1:] = features.todense()
return retMat

path = workingDir + '/' + downloadDir
path = os.path.abspath(path)
trf = path + '/train.txt'
tsf = path + '/test.txt'
assert os.path.isfile(trf), 'File not found: %s' % trf
assert os.path.isfile(tsf), 'File not found: %s' % tsf
train = loadLibSVMFile(trf)
test = loadLibSVMFile(tsf)

# Convert the labels from 0 to numClasses-1
y_train = train[:, 0]
y_test = test[:, 0]

lab = y_train.astype('uint8')
lab = np.array(lab) - min(lab)
train[:, 0] = lab

lab = y_test.astype('uint8')
lab = np.array(lab) - min(lab)
test[:, 0] = lab

np.save(path + '/train.npy', train)
np.save(path + '/test.npy', test)

if __name__ == '__main__':
# Configuration
workingDir = './'
downloadDir = 'usps10'
# End config
print("Processing data")
processData(workingDir, downloadDir)
print("Done")
72 changes: 72 additions & 0 deletions tf2.0/examples/Bonsai/quantizeBonsaiModels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import helpermethods
import os
import numpy as np


def min_max(A, name):
print(name + " has max: " + str(np.max(A)) + " min: " + str(np.min(A)))
return np.max([np.abs(np.max(A)), np.abs(np.min(A))])


def quantizeFastModels(modelDir, maxValue=127, scalarScaleFactor=1000):
ls = os.listdir(modelDir)
paramNameList = []
paramWeightList = []
paramLimitList = []

for file in ls:
if file.endswith("npy"):
if file.startswith("mean") or file.startswith("std") or file.startswith("hyperParam"):
continue
else:
paramNameList.append(file)
temp = np.load(modelDir + "/" + file)
paramWeightList.append(temp)
paramLimitList.append(min_max(temp, file))

paramLimit = np.max(paramLimitList)

paramScaleFactor = np.round((2.0 * maxValue + 1.0) / (2.0 * paramLimit))

quantParamWeights = []
for param in paramWeightList:
temp = np.round(paramScaleFactor * param)
temp[temp[:] > maxValue] = maxValue
temp[temp[:] < -maxValue] = -1 * (maxValue + 1)

if maxValue <= 127:
temp = temp.astype('int8')
elif maxValue <= 32767:
temp = temp.astype('int16')
else:
temp = temp.astype('int32')

quantParamWeights.append(temp)

if os.path.isdir(modelDir + '/QuantizedTFBonsaiModel') is False:
try:
os.mkdir(modelDir + '/QuantizedTFBonsaiModel')
quantModelDir = modelDir + '/QuantizedTFBonsaiModel'
except OSError:
print("Creation of the directory %s failed" %
modelDir + '/QuantizedFastModel')

np.save(quantModelDir + "/paramScaleFactor.npy",
paramScaleFactor.astype('int32'))

for i in range(len(paramNameList)):
np.save(quantModelDir + "/q" + paramNameList[i], quantParamWeights[i])

print("\n\nQuantized Model Dir: " + quantModelDir)


def main():
args = helpermethods.getQuantArgs()
quantizeFastModels(args.model_dir, int(args.max_val))


if __name__ == '__main__':
main()
77 changes: 77 additions & 0 deletions tf2.0/examples/FastCells/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# EdgeML FastCells on a sample public dataset

This directory includes example notebook and general execution script of
FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified
UGRNN, GRU and LSTM to support the LSQ training routine.
Also, we include a sample cleanup and use-case on the USPS10 public dataset.

`edgeml.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with
multiple additional features like Low-Rank parameterisation, custom
non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training
routine for FastRNN and FastGRNN is decoupled from the custom cells to
facilitate a plug and play behaviour of the custom RNN cells in other
architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell` etc.,
`edgeml.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../edgeml/graph/rnn.py#L862)),
**GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L635)) and **LSTM** ([`LSTMLRCell`](../../edgeml/graph/rnn.py#L376)). These cells also can be substituted for FastCells where ever feasible.

For training FastCells, `edgeml.trainer.fastTrainer` implements the three-phase
FastCell training routine in Tensorflow. A simple example,
`examples/fastcell_example.py` is provided to illustrate its usage.

Note that `fastcell_example.py` assumes that data is in a specific format. It
is assumed that train and test data is contained in two files, `train.npy` and
`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples,
numberOfFeatures]`. numberOfFeatures is `timesteps x inputDims`, flattened
across timestep dimension. So the input of 1st timestep followed by second and
so on. For an N-Class problem, we assume the labels are integers from 0
through N-1. Lastly, the training data, `train.npy`, is assumed to well shuffled
as the training routine doesn't shuffle internally.

**Tested With:** Tensorflow >1.6 with Python 2 and Python 3

## Download and clean up sample dataset

We will be testing out the validation of the code by using the USPS dataset.
The download and cleanup of the dataset to match the above-mentioned format is
done by the script [fetch_usps.py](fetch_usps.py) and
[process_usps.py](process_usps.py)

```
python fetch_usps.py
python process_usps.py
```


## Sample command for FastCells on USPS10
The following sample run on usps10 should validate your library:

Note: Even though usps10 is not a time-series dataset, it can be assumed as, a time-series where each row is coming in at one single time.
So the number of timesteps = 16 and inputDims = 16

```bash
python fastcell_example.py -dir usps10/ -id 16 -hd 32
```
This command should give you a final output screen which reads roughly similar to (might not be exact numbers due to various version mismatches):

```
Maximum Test accuracy at compressed model size(including early stopping): 0.9407075 at Epoch: 262
Final Test Accuracy: 0.93721974
Non-Zeros: 1932 Model Size: 7.546875 KB hasSparse: False
```
`usps10/` directory will now have a consolidated results file called `FastRNNResults.txt` or `FastGRNNResults.txt` depending on the choice of the RNN cell.
A directory `FastRNNResults` or `FastGRNNResults` with the corresponding models with each run of the code on the `usps10` dataset

## Byte Quantization(Q) for model compression
If you wish to quantize the generated model to use byte quantized integers use `quantizeFastModels.py`. Usage Instructions:

```
python quantizeFastModels.py -h
```

This will generate quantized models with a suffix of `q` before every param stored in a new directory `QuantizedFastModel` inside the model directory.
One can use this model further on edge devices.

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT license.
1,557 changes: 1,557 additions & 0 deletions tf2.0/examples/FastCells/fastcell_example.ipynb

Large diffs are not rendered by default.

99 changes: 99 additions & 0 deletions tf2.0/examples/FastCells/fastcell_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import helpermethods
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import sys

from edgeml.trainer.fastTrainer import FastTrainer
from edgeml.graph.rnn import FastGRNNCell
from edgeml.graph.rnn import FastRNNCell
from edgeml.graph.rnn import UGRNNLRCell
from edgeml.graph.rnn import GRULRCell
from edgeml.graph.rnn import LSTMLRCell


def main():
# Fixing seeds for reproducibility
tf.set_random_seed(42)
np.random.seed(42)

# Hyper Param pre-processing
args = helpermethods.getArgs()

dataDir = args.data_dir
cell = args.cell
inputDims = args.input_dim
hiddenDims = args.hidden_dim

totalEpochs = args.epochs
learningRate = args.learning_rate
outFile = args.output_file
batchSize = args.batch_size
decayStep = args.decay_step
decayRate = args.decay_rate

wRank = args.wRank
uRank = args.uRank

sW = args.sW
sU = args.sU

update_non_linearity = args.update_nl
gate_non_linearity = args.gate_nl

(dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest,
mean, std) = helpermethods.preProcessData(dataDir)

assert dataDimension % inputDims == 0, "Infeasible per step input, " + \
"Timesteps have to be integer"

X = tf.placeholder(
"float", [None, int(dataDimension / inputDims), inputDims])
Y = tf.placeholder("float", [None, numClasses])

currDir = helpermethods.createTimeStampDir(dataDir, cell)

helpermethods.dumpCommand(sys.argv, currDir)
helpermethods.saveMeanStd(mean, std, currDir)

if cell == "FastGRNN":
FastCell = FastGRNNCell(hiddenDims,
gate_non_linearity=gate_non_linearity,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "FastRNN":
FastCell = FastRNNCell(hiddenDims,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "UGRNN":
FastCell = UGRNNLRCell(hiddenDims,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "GRU":
FastCell = GRULRCell(hiddenDims,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "LSTM":
FastCell = LSTMLRCell(hiddenDims,
update_non_linearity=update_non_linearity,
wRank=wRank, uRank=uRank)
else:
sys.exit('Exiting: No Such Cell as ' + cell)

FastCellTrainer = FastTrainer(
FastCell, X, Y, sW=sW, sU=sU,
learningRate=learningRate, outFile=outFile)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

FastCellTrainer.train(batchSize, totalEpochs, sess, Xtrain, Xtest,
Ytrain, Ytest, decayStep, decayRate,
dataDir, currDir)


if __name__ == '__main__':
main()
66 changes: 66 additions & 0 deletions tf2.0/examples/FastCells/fetch_usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
#
# Setting up the USPS Data.

import bz2
import os
import subprocess
import sys

import requests
import numpy as np
from sklearn.datasets import load_svmlight_file
from helpermethods import download_file, decompress



def downloadData(workingDir, downloadDir, linkTrain, linkTest):
path = workingDir + '/' + downloadDir
path = os.path.abspath(path)
try:
os.makedirs(path, exist_ok=True)
except OSError:
print("Could not create %s. Make sure the path does" % path)
print("not already exist and you have permissions to create it.")
return False

training_data_bz2 = download_file(linkTrain, path)
test_data_bz2 = download_file(linkTest, path)

training_data = decompress(training_data_bz2)
test_data = decompress(test_data_bz2)

train = os.path.join(path, "train.txt")
test = os.path.join(path, "test.txt")
if os.path.isfile(train):
os.remove(train)
if os.path.isfile(test):
os.remove(test)

os.rename(training_data, train)
os.rename(test_data, test)
os.remove(training_data_bz2)
os.remove(test_data_bz2)
return True

if __name__ == '__main__':
workingDir = './'
downloadDir = 'usps10'
linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2'
linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2'
failureMsg = '''
Download Failed!
To manually perform the download
\t1. Create a new empty directory named `usps10`.
\t2. Download the data from the following links into the usps10 directory.
\t\tTest: %s
\t\tTrain: %s
\t3. Extract the downloaded files.
\t4. Rename `usps` to `train.txt` and,
\t5. Rename `usps.t` to `test.txt
''' % (linkTrain, linkTest)

if not downloadData(workingDir, downloadDir, linkTrain, linkTest):
exit(failureMsg)
print("Done: see ", downloadDir)
273 changes: 273 additions & 0 deletions tf2.0/examples/FastCells/helpermethods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

'''
Functions to check sanity of input arguments
for the example script.
'''
import argparse
import bz2
import datetime
import json
import os

import numpy as np
import requests


def decompress(filepath):
print("extracting: ", filepath)
zipfile = bz2.BZ2File(filepath) # open the file
data = zipfile.read() # get the decompressed data
newfilepath = os.path.splitext(filepath)[0] # assuming the filepath ends with .bz2
with open(newfilepath, 'wb') as f:
f.write(data) # write a uncompressed file
return newfilepath


def download_file(url, local_folder=None):
"""Downloads file pointed to by `url`.
If `local_folder` is not supplied, downloads to the current folder.
"""
filename = os.path.basename(url)
if local_folder:
filename = os.path.join(local_folder, filename)

# Download the file
print("Downloading: " + url)
response = requests.get(url, stream=True)
if response.status_code != 200:
raise Exception("download file failed with status code: %d, fetching url '%s'" % (response.status_code, url))

# Write the file to disk
with open(filename, "wb") as handle:
handle.write(response.content)
return filename


def checkIntPos(value):
ivalue = int(value)
if ivalue <= 0:
raise argparse.ArgumentTypeError(
"%s is an invalid positive int value" % value)
return ivalue


def checkIntNneg(value):
ivalue = int(value)
if ivalue < 0:
raise argparse.ArgumentTypeError(
"%s is an invalid non-neg int value" % value)
return ivalue


def checkFloatNneg(value):
fvalue = float(value)
if fvalue < 0:
raise argparse.ArgumentTypeError(
"%s is an invalid non-neg float value" % value)
return fvalue


def checkFloatPos(value):
fvalue = float(value)
if fvalue <= 0:
raise argparse.ArgumentTypeError(
"%s is an invalid positive float value" % value)
return fvalue


def getArgs():
'''
Function to parse arguments for FastCells
'''
parser = argparse.ArgumentParser(
description='HyperParams for Fast(G)RNN')
parser.add_argument('-dir', '--data-dir', required=True,
help='Data directory containing' +
'train.npy and test.npy')

parser.add_argument('-c', '--cell', type=str, default="FastGRNN",
help='Choose between [FastGRNN, FastRNN, UGRNN' +
', GRU, LSTM], default: FastGRNN')

parser.add_argument('-id', '--input-dim', type=checkIntNneg, required=True,
help='Input Dimension of RNN, each timestep will ' +
'feed input-dim features to RNN. ' +
'Total Feature length = Input Dim * Total Timestep')
parser.add_argument('-hd', '--hidden-dim', type=checkIntNneg,
required=True, help='Hidden Dimension of RNN')

parser.add_argument('-e', '--epochs', type=checkIntPos, default=300,
help='Total Epochs (default: 300 try:[100, 150, 600])')
parser.add_argument('-b', '--batch-size', type=checkIntPos, default=100,
help='Batch Size to be used (default: 100)')
parser.add_argument('-lr', '--learning-rate', type=checkFloatPos,
default=0.01, help='Initial Learning rate for ' +
'Adam Optimizer (default: 0.01)')

parser.add_argument('-rW', '--wRank', type=checkIntPos, default=None,
help='Rank for the low-rank parameterisation of W, ' +
'None => Full Rank')
parser.add_argument('-rU', '--uRank', type=checkIntPos, default=None,
help='Rank for the low-rank parameterisation of U, ' +
'None => Full Rank')

parser.add_argument('-sW', type=checkFloatPos, default=1.0,
help='Sparsity for predictor parameter W(and both ' +
'W1 and W2 in low-rank) ' +
'(default: 1.0(Dense) try: [0.1, 0.2, 0.3])')
parser.add_argument('-sU', type=checkFloatPos, default=1.0,
help='Sparsity for predictor parameter U(and both ' +
'U1 and U2 in low-rank) ' +
'(default: 1.0(Dense) try: [0.1, 0.2, 0.3])')

parser.add_argument('-unl', '--update-nl', type=str, default="tanh",
help='Update non linearity. Choose between ' +
'[tanh, sigmoid, relu, quantTanh, quantSigm]. ' +
'default => tanh. Can add more in edgeml/graph/rnn.py')
parser.add_argument('-gnl', '--gate-nl', type=str, default="sigmoid",
help='Gate non linearity. Choose between ' +
'[tanh, sigmoid, relu, quantTanh, quantSigm]. ' +
'default => sigmoid. Can add more in ' +
'edgeml/graph/rnn.py. Only Applicable to FastGRNN')

parser.add_argument('-dS', '--decay-step', type=checkIntPos, default=200,
help='The interval (in epochs) after which the ' +
'learning rate should decay. ' +
'Default is 200 for 300 epochs')

parser.add_argument('-dR', '--decay-rate', type=checkFloatPos, default=0.1,
help='The factor by which learning rate ' +
'should decay after each interval. Default 0.1')

parser.add_argument('-oF', '--output-file', default=None,
help='Output file for dumping the program output, ' +
'(default: stdout)')

return parser.parse_args()


def getQuantArgs():
'''
Function to parse arguments for Model Quantisation
'''
parser = argparse.ArgumentParser(
description='Arguments for quantizing Fast models. ' +
'Works only for piece-wise linear non-linearities, ' +
'like relu, quantTanh, quantSigm (check rnn.py for the definitions)')
parser.add_argument('-dir', '--model-dir', required=True,
help='model directory containing' +
'*.npy weight files dumped from the trained model')
parser.add_argument('-m', '--max-val', type=checkIntNneg, default=127,
help='this represents the maximum possible value ' +
'in model, essentially the byte complexity, ' +
'127=> 1 byte is default')
parser.add_argument('-s', '--scalar-scale', type=checkIntNneg,
default=1000, help='maximum granularity/decimals ' +
'you wish to get when quantising simple sclars ' +
'involved. Default is 1000')

return parser.parse_args()


def createTimeStampDir(dataDir, cell):
'''
Creates a Directory with timestamp as it's name
'''
if os.path.isdir(os.path.join(dataDir, str(cell) + 'Results')) is False:
try:
os.mkdir(os.path.join(dataDir, str(cell) + 'Results'))
except OSError:
print("Creation of the directory %s failed" %
os.path.join(dataDir, str(cell) + 'Results'))

currDir = os.path.join(str(cell) + 'Results',
datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
if os.path.isdir(os.path.join(dataDir, currDir)) is False:
try:
os.mkdir(os.path.join(dataDir, currDir))
except OSError:
print("Creation of the directory %s failed" %
os.path.join(dataDir, currDir))
else:
return (os.path.join(dataDir, currDir))
return None


def preProcessData(dataDir):
'''
Function to pre-process input data
Expects a .npy file of form [lbl feats] for each datapoint,
feats is timesteps*inputDims, flattened across timestep dimension.
So input of 1st timestep followed by second and so on.
Outputs train and test set datapoints
dataDimension, numClasses are inferred directly
'''
train = np.load(os.path.join(dataDir, 'train.npy'))
test = np.load(os.path.join(dataDir, 'test.npy'))

dataDimension = int(train.shape[1]) - 1

Xtrain = train[:, 1:dataDimension + 1]
Ytrain_ = train[:, 0]
numClasses = max(Ytrain_) - min(Ytrain_) + 1

Xtest = test[:, 1:dataDimension + 1]
Ytest_ = test[:, 0]

numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1))

# Mean Var Normalisation
mean = np.mean(Xtrain, 0)
std = np.std(Xtrain, 0)
std[std[:] < 0.000001] = 1
Xtrain = (Xtrain - mean) / std

Xtest = (Xtest - mean) / std
# End Mean Var normalisation

lab = Ytrain_.astype('uint8')
lab = np.array(lab) - min(lab)

lab_ = np.zeros((Xtrain.shape[0], numClasses))
lab_[np.arange(Xtrain.shape[0]), lab] = 1
Ytrain = lab_

lab = Ytest_.astype('uint8')
lab = np.array(lab) - min(lab)

lab_ = np.zeros((Xtest.shape[0], numClasses))
lab_[np.arange(Xtest.shape[0]), lab] = 1
Ytest = lab_

return dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std


def dumpCommand(list, currDir):
'''
Dumps the current command to a file for further use
'''
commandFile = open(os.path.join(currDir, 'command.txt'), 'w')
command = "python"

command = command + " " + ' '.join(list)
commandFile.write(command)

commandFile.flush()
commandFile.close()


def saveMeanStd(mean, std, currDir):
'''
Function to save Mean and Std vectors
'''
np.save(os.path.join(currDir, 'mean.npy'), mean)
np.save(os.path.join(currDir, 'std.npy'), std)


def saveJSon(data, filename):
with open(filename, "w") as outfile:
json.dump(data, outfile, indent=2)
41 changes: 41 additions & 0 deletions tf2.0/examples/FastCells/process_usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
#
# Processing the USPS Data. It is assumed that the data is already
# downloaded.

import subprocess
import os
import numpy as np
from sklearn.datasets import load_svmlight_file
import sys

def processData(workingDir, downloadDir):
def loadLibSVMFile(file):
data = load_svmlight_file(file)
features = data[0]
labels = data[1]
retMat = np.zeros([features.shape[0], features.shape[1] + 1])
retMat[:, 0] = labels
retMat[:, 1:] = features.todense()
return retMat

path = workingDir + '/' + downloadDir
path = os.path.abspath(path)
trf = path + '/train.txt'
tsf = path + '/test.txt'
assert os.path.isfile(trf), 'File not found: %s' % trf
assert os.path.isfile(tsf), 'File not found: %s' % tsf
train = loadLibSVMFile(trf)
test = loadLibSVMFile(tsf)
np.save(path + '/train.npy', train)
np.save(path + '/test.npy', test)

if __name__ == '__main__':
# Configuration
workingDir = './'
downloadDir = 'usps10'
# End config
print("Processing data")
processData(workingDir, downloadDir)
print("Done")
135 changes: 135 additions & 0 deletions tf2.0/examples/FastCells/quantizeFastModels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import helpermethods
import os
import numpy as np


def sigmoid(x):
return 1 / (1 + np.exp(-x))


def min_max(A, name):
print(name + " has max: " + str(np.max(A)) + " min: " + str(np.min(A)))
return np.max([np.abs(np.max(A)), np.abs(np.min(A))])


def quantizeFastModels(modelDir, maxValue=127, scalarScaleFactor=1000):
ls = os.listdir(modelDir)
paramNameList = []
paramWeightList = []
paramLimitList = []

classifierNameList = []
classifierWeightList = []
classifierLimitList = []

scalarNameList = []
scalarWeightList = []

for file in ls:
if file.endswith("npy"):
if file.startswith("W"):
paramNameList.append(file)
temp = np.load(os.path.join(modelDir, file))
paramWeightList.append(temp)
paramLimitList.append(min_max(temp, file))
elif file.startswith("U"):
paramNameList.append(file)
temp = np.load(os.path.join(modelDir, file))
paramWeightList.append(temp)
paramLimitList.append(min_max(temp, file))
elif file.startswith("B"):
paramNameList.append(file)
temp = np.load(os.path.join(modelDir, file))
paramWeightList.append(temp)
paramLimitList.append(min_max(temp, file))
elif file.startswith("FC"):
classifierNameList.append(file)
temp = np.load(os.path.join(modelDir, file))
classifierWeightList.append(temp)
classifierLimitList.append(min_max(temp, file))
elif file.startswith("mean") or file.startswith("std"):
continue
else:
scalarNameList.append(file)
scalarWeightList.append(np.load(os.path.join(modelDir, file)))

paramLimit = np.max(paramLimitList)
classifierLimit = np.max(classifierLimitList)

paramScaleFactor = np.round((2.0 * maxValue + 1.0) / (2.0 * paramLimit))
classifierScaleFactor = (2.0 * maxValue + 1.0) / (2.0 * classifierLimit)

quantParamWeights = []
for param in paramWeightList:
temp = np.round(paramScaleFactor * param)
temp[temp[:] > maxValue] = maxValue
temp[temp[:] < -maxValue] = -1 * (maxValue + 1)

if maxValue <= 127:
temp = temp.astype('int8')
elif maxValue <= 32767:
temp = temp.astype('int16')
else:
temp = temp.astype('int32')

quantParamWeights.append(temp)

quantClassifierWeights = []
for param in classifierWeightList:
temp = np.round(classifierScaleFactor * param)
temp[temp[:] > maxValue] = maxValue
temp[temp[:] < -maxValue] = -1 * (maxValue + 1)

if maxValue <= 127:
temp = temp.astype('int8')
elif maxValue <= 32767:
temp = temp.astype('int16')
else:
temp = temp.astype('int32')

quantClassifierWeights.append(temp)

quantScalarWeights = []
for scalar in scalarWeightList:
quantScalarWeights.append(
np.round(scalarScaleFactor * sigmoid(scalar)).astype('int32'))

quantModelDir = os.path.join(modelDir, 'QuantizedFastModel')
if not os.path.isdir(quantModelDir):
try:
os.makedirs(quantModelDir, exist_ok=True)
except OSError:
print("Creation of the directory %s failed" % quantModelDir)

np.save(os.path.join(quantModelDir, "paramScaleFactor.npy"),
paramScaleFactor.astype('int32'))
np.save(os.path.join(quantModelDir, "classifierScaleFactor.npy"),
classifierScaleFactor)
np.save(os.path.join(quantModelDir, "scalarScaleFactor"), scalarScaleFactor)

for i in range(0, len(scalarNameList)):
np.save(os.path.join(quantModelDir, "q" +
scalarNameList[i]), quantScalarWeights[i])

for i in range(len(classifierNameList)):
np.save(os.path.join(quantModelDir, "q" +
classifierNameList[i]), quantClassifierWeights[i])

for i in range(len(paramNameList)):
np.save(os.path.join(quantModelDir, "q" + paramNameList[i]),
quantParamWeights[i])

print("\n\nQuantized Model Dir: " + quantModelDir)


def main():
args = helpermethods.getQuantArgs()
quantizeFastModels(args.model_dir, int(
args.max_val), int(args.scalar_scale))


if __name__ == '__main__':
main()
54 changes: 54 additions & 0 deletions tf2.0/examples/ProtoNN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Tensorflow ProtoNN Examples

This directory includes an example [notebook](protoNN_example.ipynb) and a
command line execution script of ProtoNN developed as part of EdgeML. The
example is based on the USPS dataset.

`edgeml.graph.protoNN` implements the ProtoNN prediction graph in Tensorflow.
The training routine for ProtoNN is decoupled from the forward graph to
facilitate a plug and play behaviour wherein ProtoNN can be combined with or
used as a final layer classifier for other architectures (RNNs, CNNs). The
training routine is implemented in `edgeml.trainer.protoNNTrainer`.

Note that, `protoNN_example.py` assumes the data to be in a specific format. It
is assumed that train and test data is contained in two files, `train.npy` and
`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples,
numberOfFeatures + 1]`. The first column of each matrix is assumed to contain
label information. For an N-Class problem, we assume the labels are integers
from 0 through N-1.

**Tested With:** Tensorflow >1.6 with Python 2 and Python 3

## Fetching Data

The script - [fetch_usps.py](fetch_usps.py), can be used to automatically
download and [process_usps.py](process_usps.py), can be used to process the
data into the required format.
To run this script, please use:

python fetch_usps.py
python process_usps.py


## Running the ProtoNN execution script

Along with the example notebook, a command line execution script for ProtoNN is
provided in `protoNN_example.py`. After the USPS data has been setup, this
script can be used with the following command:

```
python protoNN_example.py \
--data-dir ./usps10 \
--projection-dim 60 \
--num-prototypes 80 \
--gamma 0.0015 \
--learning-rate 0.1 \
--epochs 200 \
--val-step 10 \
--output-dir ./
```

You can expect a test set accuracy of about 92.5%.

Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT license.
64 changes: 64 additions & 0 deletions tf2.0/examples/ProtoNN/fetch_usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
#
# Setting up the USPS Data.

import subprocess
import os
import numpy as np
from sklearn.datasets import load_svmlight_file
import sys

def downloadData(workingDir, downloadDir, linkTrain, linkTest):
def runcommand(command):
p = subprocess.Popen(command.split(), stdout=subprocess.PIPE)
output, error = p.communicate()
assert(p.returncode == 0), 'Command failed: %s' % command

path = workingDir + '/' + downloadDir
path = os.path.abspath(path)
try:
os.mkdir(path)
except OSError:
print("Could not create %s. Make sure the path does" % path)
print("not already exist and you have permisions to create it.")
return False
cwd = os.getcwd()
os.chdir(path)
print("Downloading data")
command = 'wget %s' % linkTrain
runcommand(command)
command = 'wget %s' % linkTest
runcommand(command)
print("Extracting data")
command = 'bzip2 -d usps.bz2'
runcommand(command)
command = 'bzip2 -d usps.t.bz2'
runcommand(command)
command = 'mv usps train.txt'
runcommand(command)
command = 'mv usps.t test.txt'
runcommand(command)
os.chdir(cwd)
return True

if __name__ == '__main__':
workingDir = './'
downloadDir = 'usps10'
linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2'
linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2'
failureMsg = '''
Download Failed!
To manually perform the download
\t1. Create a new empty directory named `usps10`.
\t2. Download the data from the following links into the usps10 directory.
\t\tTest: %s
\t\tTrain: %s
\t3. Extract the downloaded files.
\t4. Rename `usps` to `train.txt` and,
\t5. Rename `usps.t` to `test.txt
''' % (linkTrain, linkTest)

if not downloadData(workingDir, downloadDir, linkTrain, linkTest):
exit(failureMsg)
print("Done")
206 changes: 206 additions & 0 deletions tf2.0/examples/ProtoNN/helpermethods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

from __future__ import print_function
import sys
import os
import numpy as np
import tensorflow as tf
import edgeml.utils as utils
import argparse


def getModelSize(matrixList, sparcityList, expected=True, bytesPerVar=4):
'''
expected: Expected size according to the parameters set. The number of
zeros could actually be more than that is required to satisfy the
sparsity constraint.
'''
nnzList, sizeList, isSparseList = [], [], []
hasSparse = False
for i in range(len(matrixList)):
A, s = matrixList[i], sparcityList[i]
assert A.ndim == 2
assert s >= 0
assert s <= 1
nnz, size, sparse = utils.countnnZ(A, s, bytesPerVar=bytesPerVar)
nnzList.append(nnz)
sizeList.append(size)
hasSparse = (hasSparse or sparse)

totalnnZ = np.sum(nnzList)
totalSize = np.sum(sizeList)
if expected:
return totalnnZ, totalSize, hasSparse
numNonZero = 0
totalSize = 0
hasSparse = False
for i in range(len(matrixList)):
A, s = matrixList[i], sparcityList[i]
numNonZero_ = np.count_nonzero(A)
numNonZero += numNonZero_
hasSparse = (hasSparse or (s < 0.5))
if s <= 0.5:
totalSize += numNonZero_ * 2 * bytesPerVar
else:
totalSize += A.size * bytesPerVar
return numNonZero, totalSize, hasSparse


def getGamma(gammaInit, projectionDim, dataDim, numPrototypes, x_train):
if gammaInit is None:
print("Using median heuristic to estimate gamma.")
gamma, W, B = utils.medianHeuristic(x_train, projectionDim,
numPrototypes)
print("Gamma estimate is: %f" % gamma)
return W, B, gamma
return None, None, gammaInit

def to_onehot(y, numClasses, minlabel = None):
'''
If the y labelling does not contain the minimum label info, use min-label to
provide this value.
'''
lab = y.astype('uint8')
if minlabel is None:
minlabel = np.min(lab)
minlabel = int(minlabel)
lab = np.array(lab) - minlabel
lab_ = np.zeros((y.shape[0], numClasses))
lab_[np.arange(y.shape[0]), lab] = 1
return lab_

def preprocessData(train, test):
'''
Loads data from the dataDir and does some initial preprocessing
steps. Data is assumed to be contained in two files,
train.npy and test.npy. Each containing a 2D numpy array of dimension
[numberOfExamples, numberOfFeatures + 1]. The first column of each
matrix is assumed to contain label information.
For an N-Class problem, we assume the labels are integers from 0 through
N-1.
'''
dataDimension = int(train.shape[1]) - 1
x_train = train[:, 1:dataDimension + 1]
y_train_ = train[:, 0]
x_test = test[:, 1:dataDimension + 1]
y_test_ = test[:, 0]

numClasses = max(y_train_) - min(y_train_) + 1
numClasses = max(numClasses, max(y_test_) - min(y_test_) + 1)
numClasses = int(numClasses)

# mean-var
mean = np.mean(x_train, 0)
std = np.std(x_train, 0)
std[std[:] < 0.000001] = 1
x_train = (x_train - mean) / std
x_test = (x_test - mean) / std

# one hot y-train
lab = y_train_.astype('uint8')
lab = np.array(lab) - min(lab)
lab_ = np.zeros((x_train.shape[0], numClasses))
lab_[np.arange(x_train.shape[0]), lab] = 1
y_train = lab_

# one hot y-test
lab = y_test_.astype('uint8')
lab = np.array(lab) - min(lab)
lab_ = np.zeros((x_test.shape[0], numClasses))
lab_[np.arange(x_test.shape[0]), lab] = 1
y_test = lab_

return dataDimension, numClasses, x_train, y_train, x_test, y_test



def getProtoNNArgs():
def checkIntPos(value):
ivalue = int(value)
if ivalue <= 0:
raise argparse.ArgumentTypeError(
"%s is an invalid positive int value" % value)
return ivalue

def checkIntNneg(value):
ivalue = int(value)
if ivalue < 0:
raise argparse.ArgumentTypeError(
"%s is an invalid non-neg int value" % value)
return ivalue

def checkFloatNneg(value):
fvalue = float(value)
if fvalue < 0:
raise argparse.ArgumentTypeError(
"%s is an invalid non-neg float value" % value)
return fvalue

def checkFloatPos(value):
fvalue = float(value)
if fvalue <= 0:
raise argparse.ArgumentTypeError(
"%s is an invalid positive float value" % value)
return fvalue

'''
Parse protoNN commandline arguments
'''
parser = argparse.ArgumentParser(
description='Hyperparameters for ProtoNN Algorithm')

msg = 'Data directory containing train and test data. The '
msg += 'data is assumed to be saved as 2-D numpy matrices with '
msg += 'names `train.npy` and `test.npy`, of dimensions\n'
msg += '\t[numberOfInstances, numberOfFeatures + 1].\n'
msg += 'The first column of each file is assumed to contain label information.'
msg += ' For a N-class problem, labels are assumed to be integers from 0 to'
msg += ' N-1 (inclusive).'
parser.add_argument('-d', '--data-dir', required=True, help=msg)
parser.add_argument('-l', '--projection-dim', type=checkIntPos, default=10,
help='Projection Dimension.')
parser.add_argument('-p', '--num-prototypes', type=checkIntPos, default=20,
help='Number of prototypes.')
parser.add_argument('-g', '--gamma', type=checkFloatPos, default=None,
help='Gamma for Gaussian kernel. If not provided, ' +
'median heuristic will be used to estimate gamma.')

parser.add_argument('-e', '--epochs', type=checkIntPos, default=100,
help='Total training epochs.')
parser.add_argument('-b', '--batch-size', type=checkIntPos, default=32,
help='Batch size for each pass.')
parser.add_argument('-r', '--learning-rate', type=checkFloatPos,
default=0.001,
help='Initial Learning rate for ADAM Optimizer.')

parser.add_argument('-rW', type=float, default=0.000,
help='Coefficient for l2 regularizer for predictor' +
' parameter W ' + '(default = 0.0).')
parser.add_argument('-rB', type=float, default=0.00,
help='Coefficient for l2 regularizer for predictor' +
' parameter B ' + '(default = 0.0).')
parser.add_argument('-rZ', type=float, default=0.00,
help='Coefficient for l2 regularizer for predictor' +
'parameter Z ' +
'(default = 0.0).')

parser.add_argument('-sW', type=float, default=1.000,
help='Sparsity constraint for predictor parameter W ' +
'(default = 1.0, i.e. dense matrix).')
parser.add_argument('-sB', type=float, default=1.00,
help='Sparsity constraint for predictor parameter B ' +
'(default = 1.0, i.e. dense matrix).')
parser.add_argument('-sZ', type=float, default=1.00,
help='Sparsity constraint for predictor parameter Z ' +
'(default = 1.0, i.e. dense matrix).')
parser.add_argument('-pS', '--print-step', type=int, default=200,
help='The number of update steps between print ' +
'calls to console.')
parser.add_argument('-vS', '--val-step', type=int, default=3,
help='The number of epochs between validation' +
'performance evaluation')
parser.add_argument('-o', '--output-dir', type=str, default='./',
help='Output directory to dump model matrices.')
return parser.parse_args()
51 changes: 51 additions & 0 deletions tf2.0/examples/ProtoNN/process_usps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
#
# Processing the USPS Data. It is assumed that the data is already
# downloaded.

import subprocess
import os
import numpy as np
from sklearn.datasets import load_svmlight_file
import sys
from helpermethods import preprocessData

def processData(workingDir, downloadDir):
def loadLibSVMFile(file):
data = load_svmlight_file(file)
features = data[0]
labels = data[1]
retMat = np.zeros([features.shape[0], features.shape[1] + 1])
retMat[:, 0] = labels
retMat[:, 1:] = features.todense()
return retMat

path = workingDir + '/' + downloadDir
path = os.path.abspath(path)
trf = path + '/train.txt'
tsf = path + '/test.txt'
assert os.path.isfile(trf), 'File not found: %s' % trf
assert os.path.isfile(tsf), 'File not found: %s' % tsf
train = loadLibSVMFile(trf)
test = loadLibSVMFile(tsf)
np.save(path + '/train_unnormalized.npy', train)
np.save(path + '/test_unnormalized.npy', test)
_, _, x_train, y_train, x_test, y_test = preprocessData(train, test)

y_ = np.expand_dims(np.argmax(y_train, axis=1), axis=1)
train = np.concatenate([y_, x_train], axis=1)
np.save(path + '/train.npy', train)
y_ = np.expand_dims(np.argmax(y_test, axis=1), axis=1)
test = np.concatenate([y_, x_test], axis=1)
np.save(path + '/test.npy', test)


if __name__ == '__main__':
# Configuration
workingDir = './'
downloadDir = 'usps10'
# End config
print("Processing data")
processData(workingDir, downloadDir)
print("Done")
449 changes: 449 additions & 0 deletions tf2.0/examples/ProtoNN/protoNN_example.ipynb

Large diffs are not rendered by default.

88 changes: 88 additions & 0 deletions tf2.0/examples/ProtoNN/protoNN_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

from __future__ import print_function
import sys
import os
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from edgeml.trainer.protoNNTrainer import ProtoNNTrainer
from edgeml.graph.protoNN import ProtoNN
import edgeml.utils as utils
import helpermethods as helper


def main():
config = helper.getProtoNNArgs()
# Get hyper parameters
DATA_DIR = config.data_dir
PROJECTION_DIM = config.projection_dim
NUM_PROTOTYPES = config.num_prototypes
REG_W = config.rW
REG_B = config.rB
REG_Z = config.rZ
SPAR_W = config.sW
SPAR_B = config.sB
SPAR_Z = config.sZ
LEARNING_RATE = config.learning_rate
NUM_EPOCHS = config.epochs
BATCH_SIZE = config.batch_size
PRINT_STEP = config.print_step
VAL_STEP = config.val_step
OUT_DIR = config.output_dir

# Load data
train = np.load(DATA_DIR + '/train.npy')
test = np.load(DATA_DIR + '/test.npy')
x_train, y_train = train[:, 1:], train[:, 0]
x_test, y_test = test[:, 1:], test[:, 0]
# Convert y to one-hot
minval = min(min(y_train), min(y_test))
numClasses = max(y_train) - min(y_train) + 1
numClasses = max(numClasses, max(y_test) - min(y_test) + 1)
numClasses = int(numClasses)
y_train = helper.to_onehot(y_train, numClasses, minlabel=minval)
y_test = helper.to_onehot(y_test, numClasses, minlabel=minval)
dataDimension = x_train.shape[1]

W, B, gamma = helper.getGamma(config.gamma, PROJECTION_DIM, dataDimension,
NUM_PROTOTYPES, x_train)

# Setup input and train protoNN
X = tf.placeholder(tf.float32, [None, dataDimension], name='X')
Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')
protoNN = ProtoNN(dataDimension, PROJECTION_DIM,
NUM_PROTOTYPES, numClasses,
gamma, W=W, B=B)
trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,
SPAR_W, SPAR_B, SPAR_Z,
LEARNING_RATE, X, Y, lossType='xentropy')
sess = tf.Session()
trainer.train(BATCH_SIZE, NUM_EPOCHS, sess, x_train, x_test,
y_train, y_test, printStep=PRINT_STEP, valStep=VAL_STEP)

# Print some summary metrics
acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test})
# W, B, Z are tensorflow graph nodes
W, B, Z, gamma = protoNN.getModelMatrices()
matrixList = sess.run([W, B, Z])
gamma = sess.run(gamma)
sparcityList = [SPAR_W, SPAR_B, SPAR_Z]
nnz, size, sparse = helper.getModelSize(matrixList, sparcityList)
print("Final test accuracy", acc)
print("Model size constraint (Bytes): ", size)
print("Number of non-zeros: ", nnz)
nnz, size, sparse = helper.getModelSize(matrixList, sparcityList,
expected=False)
print("Actual model size: ", size)
print("Actual non-zeros: ", nnz)
print("Saving model matrices to: ", OUT_DIR)
np.save(OUT_DIR + '/W.npy', matrixList[0])
np.save(OUT_DIR + '/B.npy', matrixList[1])
np.save(OUT_DIR + '/Z.npy', matrixList[2])
np.save(OUT_DIR + '/gamma.npy', gamma)


if __name__ == '__main__':
main()
7 changes: 7 additions & 0 deletions tf2.0/requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
jupyter==1.0.0
numpy==1.14.5
pandas==0.23.4
scikit-learn==0.19.2
scipy==1.1.0
tensorflow==2.0.0-alpha0
requests
7 changes: 7 additions & 0 deletions tf2.0/requirements-gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
jupyter==1.0.0
numpy==1.14.5
pandas==0.23.4
scikit-learn==0.19.2
scipy==1.1.0
tensorflow-gpu==2.0.0-alpha0
requests
9 changes: 9 additions & 0 deletions tf2.0/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from distutils.core import setup

setup(
name='edgeml',
version='0.2',
packages=['edgeml', ],
license='MIT License',
long_description=open('../License.txt').read(),
)