ADMIT论文理解

#! https://zhuanlan.zhihu.com/p/649301188

ADMIT论文理解

论文信息

标题:Generalization Bounds for Estimating Causal Effects of Continuous Treatments

发表于:NeurIPS 2022 CCF-A

研究问题和意义

主要研究的是连续治疗的因果效应,因为目前大多数研究主要关注于离散因果效应的研究,然而现实世界中影响因素通常是连续的。例如,研究药物的剂量对于疾病恢复的影响,其中药物的剂量是一个连续的值。

论文的主要贡献如下:

  1. 这是第一项为估计平均剂量反应函数(ADRF)提供泛化界限、并减轻选择偏差的研究;
  2. 提供了具有理论保证的 IPM 距离(用于衡量两个分布之间的距离)的离散近似;
  3. 提出了一种算法 ADMIT(Average Dose-response estiMatIon via re-weighTing schema),能够减少选择偏差,同时进行事实和反事实估计;
  4. 在连续治疗环境中进行了合成和半合成实验,证明了 ADMIT 的有效性优于现有基准。

问题定义

个体剂量反应函数:

平均剂量反应函数(average dose-response function,ADRF):

我们的目标是通过个体剂量反应函数来推导出剂量反应函数:

两个假设

来保证我们能够从个体剂量反应函数推倒出剂量反应函数

Ignorability

Conditional Independence Assumption (CIA)

在给定的情况下,潜在输出与治疗的分配相互独立 :

例子1

假设我们有 4 个司机和 8 个乘客,在未发券的情况下因果效应为 0;而在给 A 组发券时候,由于司机不足导致 B 组的乘客无法打到车,所以导致因果效应估计错误。

例子2

以营销活动投放来作为例子,假设随机将人群分到实验组和控制组,实验组用于正常投放营销活动,控制组不投放营销活动。

(实验一) 随机将50%的流量作为实验组(W=1),50%的流量作为控制组(W=0)。结果可以统计出来,两组的转化率分别为P(Y=1|W=1)=0.6和P(Y=1|W=0)=0.4,说明干预和转化有相关性,且相关差异为0.6-0.4=0.2。

(实验二) 如果将上述已经分好组的实验人群交换一下,实验组(W=1)的50%数据不投放营销活动,控制组(W=0)的50%数据投放营销活动。假设两组人群的流量完全随机分配,则实验结果会和实验一一样。

CIA 假设条件的意义:改变干预条件 W 并不改变实验结果,即干预 W 和输出 Y 独立,干预 W不影响潜在结果。

在 CIA 假设条件下,可以推断出在实验组中获得的潜在结果可以推广到全部人群中,同理,测试组的潜在结果也能推广到全部人群。

基于Ignorability我们可以得到

Positive

重叠性

个体使用每种治疗剂量的概率都不为0

理论分析

公式比较多,不想看推导的同学可以查看整体推导图理解基本思想,然后直接跳转到模型实现部分查看

整体推导图

基本思路:目标误差难以直接优化→证明目标误差存在上界→优化误差上界→间接优化目标误差

估计ADRF的误差

是估计值,但是真实值无法从数据中观测到

首先得出的估计方式为使用一个模型来估计:

该模型的观测数据损失函数可以写为:

这个式子表示在因果推断中,通过一个预测模型 对处理(Treatment)条件 下的因变量 进行预测时,对预测误差进行期望损失(expected loss)的计算。

  • :这是一个函数,表示对于给定的观测值 ,使用预测模型 对因变量进行预测时的期望损失。这里的 可以代表损失函数(loss function),它衡量预测值与真实值之间的差异,用于评估预测的准确性。
  • :这是一个条件期望(conditional expectation)的符号。在给定观测值 的条件下,对因变量 进行期望。
  • :这是损失函数,用于度量预测值 与真实值 之间的差异或误差。损失函数可以根据具体的问题和应用来选择,常见的损失函数包括均方误差(Mean Squared Error, MSE)、绝对误差(Mean Absolute Error, MAE)等。
    因此,整个式子 的含义是:在给定观测值 的情况下,使用预测模型 对处理条件为 的因变量 进行预测,并计算预测误差的期望损失。

定理1

为平方误差损失的时候:

根据柯西不等式:

可以推导出

根据偏差方差分解:

通过重要性采样重加权

由于X和T可能不相互独立,即存在选择偏差,例如,病人的年龄不同选择治疗方式的倾向也不同:

所以 通常不等于

那么不等于

等于的情况下为factual loss

不等于的情况下为counterfactual loss

重要性采样的主要思想是通过利用已知分布(称为重要性分布)来估计难以直接采样的目标分布的性质:

