Skip to content

Commit 203d9a2

Browse files
authored
#50 Fix StackOverflow when processing recursive ADTs (#52)
1 parent c4cc3e4 commit 203d9a2

File tree

2 files changed

+60
-33
lines changed

2 files changed

+60
-33
lines changed

library/src/main/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistration.scala

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ class OpenAPIModelRegistration(
149149
currentSchema: Schema[_],
150150
discriminatorPropertyName: String,
151151
addOnlyToDirectChildren: Boolean,
152-
discriminatorValue: Option[String] = None
152+
discriminatorValue: Option[String] = None,
153+
seen: Set[String] = Set.empty
153154
): Unit = {
154155
val children = currentSchema.getOneOf.asScala
155156
children.foreach { s =>
@@ -162,43 +163,59 @@ class OpenAPIModelRegistration(
162163
actualSchema.addRequiredItem(discriminatorPropertyName)
163164
} else if (
164165
!addOnlyToDirectChildren &&
166+
!seen.contains(name) &&
165167
Option(actualSchema.getOneOf).map(!_.isEmpty).getOrElse(false) // is schema representing another sum ADT root
166168
) {
167169
addDiscriminatorPropertyToChildren(
168170
actualSchema,
169171
discriminatorPropertyName,
170172
addOnlyToDirectChildren,
171-
Some(name)
173+
Some(name),
174+
seen + name
172175
)
173176
}
174177
}
175178
}
176179

177180
val classSymbol = tpe.typeSymbol.asClass
178181
val name = tpe.typeSymbol.name.toString.trim
179-
val children = classSymbol.knownDirectSubclasses
180-
val childrenSchemas = children.map(_.asType.toType).map(handleType)
181-
val schema = new Schema
182-
schema.setOneOf(childrenSchemas.toList.asJava)
183-
184-
config.sumADTsShape match {
185-
case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addOnlyToDirectChildren) =>
186-
val discriminatorPropertyName = discriminatorPropertyNameFn(name)
187-
schema.setDiscriminator {
188-
val discriminator = new Discriminator
189-
discriminator.setPropertyName(discriminatorPropertyName)
190-
discriminator
191-
}
192-
addDiscriminatorPropertyToChildren(
193-
schema,
194-
discriminatorPropertyName,
195-
addOnlyToDirectChildren
196-
)
197182

198-
case _ => ()
199-
}
183+
// in case of recursive ADT, it might already have been processed, thus we should skip
184+
val wasAlreadyProcessed = Option(components.getSchemas).map(_.containsKey(name)).getOrElse(false)
185+
186+
if (wasAlreadyProcessed) {
187+
(new Schema).$ref(s"#/components/schemas/$name")
188+
} else {
189+
val children = classSymbol.knownDirectSubclasses
190+
// we can assume that all sum ADT direct children are registered as reference, as these can be:
191+
// - case classes = registered as reference
192+
// - case objects = registered as reference
193+
// - sealed trait/abstract class = registered as reference
194+
val childrenRefs = children.map(s => (new Schema).$ref(s.name.toString.trim)).toSeq
195+
val schema = new Schema
196+
schema.setOneOf(childrenRefs.asJava)
197+
val schemaRef = registerAsReference(name, schema)
198+
children.map(_.asType.toType).foreach(handleType)
199+
200+
config.sumADTsShape match {
201+
case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addOnlyToDirectChildren) =>
202+
val discriminatorPropertyName = discriminatorPropertyNameFn(name)
203+
schema.setDiscriminator {
204+
val discriminator = new Discriminator
205+
discriminator.setPropertyName(discriminatorPropertyName)
206+
discriminator
207+
}
208+
addDiscriminatorPropertyToChildren(
209+
schema,
210+
discriminatorPropertyName,
211+
addOnlyToDirectChildren
212+
)
200213

201-
registerAsReference(name, schema)
214+
case _ => ()
215+
}
216+
217+
schemaRef
218+
}
202219
}
203220

