Skip to content

Commit

Permalink
Merge branch 'master' of github.com:bagustris/nkululeko
Browse files Browse the repository at this point in the history
  • Loading branch information
bagustris committed Apr 26, 2024
2 parents dddf7d2 + 44e7c4b commit fb0e000
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 24 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

Version 0.83.1
--------------
* add test module to nkuluflag

Version 0.83.0
--------------
* test module now prints out reports
Expand Down
2 changes: 1 addition & 1 deletion nkululeko/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION="0.83.0"
VERSION="0.83.1"
SAMPLING_RATE = 16000
3 changes: 2 additions & 1 deletion nkululeko/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,8 @@ def predict_test_and_save(self, result_name):
test_predictor = TestPredictor(
model, self.df_test, self.label_encoder, result_name
)
test_predictor.predict_and_store()
result = test_predictor.predict_and_store()
return result

def load(self, filename):
f = open(filename, "rb")
Expand Down
25 changes: 19 additions & 6 deletions nkululeko/nkuluflag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import configparser
import os
import os.path
import sys

from nkululeko.nkululeko import doit as nkulu
from nkululeko.test import do_it as test_mod


def do_it(src_dir):
def doit(cla):
parser = argparse.ArgumentParser(description="Call the nkululeko framework.")
parser.add_argument("--config", help="The base configuration")
parser.add_argument("--mod", default="nkulu", help="Which nkululeko module to call")
parser.add_argument("--data", help="The databases", nargs="*", action="append")
parser.add_argument(
"--label", nargs="*", help="The labels for the target", action="append"
Expand All @@ -25,20 +28,23 @@ def do_it(src_dir):
parser.add_argument("--model", default="xgb", help="The model type")
parser.add_argument("--feat", default="['os']", help="The feature type")
parser.add_argument("--set", help="The opensmile set")
parser.add_argument("--with_os", help="To add os features")
parser.add_argument("--target", help="The target designation")
parser.add_argument("--epochs", help="The number of epochs")
parser.add_argument("--runs", help="The number of runs")
parser.add_argument("--learning_rate", help="The learning rate")
parser.add_argument("--drop", help="The dropout rate [0:1]")

args = parser.parse_args()
args = parser.parse_args(cla)

if args.config is not None:
config_file = args.config
else:
print("ERROR: need config file")
quit(-1)

if args.mod is not None:
nkulu_mod = args.mod

# test if config is there
if not os.path.isfile(config_file):
print(f"ERROR: no such file {config_file}")
Expand Down Expand Up @@ -86,10 +92,17 @@ def do_it(src_dir):
with open(tmp_config, "w") as tmp_file:
config.write(tmp_file)

result, last_epoch = nkulu(tmp_config)
result, last_epoch = 0, 0
if nkulu_mod == "nkulu":
result, last_epoch = nkulu(tmp_config)
elif nkulu_mod == "test":
result, last_epoch = test_mod(tmp_config, "test_results.csv")
else:
print(f"ERROR: unknown module: {nkulu_mod}, should be [nkulu | test]")
return result, last_epoch


if __name__ == "__main__":
cwd = os.path.dirname(os.path.abspath(__file__))
do_it(cwd) # sys.argv[1])
cla = sys.argv
cla.pop(0)
doit(cla) # sys.argv[1])
35 changes: 20 additions & 15 deletions nkululeko/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,7 @@
from nkululeko.utils.util import Util


def main(src_dir):
parser = argparse.ArgumentParser(
description="Call the nkululeko TEST framework.")
parser.add_argument("--config", default="exp.ini",
help="The base configuration")
parser.add_argument(
"--outfile",
default="my_results.csv",
help="File name to store the predictions",
)

args = parser.parse_args()

config_file = args.config
def do_it(config_file, outfile):

# test if the configuration file exists
if not os.path.isfile(config_file):
Expand All @@ -48,10 +35,28 @@ def main(src_dir):
expr.load(f"{util.get_save_name()}")
expr.fill_tests()
expr.extract_test_feats()
expr.predict_test_and_save(args.outfile)
result = expr.predict_test_and_save(outfile)

print("DONE")

return result, 0


def main(src_dir):
parser = argparse.ArgumentParser(description="Call the nkululeko TEST framework.")
parser.add_argument("--config", default="exp.ini", help="The base configuration")
parser.add_argument(
"--outfile",
default="my_results.csv",
help="File name to store the predictions",
)
args = parser.parse_args()
if args.config is not None:
config_file = args.config
else:
config_file = f"{src_dir}/exp.ini"
do_it(config_file, args.outfile)


if __name__ == "__main__":
cwd = os.path.dirname(os.path.abspath(__file__))
Expand Down
3 changes: 3 additions & 0 deletions nkululeko/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, model, orig_df, labenc, name):

def predict_and_store(self):
label_data = self.util.config_val("DATA", "label_data", False)
result = 0
if label_data:
data = Dataset(label_data)
data.load()
Expand Down Expand Up @@ -57,6 +58,7 @@ def predict_and_store(self):
test_dbs_string = "_".join(test_dbs)
predictions = self.model.get_predictions()
report = self.model.predict()
result = report.result.get_result()
report.set_filename_add(f"test-{test_dbs_string}")
self.util.print_best_results([report])
report.plot_confmatrix(self.util.get_plot_name(), 0)
Expand All @@ -74,3 +76,4 @@ def predict_and_store(self):
df = df.rename(columns={"class_label": target})
df.to_csv(self.name)
self.util.debug(f"results stored in {self.name}")
return result
2 changes: 1 addition & 1 deletion tests/exp_emodb_audmodel_xgb.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ save = True
databases = ['emodb']
emodb = ./data/emodb/emodb
emodb.split_strategy = random
emodb.limit_samples = 50
emodb.limit_samples = 200
emodb.mapping = {'anger':'angry', 'happiness':'happy', 'sadness':'sad', 'neutral':'neutral'}
labels = ['angry', 'happy', 'neutral', 'sad']
target = emotion
Expand Down

0 comments on commit fb0e000

Please sign in to comment.