Skip to content

Commit a4ce23d

Browse files
committed
Adds AsType & HasType and removes IsType.
The original implementation of IsType was wrong and asserted assign ability over comparability. Now we have an AsType that works the same As but creates it's own target with generics and a new HasType that disregards the target return of AsType.
1 parent 0ebb696 commit a4ce23d

2 files changed

Lines changed: 136 additions & 29 deletions

File tree

functions.go

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,17 @@ func Is(err, target error) bool {
351351
return stderrors.Is(err, target)
352352
}
353353

354-
// IsType is a convenience method for ascertaining if an error contains the
355-
// target error type within its chain. This is aimed at ease of development
356-
// where a more complicated error type wants to be to checked for existence but
357-
// pointer var of that type is too much overhead.
358-
func IsType[t error](err error) bool {
359-
for err != nil {
360-
if _, is := err.(t); is {
361-
return true
362-
}
363-
err = stderrors.Unwrap(err)
364-
}
365-
return false
354+
// HasType is a function wrapper around AsType dropping the where return value
355+
// from AsType() making a function that can be used like this:
356+
//
357+
// return HasType[*MyError](err)
358+
//
359+
// Or
360+
//
361+
// if HasType[*MyError](err) {}
362+
func HasType[T error](err error) bool {
363+
_, rval := AsType[T](err)
364+
return rval
366365
}
367366

368367
// As is a proxy for the As function in Go's standard `errors` library
@@ -371,6 +370,44 @@ func As(err error, target interface{}) bool {
371370
return stderrors.As(err, target)
372371
}
373372

373+
// AsType is a convenience method for checking and getting an error from within
374+
// a chain that is of type T. If no error is found of type T in the chain the
375+
// zero value of T is returned with false. If an error in the chain implementes
376+
// As(any) bool then it's As method will be called if it's type is not of type T.
377+
378+
// AsType finds the first error in err's chain that is assignable to type T, and
379+
// if a match is found, returns that error value and true. Otherwise, it returns
380+
// T's zero value and false.
381+
//
382+
// AsType is equivalent to errors.As, but uses a type parameter and returns
383+
// the target, to avoid having to define a variable before the call. For
384+
// example, callers can replace this:
385+
//
386+
// var pathError *fs.PathError
387+
// if errors.As(err, &pathError) {
388+
// fmt.Println("Failed at path:", pathError.Path)
389+
// }
390+
//
391+
// With:
392+
//
393+
// if pathError, ok := errors.AsType[*fs.PathError](err); ok {
394+
// fmt.Println("Failed at path:", pathError.Path)
395+
// }
396+
func AsType[T error](err error) (T, bool) {
397+
for err != nil {
398+
if e, is := err.(T); is {
399+
return e, true
400+
}
401+
var res T
402+
if x, ok := err.(interface{ As(any) bool }); ok && x.As(&res) {
403+
return res, true
404+
}
405+
err = stderrors.Unwrap(err)
406+
}
407+
var zero T
408+
return zero, false
409+
}
410+
374411
// SetLocation takes a given error and records where in the stack SetLocation
375412
// was called from and returns the wrapped error with the location information
376413
// set. The returned error implements the Locationer interface. If err is nil

functions_test.go

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,9 @@ func (*functionSuite) TestQuietWrappedErrorStillSatisfied(c *gc.C) {
403403
c.Assert(errors.Is(err, simpleTestError), gc.Equals, true)
404404
}
405405

406-
type FooError struct {
407-
}
408-
409-
func (*FooError) Error() string {
410-
return "I am here boss"
406+
type ComplexErrorMessage interface {
407+
error
408+
ComplexMessage() string
411409
}
412410

413411
type complexError struct {
@@ -418,30 +416,90 @@ func (c *complexError) Error() string {
418416
return c.Message
419417
}
420418

419+
func (c *complexError) ComplexMessage() string {
420+
return c.Message
421+
}
422+
421423
type complexErrorOther struct {
422424
Message string
423425
}
424426

427+
func (c *complexErrorOther) As(e any) bool {
428+
if ce, ok := e.(**complexError); ok {
429+
*ce = &complexError{
430+
Message: c.Message,
431+
}
432+
return true
433+
}
434+
return false
435+
}
436+
425437
func (c *complexErrorOther) Error() string {
426438
return c.Message
427439
}
428440

429-
func (*functionSuite) TestIsType(c *gc.C) {
441+
func (c *complexErrorOther) ComplexMessage() string {
442+
return c.Message
443+
}
444+
445+
func (*functionSuite) TestHasType(c *gc.C) {
430446
complexErr := &complexError{Message: "complex error message"}
431447
wrapped1 := fmt.Errorf("wrapping1: %w", complexErr)
432448
wrapped2 := fmt.Errorf("wrapping2: %w", wrapped1)
433449

434-
c.Assert(errors.IsType[*complexError](complexErr), gc.Equals, true)
435-
c.Assert(errors.IsType[*complexError](wrapped1), gc.Equals, true)
436-
c.Assert(errors.IsType[*complexError](wrapped2), gc.Equals, true)
437-
c.Assert(errors.IsType[*complexErrorOther](complexErr), gc.Equals, false)
438-
c.Assert(errors.IsType[*complexErrorOther](wrapped1), gc.Equals, false)
439-
c.Assert(errors.IsType[*complexErrorOther](wrapped2), gc.Equals, false)
450+
c.Assert(errors.HasType[*complexError](complexErr), gc.Equals, true)
451+
c.Assert(errors.HasType[*complexError](wrapped1), gc.Equals, true)
452+
c.Assert(errors.HasType[*complexError](wrapped2), gc.Equals, true)
453+
c.Assert(errors.HasType[ComplexErrorMessage](wrapped2), gc.Equals, true)
454+
c.Assert(errors.HasType[*complexErrorOther](wrapped2), gc.Equals, false)
455+
c.Assert(errors.HasType[*complexErrorOther](nil), gc.Equals, false)
440456

441-
err := errors.New("test")
442-
c.Assert(errors.IsType[*complexErrorOther](err), gc.Equals, false)
457+
complexErrOther := &complexErrorOther{Message: "another complex error"}
458+
459+
c.Assert(errors.HasType[*complexError](complexErrOther), gc.Equals, true)
443460

444-
c.Assert(errors.IsType[*complexErrorOther](nil), gc.Equals, false)
461+
wrapped2 = fmt.Errorf("wrapping1: %w", complexErrOther)
462+
c.Assert(errors.HasType[*complexError](wrapped2), gc.Equals, true)
463+
}
464+
465+
func (*functionSuite) TestAsType(c *gc.C) {
466+
complexErr := &complexError{Message: "complex error message"}
467+
wrapped1 := fmt.Errorf("wrapping1: %w", complexErr)
468+
wrapped2 := fmt.Errorf("wrapping2: %w", wrapped1)
469+
470+
ce, ok := errors.AsType[*complexError](complexErr)
471+
c.Assert(ok, gc.Equals, true)
472+
c.Assert(ce.Message, gc.Equals, complexErr.Message)
473+
474+
ce, ok = errors.AsType[*complexError](wrapped1)
475+
c.Assert(ok, gc.Equals, true)
476+
c.Assert(ce.Message, gc.Equals, complexErr.Message)
477+
478+
ce, ok = errors.AsType[*complexError](wrapped2)
479+
c.Assert(ok, gc.Equals, true)
480+
c.Assert(ce.Message, gc.Equals, complexErr.Message)
481+
482+
cem, ok := errors.AsType[ComplexErrorMessage](wrapped2)
483+
c.Assert(ok, gc.Equals, true)
484+
c.Assert(cem.ComplexMessage(), gc.Equals, complexErr.Message)
485+
486+
ceo, ok := errors.AsType[*complexErrorOther](wrapped2)
487+
c.Assert(ok, gc.Equals, false)
488+
c.Assert(ceo, gc.Equals, (*complexErrorOther)(nil))
489+
490+
ceo, ok = errors.AsType[*complexErrorOther](nil)
491+
c.Assert(ok, gc.Equals, false)
492+
c.Assert(ceo, gc.Equals, (*complexErrorOther)(nil))
493+
494+
complexErrOther := &complexErrorOther{Message: "another complex error"}
495+
ce, ok = errors.AsType[*complexError](complexErrOther)
496+
c.Assert(ok, gc.Equals, true)
497+
c.Assert(ce.Message, gc.Equals, complexErrOther.Message)
498+
499+
wrapped2 = fmt.Errorf("wrapping1: %w", complexErrOther)
500+
ce, ok = errors.AsType[*complexError](wrapped2)
501+
c.Assert(ok, gc.Equals, true)
502+
c.Assert(ce.Message, gc.Equals, complexErrOther.Message)
445503
}
446504

447505
func ExampleHide() {
@@ -464,12 +522,24 @@ func (m *MyError) Error() string {
464522
return m.Message
465523
}
466524

467-
func ExampleIsType() {
525+
func ExampleHasType() {
526+
myErr := &MyError{Message: "these are not the droids you're looking for"}
527+
err := fmt.Errorf("wrapped: %w", myErr)
528+
is := errors.HasType[*MyError](err)
529+
fmt.Println(is)
530+
531+
// Output:
532+
// true
533+
}
534+
535+
func ExampleAsType() {
468536
myErr := &MyError{Message: "these are not the droids you're looking for"}
469537
err := fmt.Errorf("wrapped: %w", myErr)
470-
is := errors.IsType[*MyError](err)
538+
myErr, is := errors.AsType[*MyError](err)
471539
fmt.Println(is)
540+
fmt.Println(myErr.Message)
472541

473542
// Output:
474543
// true
544+
// these are not the droids you're looking for
475545
}

0 commit comments

Comments
 (0)