|
| 1 | +/* |
| 2 | + * Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. |
| 3 | + */ |
| 4 | + |
| 5 | +package kotlinx.rpc.codegen.checkers |
| 6 | + |
| 7 | +import kotlinx.rpc.codegen.FirCheckersContext |
| 8 | +import kotlinx.rpc.codegen.FirRpcPredicates |
| 9 | +import kotlinx.rpc.codegen.checkers.FirCheckedAnnotationHelper.checkTypeArguments |
| 10 | +import kotlinx.rpc.codegen.checkers.diagnostics.FirRpcDiagnostics |
| 11 | +import kotlinx.rpc.codegen.common.RpcClassId |
| 12 | +import org.jetbrains.kotlin.KtSourceElement |
| 13 | +import org.jetbrains.kotlin.diagnostics.DiagnosticReporter |
| 14 | +import org.jetbrains.kotlin.diagnostics.reportOn |
| 15 | +import org.jetbrains.kotlin.fir.FirElement |
| 16 | +import org.jetbrains.kotlin.fir.FirSession |
| 17 | +import org.jetbrains.kotlin.fir.analysis.checkers.FirTypeRefSource |
| 18 | +import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind |
| 19 | +import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext |
| 20 | +import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirClassChecker |
| 21 | +import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirFunctionChecker |
| 22 | +import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirTypeParameterChecker |
| 23 | +import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker |
| 24 | +import org.jetbrains.kotlin.fir.analysis.checkers.extractArgumentsTypeRefAndSource |
| 25 | +import org.jetbrains.kotlin.fir.caches.getValue |
| 26 | +import org.jetbrains.kotlin.fir.declarations.FirClass |
| 27 | +import org.jetbrains.kotlin.fir.declarations.FirFunction |
| 28 | +import org.jetbrains.kotlin.fir.declarations.FirTypeParameter |
| 29 | +import org.jetbrains.kotlin.fir.declarations.hasAnnotation |
| 30 | +import org.jetbrains.kotlin.fir.expressions.FirFunctionCall |
| 31 | +import org.jetbrains.kotlin.fir.extensions.predicateBasedProvider |
| 32 | +import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol |
| 33 | +import org.jetbrains.kotlin.fir.resolve.defaultType |
| 34 | +import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol |
| 35 | +import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol |
| 36 | +import org.jetbrains.kotlin.fir.symbols.impl.FirTypeParameterSymbol |
| 37 | +import org.jetbrains.kotlin.fir.types.* |
| 38 | +import org.jetbrains.kotlin.name.ClassId |
| 39 | + |
| 40 | +class FirCheckedAnnotationFunctionCallChecker( |
| 41 | + private val ctx: FirCheckersContext, |
| 42 | +) : FirFunctionCallChecker(MppCheckerKind.Common) { |
| 43 | + override fun check( |
| 44 | + expression: FirFunctionCall, |
| 45 | + context: CheckerContext, |
| 46 | + reporter: DiagnosticReporter, |
| 47 | + ) { |
| 48 | + checkTypeArguments( |
| 49 | + context = context, |
| 50 | + reporter = reporter, |
| 51 | + ctx = ctx, |
| 52 | + origin = expression, |
| 53 | + originMapper = { it }, |
| 54 | + symbolProvider = { it.calleeReference.toResolvedCallableSymbol() }, |
| 55 | + typeParameterSymbolsProvider = { it.typeParameterSymbols }, |
| 56 | + typeArgumentsProvider = { it.typeArguments }, |
| 57 | + typeArgumentsMapper = { it.toConeTypeProjection() }, |
| 58 | + sourceProvider = { _, type -> type.source }, |
| 59 | + ) |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +class FirCheckedAnnotationTypeParameterChecker( |
| 64 | + private val ctx: FirCheckersContext, |
| 65 | +) : FirTypeParameterChecker(MppCheckerKind.Common) { |
| 66 | + override fun check( |
| 67 | + declaration: FirTypeParameter, |
| 68 | + context: CheckerContext, |
| 69 | + reporter: DiagnosticReporter, |
| 70 | + ) { |
| 71 | + @Suppress("DuplicatedCode") |
| 72 | + declaration.bounds.forEach { bound -> |
| 73 | + checkTypeArguments( |
| 74 | + context = context, |
| 75 | + reporter = reporter, |
| 76 | + ctx = ctx, |
| 77 | + origin = bound, |
| 78 | + originMapper = { it.coneType }, |
| 79 | + symbolProvider = { it.toClassSymbol(context.session) }, |
| 80 | + typeParameterSymbolsProvider = { it.typeParameterSymbols }, |
| 81 | + typeArgumentsProvider = { it.typeArguments.toList() }, |
| 82 | + typeArgumentsMapper = { it }, |
| 83 | + sourceProvider = { _, _ -> declaration.source }, |
| 84 | + ) |
| 85 | + } |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +class FirCheckedAnnotationFirClassChecker( |
| 90 | + private val ctx: FirCheckersContext, |
| 91 | +) : FirClassChecker(MppCheckerKind.Common) { |
| 92 | + override fun check( |
| 93 | + declaration: FirClass, |
| 94 | + context: CheckerContext, |
| 95 | + reporter: DiagnosticReporter, |
| 96 | + ) { |
| 97 | + @Suppress("DuplicatedCode") |
| 98 | + declaration.superTypeRefs.forEach { superType -> |
| 99 | + checkTypeArguments( |
| 100 | + context = context, |
| 101 | + reporter = reporter, |
| 102 | + ctx = ctx, |
| 103 | + origin = superType, |
| 104 | + originMapper = { it.coneType }, |
| 105 | + symbolProvider = { it.toClassSymbol(context.session) }, |
| 106 | + typeParameterSymbolsProvider = { it.typeParameterSymbols }, |
| 107 | + typeArgumentsProvider = { it.typeArguments.toList() }, |
| 108 | + typeArgumentsMapper = { it }, |
| 109 | + sourceProvider = { ref, _ -> ref.source }, |
| 110 | + ) |
| 111 | + } |
| 112 | + } |
| 113 | +} |
| 114 | + |
| 115 | +class FirCheckedAnnotationFirFunctionChecker( |
| 116 | + private val ctx: FirCheckersContext, |
| 117 | +) : FirFunctionChecker(MppCheckerKind.Common) { |
| 118 | + override fun check( |
| 119 | + declaration: FirFunction, |
| 120 | + context: CheckerContext, |
| 121 | + reporter: DiagnosticReporter, |
| 122 | + ) { |
| 123 | + declaration.valueParameters.forEach { valueParameter -> |
| 124 | + checkTypeArguments( |
| 125 | + context = context, |
| 126 | + reporter = reporter, |
| 127 | + ctx = ctx, |
| 128 | + origin = valueParameter.returnTypeRef, |
| 129 | + originMapper = { it.coneType }, |
| 130 | + symbolProvider = { it.toClassSymbol(context.session) }, |
| 131 | + typeParameterSymbolsProvider = { it.typeParameterSymbols }, |
| 132 | + typeArgumentsProvider = { it.typeArguments.toList() }, |
| 133 | + typeArgumentsMapper = { it }, |
| 134 | + sourceProvider = { ref, _ -> ref.source }, |
| 135 | + ) |
| 136 | + } |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +object FirCheckedAnnotationHelper { |
| 141 | + @Suppress("detekt.LongParameterList", "detekt.CyclomaticComplexMethod", "detekt.LongMethod") |
| 142 | + fun <Origin, OriginTransformed, Symbol, TypeArgument> checkTypeArguments( |
| 143 | + context: CheckerContext, |
| 144 | + reporter: DiagnosticReporter, |
| 145 | + ctx: FirCheckersContext, |
| 146 | + origin: Origin, |
| 147 | + originTypeRefSource: FirTypeRefSource? = null, |
| 148 | + originMapper: (Origin) -> OriginTransformed?, |
| 149 | + symbolProvider: (OriginTransformed) -> Symbol?, |
| 150 | + typeParameterSymbolsProvider: (Symbol) -> List<FirTypeParameterSymbol>?, |
| 151 | + typeArgumentsProvider: (OriginTransformed) -> List<TypeArgument>, |
| 152 | + typeArgumentsMapper: (TypeArgument) -> ConeTypeProjection?, |
| 153 | + sourceProvider: (Origin, TypeArgument) -> KtSourceElement?, |
| 154 | + ) { |
| 155 | + val originTransformed = originMapper(origin) ?: return |
| 156 | + val symbol = symbolProvider(originTransformed) ?: return |
| 157 | + |
| 158 | + val parameters = checkedAnnotationsOnTypeParameters( |
| 159 | + session = context.session, |
| 160 | + ctx = ctx, |
| 161 | + typeParameterSymbols = typeParameterSymbolsProvider(symbol), |
| 162 | + ) |
| 163 | + |
| 164 | + val typeArguments = typeArgumentsProvider(originTransformed) |
| 165 | + |
| 166 | + val extractedOriginSource = extractArgumentsTypeRefAndSource( |
| 167 | + typeRef = originTypeRefSource?.typeRef |
| 168 | + ?: (origin as? FirTypeProjectionWithVariance)?.typeRef |
| 169 | + ?: (origin as? FirTypeRef) |
| 170 | + ).orEmpty() |
| 171 | + |
| 172 | + parameters.forEach { (i, annotations) -> |
| 173 | + val typeArgument = typeArguments[i] |
| 174 | + val type = typeArgumentsMapper(typeArgument)?.type |
| 175 | + val classSymbol = type?.toClassSymbol(context.session) |
| 176 | + |
| 177 | + val symbol = when { |
| 178 | + classSymbol != null -> { |
| 179 | + classSymbol |
| 180 | + } |
| 181 | + |
| 182 | + typeArgument is ConeTypeParameterType -> { |
| 183 | + typeArgument.lookupTag.typeParameterSymbol |
| 184 | + } |
| 185 | + |
| 186 | + typeArgument is FirTypeProjectionWithVariance -> { |
| 187 | + (typeArgument.typeRef.coneType as? ConeTypeParameterType) |
| 188 | + ?.lookupTag |
| 189 | + ?.typeParameterSymbol |
| 190 | + } |
| 191 | + |
| 192 | + else -> { |
| 193 | + null |
| 194 | + } |
| 195 | + } ?: return@forEach |
| 196 | + |
| 197 | + val originSource = extractedOriginSource.getOrNull(i)?.source |
| 198 | + ?: originTypeRefSource?.source |
| 199 | + ?: (origin as? FirTypeProjectionWithVariance)?.source |
| 200 | + ?: (origin as? FirTypeRef)?.source |
| 201 | + |
| 202 | + val source = when { |
| 203 | + classSymbol != null -> { |
| 204 | + originSource ?: sourceProvider(origin, typeArgument) |
| 205 | + } |
| 206 | + |
| 207 | + typeArgument is ConeTypeParameterType -> { |
| 208 | + originSource ?: typeArgument.lookupTag.typeParameterSymbol.source |
| 209 | + } |
| 210 | + |
| 211 | + typeArgument is FirTypeProjectionWithVariance -> { |
| 212 | + extractArgumentsTypeRefAndSource(typeArgument.typeRef) |
| 213 | + ?.getOrNull(i) |
| 214 | + ?.source |
| 215 | + ?: typeArgument.source |
| 216 | + } |
| 217 | + |
| 218 | + else -> { |
| 219 | + null |
| 220 | + } |
| 221 | + } |
| 222 | + |
| 223 | + checkSymbolAnnotated( |
| 224 | + annotations = annotations, |
| 225 | + classSymbol = symbol, |
| 226 | + source = source, |
| 227 | + context = context, |
| 228 | + reporter = reporter, |
| 229 | + ) |
| 230 | + } |
| 231 | + |
| 232 | + typeArguments.forEachIndexed { i, typeArgument -> |
| 233 | + val nextOriginSource = (origin as? FirTypeProjectionWithVariance) |
| 234 | + ?.let { FirTypeRefSource(it.typeRef, it.source) } |
| 235 | + ?: (origin as? FirTypeRef)?.let { FirTypeRefSource(it, it.source) } |
| 236 | + |
| 237 | + checkTypeArguments<TypeArgument, ConeKotlinType, FirClassSymbol<*>, ConeTypeProjection>( |
| 238 | + context = context, |
| 239 | + reporter = reporter, |
| 240 | + ctx = ctx, |
| 241 | + origin = typeArgument, |
| 242 | + originMapper = { |
| 243 | + when (typeArgument) { |
| 244 | + is ConeKotlinTypeProjection -> typeArgument.type |
| 245 | + is FirTypeProjectionWithVariance -> typeArgument.toConeTypeProjection() |
| 246 | + else -> null |
| 247 | + }?.type |
| 248 | + }, |
| 249 | + originTypeRefSource = extractedOriginSource.getOrNull(i) ?: nextOriginSource ?: originTypeRefSource, |
| 250 | + symbolProvider = { it.toClassSymbol(context.session) }, |
| 251 | + typeParameterSymbolsProvider = { it.typeParameterSymbols }, |
| 252 | + typeArgumentsProvider = { it.typeArguments.toList() }, |
| 253 | + typeArgumentsMapper = { it }, |
| 254 | + sourceProvider = { arg, _ -> |
| 255 | + when (arg) { |
| 256 | + is FirElement -> arg.source |
| 257 | + is ConeKotlinTypeProjection -> sourceProvider(origin, arg) |
| 258 | + else -> null |
| 259 | + } |
| 260 | + }, |
| 261 | + ) |
| 262 | + } |
| 263 | + } |
| 264 | + |
| 265 | + private fun checkedAnnotationsOnTypeParameters( |
| 266 | + session: FirSession, |
| 267 | + ctx: FirCheckersContext, |
| 268 | + typeParameterSymbols: List<FirTypeParameterSymbol>?, |
| 269 | + ): List<Pair<Int, List<FirClassSymbol<*>>>> { |
| 270 | + return typeParameterSymbols.orEmpty().withIndex().filter { (_, parameter) -> |
| 271 | + session.predicateBasedProvider.matches( |
| 272 | + predicate = FirRpcPredicates.checkedAnnotationMeta, |
| 273 | + declaration = parameter, |
| 274 | + ) |
| 275 | + }.map { (i, parameter) -> |
| 276 | + i to ctx.typeParametersCache.getValue(parameter) |
| 277 | + } |
| 278 | + } |
| 279 | + |
| 280 | + fun checkedAnnotations( |
| 281 | + session: FirSession, |
| 282 | + symbol: FirBasedSymbol<*>, |
| 283 | + visited: Set<FirBasedSymbol<*>> = emptySet(), |
| 284 | + ): List<FirClassSymbol<*>> { |
| 285 | + return symbol.annotations.mapNotNull { |
| 286 | + it.resolvedType.toClassSymbol(session) |
| 287 | + }.filter { annotation -> |
| 288 | + when { |
| 289 | + annotation in visited -> false |
| 290 | + annotation.hasAnnotation(RpcClassId.checkedTypeAnnotation, session) -> true |
| 291 | + else -> checkedAnnotations(session, annotation, visited + annotation).isNotEmpty() |
| 292 | + } |
| 293 | + } |
| 294 | + } |
| 295 | + |
| 296 | + private fun checkSymbolAnnotated( |
| 297 | + annotations: List<FirClassSymbol<*>>, |
| 298 | + classSymbol: FirBasedSymbol<*>, |
| 299 | + source: KtSourceElement?, |
| 300 | + context: CheckerContext, |
| 301 | + reporter: DiagnosticReporter, |
| 302 | + ) { |
| 303 | + for (annotationClass in annotations) { |
| 304 | + val hasCheckedAnnotation = hasCheckedAnnotation( |
| 305 | + session = context.session, |
| 306 | + symbol = classSymbol, |
| 307 | + annotationId = annotationClass.classId, |
| 308 | + ) |
| 309 | + |
| 310 | + if (!hasCheckedAnnotation) { |
| 311 | + reporter.reportOn( |
| 312 | + source = source, |
| 313 | + factory = FirRpcDiagnostics.CHECKED_ANNOTATION_VIOLATION, |
| 314 | + a = annotationClass.defaultType(), |
| 315 | + context = context, |
| 316 | + ) |
| 317 | + } |
| 318 | + } |
| 319 | + } |
| 320 | + |
| 321 | + private fun hasCheckedAnnotation( |
| 322 | + session: FirSession, |
| 323 | + symbol: FirBasedSymbol<*>, |
| 324 | + annotationId: ClassId, |
| 325 | + visited: Set<FirBasedSymbol<*>> = emptySet(), |
| 326 | + ): Boolean { |
| 327 | + return when { |
| 328 | + symbol in visited -> false |
| 329 | + symbol.hasAnnotation(annotationId, session) -> true |
| 330 | + else -> symbol.annotations.any { annotation -> |
| 331 | + annotation.resolvedType.toClassSymbol(session)?.let { |
| 332 | + hasCheckedAnnotation(session, it, annotationId, visited + symbol) |
| 333 | + } == true |
| 334 | + } |
| 335 | + } |
| 336 | + } |
| 337 | +} |
0 commit comments