204221
private def handleSimpleType(tpe: Type): Schema[_] = {

library/src/test/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistrationSpec.scala

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec {
135135
private sealed trait NestedSealedTrait
136136

137137
private object NestedSealedTrait {
138-
case class NestedSealedTraitVariant1(a: String, b: Int) extends NestedSealedTrait
138+
case class NestedSealedTraitVariant1(a: String, b: Int, c: NestedSealedTrait) extends NestedSealedTrait
139139
case object NestedSealedTraitVariant2 extends NestedSealedTrait
140140

141141
sealed abstract class NestedSealedTraitVariant3 extends NestedSealedTrait
142142

143143
object NestedSealedTraitVariant3 {
144-
case class NestedSealedTraitVariant3Subvariant1(a: Float) extends NestedSealedTraitVariant3
144+
case class NestedSealedTraitVariant3Subvariant1(a: Float, b: NestedSealedTrait) extends NestedSealedTraitVariant3
145145
case object NestedSealedTraitVariant3Subvariant2 extends NestedSealedTraitVariant3
146146
}
147147

@@ -393,13 +393,16 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec {
393393
schema => {
394394
val actualProperties = schema.getProperties.asScala
395395

396-
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") && actualProperties.contains("b")
396+
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") &&
397+
actualProperties.contains("b") &&
398+
actualProperties.contains("c") &&
399+
actualProperties("c").get$ref === "#/components/schemas/NestedSealedTrait"
397400
val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && {
398401
val s = actualProperties("nestedSealedTraitType")
399402
s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant1")
400403
}
401404
val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitType")
402-
val isCountOfPropertiesCorrect = actualProperties.size === 3
405+
val isCountOfPropertiesCorrect = actualProperties.size === 4
403406

404407
areNonDiscriminatorPropertiesCorrect &&
405408
isDiscriminatorPropertyCorrect &&
@@ -451,7 +454,9 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec {
451454
schema => {
452455
val actualProperties = schema.getProperties.asScala
453456

454-
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a")
457+
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") &&
458+
actualProperties.contains("b") &&
459+
actualProperties("b").get$ref === "#/components/schemas/NestedSealedTrait"
455460
val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitVariant3Type") && {
456461
val s = actualProperties("nestedSealedTraitVariant3Type")
457462
s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3Subvariant1")
@@ -463,7 +468,7 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec {
463468
val areDiscriminatorPropertyRequired = schema.getRequired.contains(
464469
"nestedSealedTraitVariant3Type"
465470
) && schema.getRequired.contains("nestedSealedTraitType")
466-
val isCountOfPropertiesCorrect = actualProperties.size === 3
471+
val isCountOfPropertiesCorrect = actualProperties.size === 4
467472

468473
areNonDiscriminatorPropertiesCorrect &&
469474
isDiscriminatorPropertyCorrect &&
@@ -543,13 +548,16 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec {
543548
schema => {
544549
val actualProperties = schema.getProperties.asScala
545550

546-
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") && actualProperties.contains("b")
551+
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") &&
552+
actualProperties.contains("b") &&
553+
actualProperties.contains("c") &&
554+
actualProperties("c").get$ref === "#/components/schemas/NestedSealedTrait"
547555
val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitType") && {
548556
val s = actualProperties("nestedSealedTraitType")
549557
s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant1")
550558
}
551559
val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitType")
552-
val isCountOfPropertiesCorrect = actualProperties.size === 3
560+
val isCountOfPropertiesCorrect = actualProperties.size === 4
553561

554562
areNonDiscriminatorPropertiesCorrect &&
555563
isDiscriminatorPropertyCorrect &&
@@ -601,14 +609,16 @@ class OpenAPIModelRegistrationSpec extends AnyFlatSpec {
601609
schema => {
602610
val actualProperties = schema.getProperties.asScala
603611

604-
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a")
612+
val areNonDiscriminatorPropertiesCorrect = actualProperties.contains("a") &&
613+
actualProperties.contains("b") &&
614+
actualProperties("b").get$ref == "#/components/schemas/NestedSealedTrait"
605615
val isDiscriminatorPropertyCorrect = actualProperties.contains("nestedSealedTraitVariant3Type") && {
606616
val s = actualProperties("nestedSealedTraitVariant3Type")
607617
s.getType === "string" && s.getEnum.asScala === Seq("NestedSealedTraitVariant3Subvariant1")
608618
}
609619
val isParentDiscriminatorNotInProperties = !actualProperties.contains("nestedSealedTraitType")
610620
val isDiscriminatorPropertyRequired = schema.getRequired.contains("nestedSealedTraitVariant3Type")
611-
val isCountOfPropertiesCorrect = actualProperties.size === 2
621+
val isCountOfPropertiesCorrect = actualProperties.size === 3
612622

613623
areNonDiscriminatorPropertiesCorrect &&
614624
isDiscriminatorPropertyCorrect &&

0 commit comments

Comments
 (0)