-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtune_plot.R
executable file
·29 lines (23 loc) · 1.45 KB
/
tune_plot.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env RScript
# This R script can be used to visualize the output of the hyperparameter tuning step in train_RF.R
# param1: The path to a TSV containing the results of tuning the hyperparameters
# For each instance of cross-validation (rows), there should be the following cols: hyperparam vals (2 cols), mean test accuracy
# param2 (optional): The path to a png file in which to store a visualization of the test accuracy for different hyperparam values. If not specified, stdout will be used
args <- commandArgs(trailingOnly = TRUE)
results<- args[1]
out<- args[2]
# load libraries
suppressMessages(library(mlr))
library(BBmisc)
library(ggplot2)
library(R.devices)
# load data.frame
print("loading results of hyperparam tuning into R")
results<- read.table(results, header=TRUE, sep="\t", na.strings=c("NA",".","na","N/A"), row.names=NULL)
data = makeS3Obj("HyperParsEffectData", data = results, measures=c("fbeta.test.mean"), hyperparameters=c("mtry", "min.node.size"), partial=F, nested=F)
plt = plotHyperParsEffect(data, x = "mtry", y = "min.node.size", z = "fbeta.test.mean", plot.type = "heatmap", show.experiments = TRUE)
# min_plt = min(data$data$acc.test.mean, na.rm = TRUE)
# max_plt = max(data$data$acc.test.mean, na.rm = TRUE)
# med_plt = mean(c(min_plt, max_plt))
# plt = plt + scale_fill_gradient2(breaks = seq(min_plt, max_plt, length.out = 5), low = "blue", mid = "white", high = "red", midpoint = med_plt)
suppressGraphics(ggsave(out, plt))