-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcheckgrad.py
More file actions
117 lines (87 loc) · 2.62 KB
/
Copy pathcheckgrad.py
File metadata and controls
117 lines (87 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
__all__ = ['checkgrad', 'checkgradf', 'GradientError']
import numpy as np
from scipy.optimize import approx_fprime
from functools import wraps
_epsilon = np.sqrt(np.finfo(float).eps)
def checkgrad(func):
"""Decorator to check the gradient returned from the objective
evaluation function.
The function being decorated should return both the function
value as well as its gradient.
Example
-------
@checkgrad
def func_grad(x):
f = (3 * x**2 + 2 * x + 1).sum()
g = 6 * x + 2
return f, g
"""
if not __debug__:
return func
@wraps(func)
def _func_and_grad(x, *args):
ret = func(x, *args)
if len(ret) == 2:
def _func(x, *args): return func(x, *args)[0]
grad = ret[1]
approx_grad = approx_fprime(x, _func, _epsilon, *args)
compare_grad(grad, approx_grad)
return ret
return _func_and_grad
def checkgradf(func):
"""Decorator that takes the objective evaluation function as an argument
to check if its gradient matches the result from the gradient function
being decorated.
Example
-------
def func(x):
return (3 * x**2 + 2 * x + 1).sum()
@checkgradf(func)
def grad(x):
return 6 * x + 2
"""
@wraps(func)
def _checkgrad(grad_func):
if not __debug__:
return grad_func
def _grad(x, *args):
grad = grad_func(x, *args)
approx_grad = approx_fprime(x, func, _epsilon, *args)
compare_grad(grad, approx_grad)
return grad
return _grad
return _checkgrad
class GradientError(Exception):
def __init__(self, grad, approx_grad, diff):
self.grad = grad
self.approx_grad = approx_grad
self.diff = diff
def __str__(self):
return 'Gradient diff: {}\n{}\n{}'.format(
self.diff, self.grad, self.approx_grad)
def compare_grad(grad, approx_grad):
err = np.fabs(grad - approx_grad).max()
if err > 1e-5:
raise GradientError(grad, approx_grad, err)
if __name__ == '__main__':
try:
@checkgrad
def fg(x):
f_ = (3 * x**2 + 2 * x + 1).sum()
g_ = 6 * x + 2 + 0.1
return f_, g_
fg(np.ones(5))
except GradientError as e:
print 'grad:', e.grad
print 'approx grad:', e.approx_grad
try:
def f(x):
f_ = (3 * x**2 + 2 * x + 1).sum()
return f_
@checkgradf(f)
def g(x):
g_ = 6 * x + 2 + 0.1
return g_
g(np.ones(5))
except Exception as e:
print e