Inteligencia Artificial

Modelo de Predicción de casos de gripe A (H7N9) basado en Bosques Aleatorios

El objetivo de este estudio es presentar un modelo de predicción de casos en aprendizaje automático basado en Bosques Aleatorios (Random Forests) que permita predecir el número de pacientes fallecidos y recuperados que han contraído el virus de la gripe A (H7N9) en China durante 2013. Para ello utilizaremos un dataset con información de 134 pacientes, de los cuales se sabe que 31 han fallecido, 46 se han recuperado y 57 tienen un pronóstico desconocido. Para el desarrollo de este modelo se utilizó el lenguaje de programación R, el software RStudio y librerías, entre las que destacan ggplot2 y caret.

Introduccion

El virus de la gripe A(H7N9) forma parte de un subgrupo de virus gripales que normalmente circulan en las aves. Ahora está produciendo infecciones seres humanos, un fenómeno que hasta hace poco no se había observado. La información existente sobre el alcance de la enfermedad causada por este virus y sobre la fuente de exposición es escasa. La enfermedad es preocupante porque ha sido grave en la mayoría de los casos. Por el momento no hay indicios de que pueda transmitirse de persona a persona, pero se están investigando activamente las vías de transmisión tanto de animales a personas como de persona a persona (Organización Mundial de la Salud, 2017). La minería de datos es un campo de la estadística y la informática referido al proceso que intenta descubrir patrones en grandes volúmenes de conjuntos de datos. Utiliza los métodos de la inteligencia artificial, el aprendizaje automático, la estadística y los sistemas de bases de datos. El objetivo general del proceso de minería de datos consiste en extraer información de un conjunto de datos y transformarla en una estructura comprensible para su uso posterior. Mediante los métodos predictivos utilizados en la minería de datos, es posible obtener previsiones de enfermedades víricas en grupos de población definidos.

Modelo de Predicción de casos de gripe H1N1
Virus de la gripe. CDC, Unsplash

Caso de Estudio

El objetivo general de esta investigación es obtener un modelo óptimo que permita generar previsiones de mortalidad y recuperación en pacientes portadores del virus de la gripe A-H7N9. Los objetivos específicos son los siguientes:

  • Realizar un análisis exploratorio de los datos basado en el estudio de algunas medidas de tendencia central (media, mediana, y cuartiles)
  • Desarrollar, entrenar y probar un modelo de bosque aleatorio.

Predicción de casos: Metodología

A través del software R Studio, se realiza un análisis exploratorio de datos, basado en el estudio de medidas de tendencia central (media, mediana y cuartiles). Las medidas de tendencia central son medidas estadísticas que tratan de resumir en un único valor a un conjunto de valores. Representan un centro en el que se encuentra el conjunto de datos. (Quevedo, 2011).

Posteriormente, a través de la librería «caret» de R, se construyen y evalúan los modelos a partir de los métodos predictivos de bosque aleatorio.

Algunos parámetros generados por los modelos utilizados son los siguientes:

  • Mtry: el algoritmo seleccionará el número mtry de predictores para intentar una división para la clasificación al construir un árbol de clasificación.
  • Accuracy: la precisión del predictor se refiere a lo bien que un predictor dado puede adivinar el valor del atributo predicho para un nuevo dato.
  • Kappa: es una métrica que compara una precisión observada con una precisión esperada (probabilidad aleatoria). Se utiliza no sólo para evaluar un único clasificador, sino también para evaluar clasificadores entre sí. También tiene en cuenta el azar (según un clasificador aleatorio), lo que suele significar que es menos engañoso que utilizar simplemente la precisión como métrica. Landis y Koch consideran en sus investigaciones los valores 0-0,20 como leves, 0,21-0,40 como regulares, 0,41-0,60 como moderados, 0,61-0,80 como sustanciales y 0,81-1 como casi perfectos.

Kappa = $\frac{observed precision – expected precision}{1 – expected precision}$

