diff --git a/classifier.py b/classifier.py index 7b758e9..8a0affd 100644 --- a/classifier.py +++ b/classifier.py @@ -362,8 +362,8 @@ def get_args(): dev='data/ids-sst-dev.csv', test='data/ids-sst-test-student.csv', option=args.option, - dev_out = 'predictions/sst-dev-out.csv', - test_out = 'predictions/sst-test-out.csv' + dev_out = 'predictions/'+args.option+'-sst-dev-out.csv', + test_out = 'predictions/'+args.option+'-sst-test-out.csv' ) train(config) @@ -383,8 +383,8 @@ def get_args(): dev='data/ids-cfimdb-dev.csv', test='data/ids-cfimdb-test-student.csv', option=args.option, - dev_out = 'predictions/cfimdb-dev-out.csv', - test_out = 'predictions/cfimdb-test-out.csv' + dev_out = 'predictions/'+args.option+'-cfimdb-dev-out.csv', + test_out = 'predictions/'+args.option+'-cfimdb-test-out.csv' ) train(config)