diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py new file mode 100644 index 0000000..047d6f9 --- /dev/null +++ b/src/integration_tests/algorithmic_style_test.py @@ -0,0 +1,77 @@ +"""End-to-end test cases of algorithmic style.""" + +from __future__ import annotations + +from typing import Any, Callable + +from latexify import frontend + + +def check_algorithm( + fn: Callable[..., Any], + latex: str, + **kwargs, +) -> None: + """Helper to check if the obtained function has the expected LaTeX form. + + Args: + fn: Function to check. + latex: LaTeX form of `fn`. + **kwargs: Arguments passed to `frontend.get_latex`. + """ + # Checks the syntax: + # def fn(...): + # ... + # latexified = get_latex(fn, style=ALGORITHM, **kwargs) + latexified = frontend.get_latex(fn, style=frontend.Style.ALGORITHMIC, **kwargs) + assert latexified == latex + + +def test_factorial() -> None: + def fact(n): + if n == 0: + return 1 + else: + return n * fact(n - 1) + + latex = ( + r"\begin{algorithmic}" + r" \Function{fact}{$n$}" + r" \If{$n = 0$}" + r" \State \Return $1$" + r" \Else" + r" \State \Return $n \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$" + r" \EndIf" + r" \EndFunction" + r" \end{algorithmic}" + ) + check_algorithm(fact, latex) + + +def test_collatz() -> None: + def collatz(n): + iterations = 0 + while n > 1: + if n % 2 == 0: + n = n // 2 + else: + n = 3 * n + 1 + iterations = iterations + 1 + return iterations + + latex = ( + r"\begin{algorithmic}" + r" \Function{collatz}{$n$}" + r" \State $\mathrm{iterations} \gets 0$" + r" \While{$n > 1$}" + r" \If{$n \mathbin{\%} 2 = 0$}" + r" \State $n \gets \left\lfloor\frac{n}{2}\right\rfloor$" + r" \Else \State $n \gets 3 n + 1$" + r" \EndIf" + r" \State $\mathrm{iterations} \gets \mathrm{iterations} + 1$" + r" \EndWhile" + r" \State \Return $\mathrm{iterations}$" + r" \EndFunction" + r" \end{algorithmic}" + ) + check_algorithm(collatz, latex) diff --git a/src/integration_tests/function_expansion_test.py b/src/integration_tests/function_expansion_test.py index 5cc7ab0..d3c373d 100644 --- a/src/integration_tests/function_expansion_test.py +++ b/src/integration_tests/function_expansion_test.py @@ -1,6 +1,10 @@ +"""End-to-end test cases of function expansion.""" + +from __future__ import annotations + import math -from integration_tests import utils +from integration_tests import integration_utils def test_atan2() -> None: @@ -11,7 +15,7 @@ def solve(x, y): r"\mathrm{solve}(x, y) =" r" \arctan \mathopen{}\left( \frac{y}{x} \mathclose{}\right)" ) - utils.check_function(solve, latex, expand_functions={"atan2"}) + integration_utils.check_function(solve, latex, expand_functions={"atan2"}) def test_atan2_nested() -> None: @@ -22,7 +26,7 @@ def solve(x, y): r"\mathrm{solve}(x, y) =" r" \arctan \mathopen{}\left( \frac{e^{y}}{e^{x}} \mathclose{}\right)" ) - utils.check_function(solve, latex, expand_functions={"atan2", "exp"}) + integration_utils.check_function(solve, latex, expand_functions={"atan2", "exp"}) def test_exp() -> None: @@ -30,7 +34,7 @@ def solve(x): return math.exp(x) latex = r"\mathrm{solve}(x) = e^{x}" - utils.check_function(solve, latex, expand_functions={"exp"}) + integration_utils.check_function(solve, latex, expand_functions={"exp"}) def test_exp_nested() -> None: @@ -38,7 +42,7 @@ def solve(x): return math.exp(math.exp(x)) latex = r"\mathrm{solve}(x) = e^{e^{x}}" - utils.check_function(solve, latex, expand_functions={"exp"}) + integration_utils.check_function(solve, latex, expand_functions={"exp"}) def test_exp2() -> None: @@ -46,7 +50,7 @@ def solve(x): return math.exp2(x) latex = r"\mathrm{solve}(x) = 2^{x}" - utils.check_function(solve, latex, expand_functions={"exp2"}) + integration_utils.check_function(solve, latex, expand_functions={"exp2"}) def test_exp2_nested() -> None: @@ -54,7 +58,7 @@ def solve(x): return math.exp2(math.exp2(x)) latex = r"\mathrm{solve}(x) = 2^{2^{x}}" - utils.check_function(solve, latex, expand_functions={"exp2"}) + integration_utils.check_function(solve, latex, expand_functions={"exp2"}) def test_expm1() -> None: @@ -62,7 +66,7 @@ def solve(x): return math.expm1(x) latex = r"\mathrm{solve}(x) = \exp x - 1" - utils.check_function(solve, latex, expand_functions={"expm1"}) + integration_utils.check_function(solve, latex, expand_functions={"expm1"}) def test_expm1_nested() -> None: @@ -70,7 +74,9 @@ def solve(x, y, z): return math.expm1(math.pow(y, z)) latex = r"\mathrm{solve}(x, y, z) = e^{y^{z}} - 1" - utils.check_function(solve, latex, expand_functions={"expm1", "exp", "pow"}) + integration_utils.check_function( + solve, latex, expand_functions={"expm1", "exp", "pow"} + ) def test_hypot_without_attribute() -> None: @@ -80,7 +86,7 @@ def solve(x, y, z): return hypot(x, y, z) latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{2} + y^{2} + z^{2} }" - utils.check_function(solve, latex, expand_functions={"hypot"}) + integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_hypot() -> None: @@ -88,7 +94,7 @@ def solve(x, y, z): return math.hypot(x, y, z) latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{2} + y^{2} + z^{2} }" - utils.check_function(solve, latex, expand_functions={"hypot"}) + integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_hypot_nested() -> None: @@ -99,7 +105,7 @@ def solve(a, b, x, y): r"\mathrm{solve}(a, b, x, y) =" r" \sqrt{ \sqrt{ a^{2} + b^{2} }^{2} + x^{2} + y^{2} }" ) - utils.check_function(solve, latex, expand_functions={"hypot"}) + integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_log1p() -> None: @@ -107,7 +113,7 @@ def solve(x): return math.log1p(x) latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( 1 + x \mathclose{}\right)" - utils.check_function(solve, latex, expand_functions={"log1p"}) + integration_utils.check_function(solve, latex, expand_functions={"log1p"}) def test_log1p_nested() -> None: @@ -115,7 +121,7 @@ def solve(x): return math.log1p(math.exp(x)) latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( 1 + e^{x} \mathclose{}\right)" - utils.check_function(solve, latex, expand_functions={"log1p", "exp"}) + integration_utils.check_function(solve, latex, expand_functions={"log1p", "exp"}) def test_pow_nested() -> None: @@ -126,7 +132,7 @@ def solve(w, x, y, z): r"\mathrm{solve}(w, x, y, z) = " r"\mathopen{}\left( w^{x} \mathclose{}\right)^{y^{z}}" ) - utils.check_function(solve, latex, expand_functions={"pow"}) + integration_utils.check_function(solve, latex, expand_functions={"pow"}) def test_pow() -> None: @@ -134,4 +140,4 @@ def solve(x, y): return math.pow(x, y) latex = r"\mathrm{solve}(x, y) = x^{y}" - utils.check_function(solve, latex, expand_functions={"pow"}) + integration_utils.check_function(solve, latex, expand_functions={"pow"}) diff --git a/src/integration_tests/utils.py b/src/integration_tests/integration_utils.py similarity index 93% rename from src/integration_tests/utils.py rename to src/integration_tests/integration_utils.py index ecb4d5d..92ada9a 100644 --- a/src/integration_tests/utils.py +++ b/src/integration_tests/integration_utils.py @@ -1,3 +1,7 @@ +"""Utilities for integration tests.""" + +from __future__ import annotations + from typing import Any, Callable from latexify import frontend diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index e698cdf..f421413 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -4,7 +4,7 @@ import math -from integration_tests import utils +from integration_tests import integration_utils def test_quadratic_solution() -> None: @@ -12,7 +12,7 @@ def solve(a, b, c): return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a) latex = r"\mathrm{solve}(a, b, c) = \frac{-b + \sqrt{ b^{2} - 4 a c }}{2 a}" - utils.check_function(solve, latex) + integration_utils.check_function(solve, latex) def test_sinc() -> None: @@ -29,7 +29,7 @@ def sinc(x): r" \frac{\sin x}{x}, & \mathrm{otherwise}" r" \end{array} \right." ) - utils.check_function(sinc, latex) + integration_utils.check_function(sinc, latex) def test_x_times_beta() -> None: @@ -37,11 +37,15 @@ def xtimesbeta(x, beta): return x * beta latex_without_symbols = r"\mathrm{xtimesbeta}(x, \mathrm{beta}) = x \mathrm{beta}" - utils.check_function(xtimesbeta, latex_without_symbols) - utils.check_function(xtimesbeta, latex_without_symbols, use_math_symbols=False) + integration_utils.check_function(xtimesbeta, latex_without_symbols) + integration_utils.check_function( + xtimesbeta, latex_without_symbols, use_math_symbols=False + ) latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \beta" - utils.check_function(xtimesbeta, latex_with_symbols, use_math_symbols=True) + integration_utils.check_function( + xtimesbeta, latex_with_symbols, use_math_symbols=True + ) def test_sum_with_limit_1arg() -> None: @@ -52,7 +56,7 @@ def sum_with_limit(n): r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n - 1}" r" \mathopen{}\left({i^{2}}\mathclose{}\right)" ) - utils.check_function(sum_with_limit, latex) + integration_utils.check_function(sum_with_limit, latex) def test_sum_with_limit_2args() -> None: @@ -63,7 +67,7 @@ def sum_with_limit(a, n): r"\mathrm{sum\_with\_limit}(a, n) = \sum_{i = a}^{n - 1} " r"\mathopen{}\left({i^{2}}\mathclose{}\right)" ) - utils.check_function(sum_with_limit, latex) + integration_utils.check_function(sum_with_limit, latex) def test_sum_with_reducible_limit() -> None: @@ -74,7 +78,7 @@ def sum_with_limit(n): r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n} " r"\mathopen{}\left({i}\mathclose{}\right)" ) - utils.check_function(sum_with_limit, latex) + integration_utils.check_function(sum_with_limit, latex) def test_sum_with_irreducible_limit() -> None: @@ -85,7 +89,7 @@ def sum_with_limit(n): r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n 3 - 1} " r"\mathopen{}\left({i}\mathclose{}\right)" ) - utils.check_function(sum_with_limit, latex) + integration_utils.check_function(sum_with_limit, latex) def test_prod_with_limit_1arg() -> None: @@ -96,7 +100,7 @@ def prod_with_limit(n): r"\mathrm{prod\_with\_limit}(n) = " r"\prod_{i = 0}^{n - 1} \mathopen{}\left({i^{2}}\mathclose{}\right)" ) - utils.check_function(prod_with_limit, latex) + integration_utils.check_function(prod_with_limit, latex) def test_prod_with_limit_2args() -> None: @@ -107,7 +111,7 @@ def prod_with_limit(a, n): r"\mathrm{prod\_with\_limit}(a, n) = " r"\prod_{i = a}^{n - 1} \mathopen{}\left({i^{2}}\mathclose{}\right)" ) - utils.check_function(prod_with_limit, latex) + integration_utils.check_function(prod_with_limit, latex) def test_prod_with_reducible_limits() -> None: @@ -118,7 +122,7 @@ def prod_with_limit(n): r"\mathrm{prod\_with\_limit}(n) = " r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)" ) - utils.check_function(prod_with_limit, latex) + integration_utils.check_function(prod_with_limit, latex) def test_prod_with_irreducible_limit() -> None: @@ -129,14 +133,14 @@ def prod_with_limit(n): r"\mathrm{prod\_with\_limit}(n) = " r"\prod_{i = 0}^{n 3 - 1} \mathopen{}\left({i}\mathclose{}\right)" ) - utils.check_function(prod_with_limit, latex) + integration_utils.check_function(prod_with_limit, latex) def test_nested_function() -> None: def nested(x): return 3 * x - utils.check_function(nested, r"\mathrm{nested}(x) = 3 x") + integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 x") def test_double_nested_function() -> None: @@ -146,7 +150,7 @@ def inner(y): return inner - utils.check_function(nested(3), r"\mathrm{inner}(y) = x y") + integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x y") def test_reduce_assignments() -> None: @@ -154,11 +158,11 @@ def f(x): a = x + x return 3 * a - utils.check_function( + integration_utils.check_function( f, r"\begin{array}{l} a = x + x \\ f(x) = 3 a \end{array}", ) - utils.check_function( + integration_utils.check_function( f, r"f(x) = 3 \mathopen{}\left( x + x \mathclose{}\right)", reduce_assignments=True, @@ -179,9 +183,9 @@ def f(x): r"\end{array}" ) - utils.check_function(f, latex_without_option) - utils.check_function(f, latex_without_option, reduce_assignments=False) - utils.check_function( + integration_utils.check_function(f, latex_without_option) + integration_utils.check_function(f, latex_without_option, reduce_assignments=False) + integration_utils.check_function( f, r"f(x) = 3 \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)", reduce_assignments=True, @@ -197,7 +201,7 @@ def sigmoid(x): else: return n - utils.check_function( + integration_utils.check_function( sigmoid, ( r"\mathrm{sigmoid}(x) = \left\{ \begin{array}{ll} " @@ -221,7 +225,7 @@ def solve(a, b): r"a + b \mathclose{}\right) - \mathopen{}\left( " r"a - b \mathclose{}\right) - a b" ) - utils.check_function(solve, latex) + integration_utils.check_function(solve, latex) def test_docstring_allowed() -> None: @@ -230,7 +234,7 @@ def solve(x): return x latex = r"\mathrm{solve}(x) = x" - utils.check_function(solve, latex) + integration_utils.check_function(solve, latex) def test_multiple_constants_allowed() -> None: @@ -241,4 +245,4 @@ def solve(x): return x latex = r"\mathrm{solve}(x) = x" - utils.check_function(solve, latex) + integration_utils.check_function(solve, latex) diff --git a/src/latexify/__init__.py b/src/latexify/__init__.py index 0d567f0..f25f699 100644 --- a/src/latexify/__init__.py +++ b/src/latexify/__init__.py @@ -9,6 +9,8 @@ from latexify import frontend +Style = frontend.Style + get_latex = frontend.get_latex function = frontend.function diff --git a/src/latexify/codegen/__init__.py b/src/latexify/codegen/__init__.py index cddad8d..1aea2c3 100644 --- a/src/latexify/codegen/__init__.py +++ b/src/latexify/codegen/__init__.py @@ -1,6 +1,7 @@ """Package latexify.codegen.""" -from latexify.codegen import expression_codegen, function_codegen +from latexify.codegen import algorithmic_codegen, expression_codegen, function_codegen +AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen ExpressionCodegen = expression_codegen.ExpressionCodegen FunctionCodegen = function_codegen.FunctionCodegen diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py new file mode 100644 index 0000000..2fcb16b --- /dev/null +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -0,0 +1,106 @@ +"""Codegen for single algorithms.""" + +from __future__ import annotations + +import ast + +from latexify import exceptions +from latexify.codegen import expression_codegen, identifier_converter + + +class AlgorithmicCodegen(ast.NodeVisitor): + """Codegen for single algorithms. + + This codegen works for Module with single FunctionDef node to generate a single + LaTeX expression of the given algorithm. + """ + + _identifier_converter: identifier_converter.IdentifierConverter + + def __init__( + self, *, use_math_symbols: bool = False, use_set_symbols: bool = False + ) -> None: + """Initializer. + + Args: + use_math_symbols: Whether to convert identifiers with a math symbol surface + (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). + use_set_symbols: Whether to use set symbols or not. + """ + self._expression_codegen = expression_codegen.ExpressionCodegen( + use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols + ) + self._identifier_converter = identifier_converter.IdentifierConverter( + use_math_symbols=use_math_symbols + ) + + def generic_visit(self, node: ast.AST) -> str: + raise exceptions.LatexifyNotSupportedError( + f"Unsupported AST: {type(node).__name__}" + ) + + def visit_Assign(self, node: ast.Assign) -> str: + """Visit an Assign node.""" + operands: list[str] = [ + self._expression_codegen.visit(target) for target in node.targets + ] + operands.append(self._expression_codegen.visit(node.value)) + operands_latex = r" \gets ".join(operands) + return rf"\State ${operands_latex}$" + + def visit_Expr(self, node: ast.Expr) -> str: + """Visit an Expr node.""" + return rf"\State ${self._expression_codegen.visit(node.value)}$" + + def visit_FunctionDef(self, node: ast.FunctionDef) -> str: + """Visit a FunctionDef node.""" + # Arguments + arg_strs = [ + self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args + ] + # Body + body_strs: list[str] = [self.visit(stmt) for stmt in node.body] + return ( + rf"\begin{{algorithmic}}" + rf" \Function{{{node.name}}}{{${', '.join(arg_strs)}$}}" + f" {' '.join(body_strs)}" + r" \EndFunction" + rf" \end{{algorithmic}}" + ) + + # TODO(ZibingZhang): support \ELSIF + def visit_If(self, node: ast.If) -> str: + """Visit an If node.""" + cond_latex = self._expression_codegen.visit(node.test) + body_latex = " ".join(self.visit(stmt) for stmt in node.body) + + latex = rf"\If{{${cond_latex}$}} {body_latex}" + + if node.orelse: + latex += r" \Else " + latex += " ".join(self.visit(stmt) for stmt in node.orelse) + + return latex + r" \EndIf" + + def visit_Module(self, node: ast.Module) -> str: + """Visit a Module node.""" + return self.visit(node.body[0]) + + def visit_Return(self, node: ast.Return) -> str: + """Visit a Return node.""" + return ( + rf"\State \Return ${self._expression_codegen.visit(node.value)}$" + if node.value is not None + else r"\State \Return" + ) + + def visit_While(self, node: ast.While) -> str: + """Visit a While node.""" + if node.orelse: + raise exceptions.LatexifyNotSupportedError( + "While statement with the else clause is not supported" + ) + + cond_latex = self._expression_codegen.visit(node.test) + body_latex = " ".join(self.visit(stmt) for stmt in node.body) + return rf"\While{{${cond_latex}$}} {body_latex} \EndWhile" diff --git a/src/latexify/codegen/algorithmic_codegen_test.py b/src/latexify/codegen/algorithmic_codegen_test.py new file mode 100644 index 0000000..a972beb --- /dev/null +++ b/src/latexify/codegen/algorithmic_codegen_test.py @@ -0,0 +1,135 @@ +"""Tests for latexify.codegen.algorithmic_codegen.""" + +from __future__ import annotations + +import ast +import textwrap + +import pytest + +from latexify import exceptions +from latexify.codegen import algorithmic_codegen + + +def test_generic_visit() -> None: + class UnknownNode(ast.AST): + pass + + with pytest.raises( + exceptions.LatexifyNotSupportedError, + match=r"^Unsupported AST: UnknownNode$", + ): + algorithmic_codegen.AlgorithmicCodegen().visit(UnknownNode()) + + +@pytest.mark.parametrize( + "code,latex", + [ + ("x = 3", r"\State $x \gets 3$"), + ("a = b = 0", r"\State $a \gets b \gets 0$"), + ], +) +def test_visit_assign(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.Assign) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "def f(x): return x", + ( + r"\begin{algorithmic}" + r" \Function{f}{$x$}" + r" \State \Return $x$" + r" \EndFunction" + r" \end{algorithmic}" + ), + ), + ( + "def xyz(a, b, c): return 3", + ( + r"\begin{algorithmic}" + r" \Function{xyz}{$a, b, c$}" + r" \State \Return $3$" + r" \EndFunction" + r" \end{algorithmic}" + ), + ), + ], +) +def test_visit_functiondef(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.FunctionDef) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("if x < y: return x", r"\If{$x < y$} \State \Return $x$ \EndIf"), + ( + "if True: x\nelse: y", + r"\If{$\mathrm{True}$} \State $x$ \Else \State $y$ \EndIf", + ), + ], +) +def test_visit_if(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.If) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "return x + y", + r"\State \Return $x + y$", + ), + ( + "return", + r"\State \Return", + ), + ], +) +def test_visit_return(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.Return) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "while x < y: x = x + 1", + r"\While{$x < y$} \State $x \gets x + 1$ \EndWhile", + ) + ], +) +def test_visit_while(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.While) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +def test_visit_while_with_else() -> None: + node = ast.parse( + textwrap.dedent( + """ + while True: + x = x + else: + x = y + """ + ) + ).body[0] + assert isinstance(node, ast.While) + with pytest.raises( + exceptions.LatexifyNotSupportedError, + match="^While statement with the else clause is not supported$", + ): + algorithmic_codegen.AlgorithmicCodegen().visit(node) diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index b8e363a..d8d4f19 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -26,8 +26,8 @@ def test_visit_functiondef_use_signature() -> None: tree = ast.parse( textwrap.dedent( """ - def f(x): - return x + def f(x): + return x """ ) ).body[0] @@ -50,9 +50,9 @@ def test_visit_functiondef_ignore_docstring() -> None: tree = ast.parse( textwrap.dedent( """ - def f(x): - '''docstring''' - return x + def f(x): + '''docstring''' + return x """ ) ).body[0] @@ -66,11 +66,11 @@ def test_visit_functiondef_ignore_multiple_constants() -> None: tree = ast.parse( textwrap.dedent( """ - def f(x): - '''docstring''' - 3 - True - return x + def f(x): + '''docstring''' + 3 + True + return x """ ) ).body[0] diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py index e0896bd..83a5d7e 100644 --- a/src/latexify/frontend.py +++ b/src/latexify/frontend.py @@ -2,6 +2,7 @@ from __future__ import annotations +import enum from collections.abc import Callable from typing import Any, overload @@ -16,10 +17,16 @@ _COMMON_PREFIXES = {"math", "numpy", "np"} -# TODO(odashi): move expand_functions to Config. +class Style(enum.Enum): + EXPRESSION = "expression" + FUNCTION = "function" + ALGORITHMIC = "algorithmic" + + def get_latex( fn: Callable[..., Any], *, + style: Style = Style.FUNCTION, config: cfg.Config | None = None, **kwargs, ) -> str: @@ -27,9 +34,10 @@ def get_latex( Args: fn: Reference to a function to analyze. - config: use defined Config object, if it is None, it will be automatic assigned + style: Style of the LaTeX description, the default is FUNCTION. + config: Use defined Config object, if it is None, it will be automatic assigned with default value. - **kwargs: dict of Config field values that could be defined individually + **kwargs: Dict of Config field values that could be defined individually by users. Returns: @@ -38,6 +46,9 @@ def get_latex( Raises: latexify.exceptions.LatexifyError: Something went wrong during conversion. """ + if style == Style.EXPRESSION: + kwargs["use_signature"] = kwargs.get("use_signature", False) + merged_config = cfg.Config.defaults().merge(config=config, **kwargs) # Obtains the source AST. @@ -56,11 +67,17 @@ def get_latex( tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree) # Generates LaTeX. - return codegen.FunctionCodegen( - use_math_symbols=merged_config.use_math_symbols, - use_signature=merged_config.use_signature, - use_set_symbols=merged_config.use_set_symbols, - ).visit(tree) + if style == Style.ALGORITHMIC: + return codegen.AlgorithmicCodegen( + use_math_symbols=merged_config.use_math_symbols, + use_set_symbols=merged_config.use_set_symbols, + ).visit(tree) + else: + return codegen.FunctionCodegen( + use_math_symbols=merged_config.use_math_symbols, + use_signature=merged_config.use_signature, + use_set_symbols=merged_config.use_set_symbols, + ).visit(tree) class LatexifiedFunction: @@ -173,7 +190,7 @@ def expression( This function is a shortcut for `latexify.function` with the default parameter `use_signature=False`. """ - kwargs["use_signature"] = kwargs.get("use_signature", False) + kwargs["style"] = Style.EXPRESSION if fn is not None: return function(fn, **kwargs) else: