Skip to content

Commit 54a9448

Browse files
authored
Merge pull request #60 from tlm/as-type-support
Adds support for a new convenience function AsType and HasType.
2 parents 0ebb696 + a4ce23d commit 54a9448

2 files changed

Lines changed: 136 additions & 29 deletions

File tree

functions.go

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,17 @@ func Is(err, target error) bool {
351351
return stderrors.Is(err, target)
352352
}
353353

354-
// IsType is a convenience method for ascertaining if an error contains the
355-
// target error type within its chain. This is aimed at ease of development
356-
// where a more complicated error type wants to be to checked for existence but
357-
// pointer var of that type is too much overhead.
358-
func IsType[t error](err error) bool {
359-
for err != nil {
360-
if _, is := err.(t); is {
361-
return true
362-
}
363-
err = stderrors.Unwrap(err)
364-
}
365-
return false
354+
// HasType is a function wrapper around AsType dropping the where return value
355+
// from AsType() making a function that can be used like this:
356+
//
357+
// return HasType[*MyError](err)
358+
//
359+
// Or
360+
//
361+
// if HasType[*MyError](err) {}
362+
func HasType[T error](err error) bool {
363+
_, rval := AsType[T](err)
364+
return rval
366365
}
367366

368367
// As is a proxy for the As function in Go's standard `errors` library
@@ -371,6 +370,44 @@ func As(err error, target interface{}) bool {
371370
return stderrors.As(err, target)
372371
}
373372

373+
// AsType is a convenience method for checking and getting an error from within
374+
// a chain that is of type T. If no error is found of type T in the chain the
375+
// zero value of T is returned with false. If an error in the chain implementes
376+
// As(any) bool then it's As method will be called if it's type is not of type T.
377+
378+
// AsType finds the first error in err's chain that is assignable to type T, and
379+
// if a match is found, returns that error value and true. Otherwise, it returns
380+
// T's zero value and false.
381+
//
382+
// AsType is equivalent to errors.As, but uses a type parameter and returns
383+
// the target, to avoid having to define a variable before the call. For
384+
// example, callers can replace this:
385+
//
386+
// var pathError *fs.PathError
387+
// if errors.As(err, &pathError) {
388+
// fmt.Println("Failed at path:", pathError.Path)
389+
// }
390+
//
391+
// With:
392+
//
393+
// if pathError, ok := errors.AsType[*fs.PathError](err); ok {
394+
// fmt.Println("Failed at path:", pathError.Path)
395+
// }
396+
func AsType[T error](err error) (T, bool) {
397+
for err != nil {
398+
if e, is := err.(T); is {
399+
return e, true
400+
}
401+
var res T
402+
if x, ok := err.(interface{ As(any) bool }); ok && x.As(&res) {
403+
return res, true
404+
}
405+
err = stderrors.Unwrap(err)
406+
}
407+
var zero T
408+
return zero, false
409+
}
410+
374411
// SetLocation takes a given error and records where in the stack SetLocation
375412
// was called from and returns the wrapped error with the location information
376413
// set. The returned error implements the Locationer interface. If err is nil

functions_test.go

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,9 @@ func (*functionSuite) TestQuietWrappedErrorStillSatisfied(c *gc.C) {
403403
c.Assert(errors.Is(err, simpleTestError), gc.Equals, true)
404404
}
405405

406-
type FooError struct {
407-
}
408-
409-
func (*FooError) Error() string {
410-
return "I am here boss"
406+
type ComplexErrorMessage interface {
407+
error
408+
ComplexMessage() string
411409
}
412410

413411
type complexError struct {
@@ -418,30 +416,90 @@ func (c *complexError) Error() string {
418416
return c.Message
419417
}
420418

419+
func (c *complexError) ComplexMessage() string {
420+
return c.Message
421+
}
422+
421423
type complexErrorOther struct {
422424
Message string
423425
}
424426

427+
func (c *complexErrorOther) As(e any) bool {
428+
if ce, ok := e.(**complexError); ok {
429+
*ce = &complexError{
430+
Message: c.Message,
431+
}
432+
return true
433+
}
434+
return false
435+
}
436+
425437
func (c *complexErrorOther) Error() string {
426438
return c.Message
427439
}
428440

429-
func (*functionSuite) TestIsType(c *gc.C) {
441+
func (c *complexErrorOther) ComplexMessage() string {
442+
return c.Message
443+
}
444+
445+
func (*functionSuite) TestHasType(c *gc.C) {
430446
complexErr := &complexError{Message: "complex error message"}
431447
wrapped1 := fmt.Errorf("wrapping1: %w", complexErr)
432448
wrapped2 := fmt.Errorf("wrapping2: %w", wrapped1)
433449

434-
c.Assert(errors.IsType[*complexError](complexErr), gc.Equals, true)
435-
c.Assert(errors.IsType[*complexError](wrapped1), gc.Equals, true)
436-
c.Assert(errors.IsType[*complexError](wrapped2), gc.Equals, true)
437-
c.Assert(errors.IsType[*complexErrorOther](complexErr), gc.Equals, false)
438-
c.Assert(errors.IsType[*complexErrorOther](wrapped1), gc.Equals, false)
439-
c.Assert(errors.IsType[*complexErrorOther](wrapped2), gc.Equals, false)
450+
c.Assert(errors.HasType[*complexError](complexErr), gc.Equals, true)
451+
c.Assert(errors.HasType[*complexError](wrapped1), gc.Equals, true)
452+
c.Assert(errors.HasType[*complexError](wrapped2), gc.Equals, true)
453+
c.Assert(errors.HasType[ComplexErrorMessage](wrapped2), gc.Equals, true)
454+
c.Assert(errors.HasType[*complexErrorOther](wrapped2), gc.Equals, false)
455+
c.Assert(errors.HasType[*complexErrorOther](nil), gc.Equals, false)
440456

441-
err := errors.New("test")
442-
c.Assert(errors.IsType[*complexErrorOther](err), gc.Equals, false)
457+
complexErrOther := &complexErrorOther{Message: "another complex error"}
458+
459+
c.Assert(errors.HasType[*complexError](complexErrOther), gc.Equals, true)
443460

444-
c.Assert(errors.IsType[*complexErrorOther](nil), gc.Equals, false)
461+
wrapped2 = fmt.Errorf("wrapping1: %w", complexErrOther)
462+
c.Assert(errors.HasType[*complexError](wrapped2), gc.Equals, true)
463+
}
464+
465+
func (*functionSuite) TestAsType(c *gc.C) {
466+
complexErr := &complexError{Message: "complex error message"}
467+
wrapped1 := fmt.Errorf("wrapping1: %w", complexErr)
468+
wrapped2 := fmt.Errorf("wrapping2: %w", wrapped1)
469+
470+
ce, ok := errors.AsType[*complexError](complexErr)
471+
c.Assert(ok, gc.Equals, true)
472+
c.Assert(ce.Message, gc.Equals, complexErr.Message)
473+
474+
ce, ok = errors.AsType[*complexError](wrapped1)
475+
c.Assert(ok, gc.Equals, true)
476+
c.Assert(ce.Message, gc.Equals, complexErr.Message)
477+
478+
ce, ok = errors.AsType[*complexError](wrapped2)
479+
c.Assert(ok, gc.Equals, true)
480+
c.Assert(ce.Message, gc.Equals, complexErr.Message)
481+
482+
cem, ok := errors.AsType[ComplexErrorMessage](wrapped2)
483+
c.Assert(ok, gc.Equals, true)
484+
c.Assert(cem.ComplexMessage(), gc.Equals, complexErr.Message)
485+
486+
ceo, ok := errors.AsType[*complexErrorOther](wrapped2)
487+
c.Assert(ok, gc.Equals, false)
488+
c.Assert(ceo, gc.Equals, (*complexErrorOther)(nil))
489+
490+
ceo, ok = errors.AsType[*complexErrorOther](nil)
491+
c.Assert(ok, gc.Equals, false)
492+
c.Assert(ceo, gc.Equals, (*complexErrorOther)(nil))
493+
494+
complexErrOther := &complexErrorOther{Message: "another complex error"}
495+
ce, ok = errors.AsType[*complexError](complexErrOther)
496+
c.Assert(ok, gc.Equals, true)
497+
c.Assert(ce.Message, gc.Equals, complexErrOther.Message)
498+
499+
wrapped2 = fmt.Errorf("wrapping1: %w", complexErrOther)
500+
ce, ok = errors.AsType[*complexError](wrapped2)
501+
c.Assert(ok, gc.Equals, true)
502+
c.Assert(ce.Message, gc.Equals, complexErrOther.Message)
445503
}
446504

447505
func ExampleHide() {
@@ -464,12 +522,24 @@ func (m *MyError) Error() string {
464522
return m.Message
465523
}
466524

467-
func ExampleIsType() {
525+
func ExampleHasType() {
526+
myErr := &MyError{Message: "these are not the droids you're looking for"}
527+
err := fmt.Errorf("wrapped: %w", myErr)
528+
is := errors.HasType[*MyError](err)
529+
fmt.Println(is)
530+
531+
// Output:
532+
// true
533+
}
534+
535+
func ExampleAsType() {
468536
myErr := &MyError{Message: "these are not the droids you're looking for"}
469537
err := fmt.Errorf("wrapped: %w", myErr)
470-
is := errors.IsType[*MyError](err)
538+
myErr, is := errors.AsType[*MyError](err)
471539
fmt.Println(is)
540+
fmt.Println(myErr.Message)
472541

473542
// Output:
474543
// true
544+
// these are not the droids you're looking for
475545
}

0 commit comments

Comments
 (0)