#! https://zhuanlan.zhihu.com/p/649301188
ADMIT论文理解
论文信息
标题:Generalization Bounds for Estimating Causal Effects of Continuous Treatments
发表于:NeurIPS 2022 CCF-A
研究问题和意义
主要研究的是连续治疗的因果效应,因为目前大多数研究主要关注于离散因果效应的研究,然而现实世界中影响因素通常是连续的。例如,研究药物的剂量对于疾病恢复的影响,其中药物的剂量是一个连续的值。
论文的主要贡献如下:
- 这是第一项为估计平均剂量反应函数(ADRF)提供泛化界限、并减轻选择偏差的研究;
- 提供了具有理论保证的 IPM 距离(用于衡量两个分布之间的距离)的离散近似;
- 提出了一种算法 ADMIT(Average Dose-response estiMatIon via re-weighTing schema),能够减少选择偏差,同时进行事实和反事实估计;
- 在连续治疗环境中进行了合成和半合成实验,证明了 ADMIT 的有效性优于现有基准。
问题定义
个体剂量反应函数:
平均剂量反应函数(average dose-response function,ADRF):
我们的目标是通过个体剂量反应函数来推导出剂量反应函数:
两个假设
来保证我们能够从个体剂量反应函数推倒出剂量反应函数
Ignorability
Conditional Independence Assumption (CIA)
在给定
例子1

假设我们有 4 个司机和 8 个乘客,在未发券的情况下因果效应为 0;而在给 A 组发券时候,由于司机不足导致 B 组的乘客无法打到车,所以导致因果效应估计错误。
例子2
以营销活动投放来作为例子,假设随机将人群分到实验组和控制组,实验组用于正常投放营销活动,控制组不投放营销活动。
(实验一) 随机将50%的流量作为实验组(W=1),50%的流量作为控制组(W=0)。结果可以统计出来,两组的转化率分别为P(Y=1|W=1)=0.6和P(Y=1|W=0)=0.4,说明干预和转化有相关性,且相关差异为0.6-0.4=0.2。
(实验二) 如果将上述已经分好组的实验人群交换一下,实验组(W=1)的50%数据不投放营销活动,控制组(W=0)的50%数据投放营销活动。假设两组人群的流量完全随机分配,则实验结果会和实验一一样。
CIA 假设条件的意义:改变干预条件 W 并不改变实验结果,即干预 W 和输出 Y 独立,干预 W不影响潜在结果。
在 CIA 假设条件下,可以推断出在实验组中获得的潜在结果可以推广到全部人群中,同理,测试组的潜在结果也能推广到全部人群。
基于Ignorability我们可以得到
Positive
重叠性
个体使用每种治疗剂量的概率都不为0
理论分析
公式比较多,不想看推导的同学可以查看整体推导图理解基本思想,然后直接跳转到模型实现部分查看
整体推导图
基本思路:目标误差难以直接优化→证明目标误差存在上界→优化误差上界→间接优化目标误差

估计ADRF的误差
首先得出
该模型的观测数据损失函数可以写为:
这个式子表示在因果推断中,通过一个预测模型
:这是一个函数,表示对于给定的观测值 ,使用预测模型 对因变量进行预测时的期望损失。这里的 可以代表损失函数(loss function),它衡量预测值与真实值之间的差异,用于评估预测的准确性。 :这是一个条件期望(conditional expectation)的符号。在给定观测值 的条件下,对因变量 进行期望。 :这是损失函数,用于度量预测值 与真实值 之间的差异或误差。损失函数可以根据具体的问题和应用来选择,常见的损失函数包括均方误差(Mean Squared Error, MSE)、绝对误差(Mean Absolute Error, MAE)等。
因此,整个式子的含义是:在给定观测值 的情况下,使用预测模型 对处理条件为 的因变量 进行预测,并计算预测误差的期望损失。
定理1
当
根据柯西不等式:
可以推导出
根据偏差方差分解:
通过重要性采样重加权
由于X和T可能不相互独立,即存在选择偏差,例如,病人的年龄不同选择治疗方式的倾向也不同:

所以
那么
重要性采样的主要思想是通过利用已知分布(称为重要性分布)来估计难以直接采样的目标分布的性质:
使用重要性采样,我们能够通过在
权重计算方式:
重要性采样存在一些问题
- 连续治疗下重要性采样条件密度难以估计,可能会导致问题被放大
- 大的采样权重可能会导致更大的估计方差
边际损失的泛化界限
选择偏见可以通过IPM距离来缓解,IPM的基本计算公式如下:
IPM(Integral Probability Metrics)是一种用于度量两个概率分布之间差异的方法。IPM 距离通过在函数空间中定义一个距离度量来衡量两个概率分布之间的相似性。给定两个概率分布
IPM 距离的一般形式定义如下:
其中,
IPM 距离具有以下特性:
- 非负性:
,当且仅当 时,等号成立。 - 对称性:
,即交换两个概率分布的位置不影响 IPM 距离。 - 三角不等式:
,即 IPM 距离满足三角不等式,类似于其他距离度量的性质。
通过选择不同的函数类,可以得到不同的 IPM 距离。在实际应用中,常用的函数类包括 Lipschitz 函数、高斯核函数等。IPM 距离在统计学、机器学习和信息理论等领域中有广泛的应用,特别是在概率分布匹配和生成模型评估中具有重要作用。
引理1
引入重要性采样权重和IPM
反事实损失的边界
证明
re-weighted factual loss
counterfactual loss
定理2
损失存在上界
证明
采用Maximum Mean Discrepancy (MMD) metric 计算IPM:
其中
RBF kernel 表示径向基函数(Radial Basis Function,RBF)核函数,也称为高斯核函数。核函数是在机器学习和统计中广泛使用的一种函数,用于度量数据点之间的相似性或内积。
径向基函数核函数是一种常用的核函数,它在将数据映射到高维特征空间后,在该空间中计算数据点之间的相似性。它的形式如下:
其中,
径向基函数核函数具有平滑的性质,可以捕捉数据之间的局部相似性,并在一些机器学习算法(如支持向量机、核主成分分析等)中用于构建非线性的特征映射。这使得原本线性不可分的数据在特征空间中变得可分。
在一些机器学习库中,比如Scikit-Learn,RBF kernel 函数可以用于计算数据点之间的径向基函数核矩阵。
引理2
连续的IPM不好求 所以采用离散逼近的方式
引理3
证明
根据IPM 的三角定理
定理3
所以最终优化
方法
算法优化目标
通过推理得到的最终损失函数,其中第一项是通过重要性采样重新加权后的事实损失,第二项是不同治疗下的
模型架构图1

其中,每个 batch 中采样到的数据
模型架构图2

包括三个网络
- a representation network
- a re-weighting network
- an inference network
IPM计算示意图

算法推理公式
算法伪代码

