Skip to content

Commit

Permalink
Fix ROC Component (#559)
Browse files Browse the repository at this point in the history
* Fix ROC component.

* fix ROC component.

* Follow up on CR comments.
  • Loading branch information
qimingj authored and k8s-ci-robot committed Dec 18, 2018
1 parent 302e93c commit a23abf8
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions components/local/roc/src/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,29 @@
def main(argv=None):
parser = argparse.ArgumentParser(description='ML Trainer')
parser.add_argument('--predictions', type=str, help='GCS path of prediction file pattern.')
parser.add_argument('--trueclass', type=str, help='The name of the class as true value.')
parser.add_argument('--trueclass', type=str, default='true',
help='The name of the class as true value. If missing, assuming it is ' +
'binary classification and default to "true".')
parser.add_argument('--true_score_column', type=str, default='true',
help='The name of the column for positive prob. If missing, assuming it is ' +
'binary classification and defaults to "true".')
parser.add_argument('--target_lambda', type=str,
help='a lambda function as a string to determine positive or negative.' +
'For example, "lambda x: x[\'a\'] and x[\'b\']". If missing, ' +
'trueclass must be set and input must have a "target" column.')
'input must have a "target" column.')
parser.add_argument('--output', type=str, help='GCS path of the output directory.')
args = parser.parse_args()

if not args.target_lambda and not args.trueclass:
raise ValueError('Either target_lambda or trueclass must be set.')

schema_file = os.path.join(os.path.dirname(args.predictions), 'schema.json')
schema = json.loads(file_io.read_file_to_string(schema_file))
names = [x['name'] for x in schema]

if not args.target_lambda and 'target' not in names:
raise ValueError('There is no "target" column, and target_lambda is not provided.')

if args.true_score_column not in names:
raise ValueError('Cannot find column name "%s"' % args.true_score_column)

dfs = []
files = file_io.get_matching_files(args.predictions)
for file in files:
Expand All @@ -57,8 +66,8 @@ def main(argv=None):
df['target'] = df.apply(eval(args.target_lambda), axis=1)
else:
df['target'] = df['target'].apply(lambda x: 1 if x == args.trueclass else 0)
fpr, tpr, thresholds = roc_curve(df['target'], df[args.trueclass])
roc_auc = roc_auc_score(df['target'], df[args.trueclass])
fpr, tpr, thresholds = roc_curve(df['target'], df[args.true_score_column])
roc_auc = roc_auc_score(df['target'], df[args.true_score_column])
df_roc = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'thresholds': thresholds})
roc_file = os.path.join(args.output, 'roc.csv')
with file_io.FileIO(roc_file, 'w') as f:
Expand Down

0 comments on commit a23abf8

Please sign in to comment.