R语言 多个模型的DCA曲线绘制并寻找模型间的净收益交点

发布时间 2023-04-14 17:34:36作者: 王姑娘呀~

# library
library(ggplot2);library(reshape2);library(ggsci)

# Intersection points of DCA curves for multiple models
DCAPoints <- function(Data,ModelNames,Group,LegendPosition,Savepath){
  # Data, with modles probility
  # ModelNames,Column names for each model in the table
  # Group
  # LegendPosition: The position of the legend in the DCA curve

  DataAna <- DCAData[which(DCAData[,"Group"]==Group),c("Label",ModelNames)]

  # Net Benefit calculation
  Pi <- seq(0.01,0.99,0.01)
  # For each model, save the relevant calculations
  NB <- list()
  for(i in 2:ncol(DataAna)){
    TmpData <- DataAna[,c(1,i)]
    TPR <- c(); FPR <- c() ; PI <- c()
    for(j in 1:length(Pi)){
      # After determining the classification threshold, calculate the true positive rate (TPR),
      # false positive rate (FPR), and prevalence rate (PI)
      TableData <- TmpData
      TableData[,2] <- ifelse(TableData[,2]>=Pi[j],1,0)
      Table <- table(GT=TableData[,1],Pre=TableData[,2])
      # TPR,FPR,PI calculation
      if(ncol(Table)==2 & nrow(Table)==2){
        TP <- Table[2,2];TN <- Table[1,1];FP <- Table[1,2];FN <- Table[2,1]
        tpr <- Table[2,2]/sum(Table[2,]); fpr <- Table[1,2]/sum(Table[1,])
        fnr <- Table[2,1]/sum(Table[2,]); tnr <- Table[1,1]/sum(Table[1,])
        PI <- c(PI, (TP+FN)/sum(Table))
        TPR <- c(TPR,tpr)
        FPR <- c(FPR,fpr)
      }
      if(ncol(Table)==1 & colnames(Table)[1]=="1"){
        PI <- c(PI, 1); TPR <- c(TPR,1); FPR <- c(FPR,0)
      }
      if(ncol(Table)==1 & colnames(Table)[1]=="0"){
        PI <- c(PI, 0); TPR <- c(TPR,0); FPR <- c(FPR,1)
      }
    }
    NB[[i-1]] <- round(TPR*PI-FPR*Pi/(1-Pi),2)
  }

  # Treat All and Treat none
  Label <- DataAna[,1]; TP <- c(); TN <- c()
  for(i in 1:length(Pi)){
    LabelChange <- ifelse(Label>Pi[i],1,0)
    Table <- table(GT=Label,Pre=LabelChange)
    # TP,TN calculation
    if(ncol(Table)==2 & nrow(Table)==2){
      TP <- c(TP,Table[2,2]);TN <- c(TN,Table[1,1])
    }
    if(ncol(Table)==1 & colnames(Table)[1]=="1"){
      TP <- c(TP,Table[2,1]); TN <- c(TN,0)
    }
    if(ncol(Table)==1 & colnames(Table)[1]=="0"){
      TP <- c(TP,0); TN <- c(TN,Table[1,1])
    }
  }
  TreatAll <- round(TP/(TP+TN)-FN/(TP+TN)*Pi/(1-Pi),2)
  TreatNone <- rep(0,length(Pi))
  # Treatall and none add into NB
  LenNB <- length(NB)
  NB[[LenNB+1]] <- TreatAll
  NB[[LenNB+2]] <- TreatNone
  names(NB) <- c(ModelNames,"TreatAll","TreatNone ")

  # Data replace
  for(i in 1:(length(NB)-2)){
    RepalceLoaction <- which(NB[[i]] > max(NB[[LenNB+1]]))
    NB[[i]][RepalceLoaction] <- max(NB[[LenNB+1]])
  }

  # intersection calculation
  InterSave <- c()
  FirstComparision <- c()
  for(i in 1:length(NB)){
    FirstComparision <- c(FirstComparision,i)
    NB1 <- round(NB[[i]],2)
    SecondCom <- setdiff(1:length(NB),FirstComparision)
    if(length(SecondCom)>=1){
      for(j in 1:length(SecondCom)){
        NB2 <- round(NB[[SecondCom[j]]],2)
        InterPoint <- c(which(NB1==NB2),which(abs(NB1-NB2)<0.01))
        InterValue <- sort(unique(Pi[InterPoint]))
        if(length(InterValue)>1){
          InterValue2 <- paste0(InterValue, collapse=", ")
          Res <- paste0("The intersection points of ",names(NB)[i]," and ",names(NB)[SecondCom[j]]," are ",InterValue2)
          InterSave <- c(InterSave ,Res )
          print(Res)
        }
        if(length(InterValue)==1){
          Res <- paste0("The intersection points of ",names(NB)[i]," and ",names(NB)[SecondCom[j]]," is ",InterValue)
          InterSave <- c(InterSave ,Res)
          print(Res)
        }
        else{
          next()
        }
      }
    }
    else{
      print("Intersection calculation Over")
    }
  }

  # DCA Curve plot, prepparing data
  DaTaNB <- data.frame(x=Pi,as.data.frame(NB))
  names(DaTaNB)[1] <- c("Threshold")

  # pic data
  GGData <- melt(DaTaNB,id.vars='Threshold')
  names(GGData)[2:3] <- c("Modles","Net Benefit")
  DCAPlot <- ggplot(data=GGData,aes(x=Threshold,y=Net_Benefit,group=Modles)) +

    geom_line(aes(color=Modles),lwd=0.8)+
    scale_y_continuous(name = 'Net Benefit',breaks = round(seq(-0.2,round(max(GGData$Net_Benefit),1),by=0.1),2), limits = c(-0.2,max(GGData$Net_Benefit)+0.05))+
    scale_x_continuous(breaks = seq(0,1,by=0.1))+
    theme_bw()+
    theme(panel.grid.minor = element_line(colour = NA),
      legend.position = LegendPosition,
      legend.text = element_text(size=14,colour = 'black'),
      legend.title = element_text(size=14,colour = 'black'),
      legend.background = element_blank(),
      legend.box.background = element_blank(),
      panel.border = element_rect(),
      axis.title = element_text(size=14,colour = 'black'),
      axis.text = element_text(size = 14,colour = 'black')+
      scale_color_lancet()) # npg可以换成nejm,lancet,jama,jco,ucscgb,d3,locuszoom,tron,futurama

  # save picture as pdf
  pdf(paste0(Savepath,"\\DCA curve in ",Group,".pdf"),family="Times",height=6,width=8)#,height=8,width=8
  print(DCAPlot)
  dev.off()
  InterSave
}

# example
Savepath <- "E:\\WorkMates\\Qxw"
DCADAta <- read.csv(paste0(Savepath,"\\CalibrationData.csv"))
Res <- DCAPoints(DCAData,names(DCAData)[8:12],"Train",LegendPosition,Savepath)