Skip to content

#47 Support configurable discriminator for sum ADTs #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This library aims to avoid pollution of the model by custom annotations and depe
- support for basic Scala collections (`Map`, `Seq`, `Set`, `Array`) as types of `case class` parameters
- only top-level case classes need to be registered, child case classes are then recursively registered
- support for Scala `Enumeration` where simple `Value` constructor is used (without `name`)
- support for sum ADTs (`sealed trait` and `sealed abstract class`)
- support for sum ADTs (`sealed trait` and `sealed abstract class`) with optional discriminator

## Usage

Expand Down Expand Up @@ -215,6 +215,37 @@ Then, in `handleFn`, the handler creates a `Schema` object for `CustomClass`,
adds it to `Components` so that it can be referenced by name `CustomClass`,
and returns reference to that object.

### Registration configuration
It is possible to further customize registration by providing custom `RegistrationConfig` to `OpenAPIModelRegistration`.

#### Example
```scala
val components = ...
val registration = OpenAPIModelRegistration(
components,
config = RegistrationConfig(
OpenAPIModelRegistration.RegistrationConfig(
sumADTsShape =
// default values apply for discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren
OpenAPIModelRegistration.RegistrationConfig.SumADTsShape.WithDiscriminator()
)
)
)
```

#### sumADTsShape
This config property sets how sum ADTs are registered. It has two possible values:
- `RegistrationConfig.SumADTsShape.WithoutDiscriminator` - default option, doesn't add discriminators
- `RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addDiscriminatorPropertyOnlyToDirectChildren)` -
adds discriminator to sealed types schema,
and also adds discriminator to sum ADTs elements properties; discriminator property name is customizable by `discriminatorPropertyNameFn`,
by default it takes sealed type name, converts its first letter to lower case, and adds `"Type"` suffix,
for example if sealed type name is `Expression`, the property name is `expressionType`;
if `addDiscriminatorPropertyOnlyToDirectChildren` is `false`, discriminator property is added to all children,
so for example in `ADT = A | B | C; B = D | E` discriminator of `ADT` would be added to `A`, `C`, `D`, `E`
(`D` and `E` would have discriminator of `B` in addition to that)
while with `addDiscriminatorPropertyOnlyToDirectChildren` set to `true` (default)
it would be added only to `A` and `C`

## Examples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
package za.co.absa.springdocopenapiscala

import io.swagger.v3.oas.models.Components
import io.swagger.v3.oas.models.media.Schema
import io.swagger.v3.oas.models.media.{Discriminator, Schema}

import java.time.{Instant, LocalDate, LocalDateTime, ZonedDateTime}
import java.util.UUID
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe._

import OpenAPIModelRegistration._

class OpenAPIModelRegistration(
components: Components,
extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling
extraTypesHandler: ExtraTypesHandling.ExtraTypesHandler = ExtraTypesHandling.noExtraHandling,
config: RegistrationConfig = RegistrationConfig()
) {

/**
Expand Down Expand Up @@ -144,13 +144,60 @@ class OpenAPIModelRegistration(
s.isTerm && s.asTerm.isVal && s.typeSignature <:< typeOf[Enumeration#Value]

private def handleSealedType(tpe: Type): Schema[_] = {

def addDiscriminatorPropertyToChildren(
currentSchema: Schema[_],
discriminatorPropertyName: String,
addOnlyToDirectChildren: Boolean,
discriminatorValue: Option[String] = None
): Unit = {
val children = currentSchema.getOneOf.asScala
children.foreach { s =>
val ref = s.get$ref
val name = extractSchemaNameFromRef(ref)
val actualSchema = components.getSchemas.get(name)
if (actualSchema.getType == "object") {
val constEnumSchema = createConstEnumSchema(discriminatorValue.getOrElse(name))
actualSchema.addProperty(discriminatorPropertyName, constEnumSchema)
actualSchema.addRequiredItem(discriminatorPropertyName)
} else if (
!addOnlyToDirectChildren &&
Option(actualSchema.getOneOf).map(!_.isEmpty).getOrElse(false) // is schema representing another sum ADT root
) {
addDiscriminatorPropertyToChildren(
actualSchema,
discriminatorPropertyName,
addOnlyToDirectChildren,
Some(name)
)
}
}
}

val classSymbol = tpe.typeSymbol.asClass
val name = tpe.typeSymbol.name.toString.trim
val children = classSymbol.knownDirectSubclasses
val childrenSchemas = children.map(_.asType.toType).map(handleType)
val schema = new Schema
schema.setOneOf(childrenSchemas.toList.asJava)

config.sumADTsShape match {
case RegistrationConfig.SumADTsShape.WithDiscriminator(discriminatorPropertyNameFn, addOnlyToDirectChildren) =>
val discriminatorPropertyName = discriminatorPropertyNameFn(name)
schema.setDiscriminator {
val discriminator = new Discriminator
discriminator.setPropertyName(discriminatorPropertyName)
discriminator
}
addDiscriminatorPropertyToChildren(
schema,
discriminatorPropertyName,
addOnlyToDirectChildren
)

case _ => ()
}

registerAsReference(name, schema)
}

Expand Down Expand Up @@ -187,10 +234,54 @@ class OpenAPIModelRegistration(
schemaReference
}

private def createConstEnumSchema(const: String): Schema[_] = {
val constEnumSchema = new Schema[String]
constEnumSchema.setType("string")
constEnumSchema.setEnum(Seq(const).asJava)
constEnumSchema
}

private def extractSchemaNameFromRef(ref: String): String = {
ref.substring(ref.lastIndexOf("/") + 1)
}

}

object OpenAPIModelRegistration {

/**
* Configuration of the registration class.
*
* @param sumADTsShape how sum ADTs should be registered (with or without discriminator)
*/
case class RegistrationConfig(
sumADTsShape: RegistrationConfig.SumADTsShape = RegistrationConfig.SumADTsShape.WithoutDiscriminator
)

object RegistrationConfig {

sealed abstract class SumADTsShape

object SumADTsShape {
case object WithoutDiscriminator extends SumADTsShape
case class WithDiscriminator(
discriminatorPropertyNameFn: WithDiscriminator.DiscriminatorPropertyNameFn =
WithDiscriminator.defaultDiscriminatorPropertyNameFn,
addDiscriminatorPropertyOnlyToDirectChildren: Boolean = true
) extends SumADTsShape

object WithDiscriminator {

/** Function from sealed type name to discriminator property name. */
type DiscriminatorPropertyNameFn = String => String

val defaultDiscriminatorPropertyNameFn: DiscriminatorPropertyNameFn = sealedTypeName =>
sealedTypeName.head.toLower + sealedTypeName.tail + "Type"
}
}

}

/**
* Context of model registration.
* Currently contains only `Components` that can be mutated if needed
Expand Down
Loading