Skip to content

Revive function specialization #10452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,23 @@ class Compiler {
new CacheAliasImplicits, // Cache RHS of parameterless alias implicits
new ByNameClosures, // Expand arguments to by-name parameters to closures
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
new SpecializeApplyMethods, // Adds specialized methods to FunctionN
new RefChecks) :: // Various checks mostly related to abstract members and overriding
List(new ElimOpaque, // Turn opaque into normal aliases
new TryCatchPatterns, // Compile cases in try/catch
new PatternMatcher, // Compile pattern matches
new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only)
new ExplicitOuter, // Add accessors to outer classes from nested ones.
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
new ElimByName, // Expand by-name parameter references
new StringInterpolatorOpt) :: // Optimizes raw and s string interpolators by rewriting them to string concatentations
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
new InlinePatterns, // Remove placeholders of inlined patterns
new VCInlineMethods, // Inlines calls to value class methods
new SeqLiterals, // Express vararg arguments as arrays
new InterceptedMethods, // Special handling of `==`, `|=`, `getClass` methods
new Getters, // Replace non-private vals and vars with getter defs (fields are added later)
new ElimByName, // Expand by-name parameter references
new SpecializeFunctions, // Specialized Function{0,1,2} by replacing super with specialized super
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
new CollectNullableFields, // Collect fields that can be nulled out after use in lazy initialization
new ElimOuterSelect, // Expand outer selections
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,8 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
def tupleArgs(tree: Tree)(using Context): List[Tree] = tree match {
case Block(Nil, expr) => tupleArgs(expr)
case Inlined(_, Nil, expr) => tupleArgs(expr)
case Apply(fn, args)
if fn.symbol.name == nme.apply &&
case Apply(fn: NameTree, args)
if fn.name == nme.apply &&
fn.symbol.owner.is(Module) &&
defn.isTupleClass(fn.symbol.owner.companionClass) => args
case _ => Nil
Expand Down
32 changes: 30 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,11 @@ class Definitions {
else funType(n)
).symbol.asClass

@tu lazy val Function0_apply: Symbol = FunctionClass(0).requiredMethod(nme.apply)
@tu lazy val Function0_apply: Symbol = Function0.requiredMethod(nme.apply)

@tu lazy val Function0: Symbol = FunctionClass(0)
@tu lazy val Function1: Symbol = FunctionClass(1)
@tu lazy val Function2: Symbol = FunctionClass(2)

def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false)(using Context): TypeRef =
FunctionClass(n, isContextual && !ctx.erasedTypes, isErased).typeRef
Expand Down Expand Up @@ -1244,7 +1248,7 @@ class Definitions {

def isBottomClassAfterErasure(cls: Symbol): Boolean = cls == NothingClass || cls == NullClass

/** Is a function class.
/** Is any function class where
* - FunctionXXL
* - FunctionN for N >= 0
* - ContextFunctionN for N >= 0
Expand All @@ -1253,6 +1257,11 @@ class Definitions {
*/
def isFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isFunction

/** Is a function class where
* - FunctionN for N >= 0 and N != XXL
*/
def isPlainFunctionClass(cls: Symbol) = isVarArityClass(cls, str.Function)

/** Is an context function class.
* - ContextFunctionN for N >= 0
* - ErasedContextFunctionN for N > 0
Expand Down Expand Up @@ -1488,6 +1497,25 @@ class Definitions {
false
})

@tu lazy val Function0SpecializedApplyNames: collection.Set[TermName] =
for r <- Function0SpecializedReturnTypes
yield nme.apply.specializedFunction(r, Nil).asTermName

@tu lazy val Function1SpecializedApplyNames: collection.Set[TermName] =
for
r <- Function1SpecializedReturnTypes
t1 <- Function1SpecializedParamTypes
yield
nme.apply.specializedFunction(r, List(t1)).asTermName

@tu lazy val Function2SpecializedApplyNames: collection.Set[TermName] =
for
r <- Function2SpecializedReturnTypes
t1 <- Function2SpecializedParamTypes
t2 <- Function2SpecializedParamTypes
yield
nme.apply.specializedFunction(r, List(t1, t2)).asTermName

def functionArity(tp: Type)(using Context): Int = tp.dropDependentRefinement.dealias.argInfos.length - 1

/** Return underlying context function type (i.e. instance of an ContextFunctionN class)
Expand Down
25 changes: 24 additions & 1 deletion compiler/src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ object NameOps {
def isFunction: Boolean =
(name eq tpnme.FunctionXXL) || checkedFunArity(functionSuffixStart) >= 0

/** Is a function name
* - FunctionN for N >= 0
*/
def isPlainFunction: Boolean = functionArity >= 0

/** Is an context function name, i.e one of ContextFunctionN or ErasedContextFunctionN for N >= 0
*/
def isContextFunction: Boolean =
Expand Down Expand Up @@ -276,8 +281,10 @@ object NameOps {
case nme.clone_ => nme.clone_
}

/** This method is to be used on **type parameters** from a class, since
* this method does sorting based on their names
*/
def specializedFor(classTargs: List[Type], classTargsNames: List[Name], methodTargs: List[Type], methodTarsNames: List[Name])(using Context): N = {

val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => defn.typeTag(x._1))

Expand All @@ -286,6 +293,22 @@ object NameOps {
classTags.fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix)
}

