Skip to content

Warn if interpolator uses toString #20578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/config/ScalaSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ private sealed trait WarningSettings:
private val WenumCommentDiscard = BooleanSetting(WarningSetting, "Wenum-comment-discard", "Warn when a comment ambiguously assigned to multiple enum cases is discarded.")
private val WimplausiblePatterns = BooleanSetting(WarningSetting, "Wimplausible-patterns", "Warn if comparison with a pattern value looks like it might always fail.")
private val WunstableInlineAccessors = BooleanSetting(WarningSetting, "WunstableInlineAccessors", "Warn an inline methods has references to non-stable binary APIs.")
private val WtoStringInterpolated = BooleanSetting(WarningSetting, "Wtostring-interpolated", "Warn a standard interpolator used toString on a reference type.")
private val Wunused: Setting[List[ChoiceWithHelp[String]]] = MultiChoiceHelpSetting(
WarningSetting,
name = "Wunused",
Expand Down Expand Up @@ -308,6 +309,7 @@ private sealed trait WarningSettings:
def enumCommentDiscard(using Context): Boolean = allOr(WenumCommentDiscard)
def implausiblePatterns(using Context): Boolean = allOr(WimplausiblePatterns)
def unstableInlineAccessors(using Context): Boolean = allOr(WunstableInlineAccessors)
def toStringInterpolated(using Context): Boolean = allOr(WtoStringInterpolated)
def checkInit(using Context): Boolean = allOr(WcheckInit)

/** -X "Extended" or "Advanced" settings */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
case EnumMayNotBeValueClassesID // errorNumber: 206
case IllegalUnrollPlacementID // errorNumber: 207
case ExtensionHasDefaultID // errorNumber: 208
case FormatInterpolationErrorID // errorNumber: 209

def errorNumber = ordinal - 1

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/reporting/MessageKind.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ enum MessageKind:
case PotentialIssue
case UnusedSymbol
case Staging
case Interpolation

/** Human readable message that will end up being shown to the user.
* NOTE: This is only used in the situation where you have multiple words
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3444,3 +3444,8 @@ extends DeclarationMsg(IllegalUnrollPlacementID):

def explain(using Context) = ""
end IllegalUnrollPlacement

class BadFormatInterpolation(errorText: String)(using Context) extends Message(FormatInterpolationErrorID):
def kind = MessageKind.Interpolation
def msg(using Context) = errorText
def explain(using Context) = ""
179 changes: 103 additions & 76 deletions compiler/src/dotty/tools/dotc/transform/localopt/FormatChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.util.matching.Regex.Match

import PartialFunction.cond

import dotty.tools.dotc.ast.tpd.{Match => _, *}
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.Phases.typerPhase
import dotty.tools.dotc.reporting.BadFormatInterpolation
import dotty.tools.dotc.util.Spans.Span
import dotty.tools.dotc.util.chaining.*

Expand All @@ -29,8 +28,9 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
def argType(argi: Int, types: Type*): Type =
require(argi < argc, s"$argi out of range picking from $types")
val tpe = argTypes(argi)
types.find(t => argConformsTo(argi, tpe, t))
.orElse(types.find(t => argConvertsTo(argi, tpe, t)))
types.find(t => t != defn.AnyType && argConformsTo(argi, tpe, t))
.orElse(types.find(t => t != defn.AnyType && argConvertsTo(argi, tpe, t)))
.orElse(types.find(t => t == defn.AnyType && argConformsTo(argi, tpe, t)))
.getOrElse {
report.argError(s"Found: ${tpe.show}, Required: ${types.map(_.show).mkString(", ")}", argi)
actuals += args(argi)
Expand Down Expand Up @@ -63,50 +63,57 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List

/** For N part strings and N-1 args to interpolate, normalize parts and check arg types.
*
* Returns normalized part strings and args, where args correcpond to conversions in tail of parts.
* Returns normalized part strings and args, where args correspond to conversions in tail of parts.
*/
def checked: (List[String], List[Tree]) =
val amended = ListBuffer.empty[String]
val convert = ListBuffer.empty[Conversion]

def checkPart(part: String, n: Int): Unit =
val matches = formatPattern.findAllMatchIn(part)

def insertStringConversion(): Unit =
amended += "%s" + part
val cv = Conversion.stringXn(n)
cv.accepts(argType(n-1, defn.AnyType))
convert += cv
cv.lintToString(argTypes(n-1))

def errorLeading(op: Conversion) = op.errorAt(Spec):
s"conversions must follow a splice; ${Conversion.literalHelp}"

def accept(op: Conversion): Unit =
if !op.isLeading then errorLeading(op)
op.accepts(argType(n-1, op.acceptableVariants*))
amended += part
convert += op
op.lintToString(argTypes(n-1))

// after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed
if n == 0 then amended += part
else if !matches.hasNext then insertStringConversion()
else
val cv = Conversion(matches.next(), n)
if cv.isLiteral then insertStringConversion()
else if cv.isIndexed then
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion()
else if !cv.isError then accept(cv)

// any remaining conversions in this part must be either literals or indexed
while matches.hasNext do
val cv = Conversion(matches.next(), n)
if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg")
else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv)
end checkPart

