Skip to content

Commit d5a16af

Browse files
committed
Add option to add discriminator property only to direct children
1 parent cb7bd4c commit d5a16af

File tree

3 files changed

+164
-8
lines changed

3 files changed

+164
-8
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,16 @@ val registration = OpenAPIModelRegistration(
234234
#### sumADTsShape
235235
This config property sets how sum ADTs are registered. It has two possible values:
236236
- `RegistrationConfig.SumADTsShape.WithoutDiscriminator` - default option, doesn't add discriminators
237-
- `RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn)` - adds discriminator to sealed types schema,
237+
- `RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren)` -
238+
adds discriminator to sealed types schema,
238239
and also adds discriminator to sum ADTs elements properties; discriminator property name is customizable by `discriminatorPropertyNameFn`,
239240
by default it takes sealed type name, converts its first letter to lower case, and adds `"Type"` suffix,
240-
for example if sealed type name is `Expression`, the property name is `expressionType`
241+
for example if sealed type name is `Expression`, the property name is `expressionType`;
242+
if `addDiscriminatorPropertyOnlyToDirectChildren` is `false`, discriminator property is added to all children,
243+
so for example in `ADT = A | B | C; B = D | E` discriminator of `ADT` would be added to `A`, `C`, `D`, `E`
244+
(`D` and `E` would have discriminator of `B` in addition to that)
245+
while with `addDiscriminatorPropertyOnlyToDirectChildren` set to `true` (default)
246+
it would be added only to `A` and `C`
241247

242248
## Examples
243249

library/src/main/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistration.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ class OpenAPIModelRegistration(
148148
def addDiscriminatorPropertyToChildren(
149149
currentSchema: Schema[_],
150150
discriminatorPropertyName: String,
151+
addOnlyToDirectChildren: Boolean,
151152
discriminatorValue: Option[String] = None
152153
): Unit = {
153154
val children = currentSchema.getOneOf.asScala
@@ -159,10 +160,11 @@ class OpenAPIModelRegistration(
159160
val constEnumSchema = createConstEnumSchema(discriminatorValue.getOrElse(name))
160161
actualSchema.addProperty(discriminatorPropertyName, constEnumSchema)
161162
actualSchema.addRequiredItem(discriminatorPropertyName)
162-
} else if (Option(actualSchema.getOneOf).map(!_.isEmpty).getOrElse(false)) {
163+
} else if (!addOnlyToDirectChildren && Option(actualSchema.getOneOf).map(!_.isEmpty).getOrElse(false)) {
163164
addDiscriminatorPropertyToChildren(
164165
actualSchema,
165166
discriminatorPropertyName,
167+
addOnlyToDirectChildren,
166168
Some(name)
167169
)
168170
}
@@ -177,14 +179,18 @@ class OpenAPIModelRegistration(
177179
schema.setOneOf(childrenSchemas.toList.asJava)
178180

179181
config.sumADTsShape match {
180-
case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn) =>
182+
case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addOnlyToDirectChildren) =>
181183
val discriminatorPropertyName = discriminatorPropertyNameFn(name)
182184
schema.setDiscriminator {
183185
val discriminator = new Discriminator
184186
discriminator.setPropertyName(discriminatorPropertyName)
185187
discriminator
186188
}
187-
addDiscriminatorPropertyToChildren(schema, discriminatorPropertyName)
189+
addDiscriminatorPropertyToChildren(
190+
schema,
191+
discriminatorPropertyName,
192+
addOnlyToDirectChildren
193+
)
188194

189195
case _ => ()
190196
}
@@ -257,7 +263,8 @@ object OpenAPIModelRegistration {
257263
case object WithoutDiscriminator extends SumADTsShape
258264
case class WithDiscriminator(
259265
discriminatorPropertyNameFn: WithDiscriminator.DiscriminatorPropertyNameFn =
260-
WithDiscriminator.defaultDiscriminatorPropertyNameFn
266+
WithDiscriminator.defaultDiscriminatorPropertyNameFn,
267+
addDiscriminatorPropertyOnlyToDirectChildren: Boolean = true
261268
) extends SumADTsShape
262269

263270
object WithDiscriminator {

library/src/test/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistrationSpec.scala

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,16 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec {
350350
)
351351
}
352352

353-
it should "add discriminator to sealed types and corresponding discriminator property to their children " +
353+
it should "add discriminator to sealed types " +
354+
"and discriminator property to all children (if addDiscriminatorPropertyOnlyToDirectChildren is false)" +
354355
"if SumADTsShape is WithDiscriminator (with default DiscriminatorPropertyNameFn) in the config" in {
355356
val components = new Components
356357
val openAPIModelRegistration = new OpenAPIModelRegistration(
357358
components,
358359
config = OpenAPIModelRegistration.RegistrationConfig(
359-
sumADTsShape = OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator()
360+
sumADTsShape = OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator(
361+
addDiscriminatorPropertyOnlyToDirectChildren = false
362+
)
360363
)
361364
)
362365

@@ -497,6 +500,146 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec {
497500
)
498501
}
499502

503+
it should "add discriminator to sealed types " +
504+
"and discriminator property to direct children (if addDiscriminatorPropertyOnlyToDirectChildren is true)" +
505+
"if SumADTsShape is WithDiscriminator (with default DiscriminatorPropertyNameFn) in the config" in {
506+
val components = new Components
507+
val openAPIModelRegistration = new OpenAPIModelRegistration(
508+
components,
509+
config = OpenAPIModelRegistration.RegistrationConfig(
510+
sumADTsShape = OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator(
511+
addDiscriminatorPropertyOnlyToDirectChildren = true
512+
)
513+
)
514+
)
515+
516+
openAPIModelRegistration.register[NestedSealedTrait]()
517+
518+
val actualSchemas = components.getSchemas
519+
520+
assertPredicateForPath(
521+
actualSchemas,
522+
"NestedSealedTrait",
523+
schema => {
524+
val actualDiscriminator = schema.getDiscriminator
525+
val actualOneOf = schema.getOneOf.asScala
526+
val expectedOneOf = Seq(
527+
new Schema().$ref("#/components/schemas/NestedSealedTraitVariant1"),
528+
new Schema().$ref("#/components/schemas/NestedSealedTraitVariant2"),
529+
new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3")
530+
)
531+
532+
val isPropertyNameCorrect = actualDiscriminator.getPropertyName === "nestedSealedTraitType"
533+
val isMappingEmpty = Option(actualDiscriminator.getMapping).map(_.isEmpty).getOrElse(true)
534+
val isOneOfCorrect = actualOneOf === expectedOneOf
535+
536+
isPropertyNameCorrect && isMappingEmpty && isOneOfCorrect
537+
}
538+
)
539+
540+
assertPredicateForPath(
541+
actualSchemas,
542+
"NestedSealedTraitVariant1",
543+
schema => {
544+
val actualProperties = schema.getProperties.asScala
545+
546+
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") && actualProperties.contains("b")
547+
val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && {
548+
val s = actualProperties("nestedSealedTraitType")
549+
s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant1")
550+
}
551+
val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitType")
552+
val isCountOfPropertiesCorrect = actualProperties.size === 3
553+
554+
areNonDiscriminatorPropertiesCorrect &&
555+
isDiscriminatorPropertyCorrect &&
556+
isDiscriminatorPropertyRequired &&
557+
isCountOfPropertiesCorrect
558+
}
559+
)
560+
561+
assertPredicateForPath(
562+
actualSchemas,
563+
"NestedSealedTraitVariant2",
564+
schema => {
565+
val actualProperties = schema.getProperties.asScala
566+
567+
val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && {
568+
val s = actualProperties("nestedSealedTraitType")
569+
s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant2")
570+
}
571+
val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitType")
572+
val isCountOfPropertiesCorrect = actualProperties.size === 1
573+
574+
isDiscriminatorPropertyCorrect && isDiscriminatorPropertyRequired && isCountOfPropertiesCorrect
575+
}
576+
)
577+
578+
assertPredicateForPath(
579+
actualSchemas,
580+
"NestedSealedTraitVariant3",
581+
schema => {
582+
val actualDiscriminator = schema.getDiscriminator
583+
val actualOneOf = schema.getOneOf.asScala
584+
val expectedOneOf = Seq(
585+
new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3Subvariant1"),
586+
new Schema().$ref("#/components/schemas/NestedSealedTraitVariant3Subvariant2")
587+
)
588+
589+
val isPropertyNameCorrect = actualDiscriminator.getPropertyName === "nestedSealedTraitVariant3Type"
590+
val isMappingEmpty = Option(actualDiscriminator.getMapping).map(_.isEmpty).getOrElse(true)
591+
val isOneOfCorrect = actualOneOf === expectedOneOf
592+
val isParentDiscriminatorNotInProperties = Option(schema.getProperties).map(_.size == 0).getOrElse(true)
593+
594+
isPropertyNameCorrect && isMappingEmpty && isOneOfCorrect && isParentDiscriminatorNotInProperties
595+
}
596+
)
597+
598+
assertPredicateForPath(
599+
actualSchemas,
600+
"NestedSealedTraitVariant3Subvariant1",
601+
schema => {
602+
val actualProperties = schema.getProperties.asScala
603+
604+
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a")
605+
val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitVariant3Type") && {
606+
val s = actualProperties("nestedSealedTraitVariant3Type")
607+
s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3Subvariant1")
608+
}
609+
val isParentDiscriminatorNotInProperties = !actualProperties.contains("nestedSealedTraitType")
610+
val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitVariant3Type")
611+
val isCountOfPropertiesCorrect = actualProperties.size === 2
612+
613+
areNonDiscriminatorPropertiesCorrect &&
614+
isDiscriminatorPropertyCorrect &&
615+
isParentDiscriminatorNotInProperties &&
616+
isDiscriminatorPropertyRequired &&
617+
isCountOfPropertiesCorrect
618+
}
619+
)
620+
621+
assertPredicateForPath(
622+
actualSchemas,
623+
"NestedSealedTraitVariant3Subvariant2",
624+
schema => {
625+
val actualProperties = schema.getProperties.asScala
626+
627+
val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitVariant3Type") && {
628+
val s = actualProperties("nestedSealedTraitVariant3Type")
629+
s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3Subvariant2")
630+
}
631+
val isParentDiscriminatorNotInProperties = !actualProperties.contains("nestedSealedTraitType")
632+
val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitVariant3Type")
633+
val isCountOfPropertiesCorrect = actualProperties.size === 1
634+
635+
isDiscriminatorPropertyCorrect &&
636+
isParentDiscriminatorNotInProperties &&
637+
isCountOfPropertiesCorrect &&
638+
isDiscriminatorPropertyRequired
639+
}
640+
)
641+
}
642+
500643
it should "not fail for empty sealed trait" in {
501644
val components = new Components
502645
val openAPIModelRegistration = new OpenAPIModelRegistration(components)

0 commit comments

Comments
 (0)