Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions src/latexify/codegen/expression_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,15 +347,72 @@ def generate_matrix_from_array(data: list[list[str]]) -> str:

return generate_matrix_from_array(rows)

def _generate_zeros(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.zeros.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "zeros"

if len(node.args) != 1:
return None

# All args to np.zeros should be numeric.
if isinstance(node.args[0], ast.Tuple):
dims = [ast_utils.extract_int_or_none(x) for x in node.args[0].elts]
if any(x is None for x in dims):
return None
if not dims:
return "0"
if len(dims) == 1:
dims = [1, dims[0]]

dims_latex = r" \times ".join(str(x) for x in dims)
else:
dim = ast_utils.extract_int_or_none(node.args[0])
if not isinstance(dim, int):
return None
# 1 x N array of zeros
dims_latex = rf"1 \times {dim}"

return rf"\mathbf{{0}}^{{{dims_latex}}}"

def _generate_identity(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.identity.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "identity"

if len(node.args) != 1:
return None

ndims = ast_utils.extract_int_or_none(node.args[0])
if ndims is None:
return None

return rf"\mathbf{{I}}_{{{ndims}}}"

def visit_Call(self, node: ast.Call) -> str:
"""Visit a Call node."""
func_name = ast_utils.extract_function_name_or_none(node)

# Special treatments for some functions.
# TODO(odashi): Move these functions to some separate utility.
if func_name in ("fsum", "sum", "prod"):
special_latex = self._generate_sum_prod(node)
elif func_name in ("array", "ndarray"):
special_latex = self._generate_matrix(node)
elif func_name == "zeros":
special_latex = self._generate_zeros(node)
elif func_name == "identity":
special_latex = self._generate_identity(node)
else:
special_latex = None

Expand Down
107 changes: 85 additions & 22 deletions src/latexify/codegen/expression_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

from latexify import ast_utils, exceptions, test_utils
from latexify.codegen import ExpressionCodegen
from latexify.codegen import expression_codegen


def test_generic_visit() -> None:
Expand All @@ -18,7 +18,7 @@ class UnknownNode(ast.AST):
exceptions.LatexifyNotSupportedError,
match=r"^Unsupported AST: UnknownNode$",
):
ExpressionCodegen().visit(UnknownNode())
expression_codegen.ExpressionCodegen().visit(UnknownNode())


@pytest.mark.parametrize(
Expand All @@ -33,7 +33,7 @@ class UnknownNode(ast.AST):
def test_visit_tuple(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Tuple)
assert ExpressionCodegen().visit(node) == latex
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
Expand All @@ -48,7 +48,7 @@ def test_visit_tuple(code: str, latex: str) -> None:
def test_visit_list(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.List)
assert ExpressionCodegen().visit(node) == latex
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
Expand All @@ -64,7 +64,7 @@ def test_visit_list(code: str, latex: str) -> None:
def test_visit_set(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Set)
assert ExpressionCodegen().visit(node) == latex
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_visit_set(code: str, latex: str) -> None:
def test_visit_listcomp(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.ListComp)
assert ExpressionCodegen().visit(node) == latex
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_visit_listcomp(code: str, latex: str) -> None:
def test_visit_setcomp(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.SetComp)
assert ExpressionCodegen().visit(node) == latex
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_visit_setcomp(code: str, latex: str) -> None:
def test_visit_call(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)
assert ExpressionCodegen().visit(node) == latex
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -330,7 +330,9 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("fsum", r"\sum"), ("sum", r"\sum"), ("prod", r"\prod")]:
node = ast_utils.parse_expr(src_fn + src_suffix)
assert isinstance(node, ast.Call)
assert ExpressionCodegen().visit(node) == dest_fn + dest_suffix
assert (
expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -381,7 +383,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)
assert ExpressionCodegen().visit(node) == latex
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
Expand All @@ -407,7 +409,9 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
node = ast_utils.parse_expr(src_fn + src_suffix)
assert isinstance(node, ast.Call)
assert ExpressionCodegen().visit(node) == dest_fn + dest_suffix
assert (
expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -442,7 +446,7 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
def test_if_then_else(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.IfExp)
assert ExpressionCodegen().visit(node) == latex
assert expression_codegen.ExpressionCodegen().visit(node) == latex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -625,7 +629,7 @@ def test_if_then_else(code: str, latex: str) -> None:
def test_visit_binop(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BinOp)
assert ExpressionCodegen().visit(tree) == latex
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -664,7 +668,7 @@ def test_visit_binop(code: str, latex: str) -> None:
def test_visit_unaryop(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.UnaryOp)
assert ExpressionCodegen().visit(tree) == latex
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -718,7 +722,7 @@ def test_visit_unaryop(code: str, latex: str) -> None:
def test_visit_compare(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Compare)
assert ExpressionCodegen().visit(tree) == latex
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
Expand Down Expand Up @@ -764,7 +768,7 @@ def test_visit_compare(code: str, latex: str) -> None:
def test_visit_boolop(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BoolOp)
assert ExpressionCodegen().visit(tree) == latex
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@test_utils.require_at_most(7)
Expand All @@ -789,7 +793,7 @@ def test_visit_boolop(code: str, latex: str) -> None:
def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, cls)
assert ExpressionCodegen().visit(tree) == latex
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@test_utils.require_at_least(8)
Expand All @@ -814,7 +818,7 @@ def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> No
def test_visit_constant(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Constant)
assert ExpressionCodegen().visit(tree) == latex
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
Expand All @@ -830,7 +834,7 @@ def test_visit_constant(code: str, latex: str) -> None:
def test_visit_subscript(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Subscript)
assert ExpressionCodegen().visit(tree) == latex
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
Expand All @@ -845,7 +849,9 @@ def test_visit_subscript(code: str, latex: str) -> None:
def test_visit_binop_use_set_symbols(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BinOp)
assert ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
assert (
expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
)


@pytest.mark.parametrize(
Expand All @@ -860,7 +866,9 @@ def test_visit_binop_use_set_symbols(code: str, latex: str) -> None:
def test_visit_compare_use_set_symbols(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Compare)
assert ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
assert (
expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -906,4 +914,59 @@ def test_visit_compare_use_set_symbols(code: str, latex: str) -> None:
def test_numpy_array(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert ExpressionCodegen().visit(tree) == latex
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("zeros(0)", r"\mathbf{0}^{1 \times 0}"),
("zeros(1)", r"\mathbf{0}^{1 \times 1}"),
("zeros(2)", r"\mathbf{0}^{1 \times 2}"),
("zeros(())", r"0"),
("zeros((0,))", r"\mathbf{0}^{1 \times 0}"),
("zeros((1,))", r"\mathbf{0}^{1 \times 1}"),
("zeros((2,))", r"\mathbf{0}^{1 \times 2}"),
("zeros((0, 0))", r"\mathbf{0}^{0 \times 0}"),
("zeros((1, 1))", r"\mathbf{0}^{1 \times 1}"),
("zeros((2, 3))", r"\mathbf{0}^{2 \times 3}"),
("zeros((0, 0, 0))", r"\mathbf{0}^{0 \times 0 \times 0}"),
("zeros((1, 1, 1))", r"\mathbf{0}^{1 \times 1 \times 1}"),
("zeros((2, 3, 5))", r"\mathbf{0}^{2 \times 3 \times 5}"),
# Unsupported
("zeros()", r"\mathrm{zeros} \mathopen{}\left( \mathclose{}\right)"),
("zeros(x)", r"\mathrm{zeros} \mathopen{}\left( x \mathclose{}\right)"),
("zeros(0, x)", r"\mathrm{zeros} \mathopen{}\left( 0, x \mathclose{}\right)"),
(
"zeros((x,))",
r"\mathrm{zeros} \mathopen{}\left("
r" \mathopen{}\left( x \mathclose{}\right)"
r" \mathclose{}\right)",
),
],
)
def test_zeros(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex


@pytest.mark.parametrize(
"code,latex",
[
("identity(0)", r"\mathbf{I}_{0}"),
("identity(1)", r"\mathbf{I}_{1}"),
("identity(2)", r"\mathbf{I}_{2}"),
# Unsupported
("identity()", r"\mathrm{identity} \mathopen{}\left( \mathclose{}\right)"),
("identity(x)", r"\mathrm{identity} \mathopen{}\left( x \mathclose{}\right)"),
(
"identity(0, x)",
r"\mathrm{identity} \mathopen{}\left( 0, x \mathclose{}\right)",
),
],
)
def test_identity(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex