Skip to content

Commit

Permalink
Fixing code generation bug with underscores in names (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
vpatryshev authored and tovbinm committed Jan 15, 2019
1 parent 9898c76 commit c0ed618
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 18 deletions.
17 changes: 17 additions & 0 deletions cli/passengers_.answers
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
'age' - what kind of feature is this? [0] integral [1] categorical: => 0
'gender' - what kind of feature is this? [0] text [1] categorical: => 1
'height' - what kind of feature is this? [0] integral [1] categorical: => 0
'weight' - what kind of feature is this? [0] integral [1] categorical: => 0
'description' - what kind of feature is this? [0] text [1] categorical: => 0
'boarded' - what kind of feature is this? [0] integral [1] categorical: => 1
'recordDate' - what kind of feature is this? [0] integral [1] categorical: => 0
'Survived' - what kind of feature is this? [0] integral [1] categorical: => 1
'P_class' - what kind of feature is this? [0] integral [1] categorical: => 0
'Name' - what kind of feature is this? [0] text [1] categorical: => 0
'Sex' - what kind of feature is this? [0] text [1] categorical: => 1
'SibSp' - what kind of feature is this? [0] integral [1] categorical: => 0
'Parch' - what kind of feature is this? [0] integral [1] categorical: => 0
'Ticket' - what kind of feature is this? [0] text [1] categorical: => 1
'Cabin' - what kind of feature is this? [0] text [1] categorical: => 0
'Embarked' - what kind of feature is this? [0] text [1] categorical: => 0
Cannot infer the kind of problem based on response field 'Survived'. What kind of problem is this? => binclass
11 changes: 8 additions & 3 deletions cli/src/main/scala/com/salesforce/op/cli/CliExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,21 @@ class CliExec {
}

def main(args: Array[String]): Unit = try {
val outcome = for {
val ops = for {
arguments <- CommandParser.parse(args, CliParameters())
if arguments.command == "gen"
settings <- arguments.values
} yield Ops(settings).run()
} yield Ops(settings)

outcome getOrElse {
ops getOrElse {
CommandParser.showUsage()
quit("wrong arguments", 1)
}

val outcome = ops.map (_.run())

outcome getOrElse quit("Generation failed; see error messages", 1)

} catch {
case x: Exception =>
if (DEBUG) x.printStackTrace()
Expand Down
9 changes: 6 additions & 3 deletions cli/src/main/scala/com/salesforce/op/cli/CommandParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ trait OpCli {
else success
}

private[cli] def tweak(file: File): File =
new File(file.getPath.replace("`pwd`/", ""))

def fileExists(file: File): Either[String, Unit] = {
if (!file.exists()) failure(s"File '${file.getAbsolutePath}' not found")
else success
val f = tweak(file)
if (f.exists()) success else failure(s"File '${f.getAbsolutePath}' not found")
}

private val Identifier = "([a-zA-Z]\\w*)".r
Expand Down Expand Up @@ -80,7 +83,7 @@ object CommandParser extends scopt.OptionParser[CliParameters]("transmogrifai")
.text("Input file for the TransmogrifAI project [required]")
.validate(fileExists)
.required
.action((inputFile, cfg) => cfg.copy(inputFile = Option(inputFile))),
.action((inputFile, cfg) => cfg.copy(inputFile = Option(tweak(inputFile)))),

opt[String]("id")
.text("Name for the ID field [required]")
Expand Down
11 changes: 7 additions & 4 deletions cli/src/main/scala/com/salesforce/op/cli/gen/ProblemSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ package com.salesforce.op.cli.gen
import java.io.File

import com.salesforce.op.cli.SchemaSource
import org.apache.avro.Schema

import scala.collection.JavaConverters._
import scala.io.Source
import AvroField._

Expand Down Expand Up @@ -83,7 +81,7 @@ object ProblemSchema {

val orf = MakeRawFeature(ops)

val (responseFeature :: features) = orderedFields.map { field =>
val responseFeature :: features = orderedFields.map { field =>
orf.from(field, schemaSource.name, field == responseField)
}

Expand Down Expand Up @@ -144,10 +142,15 @@ sealed trait OPRawFeature {

/**
* Gets the java method that Avro generates for this field. e.g. `getPassengerId` for field with name "passengerId"
* Note that variable names with underscores are converted to CamelCase
* by avro; so we should do the same
*
* @return The java getter as a string
*/
def avroGetter: String = s"get${avroField.name.capitalize}"
def avroGetter: String = {
val pieces = avroField.name.split("_").map(_.capitalize)
s"get${pieces.mkString("")}"
}

/**
* Gets a name corresponding to the scala `val` that will be generated for this feature builder, e.g.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ package com.salesforce.op.cli

import language.postfixOps
import java.io.File
import java.nio.file.Paths

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

import scala.io.Source

/**
* Test for generator operations
*/
Expand Down Expand Up @@ -89,7 +92,7 @@ class CliCodeGenerationTest extends CliTestBase {
assertResult(result, Succeeded)
}

it should "read answers from a file" in {
it should "generage code, answers in a file" in {
val sut = new Sut()
val result = sut.run(
"gen",
Expand All @@ -106,6 +109,33 @@ class CliCodeGenerationTest extends CliTestBase {
}
}

it should "not fail when fields have underscores in names" in {
val sut = new Sut()
val result = sut.run(
"gen",
"--input", TestCsvHeadless,
"--id", "passengerId",
"--response", "survived",
"--schema", TestAvscWithUnderscores,
"--answers", AnswersFileWithUnderscores,
ProjectName,
"--overwrite"
)
withClue(result.err) {
result.outcome shouldBe Succeeded
}
val scalaSourcesFolder = Paths.get(projectFolder, "src", "main", "scala", "com", "salesforce", "app")

val featuresFile = Source.fromFile(new File(scalaSourcesFolder.toFile, "Features.scala")).getLines
val testLines = featuresFile.dropWhile(!_.contains("val p_class = FB"))
testLines.hasNext shouldBe true
testLines.next
val integralPassenger = testLines.next
integralPassenger.trim shouldBe ".Integral[Passenger]"
val thisOneShouldNotHaveUnderscoreInGetter = testLines.next
thisOneShouldNotHaveUnderscoreInGetter.trim shouldBe ".extract(_.getPClass.toIntegral)"
}

it should "work with autogeneration" in {
val sut = new Sut
val result = sut.run(
Expand Down Expand Up @@ -147,7 +177,6 @@ class CliCodeGenerationTest extends CliTestBase {
result.err should include("Bad data file")
val folder = new File(ProjectName.toLowerCase)
folder.exists() shouldBe false

}

it should "complain properly if neither avro nor auto is specified" in {
Expand Down
13 changes: 8 additions & 5 deletions cli/src/test/scala/com/salesforce/op/cli/CliTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import com.salesforce.op.OpWorkflowRunType
import com.salesforce.op.test.TestCommon
import org.scalactic.source
import org.scalatest.{Assertion, Assertions, BeforeAndAfter, FlatSpec}
import org.slf4j.LoggerFactory
import org.slf4j.{Logger, LoggerFactory}

import scala.io.Source
import scala.language.postfixOps
Expand All @@ -49,7 +49,8 @@ class CliTestBase extends FlatSpec with TestCommon with Assertions with BeforeAn
CommandParser.AUTO_ENABLED = true

protected val ProjectName = "CliGeneratedTestProject"
val log = LoggerFactory.getLogger("cli-test")
protected val projectFolder: String = ProjectName.toLowerCase
val log: Logger = LoggerFactory.getLogger("cli-test")

trait Outcome
case class Crashed(msg: String, code: Int) extends Throwable with Outcome {
Expand Down Expand Up @@ -107,9 +108,9 @@ class CliTestBase extends FlatSpec with TestCommon with Assertions with BeforeAn

after { new Sut().delete(new File(ProjectName)) }

val expectedSourceFiles = "Features.scala" :: s"$ProjectName.scala"::Nil
val expectedSourceFiles: List[String] = "Features.scala" :: s"$ProjectName.scala"::Nil

val projectDir = ProjectName.toLowerCase
val projectDir: String = ProjectName.toLowerCase

def checkAvroFile(source: File): Unit = {
val avroFile = Paths.get(projectDir, "src", "main", "avro", source.getName).toFile
Expand Down Expand Up @@ -144,13 +145,15 @@ class CliTestBase extends FlatSpec with TestCommon with Assertions with BeforeAn
}

protected lazy val TestAvsc: String = findFile("test-data/PassengerDataAll.avsc")
protected lazy val TestAvscWithUnderscores: String = findFile("test-data/PassengerDataAll_.avsc")
protected lazy val TestCsvHeadless: String = findFile("test-data/PassengerDataAll.csv")
protected lazy val TestSmallCsvWithHeaders: String = findFile("test-data/PassengerDataWithHeader.csv")
protected lazy val TestBigCsvWithHeaders: String = findFile("test-data/PassengerDataAllWithHeader.csv")
protected lazy val AvcsSchema: String = findFile("templates/simple/src/main/avro/Passenger.avsc")
protected lazy val AnswersFile: String = findFile("cli/passengers.answers")
protected lazy val AnswersFileWithUnderscores: String = findFile("cli/passengers_.answers")

protected def appRuntimeArgs(runType: OpWorkflowRunType) =
protected def appRuntimeArgs(runType: OpWorkflowRunType): String =
s"--run-type=${runType.toString.toLowerCase} --model-location=/tmp/titanic-model " +
s"--read-location Passenger=$TestCsvHeadless"
}
3 changes: 2 additions & 1 deletion gradlew
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`

# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS=""
# the one below is good for using with newer Java versions that don't have the options gradle tries to pass.
DEFAULT_JVM_OPTS="-XX:+IgnoreUnrecognizedVMOptions"

# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
Expand Down
43 changes: 43 additions & 0 deletions test-data/PassengerDataAll_.avsc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"type" : "record",
"name" : "Passenger",
"namespace" : "com.salesforce.app.schema",
"fields" : [ {
"name" : "PassengerId",
"type" : [ "int", "null" ]
}, {
"name" : "Survived",
"type" : "int",
"default": 0
}, {
"name" : "P_class",
"type" : [ "int", "null" ]
}, {
"name" : "Name",
"type" : [ "string", "null" ]
}, {
"name" : "Sex",
"type" : [ "string", "null" ]
}, {
"name" : "Age",
"type" : [ "double", "null" ]
}, {
"name" : "SibSp",
"type" : [ "int", "null" ]
}, {
"name" : "Parch",
"type" : [ "int", "null" ]
}, {
"name" : "Ticket",
"type" : [ "string", "null" ]
}, {
"name" : "Fare",
"type" : [ "double", "null" ]
}, {
"name" : "Cabin",
"type" : [ "string", "null" ]
}, {
"name" : "Embarked",
"type" : [ "string", "null" ]
} ]
}

0 comments on commit c0ed618

Please sign in to comment.