@@ -319,11 +319,29 @@ object Nullables:
319
319
if ! info.isEmpty then tree.putAttachment(NNInfo , info)
320
320
tree
321
321
322
+ /* Collect the nullability info from parts of `tree` */
323
+ def collectNotNullInfo (using Context ): NotNullInfo = tree match
324
+ case Typed (expr, _) =>
325
+ expr.notNullInfo
326
+ case Apply (fn, args) =>
327
+ val argsInfo = args.map(_.notNullInfo)
328
+ val fnInfo = fn.notNullInfo
329
+ argsInfo.foldLeft(fnInfo)(_ seq _)
330
+ case TypeApply (fn, _) =>
331
+ fn.notNullInfo
332
+ case _ =>
333
+ // Other cases are handled specially in typer.
334
+ NotNullInfo .empty
335
+
322
336
/* The nullability info of `tree` */
323
337
def notNullInfo (using Context ): NotNullInfo =
324
- stripInlined(tree).getAttachment(NNInfo ) match
338
+ val tree1 = stripInlined(tree)
339
+ tree1.getAttachment(NNInfo ) match
325
340
case Some (info) if ! ctx.erasedTypes => info
326
- case _ => NotNullInfo .empty
341
+ case _ =>
342
+ val nnInfo = tree1.collectNotNullInfo
343
+ tree1.withNotNullInfo(nnInfo)
344
+ nnInfo
327
345
328
346
/* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
329
347
def notNullInfoIf (c : Boolean )(using Context ): NotNullInfo =
@@ -404,21 +422,23 @@ object Nullables:
404
422
end extension
405
423
406
424
extension (tree : Assign )
407
- def computeAssignNullable ()(using Context ): tree.type = tree.lhs match
408
- case TrackedRef (ref) =>
409
- val rhstp = tree.rhs.typeOpt
410
- if ctx.explicitNulls && ref.isNullableUnion then
411
- if rhstp.isNullType || rhstp.isNullableUnion then
412
- // If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
413
- // lhs variable is no longer trackable. We don't need to check whether the type `T`
414
- // is correct here, as typer will check it.
415
- tree.withNotNullInfo(NotNullInfo (Set (), Set (ref)))
416
- else
417
- // If the initial type is nullable and the assigned value is non-null,
418
- // we add it to the NotNull.
419
- tree.withNotNullInfo(NotNullInfo (Set (ref), Set ()))
420
- else tree
421
- case _ => tree
425
+ def computeAssignNullable ()(using Context ): tree.type =
426
+ var nnInfo = tree.rhs.notNullInfo
427
+ tree.lhs match
428
+ case TrackedRef (ref) if ctx.explicitNulls && ref.isNullableUnion =>
429
+ nnInfo = nnInfo.seq:
430
+ val rhstp = tree.rhs.typeOpt
431
+ if rhstp.isNullType || rhstp.isNullableUnion then
432
+ // If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
433
+ // lhs variable is no longer trackable. We don't need to check whether the type `T`
434
+ // is correct here, as typer will check it.
435
+ NotNullInfo (Set (), Set (ref))
436
+ else
437
+ // If the initial type is nullable and the assigned value is non-null,
438
+ // we add it to the NotNull.
439
+ NotNullInfo (Set (ref), Set ())
440
+ case _ =>
441
+ tree.withNotNullInfo(nnInfo)
422
442
end extension
423
443
424
444
private val analyzedOps = Set (nme.EQ , nme.NE , nme.eq, nme.ne, nme.ZAND , nme.ZOR , nme.UNARY_! )
0 commit comments