El método Random Forest es una combinación de árboles predictores tal que cada árbol depende de los valores de un vector aleatorio probado independientemente y con la misma distribución para cada uno de ellos. Es una modificación sustancial del bagging que construye una larga colección de árboles no correlacionados y luego los promedia. (Breimann L, 2001).

Aplicación del Modelo de Predicción de casos

Análisis exploratorio de los datos

El archivo a analizar es un .csv con los datos de 134 pacientes con virus h7n9, de los cuales se sabe que 31 han fallecido, 46 se han recuperado y 57 tienen pronóstico desconocido. Las variables que componen el conjunto de datos son las siguientes:

  • case_id: Identificador del paciente.
  • outcome: Pronóstico si lo hay (recuperado o fallecido).
  • age: Edad del individuo.
  • male: Género (1 = masculino, 0 = femenino)
  • hospital: Dato booleano que indica si el paciente ha sido hospitalizado
  • days_to_hospital: número de días transcurridos entre el inicio de la enfermedad y la hospitalización
  • days_to_outcome: número de días transcurridos entre el inicio y el final de la enfermedad.
  • early_outcome: Indica si la enfermedad ha durado menos que la media del conjunto de datos.
  • Jiangsu, Shanghai, Zhejiang, Otros: Variables booleanas que indican el lugar de origen del paciente.

Para este análisis, empezamos importando las bibliotecas que vamos a utilizar.

In [ ]:

library(dplyr)
library(readr)
library(tidyr)
library(ggplot2)
library(caret)
library(gbm)
library(rpart)
library(rattle)
library(rpart.plot)
library(RColorBrewer)

Aquí podemos ver el conjunto de datos completo.

In [15]:

h7n9 <- read.csv("../input/chinah7n9/h7n9.csv")
h7n9
case_idoutcomeagemalehospitaldays_to_hospitaldays_to_outcomeearly_outcomeJiangsuOtherShanghaiZhejiang
<chr><chr><int><int><int><int><int><int><int><int><int><int>
case_1Death581041310010
case_2Death71141110010
case_3Death1101103110100
case_4NA180184601000
case_5Recover2001115701000
case_6Death90173611000
case_7Death541192011000
case_8Death1411112010001
case_9NA391101800001
case_10Death20114610010
case_11Death36112610001
case_12Death24006710010
case_13Death390131210010
case_14Recover151041010010
case_15NA3400113811000
case_16NA511032001000
case_17Death461161410010
case_18Recover381142010010
case_19Death311156700010
case_20Recover271142210100
case_21Recover391112310010
case_22NA561141701000
case_23Recover50104601000
case_24Death36106610010
case_25Recover351103500010
case_26Death491141110010
case_27Recover2301273710001
case_28NA51106610001
case_29Recover480143200010
case_30Recover530062300010
case_107Death611072100100
case_108NA551032200001
case_109NA35100800001
case_110NA251143200100
case_111Death281042200100
case_112NA411011100100
case_113NA330071310001
case_114NA220010800001
case_115NA141071000001
case_116Recover371152100100
case_117Recover480042800100
case_118NA211065701000
case_119Recover121062601000
case_120NA331033801000
case_121Death360153000100
case_122NA141051800001
case_123Death261171600100
case_124Recover521071700100
case_125Recover80002000100
case_126NA5211103100100
case_127Recover15105900100
case_128Recover301172200100
case_129Recover411142800100
case_130NA411112000100
case_131Recover60111210100
case_132NA510102400100
case_133Recover32110200100
case_134Recover2111701000
case_135Death340133200100
case_136NA230111300100

A continuación se ofrece un resumen de los datos clasificados por edades:

In [16]:

summary(h7n9$age)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
   2.00   20.25   34.00   32.46   44.75   61.00 

Construyamos algunos gráficos para obtener más información de los datos.

In [17]:

plot(density(h7n9$age), 
     main = "Density histogram",
     xlab = "Age (years old)",
     ylab = "Density")

In [18]:

ggplot(h7n9, aes(age)) + geom_density(aes(fill=outcome), alpha=1/3)

In [19]:

