Skip to content

Commit c4cc3e4

Browse files
authored
#47 Support configurable discriminator for sum ADTs (#49)
1 parent c8ce0cb commit c4cc3e4

File tree

3 files changed

+439
-6
lines changed

3 files changed

+439
-6
lines changed

README.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ This library aims to avoid pollution of the model by custom annotations and depe
2121
- support for basic Scala collections (`Map`, `Seq`, `Set`, `Array`) as types of `case class` parameters
2222
- only top-level case classes need to be registered, child case classes are then recursively registered
2323
- support for Scala `Enumeration` where simple `Value` constructor is used (without `name`)
24-
- support for sum ADTs (`sealed trait` and `sealed abstract class`)
24+
- support for sum ADTs (`sealed trait` and `sealed abstract class`) with optional discriminator
2525

2626
## Usage
2727

@@ -215,6 +215,37 @@ Then, in `handleFn`, the handler creates a `Schema` object for `CustomClass`,
215215
adds it to `Components` so that it can be referenced by name `CustomClass`,
216216
and returns reference to that object.
217217

218+
### Registration configuration
219+
It is possible to further customize registration by providing custom `RegistrationConfig` to `OpenAPIModelRegistration`.
220+
221+
#### Example
222+
```scala
223+
val components = ...
224+
val registration = OpenAPIModelRegistration(
225+
components,
226+
config = RegistrationConfig(
227+
OpenAPIModelRegistration.RegistrationConfig(
228+
sumADTsShape =
229+
// default values apply for discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren
230+
OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator()
231+
)
232+
)
233+
)
234+
```
235+
236+
#### sumADTsShape
237+
This config property sets how sum ADTs are registered. It has two possible values:
238+
- `RegistrationConfig.SumADTsShape.WithoutDiscriminator` - default option, doesn't add discriminators
239+
- `RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren)` -
240+
adds discriminator to sealed types schema,
241+
and also adds discriminator to sum ADTs elements properties; discriminator property name is customizable by `discriminatorPropertyNameFn`,
242+
by default it takes sealed type name, converts its first letter to lower case, and adds `"Type"` suffix,
243+
for example if sealed type name is `Expression`, the property name is `expressionType`;
244+
if `addDiscriminatorPropertyOnlyToDirectChildren` is `false`, discriminator property is added to all children,
245+
so for example in `ADT = A | B | C; B = D | E` discriminator of `ADT` would be added to `A`, `C`, `D`, `E`
246+
(`D` and `E` would have discriminator of `B` in addition to that)
247+
while with `addDiscriminatorPropertyOnlyToDirectChildren` set to `true` (default)
248+
it would be added only to `A` and `C`
218249

219250
## Examples
220251

library/src/main/scala/za/co/absa/springdocopenapiscala/OpenAPIModelRegistration.scala

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@
1717
package za.co.absa.springdocopenapiscala
1818

1919
import io.swagger.v3.oas.models.Components
20-
import io.swagger.v3.oas.models.media.Schema
20+
import io.swagger.v3.oas.models.media.{Discriminator, Schema}
2121

2222
import java.time.{Instant, LocalDate, LocalDateTime, ZonedDateTime}
2323
import java.util.UUID
2424
import scala.annotation.tailrec
2525
import scala.collection.JavaConverters._
2626
import scala.reflect.runtime.universe._
27-
2827
import OpenAPIModelRegistration._
2928

