Skip to content

Commit adcfda4

Browse files
committed
Fix bugs in the PR
1 parent 9b79bb9 commit adcfda4

File tree

8 files changed

+85
-61
lines changed

8 files changed

+85
-61
lines changed

compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import Contexts._
66
import Flags._
77
import Symbols._
88
import Types._
9+
import transform.SymUtils._
910

1011
/** Defines operations on nullable types and tree. */
1112
object NullOpsDecorator:
@@ -80,23 +81,26 @@ object NullOpsDecorator:
8081
case _ => tree
8182

8283
def tryToCastToCanEqualNull(using Context): Tree =
84+
// return the tree directly if not at Typer phase
85+
if !(ctx.explicitNulls && ctx.phase.isTyper) then return tree
86+
8387
val sym = tree.symbol
8488
val tp = tree.tpe
8589

8690
if !ctx.mode.is(Mode.UnsafeJavaReturn)
8791
|| !sym.is(JavaDefined)
88-
|| sym.is(Package)
92+
|| sym.isNoValue
8993
|| !sym.isTerm
9094
|| tp.isError then
9195
return tree
9296

9397
tree match
94-
case _: Apply if sym.is(Method) && !sym.isConstructor =>
98+
case _: Apply if sym.is(Method) =>
9599
val tp2 = tp.replaceOrNull
96100
if tp ne tp2 then
97101
tree.cast(tp2)
98102
else tree
99-
case _: Select if !sym.is(Method) =>
103+
case _: Select | _: Ident if !sym.is(Method) =>
100104
val tpw = tp.widen
101105
val tp2 = tpw.replaceOrNull
102106
if tpw ne tp2 then

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
547547
ref(ownType).withSpan(tree.span)
548548
case _ =>
549549
tree.withType(ownType)
550-
val tree2 = toNotNullTermRef(tree1, pt)
550+
val tree2 = toNotNullTermRef(tree1, pt).tryToCastToCanEqualNull
551551
checkLegalValue(tree2, pt)
552552
tree2
553553

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import language.unsafeJavaReturn
2+
3+
import java.math.MathContext, MathContext._
4+
5+
val x: MathContext = DECIMAL32
6+
val y: MathContext = MathContext.DECIMAL32
7+
8+
import java.io.File
9+
10+
val s: String = File.separator
11+
import java.time.ZoneId
12+
13+
val zids: java.util.Set[String] = ZoneId.getAvailableZoneIds
14+
val zarr: Array[String] = ZoneId.getAvailableZoneIds.toArray(Array.empty[String | Null])
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
import scala.language.unsafeJavaReturn
22

3-
val s = "foo"
4-
val methods: Array[java.lang.reflect.Method] = s.getClass.getMethods
3+
import java.lang.reflect.Method
4+
5+
def getMethods(f: String): List[Method] =
6+
val clazz = Class.forName(f)
7+
val methods = clazz.getMethods
8+
if methods == null then List()
9+
else methods.toList
10+
11+
def getClass(o: AnyRef): Class[?] = o.getClass

