The dmlalg package contains implementations of double machine learning algorithms in R.
You can install the released version of dmlalg from CRAN with:
install.packages("dmlalg")
The aim of this first set of functions it to perform inference for the linear parameter in partially linear models with confounding variables. The standard DML estimator of the linear parameter has a two-stage least squares interpretation, which can lead to a large variance and overwide confidence intervals. We apply regularization to reduce the variance of the estimator, which produces narrower confidence intervals that remain approximately valid. Nuisance terms can be flexibly estimated with machine learning algorithms.
This algorithm is described in Emmenegger and Bühlmann (2021b) and implemented in the function regsdml.
This is a basic example which shows you how to solve a common problem:
library(dmlalg)
## Generate some data:
set.seed(19)
# true linear parameter
<- 1
beta0 <- 40
n # observed confounder
<- pi * runif(n, -1, 1)
w # instrument
<- 3 * tanh(2 * w) + rnorm(n, 0, 1)
a # unobserved confounder
<- 2 * sin(w) + rnorm(n, 0, 1)
h # linear covariate
<- -1 * abs(a) - h - 2 * tanh(w) + rnorm(n, 0, 1)
x # response
<- beta0 * x - 3 * cos(pi * 0.25 * h) + 0.5 * w ^ 2 + rnorm(n, 0, 1)
y
## Estimate the linear coefficient from x to y
## (The parameters are chosen small enough to make estimation fast):
## Caveat: A spline estimator is extrapolated, which raises a warning message.
## Extrapolation lies in the nature of our method. To omit the warning message
## resulting from the spline estimator, another estimator may be used.
<- regsdml(a, w, x, y,
fit gamma = exp(seq(-4, 1, length.out = 4)),
S = 3,
do_regDML_all_gamma = TRUE,
cond_method = c("forest", # for E[A|W]
"spline", # for E[X|W]
"spline"), # for E[Y|W]
params = list(list(ntree = 1), NULL, NULL))
#> Warning in print_W_E_fun(errors, warningMsgs):
#> Warning messages:
#> some 'x' values beyond boundary knots may cause ill-conditioned bases
## parm = c(2, 3) prints an additional summary for the 2nd and 3rd gamma-values
summary(fit, parm = c(2, 3),
correlation = TRUE,
print_gamma = TRUE)
#>
#> Coefficients :
#> regsDML (2.72e+00) :
#> Estimate Std. Error z value Pr(>|z|)
#> b1 0.910255 0.1731559 5.256852 1.465421e-07
#>
#> regDMLall (9.70e-02) :
#> Estimate Std. Error z value Pr(>|z|)
#> b1 0.7986392 0.1514027 5.274935 1.328031e-07
#>
#> regDMLall (5.13e-01) :
#> Estimate Std. Error z value Pr(>|z|)
#> b1 0.846176 0.1651298 5.124308 2.986318e-07
#>
#>
#> Variance-covariance matrices :
#> regsDML (2.72e+00) :
#> b1
#> b1 0.02998297
#>
#> regDMLall (9.70e-02) :
#> b1
#> b1 0.02292277
#>
#> regDMLall (5.13e-01) :
#> b1
#> b1 0.02726785
confint(fit, parm = c(2, 3),
print_gamma = TRUE)
#>
#> Two-sided confidence intervals at level 0.95 :
#>
#> regsDML (2.72e+00) :
#> 2.5 % 97.5 %
#> b1 0.5708757 1.249634
#>
#> regDMLall (9.70e-02) :
#> 2.5 % 97.5 %
#> b1 0.5018955 1.095383
#>
#> regDMLall (5.13e-01) :
#> 2.5 % 97.5 %
#> b1 0.5225276 1.169824
coef(fit) # coefficients
#> regsDML
#> b1 0.910255
vcov(fit) # variance-covariance matrices
#>
#> Variance-covariance matrices :
#> regsDML :
#> b1
#> b1 0.02998297
## Alternatively, provide the data in a single data frame
## (see also caveat above):
<- data.frame(a = a, w = w, x = x, y = y)
data <- regsdml(a = "a", w = "w", x = "x", y = "y", data = data,
fit gamma = exp(seq(-4, 1, length.out = 4)),
S = 3)
#> Warning in print_W_E_fun(errors, warningMsgs):
#> Warning messages:
#> some 'x' values beyond boundary knots may cause ill-conditioned bases
## With more realistic parameter choices:
if (FALSE) {
<- regsdml(a, w, x, y,
fit cond_method = c("forest", # for E[A|W]
"spline", # for E[X|W]
"spline")) # for E[Y|W]
summary(fit)
confint(fit)
## Alternatively, provide the data in a single data frame:
## (see also caveat above):
<- data.frame(a = a, w = w, x = x, y = y)
data <- regsdml(a = "a", w = "w", x = "x", y = "y", data = data)
fit }
The aim of this second set of functions is to estimate and perform inference for the linear coefficient in a partially linear mixed-effects model with DML. Machine learning algorithms allows us to incorporate more complex interaction structures and high-dimensional variables.
This algorithm is described in Emmenegger and Bühlmann (2021a) and implemented in the function mmdml.
This is a basic example which shows you how to solve a common problem:
library(dmlalg)
## generate data
RNGkind("L'Ecuyer-CMRG")
set.seed(19)
<- example_data_mmdml(beta0 = 0.2)
data1 <- example_data_mmdml(beta0 = c(0.2, 0.2))
data2
## fit models
## Caveat: Warning messages are displayed because the small number of
## observations results in a singular random effects model
<-
fit1 mmdml(w = c("w1", "w2", "w3"), x = "x1", y = "resp", z = c("id", "cask"),
data = data1, z_formula = "(1|id) + (1|cask:id)", group = "id", S = 3)
#> Warning in mmdml(w = c("w1", "w2", "w3"), x = "x1", y = "resp", z = c("id", :
#> Warning messages:
#> boundary (singular) fit: see ?isSingular
<-
fit2 mmdml(w = c("w1", "w2", "w3"), x = c("x1", "x2"), y = "resp", z = c("id", "cask"),
data = data2, z_formula = "(1|id) + (1|cask:id)", group = "id", S = 3)
#> Warning in mmdml(w = c("w1", "w2", "w3"), x = c("x1", "x2"), y = "resp", :
#> Warning messages:
#> boundary (singular) fit: see ?isSingular
## apply methods
confint(fit2)
#> 2.5% 97.5%
#> x1 -0.03415795 0.3480103
#> x2 0.15930098 0.3893938
fixef(fit2)
#> x1 x2
#> 0.1569261 0.2743474
print(fit2)
#> Semiparametric mixed model fit by maximum likelihood ['mmdml']
#> Random effects:
#> Groups Name Std.Dev.
#> cask:id (Intercept) 1.908e-06
#> id (Intercept) 1.107e-01
#> Residual 2.756e-01
#> Number of obs: 46, groups: cask:id, 20; id, 10
#> Fixed Effects:
#> x1 x2
#> 0.1569 0.2743
#> optimizer (nloptwrap) convergence code: 0 (OK) ; 0 optimizer warnings; 1 lme4 warnings
ranef(fit2)
#> $`cask:id`
#> (Intercept)
#> 1:1 -0.0023043914
#> 1:10 -0.0050894736
#> 1:2 0.0024571669
#> 1:3 0.0007708872
#> 1:4 -0.0012417525
#> 1:5 0.0029010344
#> 1:6 0.0012307712
#> 1:7 -0.0028418387
#> 1:8 -0.0015618712
#> 1:9 -0.0048037635
#> 2:1 0.0100768089
#> 2:10 -0.0031560819
#> 2:2 -0.0033427429
#> 2:3 -0.0044928425
#> 2:4 -0.0054049237
#> 2:5 -0.0021157461
#> 2:6 -0.0023122280
#> 2:7 0.0038004751
#> 2:8 0.0148222090
#> 2:9 0.0026385335
#>
#> $id
#> (Intercept)
#> 1 0.100740957
#> 10 -0.124434023
#> 2 -0.036918731
#> 3 -0.030230821
#> 4 -0.081051109
#> 5 0.018887512
#> 6 -0.006711504
#> 7 0.025545300
#> 8 0.235373382
#> 9 -0.020965920
residuals(fit2)
#> [[1]]
#> [1] -0.1311195998 0.5733692328 0.1398125051 -0.0705463911 -0.1196552839
#> [6] -0.0354080600 0.6205378654 -0.1057642425 -0.4355021749 -0.0633888854
#> [11] 0.0070044016 -0.1777683530 -0.0214893719 0.0052358066 0.1594839987
#> [16] -0.2353753755 -0.2216497409 -0.1034882421 0.0175984650 -0.0388497525
#> [21] 0.4636325671 -0.2597143034 0.3528825573 -0.4739722035 0.0007039458
#> [26] 0.0700307380 -0.1315655000 -0.1617002846 0.2162465843 0.0934414339
#> [31] -0.0480554546 -0.1342562672 -0.2349311153 0.4021334289 0.6761796261
#> [36] 0.3514207835 -0.0918140917 -0.2144924370 -0.3184478283 -0.2704273590
#> [41] -0.1953366308 0.7209607369 -0.1050645053 -0.2895904461 -0.2737160112
#> [46] 0.0941353224
#>
#> [[2]]
#> [1] 0.066708484 0.381936532 0.083961541 -0.244607521 -0.116940987
#> [6] -0.015024540 0.605540877 0.128223071 -0.186010749 -0.119432458
#> [11] -0.101885530 -0.153724682 -0.214346785 -0.126400135 0.090522034
#> [16] -0.103818112 -0.170763502 -0.102507199 0.047067741 -0.026325741
#> [21] 0.472126666 -0.231575911 0.324749223 -0.423215690 -0.013990681
#> [26] 0.066537726 -0.086954711 -0.025470109 0.227756255 0.224213587
#> [31] -0.070700603 0.081484834 -0.226268534 0.615553468 0.723110460
#> [36] 0.333538915 -0.076459138 -0.198241935 -0.245660371 -0.366166157
#> [41] -0.142947352 0.677671159 -0.047532882 -0.305555800 -0.379445954
#> [46] 0.007159723
#>
#> [[3]]
#> [1] 0.09685066 0.34629638 0.09582384 -0.27150981 -0.12653048 -0.02387543
#> [7] 0.62488259 0.12730531 -0.19466784 -0.12227940 -0.07635676 -0.16470188
#> [13] -0.20223445 -0.11432450 0.13844295 -0.12234863 -0.18662475 -0.09034621
#> [19] 0.07330126 -0.02704395 0.51049151 -0.23716208 0.36116367 -0.42669942
#> [25] -0.02948530 0.10139429 -0.06858354 -0.03611104 0.19153360 0.21971922
#> [31] -0.04085530 0.09453877 -0.20903814 0.60734696 0.69658489 0.33318587
#> [37] -0.09082740 -0.21317885 -0.24276713 -0.34992920 -0.09491974 0.68198892
#> [43] -0.07291051 -0.24350682 -0.40714805 0.05067157
sigma(fit2)
#> [1] 0.2756384
summary(fit2)
#> Semiparametric mixed model fit by maximum likelihood ['mmdml']
#> Scaled residuals (nr_res = 3):
#> Min 1Q Median 3Q Max
#> -1.7195 -0.6674 -0.2394 0.3637 2.6234
#>
#> Random effects:
#> Groups Name Variance Std.Dev.
#> cask:id (Intercept) 3.641e-12 1.908e-06
#> id (Intercept) 1.226e-02 1.107e-01
#> Residual 7.598e-02 2.756e-01
#> Number of obs: 46, groups: cask:id, 20; id, 10
#>
#> Fixed effects:
#> Estimate Std. Error z value Pr(>|z|)
#> x1 0.15693 0.09749 1.610 0.107
#> x2 0.27435 0.05870 4.674 2.96e-06 ***
#> ---
#> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#>
#> Correlation of Fixed Effects:
#> x1
#> x2 -0.029
#> optimizer (nloptwrap) convergence code: 0 (OK)
#> boundary (singular) fit: see ?isSingular
vcov(fit2)
#> 2 x 2 Matrix of class "dpoMatrix"
#> x1 x2
#> x1 9.505018e-03 -9.208662e-05
#> x2 -9.208662e-05 3.445483e-03
VarCorr(fit2)
#> Groups Name Std.Dev.
#> cask:id (Intercept) 1.9083e-06
#> id (Intercept) 1.1074e-01
#> Residual 2.7564e-01