Skip to content

Implement non-exhaustive enums #4191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 17, 2020
41 changes: 41 additions & 0 deletions doc/langref.html.in
Original file line number Diff line number Diff line change
Expand Up @@ -2893,6 +2893,47 @@ test "switch using enum literals" {
}
{#code_end#}
{#header_close#}

{#header_open|Non-exhaustive enum#}
<p>
A Non-exhaustive enum can be created by adding a trailing '_' field.
It must specify a tag type and cannot consume every enumeration value.
</p>
<p>
{#link|@intToEnum#} on a non-exhaustive enum cannot fail.
</p>
<p>
A switch on a non-exhaustive enum can include a '_' prong as an alternative to an {#syntax#}else{#endsyntax#} prong
with the difference being that it makes it a compile error if all the known tag names are not handled by the switch.
</p>
{#code_begin|test#}
const std = @import("std");
const assert = std.debug.assert;

const Number = enum(u8) {
One,
Two,
Three,
_,
};

test "switch on non-exhaustive enum" {
const number = Number.One;
const result = switch (number) {
.One => true,
.Two,
.Three => false,
_ => false,
};
assert(result);
const is_one = switch (number) {
.One => true,
else => false,
};
assert(is_one);
}
{#code_end#}
{#header_close#}
{#header_close#}

{#header_open|union#}
Expand Down
1 change: 1 addition & 0 deletions lib/std/builtin.zig
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ pub const TypeInfo = union(enum) {
tag_type: type,
fields: []EnumField,
decls: []Declaration,
is_exhaustive: bool,
};

/// This data structure is used by the Zig language code generation and
Expand Down
93 changes: 42 additions & 51 deletions src-self-hosted/translate_c.zig
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ pub fn translate(
tree.errors = ast.Tree.ErrorList.init(arena);

tree.root_node = try arena.create(ast.Node.Root);
tree.root_node.* = ast.Node.Root{
.base = ast.Node{ .id = ast.Node.Id.Root },
tree.root_node.* = .{
.decls = ast.Node.Root.DeclList.init(arena),
// initialized with the eof token at the end
.eof_token = undefined,
Expand Down Expand Up @@ -440,7 +439,6 @@ fn visitFnDecl(c: *Context, fn_decl: *const ZigClangFunctionDecl) Error!void {
.PrivateExtern => return failDecl(c, fn_decl_loc, fn_name, "unsupported storage class: private extern", .{}),
.Auto => unreachable, // Not legal on functions
.Register => unreachable, // Not legal on functions
else => unreachable,
},
};

Expand Down Expand Up @@ -877,25 +875,23 @@ fn transEnumDecl(c: *Context, enum_decl: *const ZigClangEnumDecl) Error!?*ast.No
// types, while that's not ISO-C compliant many compilers allow this and
// default to the usual integer type used for all the enums.

// TODO only emit this tag type if the enum tag type is not the default.
// I don't know what the default is, need to figure out how clang is deciding.
// it appears to at least be different across gcc/msvc
if (int_type.ptr != null and
!isCBuiltinType(int_type, .UInt) and
!isCBuiltinType(int_type, .Int))
{
_ = try appendToken(c, .LParen, "(");
container_node.init_arg_expr = .{
.Type = transQualType(rp, int_type, enum_loc) catch |err| switch (err) {
// default to c_int since msvc and gcc default to different types
_ = try appendToken(c, .LParen, "(");
container_node.init_arg_expr = .{
.Type = if (int_type.ptr != null and
!isCBuiltinType(int_type, .UInt) and
!isCBuiltinType(int_type, .Int))
transQualType(rp, int_type, enum_loc) catch |err| switch (err) {
error.UnsupportedType => {
try failDecl(c, enum_loc, name, "unable to translate enum tag type", .{});
return null;
},
else => |e| return e,
},
};
_ = try appendToken(c, .RParen, ")");
}
}
else
try transCreateNodeIdentifier(c, "c_int"),
};
_ = try appendToken(c, .RParen, ")");

container_node.lbrace_token = try appendToken(c, .LBrace, "{");

Expand Down Expand Up @@ -953,6 +949,19 @@ fn transEnumDecl(c: *Context, enum_decl: *const ZigClangEnumDecl) Error!?*ast.No
tld_node.semicolon_token = try appendToken(c, .Semicolon, ";");
try addTopLevelDecl(c, field_name, &tld_node.base);
}
// make non exhaustive
const field_node = try c.a().create(ast.Node.ContainerField);
field_node.* = .{
.doc_comments = null,
.comptime_token = null,
.name_token = try appendIdentifier(c, "_"),
.type_expr = null,
.value_expr = null,
.align_expr = null,
};

try container_node.fields_and_decls.push(&field_node.base);
_ = try appendToken(c, .Comma, ",");
container_node.rbrace_token = try appendToken(c, .RBrace, "}");

break :blk &container_node.base;
Expand Down Expand Up @@ -1231,18 +1240,6 @@ fn transBinaryOperator(
op_id = .BitOr;
op_token = try appendToken(rp.c, .Pipe, "|");
},
.Assign,
.MulAssign,
.DivAssign,
.RemAssign,
.AddAssign,
.SubAssign,
.ShlAssign,
.ShrAssign,
.AndAssign,
.XorAssign,
.OrAssign,
=> unreachable,
else => unreachable,
}

Expand Down Expand Up @@ -1678,7 +1675,6 @@ fn transStringLiteral(
"TODO: support string literal kind {}",
.{kind},
),
else => unreachable,
}
}

Expand Down Expand Up @@ -2199,6 +2195,19 @@ fn transDoWhileLoop(
.id = .Loop,
};

// if (!cond) break;
const if_node = try transCreateNodeIf(rp.c);
var cond_scope = Scope{
.parent = scope,
.id = .Condition,
};
const prefix_op = try transCreateNodePrefixOp(rp.c, .BoolNot, .Bang, "!");
prefix_op.rhs = try transBoolExpr(rp, &cond_scope, @ptrCast(*const ZigClangExpr, ZigClangDoStmt_getCond(stmt)), .used, .r_value, true);
_ = try appendToken(rp.c, .RParen, ")");
if_node.condition = &prefix_op.base;
if_node.body = &(try transCreateNodeBreak(rp.c, null)).base;
_ = try appendToken(rp.c, .Semicolon, ";");

const body_node = if (ZigClangStmt_getStmtClass(ZigClangDoStmt_getBody(stmt)) == .CompoundStmtClass) blk: {
// there's already a block in C, so we'll append our condition to it.
// c: do {
Expand All @@ -2210,10 +2219,7 @@ fn transDoWhileLoop(
// zig: b;
// zig: if (!cond) break;
// zig: }
const body = (try transStmt(rp, &loop_scope, ZigClangDoStmt_getBody(stmt), .unused, .r_value)).cast(ast.Node.Block).?;
// if this is used as an expression in Zig it needs to be immediately followed by a semicolon
_ = try appendToken(rp.c, .Semicolon, ";");
break :blk body;
break :blk (try transStmt(rp, &loop_scope, ZigClangDoStmt_getBody(stmt), .unused, .r_value)).cast(ast.Node.Block).?;
} else blk: {
// the C statement is without a block, so we need to create a block to contain it.
// c: do
Expand All @@ -2229,19 +2235,6 @@ fn transDoWhileLoop(
break :blk block;
};

// if (!cond) break;
const if_node = try transCreateNodeIf(rp.c);
var cond_scope = Scope{
.parent = scope,
.id = .Condition,
};
const prefix_op = try transCreateNodePrefixOp(rp.c, .BoolNot, .Bang, "!");
prefix_op.rhs = try transBoolExpr(rp, &cond_scope, @ptrCast(*const ZigClangExpr, ZigClangDoStmt_getCond(stmt)), .used, .r_value, true);
_ = try appendToken(rp.c, .RParen, ")");
if_node.condition = &prefix_op.base;
if_node.body = &(try transCreateNodeBreak(rp.c, null)).base;
_ = try appendToken(rp.c, .Semicolon, ";");

try body_node.statements.push(&if_node.base);
if (new)
body_node.rbrace = try appendToken(rp.c, .RBrace, "}");
Expand Down Expand Up @@ -4776,8 +4769,7 @@ fn appendIdentifier(c: *Context, name: []const u8) !ast.TokenIndex {
fn transCreateNodeIdentifier(c: *Context, name: []const u8) !*ast.Node {
const token_index = try appendIdentifier(c, name);
const identifier = try c.a().create(ast.Node.Identifier);
identifier.* = ast.Node.Identifier{
.base = ast.Node{ .id = ast.Node.Id.Identifier },
identifier.* = .{
.token = token_index,
};
return &identifier.base;
Expand Down Expand Up @@ -4916,8 +4908,7 @@ fn transMacroFnDefine(c: *Context, it: *ctok.TokenList.Iterator, name: []const u

const token_index = try appendToken(c, .Keyword_var, "var");
const identifier = try c.a().create(ast.Node.Identifier);
identifier.* = ast.Node.Identifier{
.base = ast.Node{ .id = ast.Node.Id.Identifier },
identifier.* = .{
.token = token_index,
};

Expand Down
2 changes: 2 additions & 0 deletions src/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ struct ZigTypeEnum {
ContainerLayout layout;
ResolveStatus resolve_status;

bool non_exhaustive;
bool resolve_loop_flag;
};

Expand Down Expand Up @@ -3665,6 +3666,7 @@ struct IrInstructionCheckSwitchProngs {
IrInstructionCheckSwitchProngsRange *ranges;
size_t range_count;
bool have_else_prong;
bool have_underscore_prong;
};

struct IrInstructionCheckStatementIsVoid {
Expand Down
38 changes: 31 additions & 7 deletions src/analyze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2569,15 +2569,8 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
return ErrorSemanticAnalyzeFail;
}

enum_type->data.enumeration.src_field_count = field_count;
enum_type->data.enumeration.fields = allocate<TypeEnumField>(field_count);
enum_type->data.enumeration.fields_by_name.init(field_count);

Scope *scope = &enum_type->data.enumeration.decls_scope->base;

HashMap<BigInt, AstNode *, bigint_hash, bigint_eql> occupied_tag_values = {};
occupied_tag_values.init(field_count);

ZigType *tag_int_type;
if (enum_type->data.enumeration.layout == ContainerLayoutExtern) {
tag_int_type = get_c_int_type(g, CIntTypeInt);
Expand Down Expand Up @@ -2619,6 +2612,7 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
}
}

enum_type->data.enumeration.non_exhaustive = false;
enum_type->data.enumeration.tag_int_type = tag_int_type;
enum_type->size_in_bits = tag_int_type->size_in_bits;
enum_type->abi_size = tag_int_type->abi_size;
Expand All @@ -2627,6 +2621,31 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
BigInt bi_one;
bigint_init_unsigned(&bi_one, 1);

AstNode *last_field_node = decl_node->data.container_decl.fields.at(field_count - 1);
if (buf_eql_str(last_field_node->data.struct_field.name, "_")) {
field_count -= 1;
if (field_count > 1 && log2_u64(field_count) == enum_type->size_in_bits) {
add_node_error(g, last_field_node, buf_sprintf("non-exhaustive enum specifies every value"));
enum_type->data.enumeration.resolve_status = ResolveStatusInvalid;
}
if (decl_node->data.container_decl.init_arg_expr == nullptr) {
add_node_error(g, last_field_node, buf_sprintf("non-exhaustive enum must specify size"));
enum_type->data.enumeration.resolve_status = ResolveStatusInvalid;
}
if (last_field_node->data.struct_field.value != nullptr) {
add_node_error(g, last_field_node, buf_sprintf("value assigned to '_' field of non-exhaustive enum"));
enum_type->data.enumeration.resolve_status = ResolveStatusInvalid;
}
enum_type->data.enumeration.non_exhaustive = true;
}

enum_type->data.enumeration.src_field_count = field_count;
enum_type->data.enumeration.fields = allocate<TypeEnumField>(field_count);
enum_type->data.enumeration.fields_by_name.init(field_count);

HashMap<BigInt, AstNode *, bigint_hash, bigint_eql> occupied_tag_values = {};
occupied_tag_values.init(field_count);

TypeEnumField *last_enum_field = nullptr;

for (uint32_t field_i = 0; field_i < field_count; field_i += 1) {
Expand All @@ -2648,6 +2667,11 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
buf_sprintf("consider 'union(enum)' here"));
}

if (buf_eql_str(type_enum_field->name, "_")) {
add_node_error(g, field_node, buf_sprintf("'_' field of non-exhaustive enum must be last"));
enum_type->data.enumeration.resolve_status = ResolveStatusInvalid;
}

auto field_entry = enum_type->data.enumeration.fields_by_name.put_unique(type_enum_field->name, type_enum_field);
if (field_entry != nullptr) {
ErrorMsg *msg = add_node_error(g, field_node,
Expand Down
7 changes: 6 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3356,7 +3356,7 @@ static LLVMValueRef ir_render_int_to_enum(CodeGen *g, IrExecutable *executable,
LLVMValueRef tag_int_value = gen_widen_or_shorten(g, ir_want_runtime_safety(g, &instruction->base),
instruction->target->value->type, tag_int_type, target_val);

if (ir_want_runtime_safety(g, &instruction->base) && wanted_type->data.enumeration.layout != ContainerLayoutExtern) {
if (ir_want_runtime_safety(g, &instruction->base) && !wanted_type->data.enumeration.non_exhaustive) {
LLVMBasicBlockRef bad_value_block = LLVMAppendBasicBlock(g->cur_fn_val, "BadValue");
LLVMBasicBlockRef ok_value_block = LLVMAppendBasicBlock(g->cur_fn_val, "OkValue");
size_t field_count = wanted_type->data.enumeration.src_field_count;
Expand Down Expand Up @@ -5065,6 +5065,11 @@ static LLVMValueRef ir_render_enum_tag_name(CodeGen *g, IrExecutable *executable
{
ZigType *enum_type = instruction->target->value->type;
assert(enum_type->id == ZigTypeIdEnum);
if (enum_type->data.enumeration.non_exhaustive) {
add_node_error(g, instruction->base.source_node,
buf_sprintf("TODO @tagName on non-exhaustive enum https://github.com/ziglang/zig/issues/3991"));
codegen_report_errors_and_exit(g);
}

LLVMValueRef enum_name_function = get_enum_tag_name_function(g, enum_type);

Expand Down
Loading