summary(h7n9['age'])
boxplot(h7n9['age'], main = "Distribution of patients by age",
        xlab = "Age",
        ylab = "Patients",
        col = "orange",
        border = "brown",
        horizontal = TRUE,
        notch = TRUE)
      age       
 Min.   : 2.00  
 1st Qu.:20.25  
 Median :34.00  
 Mean   :32.46  
 3rd Qu.:44.75  
 Max.   :61.00  

In [20]:

summary(h7n9[h7n9$outcome=='Death', 'age'])
boxplot(h7n9[h7n9$outcome=='Death', 'age'], main = "Distribution of deceased patients by age",
        xlab = "Age",
        ylab = "Patients",
        col = "orange",
        border = "brown",
        horizontal = TRUE,
        notch = TRUE)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
   7.00   27.50   36.00   37.19   49.00   61.00      57 

In [21]:

summary(h7n9[h7n9$outcome=='Recover', 'age'])
boxplot(h7n9[h7n9$outcome=='Recover', 'age'], main = "Distribution of recovered patients by age",
        xlab = "Age",
        ylab = "Patients",
        col = "orange",
        border = "brown",
        horizontal = TRUE,
        notch = TRUE)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
   2.00   13.25   26.00   26.89   39.75   60.00      57 

In [22]:

summary(h7n9[is.na(h7n9$outcome), 'age'])
boxplot(h7n9[is.na(h7n9$outcome), 'age'], main = "Distribution of patients without prognosis by age",
        xlab = "Age",
        ylab = "Patients",
        col = "orange",
        border = "brown",
        horizontal = TRUE,
        notch = TRUE)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
   6.00   26.00   35.00   34.39   44.00   56.00 

En la distribución de todos los pacientes, la edad mínima es de 2 años, la edad media es de 32,46 años y la edad máxima es de 61 años.

En la distribución de los pacientes fallecidos, la edad mínima es de 7 años, la edad media es de 37,19 años y la edad máxima es de 61 años.

En la distribución de pacientes recuperados, la edad mínima es de 2 años, la edad media es de 26,89 años y la edad máxima es de 60 años.

En la distribución de pacientes con pronóstico desconocido, la edad mínima es de 6 años, la media de edad es de 34,39 años y la edad máxima es de 56 años.

Podemos observar que la edad media de los pacientes fallecidos es mayor que la edad media de los pacientes recuperados, por lo que podemos inferir que la edad es una variable importante a la hora de establecer modelos predictivos sobre este conjunto de datos. Para llegar a esta conclusión, no es necesario transformar los datos porque no se está utilizando ningún método de análisis paramétrico para ello (regresión, t de student, correlación, ANOVA, etc).

Antes de aplicar los métodos predictivos, dividimos los datos en grupos de entrenamiento y de prueba. Los datos de prueba se componen de los 57 casos con pronóstico desconocido. Los datos de entrenamiento se dividen para validar los modelos: El 70% de los datos de entrenamiento se conservarán para la construcción del modelo y el 30% restante se utilizará para la prueba del modelo. Esta proporción se utilizó debido a que es importante utilizar la mayor parte de los datos para entrenar el modelo, y al mismo tiempo dejar una proporción significativa para ejecutar las pruebas.

In [23]:

unknown_index <- which(is.na(h7n9$outcome))
unknown_data = h7n9[unknown_index, ]
train_data <- h7n9[-unknown_index, ][,-1]

set.seed(1275)
val_index <- createDataPartition(train_data$outcome, p = 0.7, list=FALSE) # training data indices
val_train_data <- train_data[val_index, ] # training data
val_test_data  <- train_data[-val_index, ] # test data

Bosque Aleatorio

Este modelo tiene una precisión del 76,73% con los siguientes parámetros:

  • mtry = 10
  • kappa = 0.4970

Obteniendo la matriz de confusión a partir de la predicción realizada, se registraron 14 resultados acertados y 8 erróneos, lo que representa una precisión del 63,64%. La variable más importante para el modelo es la edad.

In [24]:

