Skip to content
This repository was archived by the owner on Mar 4, 2025. It is now read-only.

Commit 18b9ea0

Browse files
committed
config: when no config file is found, generate an initial one
1 parent 8cb7169 commit 18b9ea0

3 files changed

Lines changed: 72 additions & 5 deletions

File tree

cmd/dio_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ func (s *DioSuite) SetUpSuite(c *chk.C) {
9292
if err != nil {
9393
log.Fatalln(err)
9494
}
95+
defer f.Close()
9596
d, err := os.Getwd()
9697
if err != nil {
9798
log.Fatalln(err)

cmd/root.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,34 @@ func init() {
6161

6262
// Read all of our configuration data now
6363
if cfgFile != "" {
64-
// Use config file from the flag.
64+
// Use config file from the flag
6565
viper.SetConfigFile(cfgFile)
6666
} else {
67-
// Find home directory.
67+
// Find home directory
6868
home, err := homedir.Dir()
6969
if err != nil {
7070
fmt.Println(err)
7171
os.Exit(1)
7272
}
7373

74-
// Search config in home directory with name ".dio" (without extension).
75-
viper.AddConfigPath(filepath.Join(home, ".dio"))
74+
// Search for config in ".dio" subdirectory under the users home directory
75+
p := filepath.Join(home, ".dio")
76+
viper.AddConfigPath(p)
7677
viper.SetConfigName("config")
78+
cfgFile = filepath.Join(p, "config.toml")
7779
}
7880

7981
// If a config file is found, read it in.
8082
if err := viper.ReadInConfig(); err != nil {
81-
log.Fatalf("Error loading config file: %s", err.Error())
83+
// No configuration file was found, so generate a default one and let the user know they need to supply the
84+
// missing info
85+
errInner := generateConfig(cfgFile)
86+
if errInner != nil {
87+
log.Fatalln(errInner)
88+
return
89+
}
90+
log.Fatalf("No usable configuration file was found, so a default one has been generated in: %s\n"+
91+
"Please update it with your name, and the path to your DBHub.io user certificate file.\n", cfgFile)
8292
return
8393
}
8494

cmd/shared.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cmd
33
import (
44
"bytes"
55
"crypto/sha256"
6+
"crypto/tls"
67
"crypto/x509"
78
"encoding/hex"
89
"encoding/json"
@@ -17,6 +18,7 @@ import (
1718
"strings"
1819
"time"
1920

21+
"github.com/mitchellh/go-homedir"
2022
rq "github.com/parnurzeal/gorequest"
2123
)
2224

@@ -169,6 +171,60 @@ var getDatabases = func(url string, user string) (dbList []dbListEntry, err erro
169171
return
170172
}
171173

174+
// Generates an initial default (production) configuration file. Before it's useful, the user will need to fill out
175+
// their display name + provide a DB4S certificate file
176+
func generateConfig(cfgFile string) (err error) {
177+
// Create the ".dio" directory in the users home folder, to store the configuration file in
178+
var home string
179+
home, err = homedir.Dir()
180+
if err != nil {
181+
return
182+
}
183+
if _, err = os.Stat(filepath.Join(home, ".dio")); os.IsNotExist(err) {
184+
err = os.Mkdir(filepath.Join(home, ".dio"), 0770)
185+
if err != nil {
186+
return
187+
}
188+
}
189+
190+
// Download the Certificate Authority chain file
191+
caURL := "https://github.com/sqlitebrowser/dio/raw/master/cert/ca-chain.cert.pem"
192+
chainFile := filepath.Join(home, ".dio", "ca-chain.cert.pem")
193+
resp, body, errs := rq.New().TLSClientConfig(&tls.Config{InsecureSkipVerify: true}).Get(caURL).EndBytes()
194+
if errs != nil {
195+
e := fmt.Sprintln("errors when retrieving the CA chain file:")
196+
for _, errInner := range errs {
197+
e += fmt.Sprintf(errInner.Error())
198+
}
199+
return errors.New(e)
200+
}
201+
defer resp.Body.Close()
202+
err = ioutil.WriteFile(chainFile, body, 0644)
203+
if err != nil {
204+
return err
205+
}
206+
207+
// Generate the initial config file
208+
const CFG = `[certs]
209+
cachain = "%s"
210+
cert = "/path/to/your/certificate/here"
211+
212+
[general]
213+
cloud = "https://dbhub.io:5550"
214+
215+
[user]
216+
name = "Your Name"
217+
`
218+
var f *os.File
219+
f, err = os.Create(cfgFile)
220+
if err != nil {
221+
return
222+
}
223+
defer f.Close()
224+
_, err = fmt.Fprintf(f, CFG, chainFile)
225+
return
226+
}
227+
172228
// Returns the name of the default database, if one has been selected. Returns an empty string if not
173229
func getDefaultDatabase() (db string, err error) {
174230
// Check if the local defaults info exists

0 commit comments

Comments
 (0)