@tailrec
def loop(remaining: List[String], n: Int): Unit =
remaining match
case part0 :: more =>
def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage.nn, index = n, offset = 0))
val part = try StringContext.processEscapes(part0) catch badPart
val matches = formatPattern.findAllMatchIn(part)

def insertStringConversion(): Unit =
amended += "%s" + part
convert += Conversion(formatPattern.findAllMatchIn("%s").next(), n) // improve
argType(n-1, defn.AnyType)
def errorLeading(op: Conversion) = op.errorAt(Spec)(s"conversions must follow a splice; ${Conversion.literalHelp}")
def accept(op: Conversion): Unit =
if !op.isLeading then errorLeading(op)
op.accepts(argType(n-1, op.acceptableVariants*))
amended += part
convert += op

// after the first part, a leading specifier is required for the interpolated arg; %s is supplied if needed
if n == 0 then amended += part
else if !matches.hasNext then insertStringConversion()
else
val cv = Conversion(matches.next(), n)
if cv.isLiteral then insertStringConversion()
else if cv.isIndexed then
if cv.index.getOrElse(-1) == n then accept(cv) else insertStringConversion()
else if !cv.isError then accept(cv)

// any remaining conversions in this part must be either literals or indexed
while matches.hasNext do
val cv = Conversion(matches.next(), n)
if n == 0 && cv.hasFlag('<') then cv.badFlag('<', "No last arg")
else if !cv.isLiteral && !cv.isIndexed then errorLeading(cv)

loop(more, n + 1)
case Nil => ()
end loop
def loop(remaining: List[String], n: Int): Unit = remaining match
case part0 :: remaining =>
def badPart(t: Throwable): String = "".tap(_ => report.partError(t.getMessage.nn, index = n, offset = 0))
val part = try StringContext.processEscapes(part0) catch badPart
checkPart(part, n)
loop(remaining, n + 1)
case Nil =>

loop(parts, n = 0)
if reported then (Nil, Nil)
Expand All @@ -124,10 +131,8 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
def intOf(g: SpecGroup): Option[Int] = group(g).map(_.toInt)

extension (inline value: Boolean)
inline def or(inline body: => Unit): Boolean = value || { body ; false }
inline def orElse(inline body: => Unit): Boolean = value || { body ; true }
inline def and(inline body: => Unit): Boolean = value && { body ; true }
inline def but(inline body: => Unit): Boolean = value && { body ; false }
inline infix def or(inline body: => Unit): Boolean = value || { body; false }
inline infix def and(inline body: => Unit): Boolean = value && { body; true }

enum Kind:
case StringXn, HashXn, BooleanXn, CharacterXn, IntegralXn, FloatingPointXn, DateTimeXn, LiteralXn, ErrorXn
Expand All @@ -146,9 +151,10 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
// the conversion char is the head of the op string (but see DateTimeXn)
val cc: Char =
kind match
case ErrorXn => if op.isEmpty then '?' else op(0)
case DateTimeXn => if op.length > 1 then op(1) else '?'
case _ => op(0)
case ErrorXn => if op.isEmpty then '?' else op(0)
case DateTimeXn => if op.length <= 1 then '?' else op(1)
case StringXn => if op.isEmpty then 's' else op(0) // accommodate the default %s
case _ => op(0)

def isIndexed: Boolean = index.nonEmpty || hasFlag('<')
def isError: Boolean = kind == ErrorXn
Expand Down Expand Up @@ -208,18 +214,28 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
// is the specifier OK with the given arg
def accepts(arg: Type): Boolean =
kind match
case BooleanXn => arg == defn.BooleanType orElse warningAt(CC)("Boolean format is null test for non-Boolean")
case IntegralXn =>
arg == BigIntType || !cond(cc) {
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") => "+ (".filter(hasFlag).foreach(bad => badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")) ; true
}
case BooleanXn if arg != defn.BooleanType =>
warningAt(CC):
"""non-Boolean value formats as "true" for non-null references and boxed primitives, otherwise "false""""
true
case IntegralXn if arg != BigIntType =>
cc match
case 'o' | 'x' | 'X' if hasAnyFlag("+ (") =>
"+ (".filter(hasFlag).foreach: bad =>
badFlag(bad, s"only use '$bad' for BigInt conversions to o, x, X")
false
case _ => true
case _ => true

def lintToString(arg: Type): Unit =
if ctx.settings.Whas.toStringInterpolated && kind == StringXn && !(arg.widen =:= defn.StringType) && !arg.isPrimitiveValueType
then warningAt(CC)("interpolation uses toString")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably be better, when done as a separate PR, but it might be useful to add message ids for all the error/warning messages in this file.


