Skip to content

Commit 6b1b597

Browse files
committed
An implementation of flexible types for explicit nulls
1 parent 7171211 commit 6b1b597

File tree

67 files changed

+525
-141
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+525
-141
lines changed

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object ScalaSettings extends ScalaSettings
3131

3232
// Kept as seperate type to avoid breaking backward compatibility
3333
abstract class ScalaSettings extends SettingGroup, AllScalaSettings:
34-
val settingsByCategory: Map[SettingCategory, List[Setting[_]]] =
34+
val settingsByCategory: Map[SettingCategory, List[Setting[_]]] =
3535
allSettings.groupBy(_.category)
3636
.view.mapValues(_.toList).toMap
3737
.withDefaultValue(Nil)
@@ -43,7 +43,7 @@ abstract class ScalaSettings extends SettingGroup, AllScalaSettings:
4343
val verboseSettings: List[Setting[_]] = settingsByCategory(VerboseSetting).sortBy(_.name)
4444
val settingsByAliases: Map[String, Setting[_]] = allSettings.flatMap(s => s.aliases.map(_ -> s)).toMap
4545

46-
46+
4747
trait AllScalaSettings extends CommonScalaSettings, PluginSettings, VerboseSettings, WarningSettings, XSettings, YSettings:
4848
self: SettingGroup =>
4949

