diff --git a/anomalib/data/utils/split.py b/anomalib/data/utils/split.py index 8c63497a41..ef12ae9773 100644 --- a/anomalib/data/utils/split.py +++ b/anomalib/data/utils/split.py @@ -49,7 +49,7 @@ def split_normal_images_in_train_set( DataFrame: Output dataframe where the part of the training set is assigned to test set. """ - if seed > 0: + if seed >= 0: random.seed(seed) normal_train_image_indices = samples.index[(samples.split == "train") & (samples.label == normal_label)].to_list() @@ -74,7 +74,7 @@ def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0, norma normal_label (str): Name of the normal label. For MVTec AD, for instance, this is normal_label. """ - if seed > 0: + if seed >= 0: random.seed(seed) # Split normal images.