tests/explicit-nulls/unsafe-java/java-chain/J.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ class J1 {
44

55
class J2 {
66
J1 getJ1() { return new J1(); }
7-
}
7+
}
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
class S {
1+
import scala.language.unsafeJavaReturn
2+
3+
def f = {
24
val j: J2 = new J2()
3-
j.getJ1().getJ2().getJ1().getJ2().getJ1().getJ2() // error
5+
j.getJ1().getJ2().getJ1().getJ2().getJ1().getJ2()
46
}

tests/explicit-nulls/unsafe-java/java-class/J.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import java.util.List;
22

3-
public class JC {
3+
public class J {
44

55
public int a;
66

@@ -39,4 +39,4 @@ public <T> List<T> h2() {
3939
public <T> T[] h3() {
4040
return null;
4141
}
42-
}
42+
}

tests/explicit-nulls/unsafe-java/java-class/S.scala

+47-50
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,50 @@ import scala.language.unsafeJavaReturn
33
import scala.annotation.CanEqualNull
44
import java.{util => ju}
55

6-
class S {
7-
8-
def test[T <: AnyRef](jc: JC) = {
9-
val a: Int = jc.a
10-
11-
val b = jc.b // it returns String @CanEqualNull
12-
val b2: String = b
13-
val b3: String @CanEqualNull = jc.b
14-
val b4: String = jc.b
15-
val bb = jc.b == null // it's ok to compare String @CanEqualNull with Null
16-
val btl = jc.b.trim().length() // String @CanEqualNull is just String, unsafe selecting
17-
18-
val c = jc.c
19-
val cl = c.length
20-
val c2: Array[String] = c
21-
val c3: Array[String @CanEqualNull] @CanEqualNull = jc.c
22-
val c4: Array[String] = jc.c
23-
val cml: Array[Int] = c.map(_.length())
24-
25-
val f1: Int = jc.f1()
26-
27-
val f21: Array[Int] @CanEqualNull = jc.f2()
28-
val f22: Array[Int] = jc.f2()
29-
val f2n = jc.f2() == null
30-
31-
val g11: String @CanEqualNull = jc.g1()
32-
val g12: String = jc.g1()
33-
val g1n = jc.g1() == null
34-
val g1tl = jc.g1().trim().length()
35-
36-
val g21: ju.List[String] @CanEqualNull = jc.g2()
37-
val g22: ju.List[String] = jc.g2()
38-
39-
val g31: Array[String @CanEqualNull] @CanEqualNull = jc.g3()
40-
val g32: Array[String] = jc.g3()
41-
val g3n = jc.g3() == null
42-
val g3m: Array[Boolean] = jc.g3().map(_ == null)
43-
44-
val h11: T @CanEqualNull = jc.h1[T]()
45-
val h12: T = jc.h1[T]()
46-
val h1n = jc.h1[T]() == null
47-
48-
val h21: ju.List[T] @CanEqualNull = jc.h2[T]()
49-
val h22: ju.List[T] = jc.h2[T]()
50-
51-
val h31: Array[T @CanEqualNull] @CanEqualNull = jc.h3[T]()
52-
val h32: Array[T] = jc.h3[T]()
53-
val h3m = jc.h3[T]().map(_ == null)
54-
}
55-
}
6+
def test[T <: AnyRef](j: J) = {
7+
val a: Int = j.a
8+
9+
val b = j.b // it returns String @CanEqualNull
10+
val b2: String = b
11+
val b3: String @CanEqualNull = j.b
12+
val b4: String = j.b
13+
val bb = j.b == null // it's ok to compare String @CanEqualNull with Null
14+
val btl = j.b.trim().length() // String @CanEqualNull is just String, unsafe selecting
15+
16+
val c = j.c
17+
val cl = c.length
18+
val c2: Array[String] = c
19+
val c3: Array[String @CanEqualNull] @CanEqualNull = j.c
20+
val c4: Array[String] = j.c
21+
val cml: Array[Int] = c.map(_.length())
22+
23+
val f1: Int = j.f1()
24+
25+
val f21: Array[Int] @CanEqualNull = j.f2()
26+
val f22: Array[Int] = j.f2()
27+
val f2n = j.f2() == null
28+
29+
val g11: String @CanEqualNull = j.g1()
30+
val g12: String = j.g1()
31+
val g1n = j.g1() == null
32+
val g1tl = j.g1().trim().length()
33+
34+
val g21: ju.List[String] @CanEqualNull = j.g2()
35+
val g22: ju.List[String] = j.g2()
36+
37+
val g31: Array[String @CanEqualNull] @CanEqualNull = j.g3()
38+
val g32: Array[String] = j.g3()
39+
val g3n = j.g3() == null
40+
val g3m: Array[Boolean] = j.g3().map(_ == null)
41+
42+
val h11: T @CanEqualNull = j.h1[T]()
43+
val h12: T = j.h1[T]()
44+
val h1n = j.h1[T]() == null
45+
46+
val h21: ju.List[T] @CanEqualNull = j.h2[T]()
47+
val h22: ju.List[T] = j.h2[T]()
48+
49+
val h31: Array[T @CanEqualNull] @CanEqualNull = j.h3[T]()
50+
val h32: Array[T] = j.h3[T]()
51+
val h3m = j.h3[T]().map(_ == null)
52+
}

0 commit comments

Comments
 (0)