使用重要性采样,我们能够通过在的情况下通过重加权后的损失函数来优化

权重计算方式:

重要性采样存在一些问题

  1. 连续治疗下重要性采样条件密度难以估计,可能会导致问题被放大
  2. 大的采样权重可能会导致更大的估计方差

边际损失的泛化界限

选择偏见可以通过IPM距离来缓解,IPM的基本计算公式如下:

IPM(Integral Probability Metrics)是一种用于度量两个概率分布之间差异的方法。IPM 距离通过在函数空间中定义一个距离度量来衡量两个概率分布之间的相似性。给定两个概率分布 ,它们的 IPM 距离记为

IPM 距离的一般形式定义如下:

其中, 是一组函数类, 是从这个函数类中选取的函数, 分别是概率分布 在样本点 处的概率密度函数。

IPM 距离具有以下特性:

  1. 非负性:,当且仅当 时,等号成立。
  2. 对称性:,即交换两个概率分布的位置不影响 IPM 距离。
  3. 三角不等式:,即 IPM 距离满足三角不等式,类似于其他距离度量的性质。
    通过选择不同的函数类 ,可以得到不同的 IPM 距离。在实际应用中,常用的函数类包括 Lipschitz 函数、高斯核函数等。IPM 距离在统计学、机器学习和信息理论等领域中有广泛的应用,特别是在概率分布匹配和生成模型评估中具有重要作用。

引理1

引入重要性采样权重和IPM

反事实损失的边界

证明

re-weighted factual loss

counterfactual loss

定理2

损失存在上界

证明

采用Maximum Mean Discrepancy (MMD) metric 计算IPM:

其中表示一种距离度量,这里使用RBF kernel计算距离

RBF kernel 表示径向基函数(Radial Basis Function,RBF)核函数,也称为高斯核函数。核函数是在机器学习和统计中广泛使用的一种函数,用于度量数据点之间的相似性或内积。

径向基函数核函数是一种常用的核函数,它在将数据映射到高维特征空间后,在该空间中计算数据点之间的相似性。它的形式如下:

其中, 是输入样本, 表示欧几里德距离, 是控制核函数宽度的参数。当 较小时,核函数的值会在邻近的数据点之间迅速减小,而当 较大时,核函数的值会在更大范围内有较小的减小趋势。

径向基函数核函数具有平滑的性质,可以捕捉数据之间的局部相似性,并在一些机器学习算法(如支持向量机、核主成分分析等)中用于构建非线性的特征映射。这使得原本线性不可分的数据在特征空间中变得可分。

在一些机器学习库中,比如Scikit-Learn,RBF kernel 函数可以用于计算数据点之间的径向基函数核矩阵。

引理2

连续的IPM不好求 所以采用离散逼近的方式

引理3

证明

根据IPM 的三角定理

定理3

所以最终优化以及

方法

算法优化目标

通过推理得到的最终损失函数,其中第一项是通过重要性采样重新加权后的事实损失,第二项是不同治疗下的分布的差异,一定程度上表示选择偏差,最小化该项可以减少选择偏差。

模型架构图1


其中,每个 batch 中采样到的数据经过re-weight 网络预测的权重重加权后会与 batch 中的数据计算 IPM 距离,其中,re-weight 网络的目的就是通过预测权重使得分布间的 IPM 变小,从而减少选择偏差。Factual Loss就是观测数据的真实值和 Inference 网络预测值间的差距,同时通过重新加权的方式使得该损失。

模型架构图2

包括三个网络

  1. a representation network
  2. a re-weighting network
  3. an inference network

IPM计算示意图

image_axtv52-BS0

算法推理公式

算法伪代码

源码解析

模型推理

    def forward(self, x, t):
        hidden = self.hidden_features(x)
        hidden = self.drop_hidden(hidden)
        t_hidden = torch.cat((torch.unsqueeze(t, 1), hidden), 1)
        w = self.rwt(t_hidden)
        w = torch.sigmoid(w) * 2
        w = torch.exp(w) / torch.exp(w).sum() * w.shape[0]
        
        out = self.out(t_hidden)

        return out, w, hidden

回归损失

def rwt_regression_loss(w, y, y_pre):
    y_pre, w = y_pre.to('cpu'), w.to('cpu')
    return ((y_pre.squeeze() - y.squeeze())**2 * w.squeeze()).mean()

IPM损失

