diff --git a/docs/neural-network/activation-functions/thresholded-relu.md b/docs/neural-network/activation-functions/thresholded-relu.md
index 03346e0a4..0fecb02cc 100644
--- a/docs/neural-network/activation-functions/thresholded-relu.md
+++ b/docs/neural-network/activation-functions/thresholded-relu.md
@@ -22,7 +22,7 @@ Thresholded ReLU maintains the computational efficiency of standard ReLU while a
## Plots
-
+
## Example
```php
diff --git a/src/NeuralNet/ActivationFunctions/Softsign/Softsign.php b/src/NeuralNet/ActivationFunctions/Softsign/Softsign.php
index 5102a08f4..77a28c5c2 100644
--- a/src/NeuralNet/ActivationFunctions/Softsign/Softsign.php
+++ b/src/NeuralNet/ActivationFunctions/Softsign/Softsign.php
@@ -1,8 +1,13 @@
*/
-class Softsign implements ActivationFunction
+class Softsign implements ActivationFunction, IBufferDerivative
{
/**
* Compute the activation.
*
- * @internal
+ * f(x) = x / (1 + |x|)
*
- * @param Matrix $input
- * @return Matrix
+ * @param NDArray $input
+ * @return NDArray
*/
- public function activate(Matrix $input) : Matrix
+ public function activate(NDArray $input) : NDArray
{
- return $input / (1 + NumPower::abs($input));
+ // Calculate |x|
+ $absInput = NumPower::abs($input);
+
+ // Calculate 1 + |x|
+ $denominator = NumPower::add(1.0, $absInput);
+
+ // Calculate x / (1 + |x|)
+ return NumPower::divide($input, $denominator);
}
/**
* Calculate the derivative of the activation.
*
- * @internal
+ * f'(x) = 1 / (1 + |x|)²
*
- * @param Matrix $input
- * @param Matrix $output
- * @return Matrix
+ * @param NDArray $input
+ * @return NDArray
*/
- public function differentiate(Matrix $input, Matrix $output) : Matrix
+ public function differentiate(NDArray $input) : NDArray
{
- return $input->map([$this, '_differentiate']);
- }
+ // Calculate |x|
+ $absInput = NumPower::abs($input);
- /**
- * @internal
- *
- * @param float $input
- * @return float
- */
- public function _differentiate(float $input) : float
- {
- return 1 / (1 + NumPower::abs($input)) ** 2;
+ // Calculate 1 + |x|
+ $onePlusAbs = NumPower::add(1.0, $absInput);
+
+ // Calculate (1 + |x|)²
+ $denominator = NumPower::multiply($onePlusAbs, $onePlusAbs);
+
+ // Calculate 1 / (1 + |x|)²
+ return NumPower::divide(1.0, $denominator);
}
/**
* Return the string representation of the object.
*
- * @internal
- *
* @return string
*/
public function __toString() : string
diff --git a/src/NeuralNet/ActivationFunctions/ThresholdedReLU/Exceptions/InvalidThresholdException.php b/src/NeuralNet/ActivationFunctions/ThresholdedReLU/Exceptions/InvalidThresholdException.php
new file mode 100644
index 000000000..a375419c4
--- /dev/null
+++ b/src/NeuralNet/ActivationFunctions/ThresholdedReLU/Exceptions/InvalidThresholdException.php
@@ -0,0 +1,19 @@
+
+ */
+class InvalidThresholdException extends InvalidArgumentException
+{
+ //
+}
diff --git a/src/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLU.php b/src/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLU.php
index 890b8c123..483744025 100644
--- a/src/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLU.php
+++ b/src/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLU.php
@@ -1,9 +1,14 @@
*/
-class ThresholdedReLU implements ActivationFunction
+class ThresholdedReLU implements ActivationFunction, IBufferDerivative
{
/**
* The input value necessary to trigger an activation.
@@ -29,14 +35,17 @@ class ThresholdedReLU implements ActivationFunction
protected float $threshold;
/**
- * @param float $threshold
- * @throws InvalidArgumentException
+ * Class constructor.
+ *
+ * @param float $threshold The input value necessary to trigger an activation.
+ * @throws InvalidThresholdException
*/
public function __construct(float $threshold = 1.0)
{
if ($threshold < 0.0) {
- throw new InvalidArgumentException('Threshold must be'
- . " positive, $threshold given.");
+ throw new InvalidThresholdException(
+ message: "Threshold must be positive, $threshold given."
+ );
}
$this->threshold = $threshold;
@@ -45,35 +54,37 @@ public function __construct(float $threshold = 1.0)
/**
* Compute the activation.
*
- * @internal
+ * f(x) = x if x > threshold, 0 otherwise
*
- * @param Matrix $input
- * @return Matrix
+ * @param NDArray $input
+ * @return NDArray
*/
- public function activate(Matrix $input) : Matrix
+ public function activate(NDArray $input) : NDArray
{
- return NumPower::greater($input, $this->threshold) * $input;
+ // Create a mask where input > threshold
+ $mask = NumPower::greater($input, $this->threshold);
+
+ // Apply the mask to the input
+ return NumPower::multiply($input, $mask);
}
/**
* Calculate the derivative of the activation.
*
- * @internal
+ * f'(x) = 1 if x > threshold, 0 otherwise
*
- * @param Matrix $input
- * @param Matrix $output
- * @return Matrix
+ * @param NDArray $input
+ * @return NDArray
*/
- public function differentiate(Matrix $input, Matrix $output) : Matrix
+ public function differentiate(NDArray $input) : NDArray
{
+ // The derivative is 1 where input > threshold, 0 otherwise
return NumPower::greater($input, $this->threshold);
}
/**
* Return the string representation of the object.
*
- * @internal
- *
* @return string
*/
public function __toString() : string
diff --git a/tests/NeuralNet/ActivationFunctions/Softsign/SoftsignTest.php b/tests/NeuralNet/ActivationFunctions/Softsign/SoftsignTest.php
new file mode 100644
index 000000000..ab1d5d9d3
--- /dev/null
+++ b/tests/NeuralNet/ActivationFunctions/Softsign/SoftsignTest.php
@@ -0,0 +1,190 @@
+
+ */
+ public static function computeProvider() : Generator
+ {
+ yield [
+ NumPower::array([
+ [2.0, 1.0, -0.5, 0.0, 20.0, -10.0],
+ ]),
+ [
+ [0.6666667, 0.5000000, -0.3333333, 0.0000000, 0.9523810, -0.9090909],
+ ],
+ ];
+
+ yield [
+ NumPower::array([
+ [-0.12, 0.31, -0.49],
+ [0.99, 0.08, -0.03],
+ [0.05, -0.52, 0.54],
+ ]),
+ [
+ [-0.1071429, 0.2366412, -0.3288591],
+ [0.4974874, 0.0740741, -0.0291262],
+ [0.0476190, -0.3421053, 0.3506494],
+ ],
+ ];
+ }
+
+ /**
+ * @return Generator
+ */
+ public static function differentiateProvider() : Generator
+ {
+ yield [
+ NumPower::array([
+ [2.0, 1.0, -0.5, 0.0, 20.0, -10.0],
+ ]),
+ [
+ [0.1111111, 0.2500000, 0.4444444, 1.0000000, 0.0022676, 0.0082645],
+ ],
+ ];
+
+ yield [
+ NumPower::array([
+ [-0.12, 0.31, -0.49],
+ [0.99, 0.08, -0.03],
+ [0.05, -0.52, 0.54],
+ ]),
+ [
+ [0.7971938, 0.5827166, 0.4504301],
+ [0.2525188, 0.8573387, 0.9425959],
+ [0.9070296, 0.4328254, 0.4216562],
+ ],
+ ];
+ }
+
+ /**
+ * @return Generator
+ */
+ public static function zeroRegionProvider() : Generator
+ {
+ // Test exactly at zero
+ yield [
+ NumPower::array([[0.0]]),
+ [[0.0]],
+ [[1.0]],
+ ];
+
+ // Test very small values
+ yield [
+ NumPower::array([[0.0000001, -0.0000001]]),
+ [[0.000000099999999, -0.000000099999999]],
+ [[0.9999998, 0.9999998]],
+ ];
+
+ // Test values around machine epsilon
+ yield [
+ NumPower::array([[PHP_FLOAT_EPSILON, -PHP_FLOAT_EPSILON]]),
+ [[PHP_FLOAT_EPSILON / (1 + PHP_FLOAT_EPSILON), -PHP_FLOAT_EPSILON / (1 + PHP_FLOAT_EPSILON)]],
+ [[1 / (1 + PHP_FLOAT_EPSILON) ** 2, 1 / (1 + PHP_FLOAT_EPSILON) ** 2]],
+ ];
+ }
+
+ /**
+ * @return Generator
+ */
+ public static function extremeValuesProvider() : Generator
+ {
+ // Test with large positive values
+ yield [
+ NumPower::array([[10.0, 100.0, 1000.0]]),
+ [[0.9090909, 0.9900990, 0.9990010]],
+ [[0.00826446, 0.0000980, 0.0000009]],
+ ];
+
+ // Test with large negative values
+ yield [
+ NumPower::array([[-10.0, -100.0, -1000.0]]),
+ [[-0.9090909, -0.9900990, -0.9990010]],
+ [[0.00826446, 0.0000980, 0.0000009]],
+ ];
+ }
+
+ /**
+ * Set up the test case.
+ */
+ protected function setUp() : void
+ {
+ parent::setUp();
+
+ $this->activationFn = new Softsign();
+ }
+
+ #[Test]
+ #[TestDox('Can be cast to a string')]
+ public function testToString() : void
+ {
+ static::assertEquals('Softsign', (string) $this->activationFn);
+ }
+
+ #[Test]
+ #[TestDox('Correctly activates the input')]
+ #[DataProvider('computeProvider')]
+ public function testActivate(NDArray $input, array $expected) : void
+ {
+ $activations = $this->activationFn->activate($input)->toArray();
+
+ static::assertEqualsWithDelta($expected, $activations, 1e-7);
+ }
+
+ #[Test]
+ #[TestDox('Correctly differentiates the input')]
+ #[DataProvider('differentiateProvider')]
+ public function testDifferentiate(NDArray $input, array $expected) : void
+ {
+ $derivatives = $this->activationFn->differentiate($input)->toArray();
+
+ static::assertEqualsWithDelta($expected, $derivatives, 1e-7);
+ }
+
+ #[Test]
+ #[TestDox('Correctly handles values around zero')]
+ #[DataProvider('zeroRegionProvider')]
+ public function testZeroRegion(NDArray $input, array $expectedActivation, array $expectedDerivative) : void
+ {
+ $activations = $this->activationFn->activate($input)->toArray();
+ $derivatives = $this->activationFn->differentiate($input)->toArray();
+
+ static::assertEqualsWithDelta($expectedActivation, $activations, 1e-7);
+ static::assertEqualsWithDelta($expectedDerivative, $derivatives, 1e-7);
+ }
+
+ #[Test]
+ #[TestDox('Correctly handles extreme values')]
+ #[DataProvider('extremeValuesProvider')]
+ public function testExtremeValues(NDArray $input, array $expectedActivation, array $expectedDerivative) : void
+ {
+ $activations = $this->activationFn->activate($input)->toArray();
+ $derivatives = $this->activationFn->differentiate($input)->toArray();
+
+ static::assertEqualsWithDelta($expectedActivation, $activations, 1e-7);
+ static::assertEqualsWithDelta($expectedDerivative, $derivatives, 1e-7);
+ }
+}
diff --git a/tests/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLUTest.php b/tests/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLUTest.php
new file mode 100644
index 000000000..e6d89ee95
--- /dev/null
+++ b/tests/NeuralNet/ActivationFunctions/ThresholdedReLU/ThresholdedReLUTest.php
@@ -0,0 +1,240 @@
+
+ */
+ public static function computeProvider() : Generator
+ {
+ yield [
+ NumPower::array([
+ [2.0, 1.0, 0.5, 0.0, -1.0, 1.5, -0.5],
+ ]),
+ [
+ [2.0, 0.0, 0.0, 0.0, 0.0, 1.5, 0.0],
+ ],
+ ];
+
+ yield [
+ NumPower::array([
+ [1.2, 0.31, 1.49],
+ [0.99, 1.08, 0.03],
+ [1.05, 0.52, 1.54],
+ ]),
+ [
+ [1.2, 0.0, 1.49],
+ [0.0, 1.08, 0.0],
+ [1.05, 0.0, 1.54],
+ ],
+ ];
+ }
+
+ /**
+ * @return Generator
+ */
+ public static function differentiateProvider() : Generator
+ {
+ yield [
+ NumPower::array([
+ [2.0, 1.0, 0.5, 0.0, -1.0, 1.5, -0.5],
+ ]),
+ [
+ [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
+ ],
+ ];
+
+ yield [
+ NumPower::array([
+ [1.2, 0.31, 1.49],
+ [0.99, 1.08, 0.03],
+ [1.05, 0.52, 1.54],
+ ]),
+ [
+ [1.0, 0.0, 1.0],
+ [0.0, 1.0, 0.0],
+ [1.0, 0.0, 1.0],
+ ],
+ ];
+ }
+
+ /**
+ * @return Generator
+ */
+ public static function thresholdValuesProvider() : Generator
+ {
+ yield [
+ 0.5,
+ NumPower::array([
+ [2.0, 1.0, 0.5, 0.0, -1.0],
+ ]),
+ [
+ [2.0, 1.0, 0.0, 0.0, 0.0],
+ ],
+ [
+ [1.0, 1.0, 0.0, 0.0, 0.0],
+ ],
+ ];
+
+ yield [
+ 2.0,
+ NumPower::array([
+ [2.0, 1.0, 3.0, 0.0, 2.5],
+ ]),
+ [
+ [0.0, 0.0, 3.0, 0.0, 2.5],
+ ],
+ [
+ [0.0, 0.0, 1.0, 0.0, 1.0],
+ ],
+ ];
+ }
+
+ /**
+ * @return Generator
+ */
+ public static function zeroRegionProvider() : Generator
+ {
+ yield [
+ NumPower::array([[0.0]]),
+ [[0.0]],
+ [[0.0]],
+ ];
+
+ yield [
+ NumPower::array([[0.5, 0.9, 0.99, 1.0, 1.01]]),
+ [[0.0, 0.0, 0.0, 0.0, 1.01]],
+ [[0.0, 0.0, 0.0, 0.0, 1.0]],
+ ];
+ }
+
+ /**
+ * @return Generator
+ */
+ public static function extremeValuesProvider() : Generator
+ {
+ yield [
+ NumPower::array([[10.0, 100.0, 1000.0]]),
+ [[10.0, 100.0, 1000.0]],
+ [[1.0, 1.0, 1.0]],
+ ];
+
+ yield [
+ NumPower::array([[-10.0, -100.0, -1000.0]]),
+ [[0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0]],
+ ];
+ }
+
+ /**
+ * Set up the test case.
+ */
+ protected function setUp() : void
+ {
+ parent::setUp();
+
+ $this->activationFn = new ThresholdedReLU($this->threshold);
+ }
+
+ #[Test]
+ #[TestDox('Can be cast to a string')]
+ public function testToString() : void
+ {
+ static::assertEquals('Thresholded ReLU (threshold: 1)', (string) $this->activationFn);
+ }
+
+ #[Test]
+ #[TestDox('It throws an exception when threshold is negative')]
+ public function testInvalidThresholdException() : void
+ {
+ $this->expectException(InvalidThresholdException::class);
+
+ new ThresholdedReLU(-1.0);
+ }
+
+ #[Test]
+ #[TestDox('Correctly activates the input')]
+ #[DataProvider('computeProvider')]
+ public function testActivate(NDArray $input, array $expected) : void
+ {
+ $activations = $this->activationFn->activate($input)->toArray();
+
+ static::assertEqualsWithDelta($expected, $activations, 1e-7);
+ }
+
+ #[Test]
+ #[TestDox('Correctly differentiates the input')]
+ #[DataProvider('differentiateProvider')]
+ public function testDifferentiate(NDArray $input, array $expected) : void
+ {
+ $derivatives = $this->activationFn->differentiate($input)->toArray();
+
+ static::assertEqualsWithDelta($expected, $derivatives, 1e-7);
+ }
+
+ #[Test]
+ #[TestDox('Correctly handles different threshold values')]
+ #[DataProvider('thresholdValuesProvider')]
+ public function testThresholdValues(float $threshold, NDArray $input, array $expectedActivation, array $expectedDerivative) : void
+ {
+ $activationFn = new ThresholdedReLU($threshold);
+
+ $activations = $activationFn->activate($input)->toArray();
+ $derivatives = $activationFn->differentiate($input)->toArray();
+
+ static::assertEqualsWithDelta($expectedActivation, $activations, 1e-7);
+ static::assertEqualsWithDelta($expectedDerivative, $derivatives, 1e-7);
+ }
+
+ #[Test]
+ #[TestDox('Correctly handles values around zero')]
+ #[DataProvider('zeroRegionProvider')]
+ public function testZeroRegion(NDArray $input, array $expectedActivation, array $expectedDerivative) : void
+ {
+ $activations = $this->activationFn->activate($input)->toArray();
+ $derivatives = $this->activationFn->differentiate($input)->toArray();
+
+ static::assertEqualsWithDelta($expectedActivation, $activations, 1e-7);
+ static::assertEqualsWithDelta($expectedDerivative, $derivatives, 1e-7);
+ }
+
+ #[Test]
+ #[TestDox('Correctly handles extreme values')]
+ #[DataProvider('extremeValuesProvider')]
+ public function testExtremeValues(NDArray $input, array $expectedActivation, array $expectedDerivative) : void
+ {
+ $activations = $this->activationFn->activate($input)->toArray();
+ $derivatives = $this->activationFn->differentiate($input)->toArray();
+
+ static::assertEqualsWithDelta($expectedActivation, $activations, 1e-7);
+ static::assertEqualsWithDelta($expectedDerivative, $derivatives, 1e-7);
+ }
+}