diff --git a/Files/mnist_keras_data_handler.py b/Files/mnist_keras_data_handler.py index 4fe79d38..9d1bef94 100644 --- a/Files/mnist_keras_data_handler.py +++ b/Files/mnist_keras_data_handler.py @@ -3,7 +3,7 @@ import pickle import numpy as np -from ibmfl.data.data_handler import DataHandler +from ibm_watson_machine_learning.federated_learning.data_handler import DataHandler logger = logging.getLogger(__name__) @@ -38,14 +38,14 @@ def get_data(self, nb_points=500): logger.info( 'Loaded training data from ' + str(self.train_file_name)) with open(self.train_file_name, 'rb') as f: - (x_train, y_train)= pickle.load(f) + (self.x_train, self.y_train)= pickle.load(f) logger.info( 'Loaded test data from ' + str(self.test_file_name)) with open(self.test_file_name, 'rb') as f: - (x_test, y_test)= pickle.load(f) + (self.x_test, self.y_test)= pickle.load(f) - x_train = x_train / 255.0 - x_test = x_test / 255.0 + self.x_train = self.x_train / 255.0 + self.x_test = self.x_test / 255.0 except Exception: @@ -55,11 +55,11 @@ def get_data(self, nb_points=500): # Add a channels dimension import tensorflow as tf - x_train = x_train[..., tf.newaxis] - x_test = x_test[..., tf.newaxis] + self.x_train = self.x_train[..., tf.newaxis] + self.x_test = self.x_test[..., tf.newaxis] - print('x_train shape:', x_train.shape) - print(x_train.shape[0], 'train samples') - print(x_test.shape[0], 'test samples') + print('self.x_train shape:', self.x_train.shape) + print(self.x_train.shape[0], 'train samples') + print(self.x_test.shape[0], 'test samples') - return (x_train, y_train), (x_test, y_test) + return (self.x_train, self.y_train), (self.x_test, self.y_test)