¿Cómo automatizar la búsqueda de la tasa de aprendizaje óptima?

Estás leyendo la publicación: ¿Cómo automatizar la búsqueda de la tasa de aprendizaje óptima?

Encontrar la mejor configuración para los hiperparámetros del modelo de aprendizaje profundo se ha considerado durante mucho tiempo más un arte que una ciencia y ha dependido principalmente de prueba y error. La tasa de aprendizaje (LR) es posiblemente el hiperparámetro más importante en el aprendizaje profundo, ya que determina cuánto gradiente se propaga hacia atrás. Esto, a su vez, determina cuánto avanzamos hacia los mínimos. Una tasa de aprendizaje lenta hace que el modelo converja lentamente, mientras que una tasa de aprendizaje rápida hace que el modelo diverja. Como resultado, la tasa de aprendizaje debe ser precisamente la correcta. Este artículo lo ayudará a aprender sobre el buscador automático de tasas de aprendizaje y su implementación. Los siguientes son los temas a considerar.

Tabla de contenido

  1. Dilema detrás de la tasa de aprendizaje óptima
  2. Arquitectura común del buscador automático de LR
  3. Encontrar la tasa de aprendizaje óptima con PyTorch

El aprendizaje no se puede arreglar para todas las redes. Veamos el detrás de escena de las pruebas de tasa de aprendizaje.

Dilema detrás de la tasa de aprendizaje óptima

La tasa de aprendizaje es un hiperparámetro que gobierna la alteración de los pesos de la red en relación con el gradiente de pérdida. Cuantifica la cantidad de información que se aprenderá de cada nuevo mini lote de datos de entrenamiento. Matemáticamente es una penalización sobre la cantidad de información consumida por los pesos del modelo. Cuanto más grandes sean los pasos dados a lo largo de la trayectoria hasta la función de pérdida más baja, donde se encuentran los parámetros óptimos del modelo, más rápido aprende.

Revista de análisis de la India

El rango de prueba de tasas de aprendizaje está limitado a una época de repeticiones de entrenamiento, y la tasa de aprendizaje aumenta con cada mini lote de datos. La tasa de aprendizaje aumenta de un número muy pequeño a un número muy grande durante el procedimiento, lo que hace que la pérdida de entrenamiento comience con una meseta, luego disminuya a un valor mínimo y luego explote. Este comportamiento habitual se puede trazar (como se muestra en el gráfico a continuación) y utilizarse para elegir un rango adecuado para la tasa de aprendizaje, especialmente en la región donde la pérdida está disminuyendo.

🔥 Recomendado:  9 modelos innovadores basados ​​en GPT desarrollados en India
Revista de análisis de la India

La tasa de aprendizaje más baja sugerida es el valor en el que la pérdida se reduce más rápidamente (gradiente negativo mínimo), mientras que la tasa de aprendizaje máxima recomendada es mucho menor que la tasa de aprendizaje en la que la pérdida es menor. Es mucho menos, digamos diez veces, porque para trazar una versión suavizada de la pérdida, es probable que elegir la tasa de aprendizaje que corresponde a la pérdida más pequeña sea demasiado grande, lo que hace que la pérdida diverja durante el entrenamiento.

¿Está buscando un repositorio completo de bibliotecas de Python utilizadas en ciencia de datos, echa un vistazo aquí.

Arquitectura común del buscador automático de LR

Un buscador de tasa de aprendizaje automático típico utiliza un método de tasa de aprendizaje cíclico. El objetivo del algoritmo es proporcionar un método para determinar las tasas de aprendizaje globales para entrenar redes neuronales que eliminen el requisito de cientos de pruebas para determinar los mejores valores sin procesamiento adicional. Al introducir la noción de la prueba de rango LR, CLR ofrece un excelente rango de tasa de aprendizaje (rango LR) para un experimento.

Una buena tasa de aprendizaje es aquella que da como resultado una disminución significativa en la pérdida de la red. Aquí viene la brujería de CLR. El documento original de CLR describe un experimento en el que puede monitorear el comportamiento de la tasa de aprendizaje en relación con la pérdida. El experimento es simple de entender: después de cada mini-lote, aumente progresivamente la tasa de aprendizaje mientras observa la pérdida en cada paso. Este lento aumento puede ser lineal o exponencial. Y, claro, esto es similar a la prueba de rango LR.

Revista de análisis de la India

Después de realizar el experimento, Leslie demostró que a tasas de aprendizaje excesivamente bajas, la pérdida puede disminuir, pero solo a un ritmo muy lento. Cuando ingrese a la zona de tasa de aprendizaje ideal, verá una fuerte disminución en la función de pérdida. Si la tasa de aprendizaje aumenta aún más, puede producir una pérdida de parámetros en la red, lo que puede resultar en un aumento de las pérdidas. Entonces, según este experimento, es evidente que está buscando una fuerte caída en la función de pérdida, y puede hacerlo analizando los gradientes de la función de pérdida en varias etapas de entrenamiento.

🔥 Recomendado:  Cómo crear un blog gratis en WordPress con dominio personalizado

Encontrar la tasa de aprendizaje óptima con PyTorch