model_rf <- caret::train(outcome ~ .,
                         data = val_train_data,
                         method = "rf",
                         preProcess = NULL,
                         trControl = trainControl(method = "repeatedcv", number = 10, repeats = 10, verboseIter = FALSE))

model_rf
Random Forest 

55 samples
10 predictors
 2 classes: 'Death', 'Recover' 

No pre-processing
Resampling: Cross-Validated (10 fold, repeated 10 times) 
Summary of sample sizes: 49, 50, 50, 49, 49, 50, ... 
Resampling results across tuning parameters:

  mtry  Accuracy   Kappa    
   2    0.6804762  0.2994427
   6    0.7564286  0.4744705
  10    0.7585238  0.4872897

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 10.

In [25]:

confusionMatrix(predict(model_rf, val_test_data), as.factor(val_test_data$outcome))
Confusion Matrix and Statistics

          Reference
Prediction Death Recover
   Death       3       2
   Recover     6      11

               Accuracy : 0.6364         
                 95% CI : (0.4066, 0.828)
    No Information Rate : 0.5909         
    P-Value [Acc > NIR] : 0.4195         

                  Kappa : 0.1927         

 Mcnemar's Test P-Value : 0.2888         

            Sensitivity : 0.3333         
            Specificity : 0.8462         
         Pos Pred Value : 0.6000         
         Neg Pred Value : 0.6471         
             Prevalence : 0.4091         
         Detection Rate : 0.1364         
   Detection Prevalence : 0.2273         
      Balanced Accuracy : 0.5897         

       'Positive' Class : Death          
                                         

In [26]:

varImp(model_rf, scale=TRUE) # Importance of the variable
varImp(model_rf, scale=TRUE) %>% plot()
rf variable importance

                  Overall
age              100.0000
days_to_outcome   35.4028
days_to_hospital  32.1660
early_outcome      8.8473
Other              4.9845
male               2.6046
hospital           2.5463
Shanghai           0.7959
Zhejiang           0.1249
Jiangsu            0.0000

In [27]:

predict(model_rf, newdata = unknown_data)  # Prediction

new_h7n9 = unknown_data # Include results in dataset
new_h7n9 %>%
  mutate(outcome=predict(model_rf, newdata=unknown_data))
  1. Recover
  2. Recover
  3. Death
  4. Recover
  5. Death
  6. Death
  7. Recover
  8. Recover
  9. Recover
  10. Recover
  11. Death
  12. Death
  13. Death
  14. Recover
  15. Recover
  16. Death
  17. Recover
  18. Death
  19. Death
  20. Death
  21. Recover
  22. Recover
  23. Recover
  24. Recover
  25. Death
  26. Death
  27. Death
  28. Recover
  29. Recover
  30. Death
  31. Recover
  32. Recover
  33. Recover
  34. Death
  35. Recover
  36. Recover
  37. Recover
  38. Recover
  39. Death
  40. Recover
  41. Recover
  42. Death
  43. Recover
  44. Death
  45. Recover
  46. Death
  47. Recover
  48. Death
  49. Recover
  50. Recover
  51. Recover
  52. Death
  53. Recover
  54. Recover
  55. Recover
  56. Recover
  57. Recover

Levels:

  1. ‘Death’
  2. ‘Recover’
case_idoutcomeagemalehospitaldays_to_hospitaldays_to_outcomeearly_outcomeJiangsuOtherShanghaiZhejiang
<chr><fct><int><int><int><int><int><int><int><int><int><int>
4case_4Recover180184601000
9case_9Recover391101800001
15case_15Death3400113811000
16case_16Recover511032001000
22case_22Death561141701000
28case_28Death51106610001
31case_31Recover431042101000
32case_32Recover461032001000
38case_38Recover28102711000
39case_39Recover381101800001
40case_40Death461151410001
41case_41Death260162800001
42case_42Death251173800010
47case_47Recover441061601000
48case_48Recover371161700001
52case_52Death3600102010001
54case_54Recover47100800001
56case_56Death451161710010
62case_62Death400011810001
63case_63Death33107210100
66case_66Recover44100211000
67case_67Recover281021000001
68case_68Recover29100600001
69case_69Recover351001700001
70case_70Death300083100001
71case_71Death44008800001
78case_80Death46106700001
82case_84Recover60072211000
83case_85Recover520173810010
84case_86Death260081110001
86case_88Recover151041810100
88case_90Recover171131100001
90case_92Recover380172200001
91case_93Death28105700001
93case_95Recover131022200001
94case_96Recover171042201000
97case_99Recover400081700001
98case_100Recover3010111700001
99case_101Death5100111000001
100case_102Recover53100800001
101case_103Recover401163201000
102case_104Death26007800001
103case_105Recover910111710001
106case_108Death551032200001
107case_109Recover35100800001
108case_110Death251143200100
110case_112Recover411011100100
111case_113Death330071310001
112case_114Recover220010800001
113case_115Recover141071000001
116case_118Recover211065701000
118case_120Death331033801000
120case_122Recover141051800001
124case_126Recover5211103100100
128case_130Recover411112000100
130case_132Recover510102400100
134case_136Recover230111300100

In [28]:

summary(predict(model_rf, newdata = unknown_data))

Death21Recover36

Resultados

A partir de los datos obtenidos, se puede determinar que de los 57 pacientes con pronóstico desconocido, 21 fallecerán y 36 se recuperarán, lo que indica una tasa de mortalidad del 36,84% para este conjunto de datos.

En cuanto a los pacientes que fallecerán, los datos son los siguientes:

  • El 52,38% (11 pacientes) son de sexo masculino y el 47,62% (10 pacientes) son de sexo femenino.
  • El 27,27% (6 pacientes) estaban hospitalizados, mientras que el 72,73% (15 pacientes) no lo estaban.
  • El 14,28% (3 pacientes) proceden de Jiangsu, el 66,66% (14 pacientes) de Zhenjiang, el 9,52% (2 pacientes) de Shanghai y el 9,52% (2 pacientes) de otras localidades.
  • El 42,85% (9 pacientes) tuvieron la enfermedad menos tiempo de lo habitual, lo que indica que murieron más rápidamente.

En cuanto a los pacientes que se recuperarán, se obtienen los siguientes datos:

  • El 77,77% (28 pacientes) son de sexo masculino, y el 22,22% (8 pacientes) son de sexo femenino.
  • El 33,33% (12 pacientes) estaban hospitalizados, mientras que el 66,66% (24 pacientes) no lo estaban.
  • El 30,55% (11 pacientes) proceden de Jiangsu, el 50% (18 pacientes) de Zhenjiang, el 2,77% (1 paciente) de Shanghai y el 16,68% (6 pacientes) de otras localidades.
  • El 16,66% (6 pacientes) tuvieron la enfermedad menos tiempo de lo habitual, lo que indica que superaron la enfermedad más rápidamente.

In [29]:

unknown_data$outcome = c('Recover', 'Recover', 'Death', 'Recover', 'Death', 'Death', 'Recover', 'Recover', 'Recover', 'Recover', 'Death', 'Death', 'Death', 'Recover', 
                         'Recover', 'Death', 'Recover', 'Death', 'Death', 'Death', 'Recover', 'Recover', 'Recover', 'Recover', 'Death', 'Death', 'Death', 'Recover', 'Recover', 
                         'Death', 'Recover', 'Recover', 'Recover', 'Death', 'Recover', 'Recover', 'Recover', 'Recover', 'Death', 'Recover', 'Recover', 'Death', 'Recover', 
                         'Death', 'Recover', 'Death', 'Recover', 'Death', 'Recover', 'Recover', 'Recover', 'Death', 'Recover', 'Recover', 'Recover', 'Recover', 'Recover')

par(mfrow=c(1,2))
plot(density(unknown_data$age), main = "Final results of the predictive model",
     xlab = "Forecast", ylab = "Frequency")


