Skip to content

Commit

Permalink
[csharp][java] Fix enum discriminator default value (#19614)
Browse files Browse the repository at this point in the history
* Fix enum discriminator default value

* Remove system out call

* Add case when discriminator type is ref

* Use correct schema

* Handle different use cases of mappings

* Add missing enum type Lizzy

* Make it more robust

* Add missing test for Sedan

* Refactor some code to make it cleaner

* Initialize discriminator enum field

* Don't override existing default value

* Fix issue with finding discriminators

* Move setIsEnum back to its original location

* Be smarter about figuring out the model name

* Fix final warnings

* Add javadocs to introduced methods
  • Loading branch information
david-marconis authored Jan 21, 2025
1 parent 1fa07bf commit c75fbb3
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3098,6 +3098,7 @@ public CodegenModel fromModel(String name, Schema schema) {
listOLists.add(m.requiredVars);
listOLists.add(m.vars);
listOLists.add(m.allVars);
listOLists.add(m.readWriteVars);
for (List<CodegenProperty> theseVars : listOLists) {
for (CodegenProperty requiredVar : theseVars) {
if (discPropName.equals(requiredVar.baseName)) {
Expand Down Expand Up @@ -3131,6 +3132,63 @@ public CodegenModel fromModel(String name, Schema schema) {
return m;
}

/**
* Sets the default value for an enum discriminator property in the provided {@link CodegenModel}.
* <p>
* If the model's discriminator is defined, this method identifies the discriminator properties among the model's
* variables and assigns the default value to reflect the corresponding enum value for the model type.
* </p>
* <p>
* Example: If the discriminator is for type `Animal`, and the model is `Cat`, the default value
* will be set to `Animal.Cat` for the properties that have the same name as the discriminator.
* </p>
*
* @param model the {@link CodegenModel} whose discriminator property default value is to be set
*/
protected static void setEnumDiscriminatorDefaultValue(CodegenModel model) {
if (model.discriminator == null) {
return;
}
String discPropName = model.discriminator.getPropertyBaseName();
Stream.of(model.requiredVars, model.vars, model.allVars, model.readWriteVars)
.flatMap(List::stream)
.filter(v -> discPropName.equals(v.baseName))
.forEach(v -> v.defaultValue = getEnumValueForProperty(model.schemaName, model.discriminator, v));
}

/**
* Retrieves the appropriate default value for an enum discriminator property based on the model name.
* <p>
* If the discriminator has a mapping defined, it attempts to find a mapping for the model name.
* Otherwise, it defaults to one of the allowable enum value associated with the property.
* If no suitable value is found, the original default value of the property is returned.
* </p>
*
* @param modelName the name of the model to determine the default value for
* @param discriminator the {@link CodegenDiscriminator} containing the mapping and enum details
* @param var the {@link CodegenProperty} representing the discriminator property
* @return the default value for the enum discriminator property, or its original default value if none is found
*/
protected static String getEnumValueForProperty(
String modelName, CodegenDiscriminator discriminator, CodegenProperty var) {
if (!discriminator.getIsEnum() && !var.isEnum) {
return var.defaultValue;
}
Map<String, String> mapping = Optional.ofNullable(discriminator.getMapping()).orElseGet(Collections::emptyMap);
for (Map.Entry<String, String> e : mapping.entrySet()) {
String schemaName = e.getValue().indexOf('/') < 0 ? e.getValue() : ModelUtils.getSimpleRef(e.getValue());
if (modelName.equals(schemaName)) {
return e.getKey();
}
}
Object values = var.allowableValues.get("values");
if (!(values instanceof List<?>)) {
return var.defaultValue;
}
List<?> valueList = (List<?>) values;
return valueList.stream().filter(o -> o.equals(modelName)).map(o -> (String) o).findAny().orElse(var.defaultValue);
}

protected void SortModelPropertiesByRequiredFlag(CodegenModel model) {
Comparator<CodegenProperty> comparator = new Comparator<CodegenProperty>() {
@Override
Expand Down Expand Up @@ -3201,15 +3259,19 @@ protected void setAddProps(Schema schema, IJsonSchemaValidationProperties proper
* @param visitedSchemas A set of visited schema names
*/
private CodegenProperty discriminatorFound(String composedSchemaName, Schema sc, String discPropName, Set<String> visitedSchemas) {
if (visitedSchemas.contains(composedSchemaName)) { // recursive schema definition found
Schema refSchema = ModelUtils.getReferencedSchema(openAPI, sc);
String schemaName = Optional.ofNullable(composedSchemaName)
.or(() -> Optional.ofNullable(refSchema.getName()))
.or(() -> Optional.ofNullable(sc.get$ref()).map(ModelUtils::getSimpleRef))
.orElseGet(sc::toString);
if (visitedSchemas.contains(schemaName)) { // recursive schema definition found
return null;
} else {
visitedSchemas.add(composedSchemaName);
visitedSchemas.add(schemaName);
}

Schema refSchema = ModelUtils.getReferencedSchema(openAPI, sc);
if (refSchema.getProperties() != null && refSchema.getProperties().get(discPropName) != null) {
Schema discSchema = (Schema) refSchema.getProperties().get(discPropName);
Schema discSchema = ModelUtils.getReferencedSchema(openAPI, (Schema)refSchema.getProperties().get(discPropName));
CodegenProperty cp = new CodegenProperty();
if (ModelUtils.isStringSchema(discSchema)) {
cp.isString = true;
Expand All @@ -3218,14 +3280,16 @@ private CodegenProperty discriminatorFound(String composedSchemaName, Schema sc,
if (refSchema.getRequired() != null && refSchema.getRequired().contains(discPropName)) {
cp.setRequired(true);
}
cp.setIsEnum(discSchema.getEnum() != null && !discSchema.getEnum().isEmpty());
return cp;
}
if (ModelUtils.isComposedSchema(refSchema)) {
Schema composedSchema = refSchema;
if (composedSchema.getAllOf() != null) {
// If our discriminator is in one of the allOf schemas break when we find it
for (Object allOf : composedSchema.getAllOf()) {
CodegenProperty cp = discriminatorFound(composedSchemaName, (Schema) allOf, discPropName, visitedSchemas);
Schema allOfSchema = (Schema) allOf;
CodegenProperty cp = discriminatorFound(allOfSchema.getName(), allOfSchema, discPropName, visitedSchemas);
if (cp != null) {
return cp;
}
Expand All @@ -3235,8 +3299,11 @@ private CodegenProperty discriminatorFound(String composedSchemaName, Schema sc,
// All oneOf definitions must contain the discriminator
CodegenProperty cp = new CodegenProperty();
for (Object oneOf : composedSchema.getOneOf()) {
String modelName = ModelUtils.getSimpleRef(((Schema) oneOf).get$ref());
CodegenProperty thisCp = discriminatorFound(composedSchemaName, (Schema) oneOf, discPropName, visitedSchemas);
Schema oneOfSchema = (Schema) oneOf;
String modelName = ModelUtils.getSimpleRef((oneOfSchema).get$ref());
// Must use a copied set as the oneOf schemas can point to the same discriminator.
Set<String> visitedSchemasCopy = new TreeSet<>(visitedSchemas);
CodegenProperty thisCp = discriminatorFound(oneOfSchema.getName(), oneOfSchema, discPropName, visitedSchemasCopy);
if (thisCp == null) {
once(LOGGER).warn(
"'{}' defines discriminator '{}', but the referenced OneOf schema '{}' is missing {}",
Expand All @@ -3258,8 +3325,11 @@ private CodegenProperty discriminatorFound(String composedSchemaName, Schema sc,
// All anyOf definitions must contain the discriminator because a min of one must be selected
CodegenProperty cp = new CodegenProperty();
for (Object anyOf : composedSchema.getAnyOf()) {
String modelName = ModelUtils.getSimpleRef(((Schema) anyOf).get$ref());
CodegenProperty thisCp = discriminatorFound(composedSchemaName, (Schema) anyOf, discPropName, visitedSchemas);
Schema anyOfSchema = (Schema) anyOf;
String modelName = ModelUtils.getSimpleRef(anyOfSchema.get$ref());
// Must use a copied set as the anyOf schemas can point to the same discriminator.
Set<String> visitedSchemasCopy = new TreeSet<>(visitedSchemas);
CodegenProperty thisCp = discriminatorFound(anyOfSchema.getName(), anyOfSchema, discPropName, visitedSchemasCopy);
if (thisCp == null) {
once(LOGGER).warn(
"'{}' defines discriminator '{}', but the referenced AnyOf schema '{}' is missing {}",
Expand Down Expand Up @@ -3542,13 +3612,11 @@ protected CodegenDiscriminator createDiscriminator(String schemaName, Schema sch
discriminator.setPropertyType(propertyType);

// check to see if the discriminator property is an enum string
if (schema.getProperties() != null &&
schema.getProperties().get(discriminatorPropertyName) instanceof StringSchema) {
StringSchema s = (StringSchema) schema.getProperties().get(discriminatorPropertyName);
if (s.getEnum() != null && !s.getEnum().isEmpty()) { // it's an enum string
discriminator.setIsEnum(true);
}
}
boolean isEnum = Optional
.ofNullable(discriminatorFound(schemaName, schema, discriminatorPropertyName, new TreeSet<>()))
.map(CodegenProperty::getIsEnum)
.orElse(false);
discriminator.setIsEnum(isEnum);

discriminator.setMapping(sourceDiscriminator.getMapping());
List<MappedModel> uniqueDescendants = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,7 @@ public CodegenModel fromModel(String name, Schema model) {

// additional import for different cases
addAdditionalImports(codegenModel, codegenModel.getComposedSchemas());
setEnumDiscriminatorDefaultValue(codegenModel);
return codegenModel;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ public String apiTestFileFolder() {
public CodegenModel fromModel(String name, Schema model) {
Map<String, Schema> allDefinitions = ModelUtils.getSchemas(this.openAPI);
CodegenModel codegenModel = super.fromModel(name, model);
setEnumDiscriminatorDefaultValue(codegenModel);
if (allDefinitions != null && codegenModel != null && codegenModel.parent != null) {
final Schema<?> parentModel = allDefinitions.get(toModelName(codegenModel.parent));
if (parentModel != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,12 @@ public static String getParentName(Schema composedSchema, Map<String, Schema> al
* @return the name of the parent model
*/
public static List<String> getAllParentsName(Schema composedSchema, Map<String, Schema> allSchemas, boolean includeAncestors) {
return getAllParentsName(composedSchema, allSchemas, includeAncestors, new HashSet<>());
}

// Use a set of seen names to avoid infinite recursion
private static List<String> getAllParentsName(
Schema composedSchema, Map<String, Schema> allSchemas, boolean includeAncestors, Set<String> seenNames) {
List<Schema> interfaces = getInterfaces(composedSchema);
List<String> names = new ArrayList<String>();

Expand All @@ -1619,6 +1625,10 @@ public static List<String> getAllParentsName(Schema composedSchema, Map<String,
// get the actual schema
if (StringUtils.isNotEmpty(schema.get$ref())) {
String parentName = getSimpleRef(schema.get$ref());
if (seenNames.contains(parentName)) {
continue;
}
seenNames.add(parentName);
Schema s = allSchemas.get(parentName);
if (s == null) {
LOGGER.error("Failed to obtain schema from {}", parentName);
Expand All @@ -1627,7 +1637,7 @@ public static List<String> getAllParentsName(Schema composedSchema, Map<String,
// discriminator.propertyName is used or x-parent is used
names.add(parentName);
if (includeAncestors && isComposedSchema(s)) {
names.addAll(getAllParentsName(s, allSchemas, true));
names.addAll(getAllParentsName(s, allSchemas, true, seenNames));
}
} else {
// not a parent since discriminator.propertyName is not set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ public class {{classname}} {{#parent}}extends {{{.}}} {{/parent}}{{#vendorExtens
{{/parcelableModel}}
{{/parent}}
{{#discriminator}}
{{#discriminator.isEnum}}
{{#readWriteVars}}{{#isDiscriminator}}{{#defaultValue}}
this.{{name}} = {{defaultValue}};
{{/defaultValue}}{{/isDiscriminator}}{{/readWriteVars}}
{{/discriminator.isEnum}}
{{^discriminator.isEnum}}
this.{{{discriminatorName}}} = this.getClass().getSimpleName();
{{/discriminator.isEnum}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,46 @@ public void test31specAdditionalPropertiesOfOneOf() throws IOException {
assertFileContains(modelFile.toPath(),
" Dictionary<string, ResponseResultsValue> results = default(Dictionary<string, ResponseResultsValue>");
}

@Test
public void testEnumDiscriminatorDefaultValueIsNotString() throws IOException {
File output = Files.createTempDirectory("test").toFile().getCanonicalFile();
output.deleteOnExit();
final OpenAPI openAPI = TestUtils.parseFlattenSpec(
"src/test/resources/3_0/enum_discriminator_inheritance.yaml");
final DefaultGenerator defaultGenerator = new DefaultGenerator();
final ClientOptInput clientOptInput = new ClientOptInput();
clientOptInput.openAPI(openAPI);
CSharpClientCodegen cSharpClientCodegen = new CSharpClientCodegen();
cSharpClientCodegen.setOutputDir(output.getAbsolutePath());
cSharpClientCodegen.setAutosetConstants(true);
clientOptInput.config(cSharpClientCodegen);
defaultGenerator.opts(clientOptInput);

Map<String, File> files = defaultGenerator.generate().stream()
.collect(Collectors.toMap(File::getPath, Function.identity()));

Map<String, String> expectedContents = Map.of(
"Cat", "PetTypeEnum petType = PetTypeEnum.Catty",
"Dog", "PetTypeEnum petType = PetTypeEnum.Dog",
"Gecko", "PetTypeEnum petType = PetTypeEnum.Gecko",
"Chameleon", "PetTypeEnum petType = PetTypeEnum.Camo",
"MiniVan", "CarType carType = CarType.MiniVan",
"CargoVan", "CarType carType = CarType.CargoVan",
"SUV", "CarType carType = CarType.SUV",
"Truck", "CarType carType = CarType.Truck",
"Sedan", "CarType carType = CarType.Sedan"

);
for (Map.Entry<String, String> e : expectedContents.entrySet()) {
String modelName = e.getKey();
String expectedContent = e.getValue();
File file = files.get(Paths
.get(output.getAbsolutePath(), "src", "Org.OpenAPITools", "Model", modelName + ".cs")
.toString()
);
assertNotNull(file, "Could not find file for model: " + modelName);
assertFileContains(file.toPath(), expectedContent);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2408,6 +2408,39 @@ public void testOpenapiGeneratorIgnoreListOption() {
assertNull(files.get("pom.xml"));
}

@Test
public void testEnumDiscriminatorDefaultValueIsNotString() {
final Path output = newTempFolder();
final OpenAPI openAPI = TestUtils.parseFlattenSpec(
"src/test/resources/3_0/enum_discriminator_inheritance.yaml");
JavaClientCodegen codegen = new JavaClientCodegen();
codegen.setOutputDir(output.toString());

Map<String, File> files = new DefaultGenerator().opts(new ClientOptInput().openAPI(openAPI).config(codegen))
.generate().stream().collect(Collectors.toMap(File::getName, Function.identity()));

Map<String, String> expectedContents = Map.of(
"Cat", "this.petType = PetTypeEnum.CATTY",
"Dog", "this.petType = PetTypeEnum.DOG",
"Gecko", "this.petType = PetTypeEnum.GECKO",
"Chameleon", "this.petType = PetTypeEnum.CAMO",
"MiniVan", "this.carType = CarType.MINI_VAN",
"CargoVan", "this.carType = CarType.CARGO_VAN",
"SUV", "this.carType = CarType.SUV",
"Truck", "this.carType = CarType.TRUCK",
"Sedan", "this.carType = CarType.SEDAN"

);
for (Map.Entry<String, String> e : expectedContents.entrySet()) {
String modelName = e.getKey();
String expectedContent = e.getValue();
File entityFile = files.get(modelName + ".java");
assertNotNull(entityFile);
assertThat(entityFile).content().doesNotContain("Type = this.getClass().getSimpleName();");
assertThat(entityFile).content().contains(expectedContent);
}
}

@Test
public void testRestTemplateHandleURIEnum() {
String[] expectedInnerEnumLines = new String[]{
Expand Down
Loading

0 comments on commit c75fbb3

Please sign in to comment.