diff --git a/src/latexify/codegen/expression_codegen.py b/src/latexify/codegen/expression_codegen.py index 54bee1b..f85f0a7 100644 --- a/src/latexify/codegen/expression_codegen.py +++ b/src/latexify/codegen/expression_codegen.py @@ -347,15 +347,72 @@ def generate_matrix_from_array(data: list[list[str]]) -> str: return generate_matrix_from_array(rows) + def _generate_zeros(self, node: ast.Call) -> str | None: + """Generates LaTeX for numpy.zeros. + Args: + node: ast.Call node containing the appropriate method invocation. + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + """ + name = ast_utils.extract_function_name_or_none(node) + assert name == "zeros" + + if len(node.args) != 1: + return None + + # All args to np.zeros should be numeric. + if isinstance(node.args[0], ast.Tuple): + dims = [ast_utils.extract_int_or_none(x) for x in node.args[0].elts] + if any(x is None for x in dims): + return None + if not dims: + return "0" + if len(dims) == 1: + dims = [1, dims[0]] + + dims_latex = r" \times ".join(str(x) for x in dims) + else: + dim = ast_utils.extract_int_or_none(node.args[0]) + if not isinstance(dim, int): + return None + # 1 x N array of zeros + dims_latex = rf"1 \times {dim}" + + return rf"\mathbf{{0}}^{{{dims_latex}}}" + + def _generate_identity(self, node: ast.Call) -> str | None: + """Generates LaTeX for numpy.identity. + Args: + node: ast.Call node containing the appropriate method invocation. + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + """ + name = ast_utils.extract_function_name_or_none(node) + assert name == "identity" + + if len(node.args) != 1: + return None + + ndims = ast_utils.extract_int_or_none(node.args[0]) + if ndims is None: + return None + + return rf"\mathbf{{I}}_{{{ndims}}}" + def visit_Call(self, node: ast.Call) -> str: """Visit a Call node.""" func_name = ast_utils.extract_function_name_or_none(node) # Special treatments for some functions. + # TODO(odashi): Move these functions to some separate utility. if func_name in ("fsum", "sum", "prod"): special_latex = self._generate_sum_prod(node) elif func_name in ("array", "ndarray"): special_latex = self._generate_matrix(node) + elif func_name == "zeros": + special_latex = self._generate_zeros(node) + elif func_name == "identity": + special_latex = self._generate_identity(node) else: special_latex = None diff --git a/src/latexify/codegen/expression_codegen_test.py b/src/latexify/codegen/expression_codegen_test.py index 5368a13..ee75cef 100644 --- a/src/latexify/codegen/expression_codegen_test.py +++ b/src/latexify/codegen/expression_codegen_test.py @@ -7,7 +7,7 @@ import pytest from latexify import ast_utils, exceptions, test_utils -from latexify.codegen import ExpressionCodegen +from latexify.codegen import expression_codegen def test_generic_visit() -> None: @@ -18,7 +18,7 @@ class UnknownNode(ast.AST): exceptions.LatexifyNotSupportedError, match=r"^Unsupported AST: UnknownNode$", ): - ExpressionCodegen().visit(UnknownNode()) + expression_codegen.ExpressionCodegen().visit(UnknownNode()) @pytest.mark.parametrize( @@ -33,7 +33,7 @@ class UnknownNode(ast.AST): def test_visit_tuple(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Tuple) - assert ExpressionCodegen().visit(node) == latex + assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( @@ -48,7 +48,7 @@ def test_visit_tuple(code: str, latex: str) -> None: def test_visit_list(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.List) - assert ExpressionCodegen().visit(node) == latex + assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( @@ -64,7 +64,7 @@ def test_visit_list(code: str, latex: str) -> None: def test_visit_set(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Set) - assert ExpressionCodegen().visit(node) == latex + assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( @@ -114,7 +114,7 @@ def test_visit_set(code: str, latex: str) -> None: def test_visit_listcomp(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.ListComp) - assert ExpressionCodegen().visit(node) == latex + assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( @@ -164,7 +164,7 @@ def test_visit_listcomp(code: str, latex: str) -> None: def test_visit_setcomp(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.SetComp) - assert ExpressionCodegen().visit(node) == latex + assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( @@ -215,7 +215,7 @@ def test_visit_setcomp(code: str, latex: str) -> None: def test_visit_call(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Call) - assert ExpressionCodegen().visit(node) == latex + assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( @@ -330,7 +330,9 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: for src_fn, dest_fn in [("fsum", r"\sum"), ("sum", r"\sum"), ("prod", r"\prod")]: node = ast_utils.parse_expr(src_fn + src_suffix) assert isinstance(node, ast.Call) - assert ExpressionCodegen().visit(node) == dest_fn + dest_suffix + assert ( + expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix + ) @pytest.mark.parametrize( @@ -381,7 +383,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Call) - assert ExpressionCodegen().visit(node) == latex + assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( @@ -407,7 +409,9 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None: for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]: node = ast_utils.parse_expr(src_fn + src_suffix) assert isinstance(node, ast.Call) - assert ExpressionCodegen().visit(node) == dest_fn + dest_suffix + assert ( + expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix + ) @pytest.mark.parametrize( @@ -442,7 +446,7 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None: def test_if_then_else(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.IfExp) - assert ExpressionCodegen().visit(node) == latex + assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( @@ -625,7 +629,7 @@ def test_if_then_else(code: str, latex: str) -> None: def test_visit_binop(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.BinOp) - assert ExpressionCodegen().visit(tree) == latex + assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( @@ -664,7 +668,7 @@ def test_visit_binop(code: str, latex: str) -> None: def test_visit_unaryop(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.UnaryOp) - assert ExpressionCodegen().visit(tree) == latex + assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( @@ -718,7 +722,7 @@ def test_visit_unaryop(code: str, latex: str) -> None: def test_visit_compare(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Compare) - assert ExpressionCodegen().visit(tree) == latex + assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( @@ -764,7 +768,7 @@ def test_visit_compare(code: str, latex: str) -> None: def test_visit_boolop(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.BoolOp) - assert ExpressionCodegen().visit(tree) == latex + assert expression_codegen.ExpressionCodegen().visit(tree) == latex @test_utils.require_at_most(7) @@ -789,7 +793,7 @@ def test_visit_boolop(code: str, latex: str) -> None: def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, cls) - assert ExpressionCodegen().visit(tree) == latex + assert expression_codegen.ExpressionCodegen().visit(tree) == latex @test_utils.require_at_least(8) @@ -814,7 +818,7 @@ def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> No def test_visit_constant(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Constant) - assert ExpressionCodegen().visit(tree) == latex + assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( @@ -830,7 +834,7 @@ def test_visit_constant(code: str, latex: str) -> None: def test_visit_subscript(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Subscript) - assert ExpressionCodegen().visit(tree) == latex + assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( @@ -845,7 +849,9 @@ def test_visit_subscript(code: str, latex: str) -> None: def test_visit_binop_use_set_symbols(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.BinOp) - assert ExpressionCodegen(use_set_symbols=True).visit(tree) == latex + assert ( + expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex + ) @pytest.mark.parametrize( @@ -860,7 +866,9 @@ def test_visit_binop_use_set_symbols(code: str, latex: str) -> None: def test_visit_compare_use_set_symbols(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Compare) - assert ExpressionCodegen(use_set_symbols=True).visit(tree) == latex + assert ( + expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex + ) @pytest.mark.parametrize( @@ -906,4 +914,59 @@ def test_visit_compare_use_set_symbols(code: str, latex: str) -> None: def test_numpy_array(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) - assert ExpressionCodegen().visit(tree) == latex + assert expression_codegen.ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("zeros(0)", r"\mathbf{0}^{1 \times 0}"), + ("zeros(1)", r"\mathbf{0}^{1 \times 1}"), + ("zeros(2)", r"\mathbf{0}^{1 \times 2}"), + ("zeros(())", r"0"), + ("zeros((0,))", r"\mathbf{0}^{1 \times 0}"), + ("zeros((1,))", r"\mathbf{0}^{1 \times 1}"), + ("zeros((2,))", r"\mathbf{0}^{1 \times 2}"), + ("zeros((0, 0))", r"\mathbf{0}^{0 \times 0}"), + ("zeros((1, 1))", r"\mathbf{0}^{1 \times 1}"), + ("zeros((2, 3))", r"\mathbf{0}^{2 \times 3}"), + ("zeros((0, 0, 0))", r"\mathbf{0}^{0 \times 0 \times 0}"), + ("zeros((1, 1, 1))", r"\mathbf{0}^{1 \times 1 \times 1}"), + ("zeros((2, 3, 5))", r"\mathbf{0}^{2 \times 3 \times 5}"), + # Unsupported + ("zeros()", r"\mathrm{zeros} \mathopen{}\left( \mathclose{}\right)"), + ("zeros(x)", r"\mathrm{zeros} \mathopen{}\left( x \mathclose{}\right)"), + ("zeros(0, x)", r"\mathrm{zeros} \mathopen{}\left( 0, x \mathclose{}\right)"), + ( + "zeros((x,))", + r"\mathrm{zeros} \mathopen{}\left(" + r" \mathopen{}\left( x \mathclose{}\right)" + r" \mathclose{}\right)", + ), + ], +) +def test_zeros(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert expression_codegen.ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("identity(0)", r"\mathbf{I}_{0}"), + ("identity(1)", r"\mathbf{I}_{1}"), + ("identity(2)", r"\mathbf{I}_{2}"), + # Unsupported + ("identity()", r"\mathrm{identity} \mathopen{}\left( \mathclose{}\right)"), + ("identity(x)", r"\mathrm{identity} \mathopen{}\left( x \mathclose{}\right)"), + ( + "identity(0, x)", + r"\mathrm{identity} \mathopen{}\left( 0, x \mathclose{}\right)", + ), + ], +) +def test_identity(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert expression_codegen.ExpressionCodegen().visit(tree) == latex