For most users, the vanilla function calc_weight
should
be sufficient for most use cases with binary treatments. However, for
users with more complicated data structures or problems, the
causalOT
package offers a more flexible interface heavily
reliant on the torch
package. We will walk through a few
use cases here to show how one might use the object-oriented programming
(OOP) objects.
One very important thing to note is that these objects are mutable; in other words, they are always passed by reference so changes to the base objects will effect all other objects they are reliant on. Thus, changes will propagate forward and backward. For these reasons, these objects will be more dangerous in terms of side-effects and should be used carefully.
Finally, these data structures are heavy reliant on the
torch
package in R
. This allows relatively
easy use of GPUs and also has other advantages such as passing by
reference and various optimization methods available by default.
The fundamental objects underlying the OOP software methods in the
package is an R6
class Measure
. These objects
are named for the fact that they specify an empirical distribution on a
set of support points. In light of this, the first two arguments should
be intuitive: x
, the set of data for the measure, and
weights
, the empirical mass.
n <- 5
d <- 3
x <- matrix(stats::rnorm(n*d), nrow = n)
w <- stats::runif(n)
w <- w/sum(w)
m <- Measure(x = x, weights = w)
We can also view the weights and data of a measure object by accessing these public fields:
m$x
#> torch_tensor
#> -0.1088 0.1408 -0.9207
#> 0.0217 -1.5297 -0.2621
#> 1.0911 0.1324 -0.2298
#> -0.0686 0.5114 -0.0258
#> 0.7843 1.1124 -1.3936
#> [ CPUDoubleType{5,3} ]
m$weights
#> torch_tensor
#> 0.2228
#> 0.1881
#> 0.2386
#> 0.2295
#> 0.1211
#> [ CPUDoubleType{5} ]
The next argument in the constructor function,
probability.measure
, lets the function know if your weights
are a probability measure—i.e., the weights sum to 1 and are
positive—versus a more general type of measure. The default assumption
is that you are using a probability measure.
Then we come to a very important argument: adapt
. This
let’s the function know if you are seeking to change nothing (“none”)
and keep the measure static, if you want to adapt the weights
(“weights”) towards another measure, or if you want to move the data
points of the measure itself (“x”).
Adapting the measure to functions of target data.
The next two arguments are useful in the setting when you want to adapt
specific functions of the Measure
to target data.
Typically, these target functions will be the empirical means of some
aspect of the covariates in a target data set. As an example:
target.data <- matrix(rnorm(n*d),n,d)
target.values <- colMeans(target.data)
m <- Measure(x = x, weights = w,
probability.measure = TRUE,
adapt = "weights",
target.values = target.values)
Note that if we don’t supply the balance.functions
argument and target.values
are provided, the function will
use the data in x
as the balance.functions
. We
can view are balance functions with the following arguments:
Note that the values returned are different than the original. This is because the software divides the balance functions and target values by the standard deviation of the balance function.
all.equal(as.numeric(m$balance_target), target.values)
#> [1] "Mean relative difference: 0.396385"
all.equal(as.matrix(m$balance_functions), x)
#> [1] "Mean relative difference: 0.3160986"
sds <- apply(x,2,sd)
all.equal(as.numeric(m$balance_target), target.values/sds)
#> [1] TRUE
all.equal(as.matrix(m$balance_functions), sweep(x,2,sds,"/"))
#> [1] TRUE
Obviously, if adapt = "none"
, then the
balance.functions
and target.values
are
essentially useless.
Finally, the arguments dtype
and device
are
arguments for setting of the the torch_tensor
s of the
underlying data structures. For more information, see the
torch
documentation.
Also, we can print the measure objects to the screen to see some of the underlying information quickly. We also get the object address which can be useful in distinguishing the different objects.
m
#> Measure: 0x7fdd8b7c4eb8
#> x : a 5x3 matrix
#> -0.11, 0.14, -0.92
#> 0.02, -1.53, -0.26
#> 1.09, 0.13, -0.23
#> -0.07, 0.51, -0.03
#> 0.78, 1.11, -1.39
#> weights: 0.22, 0.19, 0.24, 0.23, 0.12
#> balance:
#> funct.: -0.2, 0.14, -1.61
#> target: 0.55, -0.2, -1.53 …
#> adapt : weights
#> dtype : torch_Double
#> device : torch_device(type='cpu')
The next important component of the OOP framework in
causalOT
are the OTProblem
objects. Say we
have to measures, one we want to target, m_target
, and one
we want to adapt to the target measure, m_source
by
changing its weights.
m_target <- Measure(x = matrix(rnorm(n*2*d), n*2,d))
m_source <- Measure(x = x, weights = w, adapt = "weights")
Now we need some way of adapting m_source
and in this
package, we will use optimal transport methods. Thus, we specify our
optimal transport problem:
The OTProblem
is the basis for setting up the following
objective function w⋆=argminwOTλ(msource(w),mtarget)s.t. Ew(B(xsource))−E(B(xtarget))σ≤δ. OTλ is an optimal
transport distance specified by the Sinkhorn distance: Sλ(a,b)=min for some cost matrix C_{i,j} = c(x_i, x_j, or the Sinkhorn
divergence: S_\lambda(a,b) - 0.5
S_\lambda(a,a) - 0.5 S_\lambda(b,b). The linear constraint on
the problem bounds the balance functions within some distance \delta of their original standard
deviation, \sigma.
With this detail, we then need to specify which optimal transport
problem we’re using, the various penalty parameters, etc. to do this, we
use the setup_arguments
function below:
otp$setup_arguments(
lambda = NULL, # penalty values of the optimal transport (OT) distances to try
delta = NULL, # constraint values to try for balancing functinos
grid.length = 7L, # number of values of lambda and delta to try
# if none are provided
cost.function = NULL, # the ground cost to use between covariates
# default is the Euclidean distance
p = 2, # power to raise the cost by
cost.online = "auto", #Should cost be calculated "online" or "tensorized" (stored in memory). "auto" will try to decide for you
debias = TRUE, # use Sinkhorn divergences (debias = TRUE), i.e. debiased Sinkhorn distances,
# or use the Sinkhorn distances (debias = FALSE)
diameter = NULL, # the diameter of the covariate space if known
ot_niter = 1000L, # the number of iterations to run when solving OT distances
ot_tol = 0.001 # the tolerance for convergance of OT distances
)
The last two arguments may be confusing at first but understanding
how the OTProblem
objects adapt the measure may help to add
some clarity. The OTProblem
first has to solve an optimal
transport problem between the two measures (with runtime parameters
specified in the setup_arguments
function). Then the object
will take a step of updating the weights, which is done by the next
function.
Once we have set up the arguments, we can solve this
OTProblem
:
otp$solve(
niter = 1000L, # maximum number of iterations
tol = 1e-5, # tolerance for convergence
optimizer = "torch", # which optimizer to use "torch" or "frank-wolfe"
torch_optim = torch::optim_lbfgs, # torch optimizer to use if required
torch_scheduler = torch::lr_reduce_on_plateau, # torch scheduler to use if required
torch_args = list(line_search_fn = "strong_wolfe"), # args passed to the torch functions,
osqp_args = NULL, #arguments passed to the osqp solver used for "frank-wolfe" and balance functions
quick.balance.function = TRUE # if balance functions are also present, should an approximate value of the hyperparameter "delta" be found first
)
Since the objects are passed by reference, the weights of the measure object that was adapted are now different.
#> adapted original
#> [1,] 3.919938e-01 0.2227984
#> [2,] 1.150060e-01 0.1880527
#> [3,] 1.802888e-09 0.2385908
#> [4,] 4.930001e-01 0.2295054
#> [5,] 6.598182e-08 0.1210527
Note: the dual optimization method currently available for
the COT
method in the calc_weight
function is
not implemented for OTProblem
objects. Thus, these
optimization problems will possibly take longer to solve.
We have run the function with a variety of lambda
parameters chosen by the OTProblem
object. We should select
one to move forward with.
otp$choose_hyperparameters(
n_boot_lambda = 100L, #Number of bootstrap iterations to choose lambda
n_boot_delta = 1000L, #Number of bootstrap iterations to choose delta
lambda_bootstrap = Inf # penalty parameter to use for OT distances
)
The delta
parameter wasn’t used so we only select the
values of lambda
. This gives us a final value of lambda
of
and final weights of
as.numeric(m_source$weights)
#> [1] 2.489658e-01 1.233070e-01 1.145063e-09 6.277272e-01 4.190680e-08
We can also see the final value of the optimal transport problem with
the chosen value of lambda
and weights.
In summary, we have the following steps to solve our causal inference problems using optimal transport.
Measure
objectsm_target <- Measure(x = matrix(rnorm(n*2*d), n*2,d))
m_source <- Measure(x = x, weights = w, adapt = "weights")
OTProblem
OTProblem
otp$setup_arguments(
lambda = NULL, # penalty values of the optimal transport (OT) distances to try
delta = NULL, # constraint values to try for balancing functinos
grid.length = 7L, # number of values of lambda and delta to try
# if none are provided
cost.function = NULL, # the ground cost to use between covariates
# default is the Euclidean distance
p = 2, # power to raise the cost by
cost.online = "auto", #Should cost be calculated "online" or "tensorized" (stored in memory). "auto" will try to decide for you
debias = TRUE, # use Sinkhorn divergences (debias = TRUE), i.e. debiased Sinkhorn distances,
# or use the Sinkhorn distances (debias = FALSE)
diameter = NULL, # the diameter of the covariate space if known
ot_niter = 1000L, # the number of iterations to run when solving OT distances
ot_tol = 0.001 # the tolerance for convergance of OT distances
)
OTProblem
otp$solve(
niter = 1000L, # maximum number of iterations
tol = 1e-5, # tolerance for convergence
optimizer = "torch", # which optimizer to use "torch" or "frank-wolfe"
torch_optim = torch::optim_lbfgs, # torch optimizer to use if required
torch_scheduler = torch::lr_reduce_on_plateau, # torch scheduler to use if required
torch_args = list(line_search_fn = "strong_wolfe"), # args passed to the torch functions,
osqp_args = NULL, #arguments passed to the osqp solver used for "frank-wolfe" and balance functions
quick.balance.function = TRUE # if balance functions are also present, should an approximate value of the hyperparameter "delta" be found first
)
otp$choose_hyperparameters(
n_boot_lambda = 100L, #Number of bootstrap iterations to choose lambda
n_boot_delta = 1000L, #Number of bootstrap iterations to choose delta
lambda_bootstrap = Inf # penalty parameter to use for OT distances
)
The case above was simply a vanilla optimal transport problem that
could easily be solved by the calc_weight
function in the
main package. Let’s look at a more complicated use case.
Ideally, we’d simply dump the data together and run our OT framework.
nrow <- 100
ncol <- 2
a <- Measure(x = matrix(rnorm(nrow*ncol,mean=c(0.1,0.1)) + 0.1,nrow,ncol,byrow = TRUE), adapt = "weights")
b <- Measure(x = matrix(rnorm(nrow*ncol,mean=c(-0.1,-0.1),sd=0.25),nrow,ncol,byrow = TRUE), adapt = "weights")
c <- Measure(x = matrix(rnorm(nrow*ncol,mean=c(0.1,-0.1)),nrow,ncol,byrow = TRUE), adapt = "weights")
d <- Measure(x = matrix(rnorm(nrow*ncol,mean= c(-0.1,0.1),sd=0.25),nrow,ncol,byrow = TRUE), adapt = "weights")
overall <- Measure(x = torch::torch_vstack(lapply(list(a,b,c,d), function(meas) meas$x)),
adapt = "none")
overall_ot <- OTProblem(a,overall) + OTProblem(b, overall) +
OTProblem(c, overall) + OTProblem(d, overall)
One thing to note that’s kind of cool is that we can add our
OTProblem
objects together to make a unified objective
function.
overall_ot
#> OT Problem:
#> OT(0x7fdd7a0af5d0, 0x7fdd7a2b50a0) +
#> OT(0x7fdd7a15d830, 0x7fdd7a2b50a0) +
#> OT(0x7fdd7a1e1798, 0x7fdd7a2b50a0) +
#> OT(0x7fdd79b17810, 0x7fdd7a2b50a0)
Neat!
We can also run the calc_weight
function in each
treatment group targeting the overall population
source_measures <- list(a,b,c,d)
meas <- x_temp <- NULL
z_temp <- c(rep(1, nrow*4), rep(0,nrow))
wt <- list()
for(i in seq_along(source_measures)) {
meas <- source_measures[[i]]
x_temp <- as.matrix(torch::torch_vstack(list(overall$x,meas$x)))
wt[[i]] <- calc_weight(x = x_temp,
z = z_temp,
estimand = "ATT",
method = "COT")
}
If only moments are available, then each site can run essentially independently. We just need to collect the moments from each site and combine
target.values <-
as.numeric(a$x$mean(1) + b$x$mean(1) +
c$x$mean(1) + d$x$mean(1))/4
a_t <- Measure(x = a$x, adapt = "weights",
target.values = target.values)
b_t <- Measure(x = a$x, adapt = "weights",
target.values = target.values)
c_t <- Measure(x = a$x, adapt = "weights",
target.values = target.values)
d_t <- Measure(x = a$x, adapt = "weights",
target.values = target.values)
all.target.measures <- list(a_t, b_t, c_t, d_t)
Then we can optimize the weights targeting the moments in a bit of a hacky way.
ot_targ <- NULL
for(meas in all.target.measures) {
ot_targ <- OTProblem(meas, meas)
ot_targ$setup_arguments(lambda = 100)
ot_targ$solve(torch_optim = torch::optim_lbfgs,
torch_args = list(line_search_fn = "strong_wolfe"))
}
And we can check the final balance
final.bal <- as.numeric(a_t$x$mT()$matmul(a_t$weights$detach()))
original <- as.numeric(a_t$x$mean(1))
rbind(original,
`final balance` = final.bal,
`target values` = target.values)
#> [,1] [,2]
#> original 0.08015652 0.31129337
#> final balance -0.01093957 0.09104138
#> target values -0.01100336 0.09096446
This will target the moments without information about the underlying distributions. Obviously, we would prefer to use more of the available information, as we describe next.
Say we can pass any amount of data but are limited by the fact that privacy or other restrictions prevent us from sharing the full data at each site. We can instead construct a pseudo-overall population using Wasserstein Barycenters. These construct average distributions. Let’s see how it might work.
In this option, we pass gradients back to the main site. From this, we can construct a pseudo average population. Let’s see how it might work. We first construct our pseudo data.
Importantly, each data point must be initialized to a separate value otherwise all of the points will move together Then we pass this pseudo data and set up a problem at each site.
pseudo_a <- pseudo$detach()
pseudo_b <- pseudo$detach()
pseudo_c <- pseudo$detach()
pseudo_d <- pseudo$detach()
pseudo_a$requires_grad <- pseudo_b$requires_grad <-
pseudo_c$requires_grad <- pseudo_d$requires_grad <- "x"
ota <- OTProblem(a$detach(), # don't update a
pseudo_a)
otb <- OTProblem(b$detach(), # don't update b
pseudo_b)
otc <- OTProblem(c$detach(), # don't update c
pseudo_c)
otd <- OTProblem(d$detach(), # don't update c
pseudo_d)
Then we setup the arguments. For simplicity, we will set
lambda = 0.1
.
ota$setup_arguments(lambda = .1)
otb$setup_arguments(lambda = .1)
otc$setup_arguments(lambda = .1)
otd$setup_arguments(lambda = .1)
Then we setup our optimizer at the main site
opt <- torch::optim_rmsprop(pseudo$x)
sched <- torch::lr_multiplicative(opt, lr_lambda = function(epoch) {0.99})
Then we run our optimization loop like so:
#optimization loop
for (i in 1:100) {
# zero grad of main optimizer
opt$zero_grad()
# get gradients at each site
ota$loss$backward()
otb$loss$backward()
otc$loss$backward()
otd$loss$backward()
# pass grads back to main site
pseudo$grad <- pseudo_a$grad + pseudo_b$grad +
pseudo_c$grad + pseudo_d$grad
# update pseudo data at main site
opt$step()
# zero site gradients
torch::with_no_grad({
pseudo_a$grad$copy_(0.0)
pseudo_b$grad$copy_(0.0)
pseudo_c$grad$copy_(0.0)
pseudo_d$grad$copy_(0.0)
})
# update scheduler
sched$step()
}
Then we pass the final pseudo data back to the sites and optimize the weights at each site:
pseudo_a$x <- pseudo_b$x <- pseudo_c$x <- pseudo_d$x <-
pseudo$x
ota_w <- OTProblem(a, pseudo_a$detach())
otb_w <- OTProblem(b, pseudo_b$detach())
otc_w <- OTProblem(c, pseudo_c$detach())
otd_w <- OTProblem(d, pseudo_d$detach())
ota_w$setup_arguments()
ota_w$solve(torch_args = list(line_search_fn = "strong_wolfe"))
ota_w$choose_hyperparameters()
otb_w$setup_arguments()
otb_w$solve(torch_args = list(line_search_fn = "strong_wolfe"))
otb_w$choose_hyperparameters()
otc_w$setup_arguments()
otc_w$solve(torch_args = list(line_search_fn = "strong_wolfe"))
otc_w$choose_hyperparameters()
otd_w$setup_arguments()
otd_w$solve(torch_args = list(line_search_fn = "strong_wolfe"))
otd_w$choose_hyperparameters()
Note we haven’t checked for convergence when constructing the pseudo-data to save time. You should, however, do this in your own work.
Of course, maybe we can’t pass gradients. Instead, we can create pseudo data at each site.
The second option is to create pseudo-data for each site and then use this to generate an overall average data set. This will allow us to create privacy respecting pseudo-data in each site, i.e., data that is close to the population at A but with different values. Then we take these pseudo-data and create an overall average data set like follows.
First, we need to reinitialize sites again since they were changed in the previous example.
a$weights <- a$init_weights
b$weights <- b$init_weights
c$weights <- c$init_weights
d$weights <- d$init_weights
Then again create pseudo data
Then we pass this pseudo data and set up a problem at each site.
pseudo_a <- pseudo$detach()
pseudo_b <- pseudo$detach()
pseudo_c <- pseudo$detach()
pseudo_d <- pseudo$detach()
pseudo_a$requires_grad <- pseudo_b$requires_grad <-
pseudo_c$requires_grad <- pseudo_d$requires_grad <- "x"
ota <- OTProblem(a$detach(), # don't update a
pseudo_a)
otb <- OTProblem(b$detach(), # don't update b
pseudo_b)
otc <- OTProblem(c$detach(), # don't update c
pseudo_c)
otd <- OTProblem(d$detach(), # don't update c
pseudo_d)
and setup the arguments. For simplicity, we will set
lambda = 0.1
.
ota$setup_arguments(lambda = .1)
otb$setup_arguments(lambda = .1)
otc$setup_arguments(lambda = .1)
otd$setup_arguments(lambda = .1)
Then we solve for the barycenters.
# run separately at each site
ota$solve(torch_optim = torch::optim_rmsprop)
otb$solve(torch_optim = torch::optim_rmsprop)
otc$solve(torch_optim = torch::optim_rmsprop)
otd$solve(torch_optim = torch::optim_rmsprop)
Looking at the pseudo data in group B we can see that the pseudo-data is a much better approximation to B after optimization.
Then we send the pseudo-data back to our main site to create the overall pseudo-data
# send back to the main site and create overall problem
ot_overall <-
OTProblem(pseudo_a$detach(),
pseudo) +
OTProblem(pseudo_b$detach(),
pseudo) +
OTProblem(pseudo_c$detach(),
pseudo) +
OTProblem(pseudo_d$detach(),
pseudo)
ot_overall$setup_arguments(lambda = 0.1)
ot_overall$solve(torch_optim = torch::optim_rmsprop)
Now we have an average population to target at each site, which we can do like so:
# pass pseudo to each site then setup the problems again
ota2 <- OTProblem(a,
pseudo$detach())
otb2 <- OTProblem(b, # don't update b
pseudo$detach())
otc2 <- OTProblem(c,
pseudo$detach())
otd2 <- OTProblem(d,
pseudo$detach())
all.problems <- list(ota2,
otb2,
otc2,
otd2)
# then we optimize the weights at each site separately.
for (prob in all.problems) {
prob$setup_arguments()
prob$solve(
torch_optim = torch::optim_lbfgs,
torch_args = list(line_search_fn = "strong_wolfe")
)
prob$choose_hyperparameters()
}
The final example is a situation where we may have covariate, treatment, and outcome data at one location and want to use it to infer effects in another population with only covariate data. Say we have a binary treatment at our source site and only moments available from the target site.
x_1 <- matrix(rnorm(128*2),128) +
matrix(c(-0.1,-0.1), 128, 2,byrow = TRUE)
x_2 <- matrix(rnorm(256*2), 256) +
matrix(c(0.1,0.1), 256, 2,byrow = TRUE)
target.data <- matrix(rnorm(512*2), 512, 2) * 0.5 +
matrix(c(0.1,-0.1), 512, 2, byrow = TRUE)
constructor.formula <- formula("~ 0 + . + I(V1^2) + I(V2^2)")
target.values <- colMeans(model.matrix(constructor.formula,
as.data.frame(target.data)))
m_1 <- Measure(x = x_1, adapt = "weights",
balance.functions = model.matrix(constructor.formula,
as.data.frame(x_1)),
target.values = target.values)
m_2 <- Measure(x = x_2, adapt = "weights",
balance.functions = model.matrix(constructor.formula,
as.data.frame(x_2)),
target.values = target.values)
ot_binary <- OTProblem(m_1, m_2)
In this case, we’d like the treatment groups to have the same distributions but have the same first and second moments from our target site.
ot_binary$setup_arguments()
ot_binary$solve(torch_optim = torch::optim_lbfgs,
torch_args = list(line_search_fn = "strong_wolfe"))
ot_binary$choose_hyperparameters()
We now checkt to see how everything looks using the
info()
function.
info <- ot_binary$info()
names(info)
#> [1] "loss" "iterations"
#> [3] "balance.function.differences" "hyperparam.metrics"
We can see a variety of things like the metrics from the hyperparameter selection, iterations run, final loss, etc. We can also see how the balance functions are doing in terms of targeting the moments.
info$balance.function.differences
#> $`0x7fdd8cdf5660`
#> $`0x7fdd8cdf5660`$balance
#> torch_tensor
#> 0.001 *
#> -3.3433
#> 3.5927
#> -2.4325
#> 3.5927
#> [ CPUDoubleType{4} ][ grad_fn = <SubBackward0> ]
#>
#> $`0x7fdd8cdf5660`$delta
#> [1] 1e-04
#>
#>
#> $`0x7fdd9e7e9d60`
#> $`0x7fdd9e7e9d60`$balance
#> torch_tensor
#> 0.001 *
#> 1.8163
#> -0.3774
#> 1.8235
#> -1.8946
#> [ CPUDoubleType{4} ][ grad_fn = <SubBackward0> ]
#>
#> $`0x7fdd9e7e9d60`$delta
#> [1] 1e-04
It appears that all of our balance functions are less than the desired tolerance. Finally, the optimal transport distance between treatments 1 and 2 is also improved:
We have demonstrated a variety of examples here. Hopefully we have
made it clear that you can also do regular optimal transport barycenters
even in the case where causal inference isn’t the goal. You can even use
the OTProblem
to solve optimal transport problems when
there are no weights or data to adapt.