From d9a27b8052b61ed9ff043993c141af2acfb50027 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Fri, 1 Mar 2019 17:18:23 -0800 Subject: [PATCH] Provide "default" tag for unmarshal (#259) When a struct is unmarshalled, go-toml now looks at the `default` tag to provide a default value in case the key is not present in the TOML document. This is only implemented for string, bool, int, int64, float64. Additional types can be further implemented on a request-basis. --- marshal.go | 87 ++++++++++++++++++++++++++++++++++++++++--------- marshal_test.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 15 deletions(-) diff --git a/marshal.go b/marshal.go index e7ef802..f17d049 100644 --- a/marshal.go +++ b/marshal.go @@ -16,15 +16,17 @@ const ( tagFieldComment = "comment" tagCommented = "commented" tagMultiline = "multiline" + tagDefault = "default" ) type tomlOpts struct { - name string - comment string - commented bool - multiline bool - include bool - omitempty bool + name string + comment string + commented bool + multiline bool + include bool + omitempty bool + defaultValue string } type encOpts struct { @@ -37,17 +39,19 @@ var encOptsDefaults = encOpts{ } type annotation struct { - tag string - comment string - commented string - multiline string + tag string + comment string + commented string + multiline string + defaultValue string } var annotationDefault = annotation{ - tag: tagFieldName, - comment: tagFieldComment, - commented: tagCommented, - multiline: tagMultiline, + tag: tagFieldName, + comment: tagFieldComment, + commented: tagCommented, + multiline: tagMultiline, + defaultValue: tagDefault, } var timeType = reflect.TypeOf(time.Time{}) @@ -403,6 +407,14 @@ func (t *Tree) Marshal() ([]byte, error) { // The following struct annotations are supported: // // toml:"Field" Overrides the field's name to map to. +// default:"foo" Provides a default value. +// +// For default values, only fields of the following types are supported: +// * string +// * bool +// * int +// * int64 +// * float64 // // See Marshal() documentation for types mapping table. func Unmarshal(data []byte, v interface{}) error { @@ -484,6 +496,8 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, strings.ToTitle(baseKey), strings.ToLower(string(baseKey[0])) + baseKey[1:], } + + found := false for _, key := range keysToTry { exists := tval.Has(key) if !exists { @@ -495,8 +509,42 @@ func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, return mval, formatError(err, tval.GetPosition(key)) } mval.Field(i).Set(mvalf) + found = true break } + + if !found && opts.defaultValue != "" { + mvalf := mval.Field(i) + var val interface{} = nil + var err error = nil + switch mvalf.Kind() { + case reflect.Bool: + val, err = strconv.ParseBool(opts.defaultValue) + if err != nil { + return mval.Field(i), err + } + case reflect.Int: + val, err = strconv.Atoi(opts.defaultValue) + if err != nil { + return mval.Field(i), err + } + case reflect.String: + val = opts.defaultValue + case reflect.Int64: + val, err = strconv.ParseInt(opts.defaultValue, 10, 64) + if err != nil { + return mval.Field(i), err + } + case reflect.Float64: + val, err = strconv.ParseFloat(opts.defaultValue, 64) + if err != nil { + return mval.Field(i), err + } + default: + return mval.Field(i), fmt.Errorf("unsuported field type for default option") + } + mval.Field(i).Set(reflect.ValueOf(val)) + } } } case reflect.Map: @@ -646,7 +694,16 @@ func tomlOptions(vf reflect.StructField, an annotation) tomlOpts { } commented, _ := strconv.ParseBool(vf.Tag.Get(an.commented)) multiline, _ := strconv.ParseBool(vf.Tag.Get(an.multiline)) - result := tomlOpts{name: vf.Name, comment: comment, commented: commented, multiline: multiline, include: true, omitempty: false} + defaultValue := vf.Tag.Get(tagDefault) + result := tomlOpts{ + name: vf.Name, + comment: comment, + commented: commented, + multiline: multiline, + include: true, + omitempty: false, + defaultValue: defaultValue, + } if parse[0] != "" { if parse[0] == "-" && len(parse) == 1 { result.include = false diff --git a/marshal_test.go b/marshal_test.go index 87d8217..9e5357d 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1178,3 +1178,88 @@ func TestUnmarshalCamelCaseKey(t *testing.T) { t.Fatal("Did not set camelCase'd key") } } + +func TestUnmarshalDefault(t *testing.T) { + var doc struct { + StringField string `default:"a"` + BoolField bool `default:"true"` + IntField int `default:"1"` + Int64Field int64 `default:"2"` + Float64Field float64 `default:"3.1"` + } + + err := Unmarshal([]byte(``), &doc) + if err != nil { + t.Fatal(err) + } + if doc.BoolField != true { + t.Errorf("BoolField should be true, not %t", doc.BoolField) + } + if doc.StringField != "a" { + t.Errorf("StringField should be \"a\", not %s", doc.StringField) + } + if doc.IntField != 1 { + t.Errorf("IntField should be 1, not %d", doc.IntField) + } + if doc.Int64Field != 2 { + t.Errorf("Int64Field should be 2, not %d", doc.Int64Field) + } + if doc.Float64Field != 3.1 { + t.Errorf("Float64Field should be 3.1, not %f", doc.Float64Field) + } +} + +func TestUnmarshalDefaultFailureBool(t *testing.T) { + var doc struct { + Field bool `default:"blah"` + } + + err := Unmarshal([]byte(``), &doc) + if err == nil { + t.Fatal("should error") + } +} + +func TestUnmarshalDefaultFailureInt(t *testing.T) { + var doc struct { + Field int `default:"blah"` + } + + err := Unmarshal([]byte(``), &doc) + if err == nil { + t.Fatal("should error") + } +} + +func TestUnmarshalDefaultFailureInt64(t *testing.T) { + var doc struct { + Field int64 `default:"blah"` + } + + err := Unmarshal([]byte(``), &doc) + if err == nil { + t.Fatal("should error") + } +} + +func TestUnmarshalDefaultFailureFloat64(t *testing.T) { + var doc struct { + Field float64 `default:"blah"` + } + + err := Unmarshal([]byte(``), &doc) + if err == nil { + t.Fatal("should error") + } +} + +func TestUnmarshalDefaultFailureUnsupported(t *testing.T) { + var doc struct { + Field struct{} `default:"blah"` + } + + err := Unmarshal([]byte(``), &doc) + if err == nil { + t.Fatal("should error") + } +}