diff --git a/internal/ast/ast.go b/internal/ast/ast.go index 0cd9f93..a8727e3 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -76,47 +76,90 @@ type Root []Node // Dot returns a dot representation of the AST for debugging. func (r Root) Sdot() string { type edge struct { - from int - to int + from int + childIdx int + to int } - var nodes []string + var nodes []Node var edges []edge // indexes into nodes - nodes = append(nodes, "root") + nodes = append(nodes, Node{ + Kind: Invalid, + Data: []byte(`ROOT`), + Children: r, + }) - labelForNode := func(node *Node) string { - return fmt.Sprintf("{%s}", node.Kind) - } - - var processNode func(int, *Node) - processNode = func(parentIdx int, node *Node) { + var processNode func(int, int, *Node) + processNode = func(parentIdx int, childIdx int, node *Node) { idx := len(nodes) - label := labelForNode(node) - nodes = append(nodes, label) - edges = append(edges, edge{from: parentIdx, to: idx}) + nodes = append(nodes, *node) + edges = append(edges, edge{ + from: parentIdx, + childIdx: childIdx, + to: idx, + }) - for _, c := range node.Children { - processNode(idx, &c) + for i, c := range node.Children { + processNode(idx, i, &c) } } - for _, n := range r { - processNode(0, &n) + for i, n := range r { + processNode(0, i, &n) } var b strings.Builder b.WriteString("digraph tree {\n") + b.WriteString("\tnode [shape=record];\n") - for i, label := range nodes { - _, _ = fmt.Fprintf(&b, "\tnode%d [label=\"%s\"];\n", i, label) + for i, node := range nodes { + label := "" + attrs := map[string]string{} + + if i == 0 { + var ports []string + for i := 0; i < len(node.Children); i++ { + ports = append(ports, fmt.Sprintf(" %d", i, i)) + } + joinedPorts := strings.Join(ports, "|") + label = fmt.Sprintf("{ROOT|{%s}}", joinedPorts) + } else { + fields := []string{node.Kind.String()} + if len(node.Data) > 0 { + fields = append(fields, string(node.Data)) + } + + var ports []string + for i := 0; i < len(node.Children); i++ { + ports = append(ports, fmt.Sprintf(" %d", i, i)) + } + joinedPorts := strings.Join(ports, "|") + + joinedFields := strings.Join(fields, "|") + label = fmt.Sprintf("{{%s}", joinedFields) + if len(ports) > 0 { + label += fmt.Sprintf("|{%s}", joinedPorts) + } + label += "}" + if node.Kind == Invalid { + attrs["style"] = "filled" + attrs["fillcolor"] = "red" + } + } + + _, _ = fmt.Fprintf(&b, "\tnode%d [label=\"%s\"", i, label) + for k, v := range attrs { + _, _ = fmt.Fprintf(&b, ", %s=\"%s\"", k, v) + } + _, _ = fmt.Fprintf(&b, "];\n") } b.WriteString("\n") for _, e := range edges { - _, _ = fmt.Fprintf(&b, "\tnode%d -> node%d;\n", e.from, e.to) + _, _ = fmt.Fprintf(&b, "\tnode%d:f%d -> node%d;\n", e.from, e.childIdx, e.to) } b.WriteString("}") diff --git a/unmarshaler.go b/unmarshaler.go index 1fee384..45f8fbc 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -2,6 +2,7 @@ package toml import ( "fmt" + "os" "reflect" "github.com/pelletier/go-toml/v2/internal/ast" @@ -13,9 +14,28 @@ func Unmarshal(data []byte, v interface{}) error { if err != nil { return err } + + // TODO: remove me; sanity check + allValidOrDump(p.tree, p.tree) + return fromAst(p.tree, v) } +func allValidOrDump(tree ast.Root, nodes []ast.Node) bool { + for i, n := range nodes { + if n.Kind == ast.Invalid { + fmt.Printf("AST contains invalid node! idx=%d\n", i) + fmt.Fprintf(os.Stderr, "%s\n", tree.Sdot()) + return false + } + ok := allValidOrDump(tree, n.Children) + if !ok { + return ok + } + } + return true +} + func fromAst(tree ast.Root, v interface{}) error { r := reflect.ValueOf(v) if r.Kind() != reflect.Ptr {