From dca05a326cfab175c489ecb2fe2f1c4d05d1bc4e Mon Sep 17 00:00:00 2001 From: sjohn4 Date: Sun, 16 Feb 2025 18:41:47 +0100 Subject: [PATCH] Added everything neccessary for gradient clipping and accumulation --- modyn/common/grpc/grpc_helpers.py | 4 +- .../config/schema/pipeline/training/config.py | 12 ++- modyn/protos/trainer_server.proto | 2 + .../grpc/generated/trainer_server_pb2.py | 37 ++++---- .../grpc/generated/trainer_server_pb2.pyi | 30 +++++++ .../grpc/generated/trainer_server_pb2_grpc.py | 11 +-- .../internal/trainer/pytorch_trainer.py | 85 ++++++++++--------- .../internal/utils/training_info.py | 5 +- 8 files changed, 118 insertions(+), 68 deletions(-) diff --git a/modyn/common/grpc/grpc_helpers.py b/modyn/common/grpc/grpc_helpers.py index f115d0c22..632c997e6 100644 --- a/modyn/common/grpc/grpc_helpers.py +++ b/modyn/common/grpc/grpc_helpers.py @@ -218,7 +218,7 @@ def prepare_start_training_request( checkpoint_info = CheckpointInfo(checkpoint_interval=0, checkpoint_path="") grad_scaler_config = training_config.grad_scaler_config if training_config.grad_scaler_config else {} - + assert training_config.gradient_accumulation_steps > 0, "Gradient accumulation steps should be greater than 0" return StartTrainingRequest( pipeline_id=pipeline_id, trigger_id=trigger_id, @@ -251,6 +251,8 @@ def prepare_start_training_request( enable_accurate_gpu_measurements=training_config.enable_accurate_gpu_measurements, record_loss_every=training_config.record_loss_every, drop_last_batch=training_config.drop_last_batch, + grad_norm=training_config.grad_norm if training_config.grad_norm != 0.0 else None, + gradient_accumulation_steps=training_config.gradient_accumulation_steps, ) def start_training( diff --git a/modyn/config/schema/pipeline/training/config.py b/modyn/config/schema/pipeline/training/config.py index b3f673553..10b1df229 100644 --- a/modyn/config/schema/pipeline/training/config.py +++ b/modyn/config/schema/pipeline/training/config.py @@ -7,7 +7,7 @@ from modyn.config.schema.base_model import ModynBaseModel -OptimizerSource = Literal["PyTorch", "APEX"] +OptimizerSource = Literal["PyTorch", "APEX", "HuggingFace"] class OptimizerParamGroup(ModynBaseModel): @@ -119,6 +119,7 @@ class TrainingConfig(ModynBaseModel): "we start with random weights. If initial_model is 'pretrained', cannot be False." ) ) + seed: int | None = Field( None, description=( @@ -154,7 +155,14 @@ class TrainingConfig(ModynBaseModel): None, description="Configuration for the torch.cuda.amp.GradScaler. Effective only when amp is enabled.", ) - + grad_norm: float = Field( + default=0, + description="Clips the gradients normed over this value, if its 0 it will not be used.", + ) + gradient_accumulation_steps: int = Field( + default=1, + description="Number of steps to accumulate gradients over.", + ) # [Additional validation] @field_validator("gpus") diff --git a/modyn/protos/trainer_server.proto b/modyn/protos/trainer_server.proto index 0c7343c70..4de2b3cbc 100644 --- a/modyn/protos/trainer_server.proto +++ b/modyn/protos/trainer_server.proto @@ -59,6 +59,8 @@ message StartTrainingRequest { bool enable_accurate_gpu_measurements = 25; int64 record_loss_every = 26; bool drop_last_batch = 27; + optional float grad_norm = 29; + optional int64 gradient_accumulation_steps = 30; } message StartTrainingResponse { diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py index c3a21053b..8fe0d4c01 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: trainer_server.proto # Protobuf Python Version: 5.26.1 """Generated protocol buffer code.""" @@ -15,7 +16,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x14trainer_server.proto\x12\x07trainer"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05"\x19\n\x17TrainerAvailableRequest"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t"\xbb\x07\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x1c\n\x14use_pretrained_model\x18\x04 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x05 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\x06 \x01(\x05\x12\x12\n\nbatch_size\x18\x07 \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x08 \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\t \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\n \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0b \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0c \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\r \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x0e \x03(\t\x12)\n\x0clr_scheduler\x18\x0f \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x11 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x12 \x01(\x05\x12!\n\x19num_prefetched_partitions\x18\x13 \x01(\x05\x12"\n\x1aparallel_prefetch_requests\x18\x14 \x01(\x05\x12\x11\n\x04seed\x18\x15 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x16 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x12\x1b\n\x13num_samples_to_pass\x18\x17 \x01(\x03\x12\x0f\n\x07shuffle\x18\x18 \x01(\x08\x12(\n enable_accurate_gpu_measurements\x18\x19 \x01(\x08\x12\x19\n\x11record_loss_every\x18\x1a \x01(\x03\x12\x17\n\x0f\x64rop_last_batch\x18\x1b \x01(\x08\x42\x07\n\x05_seedB\x0c\n\n_tokenizer"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"\xa6\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12 \n\x03log\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x16\n\texception\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x08 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\t \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\n \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\x0b \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse"\x00\x62\x06proto3' + b'\n\x14trainer_server.proto\x12\x07trainer"\x1b\n\nJsonString\x12\r\n\x05value\x18\x01 \x01(\t"\x1d\n\x0cPythonString\x12\r\n\x05value\x18\x01 \x01(\t"3\n\x04\x44\x61ta\x12\x12\n\ndataset_id\x18\x01 \x01(\t\x12\x17\n\x0fnum_dataloaders\x18\x02 \x01(\x05"\x19\n\x17TrainerAvailableRequest"-\n\x18TrainerAvailableResponse\x12\x11\n\tavailable\x18\x01 \x01(\x08"F\n\x0e\x43heckpointInfo\x12\x1b\n\x13\x63heckpoint_interval\x18\x01 \x01(\x05\x12\x17\n\x0f\x63heckpoint_path\x18\x02 \x01(\t"\xab\x08\n\x14StartTrainingRequest\x12\x13\n\x0bpipeline_id\x18\x01 \x01(\x05\x12\x12\n\ntrigger_id\x18\x02 \x01(\x05\x12\x0e\n\x06\x64\x65vice\x18\x03 \x01(\t\x12\x1c\n\x14use_pretrained_model\x18\x04 \x01(\x08\x12\x1c\n\x14load_optimizer_state\x18\x05 \x01(\x08\x12\x1b\n\x13pretrained_model_id\x18\x06 \x01(\x05\x12\x12\n\nbatch_size\x18\x07 \x01(\x05\x12;\n\x1etorch_optimizers_configuration\x18\x08 \x01(\x0b\x32\x13.trainer.JsonString\x12\x17\n\x0ftorch_criterion\x18\t \x01(\t\x12\x31\n\x14\x63riterion_parameters\x18\n \x01(\x0b\x32\x13.trainer.JsonString\x12 \n\tdata_info\x18\x0b \x01(\x0b\x32\r.trainer.Data\x12\x30\n\x0f\x63heckpoint_info\x18\x0c \x01(\x0b\x32\x17.trainer.CheckpointInfo\x12+\n\x0c\x62ytes_parser\x18\r \x01(\x0b\x32\x15.trainer.PythonString\x12\x16\n\x0etransform_list\x18\x0e \x03(\t\x12)\n\x0clr_scheduler\x18\x0f \x01(\x0b\x32\x13.trainer.JsonString\x12\x30\n\x11label_transformer\x18\x10 \x01(\x0b\x32\x15.trainer.PythonString\x12\x36\n\x19grad_scaler_configuration\x18\x11 \x01(\x0b\x32\x13.trainer.JsonString\x12\x1a\n\x12\x65pochs_per_trigger\x18\x12 \x01(\x05\x12!\n\x19num_prefetched_partitions\x18\x13 \x01(\x05\x12"\n\x1aparallel_prefetch_requests\x18\x14 \x01(\x05\x12\x11\n\x04seed\x18\x15 \x01(\x05H\x00\x88\x01\x01\x12-\n\ttokenizer\x18\x16 \x01(\x0b\x32\x15.trainer.PythonStringH\x01\x88\x01\x01\x12\x1b\n\x13num_samples_to_pass\x18\x17 \x01(\x03\x12\x0f\n\x07shuffle\x18\x18 \x01(\x08\x12(\n enable_accurate_gpu_measurements\x18\x19 \x01(\x08\x12\x19\n\x11record_loss_every\x18\x1a \x01(\x03\x12\x17\n\x0f\x64rop_last_batch\x18\x1b \x01(\x08\x12\x16\n\tgrad_norm\x18\x1d \x01(\x02H\x02\x88\x01\x01\x12(\n\x1bgradient_accumulation_steps\x18\x1e \x01(\x03H\x03\x88\x01\x01\x42\x07\n\x05_seedB\x0c\n\n_tokenizerB\x0c\n\n_grad_normB\x1e\n\x1c_gradient_accumulation_steps"F\n\x15StartTrainingResponse\x12\x18\n\x10training_started\x18\x01 \x01(\x08\x12\x13\n\x0btraining_id\x18\x02 \x01(\x05",\n\x15TrainingStatusRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"\xa6\x03\n\x16TrainingStatusResponse\x12\r\n\x05valid\x18\x01 \x01(\x08\x12\x12\n\nis_running\x18\x02 \x01(\x08\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x17\n\x0fstate_available\x18\x04 \x01(\x08\x12\x0f\n\x07\x62locked\x18\x05 \x01(\x08\x12 \n\x03log\x18\x06 \x01(\x0b\x32\x13.trainer.JsonString\x12\x16\n\texception\x18\x07 \x01(\tH\x00\x88\x01\x01\x12\x19\n\x0c\x62\x61tches_seen\x18\x08 \x01(\x03H\x01\x88\x01\x01\x12\x19\n\x0csamples_seen\x18\t \x01(\x03H\x02\x88\x01\x01\x12&\n\x19\x64ownsampling_batches_seen\x18\n \x01(\x03H\x03\x88\x01\x01\x12&\n\x19\x64ownsampling_samples_seen\x18\x0b \x01(\x03H\x04\x88\x01\x01\x42\x0c\n\n_exceptionB\x0f\n\r_batches_seenB\x0f\n\r_samples_seenB\x1c\n\x1a_downsampling_batches_seenB\x1c\n\x1a_downsampling_samples_seen"-\n\x16StoreFinalModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"@\n\x17StoreFinalModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x10\n\x08model_id\x18\x02 \x01(\x05",\n\x15GetLatestModelRequest\x12\x13\n\x0btraining_id\x18\x01 \x01(\x05"A\n\x16GetLatestModelResponse\x12\x13\n\x0bvalid_state\x18\x01 \x01(\x08\x12\x12\n\nmodel_path\x18\x02 \x01(\t2\xc9\x03\n\rTrainerServer\x12Z\n\x11trainer_available\x12 .trainer.TrainerAvailableRequest\x1a!.trainer.TrainerAvailableResponse"\x00\x12Q\n\x0estart_training\x12\x1d.trainer.StartTrainingRequest\x1a\x1e.trainer.StartTrainingResponse"\x00\x12X\n\x13get_training_status\x12\x1e.trainer.TrainingStatusRequest\x1a\x1f.trainer.TrainingStatusResponse"\x00\x12X\n\x11store_final_model\x12\x1f.trainer.StoreFinalModelRequest\x1a .trainer.StoreFinalModelResponse"\x00\x12U\n\x10get_latest_model\x12\x1e.trainer.GetLatestModelRequest\x1a\x1f.trainer.GetLatestModelResponse"\x00\x62\x06proto3' ) _globals = globals() @@ -36,21 +37,21 @@ _globals["_CHECKPOINTINFO"]._serialized_start = 220 _globals["_CHECKPOINTINFO"]._serialized_end = 290 _globals["_STARTTRAININGREQUEST"]._serialized_start = 293 - _globals["_STARTTRAININGREQUEST"]._serialized_end = 1248 - _globals["_STARTTRAININGRESPONSE"]._serialized_start = 1250 - _globals["_STARTTRAININGRESPONSE"]._serialized_end = 1320 - _globals["_TRAININGSTATUSREQUEST"]._serialized_start = 1322 - _globals["_TRAININGSTATUSREQUEST"]._serialized_end = 1366 - _globals["_TRAININGSTATUSRESPONSE"]._serialized_start = 1369 - _globals["_TRAININGSTATUSRESPONSE"]._serialized_end = 1791 - _globals["_STOREFINALMODELREQUEST"]._serialized_start = 1793 - _globals["_STOREFINALMODELREQUEST"]._serialized_end = 1838 - _globals["_STOREFINALMODELRESPONSE"]._serialized_start = 1840 - _globals["_STOREFINALMODELRESPONSE"]._serialized_end = 1904 - _globals["_GETLATESTMODELREQUEST"]._serialized_start = 1906 - _globals["_GETLATESTMODELREQUEST"]._serialized_end = 1950 - _globals["_GETLATESTMODELRESPONSE"]._serialized_start = 1952 - _globals["_GETLATESTMODELRESPONSE"]._serialized_end = 2017 - _globals["_TRAINERSERVER"]._serialized_start = 2020 - _globals["_TRAINERSERVER"]._serialized_end = 2477 + _globals["_STARTTRAININGREQUEST"]._serialized_end = 1360 + _globals["_STARTTRAININGRESPONSE"]._serialized_start = 1362 + _globals["_STARTTRAININGRESPONSE"]._serialized_end = 1432 + _globals["_TRAININGSTATUSREQUEST"]._serialized_start = 1434 + _globals["_TRAININGSTATUSREQUEST"]._serialized_end = 1478 + _globals["_TRAININGSTATUSRESPONSE"]._serialized_start = 1481 + _globals["_TRAININGSTATUSRESPONSE"]._serialized_end = 1903 + _globals["_STOREFINALMODELREQUEST"]._serialized_start = 1905 + _globals["_STOREFINALMODELREQUEST"]._serialized_end = 1950 + _globals["_STOREFINALMODELRESPONSE"]._serialized_start = 1952 + _globals["_STOREFINALMODELRESPONSE"]._serialized_end = 2016 + _globals["_GETLATESTMODELREQUEST"]._serialized_start = 2018 + _globals["_GETLATESTMODELREQUEST"]._serialized_end = 2062 + _globals["_GETLATESTMODELRESPONSE"]._serialized_start = 2064 + _globals["_GETLATESTMODELRESPONSE"]._serialized_end = 2129 + _globals["_TRAINERSERVER"]._serialized_start = 2132 + _globals["_TRAINERSERVER"]._serialized_end = 2589 # @@protoc_insertion_point(module_scope) diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi index 96b20dde5..4ed7a2f53 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.pyi @@ -141,6 +141,8 @@ class StartTrainingRequest(google.protobuf.message.Message): ENABLE_ACCURATE_GPU_MEASUREMENTS_FIELD_NUMBER: builtins.int RECORD_LOSS_EVERY_FIELD_NUMBER: builtins.int DROP_LAST_BATCH_FIELD_NUMBER: builtins.int + GRAD_NORM_FIELD_NUMBER: builtins.int + GRADIENT_ACCUMULATION_STEPS_FIELD_NUMBER: builtins.int pipeline_id: builtins.int trigger_id: builtins.int device: builtins.str @@ -158,6 +160,8 @@ class StartTrainingRequest(google.protobuf.message.Message): enable_accurate_gpu_measurements: builtins.bool record_loss_every: builtins.int drop_last_batch: builtins.bool + grad_norm: builtins.float + gradient_accumulation_steps: builtins.int @property def torch_optimizers_configuration(self) -> global___JsonString: ... @property @@ -208,10 +212,16 @@ class StartTrainingRequest(google.protobuf.message.Message): enable_accurate_gpu_measurements: builtins.bool = ..., record_loss_every: builtins.int = ..., drop_last_batch: builtins.bool = ..., + grad_norm: builtins.float | None = ..., + gradient_accumulation_steps: builtins.int | None = ..., ) -> None: ... def HasField( self, field_name: typing.Literal[ + "_grad_norm", + b"_grad_norm", + "_gradient_accumulation_steps", + b"_gradient_accumulation_steps", "_seed", b"_seed", "_tokenizer", @@ -224,8 +234,12 @@ class StartTrainingRequest(google.protobuf.message.Message): b"criterion_parameters", "data_info", b"data_info", + "grad_norm", + b"grad_norm", "grad_scaler_configuration", b"grad_scaler_configuration", + "gradient_accumulation_steps", + b"gradient_accumulation_steps", "label_transformer", b"label_transformer", "lr_scheduler", @@ -241,6 +255,10 @@ class StartTrainingRequest(google.protobuf.message.Message): def ClearField( self, field_name: typing.Literal[ + "_grad_norm", + b"_grad_norm", + "_gradient_accumulation_steps", + b"_gradient_accumulation_steps", "_seed", b"_seed", "_tokenizer", @@ -263,8 +281,12 @@ class StartTrainingRequest(google.protobuf.message.Message): b"enable_accurate_gpu_measurements", "epochs_per_trigger", b"epochs_per_trigger", + "grad_norm", + b"grad_norm", "grad_scaler_configuration", b"grad_scaler_configuration", + "gradient_accumulation_steps", + b"gradient_accumulation_steps", "label_transformer", b"label_transformer", "load_optimizer_state", @@ -302,6 +324,14 @@ class StartTrainingRequest(google.protobuf.message.Message): ], ) -> None: ... @typing.overload + def WhichOneof( + self, oneof_group: typing.Literal["_grad_norm", b"_grad_norm"] + ) -> typing.Literal["grad_norm"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing.Literal["_gradient_accumulation_steps", b"_gradient_accumulation_steps"] + ) -> typing.Literal["gradient_accumulation_steps"] | None: ... + @typing.overload def WhichOneof(self, oneof_group: typing.Literal["_seed", b"_seed"]) -> typing.Literal["seed"] | None: ... @typing.overload def WhichOneof( diff --git a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py index a1978b37a..61f4e8a31 100644 --- a/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py +++ b/modyn/trainer_server/internal/grpc/generated/trainer_server_pb2_grpc.py @@ -1,15 +1,14 @@ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" +import grpc import warnings import grpc import modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 as trainer__server__pb2 -GRPC_GENERATED_VERSION = "1.63.0" +GRPC_GENERATED_VERSION = "1.67.1" GRPC_VERSION = grpc.__version__ -EXPECTED_ERROR_RELEASE = "1.65.0" -SCHEDULED_RELEASE_DATE = "June 25, 2024" _version_not_supported = False try: @@ -20,15 +19,12 @@ _version_not_supported = True if _version_not_supported: - warnings.warn( + raise RuntimeError( f"The grpc package installed is at version {GRPC_VERSION}," + f" but the generated code in trainer_server_pb2_grpc.py depends on" + f" grpcio>={GRPC_GENERATED_VERSION}." + f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}" + f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}." - + f" This warning will become an error in {EXPECTED_ERROR_RELEASE}," - + f" scheduled for release on {SCHEDULED_RELEASE_DATE}.", - RuntimeWarning, ) @@ -137,6 +133,7 @@ def add_TrainerServerServicer_to_server(servicer, server): } generic_handler = grpc.method_handlers_generic_handler("trainer.TrainerServer", rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers("trainer.TrainerServer", rpc_method_handlers) # This class is part of an EXPERIMENTAL API. diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index 47acce439..652f746b3 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -85,9 +85,9 @@ def __init__( self.pipeline_id = training_info.pipeline_id self.training_id = training_info.training_id self.trigger_id = training_info.trigger_id - + self._grad_norm = training_info.grad_norm self.selector_stub = self.connect_to_selector(training_info.selector_address) - + self.gradient_accumulation_steps = training_info.gradient_accumulation_steps if training_info.seed is not None: self._seed_trainer_server(training_info.seed) self._info("Everything seeded") @@ -258,6 +258,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches stopw.start("IndivFetchBatch", overwrite=True) stopw.start("FetchBatch", resume=True) + + accumulation_counter = 0 # Initialize accumulation counter. for batch in self._train_dataloader: stopw.stop("FetchBatch") batch_timings.append(stopw.stop("IndivFetchBatch")) @@ -277,9 +279,9 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches # model output is a torch.FloatTensor but weights is a torch.DoubleTensor. # We need to cast to do the dot product weights = batch[3].float().to(self._device) - - for _, optimizer in self._optimizers.items(): - optimizer.zero_grad() + if accumulation_counter == 0: # zero grad is moved here + for _, optimizer in self._optimizers.items(): + optimizer.zero_grad() with torch.autocast(self._device_type, enabled=self._amp): if self._downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE: @@ -316,39 +318,46 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches with GPUMeasurement(self._measure_gpu_ops, "Backward", self._device, stopw, resume=True): self._scaler.scale(loss).backward() - - with GPUMeasurement(self._measure_gpu_ops, "OptimizerStep", self._device, stopw, resume=True): - for _, optimizer in self._optimizers.items(): - self._scaler.step(optimizer) - - self._scaler.update() - trained_batches += 1 - - self._step_lr_if_necessary(True) - - if self._checkpoint_interval > 0 and trained_batches % self._checkpoint_interval == 0: - stopw.start("Checkpoint", resume=True) - checkpoint_file_name = self._checkpoint_path / f"model_{trained_batches}.modyn" - self.save_state(checkpoint_file_name, trained_batches) - stopw.stop("Checkpoint") - - if self._record_loss_every > 0 and trained_batches % self._record_loss_every == 0: - training_loss.append(loss.item()) - - self._num_samples += len(sample_ids) - - stopw.start("OnBatchEnd", resume=True) - for _, callback in self._callbacks.items(): - callback.on_batch_end( - self._model.model, self._optimizers, trained_batches, sample_ids, data, target, output, loss - ) - stopw.stop() - if 0 < self.num_samples_to_pass <= self._num_samples: - self._info("Stopping training as we have reached the sample threshold.") - break - stopw.start("FetchBatch", resume=True) - stopw.start("IndivFetchBatch", overwrite=True) - + if self._grad_norm is not None: + torch.nn.utils.clip_grad_norm_(self._model.model.parameters(), max_norm=self._grad_norm) + accumulation_counter += 1 + if accumulation_counter == self.gradient_accumulation_steps: + with GPUMeasurement(self._measure_gpu_ops, "OptimizerStep", self._device, stopw, resume=True): + for _, optimizer in self._optimizers.items(): + self._scaler.step(optimizer) + self._scaler.update() + trained_batches += 1 # Increment trained_batches only on optimizer update. + accumulation_counter = 0 # Reset accumulation counter. + self._step_lr_if_necessary(True) # Step LR scheduler after optimizer update. + if self._checkpoint_interval > 0 and trained_batches % self._checkpoint_interval == 0: + stopw.start("Checkpoint", resume=True) + checkpoint_file_name = self._checkpoint_path / f"model_{trained_batches}.modyn" + self.save_state(checkpoint_file_name, trained_batches) + stopw.stop("Checkpoint") + if self._record_loss_every > 0 and trained_batches % self._record_loss_every == 0: + training_loss.append(loss.item()) + print(loss.item()) + # Log loss and batch number + log_file = self._checkpoint_path / "training_log.txt" + with ( + open(log_file, "a") as f # pylint: disable=unspecified-encoding + ): # 'a' mode appends if the file exists, else creates it + f.write(f"{trained_batches},{loss.item()}\n") + # Example: Logging training losses in a loop + + self._num_samples += len(sample_ids) + + stopw.start("OnBatchEnd", resume=True) + for _, callback in self._callbacks.items(): + callback.on_batch_end( + self._model.model, self._optimizers, trained_batches, sample_ids, data, target, output, loss + ) + stopw.stop() + if 0 < self.num_samples_to_pass <= self._num_samples: + self._info("Stopping training as we have reached the sample threshold.") + break + stopw.start("FetchBatch", resume=True) + stopw.start("IndivFetchBatch", overwrite=True) self._step_lr_if_necessary(False) if len(batch_timings) <= 100000: diff --git a/modyn/trainer_server/internal/utils/training_info.py b/modyn/trainer_server/internal/utils/training_info.py index 07a246b35..73adc3c8f 100644 --- a/modyn/trainer_server/internal/utils/training_info.py +++ b/modyn/trainer_server/internal/utils/training_info.py @@ -57,7 +57,7 @@ def __init__( self.shuffle = request.shuffle self.enable_accurate_gpu_measurements = request.enable_accurate_gpu_measurements - + self.generative = request.generative assert ( self.pretrained_model_path or not self.use_pretrained_model ), "Inconsistent pretrained model configuration" @@ -80,5 +80,6 @@ def __init__( self.seed: int | None = request.seed if request.HasField("seed") else None self.tokenizer: str | None = request.tokenizer.value if request.HasField("tokenizer") else None - + self.grad_norm: float | None = request.grad_norm if request.HasField("grad_norm") else None + self.gradient_accumulation_steps: int = request.gradient_accumulation_steps self.offline_dataset_path = offline_dataset_path