diff --git a/README.md b/README.md index 8db97e4..d2c4c02 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # dotenvgo -[![Go Version](https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat-square&logo=go)](https://golang.org) +[![Go Version](https://img.shields.io/badge/Go-1.22+-00ADD8?style=flat-square&logo=go)](https://golang.org) [![License](https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square)](LICENSE) [![Zero Dependencies](https://img.shields.io/badge/Dependencies-Zero-green?style=flat-square)](go.mod) [![Coverage](https://img.shields.io/badge/Coverage-91%25-brightgreen?style=flat-square)](https://github.com/godeh/dotenvgo) @@ -13,6 +13,8 @@ go get github.com/godeh/dotenvgo ``` +Requires Go 1.22 or newer. + ## Quick Examples ### Type-Safe Variables @@ -201,8 +203,13 @@ Pointer fields are supported during struct loading. For scalar pointer fields su KEY=value MESSAGE="Hello World" DEBUG=true # inline comment +export PORT=8080 + +# Multiline quoted values +CERT="line1 +line2" -# Variable expansion +# Variable expansion (uses the current environment and variables loaded earlier in the file) BASE=/app CONFIG=${BASE}/config ``` diff --git a/dotenv_file.go b/dotenv_file.go new file mode 100644 index 0000000..0e04fc1 --- /dev/null +++ b/dotenv_file.go @@ -0,0 +1,182 @@ +package dotenvgo + +import ( + "os" + "strings" +) + +// LoadDotEnv loads environment variables from a .env file. +// By default, it does NOT override existing environment variables. +// Pass true as the second argument to override existing variables. +// +// Examples: +// +// LoadDotEnv(".env") // doesn't override existing vars +// LoadDotEnv(".env", false) // same as above +// LoadDotEnv(".env", true) // overrides existing vars +func LoadDotEnv(path string, override ...bool) error { + shouldOverride := false + if len(override) > 0 { + shouldOverride = override[0] + } + return loadDotEnvInternal(path, shouldOverride) +} + +func loadDotEnvInternal(path string, override bool) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + + entries, err := parseDotEnvEntries(string(data)) + if err != nil { + return err + } + + resolved := make(map[string]string, len(entries)) + for _, entry := range entries { + value := os.Expand(entry.value, func(name string) string { + if resolvedValue, ok := resolved[name]; ok { + return resolvedValue + } + envValue, _ := os.LookupEnv(name) + return envValue + }) + + if override { + _ = os.Setenv(entry.key, value) + resolved[entry.key] = value + continue + } + + if existing, exists := os.LookupEnv(entry.key); exists { + resolved[entry.key] = existing + continue + } + + _ = os.Setenv(entry.key, value) + resolved[entry.key] = value + } + + return nil +} + +type dotenvEntry struct { + key string + value string +} + +func parseDotEnvEntries(data string) ([]dotenvEntry, error) { + lines := strings.Split(data, "\n") + entries := make([]dotenvEntry, 0, len(lines)) + + for i := 0; i < len(lines); i++ { + line := strings.TrimSpace(lines[i]) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + if strings.HasPrefix(line, "export ") { + line = strings.TrimSpace(strings.TrimPrefix(line, "export ")) + } + + before, after, ok := strings.Cut(line, "=") + if !ok { + continue + } + + key := strings.TrimSpace(before) + if key == "" { + continue + } + + value, nextLine, err := parseDotEnvValue(lines, i, strings.TrimLeft(after, " \t")) + if err != nil { + return nil, err + } + + entries = append(entries, dotenvEntry{key: key, value: value}) + i = nextLine + } + + return entries, nil +} + +func parseDotEnvValue(lines []string, start int, valuePart string) (string, int, error) { + if valuePart == "" { + return "", start, nil + } + + quote := valuePart[0] + if quote == '"' || quote == '\'' { + return parseQuotedDotEnvValue(lines, start, valuePart, quote) + } + + return parseUnquotedDotEnvValue(valuePart), start, nil +} + +func parseQuotedDotEnvValue(lines []string, start int, valuePart string, quote byte) (string, int, error) { + var b strings.Builder + currentLine := start + segment := valuePart[1:] + + for { + for i := 0; i < len(segment); i++ { + ch := segment[i] + if quote == '"' && ch == '\\' && i+1 < len(segment) { + i++ + switch segment[i] { + case 'n': + b.WriteByte('\n') + case 'r': + b.WriteByte('\r') + case 't': + b.WriteByte('\t') + case '\\', '"', '$': + b.WriteByte(segment[i]) + default: + b.WriteByte(segment[i]) + } + continue + } + + if ch == quote { + return b.String(), currentLine, nil + } + + b.WriteByte(ch) + } + + if currentLine+1 >= len(lines) { + return b.String(), currentLine, nil + } + + b.WriteByte('\n') + currentLine++ + segment = lines[currentLine] + } +} + +func parseUnquotedDotEnvValue(valuePart string) string { + value := valuePart + for i := 0; i < len(valuePart); i++ { + if valuePart[i] != '#' { + continue + } + if i == 0 { + return "" + } + if valuePart[i-1] == ' ' || valuePart[i-1] == '\t' { + return strings.TrimSpace(valuePart[:i]) + } + } + + return value +} + +// MustLoadDotEnv loads a .env file or panics. +func MustLoadDotEnv(path string) { + if err := LoadDotEnv(path); err != nil { + panic(err) + } +} diff --git a/dotenvgo.go b/dotenvgo.go index d57fbec..32da73f 100644 --- a/dotenvgo.go +++ b/dotenvgo.go @@ -1,310 +1 @@ package dotenvgo - -import ( - "encoding" - "fmt" - "os" - "reflect" - "strings" -) - -// Var represents an environment variable with type-safe access. -type Var[T any] struct { - key string - defaultValue *T - required bool - parser func(string) (T, error) - prefix string -} - -// New creates a new environment variable of type T using the default loader. -func New[T any](key string) *Var[T] { - return NewVar[T](DefaultLoader, key) -} - -// NewVar creates a new environment variable of type T using the specified Loader. -// It searches for a registered parser or uses encoding.TextUnmarshaler. -func NewVar[T any](l *Loader, key string) *Var[T] { - var zero T - typ := reflect.TypeOf(zero) - - // 1. Check registry - if p, ok := l.getParser(typ); ok { - return &Var[T]{ - key: key, - parser: func(s string) (T, error) { - v, err := p(s) - if err != nil { - return zero, err - } - return v.(T), nil - }, - } - } - - // 2. Check TextUnmarshaler - // Better approach: Generic Parser Factory - return &Var[T]{ - key: key, - parser: func(s string) (T, error) { - // Re-lookup registry (fast) - if p, ok := l.getParser(typ); ok { - v, err := p(s) - if err != nil { - return zero, err - } - return v.(T), nil - } - - // TextUnmarshaler - valPtr := reflect.New(typ) - if u, ok := valPtr.Interface().(encoding.TextUnmarshaler); ok { - if err := u.UnmarshalText([]byte(s)); err != nil { - return zero, err - } - return valPtr.Elem().Interface().(T), nil - } - - return zero, fmt.Errorf("dotenvgo: no parser registered for type %v", typ) - }, - } -} - -// Default sets the default value if the environment variable is not set. -func (v *Var[T]) Default(value T) *Var[T] { - v.defaultValue = &value - return v -} - -// Required marks the environment variable as required. -// Get() will panic if the variable is not set. -// GetE() will return an error. -func (v *Var[T]) Required() *Var[T] { - v.required = true - return v -} - -// WithPrefix adds a prefix to the environment variable key. -// For example, WithPrefix("APP").String("PORT") will look for "APP_PORT". -func (v *Var[T]) WithPrefix(prefix string) *Var[T] { - v.prefix = prefix - return v -} - -// fullKey returns the full environment variable key with prefix. -func (v *Var[T]) fullKey() string { - if v.prefix != "" { - return v.prefix + "_" + v.key - } - return v.key -} - -// Get returns the value of the environment variable. -// Panics if the variable is required but not set. -func (v *Var[T]) Get() T { - value, err := v.GetE() - if err != nil { - panic(err) - } - return value -} - -// GetE returns the value of the environment variable or an error. -func (v *Var[T]) GetE() (T, error) { - var zero T - key := v.fullKey() - raw, exists := os.LookupEnv(key) - - if !exists { - if v.required { - return zero, &RequiredError{Key: key} - } - if v.defaultValue != nil { - return *v.defaultValue, nil - } - return zero, nil - } - - value, err := v.parser(os.ExpandEnv(raw)) - if err != nil { - return zero, &ParseError{Key: key, Value: raw, Err: err} - } - - return value, nil -} - -// Lookup returns the value and whether it was set. -func (v *Var[T]) Lookup() (T, bool) { - var zero T - key := v.fullKey() - raw, exists := os.LookupEnv(key) - - if !exists { - if v.defaultValue != nil { - return *v.defaultValue, true - } - return zero, false - } - - value, err := v.parser(os.ExpandEnv(raw)) - if err != nil { - return zero, false - } - - return value, true -} - -// MustGet returns the value or panics if there's an error. -// Alias for Get(). -func (v *Var[T]) MustGet() T { - return v.Get() -} - -// IsSet returns whether the environment variable is set. -func (v *Var[T]) IsSet() bool { - key := v.fullKey() - _, exists := os.LookupEnv(key) - return exists -} - -// LoadDotEnv loads environment variables from a .env file. -// By default, it does NOT override existing environment variables. -// Pass true as the second argument to override existing variables. -// -// Examples: -// -// LoadDotEnv(".env") // doesn't override existing vars -// LoadDotEnv(".env", false) // same as above -// LoadDotEnv(".env", true) // overrides existing vars -func LoadDotEnv(path string, override ...bool) error { - shouldOverride := false - if len(override) > 0 { - shouldOverride = override[0] - } - return loadDotEnvInternal(path, shouldOverride) -} - -func loadDotEnvInternal(path string, override bool) error { - data, err := os.ReadFile(path) - if err != nil { - return err - } - - lines := strings.Split(string(data), "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - - // Skip empty lines and comments - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - // Find the first '=' - before, after, ok := strings.Cut(line, "=") - if !ok { - continue - } - - key := strings.TrimSpace(before) - valPart := strings.TrimSpace(after) - var value string - - if len(valPart) > 0 { - quote := valPart[0] - if quote == '"' || quote == '\'' { - // Quoted value: look for matching close quote - // We start looking from index 1 - if endIdx := strings.IndexByte(valPart[1:], quote); endIdx != -1 { - // endIdx is relative to valPart[1:], so actual index in valPart is endIdx + 1 - value = valPart[1 : endIdx+1] - } else { - // Unclosed quote, take valid part or whole? - // Usually behave as if unquoted or error. - // For simplicity/robustness, treat as unquoted if not closed properly - // or just take the whole thing. - // Let's assume the user meant it to be the value if unclosed. - value = valPart - } - } else { - // Unquoted value: stop at first '#' if it is preceded by a space - value = valPart - for i := 0; i < len(valPart); i++ { - if valPart[i] == '#' { - // Comment matches if it's the first char (handled mostly by loop skip, but technically possible here if valPart is just #) - // OR if preceded by whitespace - if i == 0 { - value = "" - break - } - // Check previous char for whitespace - if i > 0 && (valPart[i-1] == ' ' || valPart[i-1] == '\t') { - value = strings.TrimSpace(valPart[:i]) - break - } - } - } - } - } - - // Only set if not already set (unless override) - if override { - _ = os.Setenv(key, value) - continue - } - - if _, exists := os.LookupEnv(key); !exists { - _ = os.Setenv(key, value) - } - } - - return nil -} - -// MustLoadDotEnv loads a .env file or panics. -func MustLoadDotEnv(path string) { - if err := LoadDotEnv(path); err != nil { - panic(err) - } -} - -// Export returns all environment variables as a map. -func Export() map[string]string { - result := make(map[string]string) - for _, env := range os.Environ() { - idx := strings.Index(env, "=") - if idx != -1 { - result[env[:idx]] = env[idx+1:] - } - } - return result -} - -// ExportWithPrefix returns environment variables matching a prefix. -func ExportWithPrefix(prefix string) map[string]string { - result := make(map[string]string) - prefixUpper := strings.ToUpper(prefix) - if !strings.HasSuffix(prefixUpper, "_") { - prefixUpper += "_" - } - - for _, env := range os.Environ() { - idx := strings.Index(env, "=") - if idx != -1 { - key := env[:idx] - if strings.HasPrefix(strings.ToUpper(key), prefixUpper) { - result[key] = env[idx+1:] - } - } - } - return result -} - -// Set sets an environment variable. -func Set(key, value string) { - _ = os.Setenv(key, value) -} - -// Unset removes an environment variable. -func Unset(key string) { - _ = os.Unsetenv(key) -} diff --git a/dotenvgo_test.go b/dotenvgo_test.go index 039c272..cf23bef 100644 --- a/dotenvgo_test.go +++ b/dotenvgo_test.go @@ -2,6 +2,7 @@ package dotenvgo import ( "os" + "strings" "testing" ) @@ -94,6 +95,48 @@ func TestLoadDotEnv(t *testing.T) { } } +func TestLoadDotEnvParsingFeatures(t *testing.T) { + filename := ".env.features" + content := []byte(strings.Join([]string{ + "export EXPORTED=from_export", + "BASE_URL=https://example.com", + `API_URL="${BASE_URL}/api"`, + `ESCAPED="tab:\t newline:\n quote:\" slash:\\"`, + `SINGLE_QUOTED='single quoted value'`, + "MULTILINE=\"line1", + "line2\"", + }, "\n")) + if err := os.WriteFile(filename, content, 0o644); err != nil { + t.Fatal(err) + } + defer func() { _ = os.Remove(filename) }() + defer Unset("EXPORTED") + defer Unset("BASE_URL") + defer Unset("API_URL") + defer Unset("ESCAPED") + defer Unset("SINGLE_QUOTED") + defer Unset("MULTILINE") + + if err := LoadDotEnv(filename, true); err != nil { + t.Fatalf("LoadDotEnv failed: %v", err) + } + + checks := map[string]string{ + "EXPORTED": "from_export", + "BASE_URL": "https://example.com", + "API_URL": "https://example.com/api", + "ESCAPED": "tab:\t newline:\n quote:\" slash:\\", + "SINGLE_QUOTED": "single quoted value", + "MULTILINE": "line1\nline2", + } + + for key, expected := range checks { + if got := os.Getenv(key); got != expected { + t.Errorf("Key %s: expected %q, got %q", key, expected, got) + } + } +} + func TestParsers(t *testing.T) { _ = os.Setenv("INT_VAL", "123") _ = os.Setenv("BOOL_VAL", "true") @@ -355,7 +398,7 @@ UNCLOSED="unclosed "SINGLE_QUOTED": "single quoted", "WITH_HASH": "val#ue", "WITH_COMMENT": "value", - "UNCLOSED": "\"unclosed", // Implementation dependent, usually raw + "UNCLOSED": "unclosed\n\t", } for k, expected := range checks { diff --git a/env_resolve.go b/env_resolve.go new file mode 100644 index 0000000..9608a4b --- /dev/null +++ b/env_resolve.go @@ -0,0 +1,41 @@ +package dotenvgo + +import "os" + +type resolvedEnvValue struct { + raw string + value string + exists bool +} + +func resolveEnvValue(key string, required bool) (resolvedEnvValue, error) { + raw, exists := os.LookupEnv(key) + if !exists { + if required { + return resolvedEnvValue{}, &RequiredError{Key: key} + } + return resolvedEnvValue{}, nil + } + + return resolvedEnvValue{ + raw: raw, + value: os.ExpandEnv(raw), + exists: true, + }, nil +} + +func resolveFieldValue(key, defaultValue string, required bool) (string, bool, error) { + resolved, err := resolveEnvValue(key, required) + if err != nil { + return "", false, err + } + if resolved.exists { + return resolved.value, true, nil + } + + if defaultValue == "" { + return "", false, nil + } + + return os.ExpandEnv(defaultValue), true, nil +} diff --git a/env_utils.go b/env_utils.go new file mode 100644 index 0000000..724fb34 --- /dev/null +++ b/env_utils.go @@ -0,0 +1,48 @@ +package dotenvgo + +import ( + "os" + "strings" +) + +// Export returns all environment variables as a map. +func Export() map[string]string { + result := make(map[string]string) + for _, env := range os.Environ() { + idx := strings.Index(env, "=") + if idx != -1 { + result[env[:idx]] = env[idx+1:] + } + } + return result +} + +// ExportWithPrefix returns environment variables matching a prefix. +func ExportWithPrefix(prefix string) map[string]string { + result := make(map[string]string) + prefixUpper := strings.ToUpper(prefix) + if !strings.HasSuffix(prefixUpper, "_") { + prefixUpper += "_" + } + + for _, env := range os.Environ() { + idx := strings.Index(env, "=") + if idx != -1 { + key := env[:idx] + if strings.HasPrefix(strings.ToUpper(key), prefixUpper) { + result[key] = env[idx+1:] + } + } + } + return result +} + +// Set sets an environment variable. +func Set(key, value string) { + _ = os.Setenv(key, value) +} + +// Unset removes an environment variable. +func Unset(key string) { + _ = os.Unsetenv(key) +} diff --git a/loader.go b/loader.go index 1fd5736..0a8a7e0 100644 --- a/loader.go +++ b/loader.go @@ -2,10 +2,9 @@ package dotenvgo import ( "encoding" + "errors" "fmt" - "os" "reflect" - "strings" ) // Load populates a struct from environment variables using struct tags. @@ -70,7 +69,7 @@ func MustLoadWithPrefix(cfg any, prefix string) { func (l *Loader) loadStruct(v reflect.Value, prefix string) (bool, error) { t := v.Type() - var errors []error + var errs []error loadedAny := false for i := 0; i < t.NumField(); i++ { @@ -78,7 +77,7 @@ func (l *Loader) loadStruct(v reflect.Value, prefix string) (bool, error) { fieldValue := v.Field(i) envKey := field.Tag.Get("env") - // Skip unexported fields + // Skip unexported fields. if !fieldValue.CanSet() { continue } @@ -93,7 +92,7 @@ func (l *Loader) loadStruct(v reflect.Value, prefix string) (bool, error) { nestedValue := reflect.New(nestedType).Elem() nestedLoaded, err := l.loadStruct(nestedValue, nestedPrefix) if err != nil { - errors = append(errors, err) + errs = appendError(errs, err) } if nestedLoaded { target := reflect.New(nestedType) @@ -106,7 +105,7 @@ func (l *Loader) loadStruct(v reflect.Value, prefix string) (bool, error) { nestedLoaded, err := l.loadStruct(fieldValue, nestedPrefix) if err != nil { - errors = append(errors, err) + errs = appendError(errs, err) } if nestedLoaded { loadedAny = true @@ -114,9 +113,7 @@ func (l *Loader) loadStruct(v reflect.Value, prefix string) (bool, error) { continue } - // Handle structs (embedded or named) that don't have a parser/unmarshaler if field.Type.Kind() == reflect.Struct { - // Check if it's a "leaf" type (has parser or implements TextUnmarshaler) _, hasParser := l.getParser(field.Type) isUnmarshaler := field.Type.Implements(reflect.TypeFor[encoding.TextUnmarshaler]()) || reflect.PointerTo(field.Type).Implements(reflect.TypeFor[encoding.TextUnmarshaler]()) @@ -126,42 +123,32 @@ func (l *Loader) loadStruct(v reflect.Value, prefix string) (bool, error) { } } - // Get struct tags if envKey == "" { continue } defaultValue := field.Tag.Get("default") required := field.Tag.Get("required") == "true" - - // Build full key with prefix fullKey := joinEnvKey(prefix, envKey) - // Get value from environment - rawValue, exists := os.LookupEnv(fullKey) - value := os.ExpandEnv(rawValue) - if !exists { - if required { - errors = append(errors, &RequiredError{Key: fullKey}) - continue - } - value = os.ExpandEnv(defaultValue) + value, exists, err := resolveFieldValue(fullKey, defaultValue, required) + if err != nil { + errs = appendError(errs, err) + continue } - - if !exists && value == "" { + if !exists { continue } - // Parse and set value if err := l.setField(fieldValue, field.Tag, value); err != nil { - errors = append(errors, &ParseError{Key: fullKey, Value: value, Err: err}) + errs = append(errs, &ParseError{Key: fullKey, Value: value, Err: err}) continue } loadedAny = true } - if len(errors) > 0 { - return loadedAny, &MultiError{Errors: errors} + if len(errs) > 0 { + return loadedAny, &MultiError{Errors: errs} } return loadedAny, nil } @@ -173,6 +160,19 @@ func joinEnvKey(prefix, key string) string { return prefix + "_" + key } +func appendError(existing []error, err error) []error { + if err == nil { + return existing + } + + var multiErr *MultiError + if errors.As(err, &multiErr) { + return append(existing, multiErr.Errors...) + } + + return append(existing, err) +} + func (l *Loader) nestedStructType(t reflect.Type) (reflect.Type, bool) { baseType := t if baseType.Kind() == reflect.Pointer { @@ -232,7 +232,6 @@ func (l *Loader) setField(field reflect.Value, tag reflect.StructTag, value stri return nil } - // 0. Handle custom separator for slices if field.Kind() == reflect.Slice { sep := tag.Get("sep") if sep != "" || field.Type().Elem().Kind() == reflect.Pointer { @@ -248,7 +247,6 @@ func (l *Loader) setField(field reflect.Value, tag reflect.StructTag, value stri } } - // 1. Check if type has a registered parser if parser, ok := l.getParser(field.Type()); ok { parsed, err := parser(value) if err != nil { @@ -258,14 +256,11 @@ func (l *Loader) setField(field reflect.Value, tag reflect.StructTag, value stri return nil } - // 2. Check if field implements encoding.TextUnmarshaler if field.CanAddr() { - // Try pointer receiver if u, ok := field.Addr().Interface().(encoding.TextUnmarshaler); ok { return u.UnmarshalText([]byte(value)) } } else if u, ok := field.Interface().(encoding.TextUnmarshaler); ok { - // Try value receiver (less common for mutation but possible) return u.UnmarshalText([]byte(value)) } @@ -273,22 +268,12 @@ func (l *Loader) setField(field reflect.Value, tag reflect.StructTag, value stri } func (l *Loader) parseSlice(sliceType reflect.Type, tag reflect.StructTag, value, sep string) (reflect.Value, error) { - parts := strings.Split(value, sep) - slice := reflect.MakeSlice(sliceType, 0, len(parts)) elemType := sliceType.Elem() - - for _, part := range parts { - part = strings.TrimSpace(part) - if part == "" { - continue - } - + return parseSliceValue(sliceType, value, sep, func(part string) (reflect.Value, error) { elem := reflect.New(elemType).Elem() if err := l.setField(elem, tag, part); err != nil { - return reflect.Value{}, fmt.Errorf("dotenvgo: no parser registered for slice element type %v: %w", elemType, err) + return reflect.Value{}, err } - slice = reflect.Append(slice, elem) - } - - return slice, nil + return elem, nil + }) } diff --git a/loader_test.go b/loader_test.go index 6ccd1f9..679559f 100644 --- a/loader_test.go +++ b/loader_test.go @@ -3,6 +3,7 @@ package dotenvgo import ( "errors" "os" + "strings" "testing" "time" ) @@ -336,9 +337,9 @@ func TestLoadErrors(t *testing.T) { } }) - t.Run("Parse Error", func(t *testing.T) { - setEnv(t, "REQUIRED_VAR", "ok") - setEnv(t, "PORT", "invalid-int") + t.Run("Parse Error", func(t *testing.T) { + setEnv(t, "REQUIRED_VAR", "ok") + setEnv(t, "PORT", "invalid-int") var cfg TestConfig err := Load(&cfg) @@ -351,9 +352,71 @@ func TestLoadErrors(t *testing.T) { t.Fatal("Expected MultiError with at least one error") } + var parseErr *ParseError + if !errors.As(multiErr.Errors[0], &parseErr) { + t.Errorf("Expected ParseError, got %T", multiErr.Errors[0]) + } + }) + + t.Run("Nested Errors Are Flattened", func(t *testing.T) { + type Database struct { + URL string `env:"URL" required:"true"` + User string `env:"USER" required:"true"` + } + + type Config struct { + DB Database `env:"DB"` + } + + var cfg Config + err := Load(&cfg) + if err == nil { + t.Fatal("Expected error for missing nested required vars") + } + + var multiErr *MultiError + if !errors.As(err, &multiErr) { + t.Fatalf("Expected MultiError, got %T", err) + } + if len(multiErr.Errors) != 2 { + t.Fatalf("Expected 2 flattened errors, got %d", len(multiErr.Errors)) + } + + for _, item := range multiErr.Errors { + var reqErr *RequiredError + if !errors.As(item, &reqErr) { + t.Fatalf("Expected flattened RequiredError, got %T", item) + } + } + }) + + t.Run("Slice Pointer Parse Error Keeps Cause", func(t *testing.T) { + type Config struct { + IDs []*int `env:"IDS"` + } + + setEnv(t, "IDS", "1, nope, 3") + + var cfg Config + err := Load(&cfg) + if err == nil { + t.Fatal("Expected parse error for invalid slice element") + } + + var multiErr *MultiError + if !errors.As(err, &multiErr) || len(multiErr.Errors) == 0 { + t.Fatalf("Expected MultiError with errors, got %v", err) + } + var parseErr *ParseError if !errors.As(multiErr.Errors[0], &parseErr) { - t.Errorf("Expected ParseError, got %T", multiErr.Errors[0]) + t.Fatalf("Expected ParseError, got %T", multiErr.Errors[0]) + } + if !strings.Contains(parseErr.Error(), "cannot parse") { + t.Fatalf("Expected parse-oriented message, got %q", parseErr.Error()) + } + if strings.Contains(parseErr.Error(), "no parser registered") { + t.Fatalf("Expected real parse cause, got %q", parseErr.Error()) } }) } diff --git a/registry.go b/registry.go index 62c2325..443b5fb 100644 --- a/registry.go +++ b/registry.go @@ -9,6 +9,8 @@ import ( "time" ) +var errorType = reflect.TypeFor[error]() + // Loader manages the configuration loading and parser registry. type Loader struct { mu sync.RWMutex @@ -130,7 +132,7 @@ func (l *Loader) RegisterParser(parser any) { if t.NumIn() != 1 || t.In(0).Kind() != reflect.String { panic("parser must take a single string argument") } - if t.NumOut() != 2 || t.Out(1).Name() != "error" { + if t.NumOut() != 2 || !t.Out(1).Implements(errorType) { panic("parser must return (T, error)") } @@ -141,7 +143,15 @@ func (l *Loader) RegisterParser(parser any) { l.registry[targetType] = func(s string) (any, error) { res := v.Call([]reflect.Value{reflect.ValueOf(s)}) - errVal := res[1].Interface() + errResult := res[1] + switch errResult.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + if errResult.IsNil() { + return res[0].Interface(), nil + } + } + + errVal := errResult.Interface() if errVal != nil { return nil, errVal.(error) } @@ -179,21 +189,15 @@ func (l *Loader) getParser(t reflect.Type) (func(string) (any, error), bool) { if elemParser, ok := l.registry[elemType]; ok { // Generate slice parser dynamically sliceParser := func(s string) (any, error) { - if s == "" { - return reflect.MakeSlice(t, 0, 0).Interface(), nil - } - parts := strings.Split(s, ",") - slice := reflect.MakeSlice(t, 0, len(parts)) - for _, p := range parts { - trimmed := strings.TrimSpace(p) - if trimmed == "" { - continue - } - val, err := elemParser(trimmed) + slice, err := parseSliceValue(t, s, ",", func(part string) (reflect.Value, error) { + val, err := elemParser(part) if err != nil { - return nil, err + return reflect.Value{}, err } - slice = reflect.Append(slice, reflect.ValueOf(val)) + return reflect.ValueOf(val), nil + }) + if err != nil { + return nil, err } return slice.Interface(), nil } diff --git a/registry_test.go b/registry_test.go index 7055588..76a6e82 100644 --- a/registry_test.go +++ b/registry_test.go @@ -1,6 +1,7 @@ package dotenvgo import ( + "fmt" "reflect" "testing" ) @@ -260,3 +261,86 @@ func TestLoaderIsolation(t *testing.T) { t.Error("DefaultLoader should NOT have a parser for ValidationStatus") } } + +type fakeError struct { + message string +} + +func (e *fakeError) Error() string { + return e.message +} + +func TestRegisterParserValidation(t *testing.T) { + t.Run("accepts custom error type", func(t *testing.T) { + loader := NewLoader() + type parserType int + + loader.RegisterParser(func(s string) (parserType, *fakeError) { + return parserType(len(s)), nil + }) + + parser, ok := loader.getParser(reflect.TypeFor[parserType]()) + if !ok { + t.Fatal("expected parser to be registered") + } + + value, err := parser("abc") + if err != nil { + t.Fatalf("expected parser to succeed: %v", err) + } + if value.(parserType) != 3 { + t.Fatalf("expected parsed value 3, got %v", value) + } + }) + + t.Run("rejects non error second return", func(t *testing.T) { + loader := NewLoader() + assertPanicsWithMessage(t, "parser must return (T, error)", func() { + loader.RegisterParser(func(s string) (int, fmt.Stringer) { + return 0, nil + }) + }) + }) + + t.Run("rejects wrong signatures", func(t *testing.T) { + loader := NewLoader() + + assertPanicsWithMessage(t, "parser must be a function", func() { + loader.RegisterParser(123) + }) + + assertPanicsWithMessage(t, "parser must take a single string argument", func() { + loader.RegisterParser(func(int) (string, error) { + return "", nil + }) + }) + + assertPanicsWithMessage(t, "parser must return (T, error)", func() { + loader.RegisterParser(func(string) string { + return "" + }) + }) + }) +} + +func assertPanicsWithMessage(t *testing.T, message string, fn func()) { + t.Helper() + + defer func() { + recovered := recover() + if recovered == nil { + t.Fatalf("expected panic %q", message) + } + + panicMessage, ok := recovered.(string) + if !ok { + t.Fatalf("expected panic string %q, got %T", message, recovered) + } + + if panicMessage != message { + t.Fatalf("expected panic %q, got %q", message, panicMessage) + } + }() + + fn() +} diff --git a/slice_parse.go b/slice_parse.go new file mode 100644 index 0000000..0d4e153 --- /dev/null +++ b/slice_parse.go @@ -0,0 +1,30 @@ +package dotenvgo + +import ( + "reflect" + "strings" +) + +func parseSliceValue(sliceType reflect.Type, value, sep string, parseElem func(string) (reflect.Value, error)) (reflect.Value, error) { + if sep == "" { + sep = "," + } + + parts := strings.Split(value, sep) + slice := reflect.MakeSlice(sliceType, 0, len(parts)) + + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + elem, err := parseElem(part) + if err != nil { + return reflect.Value{}, err + } + slice = reflect.Append(slice, elem) + } + + return slice, nil +} diff --git a/var.go b/var.go new file mode 100644 index 0000000..766be86 --- /dev/null +++ b/var.go @@ -0,0 +1,161 @@ +package dotenvgo + +import ( + "encoding" + "fmt" + "os" + "reflect" +) + +// Var represents an environment variable with type-safe access. +type Var[T any] struct { + key string + defaultValue *T + required bool + parser func(string) (T, error) + prefix string +} + +// New creates a new environment variable of type T using the default loader. +func New[T any](key string) *Var[T] { + return NewVar[T](DefaultLoader, key) +} + +// NewVar creates a new environment variable of type T using the specified Loader. +// It searches for a registered parser or uses encoding.TextUnmarshaler. +func NewVar[T any](l *Loader, key string) *Var[T] { + var zero T + typ := reflect.TypeOf(zero) + + if p, ok := l.getParser(typ); ok { + return &Var[T]{ + key: key, + parser: func(s string) (T, error) { + v, err := p(s) + if err != nil { + return zero, err + } + return v.(T), nil + }, + } + } + + return &Var[T]{ + key: key, + parser: func(s string) (T, error) { + if p, ok := l.getParser(typ); ok { + v, err := p(s) + if err != nil { + return zero, err + } + return v.(T), nil + } + + valPtr := reflect.New(typ) + if u, ok := valPtr.Interface().(encoding.TextUnmarshaler); ok { + if err := u.UnmarshalText([]byte(s)); err != nil { + return zero, err + } + return valPtr.Elem().Interface().(T), nil + } + + return zero, fmt.Errorf("dotenvgo: no parser registered for type %v", typ) + }, + } +} + +// Default sets the default value if the environment variable is not set. +func (v *Var[T]) Default(value T) *Var[T] { + v.defaultValue = &value + return v +} + +// Required marks the environment variable as required. +// Get() will panic if the variable is not set. +// GetE() will return an error. +func (v *Var[T]) Required() *Var[T] { + v.required = true + return v +} + +// WithPrefix adds a prefix to the environment variable key. +// For example, WithPrefix("APP").String("PORT") will look for "APP_PORT". +func (v *Var[T]) WithPrefix(prefix string) *Var[T] { + v.prefix = prefix + return v +} + +// Get returns the value of the environment variable. +// Panics if the variable is required but not set. +func (v *Var[T]) Get() T { + value, err := v.GetE() + if err != nil { + panic(err) + } + return value +} + +// GetE returns the value of the environment variable or an error. +func (v *Var[T]) GetE() (T, error) { + var zero T + key := v.fullKey() + resolved, err := resolveEnvValue(key, v.required) + if err != nil { + return zero, err + } + if !resolved.exists { + if v.defaultValue != nil { + return *v.defaultValue, nil + } + return zero, nil + } + + value, err := v.parser(resolved.value) + if err != nil { + return zero, &ParseError{Key: key, Value: resolved.raw, Err: err} + } + + return value, nil +} + +// Lookup returns the value and whether it was set. +func (v *Var[T]) Lookup() (T, bool) { + var zero T + key := v.fullKey() + raw, exists := os.LookupEnv(key) + + if !exists { + if v.defaultValue != nil { + return *v.defaultValue, true + } + return zero, false + } + + value, err := v.parser(os.ExpandEnv(raw)) + if err != nil { + return zero, false + } + + return value, true +} + +// MustGet returns the value or panics if there's an error. +// Alias for Get(). +func (v *Var[T]) MustGet() T { + return v.Get() +} + +// IsSet returns whether the environment variable is set. +func (v *Var[T]) IsSet() bool { + key := v.fullKey() + _, exists := os.LookupEnv(key) + return exists +} + +// fullKey returns the full environment variable key with prefix. +func (v *Var[T]) fullKey() string { + if v.prefix != "" { + return v.prefix + "_" + v.key + } + return v.key +}