From 4ada4b570c988b14bb59c340f098c0b87de33f42 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 07:48:52 +0000 Subject: [PATCH 01/14] init alg --- .../algorithmic_style_test.py | 69 ++++++++++++++++++ src/latexify/codegen/__init__.py | 3 +- src/latexify/codegen/algorithmic_codegen.py | 73 +++++++++++++++++++ src/latexify/frontend.py | 35 ++++++--- 4 files changed, 170 insertions(+), 10 deletions(-) create mode 100644 src/integration_tests/algorithmic_style_test.py create mode 100644 src/latexify/codegen/algorithmic_codegen.py diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py new file mode 100644 index 0000000..55b4e15 --- /dev/null +++ b/src/integration_tests/algorithmic_style_test.py @@ -0,0 +1,69 @@ +from typing import Any, Callable + +from latexify import frontend + + +def check_algorithm( + fn: Callable[..., Any], + latex: str, + **kwargs, +) -> None: + """Helper to check if the obtained function has the expected LaTeX form. + Args: + fn: Function to check. + latex: LaTeX form of `fn`. + **kwargs: Arguments passed to `frontend.get_latex`. + """ + # Checks the syntax: + # def fn(...): + # ... + # latexified = get_latex(fn, style=ALGORITHM, **kwargs) + latexified = frontend.get_latex(fn, style=frontend.Style.ALGORITHMIC, **kwargs) + assert latexified == latex + + +def test_factorial() -> None: + def fact(n): + if n == 0: + return 1 + else: + return n * fact(n - 1) + + latex = ( + r"\begin{algorithmic} " + r"\If{$n = 0$} " + r"\State \Return $1$ " + r"\Else " + r"\State \Return $n \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$ " + r"\EndIf " + r"\end{algorithmic}" + ) + check_algorithm(fact, latex) + + +def test_collatz() -> None: + def collatz(n): + iterations = 0 + while n > 1: + if n % 2 == 0: + n = n / 2 + else: + n = 3 * n + 1 + iterations = iterations + 1 + return iterations + + latex = ( + r"\begin{algorithmic} " + r"\State $\mathrm{iterations} \gets 0$ " + r"\While{$n > 1$} " + r"\If{$n \mathbin{\%} 2 = 0$} " + r"\State $n \gets \frac{n}{2}$ " + r"\Else \State $n \gets 3 n + 1$ " + r"\EndIf " + r"\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$ " + r"\EndWhile " + r"\State \Return $\mathrm{iterations}$ " + r"\end{algorithmic}" + ) + + check_algorithm(collatz, latex) diff --git a/src/latexify/codegen/__init__.py b/src/latexify/codegen/__init__.py index cddad8d..1aea2c3 100644 --- a/src/latexify/codegen/__init__.py +++ b/src/latexify/codegen/__init__.py @@ -1,6 +1,7 @@ """Package latexify.codegen.""" -from latexify.codegen import expression_codegen, function_codegen +from latexify.codegen import algorithmic_codegen, expression_codegen, function_codegen +AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen ExpressionCodegen = expression_codegen.ExpressionCodegen FunctionCodegen = function_codegen.FunctionCodegen diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py new file mode 100644 index 0000000..f39ac74 --- /dev/null +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -0,0 +1,73 @@ +"""Codegen for single algorithms.""" +import ast + +from latexify import exceptions +from latexify.codegen import codegen_utils, expression_codegen + + +class AlgorithmicCodegen(ast.NodeVisitor): + """Codegen for single algorithms.""" + + def __init__( + self, *, use_math_symbols: bool = False, use_set_symbols: bool = False + ) -> None: + """Initializer. + Args: + use_math_symbols: Whether to convert identifiers with a math symbol surface + (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). + use_set_symbols: Whether to use set symbols or not. + """ + self._expression_codegen = expression_codegen.ExpressionCodegen( + use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols + ) + + def generic_visit(self, node: ast.AST) -> str: + raise exceptions.LatexifyNotSupportedError( + f"Unsupported AST: {type(node).__name__}" + ) + + def visit_Assign(self, node: ast.Assign) -> str: + operands: list[str] = [ + self._expression_codegen.visit(target) for target in node.targets + ] + operands.append(self._expression_codegen.visit(node.value)) + operands_latex = r" \gets ".join(operands) + return rf"\State ${operands_latex}$" + + def visit_FunctionDef(self, node: ast.FunctionDef) -> str: + body_strs: list[str] = [self.visit(stmt) for stmt in node.body] + return rf"\begin{{algorithmic}} {' '.join(body_strs)} \end{{algorithmic}}" + + def visit_If(self, node: ast.If) -> str: + cond_latex = self._expression_codegen.visit(node.test) + body_latex = " ".join(self.visit(stmt) for stmt in node.body) + + latex = rf"\If{{${cond_latex}$}} {body_latex}" + + if node.orelse: + latex += r" \Else " + latex += " ".join(self.visit(stmt) for stmt in node.orelse) + + return latex + r" \EndIf" + + def visit_Module(self, node: ast.Module) -> str: + return self.visit(node.body[0]) + + def visit_Return(self, node: ast.Return) -> str: + return ( + rf"\State \Return ${self._expression_codegen.visit(node.value)}$" + if node.value is not None + else codegen_utils.convert_constant(None) + ) + + def visit_While(self, node: ast.While) -> str: + cond_latex = self._expression_codegen.visit(node.test) + body_latex = " ".join(self.visit(stmt) for stmt in node.body) + + latex = rf"\While{{${cond_latex}$}} {body_latex}" + + if node.orelse: + latex += r" \Else " + latex += " ".join(self.visit(stmt) for stmt in node.orelse) + + return latex + r" \EndWhile" diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py index e0896bd..9dd98a8 100644 --- a/src/latexify/frontend.py +++ b/src/latexify/frontend.py @@ -2,6 +2,7 @@ from __future__ import annotations +import enum from collections.abc import Callable from typing import Any, overload @@ -16,10 +17,16 @@ _COMMON_PREFIXES = {"math", "numpy", "np"} -# TODO(odashi): move expand_functions to Config. +class Style(str, enum.Enum): + EXPRESSION = "expression" + FUNCTION = "function" + ALGORITHMIC = "algorithmic" + + def get_latex( fn: Callable[..., Any], *, + style: Style = Style.FUNCTION, config: cfg.Config | None = None, **kwargs, ) -> str: @@ -27,9 +34,10 @@ def get_latex( Args: fn: Reference to a function to analyze. - config: use defined Config object, if it is None, it will be automatic assigned + style: Style of the LaTeX description, the default is FUNCTION. + config: Use defined Config object, if it is None, it will be automatic assigned with default value. - **kwargs: dict of Config field values that could be defined individually + **kwargs: Dict of Config field values that could be defined individually by users. Returns: @@ -38,6 +46,9 @@ def get_latex( Raises: latexify.exceptions.LatexifyError: Something went wrong during conversion. """ + if style == Style.EXPRESSION: + kwargs["use_signature"] = kwargs.get("use_signature", False) + merged_config = cfg.Config.defaults().merge(config=config, **kwargs) # Obtains the source AST. @@ -56,11 +67,17 @@ def get_latex( tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree) # Generates LaTeX. - return codegen.FunctionCodegen( - use_math_symbols=merged_config.use_math_symbols, - use_signature=merged_config.use_signature, - use_set_symbols=merged_config.use_set_symbols, - ).visit(tree) + if style == Style.ALGORITHMIC: + return codegen.AlgorithmicCodegen( + use_math_symbols=merged_config.use_math_symbols, + use_set_symbols=merged_config.use_set_symbols, + ).visit(tree) + else: + return codegen.FunctionCodegen( + use_math_symbols=merged_config.use_math_symbols, + use_signature=merged_config.use_signature, + use_set_symbols=merged_config.use_set_symbols, + ).visit(tree) class LatexifiedFunction: @@ -173,7 +190,7 @@ def expression( This function is a shortcut for `latexify.function` with the default parameter `use_signature=False`. """ - kwargs["use_signature"] = kwargs.get("use_signature", False) + kwargs["style"] = Style.EXPRESSION if fn is not None: return function(fn, **kwargs) else: From abac34b69214132c555fc051b7c92275c6a73a20 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 07:53:11 +0000 Subject: [PATCH 02/14] comments and annotations --- .../algorithmic_style_test.py | 4 ++ .../function_expansion_test.py | 38 +++++++------ .../{utils.py => integration_utils.py} | 4 ++ src/integration_tests/regression_test.py | 54 ++++++++++--------- 4 files changed, 59 insertions(+), 41 deletions(-) rename src/integration_tests/{utils.py => integration_utils.py} (93%) diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py index 55b4e15..bda42a5 100644 --- a/src/integration_tests/algorithmic_style_test.py +++ b/src/integration_tests/algorithmic_style_test.py @@ -1,3 +1,7 @@ +"""End-to-end test cases of algorithmic style.""" + +from __future__ import annotations + from typing import Any, Callable from latexify import frontend diff --git a/src/integration_tests/function_expansion_test.py b/src/integration_tests/function_expansion_test.py index 5cc7ab0..d3c373d 100644 --- a/src/integration_tests/function_expansion_test.py +++ b/src/integration_tests/function_expansion_test.py @@ -1,6 +1,10 @@ +"""End-to-end test cases of function expansion.""" + +from __future__ import annotations + import math -from integration_tests import utils +from integration_tests import integration_utils def test_atan2() -> None: @@ -11,7 +15,7 @@ def solve(x, y): r"\mathrm{solve}(x, y) =" r" \arctan \mathopen{}\left( \frac{y}{x} \mathclose{}\right)" ) - utils.check_function(solve, latex, expand_functions={"atan2"}) + integration_utils.check_function(solve, latex, expand_functions={"atan2"}) def test_atan2_nested() -> None: @@ -22,7 +26,7 @@ def solve(x, y): r"\mathrm{solve}(x, y) =" r" \arctan \mathopen{}\left( \frac{e^{y}}{e^{x}} \mathclose{}\right)" ) - utils.check_function(solve, latex, expand_functions={"atan2", "exp"}) + integration_utils.check_function(solve, latex, expand_functions={"atan2", "exp"}) def test_exp() -> None: @@ -30,7 +34,7 @@ def solve(x): return math.exp(x) latex = r"\mathrm{solve}(x) = e^{x}" - utils.check_function(solve, latex, expand_functions={"exp"}) + integration_utils.check_function(solve, latex, expand_functions={"exp"}) def test_exp_nested() -> None: @@ -38,7 +42,7 @@ def solve(x): return math.exp(math.exp(x)) latex = r"\mathrm{solve}(x) = e^{e^{x}}" - utils.check_function(solve, latex, expand_functions={"exp"}) + integration_utils.check_function(solve, latex, expand_functions={"exp"}) def test_exp2() -> None: @@ -46,7 +50,7 @@ def solve(x): return math.exp2(x) latex = r"\mathrm{solve}(x) = 2^{x}" - utils.check_function(solve, latex, expand_functions={"exp2"}) + integration_utils.check_function(solve, latex, expand_functions={"exp2"}) def test_exp2_nested() -> None: @@ -54,7 +58,7 @@ def solve(x): return math.exp2(math.exp2(x)) latex = r"\mathrm{solve}(x) = 2^{2^{x}}" - utils.check_function(solve, latex, expand_functions={"exp2"}) + integration_utils.check_function(solve, latex, expand_functions={"exp2"}) def test_expm1() -> None: @@ -62,7 +66,7 @@ def solve(x): return math.expm1(x) latex = r"\mathrm{solve}(x) = \exp x - 1" - utils.check_function(solve, latex, expand_functions={"expm1"}) + integration_utils.check_function(solve, latex, expand_functions={"expm1"}) def test_expm1_nested() -> None: @@ -70,7 +74,9 @@ def solve(x, y, z): return math.expm1(math.pow(y, z)) latex = r"\mathrm{solve}(x, y, z) = e^{y^{z}} - 1" - utils.check_function(solve, latex, expand_functions={"expm1", "exp", "pow"}) + integration_utils.check_function( + solve, latex, expand_functions={"expm1", "exp", "pow"} + ) def test_hypot_without_attribute() -> None: @@ -80,7 +86,7 @@ def solve(x, y, z): return hypot(x, y, z) latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{2} + y^{2} + z^{2} }" - utils.check_function(solve, latex, expand_functions={"hypot"}) + integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_hypot() -> None: @@ -88,7 +94,7 @@ def solve(x, y, z): return math.hypot(x, y, z) latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{2} + y^{2} + z^{2} }" - utils.check_function(solve, latex, expand_functions={"hypot"}) + integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_hypot_nested() -> None: @@ -99,7 +105,7 @@ def solve(a, b, x, y): r"\mathrm{solve}(a, b, x, y) =" r" \sqrt{ \sqrt{ a^{2} + b^{2} }^{2} + x^{2} + y^{2} }" ) - utils.check_function(solve, latex, expand_functions={"hypot"}) + integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_log1p() -> None: @@ -107,7 +113,7 @@ def solve(x): return math.log1p(x) latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( 1 + x \mathclose{}\right)" - utils.check_function(solve, latex, expand_functions={"log1p"}) + integration_utils.check_function(solve, latex, expand_functions={"log1p"}) def test_log1p_nested() -> None: @@ -115,7 +121,7 @@ def solve(x): return math.log1p(math.exp(x)) latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( 1 + e^{x} \mathclose{}\right)" - utils.check_function(solve, latex, expand_functions={"log1p", "exp"}) + integration_utils.check_function(solve, latex, expand_functions={"log1p", "exp"}) def test_pow_nested() -> None: @@ -126,7 +132,7 @@ def solve(w, x, y, z): r"\mathrm{solve}(w, x, y, z) = " r"\mathopen{}\left( w^{x} \mathclose{}\right)^{y^{z}}" ) - utils.check_function(solve, latex, expand_functions={"pow"}) + integration_utils.check_function(solve, latex, expand_functions={"pow"}) def test_pow() -> None: @@ -134,4 +140,4 @@ def solve(x, y): return math.pow(x, y) latex = r"\mathrm{solve}(x, y) = x^{y}" - utils.check_function(solve, latex, expand_functions={"pow"}) + integration_utils.check_function(solve, latex, expand_functions={"pow"}) diff --git a/src/integration_tests/utils.py b/src/integration_tests/integration_utils.py similarity index 93% rename from src/integration_tests/utils.py rename to src/integration_tests/integration_utils.py index ecb4d5d..92ada9a 100644 --- a/src/integration_tests/utils.py +++ b/src/integration_tests/integration_utils.py @@ -1,3 +1,7 @@ +"""Utilities for integration tests.""" + +from __future__ import annotations + from typing import Any, Callable from latexify import frontend diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index e698cdf..f421413 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -4,7 +4,7 @@ import math -from integration_tests import utils +from integration_tests import integration_utils def test_quadratic_solution() -> None: @@ -12,7 +12,7 @@ def solve(a, b, c): return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a) latex = r"\mathrm{solve}(a, b, c) = \frac{-b + \sqrt{ b^{2} - 4 a c }}{2 a}" - utils.check_function(solve, latex) + integration_utils.check_function(solve, latex) def test_sinc() -> None: @@ -29,7 +29,7 @@ def sinc(x): r" \frac{\sin x}{x}, & \mathrm{otherwise}" r" \end{array} \right." ) - utils.check_function(sinc, latex) + integration_utils.check_function(sinc, latex) def test_x_times_beta() -> None: @@ -37,11 +37,15 @@ def xtimesbeta(x, beta): return x * beta latex_without_symbols = r"\mathrm{xtimesbeta}(x, \mathrm{beta}) = x \mathrm{beta}" - utils.check_function(xtimesbeta, latex_without_symbols) - utils.check_function(xtimesbeta, latex_without_symbols, use_math_symbols=False) + integration_utils.check_function(xtimesbeta, latex_without_symbols) + integration_utils.check_function( + xtimesbeta, latex_without_symbols, use_math_symbols=False + ) latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \beta" - utils.check_function(xtimesbeta, latex_with_symbols, use_math_symbols=True) + integration_utils.check_function( + xtimesbeta, latex_with_symbols, use_math_symbols=True + ) def test_sum_with_limit_1arg() -> None: @@ -52,7 +56,7 @@ def sum_with_limit(n): r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n - 1}" r" \mathopen{}\left({i^{2}}\mathclose{}\right)" ) - utils.check_function(sum_with_limit, latex) + integration_utils.check_function(sum_with_limit, latex) def test_sum_with_limit_2args() -> None: @@ -63,7 +67,7 @@ def sum_with_limit(a, n): r"\mathrm{sum\_with\_limit}(a, n) = \sum_{i = a}^{n - 1} " r"\mathopen{}\left({i^{2}}\mathclose{}\right)" ) - utils.check_function(sum_with_limit, latex) + integration_utils.check_function(sum_with_limit, latex) def test_sum_with_reducible_limit() -> None: @@ -74,7 +78,7 @@ def sum_with_limit(n): r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n} " r"\mathopen{}\left({i}\mathclose{}\right)" ) - utils.check_function(sum_with_limit, latex) + integration_utils.check_function(sum_with_limit, latex) def test_sum_with_irreducible_limit() -> None: @@ -85,7 +89,7 @@ def sum_with_limit(n): r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n 3 - 1} " r"\mathopen{}\left({i}\mathclose{}\right)" ) - utils.check_function(sum_with_limit, latex) + integration_utils.check_function(sum_with_limit, latex) def test_prod_with_limit_1arg() -> None: @@ -96,7 +100,7 @@ def prod_with_limit(n): r"\mathrm{prod\_with\_limit}(n) = " r"\prod_{i = 0}^{n - 1} \mathopen{}\left({i^{2}}\mathclose{}\right)" ) - utils.check_function(prod_with_limit, latex) + integration_utils.check_function(prod_with_limit, latex) def test_prod_with_limit_2args() -> None: @@ -107,7 +111,7 @@ def prod_with_limit(a, n): r"\mathrm{prod\_with\_limit}(a, n) = " r"\prod_{i = a}^{n - 1} \mathopen{}\left({i^{2}}\mathclose{}\right)" ) - utils.check_function(prod_with_limit, latex) + integration_utils.check_function(prod_with_limit, latex) def test_prod_with_reducible_limits() -> None: @@ -118,7 +122,7 @@ def prod_with_limit(n): r"\mathrm{prod\_with\_limit}(n) = " r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)" ) - utils.check_function(prod_with_limit, latex) + integration_utils.check_function(prod_with_limit, latex) def test_prod_with_irreducible_limit() -> None: @@ -129,14 +133,14 @@ def prod_with_limit(n): r"\mathrm{prod\_with\_limit}(n) = " r"\prod_{i = 0}^{n 3 - 1} \mathopen{}\left({i}\mathclose{}\right)" ) - utils.check_function(prod_with_limit, latex) + integration_utils.check_function(prod_with_limit, latex) def test_nested_function() -> None: def nested(x): return 3 * x - utils.check_function(nested, r"\mathrm{nested}(x) = 3 x") + integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 x") def test_double_nested_function() -> None: @@ -146,7 +150,7 @@ def inner(y): return inner - utils.check_function(nested(3), r"\mathrm{inner}(y) = x y") + integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x y") def test_reduce_assignments() -> None: @@ -154,11 +158,11 @@ def f(x): a = x + x return 3 * a - utils.check_function( + integration_utils.check_function( f, r"\begin{array}{l} a = x + x \\ f(x) = 3 a \end{array}", ) - utils.check_function( + integration_utils.check_function( f, r"f(x) = 3 \mathopen{}\left( x + x \mathclose{}\right)", reduce_assignments=True, @@ -179,9 +183,9 @@ def f(x): r"\end{array}" ) - utils.check_function(f, latex_without_option) - utils.check_function(f, latex_without_option, reduce_assignments=False) - utils.check_function( + integration_utils.check_function(f, latex_without_option) + integration_utils.check_function(f, latex_without_option, reduce_assignments=False) + integration_utils.check_function( f, r"f(x) = 3 \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)", reduce_assignments=True, @@ -197,7 +201,7 @@ def sigmoid(x): else: return n - utils.check_function( + integration_utils.check_function( sigmoid, ( r"\mathrm{sigmoid}(x) = \left\{ \begin{array}{ll} " @@ -221,7 +225,7 @@ def solve(a, b): r"a + b \mathclose{}\right) - \mathopen{}\left( " r"a - b \mathclose{}\right) - a b" ) - utils.check_function(solve, latex) + integration_utils.check_function(solve, latex) def test_docstring_allowed() -> None: @@ -230,7 +234,7 @@ def solve(x): return x latex = r"\mathrm{solve}(x) = x" - utils.check_function(solve, latex) + integration_utils.check_function(solve, latex) def test_multiple_constants_allowed() -> None: @@ -241,4 +245,4 @@ def solve(x): return x latex = r"\mathrm{solve}(x) = x" - utils.check_function(solve, latex) + integration_utils.check_function(solve, latex) From c0e4b279a7454149bf6d8fa8da752b106d6743b4 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 08:02:11 +0000 Subject: [PATCH 03/14] fix indentation --- src/latexify/codegen/function_codegen_test.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index b8e363a..d8d4f19 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -26,8 +26,8 @@ def test_visit_functiondef_use_signature() -> None: tree = ast.parse( textwrap.dedent( """ - def f(x): - return x + def f(x): + return x """ ) ).body[0] @@ -50,9 +50,9 @@ def test_visit_functiondef_ignore_docstring() -> None: tree = ast.parse( textwrap.dedent( """ - def f(x): - '''docstring''' - return x + def f(x): + '''docstring''' + return x """ ) ).body[0] @@ -66,11 +66,11 @@ def test_visit_functiondef_ignore_multiple_constants() -> None: tree = ast.parse( textwrap.dedent( """ - def f(x): - '''docstring''' - 3 - True - return x + def f(x): + '''docstring''' + 3 + True + return x """ ) ).body[0] From bbd3f747f7eceb14596d90d7d6649ec5715145f8 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 08:34:03 +0000 Subject: [PATCH 04/14] tests, procedure --- .../algorithmic_style_test.py | 4 + src/latexify/codegen/algorithmic_codegen.py | 40 +++++- .../codegen/algorithmic_codegen_test.py | 125 ++++++++++++++++++ src/latexify/codegen/expression_codegen.py | 4 + 4 files changed, 167 insertions(+), 6 deletions(-) create mode 100644 src/latexify/codegen/algorithmic_codegen_test.py diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py index bda42a5..33f8f06 100644 --- a/src/integration_tests/algorithmic_style_test.py +++ b/src/integration_tests/algorithmic_style_test.py @@ -35,11 +35,13 @@ def fact(n): latex = ( r"\begin{algorithmic} " + r"\Procedure{fact}{$n$} " r"\If{$n = 0$} " r"\State \Return $1$ " r"\Else " r"\State \Return $n \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$ " r"\EndIf " + r"\EndProcedure " r"\end{algorithmic}" ) check_algorithm(fact, latex) @@ -58,6 +60,7 @@ def collatz(n): latex = ( r"\begin{algorithmic} " + r"\Procedure{collatz}{$n$} " r"\State $\mathrm{iterations} \gets 0$ " r"\While{$n > 1$} " r"\If{$n \mathbin{\%} 2 = 0$} " @@ -67,6 +70,7 @@ def collatz(n): r"\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$ " r"\EndWhile " r"\State \Return $\mathrm{iterations}$ " + r"\EndProcedure " r"\end{algorithmic}" ) diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index f39ac74..4db605a 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -2,12 +2,14 @@ import ast from latexify import exceptions -from latexify.codegen import codegen_utils, expression_codegen +from latexify.codegen import codegen_utils, expression_codegen, identifier_converter class AlgorithmicCodegen(ast.NodeVisitor): """Codegen for single algorithms.""" + _identifier_converter: identifier_converter.IdentifierConverter + def __init__( self, *, use_math_symbols: bool = False, use_set_symbols: bool = False ) -> None: @@ -20,6 +22,9 @@ def __init__( self._expression_codegen = expression_codegen.ExpressionCodegen( use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols ) + self._identifier_converter = identifier_converter.IdentifierConverter( + use_math_symbols=use_math_symbols + ) def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( @@ -27,6 +32,7 @@ def generic_visit(self, node: ast.AST) -> str: ) def visit_Assign(self, node: ast.Assign) -> str: + """Visit an Assign node.""" operands: list[str] = [ self._expression_codegen.visit(target) for target in node.targets ] @@ -34,11 +40,29 @@ def visit_Assign(self, node: ast.Assign) -> str: operands_latex = r" \gets ".join(operands) return rf"\State ${operands_latex}$" + def visit_Expr(self, node: ast.Expr) -> str: + """Visit an Expr node.""" + return rf"\State ${self._expression_codegen.visit(node)}$" + def visit_FunctionDef(self, node: ast.FunctionDef) -> str: + """Visit a FunctionDef node.""" + # Arguments + arg_strs = [ + self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args + ] + body_strs: list[str] = [self.visit(stmt) for stmt in node.body] - return rf"\begin{{algorithmic}} {' '.join(body_strs)} \end{{algorithmic}}" + return ( + rf"\begin{{algorithmic}} " + rf"\Procedure{{{node.name}}}{{${','.join(arg_strs)}$}} " + f"{' '.join(body_strs)} " + r"\EndProcedure " + rf"\end{{algorithmic}}" + ) + # TODO(ZibingZhang): support \ELSIF def visit_If(self, node: ast.If) -> str: + """Visit an If node.""" cond_latex = self._expression_codegen.visit(node.test) body_latex = " ".join(self.visit(stmt) for stmt in node.body) @@ -51,9 +75,11 @@ def visit_If(self, node: ast.If) -> str: return latex + r" \EndIf" def visit_Module(self, node: ast.Module) -> str: + """Visit a Module node.""" return self.visit(node.body[0]) def visit_Return(self, node: ast.Return) -> str: + """Visit a Return node.""" return ( rf"\State \Return ${self._expression_codegen.visit(node.value)}$" if node.value is not None @@ -61,13 +87,15 @@ def visit_Return(self, node: ast.Return) -> str: ) def visit_While(self, node: ast.While) -> str: + """Visit a While node.""" + if node.orelse: + raise exceptions.LatexifyNotSupportedError( + "Codegen does not support while statements with an else clause." + ) + cond_latex = self._expression_codegen.visit(node.test) body_latex = " ".join(self.visit(stmt) for stmt in node.body) latex = rf"\While{{${cond_latex}$}} {body_latex}" - if node.orelse: - latex += r" \Else " - latex += " ".join(self.visit(stmt) for stmt in node.orelse) - return latex + r" \EndWhile" diff --git a/src/latexify/codegen/algorithmic_codegen_test.py b/src/latexify/codegen/algorithmic_codegen_test.py new file mode 100644 index 0000000..99e0345 --- /dev/null +++ b/src/latexify/codegen/algorithmic_codegen_test.py @@ -0,0 +1,125 @@ +"""Tests for latexify.codegen.algorithmic_codegen.""" + +from __future__ import annotations + +import ast +import textwrap + +import pytest + +from latexify import exceptions +from latexify.codegen import algorithmic_codegen + + +def test_generic_visit() -> None: + class UnknownNode(ast.AST): + pass + + with pytest.raises( + exceptions.LatexifyNotSupportedError, + match=r"^Unsupported AST: UnknownNode$", + ): + algorithmic_codegen.AlgorithmicCodegen().visit(UnknownNode()) + + +@pytest.mark.parametrize( + "code,latex", + [("x = 3", r"\State $x \gets 3$"), ("a = b = 0", r"\State $a \gets b \gets 0$")], +) +def test_visit_assign(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.Assign) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "def f(x): return x", + ( + r"\begin{algorithmic} " + r"\Procedure{f}{$x$} " + r"\State \Return $x$ " + r"\EndProcedure " + r"\end{algorithmic}" + ), + ), + ( + "def xyz(a, b, c): return 3", + ( + r"\begin{algorithmic} " + r"\Procedure{xyz}{$a,b,c$} " + r"\State \Return $3$ " + r"\EndProcedure " + r"\end{algorithmic}" + ), + ), + ], +) +def test_visit_functiondef(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.FunctionDef) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("if x < y: return x", r"\If{$x < y$} \State \Return $x$ \EndIf"), + ( + "if True: x\nelse: y", + r"\If{$\mathrm{True}$} \State $x$ \Else \State $y$ \EndIf", + ), + ], +) +def test_visit_if(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.If) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "return x + y", + r"\State \Return $x + y$", + ) + ], +) +def test_visit_return(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.Return) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "while x < y: x = x + 1", + r"\While{$x < y$} \State $x \gets x + 1$ \EndWhile", + ) + ], +) +def test_visit_while(code: str, latex: str) -> None: + node = ast.parse(textwrap.dedent(code)).body[0] + assert isinstance(node, ast.While) + assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex + + +def test_visit_while_with_else() -> None: + with pytest.raises(exceptions.LatexifyNotSupportedError): + node = ast.parse( + textwrap.dedent( + """ + while True: + x = x + else: + x = y + """ + ) + ).body[0] + assert isinstance(node, ast.While) + algorithmic_codegen.AlgorithmicCodegen().visit(node) diff --git a/src/latexify/codegen/expression_codegen.py b/src/latexify/codegen/expression_codegen.py index 54bee1b..10050e0 100644 --- a/src/latexify/codegen/expression_codegen.py +++ b/src/latexify/codegen/expression_codegen.py @@ -227,6 +227,10 @@ def generic_visit(self, node: ast.AST) -> str: f"Unsupported AST: {type(node).__name__}" ) + def visit_Expr(self, node: ast.Expr) -> str: + """Visit an Expr node.""" + return self.visit(node.value) + def visit_Tuple(self, node: ast.Tuple) -> str: """Visit a Tuple node.""" elts = [self.visit(elt) for elt in node.elts] From 349464b1c0ed18817fa19933713bfcf59039bbce Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 08:35:45 +0000 Subject: [PATCH 05/14] forgot one --- src/latexify/codegen/algorithmic_codegen.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index 4db605a..a423952 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -1,4 +1,7 @@ """Codegen for single algorithms.""" + +from __future__ import annotations + import ast from latexify import exceptions From af7a3c82e421fde7c0cae915743ebb93375dfaba Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 08:42:53 +0000 Subject: [PATCH 06/14] expose Style, tests --- src/latexify/__init__.py | 2 ++ src/latexify/codegen/algorithmic_codegen.py | 11 ++++++++--- src/latexify/codegen/algorithmic_codegen_test.py | 6 +++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/latexify/__init__.py b/src/latexify/__init__.py index 0d567f0..f25f699 100644 --- a/src/latexify/__init__.py +++ b/src/latexify/__init__.py @@ -9,6 +9,8 @@ from latexify import frontend +Style = frontend.Style + get_latex = frontend.get_latex function = frontend.function diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index a423952..89ff65a 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -9,7 +9,11 @@ class AlgorithmicCodegen(ast.NodeVisitor): - """Codegen for single algorithms.""" + """Codegen for single algorithms. + + This codegen works for Module with single FunctionDef node to generate a single + LaTeX expression of the given algorithm. + """ _identifier_converter: identifier_converter.IdentifierConverter @@ -17,6 +21,7 @@ def __init__( self, *, use_math_symbols: bool = False, use_set_symbols: bool = False ) -> None: """Initializer. + Args: use_math_symbols: Whether to convert identifiers with a math symbol surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). @@ -53,7 +58,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: arg_strs = [ self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args ] - + # Body body_strs: list[str] = [self.visit(stmt) for stmt in node.body] return ( rf"\begin{{algorithmic}} " @@ -86,7 +91,7 @@ def visit_Return(self, node: ast.Return) -> str: return ( rf"\State \Return ${self._expression_codegen.visit(node.value)}$" if node.value is not None - else codegen_utils.convert_constant(None) + else rf"\State \Return ${codegen_utils.convert_constant(None)}$" ) def visit_While(self, node: ast.While) -> str: diff --git a/src/latexify/codegen/algorithmic_codegen_test.py b/src/latexify/codegen/algorithmic_codegen_test.py index 99e0345..d9a88ea 100644 --- a/src/latexify/codegen/algorithmic_codegen_test.py +++ b/src/latexify/codegen/algorithmic_codegen_test.py @@ -85,7 +85,11 @@ def test_visit_if(code: str, latex: str) -> None: ( "return x + y", r"\State \Return $x + y$", - ) + ), + ( + "return", + r"\State \Return $\mathrm{None}$", + ), ], ) def test_visit_return(code: str, latex: str) -> None: From 8c9163994c51263cb715952bd398ca41c2195488 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 08:53:48 +0000 Subject: [PATCH 07/14] bug --- src/integration_tests/algorithmic_style_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py index 33f8f06..990a20e 100644 --- a/src/integration_tests/algorithmic_style_test.py +++ b/src/integration_tests/algorithmic_style_test.py @@ -26,6 +26,8 @@ def check_algorithm( assert latexified == latex +# TODO(ZibingZhang) changing fact -> factorial breaks because factorial is replaced by +# !, substitutions should not occur for the name of the procedure def test_factorial() -> None: def fact(n): if n == 0: From 280f416edaf47ae01ef2a482b7125cdcacf2fdee Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 09:16:58 +0000 Subject: [PATCH 08/14] rm visit_Expr from expr_codegen --- src/latexify/codegen/algorithmic_codegen.py | 2 +- src/latexify/codegen/expression_codegen.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index 89ff65a..f3b7fc0 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -50,7 +50,7 @@ def visit_Assign(self, node: ast.Assign) -> str: def visit_Expr(self, node: ast.Expr) -> str: """Visit an Expr node.""" - return rf"\State ${self._expression_codegen.visit(node)}$" + return rf"\State ${self._expression_codegen.visit(node.value)}$" def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" diff --git a/src/latexify/codegen/expression_codegen.py b/src/latexify/codegen/expression_codegen.py index 10050e0..54bee1b 100644 --- a/src/latexify/codegen/expression_codegen.py +++ b/src/latexify/codegen/expression_codegen.py @@ -227,10 +227,6 @@ def generic_visit(self, node: ast.AST) -> str: f"Unsupported AST: {type(node).__name__}" ) - def visit_Expr(self, node: ast.Expr) -> str: - """Visit an Expr node.""" - return self.visit(node.value) - def visit_Tuple(self, node: ast.Tuple) -> str: """Visit a Tuple node.""" elts = [self.visit(elt) for elt in node.elts] From 0c00706fde9eed843139ea2fca3f255b71ce8621 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 09:54:13 +0000 Subject: [PATCH 09/14] rm line --- src/integration_tests/algorithmic_style_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py index 990a20e..9d28b27 100644 --- a/src/integration_tests/algorithmic_style_test.py +++ b/src/integration_tests/algorithmic_style_test.py @@ -75,5 +75,4 @@ def collatz(n): r"\EndProcedure " r"\end{algorithmic}" ) - check_algorithm(collatz, latex) From a66f8af2c3695ca57efd58a7bbe8aa24edb2f550 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 09:58:02 +0000 Subject: [PATCH 10/14] inline some code --- src/latexify/codegen/algorithmic_codegen.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index f3b7fc0..9c5d39a 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -88,10 +88,10 @@ def visit_Module(self, node: ast.Module) -> str: def visit_Return(self, node: ast.Return) -> str: """Visit a Return node.""" - return ( - rf"\State \Return ${self._expression_codegen.visit(node.value)}$" + return r"\State \Return " + ( + f"${self._expression_codegen.visit(node.value)}$" if node.value is not None - else rf"\State \Return ${codegen_utils.convert_constant(None)}$" + else f"${codegen_utils.convert_constant(None)}$" ) def visit_While(self, node: ast.While) -> str: @@ -103,7 +103,4 @@ def visit_While(self, node: ast.While) -> str: cond_latex = self._expression_codegen.visit(node.test) body_latex = " ".join(self.visit(stmt) for stmt in node.body) - - latex = rf"\While{{${cond_latex}$}} {body_latex}" - - return latex + r" \EndWhile" + return rf"\While{{${cond_latex}$}} {body_latex} \EndWhile" From 119de914a32cae4e3c9406a06fad25ddc808289e Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 09:59:32 +0000 Subject: [PATCH 11/14] specify codegen --- src/latexify/codegen/algorithmic_codegen.py | 2 +- src/latexify/codegen/function_codegen.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index 9c5d39a..53cd7c8 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -98,7 +98,7 @@ def visit_While(self, node: ast.While) -> str: """Visit a While node.""" if node.orelse: raise exceptions.LatexifyNotSupportedError( - "Codegen does not support while statements with an else clause." + "Algorithmic codegen does not support while statements with an else clause." ) cond_latex = self._expression_codegen.visit(node.test) diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index c9b01e2..d22247d 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -71,7 +71,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: if not isinstance(child, ast.Assign): raise exceptions.LatexifyNotSupportedError( - "Codegen supports only Assign nodes in multiline functions, " + "Function codegen supports only Assign nodes in multiline functions, " f"but got: {type(child).__name__}" ) body_strs.append(self.visit(child)) From 5841f93f46e674f0c7f82b7b1423f59e569481fc Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 10:01:47 +0000 Subject: [PATCH 12/14] too long --- src/latexify/codegen/algorithmic_codegen.py | 3 ++- src/latexify/codegen/function_codegen.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index 53cd7c8..4ee4e1b 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -98,7 +98,8 @@ def visit_While(self, node: ast.While) -> str: """Visit a While node.""" if node.orelse: raise exceptions.LatexifyNotSupportedError( - "Algorithmic codegen does not support while statements with an else clause." + "Algorithmic codegen does not support while statements with an else " + "clause" ) cond_latex = self._expression_codegen.visit(node.test) diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index d22247d..a2f3b47 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -71,8 +71,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: if not isinstance(child, ast.Assign): raise exceptions.LatexifyNotSupportedError( - "Function codegen supports only Assign nodes in multiline functions, " - f"but got: {type(child).__name__}" + "Function codegen supports only Assign nodes in multiline " + f"functions, but got: {type(child).__name__}" ) body_strs.append(self.visit(child)) From 3634a67e34e7d9470e236c707223e714e825e922 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 10:26:33 +0000 Subject: [PATCH 13/14] suggestions --- .../algorithmic_style_test.py | 47 ++++++++-------- src/latexify/codegen/algorithmic_codegen.py | 21 ++++---- .../codegen/algorithmic_codegen_test.py | 54 ++++++++++--------- src/latexify/codegen/function_codegen.py | 4 +- src/latexify/frontend.py | 2 +- 5 files changed, 67 insertions(+), 61 deletions(-) diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py index 9d28b27..3d0d8f0 100644 --- a/src/integration_tests/algorithmic_style_test.py +++ b/src/integration_tests/algorithmic_style_test.py @@ -13,6 +13,7 @@ def check_algorithm( **kwargs, ) -> None: """Helper to check if the obtained function has the expected LaTeX form. + Args: fn: Function to check. latex: LaTeX form of `fn`. @@ -36,15 +37,15 @@ def fact(n): return n * fact(n - 1) latex = ( - r"\begin{algorithmic} " - r"\Procedure{fact}{$n$} " - r"\If{$n = 0$} " - r"\State \Return $1$ " - r"\Else " - r"\State \Return $n \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$ " - r"\EndIf " - r"\EndProcedure " - r"\end{algorithmic}" + r"\begin{algorithmic}" + r" \Function{fact}{$n$}" + r" \If{$n = 0$}" + r" \State \Return $1$" + r" \Else" + r" \State \Return $n \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$" + r" \EndIf" + r" \EndFunction" + r" \end{algorithmic}" ) check_algorithm(fact, latex) @@ -54,25 +55,25 @@ def collatz(n): iterations = 0 while n > 1: if n % 2 == 0: - n = n / 2 + n = n // 2 else: n = 3 * n + 1 iterations = iterations + 1 return iterations latex = ( - r"\begin{algorithmic} " - r"\Procedure{collatz}{$n$} " - r"\State $\mathrm{iterations} \gets 0$ " - r"\While{$n > 1$} " - r"\If{$n \mathbin{\%} 2 = 0$} " - r"\State $n \gets \frac{n}{2}$ " - r"\Else \State $n \gets 3 n + 1$ " - r"\EndIf " - r"\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$ " - r"\EndWhile " - r"\State \Return $\mathrm{iterations}$ " - r"\EndProcedure " - r"\end{algorithmic}" + r"\begin{algorithmic}" + r" \Function{collatz}{$n$}" + r" \State $\mathrm{iterations} \gets 0$" + r" \While{$n > 1$}" + r" \If{$n \mathbin{\%} 2 = 0$}" + r" \State $n \gets \left\lfloor\frac{n}{2}\right\rfloor$" + r" \Else \State $n \gets 3 n + 1$" + r" \EndIf" + r" \State $\mathrm{iterations} \gets \mathrm{iterations} + 1$" + r" \EndWhile" + r" \State \Return $\mathrm{iterations}$" + r" \EndFunction" + r" \end{algorithmic}" ) check_algorithm(collatz, latex) diff --git a/src/latexify/codegen/algorithmic_codegen.py b/src/latexify/codegen/algorithmic_codegen.py index 4ee4e1b..2fcb16b 100644 --- a/src/latexify/codegen/algorithmic_codegen.py +++ b/src/latexify/codegen/algorithmic_codegen.py @@ -5,7 +5,7 @@ import ast from latexify import exceptions -from latexify.codegen import codegen_utils, expression_codegen, identifier_converter +from latexify.codegen import expression_codegen, identifier_converter class AlgorithmicCodegen(ast.NodeVisitor): @@ -61,11 +61,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: # Body body_strs: list[str] = [self.visit(stmt) for stmt in node.body] return ( - rf"\begin{{algorithmic}} " - rf"\Procedure{{{node.name}}}{{${','.join(arg_strs)}$}} " - f"{' '.join(body_strs)} " - r"\EndProcedure " - rf"\end{{algorithmic}}" + rf"\begin{{algorithmic}}" + rf" \Function{{{node.name}}}{{${', '.join(arg_strs)}$}}" + f" {' '.join(body_strs)}" + r" \EndFunction" + rf" \end{{algorithmic}}" ) # TODO(ZibingZhang): support \ELSIF @@ -88,18 +88,17 @@ def visit_Module(self, node: ast.Module) -> str: def visit_Return(self, node: ast.Return) -> str: """Visit a Return node.""" - return r"\State \Return " + ( - f"${self._expression_codegen.visit(node.value)}$" + return ( + rf"\State \Return ${self._expression_codegen.visit(node.value)}$" if node.value is not None - else f"${codegen_utils.convert_constant(None)}$" + else r"\State \Return" ) def visit_While(self, node: ast.While) -> str: """Visit a While node.""" if node.orelse: raise exceptions.LatexifyNotSupportedError( - "Algorithmic codegen does not support while statements with an else " - "clause" + "While statement with the else clause is not supported" ) cond_latex = self._expression_codegen.visit(node.test) diff --git a/src/latexify/codegen/algorithmic_codegen_test.py b/src/latexify/codegen/algorithmic_codegen_test.py index d9a88ea..a972beb 100644 --- a/src/latexify/codegen/algorithmic_codegen_test.py +++ b/src/latexify/codegen/algorithmic_codegen_test.py @@ -24,7 +24,10 @@ class UnknownNode(ast.AST): @pytest.mark.parametrize( "code,latex", - [("x = 3", r"\State $x \gets 3$"), ("a = b = 0", r"\State $a \gets b \gets 0$")], + [ + ("x = 3", r"\State $x \gets 3$"), + ("a = b = 0", r"\State $a \gets b \gets 0$"), + ], ) def test_visit_assign(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] @@ -38,21 +41,21 @@ def test_visit_assign(code: str, latex: str) -> None: ( "def f(x): return x", ( - r"\begin{algorithmic} " - r"\Procedure{f}{$x$} " - r"\State \Return $x$ " - r"\EndProcedure " - r"\end{algorithmic}" + r"\begin{algorithmic}" + r" \Function{f}{$x$}" + r" \State \Return $x$" + r" \EndFunction" + r" \end{algorithmic}" ), ), ( "def xyz(a, b, c): return 3", ( - r"\begin{algorithmic} " - r"\Procedure{xyz}{$a,b,c$} " - r"\State \Return $3$ " - r"\EndProcedure " - r"\end{algorithmic}" + r"\begin{algorithmic}" + r" \Function{xyz}{$a, b, c$}" + r" \State \Return $3$" + r" \EndFunction" + r" \end{algorithmic}" ), ), ], @@ -88,7 +91,7 @@ def test_visit_if(code: str, latex: str) -> None: ), ( "return", - r"\State \Return $\mathrm{None}$", + r"\State \Return", ), ], ) @@ -114,16 +117,19 @@ def test_visit_while(code: str, latex: str) -> None: def test_visit_while_with_else() -> None: - with pytest.raises(exceptions.LatexifyNotSupportedError): - node = ast.parse( - textwrap.dedent( - """ - while True: - x = x - else: - x = y - """ - ) - ).body[0] - assert isinstance(node, ast.While) + node = ast.parse( + textwrap.dedent( + """ + while True: + x = x + else: + x = y + """ + ) + ).body[0] + assert isinstance(node, ast.While) + with pytest.raises( + exceptions.LatexifyNotSupportedError, + match="^While statement with the else clause is not supported$", + ): algorithmic_codegen.AlgorithmicCodegen().visit(node) diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index a2f3b47..c9b01e2 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -71,8 +71,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: if not isinstance(child, ast.Assign): raise exceptions.LatexifyNotSupportedError( - "Function codegen supports only Assign nodes in multiline " - f"functions, but got: {type(child).__name__}" + "Codegen supports only Assign nodes in multiline functions, " + f"but got: {type(child).__name__}" ) body_strs.append(self.visit(child)) diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py index 9dd98a8..83a5d7e 100644 --- a/src/latexify/frontend.py +++ b/src/latexify/frontend.py @@ -17,7 +17,7 @@ _COMMON_PREFIXES = {"math", "numpy", "np"} -class Style(str, enum.Enum): +class Style(enum.Enum): EXPRESSION = "expression" FUNCTION = "function" ALGORITHMIC = "algorithmic" From 628a2da0a09d844b5e16f4a5deccfb083d5e1b56 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sat, 10 Dec 2022 10:27:21 +0000 Subject: [PATCH 14/14] rm todo --- src/integration_tests/algorithmic_style_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/integration_tests/algorithmic_style_test.py b/src/integration_tests/algorithmic_style_test.py index 3d0d8f0..047d6f9 100644 --- a/src/integration_tests/algorithmic_style_test.py +++ b/src/integration_tests/algorithmic_style_test.py @@ -27,8 +27,6 @@ def check_algorithm( assert latexified == latex -# TODO(ZibingZhang) changing fact -> factorial breaks because factorial is replaced by -# !, substitutions should not occur for the name of the procedure def test_factorial() -> None: def fact(n): if n == 0: