diff --git a/cmd/blockchaincmd/add_validator.go b/cmd/blockchaincmd/add_validator.go index 78f0bd763..e34b68f31 100644 --- a/cmd/blockchaincmd/add_validator.go +++ b/cmd/blockchaincmd/add_validator.go @@ -376,7 +376,7 @@ func CallAddValidator( } } if duration == 0 { - duration, err = PromptDuration(time.Now(), network) + duration, err = PromptDuration(time.Now(), network, true) // it's pos if err != nil { return nil } @@ -651,17 +651,19 @@ func CallAddValidatorNonSOV( return err } -func PromptDuration(start time.Time, network models.Network) (time.Duration, error) { +func PromptDuration(start time.Time, network models.Network, isPos bool) (time.Duration, error) { for { txt := "How long should this validator be validating? Enter a duration, e.g. 8760h. Valid time units are \"ns\", \"us\" (or \"µs\"), \"ms\", \"s\", \"m\", \"h\"" var d time.Duration var err error - switch network.Kind { - case models.Fuji: + switch { + case network.Kind == models.Fuji: d, err = app.Prompt.CaptureFujiDuration(txt) - case models.Mainnet: + case network.Kind == models.Mainnet && isPos: + d, err = app.Prompt.CaptureMainnetL1StakingDuration(txt) + case network.Kind == models.Mainnet && !isPos: d, err = app.Prompt.CaptureMainnetDuration(txt) - case models.EtnaDevnet: + case network.Kind == models.EtnaDevnet: d, err = app.Prompt.CaptureEtnaDuration(txt) default: d, err = app.Prompt.CaptureDuration(txt) @@ -756,7 +758,7 @@ func getTimeParameters(network models.Network, nodeID ids.NodeID, isValidator bo case defaultDurationOption: useDefaultDuration = true default: - duration, err = PromptDuration(start, network) + duration, err = PromptDuration(start, network, false) // notSoV if err != nil { return time.Time{}, 0, err } diff --git a/cmd/nodecmd/validate_primary.go b/cmd/nodecmd/validate_primary.go index ddf755492..3331d9a3b 100644 --- a/cmd/nodecmd/validate_primary.go +++ b/cmd/nodecmd/validate_primary.go @@ -206,7 +206,7 @@ func GetTimeParametersPrimaryNetwork(network models.Network, nodeIndex int, vali } default: useCustomDuration = true - duration, err = blockchaincmd.PromptDuration(start, network) + duration, err = blockchaincmd.PromptDuration(start, network, false) // not L1 if err != nil { return time.Time{}, 0, err } diff --git a/internal/mocks/prompter.go b/internal/mocks/prompter.go index 74a552840..24e44b1d6 100644 --- a/internal/mocks/prompter.go +++ b/internal/mocks/prompter.go @@ -508,6 +508,34 @@ func (_m *Prompter) CaptureMainnetDuration(promptStr string) (time.Duration, err return r0, r1 } +// CaptureMainnetL1StakingDuration provides a mock function with given fields: promptStr +func (_m *Prompter) CaptureMainnetL1StakingDuration(promptStr string) (time.Duration, error) { + ret := _m.Called(promptStr) + + if len(ret) == 0 { + panic("no return value specified for CaptureMainnetDuration") + } + + var r0 time.Duration + var r1 error + if rf, ok := ret.Get(0).(func(string) (time.Duration, error)); ok { + return rf(promptStr) + } + if rf, ok := ret.Get(0).(func(string) time.Duration); ok { + r0 = rf(promptStr) + } else { + r0 = ret.Get(0).(time.Duration) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(promptStr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // CaptureNewFilepath provides a mock function with given fields: promptStr func (_m *Prompter) CaptureNewFilepath(promptStr string) (string, error) { ret := _m.Called(promptStr) diff --git a/pkg/prompts/prompts.go b/pkg/prompts/prompts.go index a5fd7baf9..a09ce8113 100644 --- a/pkg/prompts/prompts.go +++ b/pkg/prompts/prompts.go @@ -107,6 +107,7 @@ type Prompter interface { CaptureEtnaDuration(promptStr string) (time.Duration, error) CaptureFujiDuration(promptStr string) (time.Duration, error) CaptureMainnetDuration(promptStr string) (time.Duration, error) + CaptureMainnetL1StakingDuration(promptStr string) (time.Duration, error) CaptureDate(promptStr string) (time.Time, error) CaptureNodeID(promptStr string) (ids.NodeID, error) CaptureID(promptStr string) (ids.ID, error) @@ -263,6 +264,20 @@ func (*realPrompter) CaptureMainnetDuration(promptStr string) (time.Duration, er return time.ParseDuration(durationStr) } +func (*realPrompter) CaptureMainnetL1StakingDuration(promptStr string) (time.Duration, error) { + prompt := promptui.Prompt{ + Label: promptStr, + Validate: validateMainnetL1StakingDuration, + } + + durationStr, err := prompt.Run() + if err != nil { + return 0, err + } + + return time.ParseDuration(durationStr) +} + func (*realPrompter) CaptureDate(promptStr string) (time.Time, error) { prompt := promptui.Prompt{ Label: promptStr, diff --git a/pkg/prompts/validations.go b/pkg/prompts/validations.go index fdfa11ea2..29f1eb3a6 100644 --- a/pkg/prompts/validations.go +++ b/pkg/prompts/validations.go @@ -58,6 +58,21 @@ func validateMainnetStakingDuration(input string) error { return nil } +func validateMainnetL1StakingDuration(input string) error { + const minL1StakingDuration = 24 * time.Hour + d, err := time.ParseDuration(input) + if err != nil { + return err + } + if d > genesis.MainnetParams.MaxStakeDuration { + return fmt.Errorf("exceeds maximum staking duration of %s", ux.FormatDuration(genesis.MainnetParams.MaxStakeDuration)) + } + if d < minL1StakingDuration { + return fmt.Errorf("below the minimum staking duration of %s", ux.FormatDuration(minL1StakingDuration)) + } + return nil +} + func validateFujiStakingDuration(input string) error { d, err := time.ParseDuration(input) if err != nil {