diff --git a/acceleration/TensorRT_inference_acceleration.ipynb b/acceleration/TensorRT_inference_acceleration.ipynb index 23e72032be..09cb00afd0 100644 --- a/acceleration/TensorRT_inference_acceleration.ipynb +++ b/acceleration/TensorRT_inference_acceleration.ipynb @@ -275,6 +275,7 @@ " config_file=inference_config,\n", " meta_file=meta_config,\n", " logging_file=os.path.join(bundle_path, \"configs\", \"logging.conf\"),\n", + " bundle_root=bundle_path,\n", ")\n", "\n", "workflow.initialize()\n", @@ -622,7 +623,7 @@ "metadata": {}, "source": [ "### Benchmark the TensorRT fp32 and fp16 models\n", - "In this part, the `trt_fp32_model` and `trt_fp16_model` are loaded to the `workflow`. The updated `workflow` runs the same iterations as before to benchmark the latency difference. Since the `trt_fp32_model` and `trt_fp16_model` cannot be loaded through the `CheckpointLoader` and don't have `amp` mode, the `CheckpointLoader` in the `handlers` of the `workflow` needs to be removed and the `amp` parameter in the `evaluator` of the `workflow` needs to be set to `False`.\n", + "In this part, the `trt_fp32_model` and `trt_fp16_model` are loaded to the `workflow`. The updated `workflow` runs the same iterations as before to benchmark the latency difference. Since the `trt_fp32_model` and `trt_fp16_model` cannot be loaded through the `CheckpointLoader` and don't have `amp` mode, disable the `CheckpointLoader` in the `initialize` of the `workflow` and the `amp` parameter in the `evaluator` of the `workflow` needs to be set to `False`.\n", "\n", "The `POST_PROCESS` and `PREPARE_BATCH` stages require a considerable amount of time. Although the model forward time is much improved, there is still room for acceleration in reducing the end-to-end latency on this particular MONAI bundle." ] @@ -633,20 +634,9 @@ "metadata": {}, "outputs": [], "source": [ - "def pop_checkpoint_loader(handlers):\n", - " pop_index = -1\n", - " for cnt, obj in enumerate(handlers):\n", - " if isinstance(obj, monai.handlers.CheckpointLoader):\n", - " pop_index = cnt\n", - " break\n", - " if pop_index >= 0:\n", - " handlers.pop(pop_index)\n", - "\n", - "\n", "workflow.initialize()\n", - "inference_handlers = workflow.handlers\n", - "pop_checkpoint_loader(inference_handlers)\n", - "workflow.handlers = inference_handlers\n", + "workflow.add_property(\"load_pretrain\", True, \"load_pretrain\")\n", + "workflow.load_pretrain = False\n", "workflow.network_def = trt_fp32_model\n", "\n", "workflow.initialize()\n",