Estás leyendo la publicación: Una guía simple para Elastic… – Hacia la IA
Publicado originalmente en Hacia la IA.
El problema del olvido catastrófico
En el campo de la inteligencia artificial, los modelos de aprendizaje profundo, especialmente las redes neuronales, han mostrado un gran éxito en una amplia gama de aplicaciones. Sin embargo, un gran desafío al que se enfrentan estos modelos es el fenómeno del olvido catastrófico. El olvido catastrófico ocurre cuando una red neuronal aprende nuevas tareas y, en el proceso, olvida tareas previamente aprendidas. Esta limitación dificulta el desarrollo de sistemas o redes de aprendizaje permanente que requieren una constante adaptación a nuevos entornos y tareas.
Para contrarrestar el olvido catastrófico, se han propuesto varios métodos:
- Consolidación de Peso Elástico (EWC): Una técnica de regularización que agrega un término de penalización a la función de pérdida basada en la matriz de información de Fisher, restringiendo el proceso de aprendizaje para retener el conocimiento de tareas anteriores.
- Redes neuronales progresivas (PNN): Una arquitectura que entrena columnas separadas de redes neuronales para cada tarea, con conexiones laterales para transferir el conocimiento previamente aprendido a la nueva tarea sin alterar las redes anteriores.
- Aprender sin olvidar (LwF): Un método que incorpora la destilación de conocimiento para entrenar a la red en nuevas tareas mientras preserva las probabilidades de salida de las tareas anteriores, manteniendo así el conocimiento antiguo.
- Aprendizaje continuo con repetición de memoria (CLMR): una técnica que mantiene un búfer de memoria de ejemplos aprendidos previamente, que se reproducen periódicamente junto con nuevos datos para evitar el olvido.
- Codificación escasa: un método de aprendizaje de representación que impone la escasez en las activaciones de la red, lo que lleva a representaciones más distintas y sin superposición para diferentes tareas, lo que reduce la interferencia.
- Aprendizaje Meta-Continuo: Un enfoque que aprovecha los algoritmos de metaaprendizaje, como MAML, para aprender una inicialización independiente del modelo que puede adaptarse rápidamente a nuevas tareas con una interferencia mínima con las tareas aprendidas previamente.
- Inteligencia sináptica (SI): un método de regularización que acumula la importancia de cada peso a lo largo del tiempo durante el aprendizaje y utiliza esta información para restringir las actualizaciones de peso, conservando el conocimiento de las tareas anteriores.
Este tutorial explica la consolidación de peso elástico, incluidos sus principios fundamentales y cómo aplicarlo en PyTorch, un marco de aprendizaje profundo ampliamente utilizado. También cubriremos técnicas de optimización y examinaremos varios escenarios donde EWC puede ser ventajoso. Con el conocimiento obtenido de este tutorial, podrá integrar EWC en sus propios proyectos y abordar el problema del olvido catastrófico.
Consolidación de Peso Elástico
Elastic Weight Consolidation (EWC) es una técnica de regularización que mitiga el olvido catastrófico al restringir el proceso de aprendizaje en las redes neuronales. La idea clave detrás de EWC es agregar un término de penalización cuadrático a la función de pérdida estándar. Este término de penalización considera la distancia entre los valores de peso actuales y los pesos óptimos obtenidos durante el aprendizaje de la tarea anterior. Al hacer esto, EWC reduce la interferencia entre tareas y ayuda a mantener un equilibrio entre aprender nuevas tareas y retener las antiguas.
EWC facilita la retención del conocimiento de la tarea A mientras se entrena en la tarea B. El proceso de entrenamiento se representa en un espacio de parámetros conceptuales, donde las regiones de parámetros que dan como resultado un buen desempeño en la tarea A se representan en gris y las de la tarea B en color crema. Después de aprender la tarea A, los parámetros se ubican en θ∗A.
Si solo consideramos los pasos de gradiente de la tarea B (flecha azul), minimizamos la pérdida de la tarea B pero comprometemos el conocimiento adquirido de la tarea A. Esto corresponde a “sin penalización” en la figura.
Por el contrario, si restringimos todos los pesos uniformemente (flecha verde), el proceso de aprendizaje corresponde a “l2” en la figura.
Resulta que la restricción l2 es tan fuerte que podría obstaculizar el proceso de aprendizaje de la tarea B. En las redes neuronales, a menudo sobre parametrizamos los modelos. Puede haber algunos parámetros que son menos útiles y otros son más valiosos.
EWC descubre una solución para la tarea B sin afectar significativamente el rendimiento de la tarea A (flecha roja) mediante el cálculo explícito de la importancia de cada peso para la tarea A. Este valor de importancia, denominado matriz de información de Fisher, cuantifica la contribución del peso al rendimiento en tareas aprendidas previamente. . La matriz de información de Fisher proporciona una aproximación de la curvatura de la función de pérdida, lo que nos da una idea de cuán sensible es la red a los cambios en los pesos. Los pesos con valores de importancia más altos tienen un mayor impacto en el desempeño de las tareas anteriores, por lo que sus actualizaciones deben estar más restringidas durante el aprendizaje de nuevas tareas.
El proceso de aprendizaje del CEE se puede formular de la siguiente manera:
Dónde:
- es la diagonal de la matriz de información de Fisher, que representa la importancia de cada ponderación
- λ es un hiperparámetro escalar que controla la fuerza de la penalización EWC
Al entrenar la red en una nueva tarea, la función de pérdida de EWC combina la pérdida de la nueva tarea con el término de penalización que restringe las actualizaciones de los pesos según sus valores de importancia. Esto asegura que el proceso de aprendizaje permanezca sesgado hacia la retención del conocimiento aprendido previamente mientras se adapta a la nueva tarea.
Matriz de información de Fisher (FIM)
La matriz de información de Fisher se calcula en función de las derivadas de segundo orden de la verosimilitud logarítmica de los datos dados los parámetros del modelo.
EWC calcula los elementos diagonales de la matriz de información de Fisher aproximada. Esta serie de aproximaciones da como resultado que la diagonal se estime como los gradientes cuadrados promediados entre mini lotes durante un solo paso a través del conjunto de datos de entrenamiento.
Para calcular la diagonal de Fisher para cada peso, siga estos pasos:
- Entrene al modelo en la tarea actual y obtenga los pesos óptimos (θ^*_i) para esa tarea.
- Calcule los gradientes de la función de pérdida con respecto a cada peso: ∇_θ_i L(θ)
- Estime la diagonal de Fisher I_i para cada peso como la expectativa al cuadrado de los gradientes: I_i = E[(∇_θ_i L(θ))²]
Implementación de PyTorch de la pérdida de EWC
Suponiendo que tenemos un modelo de red neuronal entrenado en la tarea A usando el conjunto de datos A, ahora lo entrenamos en la tarea B. A continuación se muestra un fragmento de código para obtener la pérdida de EWC con Fisher empírico.
def get_fisher_diag(modelo, conjunto de datos, parámetros, empírico=Verdadero):
pescador = {}
para n, p en deepcopy(params).items():
p.data.zero_()
pescador[n] = Variable(p.datos)
modelo.eval()
para la entrada, gt_label en el conjunto de datos:
modelo.zero_grad()
salida = modelo (entrada). vista (1, -1)
si es empírico:
etiqueta = gt_etiqueta
demás:
etiqueta = salida.max(1)[1].ver(-1)
negloglikelihood = F.nll_loss(F.log_softmax(salida, dim=1), etiqueta)
negloglikelihood.backward()
para n, p en model.named_parameters():
pescador[n].data += p.grad.data ** 2 / len(conjunto de datos)
pescador = {n: p por n, p en pescador.elementos()}
pescador de vuelta
def get_ewc_loss(modelo, pescador, p_antiguo):
pérdida = 0
para n, p en model.named_parameters():
_pérdida = pescador[n] * (p – p_antiguo[n]) ** 2
pérdida += _pérdida.suma()
pérdida de retorno
modelo = modelo_entrenado_en_tarea_A
conjunto de datos = una_pequeña_muestra_de_conjunto_de_datos_A
params = {n: p para n, p en model.named_parameters() si p.requires_grad}
p_antiguo = {}
para n, p en deepcopy(params).items():
p_viejo[n] = Variable(p.datos)
fisher_matrix = get_fisher_diag(modelo, conjunto de datos, parámetros)
ewc_loss = get_ewc_loss(modelo, fisher_matrix, p_old)
Algunos comentarios:
- El conjunto de datos utilizado para calcular FIM puede ser una pequeña muestra de la tarea 1
- Al calcular la pérdida de NLL, podemos usar la etiqueta de verdad básica o la etiqueta predicha. Ambos dan información al pescador. Cuando usamos la verdad básica, estamos calculando el Fisher empírico.
- En el ciclo de entrenamiento para la tarea B, uno simplemente usa:
pérdida = task2_loss + lambda * ewc_loss_task_A
Aplicaciones y casos de uso
Elastic Weight Consolidation (EWC) tiene una amplia gama de aplicaciones en varios dominios, particularmente en escenarios donde las redes neuronales necesitan aprender y adaptarse a nuevas tareas o entornos sin olvidar los conocimientos adquiridos previamente.
- Inteligencia incorporada: en el campo de la inteligencia incorporada, como la robótica, los vehículos autónomos y la IA de los videojuegos, EWC se puede emplear para hacer que la IE aprenda y se adapte a nuevas tareas en tiempo real.
- Sistemas de recomendación personalizados: en el contexto de los sistemas de recomendación, EWC puede ayudar a crear modelos que aprenden continuamente y se adaptan a las preferencias y comportamientos de los usuarios a lo largo del tiempo.
- Salud y diagnóstico médico: EWC se puede utilizar en aplicaciones de atención médica, como análisis de imágenes médicas o monitoreo de pacientes, donde los modelos deben adaptarse continuamente a nuevos datos de pacientes sin perder el conocimiento de casos anteriores.
- Procesamiento de lenguaje natural: en tareas de procesamiento de lenguaje natural, como el análisis de sentimientos o la traducción automática, EWC se puede emplear para desarrollar modelos que aprenden y se adaptan continuamente a nuevos dominios o idiomas sin olvidar el conocimiento adquirido en tareas anteriores.
Estos son solo algunos ejemplos de las numerosas aplicaciones y casos de uso en los que EWC puede ser beneficioso. Al incorporar EWC en sus proyectos, puede crear redes neuronales que aprenden y se adaptan de manera efectiva a nuevas tareas o entornos, al tiempo que conservan el conocimiento de experiencias anteriores, lo que en última instancia conduce a sistemas de IA más versátiles y robustos.
Materiales sugeridos
- Una encuesta de aprendizaje continuo: Desafiando el olvido en las tareas de clasificación (Matthias Lange)
- Superar el olvido catastrófico en las redes neuronales (James Kirkpatrick)
- Matriz de información de Fisher (Yuan-Hong Liao)
- Implementación de EWC PyTorch por moskomule
Publicado a través de Hacia la IA