Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions app/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ func NewAppWithConfig(sota *sotatoml.AppConfig, secretsDir string, unsafeHandler

// Extract extracts secrets from the encrypted configuration into the secrets directory.
// must be called once. A common way to do this is a systemd "oneshot" service after
// NetworkManager is up.
func (a *App) Extract() error {
// NetworkManager is up. It returns true if a config change was detected.
func (a *App) Extract() (bool, error) {
return (*internal.App)(a).Extract()
}

// CheckIn checks with the device gateway for the lastest configuration. If there
// are changes, it applies them locally. If the config on the server is unchanged,
// it returns NotModifiedError. This function also will try and run any pending
// "init" functions required by Fioconfig that have not yet been run.
func (a *App) CheckIn() error {
// It returns true if a config change was detected.
func (a *App) CheckIn() (bool, error) {
return (*internal.App)(a).CheckIn()
}

Expand Down
56 changes: 31 additions & 25 deletions internal/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ func updateSecret(secretFile string, newContent []byte) (bool, error) {
return true, sotatoml.SafeWrite(secretFile, newContent)
}

func (a *App) extract(config configSnapshot) error {
func (a *App) extract(config configSnapshot) (bool, error) {
configChanged := false
st, err := os.Stat(a.SecretsDir)
if err != nil {
return err
return configChanged, err
}

all_fname := make(map[string]bool)
Expand All @@ -120,45 +121,47 @@ func (a *App) extract(config configSnapshot) error {
fullpath := filepath.Join(a.SecretsDir, fname)
dirName := filepath.Dir(fullpath)
if err := os.MkdirAll(dirName, st.Mode()); err != nil {
return fmt.Errorf("Unable to create parent directory secret: %s - %w", fullpath, err)
return configChanged, fmt.Errorf("Unable to create parent directory secret: %s - %w", fullpath, err)
}
changed, err := updateSecret(fullpath, []byte(cfgFile.Value))
if err != nil {
return err
return configChanged, err
}
if changed {
configChanged = true
a.runOnChanged(fname, fullpath, cfgFile.OnChanged)
}
}

// Now, watch for file removals (compare with a previous version if present)
if config.prev == nil {
return nil
return configChanged, nil
}
for fname, cfgFile := range config.prev {
if _, ok := all_fname[fname]; ok {
continue
}
slog.Info("Removing file", "file", fname)
configChanged = true
fullpath := filepath.Join(a.SecretsDir, fname)
if err := os.Remove(fullpath); err != nil && !os.IsNotExist(err) {
return err
return configChanged, err
}
a.runOnChanged(fname, fullpath, cfgFile.OnChanged)
}
if err := DeleteEmptyDirs(a.SecretsDir); err != nil {
slog.Error("Unable to remove empty directories", "error", err)
}
return nil
return configChanged, nil
}

func (a *App) Extract() error {
func (a *App) Extract() (bool, error) {
_, crypto := createClient(a.sota)
defer crypto.Close()

config, err := UnmarshallFile(crypto, a.EncryptedConfig, true)
if err != nil {
return err
return false, err
}
return a.extract(configSnapshot{nil, config})
}
Expand Down Expand Up @@ -192,10 +195,9 @@ func (a *App) runOnChanged(fname string, fullpath string, onChanged []string) {
}
}

func (a *App) checkin(client *http.Client, crypto CryptoHandler) error {
func (a *App) checkin(client *http.Client, crypto CryptoHandler) (configChanged bool, err error) {
headers := make(map[string]string)
var config configSnapshot
var err error

if config.prev, err = UnmarshallFile(nil, a.EncryptedConfig, false); err != nil {
var perr *os.PathError
Expand All @@ -210,41 +212,45 @@ func (a *App) checkin(client *http.Client, crypto CryptoHandler) error {

res, err := transport.HttpGet(client, a.configUrl, headers)
if err != nil {
return err // Unable to attempt request
return // Unable to attempt request
}

if res.StatusCode == 200 {
if config.next, err = UnmarshallBuffer(crypto, res.Body, true); err != nil {
return err
return
}

if err = a.extract(config); err != nil {
return err
if configChanged, err = a.extract(config); err != nil {
return
}
if err = sotatoml.SafeWrite(a.EncryptedConfig, res.Body); err != nil {
return err
return
}

modtime, err := time.Parse(time.RFC1123, res.Header.Get("Date"))
if err != nil {
slog.Warn("Unable to get modtime of config file, defaulting to 'now'", "error", err)
modtime, err2 := time.Parse(time.RFC1123, res.Header.Get("Date"))
if err2 != nil {
slog.Warn("Unable to get modtime of config file, defaulting to 'now'", "error", err2)
modtime = time.Now()
}
if err = os.Chtimes(a.EncryptedConfig, modtime, modtime); err != nil {
return fmt.Errorf("Unable to set modified time %s - %w", a.EncryptedConfig, err)
err = fmt.Errorf("Unable to set modified time %s - %w", a.EncryptedConfig, err)
return
}
return nil
return
} else if res.StatusCode == 304 {
slog.Info("Config on server has not changed")
return NotModifiedError
err = NotModifiedError
return
} else if res.StatusCode == 204 {
slog.Info("Device has no config defined on server")
return NotModifiedError
err = NotModifiedError
return
}
return fmt.Errorf("Unable to get %s - HTTP_%d: %s", a.configUrl, res.StatusCode, res.String())
err = fmt.Errorf("Unable to get %s - HTTP_%d: %s", a.configUrl, res.StatusCode, res.String())
return
}

func (a *App) CheckIn() error {
func (a *App) CheckIn() (bool, error) {
client, crypto := createClient(a.sota)
defer crypto.Close()
callInitFunctions(a, client)
Expand Down
25 changes: 16 additions & 9 deletions internal/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func assertNoFile(t *testing.T, path string) {

func TestExtract(t *testing.T) {
testWrapper(t, nil, func(app *App, client *http.Client, tempdir string) {
if err := app.Extract(); err != nil {
if _, err := app.Extract(); err != nil {
t.Fatal(err)
}

Expand All @@ -210,7 +210,7 @@ func TestExtract(t *testing.T) {

// Make sure files that don't change aren't updated
os.Remove(barChanged)
if err := app.Extract(); err != nil {
if _, err := app.Extract(); err != nil {
t.Fatal(err)
}
_, err := os.Stat(barChanged)
Expand All @@ -223,8 +223,10 @@ func TestExtract(t *testing.T) {
func TestSafeHandler(t *testing.T) {
testWrapper(t, nil, func(app *App, client *http.Client, tempdir string) {
app.unsafeHandlers = false
if err := app.Extract(); err != nil {
if changed, err := app.Extract(); err != nil {
t.Fatal(err)
} else if !changed {
t.Fatal("Config change not detected")
}
barChanged := filepath.Join(tempdir, "bar-changed")
_, err := os.Stat(barChanged)
Expand Down Expand Up @@ -257,7 +259,9 @@ func TestHandlerExit(t *testing.T) {
buf, err = json.Marshal(config)
require.Nil(t, err)
require.Nil(t, os.WriteFile(app.EncryptedConfig, buf, 0o777))
require.Nil(t, app.Extract())
changed, err := app.Extract()
require.Nil(t, err)
require.True(t, changed)
require.True(t, called)
})
}
Expand All @@ -270,10 +274,11 @@ func TestCheckBad(t *testing.T) {
testWrapper(t, doGet, func(app *App, client *http.Client, tempdir string) {
_, crypto := createClient(app.sota)
defer crypto.Close()
err := app.checkin(client, crypto)
changed, err := app.checkin(client, crypto)
if err == nil {
t.Fatal("Checkin should have gotten a 404")
}
require.False(t, changed)

if !strings.HasSuffix(strings.TrimSpace(err.Error()), "HTTP_404: 404 page not found") {
t.Fatalf("Unexpected response: '%s'", err)
Expand Down Expand Up @@ -306,7 +311,7 @@ func TestCheckCorruptedConfig(t *testing.T) {
t.Fatal(err)
}

if err := app.checkin(client, crypto); err != nil {
if _, err := app.checkin(client, crypto); err != nil {
t.Fatal(err)
}

Expand Down Expand Up @@ -351,8 +356,10 @@ func TestCheckGood(t *testing.T) {
// Remove this file so we can be sure the check-in creates it
os.Remove(app.EncryptedConfig)

if err := app.checkin(client, crypto); err != nil {
if changed, err := app.checkin(client, crypto); err != nil {
t.Fatal(err)
} else if !changed {
t.Fatal("Config-change not detected")
}

foo := filepath.Join(tempdir, "foo")
Expand Down Expand Up @@ -380,13 +387,13 @@ func TestCheckGood(t *testing.T) {
time.Sleep(1 * time.Millisecond)

// Now make sure the if-not-modified logic works
if err := app.checkin(client, crypto); err != NotModifiedError {
if _, err := app.checkin(client, crypto); err != NotModifiedError {
t.Fatal(err)
}

// Check that files removed on server are also removed on device and onChange is called
removeBar = true
if err := app.checkin(client, crypto); err != nil {
if _, err := app.checkin(client, crypto); err != nil {
t.Fatal(err)
}

Expand Down
6 changes: 3 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func extract(c *cli.Context) error {
}
}
slog.Info("Extracting keys", "from", app.EncryptedConfig, "to", app.SecretsDir)
if err := app.Extract(); err != nil {
if _, err := app.Extract(); err != nil {
if errors.Is(err, os.ErrNotExist) {
slog.Info("Encrypted config does not exist")
} else {
Expand All @@ -76,7 +76,7 @@ func checkin(c *cli.Context) error {
}

slog.Info("Checking in with server ...")
if err := app.CheckIn(); err != nil && !errors.Is(err, internal.NotModifiedError) {
if _, err := app.CheckIn(); err != nil && !errors.Is(err, internal.NotModifiedError) {
return err
}
return nil
Expand All @@ -91,7 +91,7 @@ func daemon(c *cli.Context) error {
slog.Info("Running as daemon", "interval", c.Int("interval"))
for {
slog.Info("Checking in with server")
if err := app.CheckIn(); err != nil && !errors.Is(err, internal.NotModifiedError) {
if _, err := app.CheckIn(); err != nil && !errors.Is(err, internal.NotModifiedError) {
slog.Error("Check-in failed", "error", err)
}
time.Sleep(interval)
Expand Down
Loading