@@ -416,6 +416,7 @@ private sealed trait YSettings:
416416
// Experimental language features
417417
val YnoKindPolymorphism: Setting[Boolean] = BooleanSetting(ForkSetting, "Yno-kind-polymorphism", "Disable kind polymorphism.")
418418
val YexplicitNulls: Setting[Boolean] = BooleanSetting(ForkSetting, "Yexplicit-nulls", "Make reference types non-nullable. Nullable types can be expressed with unions: e.g. String|Null.")
419+
val YnoFlexibleTypes: Setting[Boolean] = BooleanSetting("-Yno-flexible-types", "Disable turning nullable Java return types and parameter types into flexible types, which behaves like abstract types with a nullable lower bound and non-nullable upper bound.")
419420
val YcheckInit: Setting[Boolean] = BooleanSetting(ForkSetting, "Ysafe-init", "Ensure safe initialization of objects.")
420421
val YcheckInitGlobal: Setting[Boolean] = BooleanSetting(ForkSetting, "Ysafe-init-global", "Check safe initialization of global objects.")
421422
val YrequireTargetName: Setting[Boolean] = BooleanSetting(ForkSetting, "Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation.")

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ object Contexts {
472472
/** Is the explicit nulls option set? */
473473
def explicitNulls: Boolean = base.settings.YexplicitNulls.value
474474

475+
/** Is the flexible types option set? */
476+
def flexibleTypes: Boolean = base.settings.YexplicitNulls.value && !base.settings.YnoFlexibleTypes.value
477+
475478
/** A fresh clone of this context embedded in this context. */
476479
def fresh: FreshContext = freshOver(this)
477480

compiler/src/dotty/tools/dotc/core/JavaNullInterop.scala

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ object JavaNullInterop {
7878
* but the result type is not nullable.
7979
*/
8080
private def nullifyExceptReturnType(tp: Type)(using Context): Type =
81-
new JavaNullMap(true)(tp)
81+
new JavaNullMap(outermostLevelAlreadyNullable = true)(tp)
8282

8383
/** Nullifies a Java type by adding `| Null` in the relevant places. */
8484
private def nullifyType(tp: Type)(using Context): Type =
85-
new JavaNullMap(false)(tp)
85+
new JavaNullMap(outermostLevelAlreadyNullable = false)(tp)
8686

8787
/** A type map that implements the nullification function on types. Given a Java-sourced type, this adds `| Null`
8888
* in the right places to make the nulls explicit in Scala.
@@ -96,25 +96,29 @@ object JavaNullInterop {
9696
* to `(A & B) | Null`, instead of `(A | Null & B | Null) | Null`.
9797
*/
9898
private class JavaNullMap(var outermostLevelAlreadyNullable: Boolean)(using Context) extends TypeMap {
99+
def nullify(tp: Type): Type = if ctx.flexibleTypes then FlexibleType(tp) else OrNull(tp)
100+
99101
/** Should we nullify `tp` at the outermost level? */
100102
def needsNull(tp: Type): Boolean =
101-
!outermostLevelAlreadyNullable && (tp match {
103+
!(outermostLevelAlreadyNullable || (tp match {
102104
case tp: TypeRef =>
103105
// We don't modify value types because they're non-nullable even in Java.
104-
!tp.symbol.isValueClass &&
106+
tp.symbol.isValueClass
107+
// We don't modify unit types.
108+
|| tp.isRef(defn.UnitClass)
105109
// We don't modify `Any` because it's already nullable.
106-
!tp.isRef(defn.AnyClass) &&
110+
|| tp.isRef(defn.AnyClass)
107111
// We don't nullify Java varargs at the top level.
108112
// Example: if `setNames` is a Java method with signature `void setNames(String... names)`,
109113
// then its Scala signature will be `def setNames(names: (String|Null)*): Unit`.
110114
// This is because `setNames(null)` passes as argument a single-element array containing the value `null`,
111115
// and not a `null` array.
112-
!tp.isRef(defn.RepeatedParamClass)
113-
case _ => true
114-
})
116+
|| !ctx.flexibleTypes && tp.isRef(defn.RepeatedParamClass)
117+
case _ => false
118+
}))
115119

116120
override def apply(tp: Type): Type = tp match {
117-
case tp: TypeRef if needsNull(tp) => OrNull(tp)
121+
case tp: TypeRef if needsNull(tp) => nullify(tp)
118122
case appTp @ AppliedType(tycon, targs) =>
119123
val oldOutermostNullable = outermostLevelAlreadyNullable
120124
// We don't make the outmost levels of type arguments nullable if tycon is Java-defined.
@@ -124,7 +128,7 @@ object JavaNullInterop {
124128
val targs2 = targs map this
125129
outermostLevelAlreadyNullable = oldOutermostNullable
126130
val appTp2 = derivedAppliedType(appTp, tycon, targs2)
127-
if needsNull(tycon) then OrNull(appTp2) else appTp2
131+
if needsNull(tycon) then nullify(appTp2) else appTp2
128132
case ptp: PolyType =>
129133
derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType))
130134
case mtp: MethodType =>
@@ -138,12 +142,12 @@ object JavaNullInterop {
138142
// nullify(A & B) = (nullify(A) & nullify(B)) | Null, but take care not to add
139143
// duplicate `Null`s at the outermost level inside `A` and `B`.
140144
outermostLevelAlreadyNullable = true
141-
OrNull(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
142-
case tp: TypeParamRef if needsNull(tp) => OrNull(tp)
145+
nullify(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
146+
case tp: TypeParamRef if needsNull(tp) => nullify(tp)
143147
// In all other cases, return the type unchanged.
144148
// In particular, if the type is a ConstantType, then we don't nullify it because it is the
145149
// type of a final non-nullable field.
146150
case _ => tp
147151
}
148152
}
149-
}
153+
}

compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ import Types.*
88
object NullOpsDecorator:
99

1010
extension (self: Type)
11+
def stripFlexible(using Context): Type = self match
12+
case FlexibleType(tp) => tp
13+
case _ => self
14+
1115
/** Syntactically strips the nullability from this type.
1216
* If the type is `T1 | ... | Tn`, and `Ti` references to `Null`,
1317
* then return `T1 | ... | Ti-1 | Ti+1 | ... | Tn`.
@@ -33,6 +37,7 @@ object NullOpsDecorator:
3337
if (tp1s ne tp1) && (tp2s ne tp2) then
3438
tp.derivedAndType(tp1s, tp2s)
3539
else tp
40+
case tp: FlexibleType => tp.hi
3641
case tp @ TypeBounds(lo, hi) =>
3742
tp.derivedTypeBounds(strip(lo), strip(hi))
3843
case tp => tp

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
564564
case CapturingType(parent, refs) =>
565565
val parent1 = recur(parent)
566566
if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp
567+
case tp: FlexibleType =>
568+
val underlying = recur(tp.underlying)
569+
if underlying ne tp.underlying then tp.derivedFlexibleType(underlying) else tp
567570
case tp: AnnotatedType =>
568571
val parent1 = recur(tp.parent)
569572
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
163163
}
164164
}
165165

166-
def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match {
166+
def dealiasDropNonmoduleRefs(tp: Type): Type = tp.dealias match {
167167
case tp: TermRef =>
168168
// we drop TermRefs that don't have a class symbol, as they can't
169169
// meaningfully participate in GADT reasoning and just get in the way.
@@ -172,6 +172,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
172172
// additional trait - argument-less enum cases desugar to vals.
173173
// See run/enum-Tree.scala.
174174
if tp.classSymbol.exists then tp else tp.info
175+
case FlexibleType(tp) => dealiasDropNonmoduleRefs(tp)
175176
case tp => tp
176177
}
177178

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import reporting.trace
2424
import annotation.constructorOnly
2525
import cc.*
2626
import NameKinds.WildcardParamName
27+
import NullOpsDecorator.stripFlexible
2728

2829
/** Provides methods to compare types.
2930
*/
@@ -864,6 +865,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
864865
false
865866
}
866867
compareClassInfo
868+
case tp2: FlexibleType =>
869+
recur(tp1, tp2.lo)
867870
case _ =>
868871
fourthTry
869872
}
@@ -1059,6 +1062,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
10591062
case tp1: ExprType if ctx.phaseId > gettersPhase.id =>
10601063
// getters might have converted T to => T, need to compensate.
10611064
recur(tp1.widenExpr, tp2)
1065+
case tp1: FlexibleType =>
1066+
recur(tp1.hi, tp2)
10621067
case _ =>
10631068
false
10641069
}
@@ -2516,15 +2521,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
25162521
case _ =>
25172522
tp
25182523

