diff --git a/src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py b/src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py index 8d13dd6b1d..22fe07a543 100644 --- a/src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py +++ b/src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py @@ -57,8 +57,8 @@ def __init__(self, epoch_num=20, optimize_mode='maximize', start_step=6, thresho self.threshold = threshold # Record the number of gap self.gap = gap - # Record the number of times of judgments - self.judgment_num = 0 + # Record the number of intermediate result in the lastest judgment + self.last_judgment_num = dict() # Record the best performance self.set_best_performance = False self.completed_best_performance = None @@ -112,9 +112,10 @@ def assess_trial(self, trial_job_id, trial_history): curr_step = len(trial_history) if curr_step < self.start_step: return AssessResult.Good - if (curr_step - self.start_step) // self.gap <= self.judgment_num: + + if trial_job_id in self.last_judgment_num.keys() and curr_step - self.last_judgment_num[trial_job_id] < self.gap: return AssessResult.Good - self.judgment_num = (curr_step - self.start_step) // self.gap + self.last_judgment_num[trial_job_id] = curr_step try: start_time = datetime.datetime.now()