Skip to content

Commit f39539c

Browse files
committed
Added basic CheckedTypeAnnotation impl with compiler the plugin check
1 parent 657cace commit f39539c

File tree

14 files changed

+1132
-8
lines changed

14 files changed

+1132
-8
lines changed

compiler-plugin/compiler-plugin-common/src/main/core/kotlinx/rpc/codegen/common/Names.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.jetbrains.kotlin.name.Name
1111
object RpcClassId {
1212
val remoteServiceInterface = ClassId(FqName("kotlinx.rpc"), Name.identifier("RemoteService"))
1313
val rpcAnnotation = ClassId(FqName("kotlinx.rpc.annotations"), Name.identifier("Rpc"))
14+
val checkedTypeAnnotation = ClassId(FqName("kotlinx.rpc.annotations"), Name.identifier("CheckedTypeAnnotation"))
1415

1516
val serializableAnnotation = ClassId(FqName("kotlinx.serialization"), Name.identifier("Serializable"))
1617
val contextualAnnotation = ClassId(FqName("kotlinx.serialization"), Name.identifier("Contextual"))

compiler-plugin/compiler-plugin-k2/src/main/core/kotlinx/rpc/codegen/FirRpcCheckers.kt

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,32 @@
44

55
package kotlinx.rpc.codegen
66

7+
import kotlinx.rpc.codegen.checkers.FirCheckedAnnotationHelper
78
import kotlinx.rpc.codegen.checkers.FirRpcDeclarationCheckers
9+
import kotlinx.rpc.codegen.checkers.FirRpcExpressionCheckers
810
import org.jetbrains.kotlin.fir.FirSession
911
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.DeclarationCheckers
12+
import org.jetbrains.kotlin.fir.analysis.checkers.expression.ExpressionCheckers
1013
import org.jetbrains.kotlin.fir.analysis.extensions.FirAdditionalCheckersExtension
14+
import org.jetbrains.kotlin.fir.caches.createCache
15+
import org.jetbrains.kotlin.fir.caches.firCachesFactory
1116
import org.jetbrains.kotlin.fir.extensions.FirDeclarationPredicateRegistrar
17+
import org.jetbrains.kotlin.fir.symbols.impl.FirTypeParameterSymbol
1218

1319
class FirRpcCheckers(session: FirSession) : FirAdditionalCheckersExtension(session) {
1420
override fun FirDeclarationPredicateRegistrar.registerPredicates() {
1521
register(FirRpcPredicates.rpc)
22+
register(FirRpcPredicates.checkedAnnotationMeta)
1623
}
1724

18-
override val declarationCheckers: DeclarationCheckers = FirRpcDeclarationCheckers
25+
private val ctx = FirCheckersContext(session)
26+
27+
override val declarationCheckers: DeclarationCheckers = FirRpcDeclarationCheckers(ctx)
28+
override val expressionCheckers: ExpressionCheckers = FirRpcExpressionCheckers(ctx)
29+
}
30+
31+
class FirCheckersContext(private val session: FirSession) {
32+
val typeParametersCache = session.firCachesFactory.createCache { typeParameter: FirTypeParameterSymbol ->
33+
FirCheckedAnnotationHelper.checkedAnnotations(session, typeParameter)
34+
}
1935
}

compiler-plugin/compiler-plugin-k2/src/main/core/kotlinx/rpc/codegen/FirRpcPredicates.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ object FirRpcPredicates {
1111
internal val rpc = DeclarationPredicate.create {
1212
annotated(RpcClassId.rpcAnnotation.asSingleFqName()) // @Rpc
1313
}
14+
15+
internal val checkedAnnotationMeta = DeclarationPredicate.create {
16+
metaAnnotated(RpcClassId.checkedTypeAnnotation.asSingleFqName(), includeItself = false)
17+
}
1418
}
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
/*
2+
* Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.codegen.checkers
6+
7+
import kotlinx.rpc.codegen.FirCheckersContext
8+
import kotlinx.rpc.codegen.FirRpcPredicates
9+
import kotlinx.rpc.codegen.checkers.FirCheckedAnnotationHelper.checkTypeArguments
10+
import kotlinx.rpc.codegen.checkers.diagnostics.FirRpcDiagnostics
11+
import kotlinx.rpc.codegen.common.RpcClassId
12+
import org.jetbrains.kotlin.KtSourceElement
13+
import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
14+
import org.jetbrains.kotlin.diagnostics.reportOn
15+
import org.jetbrains.kotlin.fir.FirElement
16+
import org.jetbrains.kotlin.fir.FirSession
17+
import org.jetbrains.kotlin.fir.analysis.checkers.FirTypeRefSource
18+
import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
19+
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
20+
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirClassChecker
21+
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirFunctionChecker
22+
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirTypeParameterChecker
23+
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
24+
import org.jetbrains.kotlin.fir.analysis.checkers.extractArgumentsTypeRefAndSource
25+
import org.jetbrains.kotlin.fir.caches.getValue
26+
import org.jetbrains.kotlin.fir.declarations.FirClass
27+
import org.jetbrains.kotlin.fir.declarations.FirFunction
28+
import org.jetbrains.kotlin.fir.declarations.FirTypeParameter
29+
import org.jetbrains.kotlin.fir.declarations.hasAnnotation
30+
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
31+
import org.jetbrains.kotlin.fir.extensions.predicateBasedProvider
32+
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
33+
import org.jetbrains.kotlin.fir.resolve.defaultType
34+
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
35+
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
36+
import org.jetbrains.kotlin.fir.symbols.impl.FirTypeParameterSymbol
37+
import org.jetbrains.kotlin.fir.types.*
38+
import org.jetbrains.kotlin.name.ClassId
39+
40+
class FirCheckedAnnotationFunctionCallChecker(
41+
private val ctx: FirCheckersContext,
42+
) : FirFunctionCallChecker(MppCheckerKind.Common) {
43+
override fun check(
44+
expression: FirFunctionCall,
45+
context: CheckerContext,
46+
reporter: DiagnosticReporter,
47+
) {
48+
checkTypeArguments(
49+
context = context,
50+
reporter = reporter,
51+
ctx = ctx,
52+
origin = expression,
53+
originMapper = { it },
54+
symbolProvider = { it.calleeReference.toResolvedCallableSymbol() },
55+
typeParameterSymbolsProvider = { it.typeParameterSymbols },
56+
typeArgumentsProvider = { it.typeArguments },
57+
typeArgumentsMapper = { it.toConeTypeProjection() },
58+
sourceProvider = { _, type -> type.source },
59+
)
60+
}
61+
}
62+
63+
class FirCheckedAnnotationTypeParameterChecker(
64+
private val ctx: FirCheckersContext,
65+
) : FirTypeParameterChecker(MppCheckerKind.Common) {
66+
override fun check(
67+
declaration: FirTypeParameter,
68+
context: CheckerContext,
69+
reporter: DiagnosticReporter,
70+
) {
71+
@Suppress("DuplicatedCode")
72+
declaration.bounds.forEach { bound ->
73+
checkTypeArguments(
74+
context = context,
75+
reporter = reporter,
76+
ctx = ctx,
77+
origin = bound,
78+
originMapper = { it.coneType },
79+
symbolProvider = { it.toClassSymbol(context.session) },
80+
typeParameterSymbolsProvider = { it.typeParameterSymbols },
81+
typeArgumentsProvider = { it.typeArguments.toList() },
82+
typeArgumentsMapper = { it },
83+
sourceProvider = { _, _ -> declaration.source },
84+
)
85+
}
86+
}
87+
}
88+
89+
class FirCheckedAnnotationFirClassChecker(
90+
private val ctx: FirCheckersContext,
91+
) : FirClassChecker(MppCheckerKind.Common) {
92+
override fun check(
93+
declaration: FirClass,
94+
context: CheckerContext,
95+
reporter: DiagnosticReporter,
96+
) {
97+
@Suppress("DuplicatedCode")
98+
declaration.superTypeRefs.forEach { superType ->
99+
checkTypeArguments(
100+
context = context,
101+
reporter = reporter,
102+
ctx = ctx,
103+
origin = superType,
104+
originMapper = { it.coneType },
105+
symbolProvider = { it.toClassSymbol(context.session) },
106+
typeParameterSymbolsProvider = { it.typeParameterSymbols },
107+
typeArgumentsProvider = { it.typeArguments.toList() },
108+
typeArgumentsMapper = { it },
109+
sourceProvider = { ref, _ -> ref.source },
110+
)
111+
}
112+
}
113+
}
114+
115+
class FirCheckedAnnotationFirFunctionChecker(
116+
private val ctx: FirCheckersContext,
117+
) : FirFunctionChecker(MppCheckerKind.Common) {
118+
override fun check(
119+
declaration: FirFunction,
120+
context: CheckerContext,
121+
reporter: DiagnosticReporter,
122+
) {
123+
declaration.valueParameters.forEach { valueParameter ->
124+
checkTypeArguments(
125+
context = context,
126+
reporter = reporter,
127+
ctx = ctx,
128+
origin = valueParameter.returnTypeRef,
129+
originMapper = { it.coneType },
130+
symbolProvider = { it.toClassSymbol(context.session) },
131+
typeParameterSymbolsProvider = { it.typeParameterSymbols },
132+
typeArgumentsProvider = { it.typeArguments.toList() },
133+
typeArgumentsMapper = { it },
134+
sourceProvider = { ref, _ -> ref.source },
135+
)
136+
}
137+
}
138+
}
139+
140+
object FirCheckedAnnotationHelper {
141+
@Suppress("detekt.LongParameterList", "detekt.CyclomaticComplexMethod", "detekt.LongMethod")
142+
fun <Origin, OriginTransformed, Symbol, TypeArgument> checkTypeArguments(
143+
context: CheckerContext,
144+
reporter: DiagnosticReporter,
145+
ctx: FirCheckersContext,
146+
origin: Origin,
147+
originTypeRefSource: FirTypeRefSource? = null,
148+
originMapper: (Origin) -> OriginTransformed?,
149+
symbolProvider: (OriginTransformed) -> Symbol?,
150+
typeParameterSymbolsProvider: (Symbol) -> List<FirTypeParameterSymbol>?,
151+
typeArgumentsProvider: (OriginTransformed) -> List<TypeArgument>,
152+
typeArgumentsMapper: (TypeArgument) -> ConeTypeProjection?,
153+
sourceProvider: (Origin, TypeArgument) -> KtSourceElement?,
154+
) {
155+
val originTransformed = originMapper(origin) ?: return
156+
val symbol = symbolProvider(originTransformed) ?: return
157+
158+
val parameters = checkedAnnotationsOnTypeParameters(
159+
session = context.session,
160+
ctx = ctx,
161+
typeParameterSymbols = typeParameterSymbolsProvider(symbol),
162+
)
163+
164+
val typeArguments = typeArgumentsProvider(originTransformed)
165+
166+
val extractedOriginSource = extractArgumentsTypeRefAndSource(
167+
typeRef = originTypeRefSource?.typeRef
168+
?: (origin as? FirTypeProjectionWithVariance)?.typeRef
169+
?: (origin as? FirTypeRef)
170+
).orEmpty()
171+
172+
parameters.forEach { (i, annotations) ->
173+
val typeArgument = typeArguments[i]
174+
val type = typeArgumentsMapper(typeArgument)?.type
175+
val classSymbol = type?.toClassSymbol(context.session)
176+
177+
val symbol = when {
178+
classSymbol != null -> {
179+
classSymbol
180+
}
181+
182+
typeArgument is ConeTypeParameterType -> {
183+
typeArgument.lookupTag.typeParameterSymbol
184+
}
185+
186+
typeArgument is FirTypeProjectionWithVariance -> {
187+
(typeArgument.typeRef.coneType as? ConeTypeParameterType)
188+
?.lookupTag
189+
?.typeParameterSymbol
190+
}
191+
192+
else -> {
193+
null
194+
}
195+
} ?: return@forEach
196+
197+
val originSource = extractedOriginSource.getOrNull(i)?.source
198+
?: originTypeRefSource?.source
199+
?: (origin as? FirTypeProjectionWithVariance)?.source
200+
?: (origin as? FirTypeRef)?.source
201+
202+
val source = when {
203+
classSymbol != null -> {
204+
originSource ?: sourceProvider(origin, typeArgument)
205+
}
206+
207+
typeArgument is ConeTypeParameterType -> {
208+
originSource ?: typeArgument.lookupTag.typeParameterSymbol.source
209+
}
210+
211+
typeArgument is FirTypeProjectionWithVariance -> {
212+
extractArgumentsTypeRefAndSource(typeArgument.typeRef)
213+
?.getOrNull(i)
214+
?.source
215+
?: typeArgument.source
216+
}
217+
218+
else -> {
219+
null
220+
}
221+
}
222+
223+
checkSymbolAnnotated(
224+
annotations = annotations,
225+
classSymbol = symbol,
226+
source = source,
227+
context = context,
228+
reporter = reporter,
229+
)
230+
}
231+
232+
typeArguments.forEachIndexed { i, typeArgument ->
233+
val nextOriginSource = (origin as? FirTypeProjectionWithVariance)
234+
?.let { FirTypeRefSource(it.typeRef, it.source) }
235+
?: (origin as? FirTypeRef)?.let { FirTypeRefSource(it, it.source) }
236+
237+
checkTypeArguments<TypeArgument, ConeKotlinType, FirClassSymbol<*>, ConeTypeProjection>(
238+
context = context,
239+
reporter = reporter,
240+
ctx = ctx,
241+
origin = typeArgument,
242+
originMapper = {
243+
when (typeArgument) {
244+
is ConeKotlinTypeProjection -> typeArgument.type
245+
is FirTypeProjectionWithVariance -> typeArgument.toConeTypeProjection()
246+
else -> null
247+
}?.type
248+
},
249+
originTypeRefSource = extractedOriginSource.getOrNull(i) ?: nextOriginSource ?: originTypeRefSource,
250+
symbolProvider = { it.toClassSymbol(context.session) },
251+
typeParameterSymbolsProvider = { it.typeParameterSymbols },
252+
typeArgumentsProvider = { it.typeArguments.toList() },
253+
typeArgumentsMapper = { it },
254+
sourceProvider = { arg, _ ->
255+
when (arg) {
256+
is FirElement -> arg.source
257+
is ConeKotlinTypeProjection -> sourceProvider(origin, arg)
258+
else -> null
259+
}
260+
},
261+
)
262+
}
263+
}
264+
265+
private fun checkedAnnotationsOnTypeParameters(
266+
session: FirSession,
267+
ctx: FirCheckersContext,
268+
typeParameterSymbols: List<FirTypeParameterSymbol>?,
269+
): List<Pair<Int, List<FirClassSymbol<*>>>> {
270+
return typeParameterSymbols.orEmpty().withIndex().filter { (_, parameter) ->
271+
session.predicateBasedProvider.matches(
272+
predicate = FirRpcPredicates.checkedAnnotationMeta,
273+
declaration = parameter,
274+
)
275+
}.map { (i, parameter) ->
276+
i to ctx.typeParametersCache.getValue(parameter)
277+
}
278+
}
279+
280+
fun checkedAnnotations(
281+
session: FirSession,
282+
symbol: FirBasedSymbol<*>,
283+
visited: Set<FirBasedSymbol<*>> = emptySet(),
284+
): List<FirClassSymbol<*>> {
285+
return symbol.annotations.mapNotNull {
286+
it.resolvedType.toClassSymbol(session)
287+
}.filter { annotation ->
288+
when {
289+
annotation in visited -> false
290+
annotation.hasAnnotation(RpcClassId.checkedTypeAnnotation, session) -> true
291+
else -> checkedAnnotations(session, annotation, visited + annotation).isNotEmpty()
292+
}
293+
}
294+
}
295+
296+
private fun checkSymbolAnnotated(
297+
annotations: List<FirClassSymbol<*>>,
298+
classSymbol: FirBasedSymbol<*>,
299+
source: KtSourceElement?,
300+
context: CheckerContext,
301+
reporter: DiagnosticReporter,
302+
) {
303+
for (annotationClass in annotations) {
304+
val hasCheckedAnnotation = hasCheckedAnnotation(
305+
session = context.session,
306+
symbol = classSymbol,
307+
annotationId = annotationClass.classId,
308+
)
309+
310+
if (!hasCheckedAnnotation) {
311+
reporter.reportOn(
312+
source = source,
313+
factory = FirRpcDiagnostics.CHECKED_ANNOTATION_VIOLATION,
314+
a = annotationClass.defaultType(),
315+
context = context,
316+
)
317+
}
318+
}
319+
}
320+
321+
private fun hasCheckedAnnotation(
322+
session: FirSession,
323+
symbol: FirBasedSymbol<*>,
324+
annotationId: ClassId,
325+
visited: Set<FirBasedSymbol<*>> = emptySet(),
326+
): Boolean {
327+
return when {
328+
symbol in visited -> false
329+
symbol.hasAnnotation(annotationId, session) -> true
330+
else -> symbol.annotations.any { annotation ->
331+
annotation.resolvedType.toClassSymbol(session)?.let {
332+
hasCheckedAnnotation(session, it, annotationId, visited + symbol)
333+
} == true
334+
}
335+
}
336+
}
337+
}

0 commit comments

Comments
 (0)