|
17 | 17 | package za.co.absa.springdocopenapiscala
|
18 | 18 |
|
19 | 19 | import io.swagger.v3.oas.models.Components
|
20 |
| -import io.swagger.v3.oas.models.media.Schema |
| 20 | +import io.swagger.v3.oas.models.media.{Discriminator, Schema} |
21 | 21 |
|
22 | 22 | import java.time.{Instant, LocalDate, LocalDateTime, ZonedDateTime}
|
23 | 23 | import java.util.UUID
|
24 | 24 | import scala.annotation.tailrec
|
25 | 25 | import scala.collection.JavaConverters._
|
26 | 26 | import scala.reflect.runtime.universe._
|
27 |
| - |
28 | 27 | import OpenAPIModelRegistration._
|
29 | 28 |
|
30 | 29 | class OpenAPIModelRegistration(
|
31 | 30 | components: Components,
|
32 |
| - extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling |
| 31 | + extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling, |
| 32 | + config: RegistrationConfig = RegistrationConfig() |
33 | 33 | ) {
|
34 | 34 |
|
35 | 35 | /**
|
@@ -144,13 +144,60 @@ class OpenAPIModelRegistration(
|
144 | 144 | s.isTerm && s.asTerm.isVal && s.typeSignature <:< typeOf[Enumeration#Value]
|
145 | 145 |
|
146 | 146 | private def handleSealedType(tpe: Type): Schema[_] = {
|
| 147 | + |
| 148 | + def addDiscriminatorPropertyToChildren( |
| 149 | + currentSchema: Schema[_], |
| 150 | + discriminatorPropertyName: String, |
| 151 | + addOnlyToDirectChildren: Boolean, |
| 152 | + discriminatorValue: Option[String] = None |
| 153 | + ): Unit = { |
| 154 | + val children = currentSchema.getOneOf.asScala |
| 155 | + children.foreach { s => |
| 156 | + val ref = s.get$ref |
| 157 | + val name = extractSchemaNameFromRef(ref) |
| 158 | + val actualSchema = components.getSchemas.get(name) |
| 159 | + if (actualSchema.getType == "object") { |
| 160 | + val constEnumSchema = createConstEnumSchema(discriminatorValue.getOrElse(name)) |
| 161 | + actualSchema.addProperty(discriminatorPropertyName, constEnumSchema) |
| 162 | + actualSchema.addRequiredItem(discriminatorPropertyName) |
| 163 | + } else if ( |
| 164 | + !addOnlyToDirectChildren && |
| 165 | + Option(actualSchema.getOneOf).map(!_.isEmpty).getOrElse(false) // is schema representing another sum ADT root |
| 166 | + ) { |
| 167 | + addDiscriminatorPropertyToChildren( |
| 168 | + actualSchema, |
| 169 | + discriminatorPropertyName, |
| 170 | + addOnlyToDirectChildren, |
| 171 | + Some(name) |
| 172 | + ) |
| 173 | + } |
| 174 | + } |
| 175 | + } |
| 176 | + |
147 | 177 | val classSymbol = tpe.typeSymbol.asClass
|
148 | 178 | val name = tpe.typeSymbol.name.toString.trim
|
149 | 179 | val children = classSymbol.knownDirectSubclasses
|
150 | 180 | val childrenSchemas = children.map(_.asType.toType).map(handleType)
|
151 | 181 | val schema = new Schema
|
152 | 182 | schema.setOneOf(childrenSchemas.toList.asJava)
|
153 | 183 |
|
| 184 | + config.sumADTsShape match { |
| 185 | + case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addOnlyToDirectChildren) => |
| 186 | + val discriminatorPropertyName = discriminatorPropertyNameFn(name) |
| 187 | + schema.setDiscriminator { |
| 188 | + val discriminator = new Discriminator |
| 189 | + discriminator.setPropertyName(discriminatorPropertyName) |
| 190 | + discriminator |
| 191 | + } |
| 192 | + addDiscriminatorPropertyToChildren( |
| 193 | + schema, |
| 194 | + discriminatorPropertyName, |
| 195 | + addOnlyToDirectChildren |
| 196 | + ) |
| 197 | + |
| 198 | + case _ => () |
| 199 | + } |
| 200 | + |
154 | 201 | registerAsReference(name, schema)
|
155 | 202 | }
|
156 | 203 |
|
@@ -187,10 +234,54 @@ class OpenAPIModelRegistration(
|
187 | 234 | schemaReference
|
188 | 235 | }
|
189 | 236 |
|
| 237 | + private def createConstEnumSchema(const: String): Schema[_] = { |
| 238 | + val constEnumSchema = new Schema[String] |
| 239 | + constEnumSchema.setType("string") |
| 240 | + constEnumSchema.setEnum(Seq(const).asJava) |
| 241 | + constEnumSchema |
| 242 | + } |
| 243 | + |
| 244 | + private def extractSchemaNameFromRef(ref: String): String = { |
| 245 | + ref.substring(ref.lastIndexOf("/") + 1) |
| 246 | + } |
| 247 | + |
190 | 248 | }
|
191 | 249 |
|
192 | 250 | object OpenAPIModelRegistration {
|
193 | 251 |
|
| 252 | + /** |
| 253 | + * Configuration of the registration class. |
| 254 | + * |
| 255 | + * @param sumADTsShape how sum ADTs should be registered (with or without discriminator) |
| 256 | + */ |
| 257 | + case class RegistrationConfig( |
| 258 | + sumADTsShape: RegistrationConfig.SumADTsShape = RegistrationConfig.SumADTsShape.WithoutDiscriminator |
| 259 | + ) |
| 260 | + |
| 261 | + object RegistrationConfig { |
| 262 | + |
| 263 | + sealed abstract class SumADTsShape |
| 264 | + |
| 265 | + object SumADTsShape { |
| 266 | + case object WithoutDiscriminator extends SumADTsShape |
| 267 | + case class WithDiscriminator( |
| 268 | + discriminatorPropertyNameFn: WithDiscriminator.DiscriminatorPropertyNameFn = |
| 269 | + WithDiscriminator.defaultDiscriminatorPropertyNameFn, |
| 270 | + addDiscriminatorPropertyOnlyToDirectChildren: Boolean = true |
| 271 | + ) extends SumADTsShape |
| 272 | + |
| 273 | + object WithDiscriminator { |
| 274 | + |
| 275 | + /** Function from sealed type name to discriminator property name. */ |
| 276 | + type DiscriminatorPropertyNameFn = String => String |
| 277 | + |
| 278 | + val defaultDiscriminatorPropertyNameFn: DiscriminatorPropertyNameFn = sealedTypeName => |
| 279 | + sealedTypeName.head.toLower + sealedTypeName.tail + "Type" |
| 280 | + } |
| 281 | + } |
| 282 | + |
| 283 | + } |
| 284 | + |
194 | 285 | /**
|
195 | 286 | * Context of model registration.
|
196 | 287 | * Currently contains only `Components` that can be mutated if needed
|
|
0 commit comments