Skip to content
41 changes: 41 additions & 0 deletions src/latexify/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ast
import dataclasses
import sys

from latexify import ast_utils, exceptions

Expand Down Expand Up @@ -62,3 +63,43 @@ def analyze_range(node: ast.Call) -> RangeInfo:
stop_int=ast_utils.extract_int_or_none(stop),
step_int=ast_utils.extract_int_or_none(step),
)


def reduce_stop_parameter(node: ast.expr) -> ast.expr:
"""Adjusts the stop expression of the range.

This function tries to convert the syntax as follows:
* n + 1 --> n
* n + 2 --> n + 1
* n - 1 --> n - 2

Args:
node: The target expression.

Returns:
Converted expression.
"""
if not (isinstance(node, ast.BinOp) and isinstance(node.op, (ast.Add, ast.Sub))):
return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1))

# Treatment for Python 3.7.
rhs = (
ast.Constant(value=node.right.n)
if sys.version_info.minor < 8 and isinstance(node.right, ast.Num)
else node.right
)

if not isinstance(rhs, ast.Constant):
return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1))

shift = 1 if isinstance(node.op, ast.Add) else -1

return (
node.left
if rhs.value == shift
else ast.BinOp(
left=node.left,
op=node.op,
right=ast_utils.make_constant(value=rhs.value - shift),
)
)
17 changes: 17 additions & 0 deletions src/latexify/analyzers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,20 @@ def test_analyze_range_invalid(code: str) -> None:
exceptions.LatexifySyntaxError, match=r"^Unsupported AST for analyze_range\.$"
):
analyzers.analyze_range(node)


@pytest.mark.parametrize(
"before,after",
[
("n + 1", "n"),
("n + 2", "n + 1"),
("n - (-1)", "n - (-1) - 1"),
("n - 1", "n - 2"),
("1 * 2", "1 * 2 - 1"),
],
)
def test_reduce_stop_parameter(before: str, after: str) -> None:
test_utils.assert_ast_equal(
analyzers.reduce_stop_parameter(ast_utils.parse_expr(before)),
ast_utils.parse_expr(after),
)
3 changes: 2 additions & 1 deletion src/latexify/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Package latexify.codegen."""

from latexify.codegen import function_codegen
from latexify.codegen import expression_codegen, function_codegen

ExpressionCodegen = expression_codegen.ExpressionCodegen
FunctionCodegen = function_codegen.FunctionCodegen
28 changes: 28 additions & 0 deletions src/latexify/codegen/codegen_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any

from latexify import exceptions


def convert_constant(value: Any) -> str:
"""Helper to convert constant values to LaTeX.
Args:
value: A constant value.
Returns:
The LaTeX representation of `value`.
"""
if value is None or isinstance(value, bool):
return r"\mathrm{" + str(value) + "}"
if isinstance(value, (int, float, complex)):
# TODO(odashi): Support other symbols for the imaginary unit than j.
return str(value)
if isinstance(value, str):
return r'\textrm{"' + value + '"}'
if isinstance(value, bytes):
return r"\textrm{" + str(value) + "}"
if value is ...:
return r"\cdots"
raise exceptions.LatexifyNotSupportedError(
f"Unrecognized constant: {type(value).__name__}"
)
34 changes: 34 additions & 0 deletions src/latexify/codegen/codegen_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Tests for latexify.codegen.codegen_utils."""

from __future__ import annotations

from typing import Any

import pytest

from latexify import exceptions
from latexify.codegen.codegen_utils import convert_constant


@pytest.mark.parametrize(
"constant,latex",
[
(None, r"\mathrm{None}"),
(True, r"\mathrm{True}"),
(False, r"\mathrm{False}"),
(123, "123"),
(456.789, "456.789"),
(-3 + 4j, "(-3+4j)"),
("string", r'\textrm{"string"}'),
(..., r"\cdots"),
],
)
def test_convert_constant(constant: Any, latex: str) -> None:
assert convert_constant(constant) == latex


def test_convert_constant_unsupported_constant() -> None:
with pytest.raises(
exceptions.LatexifyNotSupportedError, match="^Unrecognized constant: "
):
convert_constant({})
Loading