Skip to content

Commit

Permalink
Added everything neccessary for gradient clipping and accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
sjohn4 committed Feb 16, 2025
1 parent 1a40faf commit dca05a3
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 68 deletions.
4 changes: 3 additions & 1 deletion modyn/common/grpc/grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def prepare_start_training_request(
checkpoint_info = CheckpointInfo(checkpoint_interval=0, checkpoint_path="")

grad_scaler_config = training_config.grad_scaler_config if training_config.grad_scaler_config else {}

assert training_config.gradient_accumulation_steps > 0, "Gradient accumulation steps should be greater than 0"
return StartTrainingRequest(
pipeline_id=pipeline_id,
trigger_id=trigger_id,
Expand Down Expand Up @@ -251,6 +251,8 @@ def prepare_start_training_request(
enable_accurate_gpu_measurements=training_config.enable_accurate_gpu_measurements,
record_loss_every=training_config.record_loss_every,
drop_last_batch=training_config.drop_last_batch,
grad_norm=training_config.grad_norm if training_config.grad_norm != 0.0 else None,
gradient_accumulation_steps=training_config.gradient_accumulation_steps,
)

def start_training(
Expand Down
12 changes: 10 additions & 2 deletions modyn/config/schema/pipeline/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from modyn.config.schema.base_model import ModynBaseModel

OptimizerSource = Literal["PyTorch", "APEX"]
OptimizerSource = Literal["PyTorch", "APEX", "HuggingFace"]


class OptimizerParamGroup(ModynBaseModel):
Expand Down Expand Up @@ -119,6 +119,7 @@ class TrainingConfig(ModynBaseModel):
"we start with random weights. If initial_model is 'pretrained', cannot be False."
)
)

seed: int | None = Field(
None,
description=(
Expand Down Expand Up @@ -154,7 +155,14 @@ class TrainingConfig(ModynBaseModel):
None,
description="Configuration for the torch.cuda.amp.GradScaler. Effective only when amp is enabled.",
)

grad_norm: float = Field(
default=0,
description="Clips the gradients normed over this value, if its 0 it will not be used.",
)
gradient_accumulation_steps: int = Field(
default=1,
description="Number of steps to accumulate gradients over.",
)
# [Additional validation]

@field_validator("gpus")
Expand Down
2 changes: 2 additions & 0 deletions modyn/protos/trainer_server.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ message StartTrainingRequest {
bool enable_accurate_gpu_measurements = 25;
int64 record_loss_every = 26;
bool drop_last_batch = 27;
optional float grad_norm = 29;
optional int64 gradient_accumulation_steps = 30;
}

message StartTrainingResponse {
Expand Down
37 changes: 19 additions & 18 deletions modyn/trainer_server/internal/grpc/generated/trainer_server_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class StartTrainingRequest(google.protobuf.message.Message):
ENABLE_ACCURATE_GPU_MEASUREMENTS_FIELD_NUMBER: builtins.int
RECORD_LOSS_EVERY_FIELD_NUMBER: builtins.int
DROP_LAST_BATCH_FIELD_NUMBER: builtins.int
GRAD_NORM_FIELD_NUMBER: builtins.int
GRADIENT_ACCUMULATION_STEPS_FIELD_NUMBER: builtins.int
pipeline_id: builtins.int
trigger_id: builtins.int
device: builtins.str
Expand All @@ -158,6 +160,8 @@ class StartTrainingRequest(google.protobuf.message.Message):
enable_accurate_gpu_measurements: builtins.bool
record_loss_every: builtins.int
drop_last_batch: builtins.bool
grad_norm: builtins.float
gradient_accumulation_steps: builtins.int
@property
def torch_optimizers_configuration(self) -> global___JsonString: ...
@property
Expand Down Expand Up @@ -208,10 +212,16 @@ class StartTrainingRequest(google.protobuf.message.Message):
enable_accurate_gpu_measurements: builtins.bool = ...,
record_loss_every: builtins.int = ...,
drop_last_batch: builtins.bool = ...,
grad_norm: builtins.float | None = ...,
gradient_accumulation_steps: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing.Literal[
"_grad_norm",
b"_grad_norm",
"_gradient_accumulation_steps",
b"_gradient_accumulation_steps",
"_seed",
b"_seed",
"_tokenizer",
Expand All @@ -224,8 +234,12 @@ class StartTrainingRequest(google.protobuf.message.Message):
b"criterion_parameters",
"data_info",
b"data_info",
"grad_norm",
b"grad_norm",
"grad_scaler_configuration",
b"grad_scaler_configuration",
"gradient_accumulation_steps",
b"gradient_accumulation_steps",
"label_transformer",
b"label_transformer",
"lr_scheduler",
Expand All @@ -241,6 +255,10 @@ class StartTrainingRequest(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing.Literal[
"_grad_norm",
b"_grad_norm",
"_gradient_accumulation_steps",
b"_gradient_accumulation_steps",
"_seed",
b"_seed",
"_tokenizer",
Expand All @@ -263,8 +281,12 @@ class StartTrainingRequest(google.protobuf.message.Message):
b"enable_accurate_gpu_measurements",
"epochs_per_trigger",
b"epochs_per_trigger",
"grad_norm",
b"grad_norm",
"grad_scaler_configuration",
b"grad_scaler_configuration",
"gradient_accumulation_steps",
b"gradient_accumulation_steps",
"label_transformer",
b"label_transformer",
"load_optimizer_state",
Expand Down Expand Up @@ -302,6 +324,14 @@ class StartTrainingRequest(google.protobuf.message.Message):
],
) -> None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing.Literal["_grad_norm", b"_grad_norm"]
) -> typing.Literal["grad_norm"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing.Literal["_gradient_accumulation_steps", b"_gradient_accumulation_steps"]
) -> typing.Literal["gradient_accumulation_steps"] | None: ...
@typing.overload
def WhichOneof(self, oneof_group: typing.Literal["_seed", b"_seed"]) -> typing.Literal["seed"] | None: ...
@typing.overload
def WhichOneof(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""

import grpc
import warnings

import grpc
import modyn.trainer_server.internal.grpc.generated.trainer_server_pb2 as trainer__server__pb2

GRPC_GENERATED_VERSION = "1.63.0"
GRPC_GENERATED_VERSION = "1.67.1"
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = "1.65.0"
SCHEDULED_RELEASE_DATE = "June 25, 2024"
_version_not_supported = False

try:
Expand All @@ -20,15 +19,12 @@
_version_not_supported = True

if _version_not_supported:
warnings.warn(
raise RuntimeError(
f"The grpc package installed is at version {GRPC_VERSION},"
+ f" but the generated code in trainer_server_pb2_grpc.py depends on"
+ f" grpcio>={GRPC_GENERATED_VERSION}."
+ f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}"
+ f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}."
+ f" This warning will become an error in {EXPECTED_ERROR_RELEASE},"
+ f" scheduled for release on {SCHEDULED_RELEASE_DATE}.",
RuntimeWarning,
)


Expand Down Expand Up @@ -137,6 +133,7 @@ def add_TrainerServerServicer_to_server(servicer, server):
}
generic_handler = grpc.method_handlers_generic_handler("trainer.TrainerServer", rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers("trainer.TrainerServer", rpc_method_handlers)


# This class is part of an EXPERIMENTAL API.
Expand Down
Loading

0 comments on commit dca05a3

Please sign in to comment.