From ddc431eabdbdab2997c625c53fe15d73cf9466d7 Mon Sep 17 00:00:00 2001 From: matthiasblondeel <32500173+matthiasblondeel@users.noreply.github.com> Date: Thu, 27 Feb 2025 07:56:59 -0800 Subject: [PATCH] [Enhancements]: Re-arch inputs and tsv generations and few more updates (#3876) * More component updates * Updating training component for MI2 --------- Co-authored-by: Matthias Blondeel --- .../finetune/medimage_adapter/spec.yaml | 16 +- .../finetune/medimage_insight/spec.yaml | 4 +- .../medimage_insight_ft/spec.yaml | 28 +- .../preprocess/image_embedding/spec.yaml | 37 +-- .../medimage_train.py | 77 ++++- .../training.py | 314 ++++++++++++++++++ .../medimage_datapreprocess.py | 129 +++---- .../medimage_embedding_finetune.py | 36 +- 8 files changed, 473 insertions(+), 168 deletions(-) create mode 100644 assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/training.py diff --git a/assets/training/finetune_acft_image/components/finetune/medimage_adapter/spec.yaml b/assets/training/finetune_acft_image/components/finetune/medimage_adapter/spec.yaml index ac23b8d77d..c4236fa280 100644 --- a/assets/training/finetune_acft_image/components/finetune/medimage_adapter/spec.yaml +++ b/assets/training/finetune_acft_image/components/finetune/medimage_adapter/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: medimgage_adapter_finetune -version: 0.0.4 +version: 0.0.7 type: command is_deterministic: True @@ -29,6 +29,18 @@ inputs: description: Path to the validation data file. mode: ro_mount + validation_text_tsv: + type: uri_file + optional: false + description: Path to the evaluation text TSV file. + mode: ro_mount + + train_text_tsv: + type: uri_file + optional: false + description: Path to the text TSV file. + mode: ro_mount + train_dataloader_batch_size: type: integer min: 1 @@ -101,6 +113,8 @@ command: >- --task_name "AdapterTrain" --train_data_path "${{inputs.train_data_path}}" --validation_data_path "${{inputs.validation_data_path}}" + --validation_text_tsv "${{inputs.validation_text_tsv}}" + --train_text_tsv "${{inputs.train_text_tsv}}" --label_file "${{inputs.label_file}}" $[[--train_dataloader_batch_size "${{inputs.train_dataloader_batch_size}}"]] $[[--validation_dataloader_batch_size "${{inputs.validation_dataloader_batch_size}}"]] diff --git a/assets/training/finetune_acft_image/components/finetune/medimage_insight/spec.yaml b/assets/training/finetune_acft_image/components/finetune/medimage_insight/spec.yaml index bed6a9dfca..954ebe5221 100644 --- a/assets/training/finetune_acft_image/components/finetune/medimage_insight/spec.yaml +++ b/assets/training/finetune_acft_image/components/finetune/medimage_insight/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: medimgage_embedding_finetune -version: 0.0.23 +version: 0.0.27 type: command @@ -9,7 +9,7 @@ is_deterministic: True display_name: Medical Image Insight Embedding Finetune description: Component to finetune the model using the medical image data -environment : azureml://registries/mablonde-registry-101/environments/acpt-medimage-embedding/versions/15 +environment : azureml://registries/mablonde-registry-101/environments/acpt-medimage-embedding/versions/16 code: ../../../src/medimage_insight_embedding_finetune distribution: diff --git a/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/spec.yaml b/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/spec.yaml index f7ad90ae68..55b98efe65 100644 --- a/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/spec.yaml +++ b/assets/training/finetune_acft_image/components/pipeline_components/medimage_insight_ft/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json name: medimage_insight_ft_pipeline -version: 0.0.16 +version: 0.0.24 type: pipeline display_name: Medical Image Insight Embedding Generator and Classification Adapter Pipeline description: Pipeline Component to finetune Hugging Face pretrained models for chat completion task. The component supports optimizations such as LoRA, Deepspeed and ONNXRuntime for performance enhancement. See [docs](https://aka.ms/azureml/components/chat_completion_pipeline) to learn more. @@ -341,7 +341,7 @@ outputs: jobs: medical_image_embedding_model_finetune: type: command - component: azureml://registries/mablonde-registry-101/components/medimgage_embedding_finetune/versions/0.0.23 + component: azureml://registries/mablonde-registry-101/components/medimgage_embedding_finetune/versions/0.0.27 compute: '${{parent.inputs.compute_finetune}}' resources: instance_type: '${{parent.inputs.instance_type_finetune}}' @@ -374,27 +374,35 @@ jobs: outputs: save_dir: '${{parent.outputs.save_dir}}' mlflow_model_folder: '${{parent.outputs.mlflow_model_folder}}' - medical_image_embedding_datapreprocessing: + medical_image_embedding_datapreprocessing_train: type: command - component: azureml://registries/mablonde-registry-101/components/medical_image_embedding_datapreprocessing/versions/0.0.9 + component: azureml://registries/mablonde-registry-101/components/medical_image_embedding_datapreprocessing/versions/0.0.11 compute: '${{parent.inputs.compute_preprocess}}' resources: instance_type: '${{parent.inputs.instance_type_preprocess}}' inputs: mlflow_model_path: '${{parent.jobs.medical_image_embedding_model_finetune.outputs.mlflow_model_folder}}' - eval_image_tsv: '${{parent.inputs.eval_image_tsv}}' - eval_text_tsv: '${{parent.inputs.eval_text_tsv}}' image_tsv: '${{parent.inputs.image_tsv}}' - text_tsv: '${{parent.inputs.text_tsv}}' + medical_image_embedding_datapreprocessing_validation: + type: command + component: azureml://registries/mablonde-registry-101/components/medical_image_embedding_datapreprocessing/versions/0.0.11 + compute: '${{parent.inputs.compute_preprocess}}' + resources: + instance_type: '${{parent.inputs.instance_type_preprocess}}' + inputs: + mlflow_model_path: '${{parent.jobs.medical_image_embedding_model_finetune.outputs.mlflow_model_folder}}' + image_tsv: '${{parent.inputs.eval_image_tsv}}' medimgage_adapter_finetune: type: command - component: azureml://registries/mablonde-registry-101/components/medimgage_adapter_finetune/versions/0.0.4 + component: azureml://registries/mablonde-registry-101/components/medimgage_adapter_finetune/versions/0.0.7 compute: '${{parent.inputs.compute_finetune}}' resources: instance_type: '${{parent.inputs.instance_type_finetune}}' inputs: - train_data_path: '${{parent.jobs.medical_image_embedding_datapreprocessing.outputs.output_train_pkl}}' - validation_data_path: '${{parent.jobs.medical_image_embedding_datapreprocessing.outputs.output_validation_pkl}}' + train_data_path: '${{parent.jobs.medical_image_embedding_datapreprocessing_train.outputs.output_pkl}}' + validation_data_path: '${{parent.jobs.medical_image_embedding_datapreprocessing_validation.outputs.output_pkl}}' + train_text_tsv: '${{parent.inputs.text_tsv}}' + validation_text_tsv: '${{parent.inputs.eval_text_tsv}}' train_dataloader_batch_size: '${{parent.inputs.train_dataloader_batch_size}}' validation_dataloader_batch_size: '${{parent.inputs.validation_dataloader_batch_size}}' train_dataloader_workers: '${{parent.inputs.train_dataloader_workers}}' diff --git a/assets/training/finetune_acft_image/components/preprocess/image_embedding/spec.yaml b/assets/training/finetune_acft_image/components/preprocess/image_embedding/spec.yaml index b39081a23b..db82e7d861 100644 --- a/assets/training/finetune_acft_image/components/preprocess/image_embedding/spec.yaml +++ b/assets/training/finetune_acft_image/components/preprocess/image_embedding/spec.yaml @@ -1,6 +1,6 @@ $schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json name: medical_image_embedding_datapreprocessing -version: 0.0.9 +version: 0.0.11 type: command is_deterministic: True @@ -9,34 +9,16 @@ display_name: Embedding Generation for Medical Images description: To genrate embeddings for medical images. See [docs](https://aka.ms/azureml/components/medical_image_embedding_datapreprocessing) to learn more. #environment: azureml:/subscriptions/dbd697c3-ef40-488f-83e6-5ad4dfb78f9b/resourceGroups/rdondera/providers/Microsoft.MachineLearningServices/workspaces/validatr/environments/medimage-embedding-generation/versions/5 -environment: azureml://registries/models-staging/environments/medimage-embedding-generation/versions/5 +environment: azureml://registries/mablonde-registry-101/environments/medimage-embedding-generation/versions/7 code: ../../../src/medimage_insight_adapter_preprocess inputs: - eval_image_tsv: - type: uri_file - optional: false - description: Path to the evaluation image TSV file. - mode: ro_mount - - eval_text_tsv: - type: uri_file - optional: false - description: Path to the evaluation text TSV file. - mode: ro_mount - image_tsv: type: uri_file optional: false description: Path to the image TSV file. mode: ro_mount - text_tsv: - type: uri_file - optional: false - description: Path to the text TSV file. - mode: ro_mount - mlflow_model_path: type: uri_folder optional: false @@ -44,23 +26,14 @@ inputs: mode: ro_mount outputs: - output_train_pkl: + output_pkl: type: uri_folder description: Path to the output training PKL file. mode: rw_mount - output_validation_pkl: - type: uri_folder - description: Path to the output validation PKL file. - mode: rw_mount - command: >- python medimage_datapreprocess.py - --task_name "MedEmbedding" - --eval_image_tsv "${{inputs.eval_image_tsv}}" - --eval_text_tsv "${{inputs.eval_text_tsv}}" + --task_name "MedEmbedding" --image_tsv "${{inputs.image_tsv}}" - --text_tsv "${{inputs.text_tsv}}" - --output_train_pkl "${{outputs.output_train_pkl}}" - --output_validation_pkl "${{outputs.output_validation_pkl}}" + --output_pkl "${{outputs.output_pkl}}" --mlflow_model_path "${{inputs.mlflow_model_path}}" diff --git a/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/medimage_train.py b/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/medimage_train.py index e6982af493..7e064300a7 100644 --- a/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/medimage_train.py +++ b/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/medimage_train.py @@ -1,31 +1,24 @@ import argparse +import json from azureml.acft.common_components import get_logger_app, set_logging_parameters, LoggingLiterals from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError from azureml.acft.common_components.utils.error_handling.swallow_all_exceptions_decorator import ( swallow_all_exceptions, ) -from azureml._common._error_definition.azureml_error import AzureMLError from azureml.acft.contrib.hf import VERSION, PROJECT_NAME from azureml.acft.contrib.hf.nlp.constants.constants import LOGS_TO_BE_FILTERED_IN_APPINSIGHTS import pandas as pd import torch import os -from classification_demo.MedImageInsight import medimageinsight_package -from classification_demo.adaptor_training import training +import training import matplotlib.pyplot as plt -import SimpleITK as sitk -from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score - -# Suppress SimpleITK warnings -sitk.ProcessObject_SetGlobalWarningDisplay(False) COMPONENT_NAME = "ACFT-MedImage-Classification-Training" logger = get_logger_app("azureml.acft.contrib.hf.scripts.src.train.classification_adaptor_train") -TRAIN_EMBEDDING_FILE_NAME = "train_embeddings.pkl" -VALIDATION_EMBEDDING_FILE_NAME = "validation_embeddings.pkl" +EMBEDDING_FILE_NAME = "embeddings.pkl" def get_parser(): @@ -54,6 +47,18 @@ def get_parser(): required=True, help='The path to the validation data.' ) + parser.add_argument( + "--train_text_tsv", + type=str, + help="Path to evaluation text TSV file.", + required=True + ) + parser.add_argument( + "--validation_text_tsv", + type=str, + help="Path to training text TSV file.", + required=True + ) parser.add_argument( '--train_dataloader_batch_size', type=int, @@ -117,8 +122,7 @@ def get_parser(): return parser -def load_data(train_data_path: str, validation_data_path: str, train_file_name: str, - validation_file_name: str) -> tuple[pd.DataFrame, pd.DataFrame]: +def load_data(train_data_path: str, validation_data_path: str) -> tuple[pd.DataFrame, pd.DataFrame]: """ Load the training and validation data from the provided folder paths. @@ -132,12 +136,50 @@ def load_data(train_data_path: str, validation_data_path: str, train_file_name: tuple[pd.DataFrame, pd.DataFrame]: DataFrames containing the training and validation data. """ - train_data_file = os.path.join(train_data_path, train_file_name) - validation_data_file = os.path.join(validation_data_path, validation_file_name) + train_data_file = os.path.join(train_data_path, EMBEDDING_FILE_NAME) + validation_data_file = os.path.join(validation_data_path, EMBEDDING_FILE_NAME) train_data = pd.read_pickle(train_data_file) validation_data = pd.read_pickle(validation_data_file) return train_data, validation_data +def merge_data_with_text( + train_data: pd.DataFrame, + validation_data: pd.DataFrame, + train_text_tsv: str, + validation_text_tsv: str +) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Merge the training and validation data with the corresponding text data. + + Args: + train_data (pd.DataFrame): DataFrame containing the training data. + validation_data (pd.DataFrame): DataFrame containing the validation data. + train_text_tsv (str): Path to the TSV file containing training text data. + validation_text_tsv (str): Path to the TSV file containing validation text data. + + Returns: + tuple[pd.DataFrame, pd.DataFrame]: Merged DataFrames for training and validation data. + """ + train_text_df = pd.read_csv(train_text_tsv, sep="\t") + train_text_df.columns = ["Name", "classification_json"] + validation_text_df = pd.read_csv(validation_text_tsv, sep="\t") + validation_text_df.columns = ["Name", "classification_json"] + + def extract_label_from_json(json_str): + try: + json_obj = json.loads(json_str) + return json_obj.get("class_id", -1) + except json.JSONDecodeError: + logger.error("Failed to decode JSON from text column") + return -1 + + train_text_df["Label"] = train_text_df["classification_json"].apply(extract_label_from_json) + validation_text_df["Label"] = validation_text_df["classification_json"].apply(extract_label_from_json) + + train_data = pd.merge(train_data, train_text_df, on="Name")[["Name", "features", "Label"]] + validation_data = pd.merge(validation_data, validation_text_df, on="Name")[["Name", "features", "Label"]] + + return train_data, validation_data def initialize_model(args: argparse.Namespace) -> torch.nn.Module: """ @@ -255,10 +297,11 @@ def main(): }, azureml_pkg_denylist_logging_patterns=LOGS_TO_BE_FILTERED_IN_APPINSIGHTS, ) - train_data, validation_data = load_data(args.train_data_path, args.validation_data_path, - TRAIN_EMBEDDING_FILE_NAME, VALIDATION_EMBEDDING_FILE_NAME) + train_data, validation_data = load_data(args.train_data_path, args.validation_data_path) + train_data, validation_data = merge_data_with_text(train_data, validation_data, args.train_text_tsv, args.validation_text_tsv) model = initialize_model(args) train_dataloader, validation_dataloader = prepare_dataloaders(train_data, validation_data, args) + best_accuracy, best_auc = train_model(train_dataloader, validation_dataloader, model, args) print(f"Best Accuracy of the Adaptor: {best_accuracy:.4f}") print(f"Best AUC of the Adaptor: {best_auc:.4f}") @@ -268,4 +311,4 @@ def main(): main() # Example command to run this script: -# python medimage_train.py --task_name "AdapterTrain" --train_data_path "/home/healthcare-ai/train_merged.pkl" --validation_data_path "/home/healthcare-ai/val_merged.pkl" --train_dataloader_batch_size 8 --validation_dataloader_batch_size 1 --train_dataloader_workers 2 --validation_dataloader_workers 2 --output_classes 5 --hidden_dimensions 512 --input_channels 1024 --learning_rate 0.0003 --max_epochs 10 --output_model_path "/home/healthcare-ai/" +# python medimage_train.py --task_name "AdapterTrain" --train_data_path "/home/healthcare-ai/train_data" --validation_data_path "/home/healthcare-ai/val_data" --train_text_tsv "/home/healthcare-ai/train_text.tsv" --validation_text_tsv "/home/healthcare-ai/val_text.tsv" --train_dataloader_batch_size 8 --validation_dataloader_batch_size 1 --train_dataloader_workers 2 --validation_dataloader_workers 2 --label_file "/home/healthcare-ai/labels.txt" --hidden_dimensions 512 --input_channels 1024 --learning_rate 0.0003 --max_epochs 10 --output_model_path "/home/healthcare-ai/" diff --git a/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/training.py b/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/training.py new file mode 100644 index 0000000000..38e0e2430c --- /dev/null +++ b/assets/training/finetune_acft_image/src/medimage_insight_adapter_finetune/training.py @@ -0,0 +1,314 @@ +import os +import numpy as np +from torch.utils import data +from torch import nn +import torch +import pandas as pd +from sklearn.metrics import roc_auc_score +from tqdm import tqdm +import time + + +class feature_loader(data.Dataset): + def __init__(self, data_dict, csv, mode="train"): + self.data_dict = data_dict + self.csv = csv + self.mode = mode + self.img_name = data_dict["img_name"] + self.features = data_dict["features"] + + def __getitem__(self, item): + img_name = self.img_name[item] + features = self.features[item] + features = features.astype("float32") + + row = self.csv[self.csv["Name"] == img_name] + if self.mode == "train" or self.mode == "val": + label = row["Label"].values + + label = np.array(label) + label = np.reshape(label, (1,)) + label = label.squeeze() + + return features, label, img_name + + elif self.mode == "test": + return features, img_name + + def __len__(self): + return len(self.img_name) + + +## MLP Adaptors +## Input: 1-Dimensional Embeddings +## in_channels: Number of channels for input embeddings, num_class: Number of classes, finetune_mode: image (image-only) +## Output: Class-wise Prediction +class MLP_model(nn.Module): + def __init__(self, in_channels, hidden_dim, num_class): + super().__init__() + + self.in_channels = int(in_channels) + self.hidden_dim = int(hidden_dim) + self.num_class = num_class + + ## Adaptor Module + self.vision_embd = nn.Sequential( + nn.Linear(self.in_channels, 512), + nn.GELU(), + nn.Linear(512, 512), + nn.LayerNorm(512), + ) + + self.retrieval_conv = nn.Sequential( + nn.Conv1d( + in_channels=512, + out_channels=self.hidden_dim, + kernel_size=3, + padding=1, + ), + nn.GELU(), + nn.Conv1d( + in_channels=self.hidden_dim, + out_channels=self.hidden_dim, + kernel_size=3, + padding=1, + ), + ) + + ## Prediction Head + self.prediction_head = nn.Sequential(nn.Linear(self.hidden_dim, self.num_class)) + + def forward(self, vision_feat): + feat = self.vision_embd(vision_feat.squeeze(1)) + feat = self.retrieval_conv(torch.unsqueeze(feat, 2)) + class_output = self.prediction_head(feat.squeeze(2)) + + return feat, class_output + + +def load_label_csv(label_csv_file): + """ + Loads the label CSV file. + + Args: + - label_csv_file (str): Path to the label CSV file. + + Returns: + - df_label (pandas.DataFrame): Loaded label CSV as a DataFrame. + """ + df_label = pd.read_csv(label_csv_file) + return df_label + + +def create_data_loader(samples, csv, mode, batch_size, num_workers=2, pin_memory=True): + """ + Creates a data loader for the generated embeddings. + + Args: + - samples (dict): Dictionary containing the features and image names. + - csv (pandas.DataFrame): DataFrame containing the labels. + - mode (str): Mode of the data loader (train or test). + - batch_size (int): Batch size for the data loader. + - num_workers (int): Number of workers for the data loader (default: 2). + - pin_memory (bool): Whether to pin the memory for the data loader (default: True). + + Returns: + - data_loader (torch.utils.data.DataLoader): Data loader for the generated embeddings. + """ + ds = feature_loader(samples, csv=csv, mode=mode) + data_loader = torch.utils.data.DataLoader( + ds, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory, + ) + return data_loader + + +def create_output_directory(output_dir): + """ + Create the output directory if it does not exist. + + Args: + - output_dir (str): Path to the output directory. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + +def create_model(in_channels, hidden_dim, num_class): + """ + Create a model for the adaptor model (Default: MLP). + + Args: + in_channels (int): Number of input channels. + hidden_dim (int): Dimension of the hidden layer. + num_class (int): Number of output classes. + + Returns: + torch.nn.Module: The created MLP model. + """ + model = MLP_model( + in_channels=in_channels, hidden_dim=hidden_dim, num_class=num_class + ) + return model + + +def trainer(train_ds, test_ds, model, loss_function_ts, optimizer, epochs, root_dir): + """ + Trains a classification model and evaluates it on a validation set. + Saves the model with the best validation ROC AUC score. + """ + + start_time = time.time() + + max_epoch = epochs + best_metric = -1 + best_acc = -1 + best_metric_epoch = -1 + epoch_loss_values = [] + metric_values = [] + + # Set device + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + for epoch in range(max_epoch): + print("-" * 10) + print(f"Epoch {epoch + 1}/{max_epoch}") + model.train() + epoch_loss = 0 + step = 0 + + # Training loop + for batch_idx, (features, pathology_label, img_name) in tqdm( + enumerate(train_ds), + total=len(train_ds), + desc=f"Train Epoch={epoch}", + ncols=80, + leave=False, + ): + + step += 1 + features = features.to(device) + pathology_label = pathology_label.to(device) + + optimizer.zero_grad() + _, pred_pathology = model(features) + + loss = loss_function_ts(pred_pathology, pathology_label) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + print(f"{step}/{len(train_ds)}, train_loss: {loss.item():.4f}") + + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + print(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}") + + # Validation loop + model.eval() + with torch.no_grad(): + y_pred_list = [] + y_true_list = [] + + for batch_idx, (features, pathology_label, img_name) in tqdm( + enumerate(test_ds), + total=len(test_ds), + desc=f"Test Epoch={epoch}", + ncols=80, + leave=False, + ): + + features = features.to(device) + pathology_label = pathology_label.to(device) + + _, pred_pathology = model(features) + + y_pred_list.append(pred_pathology) + y_true_list.append(pathology_label) + + # Concatenate predictions and true labels + y_pred = torch.cat(y_pred_list, dim=0) + y_true = torch.cat(y_true_list, dim=0) + + # Compute probabilities for the positive class + y_scores = torch.softmax(y_pred, dim=1).cpu().numpy() + y_true_np = y_true.cpu().numpy() + + # Compute ROC AUC + if y_scores.shape[1] == 2: + # Compute ROC AUC for binary classification + # y_scores[:, 1] contains the probabilities for the positive class + auc = roc_auc_score(y_true_np, y_scores[:, 1]) + else: + auc = roc_auc_score(y_true_np, y_scores, multi_class="ovr") + + # Compute accuracy + acc_metric = (y_pred.argmax(dim=1) == y_true).sum().item() / len(y_true) + + metric_values.append(auc) + + # Save the best model + if auc > best_metric: + best_metric = auc + best_acc = acc_metric + best_metric_epoch = epoch + 1 + torch.save( + model.state_dict(), os.path.join(root_dir, "best_metric_model.pth") + ) + print("Saved new best metric model") + + print( + f"Current epoch: {epoch + 1} Current AUC: {auc:.4f}" + f" Current accuracy: {acc_metric:.4f}" + f" Best AUC: {best_metric:.4f}" + f" Best accuracy: {best_acc:.4f}" + f" at epoch: {best_metric_epoch}" + ) + + end_time = time.time() + training_time = end_time - start_time + hours, rem = divmod(training_time, 3600) + minutes, seconds = divmod(rem, 60) + print(f"Total Training Time: {int(hours):02}:{int(minutes):02}:{seconds:.2f}") + print( + f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}" + ) + return best_acc, best_metric + + +def perform_inference(model, test_loader): + predictions = [] + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.eval() + with torch.no_grad(): + for features, img_names in tqdm(test_loader, desc="Inference", ncols=80): + features = features.to(device) + _, output = model(features) + # Apply softmax to get probabilities + probabilities = torch.softmax(output, dim=1) + predicted_classes = probabilities.argmax(dim=1).cpu().numpy() + # Collect predictions + for img_name, predicted_class, prob in zip( + img_names, predicted_classes, probabilities.cpu().numpy() + ): + predictions.append( + { + "Name": img_name, + "PredictedClass": predicted_class, + "Probability": prob[predicted_class], + } + ) + return predictions + + +def load_trained_model(model, model_path): + # Load Model State + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.load_state_dict(torch.load(model_path, map_location=device)) + model.to(device) + + return model diff --git a/assets/training/finetune_acft_image/src/medimage_insight_adapter_preprocess/medimage_datapreprocess.py b/assets/training/finetune_acft_image/src/medimage_insight_adapter_preprocess/medimage_datapreprocess.py index df667efc37..322370db8e 100644 --- a/assets/training/finetune_acft_image/src/medimage_insight_adapter_preprocess/medimage_datapreprocess.py +++ b/assets/training/finetune_acft_image/src/medimage_insight_adapter_preprocess/medimage_datapreprocess.py @@ -23,19 +23,20 @@ import pandas as pd import numpy as np import os -import json COMPONENT_NAME = "ACFT-MedImage-Embedding-Generator" -TRAIN_EMBEDDING_FILE_NAME = "train_embeddings.pkl" -VALIDATION_EMBEDDING_FILE_NAME = "validation_embeddings.pkl" +EMBEDDING_FILE_NAME = "embeddings.pkl" logger = get_logger_app( "azureml.acft.contrib.hf.scripts.src.process_embedding.embeddings_generator" ) """ -Input Arguments: endpoint_url, endpoint_key, zeroshot_path, test_train_split_pkl_path +Input Arguments: + --image_tsv: Path to image TSV file. + --mlflow_model_path: The path to the MLflow model. + --output_pkl: Output PKL file path. """ @@ -48,139 +49,90 @@ def get_parser(): parser = argparse.ArgumentParser( description="Process medical images and get embeddigns", allow_abbrev=False ) - - parser.add_argument( - "--eval_image_tsv", type=str, help="Path to evaluation image TSV file." - ) - parser.add_argument( - "--eval_text_tsv", type=str, help="Path to evaluation text TSV file." - ) - parser.add_argument( - "--image_tsv", type=str, help="Path to training image TSV file." - ) - parser.add_argument("--text_tsv", type=str, help="Path to training text TSV file.") - parser.add_argument( - "--mlflow_model_path", - type=str, - required=True, - help="The path to the MLflow model", - ) parser.add_argument( "--task_name", type=str, required=True, help="The name of the task to be executed", + ) + parser.add_argument( + "--image_tsv", type=str, help="Path to image TSV file." ) parser.add_argument( - "--output_train_pkl", + "--mlflow_model_path", type=str, - help="Output train PKL file path", + required=True, + help="The path to the MLflow model", ) parser.add_argument( - "--output_validation_pkl", + "--output_pkl", type=str, - help="Output validation PKL file path", + help="Output PKL file path", ) return parser -def generate_embeddings(image_tsv, text_tsv, mlflow_model): +def generate_embeddings(image_tsv, mlflow_model): image_df = pd.read_csv(image_tsv, sep="\t") image_df.columns = ["Name", "image"] image_df["text"] = None image_embeddings = mlflow_model.predict(image_df) image_df["features"] = image_embeddings["image_features"].apply(lambda item: np.array(item[0])) - text_df = pd.read_csv(text_tsv, sep="\t") - text_df.columns = ["Name", "classification"] + return image_df - def extract_text_field(text): - try: - text_json = json.loads(text) - return text_json.get("class_id", -1) - except json.JSONDecodeError: - logger.error("Failed to decode JSON from text column") - return "" - text_df["Label"] = text_df["classification"].apply(extract_text_field) - return pd.merge(image_df, text_df, on="Name", how="inner") - - -def save_merged_dataframes( - train_merged: pd.DataFrame, - val_merged: pd.DataFrame, - output_train_pkl_path: str, - output_validation_pkl_path: str, - train_embedding_file_name: str, - validation_embedding_file_name: str, +def save_dataframe( + image_embeddings: pd.DataFrame, + output_pkl_path: str, ) -> None: - """Save merged DataFrames to PKL files. + """Save image embeddings DataFrame to a PKL file. - This function saves the provided training and validation merged DataFrames - to the specified PKL file paths with the given file names. It also creates - the directories if they do not exist. + This function saves the provided image embeddings DataFrame + to the specified PKL file path with the given file name. It also creates + the directory if it does not exist. Args: - train_merged (pd.DataFrame): The merged training DataFrame to be saved. - val_merged (pd.DataFrame): The merged validation DataFrame to be saved. - output_train_pkl_path (str): The directory path where the training PKL file will be saved. - output_validation_pkl_path (str): The directory path where the validation PKL file will be saved. - train_embedding_file_name (str): The file name for the training PKL file. - validation_embedding_file_name (str): The file name for the validation PKL file. + image_embeddings (pd.DataFrame): The DataFrame containing image embeddings to be saved. + output_pkl_path (str): The directory path where the PKL file will be saved. Returns: None """ - os.makedirs(output_train_pkl_path, exist_ok=True) - os.makedirs(output_validation_pkl_path, exist_ok=True) + os.makedirs(output_pkl_path, exist_ok=True) - train_merged.to_pickle( - os.path.join(output_train_pkl_path, train_embedding_file_name) - ) - val_merged.to_pickle( - os.path.join(output_validation_pkl_path, validation_embedding_file_name) + image_embeddings.to_pickle( + os.path.join(output_pkl_path, EMBEDDING_FILE_NAME) ) + logger.info("Saved merged DataFrames to PKL files") def process_embeddings(args): """ - Process medical image embeddings and save the results to PKL files. - This function initializes the medimageinsight object, generates image embeddings, - creates a features dataframe, loads train and validation PKL files, merges the dataframes, - and saves the merged dataframes to specified output PKL files. + Process medical image embeddings and save the results to a PKL file. + This function loads the MLflow model, generates image embeddings from the provided TSV file, + and saves the embeddings to the specified output PKL file. + Args: args (Namespace): A namespace object containing the following attributes: - mlflow_model_path (str): The path to the MLflow model. - - zeroshot_path (str): The path to the zeroshot data. - - output_train_pkl (str): The path to save the output training PKL file. - - output_validation_pkl (str): The path to save the output validation PKL file. - - test_train_split_csv_path (str): The path to the test/train split CSV file. + - image_tsv (str): The path to the image TSV file. + - output_pkl (str): The path to save the output PKL file. Returns: None """ model_path = args.mlflow_model_path - output_train_pkl = args.output_train_pkl - output_validation_pkl = args.output_validation_pkl + output_pkl = args.output_pkl image_tsv = args.image_tsv - text_tsv = args.text_tsv - eval_image_tsv = args.eval_image_tsv - eval_text_tsv = args.eval_text_tsv mlflow_model = mlflow.pyfunc.load_model(model_path) - image_embeddings = generate_embeddings(image_tsv, text_tsv, mlflow_model) - eval_image_embeddings = generate_embeddings( - eval_image_tsv, eval_text_tsv, mlflow_model - ) + image_embeddings = generate_embeddings(image_tsv, mlflow_model) - save_merged_dataframes( + save_dataframe( image_embeddings, - eval_image_embeddings, - output_train_pkl, - output_validation_pkl, - TRAIN_EMBEDDING_FILE_NAME, - VALIDATION_EMBEDDING_FILE_NAME, + output_pkl ) logger.info("Processing medical images and getting embeddings completed") @@ -207,8 +159,3 @@ def main(): if __name__ == "__main__": main() - -""" -python medimage_datapreprocess.py --task_name "MedEmbedding" --mlflow_model_path "/mnt/model/MedImageInsight/mlflow_model_folder" --zeroshot_path "/home/healthcare-ai/medimageinsight-zeroshot/" --test_train_split_csv_path "/home/healthcare-ai/medimageinsight/classification_demo/data_input/" --output_train_pkl "/home/healthcare-ai/" --output_validation_pkl "/home/healthcare-ai/" - -""" diff --git a/assets/training/finetune_acft_image/src/medimage_insight_embedding_finetune/medimage_embedding_finetune.py b/assets/training/finetune_acft_image/src/medimage_insight_embedding_finetune/medimage_embedding_finetune.py index 522041bec4..7472c196d9 100644 --- a/assets/training/finetune_acft_image/src/medimage_insight_embedding_finetune/medimage_embedding_finetune.py +++ b/assets/training/finetune_acft_image/src/medimage_insight_embedding_finetune/medimage_embedding_finetune.py @@ -1,4 +1,5 @@ import argparse +import uuid from azureml.acft.common_components import get_logger_app, set_logging_parameters, LoggingLiterals from azureml.acft.common_components.utils.error_handling.exceptions import ACFTValidationException from azureml.acft.common_components.utils.error_handling.error_definitions import ACFTUserError @@ -15,7 +16,7 @@ import torch import yaml from typing import Any, Dict, List, Tuple -from safetensors.torch import save_file +from safetensors.torch import save_file, load_file from MainzTrain.Trainers.MainzTrainer import MainzTrainer from MainzTrain.Utils.Timing import Timer import MainzVision as mv @@ -450,22 +451,22 @@ def get_parser() -> argparse.ArgumentParser: return parser -def copy_eval_image_tsv(eval_image_tsv: str, save_dir: str) -> str: +def copy_tsv(tsv_file: str, save_dir: str) -> str: """ - Copy the evaluation image TSV file to the evaluate directory within the save directory. + Copy the TSV file to the save directory in a unique folder Args: - eval_image_tsv (str): Path to the evaluation image TSV file. + tsv_file (str): Path to the TSV file. save_dir (str): Directory to save the copied TSV file. Returns: str: The path to the copied TSV file. """ - evaluate_dir = os.path.join(save_dir, 'evaluate') - os.makedirs(evaluate_dir, exist_ok=True) - eval_image_tsv_dest = os.path.join(evaluate_dir, os.path.basename(eval_image_tsv)) - os.system(f'cp {eval_image_tsv} {eval_image_tsv_dest}') - return eval_image_tsv_dest + unique_dir = os.path.join(save_dir, str(uuid.uuid4())) + os.makedirs(unique_dir, exist_ok=True) + tsv_dest = os.path.join(unique_dir, os.path.basename(tsv_file)) + os.system(f'cp {tsv_file} {tsv_dest}') + return tsv_dest def load_opt_command(cmdline_args: argparse.Namespace) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -501,16 +502,21 @@ def load_opt_command(cmdline_args: argparse.Namespace) -> Tuple[Dict[str, Any], opt[key] = val # Append CHECKPOINT_PATH to mlflow_model_folder and update UNICL_MODEL's PRETRAINED key + # Convert from safetensors to torch if MLFLOW_MODEL_FOLDER in cmdline_args: mlflow_model_path = os.path.join(cmdline_args[MLFLOW_MODEL_FOLDER], CHECKPOINT_PATH) + safe_model = load_file(mlflow_model_path) + os.makedirs(SAVE_DIR, exist_ok=True) + new_path = os.path.join(SAVE_DIR, "medimageinsigt-v1.0.0-native.pt") + torch.save(safe_model, new_path) if UNICL_MODEL in opt and PRETRAINED in opt[UNICL_MODEL]: - opt[UNICL_MODEL][PRETRAINED] = mlflow_model_path + opt[UNICL_MODEL][PRETRAINED] = new_path - eval_image_tsv = copy_eval_image_tsv(cmdline_args[EVAL_IMAGE_TSV], opt[SAVE_DIR]) - eval_text_tsv = copy_eval_image_tsv(cmdline_args[EVAL_TEXT_TSV], opt[SAVE_DIR]) - image_tsv = copy_eval_image_tsv(cmdline_args[IMAGE_TSV], opt[SAVE_DIR]) - text_tsv = copy_eval_image_tsv(cmdline_args[TEXT_TSV], opt[SAVE_DIR]) - label_file = copy_eval_image_tsv(cmdline_args[LABEL_FILE], opt[SAVE_DIR]) + eval_image_tsv = copy_tsv(cmdline_args[EVAL_IMAGE_TSV], opt[SAVE_DIR]) + eval_text_tsv = copy_tsv(cmdline_args[EVAL_TEXT_TSV], opt[SAVE_DIR]) + image_tsv = copy_tsv(cmdline_args[IMAGE_TSV], opt[SAVE_DIR]) + text_tsv = copy_tsv(cmdline_args[TEXT_TSV], opt[SAVE_DIR]) + label_file = copy_tsv(cmdline_args[LABEL_FILE], opt[SAVE_DIR]) if DATASET in opt and ROOT in opt[DATASET]: opt[DATASET]["TRAIN_TSV_LIST"] = [image_tsv, text_tsv]