mlr3基础知识–生存分析的模型解释(survex包)优质

45次浏览 | 2024-07-18 16:10:06 更新
来源 :互联网
最佳经验

简要回答

序言

在机器学习中,模型解释是指尝试理解和解释模型的预测结果的过程。由于一些机器学习模型是黑盒模型,即难以理解其内部工作原理,因此模型解释变得尤为重要。模型解释的目标是使人们能够理解模型是如何做出特定预测的,以及预测结果的依据是什么。

survex包专注于生存分析领域,提供了用于解释生存分析模型的工具。它包括各种方法,例如基于树的方法和模型评估指标,帮助用户解释生存分析模型的预测结果和特征重要性。

survex包的使用原理与DALEX包类似。survex包为机器学习生存模型提供了与模型无关的解释。它基于DALEX软件包。如果您对可解释的机器学习不熟悉,请参考《解释模型分析》一书——survex中包含的大多数方法都是对EMA中描述的方法进行了扩展,并在DALEX中实现,但适用于具有功能输出的模型。

1. 准备示例

library(mlr3verse)
library(mlr3proba)
task =tsk("gbcs")
task$head()

     time status age estrg_recp grade hormone menopause nodes prog_recp size
1: 2282 0 38 105 3 1 1 5 141 18
2: 2006 0 52 14 1 1 1 1 78 20
3: 1456 1 47 89 2 1 1 1 422 30
4: 148 0 40 11 1 1 1 3 25 24
5: 1863 0 64 9 2 2 2 1 19 19
6: 1933 0 49 64 1 2 2 3 356 56

德国乳腺癌数据集,这个数据集包含了乳腺癌患者的临床特征和相关信息。time:观察时间或生存时间 status:事件状态 age:患者的年龄。estrg_recp:雌激素受体状态 grade:肿瘤分级 hormone:激素治疗情况 menopause:更年期状态 nodes:正性淋巴结的数量 prog_recp:孕激素受体状态 size:肿瘤大小(以毫米为单位)

set.seed(111)
split = partition(task) #分割数据

lrn_ranger = lrn("surv.ranger") #学习器

lrn_ranger$train(task, row_ids = split$train) #训练模型

2.构建解释器

library(survival)
library(survex)

# 测试数据的特征
credit_x = task$data(rows = split$test,
cols = task$feature_names)

# 测试数据中的任务目标
credit_y = task$data(rows = split$test,
cols = task$target_names)

ranger_explainer = explain(lrn_ranger,
data = credit_x,
y = Surv(credit_y$time,credit_y$status),
label = "ranger model") #解释器

  Preparation of a new explainer is initiated 
