574 lines
12 KiB
Go
574 lines
12 KiB
Go
package toml
|
|
|
|
import (
|
|
"encoding"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/pelletier/go-toml/v2/internal/ast"
|
|
"github.com/pelletier/go-toml/v2/internal/tracker"
|
|
"github.com/pelletier/go-toml/v2/internal/unsafe"
|
|
)
|
|
|
|
// Unmarshal deserializes a TOML document into a Go value.
|
|
//
|
|
// It is a shortcut for Decoder.Decode() with the default options.
|
|
func Unmarshal(data []byte, v interface{}) error {
|
|
p := parser{}
|
|
p.Reset(data)
|
|
d := decoder{}
|
|
|
|
return d.FromParser(&p, v)
|
|
}
|
|
|
|
// Decoder reads and decode a TOML document from an input stream.
|
|
type Decoder struct {
|
|
// input
|
|
r io.Reader
|
|
|
|
// global settings
|
|
strict bool
|
|
}
|
|
|
|
// NewDecoder creates a new Decoder that will read from r.
|
|
func NewDecoder(r io.Reader) *Decoder {
|
|
return &Decoder{r: r}
|
|
}
|
|
|
|
// SetStrict toggles decoding in stict mode.
|
|
//
|
|
// When the decoder is in strict mode, it will record fields from the document
|
|
// that could not be set on the target value. In that case, the decoder returns
|
|
// a StrictMissingError that can be used to retrieve the individual errors as
|
|
// well as generate a human readable description of the missing fields.
|
|
func (d *Decoder) SetStrict(strict bool) {
|
|
d.strict = strict
|
|
}
|
|
|
|
// Decode the whole content of r into v.
|
|
//
|
|
// By default, values in the document that don't exist in the target Go value
|
|
// are ignored. See Decoder.SetStrict() to change this behavior.
|
|
//
|
|
// When a TOML local date, time, or date-time is decoded into a time.Time, its
|
|
// value is represented in time.Local timezone. Otherwise the approriate Local*
|
|
// structure is used.
|
|
//
|
|
// Empty tables decoded in an interface{} create an empty initialized
|
|
// map[string]interface{}.
|
|
//
|
|
// Types implementing the encoding.TextUnmarshaler interface are decoded from a
|
|
// TOML string.
|
|
//
|
|
// When decoding a number, go-toml will return an error if the number is out of
|
|
// bounds for the target type (which includes negative numbers when decoding
|
|
// into an unsigned int).
|
|
//
|
|
// Type mapping
|
|
//
|
|
// List of supported TOML types and their associated accepted Go types:
|
|
//
|
|
// String -> string
|
|
// Integer -> uint*, int*, depending on size
|
|
// Float -> float*, depending on size
|
|
// Boolean -> bool
|
|
// Offset Date-Time -> time.Time
|
|
// Local Date-time -> LocalDateTime, time.Time
|
|
// Local Date -> LocalDate, time.Time
|
|
// Local Time -> LocalTime, time.Time
|
|
// Array -> slice and array, depending on elements types
|
|
// Table -> map and struct
|
|
// Inline Table -> same as Table
|
|
// Array of Tables -> same as Array and Table
|
|
func (d *Decoder) Decode(v interface{}) error {
|
|
b, err := ioutil.ReadAll(d.r)
|
|
if err != nil {
|
|
return fmt.Errorf("toml: %w", err)
|
|
}
|
|
|
|
p := parser{}
|
|
p.Reset(b)
|
|
dec := decoder{
|
|
strict: strict{
|
|
Enabled: d.strict,
|
|
},
|
|
}
|
|
|
|
return dec.FromParser(&p, v)
|
|
}
|
|
|
|
type decoder struct {
|
|
// Tracks position in Go arrays.
|
|
arrayIndexes map[reflect.Value]int
|
|
|
|
// Tracks keys that have been seen, with which type.
|
|
seen tracker.SeenTracker
|
|
|
|
// Strict mode
|
|
strict strict
|
|
}
|
|
|
|
func (d *decoder) arrayIndex(shouldAppend bool, v reflect.Value) int {
|
|
if d.arrayIndexes == nil {
|
|
d.arrayIndexes = make(map[reflect.Value]int, 1)
|
|
}
|
|
|
|
idx, ok := d.arrayIndexes[v]
|
|
|
|
if !ok {
|
|
d.arrayIndexes[v] = 0
|
|
} else if shouldAppend {
|
|
idx++
|
|
d.arrayIndexes[v] = idx
|
|
}
|
|
|
|
return idx
|
|
}
|
|
|
|
func (d *decoder) FromParser(p *parser, v interface{}) error {
|
|
err := d.fromParser(p, v)
|
|
if err == nil {
|
|
return d.strict.Error(p.data)
|
|
}
|
|
|
|
var e *decodeError
|
|
if errors.As(err, &e) {
|
|
return wrapDecodeError(p.data, e)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func keyLocation(node ast.Node) []byte {
|
|
k := node.Key()
|
|
|
|
hasOne := k.Next()
|
|
if !hasOne {
|
|
panic("should not be called with empty key")
|
|
}
|
|
|
|
start := k.Node().Data
|
|
end := k.Node().Data
|
|
|
|
for k.Next() {
|
|
end = k.Node().Data
|
|
}
|
|
|
|
return unsafe.BytesRange(start, end)
|
|
}
|
|
|
|
//nolint:funlen,cyclop
|
|
func (d *decoder) fromParser(p *parser, v interface{}) error {
|
|
r := reflect.ValueOf(v)
|
|
if r.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("toml: decoding can only be performed into a pointer, not %s", r.Kind())
|
|
}
|
|
|
|
if r.IsNil() {
|
|
return fmt.Errorf("toml: decoding pointer target cannot be nil")
|
|
}
|
|
|
|
var (
|
|
skipUntilTable bool
|
|
root target = valueTarget(r.Elem())
|
|
)
|
|
|
|
current := root
|
|
|
|
for p.NextExpression() {
|
|
node := p.Expression()
|
|
|
|
if node.Kind == ast.KeyValue && skipUntilTable {
|
|
continue
|
|
}
|
|
|
|
err := d.seen.CheckExpression(node)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var found bool
|
|
|
|
switch node.Kind {
|
|
case ast.KeyValue:
|
|
err = d.unmarshalKeyValue(current, node)
|
|
found = true
|
|
case ast.Table:
|
|
d.strict.EnterTable(node)
|
|
|
|
current, found, err = d.scopeWithKey(root, node.Key())
|
|
if err == nil && found {
|
|
// In case this table points to an interface,
|
|
// make sure it at least holds something that
|
|
// looks like a table. Otherwise the information
|
|
// of a table is lost, and marshal cannot do the
|
|
// round trip.
|
|
ensureMapIfInterface(current)
|
|
}
|
|
case ast.ArrayTable:
|
|
d.strict.EnterArrayTable(node)
|
|
current, found, err = d.scopeWithArrayTable(root, node.Key())
|
|
default:
|
|
panic(fmt.Sprintf("this should not be a top level node type: %s", node.Kind))
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !found {
|
|
skipUntilTable = true
|
|
|
|
d.strict.MissingTable(node)
|
|
}
|
|
}
|
|
|
|
return p.Error()
|
|
}
|
|
|
|
// scopeWithKey performs target scoping when unmarshaling an ast.KeyValue node.
|
|
//
|
|
// The goal is to hop from target to target recursively using the names in key.
|
|
// Parts of the key should be used to resolve field names for structs, and as
|
|
// keys when targeting maps.
|
|
//
|
|
// When encountering slices, it should always use its last element, and error
|
|
// if the slice does not have any.
|
|
func (d *decoder) scopeWithKey(x target, key ast.Iterator) (target, bool, error) {
|
|
var (
|
|
err error
|
|
found bool
|
|
)
|
|
|
|
for key.Next() {
|
|
n := key.Node()
|
|
|
|
x, found, err = d.scopeTableTarget(false, x, string(n.Data))
|
|
if err != nil || !found {
|
|
return nil, found, err
|
|
}
|
|
}
|
|
|
|
return x, true, nil
|
|
}
|
|
|
|
//nolint:cyclop
|
|
// scopeWithArrayTable performs target scoping when unmarshaling an
|
|
// ast.ArrayTable node.
|
|
//
|
|
// It is the same as scopeWithKey, but when scoping the last part of the key
|
|
// it creates a new element in the array instead of using the last one.
|
|
func (d *decoder) scopeWithArrayTable(x target, key ast.Iterator) (target, bool, error) {
|
|
var (
|
|
err error
|
|
found bool
|
|
)
|
|
|
|
for key.Next() {
|
|
n := key.Node()
|
|
if !n.Next().Valid() { // want to stop at one before last
|
|
break
|
|
}
|
|
|
|
x, found, err = d.scopeTableTarget(false, x, string(n.Data))
|
|
if err != nil || !found {
|
|
return nil, found, err
|
|
}
|
|
}
|
|
|
|
n := key.Node()
|
|
|
|
x, found, err = d.scopeTableTarget(false, x, string(n.Data))
|
|
if err != nil || !found {
|
|
return x, found, err
|
|
}
|
|
|
|
v := x.get()
|
|
|
|
if v.Kind() == reflect.Ptr {
|
|
x = scopePtr(x)
|
|
v = x.get()
|
|
}
|
|
|
|
if v.Kind() == reflect.Interface {
|
|
x = scopeInterface(true, x)
|
|
v = x.get()
|
|
}
|
|
|
|
switch v.Kind() {
|
|
case reflect.Slice:
|
|
x = scopeSlice(true, x)
|
|
case reflect.Array:
|
|
x, err = d.scopeArray(true, x)
|
|
default:
|
|
}
|
|
|
|
return x, found, err
|
|
}
|
|
|
|
func (d *decoder) unmarshalKeyValue(x target, node ast.Node) error {
|
|
assertNode(ast.KeyValue, node)
|
|
|
|
d.strict.EnterKeyValue(node)
|
|
defer d.strict.ExitKeyValue(node)
|
|
|
|
x, found, err := d.scopeWithKey(x, node.Key())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// A struct in the path was not found. Skip this value.
|
|
if !found {
|
|
d.strict.MissingField(node)
|
|
|
|
return nil
|
|
}
|
|
|
|
return d.unmarshalValue(x, node.Value())
|
|
}
|
|
|
|
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
|
|
|
|
func tryTextUnmarshaler(x target, node ast.Node) (bool, error) {
|
|
v := x.get()
|
|
|
|
if v.Kind() != reflect.Struct {
|
|
return false, nil
|
|
}
|
|
|
|
// Special case for time, because we allow to unmarshal to it from
|
|
// different kind of AST nodes.
|
|
if v.Type() == timeType {
|
|
return false, nil
|
|
}
|
|
|
|
if v.Type().Implements(textUnmarshalerType) {
|
|
err := v.Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
|
|
if err != nil {
|
|
return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) {
|
|
err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
|
|
if err != nil {
|
|
return false, newDecodeError(node.Data, "error calling UnmarshalText: %w", err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
//nolint:cyclop
|
|
func (d *decoder) unmarshalValue(x target, node ast.Node) error {
|
|
v := x.get()
|
|
|
|
if v.Kind() == reflect.Ptr {
|
|
if !v.Elem().IsValid() {
|
|
x.set(reflect.New(v.Type().Elem()))
|
|
v = x.get()
|
|
}
|
|
|
|
return d.unmarshalValue(valueTarget(v.Elem()), node)
|
|
}
|
|
|
|
ok, err := tryTextUnmarshaler(x, node)
|
|
if ok {
|
|
return err
|
|
}
|
|
|
|
switch node.Kind {
|
|
case ast.String:
|
|
return unmarshalString(x, node)
|
|
case ast.Bool:
|
|
return unmarshalBool(x, node)
|
|
case ast.Integer:
|
|
return unmarshalInteger(x, node)
|
|
case ast.Float:
|
|
return unmarshalFloat(x, node)
|
|
case ast.Array:
|
|
return d.unmarshalArray(x, node)
|
|
case ast.InlineTable:
|
|
return d.unmarshalInlineTable(x, node)
|
|
case ast.LocalDateTime:
|
|
return unmarshalLocalDateTime(x, node)
|
|
case ast.DateTime:
|
|
return unmarshalDateTime(x, node)
|
|
case ast.LocalDate:
|
|
return unmarshalLocalDate(x, node)
|
|
default:
|
|
panic(fmt.Sprintf("unhandled node kind %s", node.Kind))
|
|
}
|
|
}
|
|
|
|
func unmarshalLocalDate(x target, node ast.Node) error {
|
|
assertNode(ast.LocalDate, node)
|
|
|
|
v, err := parseLocalDate(node.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
setDate(x, v)
|
|
|
|
return nil
|
|
}
|
|
|
|
func unmarshalLocalDateTime(x target, node ast.Node) error {
|
|
assertNode(ast.LocalDateTime, node)
|
|
|
|
v, rest, err := parseLocalDateTime(node.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(rest) > 0 {
|
|
return newDecodeError(rest, "extra characters at the end of a local date time")
|
|
}
|
|
|
|
setLocalDateTime(x, v)
|
|
|
|
return nil
|
|
}
|
|
|
|
func unmarshalDateTime(x target, node ast.Node) error {
|
|
assertNode(ast.DateTime, node)
|
|
|
|
v, err := parseDateTime(node.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
setDateTime(x, v)
|
|
|
|
return nil
|
|
}
|
|
|
|
func setLocalDateTime(x target, v LocalDateTime) {
|
|
if x.get().Type() == timeType {
|
|
cast := v.In(time.Local)
|
|
|
|
setDateTime(x, cast)
|
|
return
|
|
}
|
|
|
|
x.set(reflect.ValueOf(v))
|
|
}
|
|
|
|
func setDateTime(x target, v time.Time) {
|
|
x.set(reflect.ValueOf(v))
|
|
}
|
|
|
|
var timeType = reflect.TypeOf(time.Time{})
|
|
|
|
func setDate(x target, v LocalDate) {
|
|
if x.get().Type() == timeType {
|
|
cast := v.In(time.Local)
|
|
|
|
setDateTime(x, cast)
|
|
return
|
|
}
|
|
|
|
x.set(reflect.ValueOf(v))
|
|
}
|
|
|
|
func unmarshalString(x target, node ast.Node) error {
|
|
assertNode(ast.String, node)
|
|
|
|
return setString(x, string(node.Data))
|
|
}
|
|
|
|
func unmarshalBool(x target, node ast.Node) error {
|
|
assertNode(ast.Bool, node)
|
|
v := node.Data[0] == 't'
|
|
|
|
return setBool(x, v)
|
|
}
|
|
|
|
func unmarshalInteger(x target, node ast.Node) error {
|
|
assertNode(ast.Integer, node)
|
|
|
|
v, err := parseInteger(node.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return setInt64(x, v)
|
|
}
|
|
|
|
func unmarshalFloat(x target, node ast.Node) error {
|
|
assertNode(ast.Float, node)
|
|
|
|
v, err := parseFloat(node.Data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return setFloat64(x, v)
|
|
}
|
|
|
|
func (d *decoder) unmarshalInlineTable(x target, node ast.Node) error {
|
|
assertNode(ast.InlineTable, node)
|
|
|
|
ensureMapIfInterface(x)
|
|
|
|
it := node.Children()
|
|
for it.Next() {
|
|
n := it.Node()
|
|
|
|
err := d.unmarshalKeyValue(x, n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *decoder) unmarshalArray(x target, node ast.Node) error {
|
|
assertNode(ast.Array, node)
|
|
|
|
err := ensureValueIndexable(x)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
idx := 0
|
|
|
|
it := node.Children()
|
|
for it.Next() {
|
|
n := it.Node()
|
|
|
|
v := elementAt(x, idx)
|
|
|
|
if v == nil {
|
|
// when we go out of bound for an array just stop processing it to
|
|
// mimic encoding/json
|
|
break
|
|
}
|
|
|
|
err = d.unmarshalValue(v, n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
idx++
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func assertNode(expected ast.Kind, node ast.Node) {
|
|
if node.Kind != expected {
|
|
panic(fmt.Sprintf("expected node of kind %s, not %s", expected, node.Kind))
|
|
}
|
|
}
|