diff --git a/ComputerAlgebra/LinqCompiler/CompileExpression.cs b/ComputerAlgebra/LinqCompiler/CompileExpression.cs index 520c0a2..1b73508 100644 --- a/ComputerAlgebra/LinqCompiler/CompileExpression.cs +++ b/ComputerAlgebra/LinqCompiler/CompileExpression.cs @@ -110,6 +110,35 @@ protected override LinqExpr VisitPower(Power P) protected override LinqExpr VisitCall(Call C) { + if (C.Target.Name == "Ln") + { + var arg = C.Arguments.Single(); + if (arg is Sum sum) + { + if (sum.Terms.Count() == 2) + { + var terms = sum.Terms.ToArray(); + var a = terms[0]; + var b = terms[1]; + + var (match, lnArgs) = (a, b) switch + { + (Constant c, Call f) when c.Value == 1 && f.Target.Name == "Exp" => (true, f.Arguments), + (Call f, Constant c) when c.Value == 1 && f.Target.Name == "Exp" => (true, f.Arguments), + _ => (false, null), + }; + + if (match) + { + var compiledArgs = lnArgs.Select(i => Visit(i)).ToArray(); + return Int(C, LinqExpr.Call( + target.Module.GetFunction("Ln1Exp", compiledArgs.Select(i => i.Type).ToArray()), + compiledArgs)); + } + } + } + + } LinqExpr[] args = C.Arguments.Select(i => Visit(i)).ToArray(); return Int(C, LinqExpr.Call( target.Module.Compile(C.Target, args.Select(i => i.Type).ToArray()), diff --git a/ComputerAlgebra/LinqCompiler/StandardMath.cs b/ComputerAlgebra/LinqCompiler/StandardMath.cs index 4395f1a..f56db47 100644 --- a/ComputerAlgebra/LinqCompiler/StandardMath.cs +++ b/ComputerAlgebra/LinqCompiler/StandardMath.cs @@ -7,6 +7,13 @@ namespace ComputerAlgebra.LinqCompiler /// public class StandardMath { + private static double ExpKnee = 50; + static StandardMath() + { + _knee = Math.Exp(ExpKnee); + _b = _knee - _knee * ExpKnee; + } + public static double Ln1Exp(double x) => x > 50d ? x : Math.Log(1d + Math.Exp(x)); public static double Abs(double x) { return x < 0 ? -x : x; } public static double Sign(double x) { return x > 0 ? 1 : (x < 0 ? -1 : 0); } @@ -42,7 +49,11 @@ public class StandardMath public static double ArcCoth(double x) { return ArcTanh(1 / x); } public static double Sqrt(double x) { return Math.Sqrt(x); } - public static double Exp(double x) { return Math.Exp(x); } + + private static double _knee; + private static double _b; + + public static double Exp(double x) => x > ExpKnee ? _knee * x + _b : Math.Exp(x); public static double Ln(double x) { return Math.Log(x); } public static double Log(double x, double b) { return Math.Log(x, b); } public static double Pow(double x, double y) { return Math.Pow(x, y); }