-> model label : ranger model
-> data : 226 rows 8 cols
-> target variable : 226 values ( 56 events and 170 censored )
-> times : 37 unique time points , min = 308.5 , max = 2531
-> times : ( generated from y as uniformly distributed survival quantiles based on Kaplan-Meier estimator )
-> predict function : predict_newdata()$crank will be used ( [33m default [39m )
-> predict survival function : predict_newdata()$distr$survival will be used ( [33m default [39m )
-> predict cumulative hazard function : predict_newdata()$distr$cumHazard will be used ( [33m default [39m )
-> model_info : package mlr3proba , ver. 0.5.9 , task survival ( [33m default [39m )
A new explainer has been created!

ranger_explainer

  Model label:  ranger model 
Model class: LearnerSurvRanger,LearnerSurv,Learner,R6
Data head :
age estrg_recp grade hormone menopause nodes prog_recp size
1: 38 105 3 1 1 5 141 18
2: 64 9 2 2 2 1 19 19

3.全局解释

3.1 评估模型的性能

perf_credit = model_performance(ranger_explainer)
perf_credit$result

  $`C-index`   #度量模型对生存时间的排序能力。越接近1越好
[1] 0.7699608
attr(,"loss_type")
[1] "risk-based"

$`Integrated C/D AUC`
[1] 0.6505562
attr(,"loss_type")
[1] "integrated"

$`Brier score`
308.5 369 392 425 462 464
0.004173946 0.017304819 0.025160017 0.033198868 0.037988601 0.041442632
534 569 622 668 711 740
0.048734555 0.054926604 0.061468128 0.071571486 0.077183340 0.083407742
777 901 919 978 1022 1044
0.088953440 0.094217604 0.099449164 0.105381540 0.110237268 0.109868125
1088 1154 1177 1180 1280 1298
0.113241186 0.115567089 0.117005850 0.122404880 0.129719570 0.131721290
1328 1377 1456 1521 1684 1774
0.139693925 0.145954202 0.151369430 0.154794009 0.165538786 0.165925747
1781 1931 1959 1990 2033 2450
0.173451857 0.179998478 0.177094735 0.184637948 0.173764049 0.211974561
2531
0.164138469
attr(,"loss_type")
[1] "time-dependent"

# Brier score度量模型的生存概率预测与实际观测值之间的差距。较低的brier 分数表示模型的生存

概率预测与实际结果更为一致。

$`Integrated Brier score`
[1] 0.1329769
attr(,"loss_type")
[1] "integrated"

$`C/D AUC`
[1] 0.5081539 0.5463002 0.6049115 0.6398810 0.7152913 0.7279412 0.6976923
[8] 0.7143357 0.7101586 0.6760918 0.6844086 0.6789560 0.6833554 0.6767014
[15] 0.6669560 0.6504316 0.6626067 0.6773684 0.6816518 0.6788321 0.6846123
[22] 0.6786058 0.6610169 0.6642387 0.6647870 0.6526564 0.6514029 0.6554522
[29] 0.6323679 0.6300516 0.6278521 0.6257466 0.6419355 0.6557987 0.6991678
[36] 0.5696970 0.5892857
attr(,"loss_type")

[1] "time-dependent"


C/D AUC是一个关于时间的面积下的曲线,用于评估生存模型的时间相关性。

plot(perf_credit)#不同时间点对应的Brier分数和C/D AUC值

plot(perf_credit,metrics_type="scalar")

3.2 model_parts()计算特征重要性分数

model_parts()函数中的type参数允许您指定如何计算特征的重要性,通过损失函数 ( type = “difference” )、熵的差值( type = “ratio” ) 或不进行任何变换 ( type = “raw” )。

gbm_effect = model_parts(ranger_explainer)
plot(gbm_effect) #绘制特征重要性

在1300天之前,prog_recp特征的排列会导致损失函数的最大增加,其次是nodes;在1300天之后,nodes特征的排列会导致损失函数的最大增加,其次是prog_recp。说明这两个特征对模型进行预测最为重要。

3.3 model_profile() PD 图计算特征效应 ,默认情况下绘制为 PD 图

model_profile()函数的type参数还允许计算边际剖面(type = "conditional" )和累积局部剖面(type = "accumulated" );

这些图显示将一个变量设置为不同值会如何影响模型的预测。

gbm_profiles = model_profile(ranger_explainer)

  Aggregating predictions.. Progress: 76%. Estimated remaining time: 9 seconds.
Aggregating predictions.. Progress: 81%. Estimated remaining time: 7 seconds.

plot(gbm_profiles)

从图中可以看到,grade 、hormone、menopause、age四个特征对于模型不是很重要,绘制的分布图非常薄,几乎重叠,意味着无论这些变量取什么值,总体预测都会相似。而prog_recp和nodes特征的带非常宽,意味着即使它们的值发生微小变化也会导致预测生存函数的较大差异;其中nodes特征的值越大生存率越低,prog_recp值越小生存率越低。

4.局部解释

4.1 计算模型预测开始

Charlie = credit_x[180, ]  #选择数据点

#output_type = "survival"指定了模型输出为生存函数(survival function)或生存概率;生存函数描述了一个个体在某一时间点存活下来的概率,而生存概率描述了在给定时间段内存活的概率

predict(ranger_explainer, Charlie,output_type = "survival")

           308.5       369       392       425       462       464       534
[1,] 0.9809668 0.9809668 0.9809668 0.9809668 0.9789416 0.9789416 0.9615744
569 622 668 711 740 777 901
[1,] 0.9529591 0.8951028 0.8441129 0.8441129 0.8145368 0.8145368 0.8122322
919 978 1022 1044 1088 1154 1177
[1,] 0.8122322 0.8122322 0.8104742 0.8104742 0.7956562 0.7956562 0.7940665
1180 1280 1298 1328 1377 1456 1521
[1,] 0.7940665 0.7477239 0.7150605 0.7150605 0.7150605 0.7117313 0.7095994
1684 1774 1781 1931 1959 1990 2033
[1,] 0.6735683 0.6728951 0.6728951 0.6578582 0.6578582 0.6578582 0.616663
2450 2531
[1,] 0.5918911 0.5918911

predict(ranger_explainer, Charlie,output_type = "chf")#返回累积风险曲线(CHF),CHF 是指在给定时间点 t 前发生事件的概率,通常用于描述生存时间随时间的累积风险

            308.5        369        392        425        462        464
[1,] 0.01921667 0.01921667 0.01921667 0.01921667 0.02128333 0.02128333
534 569 622 668 711 740 777
[1,] 0.03918333 0.04818333 0.1108167 0.169469 0.169469 0.2051357 0.2051357
901 919 978 1022 1044 1088 1154
[1,] 0.207969 0.207969 0.207969 0.2101357 0.2101357 0.2285881 0.2285881
1177 1180 1280 1298 1328 1377 1456
[1,] 0.2305881 0.2305881 0.2907214 0.3353881 0.3353881 0.3353881 0.3400548
1521 1684 1774 1781 1931 1959 1990
[1,] 0.3430548 0.3951659 0.3961659 0.3961659 0.4187659 0.4187659 0.4187659
2033 2450 2531
[1,] 0.4834325 0.5244325 0.5244325

predict(ranger_explainer, Charlie,output_type = "risk") #风险(risk)的输出类型;风险通常是指在未来一段时间内发生事件的可能性

  [1] 118.8595

4.2 predict_parts()绘制分解图

它将模型的预测分解为可归因于不同解释变量的贡献,order参数允许您指示所选特征的顺序 可以通过两种方法完成:SurvSHAP(t)和SurvLIME

plot(predict_parts(ranger_explainer, new_observation = Charlie))#SurvSHAP(t)

  Aggregating predictions.. Progress: 12%. Estimated remaining time: 3 minutes, 43 seconds.
Aggregating predictions.. Progress: 25%. Estimated remaining time: 3 minutes, 9 seconds.
Aggregating predictions.. Progress: 37%. Estimated remaining time: 2 minutes, 36 seconds.
Aggregating predictions.. Progress: 51%. Estimated remaining time: 2 minutes, 1 seconds.
Aggregating predictions.. Progress: 63%. Estimated remaining time: 1 minute, 30 seconds.
Aggregating predictions.. Progress: 76%. Estimated remaining time: 58 seconds.
Aggregating predictions.. Progress: 89%. Estimated remaining time: 26 seconds.

如图所示,size的值增加这个患者的生存几率,prog_recp则降低了这个患者的生存几率。

plot(predict_parts(ranger_explainer, 
new_observation = Charlie,
type="survlime"))#SurvLIME

图的左侧显示哪些变量最重要,它们的值是增加还是降低生存几率;右侧显示黑盒模型预测与找到的替代模型预测,两条曲线越接近,解释就越准确。

4.3 predict_profile() 绘制 ICE 曲线

plot(predict_profile(ranger_explainer,  credit_x[176, ]))

它们显示当我们一次改变一个变量的值时,预测如何改变。

参考资料

Applied Machine Learning Using mlr3 in R

survex: an R package for explaining machine learning survival models. arXiv preprint arXiv:2308.16113, 2023.

本文地址:https://www.huajie.net.cn/qkl/44376.html

发布于 2024-07-18 16:10:06
收藏
分享
海报
45
上一篇:二宝郭宏才都准备挖中本聪(btcs)? 下一篇:真假郭宏才币市无情二宝有意放下牛腰力推FT逆市上扬,将引领交易所新生态?比特币价格重挫9% 日本6家交易所曝洗钱漏洞

推荐阅读

0 条评论

本站已关闭游客评论,请登录或者注册后再评论吧~

忘记密码?

图形验证码