// what arg type if any does the conversion accept
def acceptableVariants: List[Type] =
kind match
case StringXn => if hasFlag('#') then FormattableType :: Nil else defn.AnyType :: Nil
case BooleanXn => defn.BooleanType :: defn.NullType :: Nil
case BooleanXn => defn.BooleanType :: defn.NullType :: defn.AnyType :: Nil // warn if not boolean
case HashXn => defn.AnyType :: Nil
case CharacterXn => defn.CharType :: defn.ByteType :: defn.ShortType :: defn.IntType :: Nil
case IntegralXn => defn.IntType :: defn.LongType :: defn.ByteType :: defn.ShortType :: BigIntType :: Nil
Expand Down Expand Up @@ -248,25 +264,30 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List

object Conversion:
def apply(m: Match, i: Int): Conversion =
def kindOf(cc: Char) = cc match
case 's' | 'S' => StringXn
case 'h' | 'H' => HashXn
case 'b' | 'B' => BooleanXn
case 'c' | 'C' => CharacterXn
case 'd' | 'o' |
'x' | 'X' => IntegralXn
case 'e' | 'E' |
'f' |
'g' | 'G' |
'a' | 'A' => FloatingPointXn
case 't' | 'T' => DateTimeXn
case '%' | 'n' => LiteralXn
case _ => ErrorXn
end kindOf
m.group(CC) match
case Some(cc) => new Conversion(m, i, kindOf(cc(0))).tap(_.verify)
case None => new Conversion(m, i, ErrorXn).tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp"))
case Some(cc) =>
val xn = cc(0) match
case 's' | 'S' => StringXn
case 'h' | 'H' => HashXn
case 'b' | 'B' => BooleanXn
case 'c' | 'C' => CharacterXn
case 'd' | 'o' |
'x' | 'X' => IntegralXn
case 'e' | 'E' |
'f' |
'g' | 'G' |
'a' | 'A' => FloatingPointXn
case 't' | 'T' => DateTimeXn
case '%' | 'n' => LiteralXn
case _ => ErrorXn
new Conversion(m, i, xn)
.tap(_.verify)
case None =>
new Conversion(m, i, ErrorXn)
.tap(_.errorAt(Spec)(s"Missing conversion operator in '${m.matched}'; $literalHelp"))
end apply
// construct a default %s conversion
def stringXn(i: Int): Conversion = new Conversion(formatPattern.findAllMatchIn("%").next(), i, StringXn)
val literalHelp = "use %% for literal %, %n for newline"
end Conversion

Expand All @@ -276,10 +297,16 @@ class TypedFormatChecker(partsElems: List[Tree], parts: List[String], args: List
val pos = partsElems(index).sourcePos
val bgn = pos.span.start + offset
val fin = if end < 0 then pos.span.end else pos.span.start + end
pos.withSpan(Span(bgn, fin, bgn))
pos.withSpan(Span(start = bgn, end = fin, point = bgn))

extension (r: report.type)
def argError(message: String, index: Int): Unit = r.error(message, args(index).srcPos).tap(_ => reported = true)
def partError(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.error(message, partPosAt(index, offset, end)).tap(_ => reported = true)
def partWarning(message: String, index: Int, offset: Int, end: Int = -1): Unit = r.warning(message, partPosAt(index, offset, end)).tap(_ => reported = true)
def argError(message: String, index: Int): Unit =
r.error(BadFormatInterpolation(message), args(index).srcPos)
.tap(_ => reported = true)
def partError(message: String, index: Int, offset: Int, end: Int = -1): Unit =
r.error(BadFormatInterpolation(message), partPosAt(index, offset, end))
.tap(_ => reported = true)
def partWarning(message: String, index: Int, offset: Int, end: Int): Unit =
r.warning(BadFormatInterpolation(message), partPosAt(index, offset, end))
.tap(_ => reported = true)
end TypedFormatChecker
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,22 @@ class StringInterpolatorOpt extends MiniPhase:
def mkConcat(strs: List[Literal], elems: List[Tree]): Tree =
val stri = strs.iterator
val elemi = elems.iterator
var result: Tree = stri.next
var result: Tree = stri.next()
def concat(tree: Tree): Unit =
result = result.select(defn.String_+).appliedTo(tree).withSpan(tree.span)
while elemi.hasNext
do
concat(elemi.next)
val str = stri.next
val elem = elemi.next()
lintToString(elem)
concat(elem)
val str = stri.next()
if !str.const.stringValue.isEmpty then concat(str)
result
end mkConcat
def lintToString(t: Tree): Unit =
val arg: Type = t.tpe
if ctx.settings.Whas.toStringInterpolated && !(arg.widen =:= defn.StringType) && !arg.isPrimitiveValueType
then report.warning("interpolation uses toString", t.srcPos)
val sym = tree.symbol
// Test names first to avoid loading scala.StringContext if not used, and common names first
val isInterpolatedMethod =
Expand Down
Loading
Loading