Análisis discriminante lineal en Python (paso a paso)

El análisis discriminante lineal es un método que puede utilizar cuando tiene un conjunto de variables predictoras y le gustaría clasificar una variable de respuesta en dos o más clases.

Este tutorial proporciona un ejemplo paso a paso de cómo realizar un análisis discriminante lineal en Python.

Paso 1: cargue las bibliotecas necesarias

Primero, cargaremos las funciones y bibliotecas necesarias para este ejemplo:

de sklearn. model_selection  importar train_test_split
 desde sklearn. model_selection  importar RepeatedStratifiedKFold
 desde sklearn. model_selection  importar cross_val_score
 de sklearn. discriminant_analysis  importar LinearDiscriminantAnalysis 
 de sklearn importar conjuntos de datos
 importar matplotlib. pyplot  como plt
 importar pandas como pd
 importar numpy como np

Paso 2: cargue los datos

Para este ejemplo, usaremos el conjunto de datos de iris de la biblioteca sklearn. El siguiente código muestra cómo cargar este conjunto de datos y convertirlo en un DataFrame de pandas para que sea más fácil trabajar con él:

#load iris dataset 
iris = datasets. load_iris ()

#convertir conjunto de datos a pandas DataFrame
 df = pd.DataFrame (data = np.c_ [iris [' data '], iris [' target ']],
                 columnas = iris [' feature_names '] + [' target '])
df [' especie '] = pd. Categórico . from_codes (iris.target, iris.target_names)
df.columns = [' s_length ', ' s_width ', ' p_length ', ' p_width ', ' target ', ' especie ']

#ver las primeras seis filas de DataFrame
 df. cabeza ()

   s_length s_width p_length p_width especies de destino
0 5,1 3,5 1,4 0,2 0,0 setosa
1 4,9 3,0 1,4 0,2 0,0 setosa
2 4,7 3,2 1,3 0,2 0,0 setosa
3 4,6 3,1 1,5 0,2 0,0 setosa
4 5,0 3,6 1,4 0,2 0,0 setosa

# encontrar cuántas observaciones totales hay en el conjunto de datos 
len ( índice df. )

150

Podemos ver que el conjunto de datos contiene 150 observaciones en total.

Para este ejemplo, crearemos un modelo de análisis discriminante lineal para clasificar a qué especie pertenece una flor determinada.

Usaremos las siguientes variables predictoras en el modelo:

  • Longitud del sépalo
  • Ancho del sépalo
  • Longitud del pétalo
  • Ancho del pétalo

Y los usaremos para predecir la variable de respuesta Species , que toma las siguientes tres clases potenciales:

  • setosa
  • versicolor
  • virginica

Paso 3: ajuste el modelo LDA

A continuación, ajustaremos el modelo LDA a nuestros datos usando la función LinearDiscriminantAnalsyis de sklearn:

#definir variables de predicción y respuesta
 X = df [[' s_length ', ' s_width ', ' p_length ', ' p_width ']]
y = df [' especie ']

#Ajuste el modelo del modelo LDA
 = LinearDiscriminantAnalysis ()
modelo. encajar (X, y)

Paso 4: use el modelo para hacer predicciones

Una vez que hemos ajustado el modelo usando nuestros datos, podemos evaluar qué tan bien se desempeñó el modelo usando la validación cruzada estratificada repetida de k-veces.

Para este ejemplo, usaremos 10 pliegues y 3 repeticiones:

#Define el método para evaluar el modelo
 cv = RepeatedStratifiedKFold (n_splits = 10 , n_repeats = 3 , random_state = 1 )

#evaluar modelo
puntuaciones = cross_val_score (modelo, X, y, puntuación = ' precisión ', cv = cv, n_jobs = -1)
imprimir (np. media (puntuaciones))  

0.9777777777777779

Podemos ver que el modelo obtuvo una precisión media del 97,78% .

También podemos usar el modelo para predecir a qué clase pertenece una nueva flor, en función de los valores de entrada:

#definir nueva observación
 nueva = [5, 3, 1, .4]

# predice a qué clase pertenece la nueva observación del
 modelo. predecir ([nuevo])

matriz (['setosa'], dtype = '<U10')

Podemos ver que el modelo predice que esta nueva observación pertenece a la especie llamada setosa .

Paso 5: Visualice los resultados

Por último, podemos crear una gráfica LDA para ver los discriminantes lineales del modelo y visualizar qué tan bien separó las tres especies diferentes en nuestro conjunto de datos:

# definir datos para trazar
 X = iris.data
y = iris.target
modelo = LinearDiscriminantAnalysis ()
diagrama_datos = modelo. encajar (X, y). transformar (X)
target_names = iris. target_names

#create LDA plot
 plt. figura ()
colores = [' rojo ', ' verde ', ' azul ']
lw = 2
para color, i, target_name en zip (colores, [0, 1, 2], target_names):
    plt. scatter (data_plot [y == i, 0], data_plot [y == i, 1], alpha = .8, color = color,
                label = target_name)

#add leyenda para trazar
 plt. leyenda (loc = ' mejor ', sombra = Falso , puntos de dispersión = 1)

#display LDA plot
 plt. mostrar ()

análisis discriminante lineal en Python

Puede encontrar el código Python completo utilizado en este tutorial aquí .

  • https://r-project.org
  • https://www.python.org/
  • https://www.stata.com/

Deja un comentario

En muchas pruebas estadísticas, como un ANOVA unidireccional o un ANOVA bidireccional , asumimos que la varianza entre varios grupos…
statologos comunidad-2

Compartimos información EXCLUSIVA y GRATUITA solo para suscriptores (cursos privados, programas, consejos y mucho más)

You have Successfully Subscribed!