Skip to content

Sam 13 softsign #380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Thresholded ReLU maintains the computational efficiency of standard ReLU while a
## Plots
<img src="../../images/activation-functions/thresholded-relu.png" alt="Thresholded ReLU Function" width="500" height="auto">

<img src="../../images/activation-functions/thresholded-derivative.png" alt="Thresholded ReLU Derivative" width="500" height="auto">
<img src="../../images/activation-functions/thresholded-relu-derivative.png" alt="Thresholded ReLU Derivative" width="500" height="auto">

## Example
```php
Expand Down
61 changes: 35 additions & 26 deletions src/NeuralNet/ActivationFunctions/Softsign/Softsign.php
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
<?php

namespace Rubix\ML\NeuralNet\ActivationFunctions;
declare(strict_types=1);

use Tensor\Matrix;
namespace Rubix\ML\NeuralNet\ActivationFunctions\Softsign;

use NumPower;
use NDArray;
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction;
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative;

/**
* Softsign
Expand All @@ -17,52 +22,56 @@
* @category Machine Learning
* @package Rubix/ML
* @author Andrew DalPino
* @author Samuel Akopyan <[email protected]>
*/
class Softsign implements ActivationFunction
class Softsign implements ActivationFunction, IBufferDerivative
{
/**
* Compute the activation.
*
* @internal
* f(x) = x / (1 + |x|)
*
* @param Matrix $input
* @return Matrix
* @param NDArray $input
* @return NDArray
*/
public function activate(Matrix $input) : Matrix
public function activate(NDArray $input) : NDArray
{
return $input / (1 + NumPower::abs($input));
// Calculate |x|
$absInput = NumPower::abs($input);

// Calculate 1 + |x|
$denominator = NumPower::add(1.0, $absInput);

// Calculate x / (1 + |x|)
return NumPower::divide($input, $denominator);
}

/**
* Calculate the derivative of the activation.
*
* @internal
* f'(x) = 1 / (1 + |x|)²
*
* @param Matrix $input
* @param Matrix $output
* @return Matrix
* @param NDArray $input
* @return NDArray
*/
public function differentiate(Matrix $input, Matrix $output) : Matrix
public function differentiate(NDArray $input) : NDArray
{
return $input->map([$this, '_differentiate']);
}
// Calculate |x|
$absInput = NumPower::abs($input);

/**
* @internal
*
* @param float $input
* @return float
*/
public function _differentiate(float $input) : float
{
return 1 / (1 + NumPower::abs($input)) ** 2;
// Calculate 1 + |x|
$onePlusAbs = NumPower::add(1.0, $absInput);

// Calculate (1 + |x|)²
$denominator = NumPower::multiply($onePlusAbs, $onePlusAbs);

// Calculate 1 / (1 + |x|)²
return NumPower::divide(1.0, $denominator);
}

/**
* Return the string representation of the object.
*
* @internal
*
* @return string
*/
public function __toString() : string
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<?php

declare(strict_types=1);

namespace Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU\Exceptions;

use InvalidArgumentException;

/**
* Invalid Threshold Exception
*
* @category Machine Learning
* @package Rubix/ML
* @author Samuel Akopyan <[email protected]>
*/
class InvalidThresholdException extends InvalidArgumentException
{
//
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
<?php

declare(strict_types=1);

namespace Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU;

use Tensor\Matrix;
use Rubix\ML\Exceptions\InvalidArgumentException;
use NumPower;
use NDArray;
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\ActivationFunction;
use Rubix\ML\NeuralNet\ActivationFunctions\Base\Contracts\IBufferDerivative;
use Rubix\ML\NeuralNet\ActivationFunctions\ThresholdedReLU\Exceptions\InvalidThresholdException;

/**
* Thresholded ReLU
Expand All @@ -18,8 +23,9 @@
* @category Machine Learning
* @package Rubix/ML
* @author Andrew DalPino
* @author Samuel Akopyan <[email protected]>
*/
class ThresholdedReLU implements ActivationFunction
class ThresholdedReLU implements ActivationFunction, IBufferDerivative
{
/**
* The input value necessary to trigger an activation.
Expand All @@ -29,14 +35,17 @@ class ThresholdedReLU implements ActivationFunction
protected float $threshold;

/**
* @param float $threshold
* @throws InvalidArgumentException
* Class constructor.
*
* @param float $threshold The input value necessary to trigger an activation.
* @throws InvalidThresholdException
*/
public function __construct(float $threshold = 1.0)
{
if ($threshold < 0.0) {
throw new InvalidArgumentException('Threshold must be'
. " positive, $threshold given.");
throw new InvalidThresholdException(
message: "Threshold must be positive, $threshold given."
);
}

$this->threshold = $threshold;
Expand All @@ -45,35 +54,37 @@ public function __construct(float $threshold = 1.0)
/**
* Compute the activation.
*
* @internal
* f(x) = x if x > threshold, 0 otherwise
*
* @param Matrix $input
* @return Matrix
* @param NDArray $input
* @return NDArray
*/
public function activate(Matrix $input) : Matrix
public function activate(NDArray $input) : NDArray
{
return NumPower::greater($input, $this->threshold) * $input;
// Create a mask where input > threshold
$mask = NumPower::greater($input, $this->threshold);

// Apply the mask to the input
return NumPower::multiply($input, $mask);
}

/**
* Calculate the derivative of the activation.
*
* @internal
* f'(x) = 1 if x > threshold, 0 otherwise
*
* @param Matrix $input
* @param Matrix $output
* @return Matrix
* @param NDArray $input
* @return NDArray
*/
public function differentiate(Matrix $input, Matrix $output) : Matrix
public function differentiate(NDArray $input) : NDArray
{
// The derivative is 1 where input > threshold, 0 otherwise
return NumPower::greater($input, $this->threshold);
}

/**
* Return the string representation of the object.
*
* @internal
*
* @return string
*/
public function __toString() : string
Expand Down
Loading
Loading