Skip to content

Precision, Recall and f1 score for multiclass classification #6507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
callicles opened this issue May 4, 2017 · 17 comments
Closed

Precision, Recall and f1 score for multiclass classification #6507

callicles opened this issue May 4, 2017 · 17 comments

Comments

@callicles
Copy link

callicles commented May 4, 2017

Hi!

Keras: 2.0.4

I recently spent some time trying to build metrics for multi-class classification outputting a per class precision, recall and f1 score.

I want to have a metric that's correctly aggregating the values out of the different batches and gives me a result on the global training process with a per class granularity.

The way I understand it is currently working is by calling the function declared inside the metric argument of the compile function after every batch to output an estimated metric on the batch that is stored in a logs object.

training.py

            for metric in output_metrics:
                if metric == 'accuracy' or metric == 'acc':
                    # custom handling of accuracy
                    # (because of class mode duality)
                    output_shape = self.internal_output_shapes[i]
                    acc_fn = None
                    if (output_shape[-1] == 1 or
                       self.loss_functions[i] == losses.binary_crossentropy):
                        # case: binary accuracy
                        acc_fn = metrics_module.binary_accuracy
                    elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
                        # case: categorical accuracy with sparse targets
                        acc_fn = metrics_module.sparse_categorical_accuracy
                    else:
                        acc_fn = metrics_module.categorical_accuracy

                    masked_fn = _masked_objective(acc_fn)
                    append_metric(i, 'acc', masked_fn(y_true, y_pred, mask=masks[i]))
                else:
                    metric_fn = metrics_module.get(metric)
                    masked_metric_fn = _masked_objective(metric_fn)
                    metric_result = masked_metric_fn(y_true, y_pred, mask=masks[i])
                    metric_result = {
                        metric_fn.__name__: metric_result
                    }
                    for name, tensor in six.iteritems(metric_result):
                        append_metric(i, name, tensor)

Here there is a call to _masked_objective which is defined as:

_masked_objective

def _masked_objective(fn):
    """Adds support for masking to an objective function.

    It transforms an objective function `fn(y_true, y_pred)`
    into a cost-masked objective function
    `fn(y_true, y_pred, mask)`.

    # Arguments
        fn: The objective function to wrap,
            with signature `fn(y_true, y_pred)`.

    # Returns
        A function with signature `fn(y_true, y_pred, mask)`.
    """
    def masked(y_true, y_pred, mask=None):
        """Wrapper function.

        # Arguments
            y_true: `y_true` argument of `fn`.
            y_pred: `y_pred` argument of `fn`.
            mask: Mask tensor.

        # Returns
            Scalar tensor.
        """
        # score_array has ndim >= 2
        score_array = fn(y_true, y_pred)
        if mask is not None:
            # Cast the mask to floatX to avoid float64 upcasting in theano
            mask = K.cast(mask, K.floatx())
            # mask should have the same shape as score_array
            score_array *= mask
            #  the loss per batch should be proportional
            #  to the number of unmasked samples.
            score_array /= K.mean(mask)

        return K.mean(score_array)
    return masked

Which averages whatever tensor comes out of the metrics.

Here is how I was thinking about implementing the precision, recall and f score.

I was planning to use the metrics callback to accumulate true positives, Positives, and false negatives per class counts. Accumulate them within the logs and then compute the precision, recall and f1 score within the callback.

The problem with that approach is that the tensor that I output with counts from the metrics gets averaged before getting to the Callback.

My change request is thus the following, could we remove that average from the core and metrics and let the Callbacks handle the data that has been returned from the metrics function however they want?

I really think this is important since it now feels a bit like flying blind without having per class metrics on multi class classification.

I can also contribute code on whatever solution we come up with.

Thank you

@ribx
Copy link

ribx commented Jun 16, 2017

I tried to do the same thing. Maybe a "callback" added to the "fit" function could be a solution?

@callicles
Copy link
Author

callicles commented Jun 16, 2017

Let's say you want a per class accuracy.

The way we have hacked internally is to have a function to generates accuracy metrics function for each class and we pass them as argument to the metrics arguments when calling compile.

It is kind of crappy but it works

@ribx
Copy link

ribx commented Jun 17, 2017