/** Use for specializing function names ONLY and use it if you are **not**
* creating specialized name from type parameters. The order of names will
* be:
*
* `<return type><first type><second type><...>`
*/
def specializedFunction(ret: Type, args: List[Type])(using Context): Name =
val sb = new StringBuilder
sb.append(name.toString)
sb.append(nme.specializedTypeNames.prefix.toString)
sb.append(nme.specializedTypeNames.separator)
sb.append(defn.typeTag(ret).toString)
args.foreach { arg => sb.append(defn.typeTag(arg)) }
sb.append(nme.specializedTypeNames.suffix)
termName(sb.toString)

/** If name length exceeds allowable limit, replace part of it by hash */
def compactified(using Context): TermName = termName(compactify(name.toString))

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class ExpandSAMs extends MiniPhase {
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))

case _ =>
val found = tpe.baseType(defn.FunctionClass(1))
val found = tpe.baseType(defn.Function1)
report.error(TypeMismatch(found, tpe), tree.srcPos)
tree
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,20 @@ class FunctionXXLForwarders extends MiniPhase with IdentityDenotTransformer {
ref(receiver.symbol).appliedToArgss(argss).cast(defn.ObjectType)
}

if impl.symbol.owner.is(Trait) then return impl

val forwarders =
for {
tree <- if (impl.symbol.owner.is(Trait)) Nil else impl.body
if tree.symbol.is(Method) && tree.symbol.name == nme.apply &&
tree.symbol.signature.paramsSig.size > MaxImplementedFunctionArity &&
tree.symbol.allOverriddenSymbols.exists(sym => defn.isXXLFunctionClass(sym.owner))
(ddef: DefDef) <- impl.body
if ddef.name == nme.apply && ddef.symbol.is(Method) &&
ddef.symbol.signature.paramsSig.size > MaxImplementedFunctionArity &&
ddef.symbol.allOverriddenSymbols.exists(sym => defn.isXXLFunctionClass(sym.owner))
}
yield {
val xsType = defn.ArrayType.appliedTo(List(defn.ObjectType))
val methType = MethodType(List(nme.args))(_ => List(xsType), _ => defn.ObjectType)
val meth = newSymbol(tree.symbol.owner, nme.apply, Synthetic | Method, methType)
DefDef(meth, paramss => forwarderRhs(tree, paramss.head.head))
val meth = newSymbol(ddef.symbol.owner, nme.apply, Synthetic | Method, methType)
DefDef(meth, paramss => forwarderRhs(ddef, paramss.head.head))
}

cpy.Template(impl)(body = forwarders ::: impl.body)
Expand Down
118 changes: 118 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SpecializeApplyMethods.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package dotty.tools.dotc
package transform

