-
Notifications
You must be signed in to change notification settings - Fork 685
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make benchmark a standalone module (#1043)
Change-Id: I3b169b2c27239ae57be09ca63cb24e032d32e4f6
- Loading branch information
Showing
14 changed files
with
324 additions
and
234 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
plugins { | ||
id 'application' | ||
} | ||
|
||
dependencies { | ||
implementation "commons-cli:commons-cli:${commons_cli_version}" | ||
implementation "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}" | ||
implementation project(":model-zoo") | ||
|
||
runtimeOnly project(":pytorch:pytorch-model-zoo") | ||
runtimeOnly "ai.djl.pytorch:pytorch-native-auto:1.8.1" | ||
|
||
// javacpp bug fix https://github.com/bytedeco/javacpp/commit/7f27899578dfa18e22738a3dd49701e1806b464a | ||
runtimeOnly "org.bytedeco:javacpp:1.5.6-SNAPSHOT" | ||
runtimeOnly(project(":tensorflow:tensorflow-model-zoo")) { | ||
exclude group: "org.bytedeco", module: "javacpp" | ||
} | ||
runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto:${tensorflow_version}" | ||
|
||
runtimeOnly project(":mxnet:mxnet-model-zoo") | ||
runtimeOnly "ai.djl.mxnet:mxnet-native-auto:${mxnet_version}" | ||
|
||
runtimeOnly project(":tflite:tflite-engine") | ||
runtimeOnly "ai.djl.tflite:tflite-native-auto:${tflite_version}" | ||
|
||
runtimeOnly project(":paddlepaddle:paddlepaddle-model-zoo") | ||
runtimeOnly "ai.djl.paddlepaddle:paddlepaddle-native-auto:${paddlepaddle_version}" | ||
|
||
// onnxruntime requires user install libgomp.so.1 manually, exclude from default dependency | ||
runtimeOnly project(":onnxruntime:onnxruntime-engine") | ||
|
||
runtimeOnly project(":dlr:dlr-engine") | ||
runtimeOnly "ai.djl.dlr:dlr-native-auto:${dlr_version}" | ||
|
||
runtimeOnly(project(":ml:xgboost")) { | ||
exclude group: "ml.dmlc", module: "xgboost4j_2.12" | ||
} | ||
|
||
testImplementation("org.testng:testng:${testng_version}") { | ||
exclude group: "junit", module: "junit" | ||
} | ||
} | ||
|
||
application { | ||
mainClassName = System.getProperty("main", "ai.djl.benchmark.Benchmark") | ||
} | ||
|
||
run { | ||
environment("TF_CPP_MIN_LOG_LEVEL", "1") // turn off TensorFlow print out | ||
systemProperties System.getProperties() | ||
systemProperties.remove("user.dir") | ||
systemProperty("file.encoding", "UTF-8") | ||
} | ||
|
||
task benchmark(type: JavaExec) { | ||
environment("TF_CPP_MIN_LOG_LEVEL", "1") // turn off TensorFlow print out | ||
List<String> arguments = gradle.startParameter["taskRequests"]["args"].getAt(0) | ||
for (String argument : arguments) { | ||
if (argument.trim().startsWith("--args")) { | ||
String[] line = argument.split("=", 2) | ||
if (line.length == 2) { | ||
line = line[1].split(" ") | ||
if (line.contains("-t")) { | ||
if (System.getProperty("ai.djl.default_engine") == "TensorFlow") { | ||
environment("OMP_NUM_THREADS", "1") | ||
environment("TF_NUM_INTRAOP_THREADS", "1") | ||
} else { | ||
environment("MXNET_ENGINE_TYPE", "NaiveEngine") | ||
environment("OMP_NUM_THREADS", "1") | ||
} | ||
} | ||
break | ||
} | ||
} | ||
} | ||
|
||
systemProperties System.getProperties() | ||
systemProperties.remove("user.dir") | ||
systemProperty("file.encoding", "UTF-8") | ||
classpath = sourceSets.main.runtimeClasspath | ||
// restrict the jvm heap size for better monitoring benchmark | ||
jvmArgs = ["-Xmx2g"] | ||
if (Boolean.getBoolean("loggc")) { | ||
if (JavaVersion.current() == JavaVersion.VERSION_1_8) { | ||
jvmArgs += ["-XX:+PrintGCTimeStamps", "-Xloggc:build/gc.log"] | ||
} else { | ||
jvmArgs += ["-Xlog:gc*=debug:file=build/gc.log"] | ||
} | ||
} | ||
main = "ai.djl.benchmark.Benchmark" | ||
} | ||
|
||
startScripts { | ||
defaultJvmOpts = [] | ||
doLast { | ||
String replacement = 'CLASSPATH=\\$APP_HOME/lib/*\n\n' + | ||
'if [[ "\\$*" == *-t* || "\\$*" == *--threads* ]]\n' + | ||
'then\n' + | ||
' export TF_CPP_MIN_LOG_LEVEL=1\n' + | ||
' export MXNET_ENGINE_TYPE=NaiveEngine\n' + | ||
' export OMP_NUM_THREADS=1\n' + | ||
' export TF_NUM_INTRAOP_THREADS=1\n' + | ||
'fi' | ||
|
||
String text = unixScript.text.replaceAll('CLASSPATH=\\$APP_HOME/lib/.*', replacement) | ||
unixScript.text = text | ||
} | ||
} | ||
|
||
tasks.distZip.enabled = false |
Oops, something went wrong.