Skip to content

Commit

Permalink
Allow to generate getters returning Optional for nullable fields (#40)
Browse files Browse the repository at this point in the history
* Allow to generate getters with Optional for nullable fields

* Set avroOptionalGetters to false by default

* Bugfix when generating idls

* Use avroOptionalGetters settings only for avro 1.10+

Co-authored-by: Eric Palacios <eric.palacios@xing.com>
Co-authored-by: Michel Davit <michel.davit@gmail.com>
  • Loading branch information
3 people authored Oct 12, 2020
1 parent 030821b commit 914213a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 15 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ libraryDependencies += "org.apache.avro" % "avro" % "1.10.0"
| `avroUseNamespace` | `false` | Validate that directory layout reflects namespaces, i.e. `com/myorg/MyRecord.avsc`. |
| `avroFieldVisibility` | `public_deprecated` | Field Visibility for the properties. Possible values: `private`, `public`, `public_deprecated`. |
| `avroEnableDecimalLogicalType` | `true` | Set to true to use `java.math.BigDecimal` instead of `java.nio.ByteBuffer` for logical type `decimal`. |
| `avroOptionalGetters` | `false` (requires avro `1.10+`) | Set to true to generate getters that return `Optional` for nullable fields. |

## Examples

Expand Down
15 changes: 11 additions & 4 deletions src/main/java/com/spotify/avro/mojo/AvscFilesCompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;

public class AvscFilesCompiler {
Expand All @@ -27,6 +24,7 @@ public class AvscFilesCompiler {
private boolean useNamespace;
private boolean enableDecimalLogicalType;
private boolean createSetters;
private Optional<Boolean> optionalGetters = Optional.empty();
private Map<AvroFileRef, Exception> compileExceptions;
private boolean logCompileExceptions;

Expand Down Expand Up @@ -95,6 +93,11 @@ private boolean tryCompile(AvroFileRef src, File outputDirectory) {
compiler.setFieldVisibility(fieldVisibility);
compiler.setEnableDecimalLogicalType(enableDecimalLogicalType);
compiler.setCreateSetters(createSetters);
if (optionalGetters.isPresent()) {
compiler.setGettersReturnOptional(optionalGetters.get());
compiler.setOptionalGettersForNullableFieldsOnly(optionalGetters.get());
}

try {
compiler.compileToDestination(src.getFile(), outputDirectory);
} catch (IOException e) {
Expand Down Expand Up @@ -163,4 +166,8 @@ public void setCreateSetters(boolean createSetters) {
public void setLogCompileExceptions(final boolean logCompileExceptions) {
this.logCompileExceptions = logCompileExceptions;
}

public void setOptionalGetters(final boolean optionalGetters) {
this.optionalGetters = Optional.of(optionalGetters);
}
}
36 changes: 26 additions & 10 deletions src/main/scala/sbtavro/SbtAvro.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package sbtavro

import java.io.File
import java.util.concurrent.atomic.AtomicReference

import org.apache.avro.Protocol
import org.apache.avro.compiler.idl.Idl
import org.apache.avro.compiler.specific.SpecificCompiler
import org.apache.avro.compiler.specific.SpecificCompiler.FieldVisibility
import org.apache.avro.generic.GenericData.StringType
import org.apache.avro.{Protocol, Schema}
import sbt.Keys._
import sbt._
import Path.relativeTo
import CrossVersion.partialVersion
import com.spotify.avro.mojo.{AvroFileRef, SchemaParserBuilder}
import sbt.librarymanagement.DependencyFilter

Expand All @@ -21,6 +21,8 @@ object SbtAvro extends AutoPlugin {

val AvroClassifier = "avro"

private val avroCompilerVersion = classOf[SpecificCompiler].getPackage.getImplementationVersion

private val AvroAvrpFilter: NameFilter = "*.avpr"
private val AvroAvdlFilter: NameFilter = "*.avdl"
private val AvroAvscFilter: NameFilter = "*.avsc"
Expand All @@ -37,6 +39,7 @@ object SbtAvro extends AutoPlugin {
val avroEnableDecimalLogicalType = settingKey[Boolean]("Set to true to use java.math.BigDecimal instead of java.nio.ByteBuffer for logical type \"decimal\".")
val avroFieldVisibility = settingKey[String]("Field visibility for the properties. Possible values: private, public, public_deprecated. Default: public_deprecated.")
val avroUseNamespace = settingKey[Boolean]("Validate that directory layout reflects namespaces, i.e. src/main/avro/com/myorg/MyRecord.avsc.")
val avroOptionalGetters = settingKey[Boolean]("Set to true to generate getters that return Optional for nullable fields")
val avroSource = settingKey[File]("Default Avro source directory.")
val avroIncludes = settingKey[Seq[File]]("Avro schema includes.")
val avroSchemaParserBuilder = settingKey[SchemaParserBuilder](".avsc schema parser builder")
Expand Down Expand Up @@ -92,6 +95,7 @@ object SbtAvro extends AutoPlugin {
avroFieldVisibility := "public_deprecated",
avroEnableDecimalLogicalType := true,
avroUseNamespace := false,
avroOptionalGetters := false,
avroSchemaParserBuilder := DefaultSchemaParserBuilder.default()
)

Expand Down Expand Up @@ -135,7 +139,7 @@ object SbtAvro extends AutoPlugin {
)
}

def compileIdls(idls: Seq[File], target: File, log: Logger, stringType: StringType, fieldVisibility: FieldVisibility, enableDecimalLogicalType: Boolean) = {
def compileIdls(idls: Seq[File], target: File, log: Logger, stringType: StringType, fieldVisibility: FieldVisibility, enableDecimalLogicalType: Boolean, optionalGetters: Option[Boolean]) = {
idls.foreach { idl =>
log.info(s"Compiling Avro IDL $idl")
val parser = new Idl(idl)
Expand All @@ -144,19 +148,23 @@ object SbtAvro extends AutoPlugin {
compiler.setStringType(stringType)
compiler.setFieldVisibility(fieldVisibility)
compiler.setEnableDecimalLogicalType(enableDecimalLogicalType)
optionalGetters.foreach(compiler.setGettersReturnOptional)
optionalGetters.foreach(compiler.setOptionalGettersForNullableFieldsOnly)
compiler.compileToDestination(null, target)
}
}

def compileAvscs(refs: Seq[AvroFileRef], target: File, log: Logger, stringType: StringType, fieldVisibility: FieldVisibility, enableDecimalLogicalType: Boolean, useNamespace: Boolean, builder: SchemaParserBuilder) = {
def compileAvscs(refs: Seq[AvroFileRef], target: File, log: Logger, stringType: StringType, fieldVisibility: FieldVisibility, enableDecimalLogicalType: Boolean, useNamespace: Boolean, optionalGetters: Option[Boolean], builder: SchemaParserBuilder) = {
import com.spotify.avro.mojo._

import scala.collection.JavaConverters._
val compiler = new AvscFilesCompiler(builder)
compiler.setStringType(stringType)
compiler.setFieldVisibility(fieldVisibility)
compiler.setUseNamespace(useNamespace)
compiler.setEnableDecimalLogicalType(enableDecimalLogicalType)
compiler.setCreateSetters(true)
optionalGetters.foreach(compiler.setOptionalGetters)
compiler.setLogCompileExceptions(true)
compiler.setTemplateDirectory("/org/apache/avro/compiler/specific/templates/java/classic/")

Expand All @@ -166,14 +174,16 @@ object SbtAvro extends AutoPlugin {
compiler.compileFiles(refs.toSet.asJava, target)
}

def compileAvprs(avprs: Seq[File], target: File, log: Logger, stringType: StringType, fieldVisibility: FieldVisibility, enableDecimalLogicalType: Boolean) = {
def compileAvprs(avprs: Seq[File], target: File, log: Logger, stringType: StringType, fieldVisibility: FieldVisibility, enableDecimalLogicalType: Boolean, optionalGetters: Option[Boolean]) = {
avprs.foreach { avpr =>
log.info(s"Compiling Avro protocol $avpr")
val protocol = Protocol.parse(avpr)
val compiler = new SpecificCompiler(protocol)
compiler.setStringType(stringType)
compiler.setFieldVisibility(fieldVisibility)
compiler.setEnableDecimalLogicalType(enableDecimalLogicalType)
optionalGetters.foreach(compiler.setGettersReturnOptional)
optionalGetters.foreach(compiler.setOptionalGettersForNullableFieldsOnly)
compiler.compileToDestination(null, target)
}
}
Expand All @@ -185,19 +195,21 @@ object SbtAvro extends AutoPlugin {
fieldVisibility: FieldVisibility,
enableDecimalLogicalType: Boolean,
useNamespace: Boolean,
optionalGetters: Option[Boolean],
builder: SchemaParserBuilder): Set[File] = {
val avdls = srcDirs.flatMap(d => (d ** AvroAvdlFilter).get)
val avscs = srcDirs.flatMap(d => (d ** AvroAvscFilter).get.map(avsc => new AvroFileRef(d, avsc.relativeTo(d).get.toString)))
val avprs = srcDirs.flatMap(d => (d ** AvroAvrpFilter).get)

compileIdls(avdls, target, log, stringType, fieldVisibility, enableDecimalLogicalType)
compileAvscs(avscs, target, log, stringType, fieldVisibility, enableDecimalLogicalType, useNamespace, builder)
compileAvprs(avprs, target, log, stringType, fieldVisibility, enableDecimalLogicalType)
compileIdls(avdls, target, log, stringType, fieldVisibility, enableDecimalLogicalType, optionalGetters)
compileAvscs(avscs, target, log, stringType, fieldVisibility, enableDecimalLogicalType, useNamespace, optionalGetters, builder)
compileAvprs(avprs, target, log, stringType, fieldVisibility, enableDecimalLogicalType, optionalGetters)

(target ** JavaFileFilter).get.toSet
}

private def sourceGeneratorTask(key: TaskKey[Seq[File]]) = Def.task {

val out = (key / streams).value
val srcDir = avroSource.value
val externalSrcDir = (avroUnpackDependencies / target).value
Expand All @@ -208,11 +220,15 @@ object SbtAvro extends AutoPlugin {
val fieldVis = SpecificCompiler.FieldVisibility.valueOf(avroFieldVisibility.value.toUpperCase)
val enbDecimal = avroEnableDecimalLogicalType.value
val useNs = avroUseNamespace.value
val optionalGetters = partialVersion(avroCompilerVersion) match {
case Some((1, minor)) if minor >= 10 => Some(avroOptionalGetters.value)
case _ => None
}
val builder = avroSchemaParserBuilder.value
val cachedCompile = {
FileFunction.cached(out.cacheDirectory / "avro", FilesInfo.lastModified, FilesInfo.exists) { _ =>
out.log.info(s"Avro compiler using stringType=$strType")
compileAvroSchema(srcDirs, outDir, out.log, strType, fieldVis, enbDecimal, useNs, builder)
out.log.info(s"Avro compiler $avroCompilerVersion using stringType=$strType")
compileAvroSchema(srcDirs, outDir, out.log, strType, fieldVis, enbDecimal, useNs, optionalGetters, builder)
}
}

Expand Down
12 changes: 11 additions & 1 deletion src/test/scala/sbtavro/SbtAvroSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,17 @@ class SbtAvroSpec extends Specification {
_eJavaFile.delete()

val refs = sourceFiles.map(s => new AvroFileRef(sourceDir, s.getName))
SbtAvro.compileAvscs(refs, targetDir, logger, StringType.CharSequence, FieldVisibility.PUBLIC_DEPRECATED, true, false, builder)
SbtAvro.compileAvscs(
refs = refs,
target = targetDir,
log = logger,
stringType = StringType.CharSequence,
fieldVisibility = FieldVisibility.PUBLIC_DEPRECATED,
enableDecimalLogicalType = true,
useNamespace = false,
optionalGetters = None,
builder = builder
)

aJavaFile.isFile must beTrue
bJavaFile.isFile must beTrue
Expand Down

0 comments on commit 914213a

Please sign in to comment.