def IPM_loss(x, t, w, k=5, rbf_sigma=1):
    _, idx = torch.sort(t)
    xw = x * w
    sorted_x = x[idx]
    sorted_xw = xw[idx]
    split_x = torch.tensor_split(sorted_x, k)
    split_xw = torch.tensor_split(sorted_xw, k)
    loss = torch.zeros(k)
    for i in range(k):
        A = split_xw[i]
        tmp_loss = torch.zeros(k - 1)
        idx = 0
        for j in range(k):
            if i == j:
                continue
            B = split_x[j]
            partial_loss = calculate_mmd(A, B, rbf_sigma)
            tmp_loss[idx] = partial_loss
            idx += 1
        loss[i] = tmp_loss.max()
    return loss.mean()
    
 def calculate_mmd(A, B, rbf_sigma=1):
    Kaa = rbf_kernel(A, A, rbf_sigma)
    Kab = rbf_kernel(A, B, rbf_sigma)
    Kbb = rbf_kernel(B, B, rbf_sigma)
    mmd = Kaa.mean() - 2 * Kab.mean() + Kbb.mean()
    return mmd
    
def rbf_kernel(A, B, rbf_sigma=1):
    rbf_sigma = torch.tensor(rbf_sigma)
    return torch.exp(-pdist2sq(A, B) / torch.square(rbf_sigma) *.5)
    
def pdist2sq(A, B):
    # return pairwise euclidean difference matrix
    D = torch.sum((torch.unsqueeze(A, 1) - torch.unsqueeze(B, 0))**2, 2) 
    return D

运行合成数据代码结果

193 0.2953145
194 0.29614568
195 0.2945741
196 0.2949531
197 0.29519412
198 0.2924068
199 0.29387188
200 0.2961232
eval time cost 1.162
eval_mse 0.00599

实验结果

使用合成数据和半合成数据进行评估

合成数据的产生方式

均通过分布产生

半合成数据产生方式

是从真实数据采集的,根据分布产生

New是新闻数据5000news每个3477维度特征,其中表示新闻中单词的次数,治疗表示阅读时间,表示满意度

TCGA 9659个数据每个数据4000维特征,治疗表示药量,表示癌症复发概率

对比算法

  1. DRNet
  2. SCIGAN
  3. VCNet+EBCT
  4. GPS

评价标准

结果分析

  1. TCGA和News的维度很高因此说明ADMIT可以被运用于高纬问题
  2. ADMIT的效果最好

  1. 随着选择偏差的增加,ADMIT表现的一致性很强,说明能够有效缓解选择偏差

结论

  1. 提出了一种新颖的重新加权模式,以减轻连续治疗因果推理中选择偏差的影响
  2. 提供并证明了基于观测(重新加权)和反事实分布之间的 IPM 距离的 ADRF 估计的泛化误差界,以减轻选择偏差
  3. 提供了具有理论保证的IPM距离的离散近似

相关工作

CFRNet和TARNet

论文信息

标题:Estimating individual treatment effect: generalization bounds and algorithms

发表于:ICML 2017 CCF-A

知乎解读:

https://zhuanlan.zhihu.com/p/496194898

针对离散治疗空间

提出了

  • Treatment-Agnostic Representation Network(TARNet)
  • Counterfactual Regression Network(CFRNet)

模型架构

优化目标

其中是标准泛化误差,是不同策略组的协变量分布的距离,目标函数多出来的一项为模型复杂度的惩罚项。我们发现这三项前面都有一个权重系数,其中,是固定的,根据不同策略策略组的样本比例来对泛化误差进行加权求和,当策略组和控制组的样本数相同时。当时候为CFRNet,当时候为TARNet。

算法伪代码

DRNet

论文信息

标题:Learning Counterfactual Representations for Estimating Individual Dose-Response Curves

发表于:AAAI 2020 CCF-A

知乎解读:

https://zhuanlan.zhihu.com/p/448208546

针对连续治疗空间

提出了

  • Dose-Response Network(DRNet)

网络架构

把治疗的连续空间离散化分别用不同的网络头进行预测,类似于一种分层方法。

VCNet

论文信息

标题:VCNet and Functional Targeted Regularization fot Learning Causal Effects of Continuous Treatments

发表于:ICLR 2021

针对连续治疗空间

提出了

  • Varying Coefficient Network(VCNet)

网络架构

优化目标

Loss分为两个部分第一个部分是最小化预测值和观测值间的误差,第二个部分是最大化 的概率。

相当学习了一个倾向性得分→相当于在中间表示中增加了倾向性得分特征→更好的预测因果效应。

SCIGAN

论文信息

标题:GANITE: Estimation of Individualized Treatment Effects Using Generative Adversarial Net

使用了生成对抗网络

针对连续治疗空间

EBCT

entropy balancing for continuous treatments

使用了熵的概念

针对连续治疗空间

暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