3029
class OpenAPIModelRegistration(
3130
components: Components,
32-
extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling
31+
extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling,
32+
config: RegistrationConfig = RegistrationConfig()
3333
) {
3434

3535
/**
@@ -144,13 +144,60 @@ class OpenAPIModelRegistration(
144144
s.isTerm && s.asTerm.isVal && s.typeSignature <:< typeOf[Enumeration#Value]
145145

146146
private def handleSealedType(tpe: Type): Schema[_] = {
147+
148+
def addDiscriminatorPropertyToChildren(
149+
currentSchema: Schema[_],
150+
discriminatorPropertyName: String,
151+
addOnlyToDirectChildren: Boolean,
152+
discriminatorValue: Option[String] = None
153+
): Unit = {
154+
val children = currentSchema.getOneOf.asScala
155+
children.foreach { s =>
156+
val ref = s.get$ref
157+
val name = extractSchemaNameFromRef(ref)
158+
val actualSchema = components.getSchemas.get(name)
159+
if (actualSchema.getType == "object") {
160+
val constEnumSchema = createConstEnumSchema(discriminatorValue.getOrElse(name))
161+
actualSchema.addProperty(discriminatorPropertyName, constEnumSchema)
162+
actualSchema.addRequiredItem(discriminatorPropertyName)
163+
} else if (
164+
!addOnlyToDirectChildren &&
165+
Option(actualSchema.getOneOf).map(!_.isEmpty).getOrElse(false) // is schema representing another sum ADT root
166+
) {
167+
addDiscriminatorPropertyToChildren(
168+
actualSchema,
169+
discriminatorPropertyName,
170+
addOnlyToDirectChildren,
171+
Some(name)
172+
)
173+
}
174+
}
175+
}
176+
147177
val classSymbol = tpe.typeSymbol.asClass
148178
val name = tpe.typeSymbol.name.toString.trim
149179
val children = classSymbol.knownDirectSubclasses
150180
val childrenSchemas = children.map(_.asType.toType).map(handleType)
151181
val schema = new Schema
152182
schema.setOneOf(childrenSchemas.toList.asJava)
153183

184+
config.sumADTsShape match {
185+
case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addOnlyToDirectChildren) =>
186+
val discriminatorPropertyName = discriminatorPropertyNameFn(name)
187+
schema.setDiscriminator {
188+
val discriminator = new Discriminator
189+
discriminator.setPropertyName(discriminatorPropertyName)
190+
discriminator
191+
}
192+
addDiscriminatorPropertyToChildren(
193+
schema,
194+
discriminatorPropertyName,
195+
addOnlyToDirectChildren
196+
)
197+
198+
case _ => ()
199+
}
200+
154201
registerAsReference(name, schema)
155202
}
156203

@@ -187,10 +234,54 @@ class OpenAPIModelRegistration(
187234
schemaReference
188235
}
189236

237+
private def createConstEnumSchema(const: String): Schema[_] = {
238+
val constEnumSchema = new Schema[String]
239+
constEnumSchema.setType("string")
240+
constEnumSchema.setEnum(Seq(const).asJava)
241+
constEnumSchema
242+
}
243+
244+
private def extractSchemaNameFromRef(ref: String): String = {
245+
ref.substring(ref.lastIndexOf("/") + 1)
246+
}
247+
190248
}
191249

192250
object OpenAPIModelRegistration {
193251

252+
/**
253+
* Configuration of the registration class.
254+
*
255+
* @param sumADTsShape how sum ADTs should be registered (with or without discriminator)
256+
*/
257+
case class RegistrationConfig(
258+
sumADTsShape: RegistrationConfig.SumADTsShape = RegistrationConfig.SumADTsShape.WithoutDiscriminator
259+
)
260+
261+
object RegistrationConfig {
262+
263+
sealed abstract class SumADTsShape
264+
265+
object SumADTsShape {
266+
case object WithoutDiscriminator extends SumADTsShape
267+
case class WithDiscriminator(
268+
discriminatorPropertyNameFn: WithDiscriminator.DiscriminatorPropertyNameFn =
269+
WithDiscriminator.defaultDiscriminatorPropertyNameFn,
270+
addDiscriminatorPropertyOnlyToDirectChildren: Boolean = true
271+
) extends SumADTsShape
272+
273+
object WithDiscriminator {
274+
275+
/** Function from sealed type name to discriminator property name. */
276+
type DiscriminatorPropertyNameFn = String => String
277+
278+
val defaultDiscriminatorPropertyNameFn: DiscriminatorPropertyNameFn = sealedTypeName =>
279+
sealedTypeName.head.toLower + sealedTypeName.tail + "Type"
280+
}
281+
}
282+
283+
}
284+
194285
/**
195286
* Context of model registration.
196287
* Currently contains only `Components` that can be mutated if needed

0 commit comments

Comments
 (0)