源码解析
模型推理
def forward(self, x, t):
hidden = self.hidden_features(x)
hidden = self.drop_hidden(hidden)
t_hidden = torch.cat((torch.unsqueeze(t, 1), hidden), 1)
w = self.rwt(t_hidden)
w = torch.sigmoid(w) * 2
w = torch.exp(w) / torch.exp(w).sum() * w.shape[0]
out = self.out(t_hidden)
return out, w, hidden
回归损失
def rwt_regression_loss(w, y, y_pre):
y_pre, w = y_pre.to('cpu'), w.to('cpu')
return ((y_pre.squeeze() - y.squeeze())**2 * w.squeeze()).mean()
IPM损失
def IPM_loss(x, t, w, k=5, rbf_sigma=1):
_, idx = torch.sort(t)
xw = x * w
sorted_x = x[idx]
sorted_xw = xw[idx]
split_x = torch.tensor_split(sorted_x, k)
split_xw = torch.tensor_split(sorted_xw, k)
loss = torch.zeros(k)
for i in range(k):
A = split_xw[i]
tmp_loss = torch.zeros(k - 1)
idx = 0
for j in range(k):
if i == j:
continue
B = split_x[j]
partial_loss = calculate_mmd(A, B, rbf_sigma)
tmp_loss[idx] = partial_loss
idx += 1
loss[i] = tmp_loss.max()
return loss.mean()
def calculate_mmd(A, B, rbf_sigma=1):
Kaa = rbf_kernel(A, A, rbf_sigma)
Kab = rbf_kernel(A, B, rbf_sigma)
Kbb = rbf_kernel(B, B, rbf_sigma)
mmd = Kaa.mean() - 2 * Kab.mean() + Kbb.mean()
return mmd
def rbf_kernel(A, B, rbf_sigma=1):
rbf_sigma = torch.tensor(rbf_sigma)
return torch.exp(-pdist2sq(A, B) / torch.square(rbf_sigma) *.5)
def pdist2sq(A, B):
# return pairwise euclidean difference matrix
D = torch.sum((torch.unsqueeze(A, 1) - torch.unsqueeze(B, 0))**2, 2)
return D
运行合成数据代码结果
193 0.2953145
194 0.29614568
195 0.2945741
196 0.2949531
197 0.29519412
198 0.2924068
199 0.29387188
200 0.2961232
eval time cost 1.162
eval_mse 0.00599
实验结果
使用合成数据和半合成数据进行评估
合成数据的产生方式
半合成数据产生方式
New是新闻数据5000news每个3477维度特征,其中
TCGA 9659个数据每个数据4000维特征,治疗
对比算法
- DRNet
- SCIGAN
- VCNet+EBCT
- GPS
评价标准
结果分析

- TCGA和News的维度很高因此说明ADMIT可以被运用于高纬问题
- ADMIT的效果最好

- 随着选择偏差的增加,ADMIT表现的一致性很强,说明能够有效缓解选择偏差
结论
- 提出了一种新颖的重新加权模式,以减轻连续治疗因果推理中选择偏差的影响
- 提供并证明了基于观测(重新加权)和反事实分布之间的 IPM 距离的 ADRF 估计的泛化误差界,以减轻选择偏差
- 提供了具有理论保证的IPM距离的离散近似
相关工作
CFRNet和TARNet
论文信息
标题:Estimating individual treatment effect: generalization bounds and algorithms
发表于:ICML 2017 CCF-A
知乎解读:
https://zhuanlan.zhihu.com/p/496194898
针对离散治疗空间
提出了
- Treatment-Agnostic Representation Network(TARNet)
- Counterfactual Regression Network(CFRNet)
模型架构

优化目标
其中
算法伪代码

DRNet
论文信息
标题:Learning Counterfactual Representations for Estimating Individual Dose-Response Curves
发表于:AAAI 2020 CCF-A
知乎解读:
https://zhuanlan.zhihu.com/p/448208546
针对连续治疗空间
提出了
- Dose-Response Network(DRNet)
网络架构

把治疗的连续空间离散化分别用不同的网络头进行预测,类似于一种分层方法。
VCNet
论文信息
标题:VCNet and Functional Targeted Regularization fot Learning Causal Effects of Continuous Treatments
发表于:ICLR 2021
针对连续治疗空间
提出了
- Varying Coefficient Network(VCNet)
网络架构

优化目标
Loss分为两个部分第一个部分是最小化预测值和观测值间的误差,第二个部分是最大化
相当学习了一个倾向性得分→相当于在中间表示
SCIGAN
论文信息
标题:GANITE: Estimation of Individualized Treatment Effects Using Generative Adversarial Net
使用了生成对抗网络
针对连续治疗空间
EBCT
entropy balancing for continuous treatments
使用了熵的概念
针对连续治疗空间