Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Doc/whatsnew/3.10.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ New Modules
Improved Modules
================

ast
---

Sequential and optional fields for AST nodes are now auto-initialized with the
corresponding empty values. See the :ref:`ASDL <abstract-grammar>` for more
information about the AST node classes and fields they have.
(Contributed by Batuhan Taskaya in :issue:`39981`)

base64
------

Expand Down
1 change: 1 addition & 0 deletions Include/Python-ast.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
integer or string, then the tree will be pretty-printed with that indent
level. None (the default) selects the single line representation.
"""
def _qualifier_to_default(qualifier):
if qualifier == 1:
return []
elif qualifier == 2:
return None
else:
return ...

def _format(node, level=0):
if indent is not None:
level += 1
Expand All @@ -130,13 +138,14 @@ def _format(node, level=0):
args = []
allsimple = True
keywords = annotate_fields
for name in node._fields:
for name, qualifier in zip(node._fields, node._field_qualifiers):
default_value = _qualifier_to_default(qualifier)
try:
value = getattr(node, name)
except AttributeError:
keywords = True
continue
if value is None and getattr(cls, name, ...) is None:
if value is None and default_value is None:
keywords = True
continue
value, simple = _format(value, level)
Expand Down
37 changes: 35 additions & 2 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,19 +358,52 @@ def test_arguments(self):
self.assertEqual(x._fields, ('posonlyargs', 'args', 'vararg', 'kwonlyargs',
'kw_defaults', 'kwarg', 'defaults'))

with self.assertRaises(AttributeError):
x.args
self.assertIsNone(x.vararg)
self.assertEqual(x.args, [])

x = ast.arguments(*range(1, 8))
self.assertEqual(x.args, 2)
self.assertEqual(x.vararg, 3)

def test_field_defaults(self):
func = ast.FunctionDef("foo", ast.arguments())
self.assertEqual(func.name, "foo")
self.assertEqual(ast.dump(func.args), ast.dump(ast.arguments()))
self.assertEqual(func.body, [])
self.assertEqual(func.decorator_list, [])
self.assertEqual(func.returns, None)
self.assertEqual(func.type_comment, None)

func2 = ast.FunctionDef()
with self.assertRaises(AttributeError):
func2.name2

self.assertEqual(func.body, [])
self.assertEqual(func.returns, None)

func3 = ast.FunctionDef(body=[1])
self.assertEqual(func3.body, [1])
self.assertFalse(hasattr(func3, "name"))
self.assertTrue(hasattr(func3, "returns"))

def test_field_attr_writable(self):
x = ast.Num()
# We can assign to _fields
x._fields = 666
x._field_qualifiers = 999
self.assertEqual(x._fields, 666)
self.assertEqual(x._field_qualifiers, 999)

functiondef_qualifiers = ast.FunctionDef._field_qualifiers
del ast.FunctionDef._field_qualifiers
fnctdef = ast.FunctionDef("foo")
self.assertEqual(fnctdef.name, "foo")
with self.assertRaises(AttributeError):
fnctdef.body
ast.FunctionDef._field_qualifiers = (5,) * len(functiondef_qualifiers)
with self.assertRaises(ValueError):
ast.FunctionDef() # 5 as a field qualifier is an invalid value
ast.FunctionDef._field_qualifiers = functiondef_qualifiers

def test_classattrs(self):
x = ast.Num()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce default values for AST node class initializations.
156 changes: 124 additions & 32 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ class TypeDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)

self.emit(
"typedef enum _field_qualifier {Q_SEQUENCE=1, Q_OPTIONAL=2} "
"field_qualifier;", 0
)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)

Expand Down Expand Up @@ -665,6 +668,7 @@ def visitProduct(self, prod, name):
for f in prod.fields:
self.emit('"%s",' % f.name, 1)
self.emit("};", 0)
self._emit_field_qualifiers(name, prod.fields, 0)

def visitSum(self, sum, name):
self.emit_type("%s_type" % name)
Expand Down Expand Up @@ -692,6 +696,19 @@ def visitConstructor(self, cons, name):
for t in cons.fields:
self.emit('"%s",' % t.name, 1)
self.emit("};",0)
self._emit_field_qualifiers(cons.name, cons.fields, 0)

def _emit_field_qualifiers(self, name, fields, depth):
self.emit("static const field_qualifier %s_field_qualifiers[]={" % name, depth)
for field in fields:
if field.seq:
qualifier = "Q_SEQUENCE"
elif field.opt:
qualifier = "Q_OPTIONAL"
else:
qualifier = "0"
self.emit("%s, // %s" % (qualifier, field.name), depth+1)
self.emit("};", depth)


class PyTypesVisitor(PickleVisitor):
Expand Down Expand Up @@ -742,7 +759,7 @@ def visitModule(self, mod):

Py_ssize_t i, numfields = 0;
int res = -1;
PyObject *key, *value, *fields;
PyObject *key, *value, *fields, *field_qualifiers = NULL;
if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
goto cleanup;
}
Expand Down Expand Up @@ -802,8 +819,67 @@ def visitModule(self, mod):
}
}
}
if (_PyObject_LookupAttr(self, state->_field_qualifiers, &field_qualifiers) < 0) {
res = -1;
goto cleanup;
}

if (!PyTuple_CheckExact(field_qualifiers) || PyTuple_Size(field_qualifiers) != numfields) {
goto cleanup;
}

PyObject *field, *field_qualifier;
for (i = 0; i < numfields; i++) {
field = PySequence_GetItem(fields, i);
field_qualifier = PySequence_GetItem(field_qualifiers, i);
if (!field_qualifier || !field) {
res = -1;
goto next_iteration;
}

if (PyObject_HasAttr(self, field)) {
goto next_iteration;
}

PyObject *field_default = NULL;
switch (PyLong_AsLong(field_qualifier)) {
case -1:
res = -1;
goto next_iteration;
case 0:
goto next_iteration;
case Q_SEQUENCE:
field_default = PyList_New(0);
if (field_default == NULL) {
res = -1;
goto next_iteration;
}
break;
case Q_OPTIONAL:
field_default = Py_None;
Py_INCREF(field_default);
break;
default:
PyErr_Format(PyExc_ValueError,
"Unknown field qualifier: \\"%R\\"", field_qualifier);
res = -1;
goto next_iteration;
}
assert(field_default != NULL);
res = PyObject_SetAttr(self, field, field_default);
Py_DECREF(field_default);
next_iteration:
Py_XDECREF(field);
Py_XDECREF(field_qualifier);
if (res < 0) {
goto cleanup;
}
continue;
}

cleanup:
Py_XDECREF(fields);
Py_XDECREF(field_qualifiers);
return res;
}

Expand Down Expand Up @@ -866,29 +942,45 @@ def visitModule(self, mod):
};

static PyObject *
make_type(astmodulestate *state, const char *type, PyObject* base,
const char* const* fields, int num_fields, const char *doc)
make_type(
astmodulestate *state,
const char *type,
PyObject* base,
const char* const* fields,
const field_qualifier* field_qualifiers,
Py_ssize_t num_fields,
const char *doc
)
{
PyObject *fnames, *result;
int i;
fnames = PyTuple_New(num_fields);
if (!fnames) return NULL;
Py_ssize_t i;
PyObject *result = NULL;
PyObject *fnames = PyTuple_New(num_fields);
PyObject *fqualifiers = PyTuple_New(num_fields);

if (!fnames || !fqualifiers) {
goto exit;
}

for (i = 0; i < num_fields; i++) {
PyObject *field = PyUnicode_InternFromString(fields[i]);
if (!field) {
Py_DECREF(fnames);
return NULL;
PyObject *qualifier = PyLong_FromLong((long)field_qualifiers[i]);
if (!field || !qualifier) {
goto exit;
}
PyTuple_SET_ITEM(fnames, i, field);
PyTuple_SET_ITEM(fqualifiers, i, qualifier);
}
result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOs}",
result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOOOs}",
type, base,
state->_fields, fnames,
state->_field_qualifiers, fqualifiers,
state->__module__,
state->ast,
state->__doc__, doc);
Py_DECREF(fnames);
return result;
exit:
Py_XDECREF(fnames);
Py_XDECREF(fqualifiers);
return result;
}

static int
Expand Down Expand Up @@ -1012,8 +1104,10 @@ def visitModule(self, mod):
{
PyObject *empty_tuple;
empty_tuple = PyTuple_New(0);
Py_XINCREF(empty_tuple); // for _field_qualifiers
if (!empty_tuple ||
PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 ||
PyObject_SetAttrString(state->AST_type, "_field_qualifiers", empty_tuple) < 0 ||
PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) {
Py_XDECREF(empty_tuple);
return -1;
Expand All @@ -1040,10 +1134,11 @@ def visitModule(self, mod):
def visitProduct(self, prod, name):
if prod.fields:
fields = name+"_fields"
field_qualifiers = name+"_field_qualifiers"
else:
fields = "NULL"
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' %
(name, name, fields, len(prod.fields)), 1)
fields = field_defaults = "NULL"
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %s, %d,' %
(name, name, fields, field_qualifiers, len(prod.fields)), 1)
self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False)
self.emit("if (!state->%s_type) return 0;" % name, 1)
self.emit_type("AST_type")
Expand All @@ -1053,11 +1148,9 @@ def visitProduct(self, prod, name):
(name, name, len(prod.attributes)), 1)
else:
self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
self.emit_defaults(name, prod.fields, 1)
self.emit_defaults(name, prod.attributes, 1)

def visitSum(self, sum, name):
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' %
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, NULL, 0,' %
(name, name), 1)
self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False)
self.emit_type("%s_type" % name)
Expand All @@ -1067,35 +1160,33 @@ def visitSum(self, sum, name):
(name, name, len(sum.attributes)), 1)
else:
self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
self.emit_defaults(name, sum.attributes, 1)
for attribute in sum.attributes:
if attribute.opt:
self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1) {' %
(name, attribute.name), 1)
self.emit("return 0;", 2)
self.emit("}", 1)
simple = is_simple(sum)
for t in sum.types:
self.visitConstructor(t, name, simple)

def visitConstructor(self, cons, name, simple):
if cons.fields:
fields = cons.name+"_fields"
field_qualifiers = cons.name+"_field_qualifiers"
else:
fields = "NULL"
self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' %
(cons.name, cons.name, name, fields, len(cons.fields)), 1)
fields = field_qualifiers = "NULL"
self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %s, %d,' %
(cons.name, cons.name, name, fields, field_qualifiers, len(cons.fields)), 1)
self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False)
self.emit("if (!state->%s_type) return 0;" % cons.name, 1)
self.emit_type("%s_type" % cons.name)
self.emit_defaults(cons.name, cons.fields, 1)
if simple:
self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)"
"state->%s_type, NULL, NULL);" %
(cons.name, cons.name), 1)
self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1)

def emit_defaults(self, name, fields, depth):
for field in fields:
if field.opt:
self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' %
(name, field.name), depth)
self.emit("return 0;", depth+1)


class ASTModuleVisitor(PickleVisitor):

Expand Down Expand Up @@ -1397,6 +1488,7 @@ def generate_module_def(f, mod):
state_strings = {
"ast",
"_fields",
"_field_qualifiers",
"__doc__",
"__dict__",
"__module__",
Expand Down
Loading