hist(unknown_data$age, main = "Histogram of frequencies",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

death_male = unknown_data[ which(unknown_data$male=='1' & unknown_data$outcome=='Death'),]
death_female = unknown_data[ which(unknown_data$male=='0' & unknown_data$outcome=='Death'),]

death_hospital = unknown_data[ which(unknown_data$hospital=='1' & unknown_data$outcome=='Death'),]
death_not_hospital = unknown_data[ which(unknown_data$hospital=='0' & unknown_data$outcome=='Death'),]

death_jiangsu = unknown_data[ which(unknown_data$Jiangsu=='1' & unknown_data$outcome=='Death'),]
death_zhejiang = unknown_data[ which(unknown_data$Zhejiang =='1' & unknown_data$outcome=='Death'),]
death_shanghai = unknown_data[ which(unknown_data$Shanghai=='1' & unknown_data$outcome=='Death'),]
death_other = unknown_data[ which(unknown_data$Other=='1' & unknown_data$outcome=='Death'),]

death_early_outcome = unknown_data[ which(unknown_data$early_outcome=='1' & unknown_data$outcome=='Death'),]
death_not_early_outcome = unknown_data[ which(unknown_data$early_outcome=='0' & unknown_data$outcome=='Death'),]

recover_male = unknown_data[ which(unknown_data$male=='1' & unknown_data$outcome=='Recover'),]
recover_female = unknown_data[ which(unknown_data$male=='0' & unknown_data$outcome=='Recover'),]

recover_hospital = unknown_data[ which(unknown_data$hospital=='1' & unknown_data$outcome=='Recover'),]
recover_not_hospital = unknown_data[ which(unknown_data$hospital=='0' & unknown_data$outcome=='Recover'),]

recover_jiangsu = unknown_data[ which(unknown_data$Jiangsu=='1' & unknown_data$outcome=='Recover'),]
recover_zhejiang = unknown_data[ which(unknown_data$Zhejiang =='1' & unknown_data$outcome=='Recover'),]
recover_shanghai = unknown_data[ which(unknown_data$Shanghai=='1' & unknown_data$outcome=='Recover'),]
recover_other = unknown_data[ which(unknown_data$Other=='1' & unknown_data$outcome=='Recover'),]

recover_early_outcome = unknown_data[ which(unknown_data$early_outcome=='1' & unknown_data$outcome=='Recover'),]
recover_not_early_outcome = unknown_data[ which(unknown_data$early_outcome=='0' & unknown_data$outcome=='Recover'),]

count(death_early_outcome)
n
<int>
9

In [33]:

# Pie Chart
x <-  c(22, 35)
labels <-  c("Deceased","Recovered")
piepercent<- round(100*x/sum(x), 1)
pie(x, labels = piepercent, main = "Deaths & Recoveries comparison",col = rainbow(length(x)))
legend("topright", c("Deceased","Recovered"), cex = 0.8,
       fill = rainbow(length(x)))

par(mfrow=c(2,2))

hist(death_male$age, main = "Total deaths by male gender",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(death_female$age, main = "Total deaths by female gender",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(recover_male$age, main = "Total recovered by male gender",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(recover_female$age, main = "Total recovered by female gender",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

In [31]:

par(mfrow=c(4,2))
hist(death_jiangsu$age, main = "Total deaths in Jiangsu",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(recover_jiangsu$age, main = "Total recovered in Jiangsu",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(death_shanghai$age, main = "Total deaths in Shanghai",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(recover_shanghai$age, main = "Total recovered in Shanghai",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(death_zhejiang$age, main = "Total deaths in Zhejiang",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(recover_zhejiang$age, main = "Total recovered in Zhejiang",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(death_other$age, main = "Total deaths in other location",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

hist(recover_other$age, main = "Total recovered in other location
",
     xlab = "Age",
     ylab = "Frequency",
     col = "red",
     border = "black")

Referencias

Sergio Alves

Ingeniero de Sistemas. MSc. en Data Science. Cuento con una amplia trayectoria profesional en las áreas de Desarrollo Web FullStack, DBA, DevOps, Inteligencia Artificial y Ciencia de Datos. Soy un entusiasta de la música, la tecnología y el aprendizaje contínuo.

Artículos Relacionados

Back to top button