I added the f1 metrics: (note that this works only for binary problems so far!)

        if metric == 'accuracy' or metric == 'acc':
          # custom handling of accuracy
          # (because of class mode duality)
          output_shape = self.internal_output_shapes[i]
          acc_fn = None
          if (output_shape[-1] == 1 or
              self.loss_functions[i] == losses.binary_crossentropy):
            # case: binary accuracy
            acc_fn = metrics_module.binary_accuracy
          elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
            # case: categorical accuracy with sparse targets
            acc_fn = metrics_module.sparse_categorical_accuracy
          else:
            acc_fn = metrics_module.categorical_accuracy

          masked_fn = _masked_objective(acc_fn)
          append_metric(i, 'acc', masked_fn(y_true, y_pred, mask=masks[i]))
        elif metric in ['f1','f1-score']:
          if (output_shape[-1] == 1 or
              self.loss_functions[i] == losses.binary_crossentropy):

            def true_pos(y_true, y_pred):
              return K.sum(y_true * K.round(y_pred))

            def false_pos(y_true, y_pred):
              return K.sum(y_true * (1. - K.round(y_pred)))

            def false_neg(y_true, y_pred):
              return K.sum((1. - y_true) * K.round(y_pred))

            def precision(y_true, y_pred):
              return true_pos(y_true, y_pred) / \
                     (true_pos(y_true, y_pred) + false_pos(y_true, y_pred))

            def recall(y_true, y_pred):
              return true_pos(y_true, y_pred) / \
                     (true_pos(y_true, y_pred) + false_neg(y_true, y_pred))

            def f1_score(y_true, y_pred):
              return 2. / (1. / recall(y_true, y_pred) + 1. / precision(y_true, y_pred))

            for fn in [precision, recall, f1_score]:
              append_metric(i, fn.__name__, fn(y_true, y_pred))
        else:
          metric_fn = metrics_module.get(metric)
          masked_metric_fn = _masked_objective(metric_fn)
          metric_result = masked_metric_fn(y_true, y_pred, mask=masks[i])
          metric_result = {metric_fn.__name__: metric_result}
          for name, tensor in six.iteritems(metric_result):
            append_metric(i, name, tensor)```

@trevorwelch
Copy link

trevorwelch commented Aug 16, 2017

I use these custom metrics for binary classification in Keras:

def precision(y_true, y_pred):
    # Calculates the precision
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision


def recall(y_true, y_pred):
    # Calculates the recall
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def fbeta_score(y_true, y_pred, beta=1):
    # Calculates the F score, the weighted harmonic mean of precision and recall.

    if beta < 0:
        raise ValueError('The lowest choosable beta is zero (only precision).')
        
    # If there are no true positives, fix the F score at 0 like sklearn.
    if K.sum(K.round(K.clip(y_true, 0, 1))) == 0:
        return 0

    p = precision(y_true, y_pred)
    r = recall(y_true, y_pred)
    bb = beta ** 2
    fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())
    return fbeta_score

def fmeasure(y_true, y_pred):
    # Calculates the f-measure, the harmonic mean of precision and recall.
    return fbeta_score(y_true, y_pred, beta=1)

Source

And then in your compile:

model.compile( 
    optimizer=Adam(), 
    loss='binary_crossentropy',
    metrics = ['accuracy', 
                           custom_metrics.fmeasure,
                           custom_metrics.recall,
                          custom_metrics.precision]
            )

But what I would really like to have is a custom loss function that optimizes for F1_score on the minority class only with binary classification. Something like:

from sklearn.metrics import precision_recall_fscore_support

def f_score_obj(y_true, y_pred):
    y_true = K.eval(y_true)
    y_pred = K.eval(y_pred)
    precision, recall, f_score, support = precision_recall_fscore_support(y_true, y_pred)
    return K.variable(1.-f_score[1])

However, I know this is a mathematically invalid way of computing loss with regards to gradients and differentiability...

@micklexqg
Copy link

micklexqg commented Sep 4, 2017

@trevorwelch , it's batch-wise, not the global and final one.

@stale
Copy link

stale bot commented Dec 4, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@stale stale bot added the stale label Dec 4, 2017
@mushahrukhkhan
Copy link

Hello Everyone,

@trevorwelch, how could I customize these custom matrices for finding Precision@k and recall@k ???

@SamuelMarks
Copy link
Contributor

@trevorwelch Really interested in the answer to this also 👍

@trevorwelch
Copy link

@trevorwelch, how could I customize these custom matrices for finding Precision@k and recall@k

@trevorwelch Really interested in the answer to this also 👍

The code snippets that I shared above (and the code I was hoping to find [optimize F1 score for the minority class]) was for a binary classification problem.

Are you asking if the code snippets I shared above could be adapted for multilabel classification with ranking?

@SamuelMarks
Copy link
Contributor

@trevorwelch Actually interested in the binary case for now, multilabel classification problem for later.

Want to know, you have this comment:

However, I know this is a mathematically invalid way of computing loss with regards to gradients and differentiability...

Did you end up figuring out a mathematically valid approach?

@JanderHungrige
Copy link

This is still interesting. Does anyone know if multilabel classification performance per label is solved?

@puranjayr96
Copy link

Has the problem for "Precision, Recall and f1 score for multiclass classification" been solved?

@romanbsd
Copy link

You can use the metrics which were removed, if it helps:
https://github.com/keras-team/keras/blob/1c630c3e3c8969b40a47d07b9f2edda50ec69720/keras/metrics.py

@puranjayr96
Copy link

Thanks but I used the callbacks in model.fit . Here is the code I used :

class Metrics(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.confusion = []
        self.precision = []
        self.recall = []
        self.f1s = []

    def on_epoch_end(self, epoch, logs={}):
        score = np.asarray(self.model.predict(self.validation_data[0]))
        predict = np.round(np.asarray(self.model.predict(self.validation_data[0])))
        targ = self.validation_data[1]

        self.f1s.append(sklm.f1_score(targ, predict,average='micro'))
        self.confusion.append(sklm.confusion_matrix(targ.argmax(axis=1),predict.argmax(axis=1)))

        return

    history =model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size,
                       validation_data=(X_val, y_val), shuffle=True, verbose=1,callbacks=[metrics] ) 

The article on which I saw this code:
https://medium.com/@thongonary/how-to-compute-f1-score-for-each-epoch-in-keras-a1acd17715a2

@faerols
Copy link

faerols commented May 28, 2019

@romanbsd I dont think this is correct since they use round which should lead to error in case of multiclass classification when no predicted value > 0.5...

@puranjayr96 you'r code look correct but for what I know you can not save best weight when using metric in callback.. they need to be called when you compile the model

I think this question still need an answer.. that I can't provide because of my low skill :(

@isjjhang
Copy link

I think maybe the following code will work

import tensorflow_addons as tfa 
f1 = tfa.metrics.F1Score(num_classes, 'macro')
...
model.compile(..., metrics=[f1])

ref: https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score

@ying-hao
Copy link

I think maybe the following code will work

import tensorflow_addons as tfa 
f1 = tfa.metrics.F1Score(num_classes, 'macro')
...
model.compile(..., metrics=[f1])

ref: https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score

This does not work on my side. It returns TypeError: array() takes 1 positional argument but 2 were given

Has this problem been solved yet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests