Skip to content

Commit 6287ff2

Browse files
authored
Extract the function type and SAM in union type (#11760)
Fixes #11694.
1 parent aa5ed33 commit 6287ff2

File tree

4 files changed

+59
-14
lines changed

4 files changed

+59
-14
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

+18
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,24 @@ object Types {
16511651
case _ => resultType
16521652
}
16531653

1654+
/** Find the function type in union.
1655+
* If there are multiple function types, NoType is returned.
1656+
*/
1657+
def findFunctionTypeInUnion(using Context): Type = this match {
1658+
case t: OrType =>
1659+
val t1 = t.tp1.findFunctionTypeInUnion
1660+
if t1 == NoType then t.tp2.findFunctionTypeInUnion else
1661+
val t2 = t.tp2.findFunctionTypeInUnion
1662+
// Returen NoType if the union contains multiple function types
1663+
if t2 == NoType then t1 else NoType
1664+
case t if defn.isNonRefinedFunction(t) =>
1665+
t
1666+
case t @ SAMType(_) =>
1667+
t
1668+
case _ =>
1669+
NoType
1670+
}
1671+
16541672
/** This type seen as a TypeBounds */
16551673
final def bounds(using Context): TypeBounds = this match {
16561674
case tp: TypeBounds => tp

compiler/src/dotty/tools/dotc/typer/Typer.scala

+18-14
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,7 @@ class Typer extends Namer
11251125
newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds)
11261126
case _ => mapOver(t)
11271127
}
1128+
11281129
val pt1 = pt.stripTypeVar.dealias
11291130
if (pt1 ne pt1.dropDependentRefinement)
11301131
&& defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType)
@@ -1133,22 +1134,25 @@ class Typer extends Namer
11331134
i"""Implementation restriction: Expected result type $pt1
11341135
|is a curried dependent context function type. Such types are not yet supported.""",
11351136
tree.srcPos)
1137+
11361138
pt1 match {
1137-
case pt1 if defn.isNonRefinedFunction(pt1) =>
1138-
// if expected parameter type(s) are wildcards, approximate from below.
1139-
// if expected result type is a wildcard, approximate from above.
1140-
// this can type the greatest set of admissible closures.
1141-
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
1142-
case SAMType(sam @ MethodTpe(_, formals, restpe)) =>
1143-
(formals,
1144-
if (sam.isResultDependent)
1145-
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
1146-
else
1147-
typeTree(restpe))
11481139
case tp: TypeParamRef =>
11491140
decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree)
1150-
case _ =>
1151-
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
1141+
case _ => pt1.findFunctionTypeInUnion match {
1142+
case pt1 if defn.isNonRefinedFunction(pt1) =>
1143+
// if expected parameter type(s) are wildcards, approximate from below.
1144+
// if expected result type is a wildcard, approximate from above.
1145+
// this can type the greatest set of admissible closures.
1146+
(pt1.argTypesLo.init, typeTree(interpolateWildcards(pt1.argTypesHi.last)))
1147+
case SAMType(sam @ MethodTpe(_, formals, restpe)) =>
1148+
(formals,
1149+
if sam.isResultDependent then
1150+
untpd.DependentTypeTree(syms => restpe.substParams(sam, syms.map(_.termRef)))
1151+
else
1152+
typeTree(restpe))
1153+
case _ =>
1154+
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
1155+
}
11521156
}
11531157
}
11541158

@@ -1399,7 +1403,7 @@ class Typer extends Namer
13991403
if (tree.tpt.isEmpty)
14001404
meth1.tpe.widen match {
14011405
case mt: MethodType =>
1402-
pt.stripNull match {
1406+
pt.findFunctionTypeInUnion match {
14031407
case pt @ SAMType(sam)
14041408
if !defn.isFunctionType(pt) && mt <:< sam =>
14051409
// SAMs of the form C[?] where C is a class cannot be conversion targets.

tests/explicit-nulls/pos/i11694.scala

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def test = {
2+
val x = new java.util.ArrayList[String]()
3+
val y = x.stream().nn.filter(s => s.nn.length > 0)
4+
}

tests/neg/i11694.scala

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
def test1 = {
2+
def f11: (Int => Int) | Unit = x => x + 1
3+
def f12: Null | (Int => Int) = x => x + 1
4+
5+
def f21: (Int => Int) | Null = x => x + 1
6+
def f22: Null | (Int => Int) = x => x + 1
7+
}
8+
9+
def test2 = {
10+
def f1: (Int => String) | (Int => Int) | Null = x => x + 1 // error
11+
def f2: (Int => String) | Function[String, Int] | Null = x => "" + x // error
12+
def f3: Function[Int, Int] | Function[String, Int] | Null = x => x + 1 // error
13+
}
14+
15+
def test3 = {
16+
import java.util.function.Function
17+
val f1: Function[String, Int] | Unit = x => x.length
18+
val f2: Function[String, Int] | Null = x => x.length
19+
}

0 commit comments

Comments
 (0)