2519-
private def andTypeGen(tp1: Type, tp2: Type, op: (Type, Type) => Type,
2520-
original: (Type, Type) => Type = _ & _, isErased: Boolean = ctx.erasedTypes): Type = trace(s"andTypeGen(${tp1.show}, ${tp2.show})", subtyping, show = true) {
2524+
private def andTypeGen(tp1orig: Type, tp2orig: Type, op: (Type, Type) => Type,
2525+
original: (Type, Type) => Type = _ & _, isErased: Boolean = ctx.erasedTypes): Type = trace(s"andTypeGen(${tp1orig.show}, ${tp2orig.show})", subtyping, show = true) {
2526+
val tp1 = tp1orig.stripFlexible
2527+
val tp2 = tp2orig.stripFlexible
25212528
val t1 = distributeAnd(tp1, tp2)
2522-
if (t1.exists) t1
2523-
else {
2524-
val t2 = distributeAnd(tp2, tp1)
2525-
if (t2.exists) t2
2526-
else if (isErased) erasedGlb(tp1, tp2)
2527-
else liftIfHK(tp1, tp2, op, original, _ | _)
2529+
val ret =
2530+
if t1.exists then t1
2531+
else
2532+
val t2 = distributeAnd(tp2, tp1)
2533+
if t2.exists then t2
2534+
else if isErased then erasedGlb(tp1, tp2)
2535+
else liftIfHK(tp1, tp2, op, original, _ | _)
25282536
// The ` | ` on variances is needed since variances are associated with bounds
25292537
// not lambdas. Example:
25302538
//
@@ -2534,7 +2542,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
25342542
//
25352543
// Here, `F` is treated as bivariant in `O`. That is, only bivariant implementation
25362544
// of `F` are allowed. See neg/hk-variance2s.scala test.
2537-
}
2545+
2546+
if tp1orig.isInstanceOf[FlexibleType] && tp2orig.isInstanceOf[FlexibleType]
2547+
then FlexibleType(ret) else ret
25382548
}
25392549

25402550
/** Form a normalized conjunction of two types.

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ object Types extends TypeUtils {
307307
isRef(defn.ObjectClass) && (typeSymbol eq defn.FromJavaObjectSymbol)
308308

309309
def containsFromJavaObject(using Context): Boolean = this match
310+
case tp: FlexibleType => tp.original.containsFromJavaObject
310311
case tp: OrType => tp.tp1.containsFromJavaObject || tp.tp2.containsFromJavaObject
311312
case tp: AndType => tp.tp1.containsFromJavaObject && tp.tp2.containsFromJavaObject
312313
case _ => isFromJavaObject
@@ -345,6 +346,7 @@ object Types extends TypeUtils {
345346
/** Is this type guaranteed not to have `null` as a value? */
346347
final def isNotNull(using Context): Boolean = this match {
347348
case tp: ConstantType => tp.value.value != null
349+
case tp: FlexibleType => false
348350
case tp: ClassInfo => !tp.cls.isNullableClass && tp.cls != defn.NothingClass
349351
case tp: AppliedType => tp.superType.isNotNull
350352
case tp: TypeBounds => tp.lo.isNotNull
@@ -374,6 +376,7 @@ object Types extends TypeUtils {
374376
case AppliedType(tycon, args) => tycon.unusableForInference || args.exists(_.unusableForInference)
375377
case RefinedType(parent, _, rinfo) => parent.unusableForInference || rinfo.unusableForInference
376378
case TypeBounds(lo, hi) => lo.unusableForInference || hi.unusableForInference
379+
case FlexibleType(underlying) => underlying.unusableForInference
377380
case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference
378381
case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference)
379382
case WildcardType(optBounds) => optBounds.unusableForInference
@@ -1616,6 +1619,9 @@ object Types extends TypeUtils {
16161619

16171620
/** If this is a repeated type, its element type, otherwise the type itself */
16181621
def repeatedToSingle(using Context): Type = this match {
1622+
case tp: FlexibleType =>
1623+
val underlyingSingle = tp.underlying.repeatedToSingle
1624+
if underlyingSingle ne tp.underlying then underlyingSingle else tp
16191625
case tp @ ExprType(tp1) => tp.derivedExprType(tp1.repeatedToSingle)
16201626
case _ => if (isRepeatedParam) this.argTypesHi.head else this
16211627
}
@@ -3440,6 +3446,56 @@ object Types extends TypeUtils {
34403446
}
34413447
}
34423448

3449+
// --- FlexibleType -----------------------------------------------------------------
3450+
3451+
/* A flexible type is a type with a custom subtyping relationship.
3452+
* It is used by explicit nulls to represent a type coming from Java which can be
3453+
* consider as nullable or non-nullable depending on the context, in a similar way to Platform
3454+
* Types in Kotlin. A `FlexibleType(T)` generally behaves like a type variable with special bounds
3455+
* `T | Null .. T`, so that `T | Null <: FlexibleType(T) <: T`.
3456+
* A flexible type will be erased to its original type `T`.
3457+
*/
3458+
case class FlexibleType(original: Type, lo: Type, hi: Type) extends CachedProxyType with ValueType {
3459+
def underlying(using Context): Type = original
3460+
3461+
override def superType(using Context): Type = hi
3462+
3463+
def derivedFlexibleType(original: Type)(using Context): Type =
3464+
if this.original eq original then this else FlexibleType(original)
3465+
3466+
override def computeHash(bs: Binders): Int = doHash(bs, original)
3467+
3468+
override final def baseClasses(using Context): List[ClassSymbol] = original.baseClasses
3469+
}
3470+
3471+
object FlexibleType {
3472+
def apply(original: Type)(using Context): Type = original match {
3473+
case ft: FlexibleType => ft
3474+
case _ =>
3475+
// val original1 = original.stripNull
3476+
// if original1.isNullType then
3477+
// // (Null)? =:= ? >: Null <: (Object & Null)
3478+
// FlexibleType(defn.NullType, original, AndType(defn.ObjectType, defn.NullType))
3479+
// else
3480+
// // (T | Null)? =:= ? >: T | Null <: T
3481+
// // (T)? =:= ? >: T | Null <: T
3482+
// val hi = original1
3483+
// val lo = if hi eq original then OrNull(hi) else original
3484+
// FlexibleType(original, lo, hi)
3485+
//
3486+
// The commented out code does more work to analyze the original type to ensure the
3487+
// flexible type is always a subtype of the original type and the Object type.
3488+
// It is not necessary according to the use cases, so we choose to use a simpler
3489+
// rule.
3490+
FlexibleType(original, OrNull(original), original)
3491+
}
3492+
3493+
def unapply(tp: Type)(using Context): Option[Type] = tp match {
3494+
case ft: FlexibleType => Some(ft.original)
3495+
case _ => None
3496+
}
3497+
}
3498+
34433499
// --- AndType/OrType ---------------------------------------------------------------
34443500

34453501
abstract class AndOrType extends CachedGroundType with ValueType {
@@ -5941,6 +5997,8 @@ object Types extends TypeUtils {
59415997
samClass(tp.underlying)
59425998
case tp: AnnotatedType =>
59435999
samClass(tp.underlying)
6000+
case tp: FlexibleType =>
6001+
samClass(tp.superType)
59446002
case _ =>
59456003
NoSymbol
59466004

@@ -6071,6 +6129,8 @@ object Types extends TypeUtils {
60716129
tp.derivedJavaArrayType(elemtp)
60726130
protected def derivedExprType(tp: ExprType, restpe: Type): Type =
60736131
tp.derivedExprType(restpe)
6132+
protected def derivedFlexibleType(tp: FlexibleType, under: Type): Type =
6133+
tp.derivedFlexibleType(under)
60746134
// note: currying needed because Scala2 does not support param-dependencies
60756135
protected def derivedLambdaType(tp: LambdaType)(formals: List[tp.PInfo], restpe: Type): Type =
60766136
tp.derivedLambdaType(tp.paramNames, formals, restpe)
@@ -6194,6 +6254,9 @@ object Types extends TypeUtils {
61946254
case tp: OrType =>
61956255
derivedOrType(tp, this(tp.tp1), this(tp.tp2))
61966256

6257+
case tp: FlexibleType =>
6258+
derivedFlexibleType(tp, this(tp.underlying))
6259+
61976260
case tp: MatchType =>
61986261
val bound1 = this(tp.bound)
61996262
val scrut1 = atVariance(0)(this(tp.scrutinee))
@@ -6481,6 +6544,14 @@ object Types extends TypeUtils {
64816544
if (underlying.isExactlyNothing) underlying
64826545
else tp.derivedAnnotatedType(underlying, annot)
64836546
}
6547+
override protected def derivedFlexibleType(tp: FlexibleType, underlying: Type): Type =
6548+
underlying match {
6549+
case Range(lo, hi) =>
6550+
range(tp.derivedFlexibleType(lo), tp.derivedFlexibleType(hi))
6551+
case _ =>
6552+
if (underlying.isExactlyNothing) underlying
6553+
else tp.derivedFlexibleType(underlying)
6554+
}
64846555
override protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type =
64856556
parent match // TODO ^^^ handle ranges in capture sets as well
64866557
case Range(lo, hi) =>
@@ -6610,6 +6681,9 @@ object Types extends TypeUtils {
66106681
case tp: TypeVar =>
66116682
this(x, tp.underlying)
66126683

6684+
case tp: FlexibleType =>
6685+
this(x, tp.underlying)
6686+
66136687
case ExprType(restpe) =>
66146688
this(x, restpe)
66156689

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
272272
case tpe: OrType =>
273273
writeByte(ORtype)
274274
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }
275+
case tpe: FlexibleType =>
276+
writeByte(FLEXIBLEtype)
277+
withLength { pickleType(tpe.underlying, richTypes) }
275278
case tpe: ExprType =>
276279
writeByte(BYNAMEtype)
277280
pickleType(tpe.underlying)

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,8 @@ class TreeUnpickler(reader: TastyReader,
444444
readTypeRef() match {
445445
case binder: LambdaType => binder.paramRefs(readNat())
446446
}
447+
case FLEXIBLEtype =>
448+
FlexibleType(readType())
447449
}
448450
assert(currentAddr == end, s"$start $currentAddr $end ${astTagToString(tag)}")
449451
result

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ class PlainPrinter(_ctx: Context) extends Printer {
294294
&& !printDebug
295295
then atPrec(GlobalPrec)( Str("into ") ~ toText(tpe) )
296296
else toTextLocal(tpe) ~ " " ~ toText(annot)
297+
case FlexibleType(tpe) =>
298+
"(" ~ toText(tpe) ~ ")?"
297299
case tp: TypeVar =>
298300
def toTextCaret(tp: Type) = if printDebug then toTextLocal(tp) ~ Str("^") else toText(tp)
299301
if (tp.isInstantiated)

0 commit comments

Comments
 (0)