diff --git a/README.md b/README.md index 22587fc..facccc9 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ This library aims to avoid pollution of the model by custom annotations and depe - support for basic Scala collections (`Map`, `Seq`, `Set`, `Array`) as types of `case class` parameters - only top-level case classes need to be registered, child case classes are then recursively registered - support for Scala `Enumeration` where simple `Value` constructor is used (without `name`) -- support for sum ADTs (`sealed trait` and `sealed abstract class`) +- support for sum ADTs (`sealed trait` and `sealed abstract class`) with optional discriminator ## Usage @@ -215,6 +215,37 @@ Then, in `handleFn`, the handler creates a `Schema` object for `CustomClass`, adds it to `Components` so that it can be referenced by name `CustomClass`, and returns reference to that object. +### Registration configuration +It is possible to further customize registration by providing custom `RegistrationConfig` to `OpenAPIModelRegistration`. + +#### Example +```scala +val components = ... +val registration = OpenAPIModelRegistration( + components, + config = RegistrationConfig( + OpenAPIModelRegistration.RegistrationConfig( + sumADTsShape = + // default values apply for discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren + OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator() + ) + ) +) +``` + +#### sumADTsShape +This config property sets how sum ADTs are registered. It has two possible values: +- `RegistrationConfig.SumADTsShape.WithoutDiscriminator` - default option, doesn't add discriminators +- `RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren)` - + adds discriminator to sealed types schema, + and also adds discriminator to sum ADTs elements properties; discriminator property name is customizable by `discriminatorPropertyNameFn`, + by default it takes sealed type name, converts its first letter to lower case, and adds `"Type"` suffix, + for example if sealed type name is `Expression`, the property name is `expressionType`; + if `addDiscriminatorPropertyOnlyToDirectChildren` is `false`, discriminator property is added to all children, + so for example in `ADT = A | B | C; B = D | E` discriminator of `ADT` would be added to `A`, `C`, `D`, `E` + (`D` and `E` would have discriminator of `B` in addition to that) + while with `addDiscriminatorPropertyOnlyToDirectChildren` set to `true` (default) + it would be added only to `A` and `C` ## Examples diff --git a/library/src/main/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistration.scala b/library/src/main/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistration.scala index af01a4b..8419aa3 100644 --- a/library/src/main/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistration.scala +++ b/library/src/main/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistration.scala @@ -17,19 +17,19 @@ package za.co.absa.springdocopenapiscala import io.swagger.v3.oas.models.Components -import io.swagger.v3.oas.models.media.Schema +import io.swagger.v3.oas.models.media.{Discriminator, Schema} import java.time.{Instant, LocalDate, LocalDateTime, ZonedDateTime} import java.util.UUID import scala.annotation.tailrec import scala.collection.JavaConverters._ import scala.reflect.runtime.universe._ - import OpenAPIModelRegistration._ class OpenAPIModelRegistration( components: Components, - extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling + extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling, + config: RegistrationConfig = RegistrationConfig() ) { /** @@ -144,6 +144,36 @@ class OpenAPIModelRegistration( s.isTerm && s.asTerm.isVal && s.typeSignature <:< typeOf[Enumeration#Value] private def handleSealedType(tpe: Type): Schema[_] = { + + def addDiscriminatorPropertyToChildren( + currentSchema: Schema[_], + discriminatorPropertyName: String, + addOnlyToDirectChildren: Boolean, + discriminatorValue: Option[String] = None + ): Unit = { + val children = currentSchema.getOneOf.asScala + children.foreach { s => + val ref = s.get$ref + val name = extractSchemaNameFromRef(ref) + val actualSchema = components.getSchemas.get(name) + if (actualSchema.getType == "object") { + val constEnumSchema = createConstEnumSchema(discriminatorValue.getOrElse(name)) + actualSchema.addProperty(discriminatorPropertyName, constEnumSchema) + actualSchema.addRequiredItem(discriminatorPropertyName) + } else if ( + !addOnlyToDirectChildren && + Option(actualSchema.getOneOf).map(!_.isEmpty).getOrElse(false) // is schema representing another sum ADT root + ) { + addDiscriminatorPropertyToChildren( + actualSchema, + discriminatorPropertyName, + addOnlyToDirectChildren, + Some(name) + ) + } + } + } + val classSymbol = tpe.typeSymbol.asClass val name = tpe.typeSymbol.name.toString.trim val children = classSymbol.knownDirectSubclasses @@ -151,6 +181,23 @@ class OpenAPIModelRegistration( val schema = new Schema schema.setOneOf(childrenSchemas.toList.asJava) + config.sumADTsShape match { + case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addOnlyToDirectChildren) => + val discriminatorPropertyName = discriminatorPropertyNameFn(name) + schema.setDiscriminator { + val discriminator = new Discriminator + discriminator.setPropertyName(discriminatorPropertyName) + discriminator + } + addDiscriminatorPropertyToChildren( + schema, + discriminatorPropertyName, + addOnlyToDirectChildren + ) + + case _ => () + } + registerAsReference(name, schema) } @@ -187,10 +234,54 @@ class OpenAPIModelRegistration( schemaReference } + private def createConstEnumSchema(const: String): Schema[_] = { + val constEnumSchema = new Schema[String] + constEnumSchema.setType("string") + constEnumSchema.setEnum(Seq(const).asJava) + constEnumSchema + } + + private def extractSchemaNameFromRef(ref: String): String = { + ref.substring(ref.lastIndexOf("/") + 1) + } + } object OpenAPIModelRegistration { + /** + * Configuration of the registration class. + * + * @param sumADTsShape how sum ADTs should be registered (with or without discriminator) + */ + case class RegistrationConfig( + sumADTsShape: RegistrationConfig.SumADTsShape = RegistrationConfig.SumADTsShape.WithoutDiscriminator + ) + + object RegistrationConfig { + + sealed abstract class SumADTsShape + + object SumADTsShape { + case object WithoutDiscriminator extends SumADTsShape + case class WithDiscriminator( + discriminatorPropertyNameFn: WithDiscriminator.DiscriminatorPropertyNameFn = + WithDiscriminator.defaultDiscriminatorPropertyNameFn, + addDiscriminatorPropertyOnlyToDirectChildren: Boolean = true + ) extends SumADTsShape + + object WithDiscriminator { + + /** Function from sealed type name to discriminator property name. */ + type DiscriminatorPropertyNameFn = String => String + + val defaultDiscriminatorPropertyNameFn: DiscriminatorPropertyNameFn = sealedTypeName => + sealedTypeName.head.toLower + sealedTypeName.tail + "Type" + } + } + + } + /** * Context of model registration. * Currently contains only `Components` that can be mutated if needed diff --git a/library/src/test/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistrationSpec.scala b/library/src/test/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistrationSpec.scala index a816fd9..c8485b8 100644 --- a/library/src/test/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistrationSpec.scala +++ b/library/src/test/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistrationSpec.scala @@ -132,6 +132,21 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec { private case class EmptySealedTraitClass(a: EmptySealedTrait) + private sealed trait NestedSealedTrait + + private object NestedSealedTrait { + case class NestedSealedTraitVariant1(a: String, b: Int) extends NestedSealedTrait + case object NestedSealedTraitVariant2 extends NestedSealedTrait + + sealed abstract class NestedSealedTraitVariant3 extends NestedSealedTrait + + object NestedSealedTraitVariant3 { + case class NestedSealedTraitVariant3Subvariant1(a: Float) extends NestedSealedTraitVariant3 + case object NestedSealedTraitVariant3Subvariant2 extends NestedSealedTraitVariant3 + } + + } + private case class CustomClassComplexChild(a: Option[Int]) private class CustomClass(val complexChild: CustomClassComplexChild) { @@ -335,6 +350,296 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec { ) } + it should "add discriminator to sealed types " + + "and discriminator property to all children (if addDiscriminatorPropertyOnlyToDirectChildren is false)" + + "if SumADTsShape is WithDiscriminator (with default DiscriminatorPropertyNameFn) in the config" in { + val components = new Components + val openAPIModelRegistration = new OpenAPIModelRegistration( + components, + config = OpenAPIModelRegistration.RegistrationConfig( + sumADTsShape = OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator( + addDiscriminatorPropertyOnlyToDirectChildren = false + ) + ) + ) + + openAPIModelRegistration.register[NestedSealedTrait]() + + val actualSchemas = components.getSchemas + + assertPredicateForPath( + actualSchemas, + "NestedSealedTrait", + schema => { + val actualDiscriminator = schema.getDiscriminator + val actualOneOf = schema.getOneOf.asScala + val expectedOneOf = Seq( + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant1"), + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant2"), + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3") + ) + + val isPropertyNameCorrect = actualDiscriminator.getPropertyName === "nestedSealedTraitType" + val isMappingEmpty = Option(actualDiscriminator.getMapping).map(_.isEmpty).getOrElse(true) + val isOneOfCorrect = actualOneOf === expectedOneOf + + isPropertyNameCorrect && isMappingEmpty && isOneOfCorrect + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant1", + schema => { + val actualProperties = schema.getProperties.asScala + + val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") && actualProperties.contains("b") + val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && { + val s = actualProperties("nestedSealedTraitType") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant1") + } + val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitType") + val isCountOfPropertiesCorrect = actualProperties.size === 3 + + areNonDiscriminatorPropertiesCorrect && + isDiscriminatorPropertyCorrect && + isDiscriminatorPropertyRequired && + isCountOfPropertiesCorrect + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant2", + schema => { + val actualProperties = schema.getProperties.asScala + + val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && { + val s = actualProperties("nestedSealedTraitType") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant2") + } + val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitType") + val isCountOfPropertiesCorrect = actualProperties.size === 1 + + isDiscriminatorPropertyCorrect && isDiscriminatorPropertyRequired && isCountOfPropertiesCorrect + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant3", + schema => { + val actualDiscriminator = schema.getDiscriminator + val actualOneOf = schema.getOneOf.asScala + val expectedOneOf = Seq( + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3Subvariant1"), + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3Subvariant2") + ) + + val isPropertyNameCorrect = actualDiscriminator.getPropertyName === "nestedSealedTraitVariant3Type" + val isMappingEmpty = Option(actualDiscriminator.getMapping).map(_.isEmpty).getOrElse(true) + val isOneOfCorrect = actualOneOf === expectedOneOf + val isParentDiscriminatorNotInProperties = Option(schema.getProperties).map(_.size == 0).getOrElse(true) + + isPropertyNameCorrect && isMappingEmpty && isOneOfCorrect && isParentDiscriminatorNotInProperties + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant3Subvariant1", + schema => { + val actualProperties = schema.getProperties.asScala + + val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") + val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitVariant3Type") && { + val s = actualProperties("nestedSealedTraitVariant3Type") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3Subvariant1") + } + val isParentDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && { + val s = actualProperties("nestedSealedTraitType") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3") + } + val areDiscriminatorPropertyRequired = schema.getRequired.contains( + "nestedSealedTraitVariant3Type" + ) && schema.getRequired.contains("nestedSealedTraitType") + val isCountOfPropertiesCorrect = actualProperties.size === 3 + + areNonDiscriminatorPropertiesCorrect && + isDiscriminatorPropertyCorrect && + isParentDiscriminatorPropertyCorrect && + areDiscriminatorPropertyRequired && + isCountOfPropertiesCorrect + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant3Subvariant2", + schema => { + val actualProperties = schema.getProperties.asScala + + val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitVariant3Type") && { + val s = actualProperties("nestedSealedTraitVariant3Type") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3Subvariant2") + } + val isParentDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && { + val s = actualProperties("nestedSealedTraitType") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3") + } + val areDiscriminatorPropertyRequired = schema.getRequired.contains( + "nestedSealedTraitVariant3Type" + ) && schema.getRequired.contains("nestedSealedTraitType") + val isCountOfPropertiesCorrect = actualProperties.size === 2 + + isDiscriminatorPropertyCorrect && + isParentDiscriminatorPropertyCorrect && + isCountOfPropertiesCorrect && + areDiscriminatorPropertyRequired + } + ) + } + + it should "add discriminator to sealed types " + + "and discriminator property to direct children (if addDiscriminatorPropertyOnlyToDirectChildren is true)" + + "if SumADTsShape is WithDiscriminator (with default DiscriminatorPropertyNameFn) in the config" in { + val components = new Components + val openAPIModelRegistration = new OpenAPIModelRegistration( + components, + config = OpenAPIModelRegistration.RegistrationConfig( + sumADTsShape = OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator( + addDiscriminatorPropertyOnlyToDirectChildren = true + ) + ) + ) + + openAPIModelRegistration.register[NestedSealedTrait]() + + val actualSchemas = components.getSchemas + + assertPredicateForPath( + actualSchemas, + "NestedSealedTrait", + schema => { + val actualDiscriminator = schema.getDiscriminator + val actualOneOf = schema.getOneOf.asScala + val expectedOneOf = Seq( + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant1"), + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant2"), + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3") + ) + + val isPropertyNameCorrect = actualDiscriminator.getPropertyName === "nestedSealedTraitType" + val isMappingEmpty = Option(actualDiscriminator.getMapping).map(_.isEmpty).getOrElse(true) + val isOneOfCorrect = actualOneOf === expectedOneOf + + isPropertyNameCorrect && isMappingEmpty && isOneOfCorrect + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant1", + schema => { + val actualProperties = schema.getProperties.asScala + + val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") && actualProperties.contains("b") + val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && { + val s = actualProperties("nestedSealedTraitType") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant1") + } + val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitType") + val isCountOfPropertiesCorrect = actualProperties.size === 3 + + areNonDiscriminatorPropertiesCorrect && + isDiscriminatorPropertyCorrect && + isDiscriminatorPropertyRequired && + isCountOfPropertiesCorrect + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant2", + schema => { + val actualProperties = schema.getProperties.asScala + + val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && { + val s = actualProperties("nestedSealedTraitType") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant2") + } + val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitType") + val isCountOfPropertiesCorrect = actualProperties.size === 1 + + isDiscriminatorPropertyCorrect && isDiscriminatorPropertyRequired && isCountOfPropertiesCorrect + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant3", + schema => { + val actualDiscriminator = schema.getDiscriminator + val actualOneOf = schema.getOneOf.asScala + val expectedOneOf = Seq( + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3Subvariant1"), + new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3Subvariant2") + ) + + val isPropertyNameCorrect = actualDiscriminator.getPropertyName === "nestedSealedTraitVariant3Type" + val isMappingEmpty = Option(actualDiscriminator.getMapping).map(_.isEmpty).getOrElse(true) + val isOneOfCorrect = actualOneOf === expectedOneOf + val isParentDiscriminatorNotInProperties = Option(schema.getProperties).map(_.size == 0).getOrElse(true) + + isPropertyNameCorrect && isMappingEmpty && isOneOfCorrect && isParentDiscriminatorNotInProperties + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant3Subvariant1", + schema => { + val actualProperties = schema.getProperties.asScala + + val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") + val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitVariant3Type") && { + val s = actualProperties("nestedSealedTraitVariant3Type") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3Subvariant1") + } + val isParentDiscriminatorNotInProperties = !actualProperties.contains("nestedSealedTraitType") + val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitVariant3Type") + val isCountOfPropertiesCorrect = actualProperties.size === 2 + + areNonDiscriminatorPropertiesCorrect && + isDiscriminatorPropertyCorrect && + isParentDiscriminatorNotInProperties && + isDiscriminatorPropertyRequired && + isCountOfPropertiesCorrect + } + ) + + assertPredicateForPath( + actualSchemas, + "NestedSealedTraitVariant3Subvariant2", + schema => { + val actualProperties = schema.getProperties.asScala + + val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitVariant3Type") && { + val s = actualProperties("nestedSealedTraitVariant3Type") + s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3Subvariant2") + } + val isParentDiscriminatorNotInProperties = !actualProperties.contains("nestedSealedTraitType") + val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitVariant3Type") + val isCountOfPropertiesCorrect = actualProperties.size === 1 + + isDiscriminatorPropertyCorrect && + isParentDiscriminatorNotInProperties && + isCountOfPropertiesCorrect && + isDiscriminatorPropertyRequired + } + ) + } + it should "not fail for empty sealed trait" in { val components = new Components val openAPIModelRegistration = new OpenAPIModelRegistration(components) @@ -389,7 +694,10 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec { } (Set.empty, handleFn) } - val openAPIModelRegistration = new OpenAPIModelRegistration(components, extraTypesHandler) + val openAPIModelRegistration = new OpenAPIModelRegistration( + components, + extraTypesHandler = extraTypesHandler + ) openAPIModelRegistration.register[ForCustomHandling]() @@ -440,7 +748,10 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec { schema.setFormat("my-custom-format") schema } - val openAPIModelRegistration = new OpenAPIModelRegistration(components, extraTypesHandler) + val openAPIModelRegistration = new OpenAPIModelRegistration( + components, + extraTypesHandler = extraTypesHandler + ) openAPIModelRegistration.register[SimpleTypesMaybeInOption]()