import ast.Trees._, ast.tpd, core._
import Contexts._, Types._, Decorators._, Symbols._, DenotTransformers._
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
import MegaPhase.MiniPhase

import scala.collection.mutable


/** This phase synthesizes specialized methods for FunctionN, this is done
* since there are no scala signatures in the bytecode for the specialized
* methods.
*
* We know which specializations exist for the different arities, therefore we
* can hardcode them. This should, however be removed once we're using a
* different standard library.
*/
class SpecializeApplyMethods extends MiniPhase with InfoTransformer {
import ast.tpd._

val phaseName = "specializeApplyMethods"

override def isEnabled(using Context): Boolean =
!ctx.settings.scalajs.value

private def specApplySymbol(sym: Symbol, args: List[Type], ret: Type)(using Context): Symbol = {
val name = nme.apply.specializedFunction(ret, args)
// Create the symbol at the next phase, so that it is a valid member of the
// corresponding function for all valid periods of its SymDenotations.
// Otherwise, the valid period will offset by 1, which causes a stale symbol
// in compiling stdlib.
atNextPhase(newSymbol(sym, name, Flags.Method, MethodType(args, ret)))
}

private inline def specFun0(inline op: Type => Unit)(using Context): Unit = {
for (r <- defn.Function0SpecializedReturnTypes) do
op(r)
}

private inline def specFun1(inline op: (Type, Type) => Unit)(using Context): Unit = {
for
r <- defn.Function1SpecializedReturnTypes
t1 <- defn.Function1SpecializedParamTypes
do
op(t1, r)
}

private inline def specFun2(inline op: (Type, Type, Type) => Unit)(using Context): Unit = {
for
r <- defn.Function2SpecializedReturnTypes
t1 <- defn.Function2SpecializedParamTypes
t2 <- defn.Function2SpecializedParamTypes
do
op(t1, t2, r)
}

override def infoMayChange(sym: Symbol)(using Context) =
sym == defn.Function0
|| sym == defn.Function1
|| sym == defn.Function2

/** Add symbols for specialized methods to FunctionN */
override def transformInfo(tp: Type, sym: Symbol)(using Context) = tp match {
case tp: ClassInfo =>
if sym == defn.Function0 then
val scope = tp.decls.cloneScope
specFun0 { r => scope.enter(specApplySymbol(sym, Nil, r)) }
tp.derivedClassInfo(decls = scope)

else if sym == defn.Function1 then
val scope = tp.decls.cloneScope
specFun1 { (t1, r) => scope.enter(specApplySymbol(sym, t1 :: Nil, r)) }
tp.derivedClassInfo(decls = scope)

else if sym == defn.Function2 then
val scope = tp.decls.cloneScope
specFun2 { (t1, t2, r) => scope.enter(specApplySymbol(sym, t1 :: t2 :: Nil, r)) }
tp.derivedClassInfo(decls = scope)

else tp

case _ => tp
}

/** Create bridge methods for FunctionN with specialized applys */
override def transformTemplate(tree: Template)(using Context) = {
val cls = tree.symbol.owner.asClass

def synthesizeApply(names: collection.Set[TermName]): Tree = {
val applyBuf = new mutable.ListBuffer[DefDef]
names.foreach { name =>
val applySym = cls.info.decls.lookup(name)
val ddef = DefDef(
applySym.asTerm,
{ vparamss =>
This(cls)
.select(nme.apply)
.appliedToArgss(vparamss)
.ensureConforms(applySym.info.finalResultType)
}
)
applyBuf += ddef
}
cpy.Template(tree)(body = tree.body ++ applyBuf)
}

if cls == defn.Function0 then
synthesizeApply(defn.Function0SpecializedApplyNames)
else if cls == defn.Function1 then
synthesizeApply(defn.Function1SpecializedApplyNames)
else if cls == defn.Function2 then
synthesizeApply(defn.Function2SpecializedApplyNames)
else
tree
}
}
Loading