@@ -8,14 +8,15 @@ import (
88 "context"
99 "fmt"
1010 "go/ast"
11- "go/importer"
1211 "go/parser"
1312 "go/token"
1413 "go/types"
1514 "io/ioutil"
1615 "os"
1716 "os/exec"
1817 "path/filepath"
18+ "regexp"
19+ "strings"
1920
2021 "github.com/dave/jennifer/jen"
2122 "github.com/pkg/errors"
@@ -24,7 +25,85 @@ import (
2425// TODO(saswatamcode): Add tests.
2526// TODO(saswatamcode): Check jennifer code for some safety.
2627// TODO(saswatamcode): Add mechanism for caching output from generated code.
27- // TODO(saswatamcode): Currently takes file names, need to make it module based(something such as https://golang.org/pkg/cmd/go/internal/list/).
28+
29+ // getSourceFromMod fetches source code file from $GOPATH/pkg/mod.
30+ func getSourceFromMod (root string , structName string ) ([]byte , error ) {
31+ var src []byte
32+ stopWalk := errors .New ("stop walking" )
33+
34+ // Walk source dir.
35+ err := filepath .Walk (root , func (path string , info os.FileInfo , err error ) error {
36+ // Check if file is Go code.
37+ if ! info .IsDir () && filepath .Ext (path ) == ".go" && err == nil {
38+ src , err = ioutil .ReadFile (path )
39+ if err != nil {
40+ return errors .Wrapf (err , "read file for yaml gen %v" , path )
41+ }
42+ // Check if file has struct.
43+ if bytes .Contains (src , []byte ("type " + structName + " struct" )) {
44+ return stopWalk
45+ }
46+ }
47+ return nil
48+ })
49+ if err == stopWalk {
50+ err = nil
51+ }
52+
53+ return src , err
54+ }
55+
56+ // GetSource get source code of file containing the struct.
57+ func GetSource (ctx context.Context , structLocation string ) ([]byte , error ) {
58+ // Get struct name.
59+ loc := strings .Split (structLocation , ":" )
60+
61+ // Check if it is a local file.
62+ _ , err := os .Stat (loc [0 ])
63+ if err == nil {
64+ src , err := ioutil .ReadFile (loc [0 ])
65+ if err != nil {
66+ return nil , errors .Wrapf (err , "read file for yaml gen %v" , loc [0 ])
67+ }
68+
69+ // As it is a local file, return source directly.
70+ return src , nil
71+ }
72+
73+ // Not local file so must be module. Will be of form `github.com/bwplotka/mdox@v0.2.2-0.20210712170635-f49414cc6b5a/pkg/mdformatter/linktransformer:Config`.
74+ // Split using version number (if provided).
75+ getModule := loc [0 ]
76+ moduleName := strings .SplitN (loc [0 ], "@" , 2 )
77+ if len (moduleName ) == 2 {
78+ // Split package dir (if provided).
79+ pkg := strings .SplitN (moduleName [1 ], "/" , 2 )
80+ if len (pkg ) == 2 {
81+ getModule = moduleName [0 ] + "@" + pkg [0 ]
82+ }
83+ }
84+ //TODO(saswatamcode): Handle case where version number not present but package name is.
85+
86+ // Fetch module.
87+ cmd := exec .CommandContext (ctx , "go" , "get" , "-u" , getModule )
88+ err = cmd .Run ()
89+ if err != nil {
90+ return nil , errors .Wrapf (err , "run %v" , cmd )
91+ }
92+
93+ // Get GOPATH.
94+ goPath , ok := os .LookupEnv ("GOPATH" )
95+ if ! ok {
96+ return nil , errors .New ("GOPATH not set" )
97+ }
98+
99+ // Get source file of struct.
100+ file , err := getSourceFromMod (filepath .Join (goPath , "pkg/mod" , loc [0 ]), loc [1 ])
101+ if err != nil {
102+ return nil , err
103+ }
104+
105+ return file , nil
106+ }
28107
29108// GenGoCode generates Go code for yaml gen from structs in src file.
30109func GenGoCode (src []byte ) (string , error ) {
@@ -56,7 +135,11 @@ func GenGoCode(src []byte) (string, error) {
56135 if typeDecl , ok := genericDecl .Specs [0 ].(* ast.TypeSpec ); ok {
57136 var structFields []jen.Code
58137 // Cast to `type struct`.
59- structDecl := typeDecl .Type .(* ast.StructType )
138+ structDecl , ok := typeDecl .Type .(* ast.StructType )
139+ if ! ok {
140+ generatedCode .Type ().Id (typeDecl .Name .Name ).Id (string (src [typeDecl .Type .Pos ()- 1 : typeDecl .Type .End ()- 1 ]))
141+ continue
142+ }
60143 fields := structDecl .Fields .List
61144 arrayInit := make (jen.Dict )
62145
@@ -68,20 +151,12 @@ func GenGoCode(src []byte) (string, error) {
68151 if n .IsExported () {
69152 pos := n .Obj .Decl .(* ast.Field )
70153
71- // Make type map to check if field is array.
72- info := types.Info {Types : make (map [ast.Expr ]types.TypeAndValue )}
73- _ , err = (& types.Config {Importer : importer .ForCompiler (fset , "source" , nil )}).Check ("mypkg" , fset , []* ast.File {f }, & info )
74- if err != nil {
75- return "" , err
76- }
77- typ := info .Types [field .Type ].Type
78-
79- switch typ .(type ) {
80- case * types.Slice :
81- // Field is of type array so initialize it using code like `[]Type{Type{}}`.
154+ // Check if field is a slice type.
155+ sliceRe := regexp .MustCompile (`.*\[.*\].*` )
156+ if sliceRe .MatchString (types .ExprString (field .Type )) {
82157 arrayInit [jen .Id (n .Name )] = jen .Id (string (src [pos .Type .Pos ()- 1 : pos .Type .End ()- 1 ])).Values (jen .Id (string (src [pos .Type .Pos ()+ 1 : pos .Type .End ()- 1 ])).Values ())
83- default :
84158 }
159+
85160 // Copy struct field to generated code.
86161 if pos .Tag != nil {
87162 structFields = append (structFields , jen .Id (n .Name ).Id (string (src [pos .Type .Pos ()- 1 :pos .Type .End ()- 1 ])).Id (pos .Tag .Value ))
0 commit comments