From 76be5b074eb5c0789605821f12dfb97192ce5e18 Mon Sep 17 00:00:00 2001 From: Tom Fleet Date: Sun, 26 Apr 2026 14:37:44 +0100 Subject: [PATCH] More performance wins --- command.go | 76 +++--- internal/arg/arg.go | 363 ++++++++++++------------ internal/flag/flag.go | 546 +++++++++++++++++++------------------ internal/flag/flag_test.go | 7 + internal/flag/set.go | 7 - internal/flag/set_test.go | 89 ------ internal/kind/kind.go | 48 ++++ 7 files changed, 542 insertions(+), 594 deletions(-) create mode 100644 internal/kind/kind.go diff --git a/command.go b/command.go index b921744..7262e73 100644 --- a/command.go +++ b/command.go @@ -379,47 +379,37 @@ func (cmd *Command) hasShortFlag(name string) bool { // findRequestedCommand uses the raw arguments and the command tree to determine what // (if any) subcommand is being requested and return that command along with the arguments // that were meant for it. +// +// On the first descent into a subcommand it snapshots args into a working +// slice that we own; subsequent levels then mutate that slice in place via +// [slices.Delete]. The original cmd.rawArgs is never touched, so re-Execute +// on the same Command still sees pristine input. func findRequestedCommand(cmd *Command, args []string) (*Command, []string) { - // The next non-flag argument (if any) is the first immediate subcommand - // e.g. in 'go mod tidy' we're looking for 'mod'. - nextSubCommand, ok := firstNonFlagArg(cmd, args) - if !ok { - // No non-flag arguments, so we must already be either at the root command - // or the correct subcommand - return cmd, args - } - - // Lookup this immediate subcommand by name and if we find it, recursively call - // this function so we eventually end up at the end of the command tree with - // the right arguments - next := findSubCommand(cmd, nextSubCommand) - if next != nil { - return findRequestedCommand(next, argsMinusFirstX(args, nextSubCommand)) - } + owned := false + + for { + // The next non-flag argument (if any) is the immediate subcommand + // e.g. in 'go mod tidy' we're looking for 'mod'. + idx, ok := firstNonFlagArg(cmd, args) + if !ok { + return cmd, args + } - // Found it - return cmd, args -} + next := findSubCommand(cmd, args[idx]) + if next == nil { + return cmd, args + } -// argsMinusFirstX removes only the first x from args. Otherwise, commands that look like -// openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]). -// -// The input slice is not mutated so that repeated Execute calls on the same -// Command see the original rawArgs. -func argsMinusFirstX(args []string, x string) []string { - // Note: this is borrowed from Cobra but ours is a lot simpler because we don't support - // persistent flags - for i, arg := range args { - if arg == x { - result := make([]string, 0, len(args)-1) - result = append(result, args[:i]...) - result = append(result, args[i+1:]...) - - return result + if !owned { + working := make([]string, len(args)) + copy(working, args) + args = working + owned = true } - } - return args + args = slices.Delete(args, idx, idx+1) + cmd = next + } } // findSubCommand searches the immediate subcommands of cmd by name, looking for next. @@ -435,25 +425,25 @@ func findSubCommand(cmd *Command, next string) *Command { return nil } -// firstNonFlagArg walks args and returns the first positional (non-flag) -// argument along with a boolean indicating whether one was found. +// firstNonFlagArg walks args and returns the index of the first positional +// (non-flag) argument along with a boolean indicating whether one was found. // // It consumes flag-value pairs (e.g. '--flag value' or '-f value') so they // aren't mistaken for positional arguments, and stops at '--'. -func firstNonFlagArg(cmd *Command, args []string) (arg string, ok bool) { +func firstNonFlagArg(cmd *Command, args []string) (idx int, ok bool) { for i := 0; i < len(args); i++ { a := args[i] switch { case a == "--": // "--" terminates the flags - return "", false + return -1, false case strings.HasPrefix(a, "--") && !strings.Contains(a, "=") && !cmd.hasFlag(a[2:]): // If '--flag value' then skip value fallthrough case strings.HasPrefix(a, "-") && !strings.Contains(a, "=") && len(a) == 2 && !cmd.hasShortFlag(a[1:]): // '-f value' skip the value too. If there isn't one, we're done. if i+1 >= len(args) { - return "", false + return -1, false } i++ @@ -461,11 +451,11 @@ func firstNonFlagArg(cmd *Command, args []string) (arg string, ok bool) { continue case a != "" && !strings.HasPrefix(a, "-"): // First valid positional arg - return a, true + return i, true } } - return "", false + return -1, false } // showHelp is the default for a command's helpFunc. diff --git a/internal/arg/arg.go b/internal/arg/arg.go index 32985f8..043883b 100644 --- a/internal/arg/arg.go +++ b/internal/arg/arg.go @@ -16,6 +16,7 @@ import ( "go.followtheprocess.codes/cli/arg" "go.followtheprocess.codes/cli/internal/format" + "go.followtheprocess.codes/cli/internal/kind" "go.followtheprocess.codes/cli/internal/parse" ) @@ -23,10 +24,12 @@ var _ Value = Arg[string]{} // This will fail if we violate our Value interface // Arg represents a single command line argument. type Arg[T arg.Argable] struct { - value *T // The actual stored value - config Config[T] // Additional configuration - name string // Name of the argument as it appears on the command line - usage string // One line description of the argument. + value *T // The actual stored value + config Config[T] // Additional configuration + name string // Name of the argument as it appears on the command line + usage string // One line description of the argument. + typeStr string // Cached result of Type() + kind kind.Kind // Cached concrete kind of T, set in New so hot paths skip any() boxing } // New constructs and returns a new [Arg]. @@ -39,11 +42,15 @@ func New[T arg.Argable](p *T, name, usage string, config Config[T]) (Arg[T], err p = new(T) } + k, typeStr := typeInfo[T]() + argument := Arg[T]{ - value: p, - name: name, - usage: usage, - config: config, + value: p, + name: name, + usage: usage, + config: config, + typeStr: typeStr, + kind: k, } return argument, nil @@ -61,355 +68,321 @@ func (a Arg[T]) Usage() string { // Default returns the default value of the argument as a string // or "" if the argument is required. -// -//nolint:cyclop // No other way of doing this func (a Arg[T]) Default() string { if a.config.DefaultValue == nil { // DefaultValue is nil, therefore this is a required arg return "" } - switch typ := any(*a.config.DefaultValue).(type) { - case int: - return format.Int(typ) - case int8: - return format.Int(typ) - case int16: - return format.Int(typ) - case int32: - return format.Int(typ) - case int64: - return format.Int(typ) - case uint: - return format.Uint(typ) - case uint8: - return format.Uint(typ) - case uint16: - return format.Uint(typ) - case uint32: - return format.Uint(typ) - case uint64: - return format.Uint(typ) - case uintptr: - return format.Uint(typ) - case float32: - return format.Float32(typ) - case float64: - return format.Float64(typ) - case string: - return typ - case *url.URL: - return typ.String() - case bool: - return strconv.FormatBool(typ) - case []byte: - return hex.EncodeToString(typ) - case time.Time: - return typ.Format(time.RFC3339) - case time.Duration: - return typ.String() - case net.IP: - return typ.String() - default: - return fmt.Sprintf("Arg.String: unsupported arg type: %T", typ) - } + return formatValue(a.kind, a.config.DefaultValue) } // String returns the string representation of the current value of the arg. -// -//nolint:cyclop // No other way of doing this realistically func (a Arg[T]) String() string { if a.value == nil { return format.Nil } - switch typ := any(*a.value).(type) { - case int: - return format.Int(typ) - case int8: - return format.Int(typ) - case int16: - return format.Int(typ) - case int32: - return format.Int(typ) - case int64: - return format.Int(typ) - case uint: - return format.Uint(typ) - case uint8: - return format.Uint(typ) - case uint16: - return format.Uint(typ) - case uint32: - return format.Uint(typ) - case uint64: - return format.Uint(typ) - case uintptr: - return format.Uint(typ) - case float32: - return format.Float32(typ) - case float64: - return format.Float64(typ) - case string: - return typ - case *url.URL: - return typ.String() - case bool: - return strconv.FormatBool(typ) - case []byte: - return hex.EncodeToString(typ) - case time.Time: - return typ.Format(time.RFC3339) - case time.Duration: - return typ.String() - case net.IP: - return typ.String() - default: - return fmt.Sprintf("Arg.String: unsupported arg type: %T", typ) - } + return formatValue(a.kind, a.value) } // Type returns a string representation of the type of the Arg. -// -//nolint:cyclop // No other way of doing this realistically func (a Arg[T]) Type() string { if a.value == nil { return format.Nil } - switch typ := any(*a.value).(type) { - case int: - return format.TypeInt - case int8: - return format.TypeInt8 - case int16: - return format.TypeInt16 - case int32: - return format.TypeInt32 - case int64: - return format.TypeInt64 - case uint: - return format.TypeUint - case uint8: - return format.TypeUint8 - case uint16: - return format.TypeUint16 - case uint32: - return format.TypeUint32 - case uint64: - return format.TypeUint64 - case uintptr: - return format.TypeUintptr - case float32: - return format.TypeFloat32 - case float64: - return format.TypeFloat64 - case string: - return format.TypeString - case *url.URL: - return format.TypeURL - case bool: - return format.TypeBool - case []byte: - return format.TypeBytesHex - case time.Time: - return format.TypeTime - case time.Duration: - return format.TypeDuration - case net.IP: - return format.TypeIP - default: - return fmt.Sprintf("%T", typ) - } + return a.typeStr } // Set sets an [Arg] value by parsing it's string value. // -//nolint:gocognit,maintidx // No other way of doing this realistically +//nolint:gocognit,maintidx,cyclop // No other way of doing this realistically func (a Arg[T]) Set(str string) error { if a.value == nil { return fmt.Errorf("cannot set value %s, arg.value was nil", str) } - switch typ := any(*a.value).(type) { - case int: + switch a.kind { + case kind.Int: val, err := parse.Int(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case int8: + case kind.Int8: val, err := parse.Int8(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case int16: + case kind.Int16: val, err := parse.Int16(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case int32: + case kind.Int32: val, err := parse.Int32(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case int64: + case kind.Int64: val, err := parse.Int64(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case uint: + case kind.Uint: val, err := parse.Uint(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case uint8: + case kind.Uint8: val, err := parse.Uint8(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case uint16: + case kind.Uint16: val, err := parse.Uint16(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case uint32: + case kind.Uint32: val, err := parse.Uint32(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) - } - - *a.value = *parse.Cast[T](&val) - - return nil - case uint64: - val, err := parse.Uint64(str) - if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case uintptr: + case kind.Uint64, kind.Uintptr: val, err := parse.Uint64(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case float32: + case kind.Float32: val, err := parse.Float32(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case float64: + case kind.Float64: val, err := parse.Float64(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case string: + case kind.String: val := str *a.value = *parse.Cast[T](&val) return nil - case *url.URL: + case kind.URL: val, err := url.ParseRequestURI(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case bool: + case kind.Bool: val, err := strconv.ParseBool(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case []byte: + case kind.BytesHex: val, err := hex.DecodeString(strings.TrimSpace(str)) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case time.Time: + case kind.Time: val, err := time.Parse(time.RFC3339, str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case time.Duration: + case kind.Duration: val, err := time.ParseDuration(str) if err != nil { - return parse.Error(parse.KindArgument, a.name, str, typ, err) + return parse.Error(parse.KindArgument, a.name, str, *a.value, err) } *a.value = *parse.Cast[T](&val) return nil - case net.IP: + case kind.IP: val := net.ParseIP(str) if val == nil { - return parse.Error(parse.KindArgument, a.name, str, typ, errors.New("invalid IP address")) + return parse.Error(parse.KindArgument, a.name, str, *a.value, errors.New("invalid IP address")) } *a.value = *parse.Cast[T](&val) return nil default: - return fmt.Errorf("Arg.Set: unsupported arg type: %T", typ) + return fmt.Errorf("Arg.Set: unsupported arg type: %T", *a.value) + } +} + +// typeInfo computes the type-dependent metadata (kind, type string) for an +// arg of type T. It is called once per arg at construction so that hot paths +// (Set, String, Type) never have to type-switch on any(*a.value), which would +// box the value on every call. +// +//nolint:cyclop // No other way of doing this realistically +func typeInfo[T arg.Argable]() (kind.Kind, string) { + var zero T + + switch typ := any(zero).(type) { + case int: + return kind.Int, format.TypeInt + case int8: + return kind.Int8, format.TypeInt8 + case int16: + return kind.Int16, format.TypeInt16 + case int32: + return kind.Int32, format.TypeInt32 + case int64: + return kind.Int64, format.TypeInt64 + case uint: + return kind.Uint, format.TypeUint + case uint8: + return kind.Uint8, format.TypeUint8 + case uint16: + return kind.Uint16, format.TypeUint16 + case uint32: + return kind.Uint32, format.TypeUint32 + case uint64: + return kind.Uint64, format.TypeUint64 + case uintptr: + return kind.Uintptr, format.TypeUintptr + case float32: + return kind.Float32, format.TypeFloat32 + case float64: + return kind.Float64, format.TypeFloat64 + case string: + return kind.String, format.TypeString + case *url.URL: + return kind.URL, format.TypeURL + case bool: + return kind.Bool, format.TypeBool + case []byte: + return kind.BytesHex, format.TypeBytesHex + case time.Time: + return kind.Time, format.TypeTime + case time.Duration: + return kind.Duration, format.TypeDuration + case net.IP: + return kind.IP, format.TypeIP + default: + return kind.Invalid, fmt.Sprintf("%T", typ) + } +} + +// formatValue renders the value pointed to by p as a string using the kind dispatch. +// +//nolint:cyclop // No other way of doing this realistically +func formatValue[T arg.Argable](k kind.Kind, p *T) string { + switch k { + case kind.Int: + return format.Int(*parse.Cast[int, T](p)) + case kind.Int8: + return format.Int(*parse.Cast[int8, T](p)) + case kind.Int16: + return format.Int(*parse.Cast[int16, T](p)) + case kind.Int32: + return format.Int(*parse.Cast[int32, T](p)) + case kind.Int64: + return format.Int(*parse.Cast[int64, T](p)) + case kind.Uint: + return format.Uint(*parse.Cast[uint, T](p)) + case kind.Uint8: + return format.Uint(*parse.Cast[uint8, T](p)) + case kind.Uint16: + return format.Uint(*parse.Cast[uint16, T](p)) + case kind.Uint32: + return format.Uint(*parse.Cast[uint32, T](p)) + case kind.Uint64: + return format.Uint(*parse.Cast[uint64, T](p)) + case kind.Uintptr: + return format.Uint(*parse.Cast[uintptr, T](p)) + case kind.Float32: + return format.Float32(*parse.Cast[float32, T](p)) + case kind.Float64: + return format.Float64(*parse.Cast[float64, T](p)) + case kind.String: + return *parse.Cast[string, T](p) + case kind.URL: + u := *parse.Cast[*url.URL, T](p) + if u == nil { + return format.Nil + } + + return u.String() + case kind.Bool: + return strconv.FormatBool(*parse.Cast[bool, T](p)) + case kind.BytesHex: + return hex.EncodeToString(*parse.Cast[[]byte, T](p)) + case kind.Time: + return parse.Cast[time.Time, T](p).Format(time.RFC3339) + case kind.Duration: + return parse.Cast[time.Duration, T](p).String() + case kind.IP: + return parse.Cast[net.IP, T](p).String() + default: + return fmt.Sprintf("Arg.String: unsupported arg type: %T", *p) } } diff --git a/internal/flag/flag.go b/internal/flag/flag.go index 8924f10..6ec18c7 100644 --- a/internal/flag/flag.go +++ b/internal/flag/flag.go @@ -17,6 +17,7 @@ import ( "go.followtheprocess.codes/cli/flag" "go.followtheprocess.codes/cli/internal/format" + "go.followtheprocess.codes/cli/internal/kind" "go.followtheprocess.codes/cli/internal/parse" ) @@ -24,14 +25,15 @@ var _ Value = &Flag[string]{} // This will fail if we violate our Value interfac // Flag represents a single command line flag. type Flag[T flag.Flaggable] struct { - value *T // The actual stored value - name string // The name of the flag as appears on the command line, e.g. "force" for a --force flag - usage string // one line description of the flag, e.g. "Force deletion without confirmation" - envVar string // Name of an environment variable that may set this flag's value if the flag is not explicitly provided on the command line - typeStr string // Cached result of Type() - noArgValue string // Cached result of NoArgValue() - short rune // Optional shorthand version of the flag, e.g. "f" for a -f flag - isSlice bool // Cached result of IsSlice() + value *T // The actual stored value + name string // The name of the flag as appears on the command line, e.g. "force" for a --force flag + usage string // one line description of the flag, e.g. "Force deletion without confirmation" + envVar string // Name of an environment variable that may set this flag's value if the flag is not explicitly provided on the command line + typeStr string // Cached result of Type() + noArgValue string // Cached result of NoArgValue() + short rune // Optional shorthand version of the flag, e.g. "f" for a -f flag + kind kind.Kind // Cached concrete kind of T + isSlice bool // Cached result of IsSlice() } // New constructs and returns a new [Flag]. @@ -53,7 +55,7 @@ func New[T flag.Flaggable](p *T, name string, short rune, usage string, config C *p = config.DefaultValue - typeStr, noArgValue, isSlice := typeInfo[T]() + info := typeInfo[T]() return &Flag[T]{ value: p, @@ -61,9 +63,10 @@ func New[T flag.Flaggable](p *T, name string, short rune, usage string, config C usage: usage, short: short, envVar: config.EnvVar, - typeStr: typeStr, - noArgValue: noArgValue, - isSlice: isSlice, + typeStr: info.typeStr, + noArgValue: info.noArgValue, + kind: info.kind, + isSlice: info.isSlice, }, nil } @@ -91,7 +94,7 @@ func (f *Flag[T]) Default() string { // Special case a --help flag, because if we didn't, when you call --help // it would show up with a default of true because you've passed it // so it's value is true here - if isZeroIsh(*f.value) || f.name == "help" { + if f.isZeroIsh() || f.name == "help" { return "" } @@ -136,477 +139,478 @@ func (f *Flag[T]) String() string { return format.Nil } - switch typ := any(*f.value).(type) { - case int: - return format.Int(typ) - case int8: - return format.Int(typ) - case int16: - return format.Int(typ) - case int32: - return format.Int(typ) - case int64: - return format.Int(typ) - case flag.Count: - return format.Uint(typ) - case uint: - return format.Uint(typ) - case uint8: - return format.Uint(typ) - case uint16: - return format.Uint(typ) - case uint32: - return format.Uint(typ) - case uint64: - return format.Uint(typ) - case uintptr: - return format.Uint(typ) - case float32: - return format.Float32(typ) - case float64: - return format.Float64(typ) - case string: - return typ - case bool: - return strconv.FormatBool(typ) - case []byte: - return hex.EncodeToString(typ) - case time.Time: - return typ.Format(time.RFC3339) - case time.Duration: - return typ.String() - case net.IP: - return typ.String() - case *url.URL: - if typ == nil { + switch f.kind { + case kind.Int: + return format.Int(*parse.Cast[int, T](f.value)) + case kind.Int8: + return format.Int(*parse.Cast[int8, T](f.value)) + case kind.Int16: + return format.Int(*parse.Cast[int16, T](f.value)) + case kind.Int32: + return format.Int(*parse.Cast[int32, T](f.value)) + case kind.Int64: + return format.Int(*parse.Cast[int64, T](f.value)) + case kind.Count: + return format.Uint(*parse.Cast[flag.Count, T](f.value)) + case kind.Uint: + return format.Uint(*parse.Cast[uint, T](f.value)) + case kind.Uint8: + return format.Uint(*parse.Cast[uint8, T](f.value)) + case kind.Uint16: + return format.Uint(*parse.Cast[uint16, T](f.value)) + case kind.Uint32: + return format.Uint(*parse.Cast[uint32, T](f.value)) + case kind.Uint64: + return format.Uint(*parse.Cast[uint64, T](f.value)) + case kind.Uintptr: + return format.Uint(*parse.Cast[uintptr, T](f.value)) + case kind.Float32: + return format.Float32(*parse.Cast[float32, T](f.value)) + case kind.Float64: + return format.Float64(*parse.Cast[float64, T](f.value)) + case kind.String: + return *parse.Cast[string, T](f.value) + case kind.Bool: + return strconv.FormatBool(*parse.Cast[bool, T](f.value)) + case kind.BytesHex: + return hex.EncodeToString(*parse.Cast[[]byte, T](f.value)) + case kind.Time: + return parse.Cast[time.Time, T](f.value).Format(time.RFC3339) + case kind.Duration: + return parse.Cast[time.Duration, T](f.value).String() + case kind.IP: + return parse.Cast[net.IP, T](f.value).String() + case kind.URL: + u := *parse.Cast[*url.URL, T](f.value) + if u == nil { return format.Nil } - return typ.String() - case []int: - return format.Slice(typ) - case []int8: - return format.Slice(typ) - case []int16: - return format.Slice(typ) - case []int32: - return format.Slice(typ) - case []int64: - return format.Slice(typ) - case []uint: - return format.Slice(typ) - case []uint16: - return format.Slice(typ) - case []uint32: - return format.Slice(typ) - case []uint64: - return format.Slice(typ) - case []float32: - return format.Slice(typ) - case []float64: - return format.Slice(typ) - case []string: - return format.Slice(typ) + return u.String() + case kind.IntSlice: + return format.Slice(*parse.Cast[[]int, T](f.value)) + case kind.Int8Slice: + return format.Slice(*parse.Cast[[]int8, T](f.value)) + case kind.Int16Slice: + return format.Slice(*parse.Cast[[]int16, T](f.value)) + case kind.Int32Slice: + return format.Slice(*parse.Cast[[]int32, T](f.value)) + case kind.Int64Slice: + return format.Slice(*parse.Cast[[]int64, T](f.value)) + case kind.UintSlice: + return format.Slice(*parse.Cast[[]uint, T](f.value)) + case kind.Uint16Slice: + return format.Slice(*parse.Cast[[]uint16, T](f.value)) + case kind.Uint32Slice: + return format.Slice(*parse.Cast[[]uint32, T](f.value)) + case kind.Uint64Slice: + return format.Slice(*parse.Cast[[]uint64, T](f.value)) + case kind.Float32Slice: + return format.Slice(*parse.Cast[[]float32, T](f.value)) + case kind.Float64Slice: + return format.Slice(*parse.Cast[[]float64, T](f.value)) + case kind.StringSlice: + return format.Slice(*parse.Cast[[]string, T](f.value)) default: - return fmt.Sprintf("Flag.String: unsupported flag type: %T", typ) + return fmt.Sprintf("Flag.String: unsupported flag type: %T", *f.value) } } -// typeInfo computes the type-dependent metadata (Type string, NoArgValue, -// IsSlice) for a flag of type T. It is called once per flag at construction -// so that the hot path of Parse never has to type-switch on any(*f.value), -// which would otherwise box the value on every call. -func typeInfo[T flag.Flaggable]() (typeStr, noArgValue string, isSlice bool) { //nolint:cyclop // No other way of doing this realistically +// info bundles the cacheable, type-dependent metadata for a Flag of a given T. +type info struct { + typeStr string + noArgValue string + kind kind.Kind + isSlice bool +} + +// typeInfo computes the type-dependent metadata (kind, type string, no-arg +// value, isSlice) for a flag of type T. It is called once per flag at +// construction so that the hot path of Parse never has to type-switch on +// any(*f.value), which would otherwise box the value on every call. +func typeInfo[T flag.Flaggable]() info { //nolint:cyclop // No other way of doing this realistically var zero T switch typ := any(zero).(type) { case int: - return format.TypeInt, "", false + return info{kind: kind.Int, typeStr: format.TypeInt} case int8: - return format.TypeInt8, "", false + return info{kind: kind.Int8, typeStr: format.TypeInt8} case int16: - return format.TypeInt16, "", false + return info{kind: kind.Int16, typeStr: format.TypeInt16} case int32: - return format.TypeInt32, "", false + return info{kind: kind.Int32, typeStr: format.TypeInt32} case int64: - return format.TypeInt64, "", false + return info{kind: kind.Int64, typeStr: format.TypeInt64} case flag.Count: - return format.TypeCount, "1", false + return info{kind: kind.Count, typeStr: format.TypeCount, noArgValue: "1"} case uint: - return format.TypeUint, "", false + return info{kind: kind.Uint, typeStr: format.TypeUint} case uint8: - return format.TypeUint8, "", false + return info{kind: kind.Uint8, typeStr: format.TypeUint8} case uint16: - return format.TypeUint16, "", false + return info{kind: kind.Uint16, typeStr: format.TypeUint16} case uint32: - return format.TypeUint32, "", false + return info{kind: kind.Uint32, typeStr: format.TypeUint32} case uint64: - return format.TypeUint64, "", false + return info{kind: kind.Uint64, typeStr: format.TypeUint64} case uintptr: - return format.TypeUintptr, "", false + return info{kind: kind.Uintptr, typeStr: format.TypeUintptr} case float32: - return format.TypeFloat32, "", false + return info{kind: kind.Float32, typeStr: format.TypeFloat32} case float64: - return format.TypeFloat64, "", false + return info{kind: kind.Float64, typeStr: format.TypeFloat64} case string: - return format.TypeString, "", false + return info{kind: kind.String, typeStr: format.TypeString} case bool: - return format.TypeBool, format.True, false + return info{kind: kind.Bool, typeStr: format.TypeBool, noArgValue: format.True} case []byte: - return format.TypeBytesHex, "", false + return info{kind: kind.BytesHex, typeStr: format.TypeBytesHex} case time.Time: - return format.TypeTime, "", false + return info{kind: kind.Time, typeStr: format.TypeTime} case time.Duration: - return format.TypeDuration, "", false + return info{kind: kind.Duration, typeStr: format.TypeDuration} case net.IP: - return format.TypeIP, "", false + return info{kind: kind.IP, typeStr: format.TypeIP} case *url.URL: - return format.TypeURL, "", false + return info{kind: kind.URL, typeStr: format.TypeURL} case []int: - return format.TypeIntSlice, "", true + return info{kind: kind.IntSlice, typeStr: format.TypeIntSlice, isSlice: true} case []int8: - return format.TypeInt8Slice, "", true + return info{kind: kind.Int8Slice, typeStr: format.TypeInt8Slice, isSlice: true} case []int16: - return format.TypeInt16Slice, "", true + return info{kind: kind.Int16Slice, typeStr: format.TypeInt16Slice, isSlice: true} case []int32: - return format.TypeInt32Slice, "", true + return info{kind: kind.Int32Slice, typeStr: format.TypeInt32Slice, isSlice: true} case []int64: - return format.TypeInt64Slice, "", true + return info{kind: kind.Int64Slice, typeStr: format.TypeInt64Slice, isSlice: true} case []uint: - return format.TypeUintSlice, "", true + return info{kind: kind.UintSlice, typeStr: format.TypeUintSlice, isSlice: true} case []uint16: - return format.TypeUint16Slice, "", true + return info{kind: kind.Uint16Slice, typeStr: format.TypeUint16Slice, isSlice: true} case []uint32: - return format.TypeUint32Slice, "", true + return info{kind: kind.Uint32Slice, typeStr: format.TypeUint32Slice, isSlice: true} case []uint64: - return format.TypeUint64Slice, "", true + return info{kind: kind.Uint64Slice, typeStr: format.TypeUint64Slice, isSlice: true} case []float32: - return format.TypeFloat32Slice, "", true + return info{kind: kind.Float32Slice, typeStr: format.TypeFloat32Slice, isSlice: true} case []float64: - return format.TypeFloat64Slice, "", true + return info{kind: kind.Float64Slice, typeStr: format.TypeFloat64Slice, isSlice: true} case []string: - return format.TypeStringSlice, "", true + return info{kind: kind.StringSlice, typeStr: format.TypeStringSlice, isSlice: true} default: - return fmt.Sprintf("%T", typ), "", false + return info{kind: kind.Invalid, typeStr: fmt.Sprintf("%T", typ)} } } // Set sets a [Flag] value based on string input, i.e. parsing from the command line. // -//nolint:gocognit,maintidx // No other way of doing this realistically +//nolint:gocognit,maintidx,cyclop // No other way of doing this realistically func (f *Flag[T]) Set(str string) error { if f.value == nil { return fmt.Errorf("cannot set value %s, flag.value was nil", str) } - switch typ := any(*f.value).(type) { - case int: + switch f.kind { + case kind.Int: val, err := parse.Int(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case int8: + case kind.Int8: val, err := parse.Int8(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case int16: + case kind.Int16: val, err := parse.Int16(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case int32: + case kind.Int32: val, err := parse.Int32(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case int64: + case kind.Int64: val, err := parse.Int64(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case flag.Count: + case kind.Count: // Add the count and store it back, we still parse the given str rather // than just +1 every time as this allows people to do e.g. --verbosity=3 // as well as -vvv val, err := parse.Uint(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } - newValue := typ + flag.Count(val) + newValue := *parse.Cast[flag.Count, T](f.value) + flag.Count(val) *f.value = *parse.Cast[T](&newValue) return nil - case uint: + case kind.Uint: val, err := parse.Uint(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case uint8: + case kind.Uint8: val, err := parse.Uint8(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case uint16: + case kind.Uint16: val, err := parse.Uint16(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case uint32: + case kind.Uint32: val, err := parse.Uint32(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) - } - - *f.value = *parse.Cast[T](&val) - - return nil - case uint64: - val, err := parse.Uint64(str) - if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case uintptr: + case kind.Uint64, kind.Uintptr: val, err := parse.Uint64(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case float32: + case kind.Float32: val, err := parse.Float32(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case float64: + case kind.Float64: val, err := parse.Float64(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case string: + case kind.String: val := str *f.value = *parse.Cast[T](&val) return nil - case bool: + case kind.Bool: val, err := strconv.ParseBool(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case []byte: + case kind.BytesHex: val, err := hex.DecodeString(strings.TrimSpace(str)) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case time.Time: + case kind.Time: val, err := time.Parse(time.RFC3339, str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case time.Duration: + case kind.Duration: val, err := time.ParseDuration(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case net.IP: + case kind.IP: val := net.ParseIP(str) if val == nil { - return parse.Error(parse.KindFlag, f.name, str, typ, errors.New("invalid IP address")) + return parse.Error(parse.KindFlag, f.name, str, *f.value, errors.New("invalid IP address")) } *f.value = *parse.Cast[T](&val) return nil - case *url.URL: + case kind.URL: val, err := url.ParseRequestURI(str) if err != nil { - return parse.Error(parse.KindFlag, f.name, str, typ, err) + return parse.Error(parse.KindFlag, f.name, str, *f.value, err) } *f.value = *parse.Cast[T](&val) return nil - case []int: + case kind.IntSlice: // Like Count, a slice flag is a read/write op newValue, err := parse.Int(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]int, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []int8: + case kind.Int8Slice: newValue, err := parse.Int8(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]int8, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []int16: + case kind.Int16Slice: newValue, err := parse.Int16(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]int16, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []int32: + case kind.Int32Slice: newValue, err := parse.Int32(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]int32, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []int64: + case kind.Int64Slice: newValue, err := parse.Int64(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]int64, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []uint: + case kind.UintSlice: newValue, err := parse.Uint(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]uint, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []uint16: + case kind.Uint16Slice: newValue, err := parse.Uint16(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]uint16, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []uint32: + case kind.Uint32Slice: newValue, err := parse.Uint32(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]uint32, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []uint64: + case kind.Uint64Slice: newValue, err := parse.Uint64(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]uint64, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []float32: + case kind.Float32Slice: newValue, err := parse.Float32(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]float32, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []float64: + case kind.Float64Slice: newValue, err := parse.Float64(str) if err != nil { - return parse.ErrorSlice(parse.KindFlag, f.name, str, typ, err) + return parse.ErrorSlice(parse.KindFlag, f.name, str, *f.value, err) } - typ = append(typ, newValue) + typ := append(*parse.Cast[[]float64, T](f.value), newValue) *f.value = *parse.Cast[T](&typ) return nil - case []string: + case kind.StringSlice: + typ := *parse.Cast[[]string, T](f.value) typ = append(typ, str) *f.value = *parse.Cast[T](&typ) return nil default: - return fmt.Errorf("Flag.Set: unsupported flag type: %T", typ) + return fmt.Errorf("Flag.Set: unsupported flag type: %T", *f.value) } } @@ -674,61 +678,83 @@ func validateFlagShort(short rune) error { return nil } -// isZeroIsh reports whether value is the zero value (ish) for it's type. +// isZeroIsh reports whether the flag's value is the zero value (ish) for it's type. // -// "ish" means that empty slices will return true from isZeroIsh despite their official -// zero value being nil. The primary use of isZeroIsh is to determine whether or not -// a default value is worth displaying to the user in the help text, and an empty slice -// is probably not. -func isZeroIsh[T flag.Flaggable](value T) bool { //nolint:cyclop // Not much else we can do here - // Note: all the slice values ([]T) are in their own separate branches because if you - // combine them, the resulting value in the body of the case block is 'any' and - // you cannot do len(any) - switch typ := any(value).(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, float32, float64: - return typ == 0 - case flag.Count: - return typ == flag.Count(0) - case string: - return typ == "" - case bool: - return !typ - case []byte: - return len(typ) == 0 - case net.IP: - return len(typ) == 0 - case *url.URL: - return typ == nil - case []int: - return len(typ) == 0 - case []int8: - return len(typ) == 0 - case []int16: - return len(typ) == 0 - case []int32: - return len(typ) == 0 - case []int64: - return len(typ) == 0 - case []uint: - return len(typ) == 0 - case []uint16: - return len(typ) == 0 - case []uint32: - return len(typ) == 0 - case []uint64: - return len(typ) == 0 - case []float32: - return len(typ) == 0 - case []float64: - return len(typ) == 0 - case []string: - return len(typ) == 0 - case time.Time: +// "ish" means that empty slices will return true despite their official zero +// value being nil. The primary use is to determine whether a default value is +// worth displaying to the user in the help text — an empty slice is probably +// not. +// +//nolint:cyclop // Not much else we can do here +func (f *Flag[T]) isZeroIsh() bool { + switch f.kind { + case kind.Int: + return *parse.Cast[int, T](f.value) == 0 + case kind.Int8: + return *parse.Cast[int8, T](f.value) == 0 + case kind.Int16: + return *parse.Cast[int16, T](f.value) == 0 + case kind.Int32: + return *parse.Cast[int32, T](f.value) == 0 + case kind.Int64: + return *parse.Cast[int64, T](f.value) == 0 + case kind.Uint: + return *parse.Cast[uint, T](f.value) == 0 + case kind.Uint8: + return *parse.Cast[uint8, T](f.value) == 0 + case kind.Uint16: + return *parse.Cast[uint16, T](f.value) == 0 + case kind.Uint32: + return *parse.Cast[uint32, T](f.value) == 0 + case kind.Uint64: + return *parse.Cast[uint64, T](f.value) == 0 + case kind.Uintptr: + return *parse.Cast[uintptr, T](f.value) == 0 + case kind.Float32: + return *parse.Cast[float32, T](f.value) == 0 + case kind.Float64: + return *parse.Cast[float64, T](f.value) == 0 + case kind.Count: + return *parse.Cast[flag.Count, T](f.value) == 0 + case kind.String: + return *parse.Cast[string, T](f.value) == "" + case kind.Bool: + return !*parse.Cast[bool, T](f.value) + case kind.BytesHex: + return len(*parse.Cast[[]byte, T](f.value)) == 0 + case kind.IP: + return len(*parse.Cast[net.IP, T](f.value)) == 0 + case kind.URL: + return *parse.Cast[*url.URL, T](f.value) == nil + case kind.IntSlice: + return len(*parse.Cast[[]int, T](f.value)) == 0 + case kind.Int8Slice: + return len(*parse.Cast[[]int8, T](f.value)) == 0 + case kind.Int16Slice: + return len(*parse.Cast[[]int16, T](f.value)) == 0 + case kind.Int32Slice: + return len(*parse.Cast[[]int32, T](f.value)) == 0 + case kind.Int64Slice: + return len(*parse.Cast[[]int64, T](f.value)) == 0 + case kind.UintSlice: + return len(*parse.Cast[[]uint, T](f.value)) == 0 + case kind.Uint16Slice: + return len(*parse.Cast[[]uint16, T](f.value)) == 0 + case kind.Uint32Slice: + return len(*parse.Cast[[]uint32, T](f.value)) == 0 + case kind.Uint64Slice: + return len(*parse.Cast[[]uint64, T](f.value)) == 0 + case kind.Float32Slice: + return len(*parse.Cast[[]float32, T](f.value)) == 0 + case kind.Float64Slice: + return len(*parse.Cast[[]float64, T](f.value)) == 0 + case kind.StringSlice: + return len(*parse.Cast[[]string, T](f.value)) == 0 + case kind.Time: var zero time.Time - return typ.Equal(zero) - case time.Duration: - var zero time.Duration - return typ == zero + return parse.Cast[time.Time, T](f.value).Equal(zero) + case kind.Duration: + return *parse.Cast[time.Duration, T](f.value) == 0 default: return false } diff --git a/internal/flag/flag_test.go b/internal/flag/flag_test.go index 40ee3d4..5b1d259 100644 --- a/internal/flag/flag_test.go +++ b/internal/flag/flag_test.go @@ -1043,6 +1043,13 @@ func TestFlagValidation(t *testing.T) { wantErr: true, errMsg: `invalid shorthand for flag "delete": invalid character, must be a single ASCII letter, got "本"`, }, + { + name: "short is whitespace", + flagName: "delete", + short: ' ', + wantErr: true, + errMsg: `invalid shorthand for flag "delete": cannot contain whitespace`, + }, } for _, tt := range tests { diff --git a/internal/flag/set.go b/internal/flag/set.go index 70a061c..2d2fe78 100644 --- a/internal/flag/set.go +++ b/internal/flag/set.go @@ -290,9 +290,6 @@ func (s *Set) parseLongFlag(long string, rest []string) (remaining []string, err // name will either be the entire string or the name before the "=" name, value, containsEquals := strings.Cut(name, "=") - if err := validateFlagName(name); err != nil { - return nil, fmt.Errorf("invalid flag name %q: %w", name, err) - } flag, exists := s.flags[name] if !exists { @@ -368,10 +365,6 @@ func (s *Set) parseShortFlag(short string, rest []string) (remaining []string, e func (s *Set) parseSingleShortFlag(shorthands string, rest []string) (string, []string, error) { char, _ := utf8.DecodeRuneInString(shorthands) - if err := validateFlagShort(char); err != nil { - return "", nil, fmt.Errorf("invalid flag shorthand %q: %w", string(char), err) - } - flag, exists := s.shorthands[char] if !exists { return "", nil, fmt.Errorf("unrecognised shorthand flag: %q in -%s", string(char), shorthands) diff --git a/internal/flag/set_test.go b/internal/flag/set_test.go index 69bf2ad..2ee1cac 100644 --- a/internal/flag/set_test.go +++ b/internal/flag/set_test.go @@ -162,59 +162,6 @@ func TestParse(t *testing.T) { wantErr: true, errMsg: `invalid flag name "": must not be empty`, }, - { - name: "bad syntax long extra hyphen", - newSet: func(t *testing.T) *flag.Set { - return flag.NewSet() - }, - args: []string{"---"}, - wantErr: true, - errMsg: `invalid flag name "-": trailing hyphen`, - }, - { - name: "bad syntax long leading whitespace", - newSet: func(t *testing.T) *flag.Set { - return flag.NewSet() - }, - args: []string{"-- delete"}, - wantErr: true, - errMsg: `invalid flag name " delete": cannot contain whitespace`, - }, - { - name: "bad syntax short leading whitespace", - newSet: func(t *testing.T) *flag.Set { - return flag.NewSet() - }, - args: []string{"- d"}, - wantErr: true, - errMsg: `invalid flag shorthand " ": cannot contain whitespace`, - }, - { - name: "bad syntax long trailing whitespace", - newSet: func(t *testing.T) *flag.Set { - return flag.NewSet() - }, - args: []string{"--delete "}, - wantErr: true, - errMsg: `invalid flag name "delete ": cannot contain whitespace`, - }, - { - name: "bad syntax short trailing whitespace", - newSet: func(t *testing.T) *flag.Set { - f, err := flag.New(new(bool), "delete", 'd', "Delete something", flag.Config[bool]{}) - test.Ok(t, err) - - set := flag.NewSet() - - err = flag.AddToSet(set, f) - test.Ok(t, err) - - return set - }, - args: []string{"-d "}, - wantErr: true, - errMsg: `invalid flag shorthand " ": cannot contain whitespace`, - }, { name: "bad syntax short more than 1 char equals", newSet: func(t *testing.T) *flag.Set { @@ -224,42 +171,6 @@ func TestParse(t *testing.T) { wantErr: true, errMsg: `unrecognised shorthand flag: "d" in -dfv=something`, }, - { - name: "bad syntax short non utf8", - newSet: func(t *testing.T) *flag.Set { - return flag.NewSet() - }, - args: []string{"-Ê"}, - wantErr: true, - errMsg: `invalid flag shorthand "Ê": invalid character, must be a single ASCII letter, got "Ê"`, - }, - { - name: "bad syntax short non utf8 equals", - newSet: func(t *testing.T) *flag.Set { - return flag.NewSet() - }, - args: []string{"-Ê=something"}, - wantErr: true, - errMsg: `invalid flag shorthand "Ê": invalid character, must be a single ASCII letter, got "Ê"`, - }, - { - name: "bad syntax short multiple non utf8", - newSet: func(t *testing.T) *flag.Set { - return flag.NewSet() - }, - args: []string{"-本¼語"}, - wantErr: true, - errMsg: `invalid flag shorthand "本": invalid character, must be a single ASCII letter, got "本"`, - }, - { - name: "bad syntax long internal whitespace", - newSet: func(t *testing.T) *flag.Set { - return flag.NewSet() - }, - args: []string{"--de lete"}, - wantErr: true, - errMsg: `invalid flag name "de lete": cannot contain whitespace`, - }, { name: "valid long", newSet: func(t *testing.T) *flag.Set { diff --git a/internal/kind/kind.go b/internal/kind/kind.go new file mode 100644 index 0000000..7ffcff2 --- /dev/null +++ b/internal/kind/kind.go @@ -0,0 +1,48 @@ +// Package kind defines a compact type tag identifying the underlying +// concrete type of a Flag or Arg value. +// +// It exists so that hot paths do not have to do type switching which +// boosts performance and cuts allocations. +package kind + +// Kind identifies the underlying concrete type of a Flag or Arg value. +type Kind uint8 + +// Concrete kinds for every type in the public flag.Flaggable / arg.Argable +// constraints. +const ( + Invalid Kind = iota + Int + Int8 + Int16 + Int32 + Int64 + Uint + Uint8 + Uint16 + Uint32 + Uint64 + Uintptr + Float32 + Float64 + String + Bool + BytesHex + Count + Time + Duration + IP + URL + IntSlice + Int8Slice + Int16Slice + Int32Slice + Int64Slice + UintSlice + Uint16Slice + Uint32Slice + Uint64Slice + Float32Slice + Float64Slice + StringSlice +)