diff --git a/app/api.go b/app/api.go index 8f5ddc3..e06f154 100644 --- a/app/api.go +++ b/app/api.go @@ -21,8 +21,8 @@ func NewAppWithConfig(sota *sotatoml.AppConfig, secretsDir string, unsafeHandler // Extract extracts secrets from the encrypted configuration into the secrets directory. // must be called once. A common way to do this is a systemd "oneshot" service after -// NetworkManager is up. -func (a *App) Extract() error { +// NetworkManager is up. It returns true if a config change was detected. +func (a *App) Extract() (bool, error) { return (*internal.App)(a).Extract() } @@ -30,7 +30,8 @@ func (a *App) Extract() error { // are changes, it applies them locally. If the config on the server is unchanged, // it returns NotModifiedError. This function also will try and run any pending // "init" functions required by Fioconfig that have not yet been run. -func (a *App) CheckIn() error { +// It returns true if a config change was detected. +func (a *App) CheckIn() (bool, error) { return (*internal.App)(a).CheckIn() } diff --git a/internal/app.go b/internal/app.go index 6aad8b9..5f3e481 100644 --- a/internal/app.go +++ b/internal/app.go @@ -107,10 +107,11 @@ func updateSecret(secretFile string, newContent []byte) (bool, error) { return true, sotatoml.SafeWrite(secretFile, newContent) } -func (a *App) extract(config configSnapshot) error { +func (a *App) extract(config configSnapshot) (bool, error) { + configChanged := false st, err := os.Stat(a.SecretsDir) if err != nil { - return err + return configChanged, err } all_fname := make(map[string]bool) @@ -120,45 +121,47 @@ func (a *App) extract(config configSnapshot) error { fullpath := filepath.Join(a.SecretsDir, fname) dirName := filepath.Dir(fullpath) if err := os.MkdirAll(dirName, st.Mode()); err != nil { - return fmt.Errorf("Unable to create parent directory secret: %s - %w", fullpath, err) + return configChanged, fmt.Errorf("Unable to create parent directory secret: %s - %w", fullpath, err) } changed, err := updateSecret(fullpath, []byte(cfgFile.Value)) if err != nil { - return err + return configChanged, err } if changed { + configChanged = true a.runOnChanged(fname, fullpath, cfgFile.OnChanged) } } // Now, watch for file removals (compare with a previous version if present) if config.prev == nil { - return nil + return configChanged, nil } for fname, cfgFile := range config.prev { if _, ok := all_fname[fname]; ok { continue } slog.Info("Removing file", "file", fname) + configChanged = true fullpath := filepath.Join(a.SecretsDir, fname) if err := os.Remove(fullpath); err != nil && !os.IsNotExist(err) { - return err + return configChanged, err } a.runOnChanged(fname, fullpath, cfgFile.OnChanged) } if err := DeleteEmptyDirs(a.SecretsDir); err != nil { slog.Error("Unable to remove empty directories", "error", err) } - return nil + return configChanged, nil } -func (a *App) Extract() error { +func (a *App) Extract() (bool, error) { _, crypto := createClient(a.sota) defer crypto.Close() config, err := UnmarshallFile(crypto, a.EncryptedConfig, true) if err != nil { - return err + return false, err } return a.extract(configSnapshot{nil, config}) } @@ -192,10 +195,9 @@ func (a *App) runOnChanged(fname string, fullpath string, onChanged []string) { } } -func (a *App) checkin(client *http.Client, crypto CryptoHandler) error { +func (a *App) checkin(client *http.Client, crypto CryptoHandler) (configChanged bool, err error) { headers := make(map[string]string) var config configSnapshot - var err error if config.prev, err = UnmarshallFile(nil, a.EncryptedConfig, false); err != nil { var perr *os.PathError @@ -210,41 +212,45 @@ func (a *App) checkin(client *http.Client, crypto CryptoHandler) error { res, err := transport.HttpGet(client, a.configUrl, headers) if err != nil { - return err // Unable to attempt request + return // Unable to attempt request } if res.StatusCode == 200 { if config.next, err = UnmarshallBuffer(crypto, res.Body, true); err != nil { - return err + return } - if err = a.extract(config); err != nil { - return err + if configChanged, err = a.extract(config); err != nil { + return } if err = sotatoml.SafeWrite(a.EncryptedConfig, res.Body); err != nil { - return err + return } - modtime, err := time.Parse(time.RFC1123, res.Header.Get("Date")) - if err != nil { - slog.Warn("Unable to get modtime of config file, defaulting to 'now'", "error", err) + modtime, err2 := time.Parse(time.RFC1123, res.Header.Get("Date")) + if err2 != nil { + slog.Warn("Unable to get modtime of config file, defaulting to 'now'", "error", err2) modtime = time.Now() } if err = os.Chtimes(a.EncryptedConfig, modtime, modtime); err != nil { - return fmt.Errorf("Unable to set modified time %s - %w", a.EncryptedConfig, err) + err = fmt.Errorf("Unable to set modified time %s - %w", a.EncryptedConfig, err) + return } - return nil + return } else if res.StatusCode == 304 { slog.Info("Config on server has not changed") - return NotModifiedError + err = NotModifiedError + return } else if res.StatusCode == 204 { slog.Info("Device has no config defined on server") - return NotModifiedError + err = NotModifiedError + return } - return fmt.Errorf("Unable to get %s - HTTP_%d: %s", a.configUrl, res.StatusCode, res.String()) + err = fmt.Errorf("Unable to get %s - HTTP_%d: %s", a.configUrl, res.StatusCode, res.String()) + return } -func (a *App) CheckIn() error { +func (a *App) CheckIn() (bool, error) { client, crypto := createClient(a.sota) defer crypto.Close() callInitFunctions(a, client) diff --git a/internal/app_test.go b/internal/app_test.go index 0d25352..96548a3 100644 --- a/internal/app_test.go +++ b/internal/app_test.go @@ -197,7 +197,7 @@ func assertNoFile(t *testing.T, path string) { func TestExtract(t *testing.T) { testWrapper(t, nil, func(app *App, client *http.Client, tempdir string) { - if err := app.Extract(); err != nil { + if _, err := app.Extract(); err != nil { t.Fatal(err) } @@ -210,7 +210,7 @@ func TestExtract(t *testing.T) { // Make sure files that don't change aren't updated os.Remove(barChanged) - if err := app.Extract(); err != nil { + if _, err := app.Extract(); err != nil { t.Fatal(err) } _, err := os.Stat(barChanged) @@ -223,8 +223,10 @@ func TestExtract(t *testing.T) { func TestSafeHandler(t *testing.T) { testWrapper(t, nil, func(app *App, client *http.Client, tempdir string) { app.unsafeHandlers = false - if err := app.Extract(); err != nil { + if changed, err := app.Extract(); err != nil { t.Fatal(err) + } else if !changed { + t.Fatal("Config change not detected") } barChanged := filepath.Join(tempdir, "bar-changed") _, err := os.Stat(barChanged) @@ -257,7 +259,9 @@ func TestHandlerExit(t *testing.T) { buf, err = json.Marshal(config) require.Nil(t, err) require.Nil(t, os.WriteFile(app.EncryptedConfig, buf, 0o777)) - require.Nil(t, app.Extract()) + changed, err := app.Extract() + require.Nil(t, err) + require.True(t, changed) require.True(t, called) }) } @@ -270,10 +274,11 @@ func TestCheckBad(t *testing.T) { testWrapper(t, doGet, func(app *App, client *http.Client, tempdir string) { _, crypto := createClient(app.sota) defer crypto.Close() - err := app.checkin(client, crypto) + changed, err := app.checkin(client, crypto) if err == nil { t.Fatal("Checkin should have gotten a 404") } + require.False(t, changed) if !strings.HasSuffix(strings.TrimSpace(err.Error()), "HTTP_404: 404 page not found") { t.Fatalf("Unexpected response: '%s'", err) @@ -306,7 +311,7 @@ func TestCheckCorruptedConfig(t *testing.T) { t.Fatal(err) } - if err := app.checkin(client, crypto); err != nil { + if _, err := app.checkin(client, crypto); err != nil { t.Fatal(err) } @@ -351,8 +356,10 @@ func TestCheckGood(t *testing.T) { // Remove this file so we can be sure the check-in creates it os.Remove(app.EncryptedConfig) - if err := app.checkin(client, crypto); err != nil { + if changed, err := app.checkin(client, crypto); err != nil { t.Fatal(err) + } else if !changed { + t.Fatal("Config-change not detected") } foo := filepath.Join(tempdir, "foo") @@ -380,13 +387,13 @@ func TestCheckGood(t *testing.T) { time.Sleep(1 * time.Millisecond) // Now make sure the if-not-modified logic works - if err := app.checkin(client, crypto); err != NotModifiedError { + if _, err := app.checkin(client, crypto); err != NotModifiedError { t.Fatal(err) } // Check that files removed on server are also removed on device and onChange is called removeBar = true - if err := app.checkin(client, crypto); err != nil { + if _, err := app.checkin(client, crypto); err != nil { t.Fatal(err) } diff --git a/main.go b/main.go index 6e22ecb..2d967d9 100644 --- a/main.go +++ b/main.go @@ -59,7 +59,7 @@ func extract(c *cli.Context) error { } } slog.Info("Extracting keys", "from", app.EncryptedConfig, "to", app.SecretsDir) - if err := app.Extract(); err != nil { + if _, err := app.Extract(); err != nil { if errors.Is(err, os.ErrNotExist) { slog.Info("Encrypted config does not exist") } else { @@ -76,7 +76,7 @@ func checkin(c *cli.Context) error { } slog.Info("Checking in with server ...") - if err := app.CheckIn(); err != nil && !errors.Is(err, internal.NotModifiedError) { + if _, err := app.CheckIn(); err != nil && !errors.Is(err, internal.NotModifiedError) { return err } return nil @@ -91,7 +91,7 @@ func daemon(c *cli.Context) error { slog.Info("Running as daemon", "interval", c.Int("interval")) for { slog.Info("Checking in with server") - if err := app.CheckIn(); err != nil && !errors.Is(err, internal.NotModifiedError) { + if _, err := app.CheckIn(); err != nil && !errors.Is(err, internal.NotModifiedError) { slog.Error("Check-in failed", "error", err) } time.Sleep(interval)