Este artículo para encontrar la tasa de aprendizaje óptima para la red neuronal utiliza el paquete de iluminación PyTorch. El modelo utilizado para este artículo es un clasificador LeNet, una típica red neuronal convolucional para principiantes. Este modelo se utiliza como clasificador de imágenes en este artículo, el conjunto de datos utilizado es el famoso conjunto de datos MNIST.

Comencemos por instalar e importar las dependencias y los requisitos previos.

!pip instalar PyTorch-lightning !pip instalar torchmetrics

Es necesario instalar la métrica de la antorcha con la iluminación PyTorch porque el módulo de métricas se desplaza del paquete de iluminación PyTorch. Las métricas de antorcha tienen predefinidas y también pueden crear un método de evaluación personalizado.

importar torch.nn como nn importar torch.nn.funcional como F importar torchvision.transforms como transformaciones importar PyTorch_lightning como pl de PyTorch_lightning.loggers importar TensorBoardLogger de torchmetrics importar funcional como FM

Uno puede descargar fácilmente el conjunto de datos MNIST o usar el código del cuaderno adjunto en la sección de referencias para descargar y dividir los datos.

El modelo es un clasificador LeNet construido con el “LightningModule” de la iluminación PyTorch. El modelo tiene 2 capas de convolución y 3 capas lineales con 120, 84 y 10 neuronas completamente conectadas respectivamente.

def __init__(self, num_classes=10): super().__init__() self.lr = 2e-3 self.conv1 = nn.Conv2d(1, 6, 5, padding=2) self.conv2 = nn.Conv2d( 6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def training_step(self , lote, lote_idx): x, y = lote y_hat = self(x) acc = FM.accuracy(y_hat, y) loss = F.cross_entropy(y_hat, y) self.log_dict({‘train_acc’: acc, ‘train_loss ‘: loss}, on_step=True, on_epoch=False) return {“precisión del entrenamiento”: acc,”loss”:loss} def validation_step(self, lote, lote_idx): x, y = lote y_hat = self(x) val_acc = FM.accuracy(y_hat, y) loss = F.cross_entropy(y_hat, y) self.log_dict({‘val_acc’: val_acc, ‘val_loss’: loss}, reduce_fx=torch.mean) return {“precisión de validación”: val_acc, “pérdida de validación”:pérdida}

Aquí hay un vistazo del modelo clasificador. Para obtener más información, consulte el cuaderno de Colab adjunto en la sección de referencia.

Ahora es el momento de entrenar el modelo. Habrá dos versiones de este modelo, una sin el buscador de tasas de autoaprendizaje y la otra con la tasa de autoaprendizaje. Las pérdidas y las métricas de precisión se registrarían utilizando el tensorboard para una mejor visualización.

🔥 Recomendado:  ¿Por qué la tasa de aprendizaje debe ser siempre baja?

data_directory= “/content/” batch_size=60 logger_directory= ‘registros/sin_auto_lr’ name_of_log= ‘LeNet classifier’ version_of_log = 1.0 Default_lr= 1e-3 max_epochs=8 model = LeNet_classifier_model() dataset = MNISTData(data_directory, batch_size) logger = TensorBoardLogger (save_dir=logger_directory,version=version_of_log,name=name_of_log) trainer = pl.Trainer(gpus=1, max_epochs=max_epochs, logger=logger, auto_lr_find=False, val_check_interval=0.5) model.lr = Default_lr print(f’Modelo predeterminado LR: {modelo.lr}’) trainer.fit(modelo, conjunto de datos)

Revista de análisis de la India
Revista de análisis de la India
Revista de análisis de la India

Entonces, con una tasa de aprendizaje de 0.001 y un total de 8 épocas, la pérdida mínima se logra en 5000 pasos para los datos de entrenamiento y para la validación, son 6500 pasos que parecen disminuir a medida que aumentan las épocas.

Encontremos la tasa de aprendizaje óptima con menos pasos requeridos y menor pérdida y puntaje de alta precisión. Para usar el buscador LR automático, cada uno sería el mismo que antes. Solo necesita agregar estas líneas al código que encontrará la tasa de aprendizaje óptima y trazará la curva de tasa de pérdida frente a aprendizaje para una mejor visualización.

lr_finder = trainer_2.tuner.lr_find(modelo, conjunto de datos) model.hparams.lr = lr_finder.suggestion() print(f’Auto-buscar modelo LR: {model.hparams.lr}’) fig = lr_finder.plot(suggest= Verdadero)

Revista de análisis de la India
Revista de análisis de la India

Entonces, la tasa de aprendizaje óptima para el modelo es 0,025, que es mayor que la tasa de aprendizaje predeterminada. Por lo tanto, el tiempo computacional sería menor en comparación y sería menos rentable. Además, puede entrenar el modelo en esta tasa de aprendizaje, dejándolo a usted.

Conclusiones

El buscador de LR es una gran herramienta para determinar la mejor tasa de aprendizaje para una situación determinada, pero debe usarse con precaución. Es fundamental utilizar los mismos pesos iniciales tanto en la prueba de rango de velocidad de aprendizaje como en el entrenamiento posterior del modelo. Nunca suponga que las tasas de aprendizaje descubiertas son óptimas para la inicialización de cualquier modelo. Con este artículo, hemos entendido el concepto de buscador automático de tasas de aprendizaje con implementación.

Referencias

Tabla de Contenido