Skip to content
This repository was archived by the owner on Oct 31, 2024. It is now read-only.

Commit 3fcc0ef

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a62aa3b commit 3fcc0ef

2 files changed

Lines changed: 61 additions & 61 deletions

File tree

development/hessian_approximation/compute_hessian.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def compute_hessian(x0, X1, X0, Z1, Z0, Y1, Y0):
2929

3030
# aux_params
3131
nu1 = (Y1 - np.dot(beta1, X1.T)) / sd1
32-
lambda1 = (np.dot(gamma, Z1.T) - rho1v * nu1) / (np.sqrt(1 - rho1v ** 2))
32+
lambda1 = (np.dot(gamma, Z1.T) - rho1v * nu1) / (np.sqrt(1 - rho1v**2))
3333
nu0 = (Y0 - np.dot(beta0, X0.T)) / sd0
34-
lambda0 = (np.dot(gamma, Z0.T) - rho0v * nu0) / (np.sqrt(1 - rho0v ** 2))
34+
lambda0 = (np.dot(gamma, Z0.T) - rho0v * nu0) / (np.sqrt(1 - rho0v**2))
3535

3636
eta1 = (
3737
-lambda1 * norm.pdf(lambda1) * norm.cdf(lambda1) - norm.pdf(lambda1) ** 2
@@ -113,21 +113,21 @@ def calc_hess_beta1(
113113
"""
114114

115115
# define some auxiliary variables
116-
rho_aux1 = lambda1 * rho1v / (1 - rho1v ** 2) - nu1 / (1 - rho1v ** 2) ** 0.5
117-
rho_aux2 = rho1v ** 2 / ((1 - rho1v ** 2) ** (3 / 2)) + 1 / (1 - rho1v ** 2) ** 0.5
118-
sd_aux1 = rho1v ** 2 / (1 - rho1v ** 2)
119-
sd_aux2 = rho1v / np.sqrt(1 - rho1v ** 2)
116+
rho_aux1 = lambda1 * rho1v / (1 - rho1v**2) - nu1 / (1 - rho1v**2) ** 0.5
117+
rho_aux2 = rho1v**2 / ((1 - rho1v**2) ** (3 / 2)) + 1 / (1 - rho1v**2) ** 0.5
118+
sd_aux1 = rho1v**2 / (1 - rho1v**2)
119+
sd_aux2 = rho1v / np.sqrt(1 - rho1v**2)
120120
# derivation wrt beta1
121121
der_b1_beta1 = -(
122-
X1X1 * (rho1v ** 2 / (1 - rho1v ** 2)) * 1 / sd1 ** 2 - X1.T @ X1 / sd1 ** 2
122+
X1X1 * (rho1v**2 / (1 - rho1v**2)) * 1 / sd1**2 - X1.T @ X1 / sd1**2
123123
)
124124
# add zeros for derv beta 0
125125
der_b1_beta1 = np.concatenate(
126126
(der_b1_beta1, np.zeros((n_col_X1, n_col_X0))), axis=1
127127
)
128128

129129
# derivation wrt gamma
130-
der_b1_gamma = -(X1Z1 * rho1v / (sd1 * (1 - rho1v ** 2)))
130+
der_b1_gamma = -(X1Z1 * rho1v / (sd1 * (1 - rho1v**2)))
131131
der_b1_gamma = np.concatenate((der_b1_beta1, der_b1_gamma), axis=1)
132132
# derv wrt sigma 1
133133
der_b1_sd = (
@@ -150,7 +150,7 @@ def calc_hess_beta1(
150150
# derv wrt rho1
151151
der_b1_rho = (
152152
-(
153-
eta1 * rho_aux1 * rho1v / ((1 - rho1v ** 2) ** 0.5)
153+
eta1 * rho_aux1 * rho1v / ((1 - rho1v**2) ** 0.5)
154154
+ norm.pdf(lambda1) / norm.cdf(lambda1) * rho_aux2
155155
)
156156
* 1
@@ -176,21 +176,21 @@ def calc_hess_beta0(
176176
"""
177177

178178
# define some aux_vars
179-
rho_aux1 = lambda0 * rho0v / (1 - rho0v ** 2) - nu0 / (1 - rho0v ** 2) ** 0.5
180-
rho_aux2 = rho0v ** 2 / ((1 - rho0v ** 2) ** (3 / 2)) + 1 / (1 - rho0v ** 2) ** 0.5
181-
sd_aux1 = rho0v ** 2 / (1 - rho0v ** 2)
182-
sd_aux2 = rho0v / (np.sqrt(1 - rho0v ** 2))
179+
rho_aux1 = lambda0 * rho0v / (1 - rho0v**2) - nu0 / (1 - rho0v**2) ** 0.5
180+
rho_aux2 = rho0v**2 / ((1 - rho0v**2) ** (3 / 2)) + 1 / (1 - rho0v**2) ** 0.5
181+
sd_aux1 = rho0v**2 / (1 - rho0v**2)
182+
sd_aux2 = rho0v / (np.sqrt(1 - rho0v**2))
183183

184184
# add zeros for beta0
185185
der_b0_beta1 = np.zeros((n_col_X1, n_col_X0))
186186

187187
# beta0
188188
der_b0_beta0 = (
189-
-(X0X0 * (rho0v ** 2 / (1 - rho0v ** 2)) * 1 / sd0 ** 2) + X0.T @ X0 / sd0 ** 2
189+
-(X0X0 * (rho0v**2 / (1 - rho0v**2)) * 1 / sd0**2) + X0.T @ X0 / sd0**2
190190
)
191191
der_b0_beta0 = np.concatenate((der_b0_beta1, der_b0_beta0), axis=1)
192192
# gamma
193-
der_b0_gamma = -X0Z0 * rho0v / (1 - rho0v ** 2) * 1 / sd0
193+
der_b0_gamma = -X0Z0 * rho0v / (1 - rho0v**2) * 1 / sd0
194194
der_b0_gamma = np.concatenate((der_b0_beta0, der_b0_gamma), axis=1)
195195

196196
# add zeros for sigma1 and rho1
@@ -204,15 +204,15 @@ def calc_hess_beta0(
204204
- 2 * nu0
205205
)
206206
* 1
207-
/ sd0 ** 2
207+
/ sd0**2
208208
)
209209
der_b0_sd = np.expand_dims((der_b0_sd.T @ X0), 1)
210210
der_b0_sd = np.concatenate((der_b0_gamma, der_b0_sd), axis=1)
211211

212212
# rho
213213
der_b0_rho = (
214214
(
215-
eta0 * -rho_aux1 * (rho0v / ((1 - rho0v ** 2) ** 0.5))
215+
eta0 * -rho_aux1 * (rho0v / ((1 - rho0v**2) ** 0.5))
216216
+ norm.pdf(lambda0) / (1 - norm.cdf(lambda0)) * rho_aux2
217217
)
218218
* 1
@@ -252,7 +252,7 @@ def calc_hess_gamma(
252252

253253
der_gamma_beta = np.zeros((Z1.shape[1], num_col_X1X0))
254254

255-
der_g_gamma = -(1 / (1 - rho1v ** 2) * Z1Z1 + 1 / (1 - rho0v ** 2) * Z0Z0)
255+
der_g_gamma = -(1 / (1 - rho1v**2) * Z1Z1 + 1 / (1 - rho0v**2) * Z0Z0)
256256
der_g_gamma = np.concatenate((der_gamma_beta, der_g_gamma), axis=1)
257257

258258
# sigma1
@@ -261,19 +261,19 @@ def calc_hess_gamma(
261261
@ np.einsum("ij, i ->ij", X1, nu1)
262262
/ sd1
263263
* rho1v
264-
/ (1 - rho1v ** 2)
264+
/ (1 - rho1v**2)
265265
)[:, 0]
266266
der_g_sd1 = np.expand_dims(der_g_sd1, 0).T
267267
der_g_sd1 = np.concatenate((der_g_gamma, der_g_sd1), axis=1)
268268

269269
# rho1
270270
aux_rho11 = np.einsum("ij, i ->ij", Z1, eta1).T @ (
271-
lambda1 * rho1v / (1 - rho1v ** 2) - nu1 / np.sqrt(1 - rho1v ** 2)
271+
lambda1 * rho1v / (1 - rho1v**2) - nu1 / np.sqrt(1 - rho1v**2)
272272
)
273273
aux_rho21 = Z1.T @ (norm.pdf(lambda1) / norm.cdf(lambda1))
274274

275-
der_g_rho1 = -aux_rho11 * 1 / (np.sqrt(1 - rho1v ** 2)) - aux_rho21 * rho1v / (
276-
(1 - rho1v ** 2) ** (3 / 2)
275+
der_g_rho1 = -aux_rho11 * 1 / (np.sqrt(1 - rho1v**2)) - aux_rho21 * rho1v / (
276+
(1 - rho1v**2) ** (3 / 2)
277277
)
278278
der_g_rho1 = np.expand_dims(der_g_rho1, 0).T
279279
der_g_rho1 = np.concatenate((der_g_sd1, der_g_rho1), axis=1)
@@ -284,19 +284,19 @@ def calc_hess_gamma(
284284
@ np.einsum("ij, i ->ij", X0, nu0)
285285
/ sd0
286286
* rho0v
287-
/ (1 - rho0v ** 2)
287+
/ (1 - rho0v**2)
288288
)[:, 0]
289289
der_g_sd0 = np.expand_dims(der_g_sd0, 0).T
290290
der_g_sd0 = np.concatenate((der_g_rho1, -der_g_sd0), axis=1)
291291

292292
# rho1
293293
aux_rho10 = np.einsum("ij, i ->ij", Z0, eta0).T @ (
294-
lambda0 * rho0v / (1 - rho0v ** 2) - nu0 / np.sqrt(1 - rho0v ** 2)
294+
lambda0 * rho0v / (1 - rho0v**2) - nu0 / np.sqrt(1 - rho0v**2)
295295
)
296296
aux_rho20 = -Z0.T @ (norm.pdf(lambda0) / (1 - norm.cdf(lambda0)))
297297

298-
der_g_rho0 = aux_rho10 * 1 / (np.sqrt(1 - rho0v ** 2)) + aux_rho20 * rho0v / (
299-
(1 - rho0v ** 2) ** (3 / 2)
298+
der_g_rho0 = aux_rho10 * 1 / (np.sqrt(1 - rho0v**2)) + aux_rho20 * rho0v / (
299+
(1 - rho0v**2) ** (3 / 2)
300300
)
301301
der_g_rho0 = np.expand_dims(-der_g_rho0, 0).T
302302
der_g_rho0 = np.concatenate((der_g_sd0, der_g_rho0), axis=1)
@@ -328,52 +328,52 @@ def calc_hess_dist(
328328
Delta_sd1 = (
329329
+1 / sd1
330330
- (norm.pdf(lambda1) / norm.cdf(lambda1))
331-
* (rho1v * nu1 / (np.sqrt(1 - rho1v ** 2) * sd1))
332-
- nu1 ** 2 / sd1
331+
* (rho1v * nu1 / (np.sqrt(1 - rho1v**2) * sd1))
332+
- nu1**2 / sd1
333333
)
334334
Delta_sd1_der = (
335335
nu1
336336
/ sd1
337337
* (
338-
-eta1 * (rho1v ** 2 * nu1) / (1 - rho1v ** 2)
339-
+ (norm.pdf(lambda1) / norm.cdf(lambda1)) * rho1v / np.sqrt(1 - rho1v ** 2)
338+
-eta1 * (rho1v**2 * nu1) / (1 - rho1v**2)
339+
+ (norm.pdf(lambda1) / norm.cdf(lambda1)) * rho1v / np.sqrt(1 - rho1v**2)
340340
+ 2 * nu1
341341
)
342342
)
343343

344344
Delta_sd0 = (
345345
+1 / sd0
346346
+ (norm.pdf(lambda0) / (1 - norm.cdf(lambda0)))
347-
* (rho0v * nu0 / (np.sqrt(1 - rho0v ** 2) * sd0))
348-
- nu0 ** 2 / sd0
347+
* (rho0v * nu0 / (np.sqrt(1 - rho0v**2) * sd0))
348+
- nu0**2 / sd0
349349
)
350350
Delta_sd0_der = (
351351
nu0
352352
/ sd0
353353
* (
354-
-eta0 * (rho0v ** 2 * nu0) / (1 - rho0v ** 2)
354+
-eta0 * (rho0v**2 * nu0) / (1 - rho0v**2)
355355
- (norm.pdf(lambda0) / (1 - norm.cdf(lambda0)))
356356
* rho0v
357-
/ np.sqrt(1 - rho0v ** 2)
357+
/ np.sqrt(1 - rho0v**2)
358358
+ 2 * nu0
359359
)
360360
)
361361

362-
aux_rho11 = lambda1 * rho1v / (1 - rho1v ** 2) - nu1 / np.sqrt(1 - rho1v ** 2)
363-
aux_rho12 = 1 / (1 - rho1v ** 2) ** (3 / 2)
362+
aux_rho11 = lambda1 * rho1v / (1 - rho1v**2) - nu1 / np.sqrt(1 - rho1v**2)
363+
aux_rho12 = 1 / (1 - rho1v**2) ** (3 / 2)
364364

365-
aux_rho_rho11 = (np.dot(gamma, Z1.T) * rho1v - nu1) / (1 - rho1v ** 2) ** (3 / 2)
365+
aux_rho_rho11 = (np.dot(gamma, Z1.T) * rho1v - nu1) / (1 - rho1v**2) ** (3 / 2)
366366
aux_rho_rho12 = (
367-
2 * np.dot(gamma, Z1.T) * rho1v ** 2 + np.dot(gamma, Z1.T) - 3 * nu1 * rho1v
368-
) / (1 - rho1v ** 2) ** (5 / 2)
367+
2 * np.dot(gamma, Z1.T) * rho1v**2 + np.dot(gamma, Z1.T) - 3 * nu1 * rho1v
368+
) / (1 - rho1v**2) ** (5 / 2)
369369

370-
aux_rho01 = lambda0 * rho0v / (1 - rho0v ** 2) - nu0 / np.sqrt(1 - rho0v ** 2)
371-
aux_rho02 = 1 / (1 - rho0v ** 2) ** (3 / 2)
370+
aux_rho01 = lambda0 * rho0v / (1 - rho0v**2) - nu0 / np.sqrt(1 - rho0v**2)
371+
aux_rho02 = 1 / (1 - rho0v**2) ** (3 / 2)
372372

373-
aux_rho_rho01 = (np.dot(gamma, Z0.T) * rho0v - nu0) / (1 - rho0v ** 2) ** (3 / 2)
373+
aux_rho_rho01 = (np.dot(gamma, Z0.T) * rho0v - nu0) / (1 - rho0v**2) ** (3 / 2)
374374
aux_rho_rho02 = (
375-
2 * np.dot(gamma, Z0.T) * rho0v ** 2 + np.dot(gamma, Z0.T) - 3 * nu0 * rho0v
376-
) / (1 - rho0v ** 2) ** (5 / 2)
375+
2 * np.dot(gamma, Z0.T) * rho0v**2 + np.dot(gamma, Z0.T) - 3 * nu0 * rho0v
376+
) / (1 - rho0v**2) ** (5 / 2)
377377

378378
# for sigma1
379379

@@ -385,7 +385,7 @@ def calc_hess_dist(
385385
1
386386
/ sd1
387387
* (
388-
-eta1 * aux_rho11 * (rho1v * nu1) / (np.sqrt(1 - rho1v ** 2))
388+
-eta1 * aux_rho11 * (rho1v * nu1) / (np.sqrt(1 - rho1v**2))
389389
- (norm.pdf(lambda1) / norm.cdf(lambda1)) * aux_rho12 * nu1
390390
)
391391
)
@@ -411,7 +411,7 @@ def calc_hess_dist(
411411
1
412412
/ sd0
413413
* (
414-
-eta0 * aux_rho01 * (rho0v * nu0) / (np.sqrt(1 - rho0v ** 2))
414+
-eta0 * aux_rho01 * (rho0v * nu0) / (np.sqrt(1 - rho0v**2))
415415
+ (norm.pdf(lambda0) / (1 - norm.cdf(lambda0))) * aux_rho02 * nu0
416416
)
417417
)

grmpy/estimate/estimate_par.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -474,10 +474,10 @@ def log_likelihood(
474474
sd1, sd0, rho1v, rho0v = x0[-4], x0[-2], x0[-3], x0[-1]
475475

476476
nu1 = (Y1 - np.dot(beta1, X1.T)) / sd1
477-
lambda1 = (np.dot(gamma, Z1.T) - rho1v * nu1) / (np.sqrt(1 - rho1v ** 2))
477+
lambda1 = (np.dot(gamma, Z1.T) - rho1v * nu1) / (np.sqrt(1 - rho1v**2))
478478

479479
nu0 = (Y0 - np.dot(beta0, X0.T)) / sd0
480-
lambda0 = (np.dot(gamma, Z0.T) - rho0v * nu0) / (np.sqrt(1 - rho0v ** 2))
480+
lambda0 = (np.dot(gamma, Z0.T) - rho0v * nu0) / (np.sqrt(1 - rho0v**2))
481481

482482
treated = (1 / sd1) * norm.pdf(nu1) * norm.cdf(lambda1)
483483
untreated = (1 / sd0) * norm.pdf(nu0) * (1 - norm.cdf(lambda0))
@@ -916,7 +916,7 @@ def gradient(X1, X0, Z1, Z0, nu1, nu0, lambda1, lambda0, gamma, sd1, sd0, rho1v,
916916
"ij, i ->ij",
917917
X1,
918918
-(norm.pdf(lambda1) / norm.cdf(lambda1))
919-
* (rho1v / (np.sqrt(1 - rho1v ** 2) * sd1))
919+
* (rho1v / (np.sqrt(1 - rho1v**2) * sd1))
920920
- nu1 / sd1,
921921
),
922922
0,
@@ -929,7 +929,7 @@ def gradient(X1, X0, Z1, Z0, nu1, nu0, lambda1, lambda0, gamma, sd1, sd0, rho1v,
929929
X0,
930930
norm.pdf(lambda0)
931931
/ (1 - norm.cdf(lambda0))
932-
* (rho0v / (np.sqrt(1 - rho0v ** 2) * sd0))
932+
* (rho0v / (np.sqrt(1 - rho0v**2) * sd0))
933933
- nu0 / sd0,
934934
),
935935
0,
@@ -939,8 +939,8 @@ def gradient(X1, X0, Z1, Z0, nu1, nu0, lambda1, lambda0, gamma, sd1, sd0, rho1v,
939939
* (
940940
+1 / sd1
941941
- (norm.pdf(lambda1) / norm.cdf(lambda1))
942-
* (rho1v * nu1 / (np.sqrt(1 - rho1v ** 2) * sd1))
943-
- nu1 ** 2 / sd1
942+
* (rho1v * nu1 / (np.sqrt(1 - rho1v**2) * sd1))
943+
- nu1**2 / sd1
944944
),
945945
keepdims=True,
946946
)
@@ -949,16 +949,16 @@ def gradient(X1, X0, Z1, Z0, nu1, nu0, lambda1, lambda0, gamma, sd1, sd0, rho1v,
949949
* (
950950
+1 / sd0
951951
+ (norm.pdf(lambda0) / (1 - norm.cdf(lambda0)))
952-
* (rho0v * nu0 / (np.sqrt(1 - rho0v ** 2) * sd0))
953-
- nu0 ** 2 / sd0
952+
* (rho0v * nu0 / (np.sqrt(1 - rho0v**2) * sd0))
953+
- nu0**2 / sd0
954954
),
955955
keepdims=True,
956956
)
957957
grad_rho1v = np.sum(
958958
(
959959
-(norm.pdf(lambda1) / norm.cdf(lambda1))
960960
* ((np.dot(gamma, Z1.T) * rho1v) - nu1)
961-
/ (1 - rho1v ** 2) ** (1 / 2)
961+
/ (1 - rho1v**2) ** (1 / 2)
962962
),
963963
keepdims=True,
964964
)
@@ -967,7 +967,7 @@ def gradient(X1, X0, Z1, Z0, nu1, nu0, lambda1, lambda0, gamma, sd1, sd0, rho1v,
967967
(
968968
(norm.pdf(lambda0) / (1 - norm.cdf(lambda0)))
969969
* ((np.dot(gamma, Z0.T) * rho0v) - nu0)
970-
/ (1 - rho0v ** 2) ** (1 / 2)
970+
/ (1 - rho0v**2) ** (1 / 2)
971971
),
972972
keepdims=True,
973973
)
@@ -976,14 +976,14 @@ def gradient(X1, X0, Z1, Z0, nu1, nu0, lambda1, lambda0, gamma, sd1, sd0, rho1v,
976976
np.einsum(
977977
"ij, i ->ij",
978978
Z1,
979-
(norm.pdf(lambda1) / norm.cdf(lambda1)) * 1 / np.sqrt(1 - rho1v ** 2),
979+
(norm.pdf(lambda1) / norm.cdf(lambda1)) * 1 / np.sqrt(1 - rho1v**2),
980980
)
981981
) - sum(
982982
np.einsum(
983983
"ij, i ->ij",
984984
Z0,
985985
(norm.pdf(lambda0) / (1 - norm.cdf(lambda0)))
986-
* (1 / np.sqrt(1 - rho0v ** 2)),
986+
* (1 / np.sqrt(1 - rho0v**2)),
987987
)
988988
)
989989

@@ -1042,10 +1042,10 @@ def gradient_hessian(x0, X1, X0, Z1, Z0, Y1, Y0):
10421042
# compute gradient for beta 1
10431043

10441044
nu1 = (Y1 - np.dot(beta1, X1.T)) / sd1
1045-
lambda1 = (np.dot(gamma, Z1.T) - rho1v * nu1) / (np.sqrt(1 - rho1v ** 2))
1045+
lambda1 = (np.dot(gamma, Z1.T) - rho1v * nu1) / (np.sqrt(1 - rho1v**2))
10461046

10471047
nu0 = (Y0 - np.dot(beta0, X0.T)) / sd0
1048-
lambda0 = (np.dot(gamma, Z0.T) - rho0v * nu0) / (np.sqrt(1 - rho0v ** 2))
1048+
lambda0 = (np.dot(gamma, Z0.T) - rho0v * nu0) / (np.sqrt(1 - rho0v**2))
10491049

10501050
grad = gradient(
10511051
X1, X0, Z1, Z0, nu1, nu0, lambda1, lambda0, gamma, sd1, sd0, rho1v, rho0v
@@ -1054,7 +1054,7 @@ def gradient_hessian(x0, X1, X0, Z1, Z0, Y1, Y0):
10541054
multiplier = np.concatenate(
10551055
(
10561056
np.ones(len(grad[:-4])),
1057-
np.array([1 / sd1, 1 / (1 - rho1v ** 2), 1 / sd0, 1 / (1 - rho0v ** 2)]),
1057+
np.array([1 / sd1, 1 / (1 - rho1v**2), 1 / sd0, 1 / (1 - rho0v**2)]),
10581058
)
10591059
)
10601060

0 commit comments

Comments
 (0)