diff --git a/CHANGELOG.md b/CHANGELOG.md index d9cdfcf8..2798c806 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ Changelog ========= +Version 0.83.1 +-------------- +* add test module to nkuluflag + Version 0.83.0 -------------- * test module now prints out reports diff --git a/nkululeko/constants.py b/nkululeko/constants.py index f47062a9..b2246bd4 100644 --- a/nkululeko/constants.py +++ b/nkululeko/constants.py @@ -1,2 +1,2 @@ -VERSION="0.83.0" +VERSION="0.83.1" SAMPLING_RATE = 16000 diff --git a/nkululeko/experiment.py b/nkululeko/experiment.py index 899f578b..12353895 100644 --- a/nkululeko/experiment.py +++ b/nkululeko/experiment.py @@ -675,7 +675,8 @@ def predict_test_and_save(self, result_name): test_predictor = TestPredictor( model, self.df_test, self.label_encoder, result_name ) - test_predictor.predict_and_store() + result = test_predictor.predict_and_store() + return result def load(self, filename): f = open(filename, "rb") diff --git a/nkululeko/nkuluflag.py b/nkululeko/nkuluflag.py index 5603bcff..a827b3e6 100644 --- a/nkululeko/nkuluflag.py +++ b/nkululeko/nkuluflag.py @@ -2,13 +2,16 @@ import configparser import os import os.path +import sys from nkululeko.nkululeko import doit as nkulu +from nkululeko.test import do_it as test_mod -def do_it(src_dir): +def doit(cla): parser = argparse.ArgumentParser(description="Call the nkululeko framework.") parser.add_argument("--config", help="The base configuration") + parser.add_argument("--mod", default="nkulu", help="Which nkululeko module to call") parser.add_argument("--data", help="The databases", nargs="*", action="append") parser.add_argument( "--label", nargs="*", help="The labels for the target", action="append" @@ -25,20 +28,23 @@ def do_it(src_dir): parser.add_argument("--model", default="xgb", help="The model type") parser.add_argument("--feat", default="['os']", help="The feature type") parser.add_argument("--set", help="The opensmile set") - parser.add_argument("--with_os", help="To add os features") parser.add_argument("--target", help="The target designation") parser.add_argument("--epochs", help="The number of epochs") parser.add_argument("--runs", help="The number of runs") parser.add_argument("--learning_rate", help="The learning rate") parser.add_argument("--drop", help="The dropout rate [0:1]") - args = parser.parse_args() + args = parser.parse_args(cla) if args.config is not None: config_file = args.config else: print("ERROR: need config file") quit(-1) + + if args.mod is not None: + nkulu_mod = args.mod + # test if config is there if not os.path.isfile(config_file): print(f"ERROR: no such file {config_file}") @@ -86,10 +92,17 @@ def do_it(src_dir): with open(tmp_config, "w") as tmp_file: config.write(tmp_file) - result, last_epoch = nkulu(tmp_config) + result, last_epoch = 0, 0 + if nkulu_mod == "nkulu": + result, last_epoch = nkulu(tmp_config) + elif nkulu_mod == "test": + result, last_epoch = test_mod(tmp_config, "test_results.csv") + else: + print(f"ERROR: unknown module: {nkulu_mod}, should be [nkulu | test]") return result, last_epoch if __name__ == "__main__": - cwd = os.path.dirname(os.path.abspath(__file__)) - do_it(cwd) # sys.argv[1]) + cla = sys.argv + cla.pop(0) + doit(cla) # sys.argv[1]) diff --git a/nkululeko/test.py b/nkululeko/test.py index ac1a781c..06462d77 100644 --- a/nkululeko/test.py +++ b/nkululeko/test.py @@ -10,20 +10,7 @@ from nkululeko.utils.util import Util -def main(src_dir): - parser = argparse.ArgumentParser( - description="Call the nkululeko TEST framework.") - parser.add_argument("--config", default="exp.ini", - help="The base configuration") - parser.add_argument( - "--outfile", - default="my_results.csv", - help="File name to store the predictions", - ) - - args = parser.parse_args() - - config_file = args.config +def do_it(config_file, outfile): # test if the configuration file exists if not os.path.isfile(config_file): @@ -48,10 +35,28 @@ def main(src_dir): expr.load(f"{util.get_save_name()}") expr.fill_tests() expr.extract_test_feats() - expr.predict_test_and_save(args.outfile) + result = expr.predict_test_and_save(outfile) print("DONE") + return result, 0 + + +def main(src_dir): + parser = argparse.ArgumentParser(description="Call the nkululeko TEST framework.") + parser.add_argument("--config", default="exp.ini", help="The base configuration") + parser.add_argument( + "--outfile", + default="my_results.csv", + help="File name to store the predictions", + ) + args = parser.parse_args() + if args.config is not None: + config_file = args.config + else: + config_file = f"{src_dir}/exp.ini" + do_it(config_file, args.outfile) + if __name__ == "__main__": cwd = os.path.dirname(os.path.abspath(__file__)) diff --git a/nkululeko/test_predictor.py b/nkululeko/test_predictor.py index 1579cedb..0cfb68a5 100644 --- a/nkululeko/test_predictor.py +++ b/nkululeko/test_predictor.py @@ -29,6 +29,7 @@ def __init__(self, model, orig_df, labenc, name): def predict_and_store(self): label_data = self.util.config_val("DATA", "label_data", False) + result = 0 if label_data: data = Dataset(label_data) data.load() @@ -57,6 +58,7 @@ def predict_and_store(self): test_dbs_string = "_".join(test_dbs) predictions = self.model.get_predictions() report = self.model.predict() + result = report.result.get_result() report.set_filename_add(f"test-{test_dbs_string}") self.util.print_best_results([report]) report.plot_confmatrix(self.util.get_plot_name(), 0) @@ -74,3 +76,4 @@ def predict_and_store(self): df = df.rename(columns={"class_label": target}) df.to_csv(self.name) self.util.debug(f"results stored in {self.name}") + return result diff --git a/tests/exp_emodb_audmodel_xgb.ini b/tests/exp_emodb_audmodel_xgb.ini index ee6cd684..e4ae2a9c 100644 --- a/tests/exp_emodb_audmodel_xgb.ini +++ b/tests/exp_emodb_audmodel_xgb.ini @@ -8,7 +8,7 @@ save = True databases = ['emodb'] emodb = ./data/emodb/emodb emodb.split_strategy = random -emodb.limit_samples = 50 +emodb.limit_samples = 200 emodb.mapping = {'anger':'angry', 'happiness':'happy', 'sadness':'sad', 'neutral':'neutral'} labels = ['angry', 'happy', 'neutral', 'sad'] target = emotion