Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fiddle/_src/codegen/auto_config/code_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ class AttributeExpression(CodegenNode):
base: Any # Wrapped expression, can involve VariableReference's
attribute: str

def __hash__(self):
# Currently, some Pax (https://github.com/google/paxml) codegen involves
# having AttributeExpression's as dict keys, as those keys are rewritten to
# expressions. This function allows for that, but one shouldn't generally
# assume equality as object identity here.
return id(self)


@dataclasses.dataclass
class ArgFactoryExpr(CodegenNode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def test_sub_fixtures_with_shared_nodes(self, api: str):
# the MoveComplexNodesToVariables pass is run.
num_lines = len(code.splitlines())
if complexity is None:
self.assertLessEqual(num_lines, 25)
self.assertLessEqual(num_lines, 22)
else:
self.assertGreater(num_lines, 25)
self.assertGreater(num_lines, 22)

matches = re.findall(r"def\ (?P<name>[\w_]+)\(", code)
self.assertEqual(
Expand Down
18 changes: 17 additions & 1 deletion fiddle/_src/codegen/auto_config/ir_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def traverse(value, state: daglish.State) -> str:
+ ", ".join(f'"{key}": {value}' for key, value in value.items())
+ "}"
)
elif isinstance(value, code_ir.VariableReference):
elif isinstance(value, code_ir.BaseNameReference):
return value.name.value
elif isinstance(value, code_ir.AttributeExpression):
base_obj = state.call(value.base, daglish.Attr("base"))
Expand All @@ -90,6 +90,22 @@ def traverse(value, state: daglish.State) -> str:
elif isinstance(value, code_ir.WithTagsCall):
sub_value = state.map_children(value).expression
return f"WithTagsCall[{sub_value}]"
elif isinstance(value, code_ir.SymbolOrFixtureCall):
symbol_expression = state.call(
value.symbol_expression, daglish.Attr("symbol_expression")
)
positional_arg_expressions = state.call(
value.positional_arg_expressions,
daglish.Attr("positional_arg_expressions"),
)
arg_expressions = state.call(
value.arg_expressions, daglish.Attr("arg_expressions")
)
return (
f"call:<{symbol_expression}"
f"(*[{positional_arg_expressions}],"
f" **{arg_expressions})>"
)
elif isinstance(value, code_ir.Name):
return value.value
elif isinstance(value, type):
Expand Down
14 changes: 14 additions & 0 deletions fiddle/_src/codegen/auto_config/ir_printer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def test_format_attributes(self):
attr = code_ir.AttributeExpression(self_var, "foo")
self.assertEqual(ir_printer.format_expr(attr), "self.foo")

def test_format_calls(self):
call = code_ir.SymbolOrFixtureCall(
symbol_expression=code_ir.Name("foo"),
positional_arg_expressions=[code_ir.Name("bar")],
arg_expressions={"baz": code_ir.Name("qux")},
)
self.assertEqual(
ir_printer.format_expr(call), 'call:<foo(*[[bar]], **{"baz": qux})>'
)

def test_format_module_reference(self):
module_reference = code_ir.ModuleReference(code_ir.Name("foo"))
self.assertEqual(ir_printer.format_expr(module_reference), "foo")

def test_format_simple_ir(self):
task = test_fixtures.simple_ir()
code = "\n".join(ir_printer.format_fn(task.top_level_call.fn))
Expand Down
3 changes: 2 additions & 1 deletion fiddle/_src/codegen/auto_config/ir_to_cst.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def _prepare_args_helper(
except:
print(f"\n\nERROR CONVERTING: {value!r}")
print(f"\n\nTYPE: {type(value)}")
print(f"\n\nPATH: {daglish.path_str(state.current_path)}")
raise

return daglish.MemoizedTraversal.run(traverse, expr)
Expand Down Expand Up @@ -198,7 +199,7 @@ def code_for_fn(
),
]
)
if fn.parameters:
if fn.parameters and len(fn.parameters) > 1:
whitespace_before_params = cst.ParenthesizedWhitespace(
cst.TrailingWhitespace(),
indent=True,
Expand Down
35 changes: 23 additions & 12 deletions fiddle/_src/codegen/codegen_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import functools
import re
import types
from typing import Any, Callable, Dict, List, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

from fiddle import daglish
from fiddle import diffing
Expand All @@ -31,10 +31,13 @@
import libcst as cst


def fiddler_from_diff(diff: diffing.Diff,
old: Any = None,
func_name: str = 'fiddler',
param_name: str = 'cfg'):
def fiddler_from_diff(
diff: diffing.Diff,
old: Any = None,
func_name: str = 'fiddler',
param_name: str = 'cfg',
import_manager: Optional[import_manager_lib.ImportManager] = None,
):
"""Returns the CST for a fiddler function that applies the changes in `diff`.

The returned `cst.Module` consists of a set of `import` statements for any
Expand Down Expand Up @@ -66,18 +69,26 @@ def fiddler_from_diff(diff: diffing.Diff,
all referenced paths.
func_name: The name for the fiddler function.
param_name: The name for the parameter to the fiddler function.
import_manager: Existing import manager. Usually set to None, but if you are
integrating this with other code generation tasks, it can be nice to
share.

Returns:
An `cst.Module` object. You can convert this to a string using
`result.code`.
"""
# Create a namespace to keep track of variables that we add. Reserve the
# names of the param & func.
namespace = namespace_lib.Namespace()
namespace.add(param_name)
namespace.add(func_name)

import_manager = import_manager_lib.ImportManager(namespace)
if import_manager is None:
# Create a namespace to keep track of variables that we add. Reserve the
# names of the param & func.
namespace = namespace_lib.Namespace()
namespace.add(param_name)
namespace.add(func_name)

import_manager = import_manager_lib.ImportManager(namespace)
else:
namespace = import_manager.namespace
namespace.add(param_name)
namespace.add(func_name)

# Get a list of paths that are referenced by the diff.
used_paths = _find_used_paths(diff)
Expand Down