Estás leyendo la publicación: ¿Qué hace que JAX sea tan impresionante?
Para la investigación de aprendizaje automático de alto rendimiento, Just After eXceution (JAX) es NumPy en la CPU, la GPU y la TPU, con una excelente diferenciación automatizada. Es una biblioteca de Python para computación numérica de alto rendimiento, en particular para la investigación de aprendizaje automático. Su API numérica se basa en NumPy, una biblioteca de funciones utilizadas en computación científica. Python y NumPy son lenguajes de programación reconocidos y utilizados, lo que hace que JAX sea sencillo, versátil y fácil de implementar. Este artículo se centrará en las funciones y la implementación de JAX para crear un modelo de aprendizaje profundo. Los siguientes son los temas a tratar.
Tabla de contenido
- Razón para usar JAX
- ¿Qué es XLA?
- ¿Qué hay en el ecosistema de JAX?
- Construyendo un modelo ML con JAX
JAX no es un producto oficial de Google, pero su popularidad está aumentando, conozcamos las razones detrás de la popularidad.
Razón para usar JAX
Aunque JAX proporciona una API sólida y directa para desarrollar código numérico acelerado, trabajar de manera eficiente con JAX ocasionalmente requiere una reflexión adicional. JAX es esencialmente un compilador Just-In-Time (JIT) que se enfoca en generar código eficiente mientras utiliza la simplicidad de Python puro. Además de la API de NumPy, JAX contiene un conjunto extensible de transformaciones de funciones componibles que ayudan en la investigación de aprendizaje automático, como:
- Diferenciación: La optimización basada en gradientes es esencial para el aprendizaje automático. JAX permite de forma nativa la diferenciación automatizada de funciones numéricas arbitrarias tanto en modo directo como inverso mediante transformaciones de funciones como Gradients, Hessian y Jacobians (jacfwd y jacrev).
- Vectorización: En la investigación de aprendizaje automático, con frecuencia se aplica una sola función a grandes cantidades de datos, como calcular la pérdida en un lote o evaluar gradientes por ejemplo para un aprendizaje diferencialmente privado. La transformación vmap en JAX permite la vectorización automatizada, lo que simplifica este tipo de programación. Al desarrollar nuevos algoritmos, por ejemplo, los investigadores no necesitan considerar el procesamiento por lotes. JAX también permite el paralelismo de datos a gran escala con la transformación pmap relacionada, que distribuye elegantemente datos que son demasiado grandes para la memoria de un solo acelerador.
- Compilación justo a tiempo (JIT): XLA se utiliza para compilar JIT y ejecutar aplicaciones JAX en aceleradores GPU y Cloud TPU. La compilación JIT, junto con la API compatible con NumPy de JAX, permite a los investigadores sin experiencia previa en computación de alto rendimiento escalar fácilmente a uno o más aceleradores.
¿Está buscando un repositorio completo de bibliotecas de Python utilizadas en ciencia de datos, echa un vistazo aquí.
¿Qué es XLA?
XLA (álgebra lineal acelerada) es un compilador de álgebra lineal específico de dominio que puede acelerar los modelos de TensorFlow con pequeñas modificaciones en el código fuente.
Cuando se ejecuta un programa de TensorFlow, el ejecutor de TensorFlow realiza cada operación de forma independiente. El ejecutor se envía a una implementación de kernel de GPU precompilada para cada operación de TensorFlow. XLA ofrece una forma adicional de ejecución del modelo al compilar el gráfico de TensorFlow en una secuencia de núcleos informáticos creados especialmente para el modelo especificado. Debido a que estos núcleos son específicos del modelo, pueden usar información específica del modelo para optimizar.
Arquitectura de XLA
El lenguaje de entrada a XLA se llama Operaciones de alto nivel (HLO). Es más conveniente pensar en HLO como una representación intermedia del compilador. Entonces, HLO representa un programa “entre” los idiomas de origen y de destino.
XLA traduce los gráficos descritos en HLO en instrucciones de máquina para múltiples plataformas. XLA es modular en el sentido de que se puede insertar fácilmente un backend alternativo para apuntar a alguna arquitectura de hardware innovadora. XLA transfiere el cálculo de HLO a un backend después de la fase independiente del objetivo. El backend puede realizar optimizaciones adicionales a nivel de HLO, esta vez teniendo en cuenta los requisitos y los datos específicos del objetivo.
El siguiente paso es generar código específico de destino. LLVM es utilizado por los backends de CPU y GPU incluidos con XLA para la optimización de representación intermedia de bajo nivel y la creación de código. Estos backends producen el IR de LLVM requerido para describir de manera eficiente el cálculo de XLA HLO y luego usan LLVM para emitir código nativo desde esta representación intermedia de LLVM.
Razón para usar XLA
Hay cuatro razones principales para usar XLA.
- Porque la traducción parece implicar análisis y síntesis por definición. La traducción palabra por palabra es ineficaz.
- Dividir el complejo desafío de la traducción en dos mitades más simples y manejables.
- Se puede construir un back-end nuevo para un front-end existente para proporcionar compiladores redirigibles y viceversa.
- Para llevar a cabo optimizaciones independientes de la máquina.
¿Qué hay en el ecosistema de JAX?
El ecosistema consta de cinco bibliotecas diferentes.
haiku
Tratar con objetos con estado, como redes neuronales con parámetros entrenables, puede ser difícil con el paradigma de programación JAX de transformaciones de funciones componibles. Haiku es una biblioteca de redes neuronales que permite a los usuarios utilizar paradigmas tradicionales de programación orientada a objetos mientras aprovechan la potencia y la simplicidad del paradigma funcional puro de JAX.
Varios proyectos externos, incluidos Coax, DeepChem y NumPyro, utilizan activamente Haiku. Extiende la API para Sonnet, nuestro modelo de programación de redes neuronales basado en módulos en TensorFlow.
Optax
La optimización basada en gradientes es importante para el aprendizaje automático. Optax incluye una biblioteca de transformación de gradiente, así como operadores de composición (como cadena) que permiten el desarrollo de numerosos optimizadores comunes (como RMSProp o Adam) en una sola línea de código. La estructura compositiva de Optax se presta fácilmente a la recombinación de los mismos elementos fundamentales en optimizadores personalizados. También incluye utilidades para la estimación de gradientes estocásticos y optimización de segundo orden.
Rlax
RLax es una biblioteca que proporciona componentes básicos importantes para el desarrollo del aprendizaje por refuerzo (RL), también conocido como aprendizaje por refuerzo profundo. Los componentes de RLax incluyen TD-learning, gradientes de políticas, críticos de actores, MAP, optimización de políticas proximales, transformación de valor no lineal, funciones de valor genéricas y numerosos enfoques de exploración.
RLax no pretende ser un marco para desarrollar e implementar sistemas de agentes RL completos. Acme es un ejemplo de una arquitectura de agente con todas las funciones construida sobre componentes RLax.
Chex
Las pruebas son esenciales para la confiabilidad del software, y el código de investigación no es una excepción. Extraer hallazgos científicos de ensayos de investigación requiere fe en la precisión de su código. Chex es una colección de utilidades de prueba que utilizan los escritores de bibliotecas para garantizar que los componentes básicos comunes sean correctos y resistentes, así como los usuarios finales para validar sus programas experimentales.
Chex incluye una serie de herramientas, como pruebas unitarias con reconocimiento de JAX, afirmaciones sobre atributos de tipo de datos JAX, simulacros y falsificaciones, y entornos de prueba de dispositivos múltiples.
Jraph
Jraph es una pequeña biblioteca para trabajar con redes neuronales Graph GNN en JAX. Jraph proporciona una estructura de datos estandarizada para gráficos, un conjunto de herramientas para trabajar con gráficos y un conjunto de modelos de redes neuronales gráficas que se pueden bifurcar y expandir fácilmente. Otras características importantes incluyen el procesamiento por lotes de GraphTuple que aprovecha los aceleradores de hardware, la compatibilidad con la compilación JIT para gráficos de forma variable a través del relleno y el enmascaramiento, y las pérdidas especificadas en las particiones de entrada. Jraph, como Optax y nuestras otras bibliotecas, no tiene restricciones en la elección del usuario de una biblioteca de red neuronal.
Construyendo un modelo ML con JAX
Para este artículo, se crea un modelo de red adversaria generativa en la plataforma TensorFlow entrenada en el conjunto de datos MNIST en Jax’s Haiku.
Comencemos instalando Haiku y Optax
!pip instalar dm-haiku! pip instalar optax
Importar bibliotecas necesarias
importar funciones de tipeo importar Cualquiera, NamedTuple importar haiku como hk importar jax importar jax.numpy como jnp importar matplotlib.pyplot como plt importar numpy como np importar tensorflow como tf importar tensorflow_datasets como tfds
Lectura del conjunto de datos
mnist_dataset = tfds.load(“mnist”) def make_dataset(batch_size, seed=1): def _preprocess(sample): image = tf.image.convert_image_dtype(sample[“image”]tf.float32) devuelve 2.0 * imagen – 1.0 ds = mnist[“train”]
ds = ds.map(map_func=_preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.cache() ds = ds.shuffle(10 * tamaño_lote, seed=seed).repeat().batch(tamaño_lote) volver iter(tfds.as_numpy(ds))
Creando generador y discriminador
El modelo se utiliza como generador para producir nuevas instancias plausibles del área del problema, mientras que el modelo se utiliza como discriminador para determinar si un ejemplo es real (del dominio) o generado.
class Generator(hk.Module): def __init__(self, canales_de_salida=(32, 1), nombre=Ninguno): super().__init__(nombre=nombre) self.canales_de_salida = canales_de_salida def __call__(self, x): x = hk.Lineal(7 * 7 * 64)(x) x = jnp.reforma(x, x.forma[:1] + (7, 7, 64)) para output_channels en self.output_channels: x = jax.nn.relu(x) x = hk.Conv2DTranspose(output_channels=output_channels, kernel_shape=[5, 5]zancada=2, padding=”MISMO”)(x) return jnp.tanh(x) clase Discriminador(hk.Module): def __init__(self, canales_de_salida=(8, 16, 32, 64, 128), zancadas= (2, 1, 2, 1, 2), nombre=Ninguno): super().__init__(nombre=nombre) self.output_channels = output_channels self.strides = strides def __call__(self, x): for output_channels, stride in zip(self.output_channels, self.strides): x = hk.Conv2D(output_channels=output_channels, kernel_shape=[5, 5]paso=paso, relleno=”MISMO”)(x) x = jax.nn.leaky_relu(x, pendiente_negativa=0.2) x = hk.Flatten()(x) logits = hk.Linear(2)(x) return logits
Creando el algoritmo GAN
import optax class GAN_algo_basic: def __init__(self, num_latents): self.num_latents = num_latents self.gen_transform = hk.sin_aplicar_rng( hk.transform(lambda *args: Generator()(*args))) self.disc_transform = hk.sin_aplicar_rng ( hk.transform(lambda *args: Discriminator()(*args))) self.optimizers = GANTuple(gen=optax.adam(1e-4, b1=0.5, b2=0.9), disc=optax.adam(1e -4, b1=0.5, b2=0.9)) @functools.partial(jax.jit, static_argnums=0) def initial_state(self, rng, batch): dummy_latents = jnp.zeros((batch.shape[0]self.num_latents)) rng_gen, rng_disc = jax.random.split(rng) params = GANTuple(gen=self.gen_transform.init(rng_gen, dummy_latents), disc=self.disc_transform.init(rng_disc, lote)) print( “Generador: \n\n{}\n”.format(tree_shape(params.gen))) print(“Discriminador: \n\n{}\n”.format(tree_shape(params.disc))) opt_state = GANTuple(gen=self.optimizers.gen.init(params.gen), disc=self.optimizers.disc.init(params.disc)) return GANState(params=params, opt_state=opt_state) def sample(self, rng, gen_params, num_samples): “””Genera imágenes a partir de latentes de ruido.””” latents = jax.random.normal(rng, shape=(num_samples, self.num_latents)) return self.gen_transform.apply(gen_params, latents) def gen_loss (self, gen_params, rng, disc_params, lote): fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0]) fake_logits = self.disc_transform.apply(disc_params, fake_batch) fake_probs = jax.nn.softmax(fake_logits)[:, 1]
loss = -jnp.log(fake_probs) return jnp.mean(loss) def disc_loss(self, disc_params, rng, gen_params, lote): fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0]) real_and_fake_batch = jnp.concatenate([batch, fake_batch]eje=0) real_and_fake_logits = self.disc_transform.apply(disc_params, real_and_fake_batch) real_logits, fake_logits = jnp.split(real_and_fake_logits, 2, axis=0) real_labels = jnp.ones((batch.shape[0],), dtype=jnp.int32) real_loss = sparse_softmax_cross_entropy(real_logits, real_labels) fake_labels = jnp.zeros((batch.shape[0],), dtype=jnp.int32) fake_loss = sparse_softmax_cross_entropy(fake_logits, fake_labels) return jnp.mean(real_loss + fake_loss) @functools.partial(jax.jit, static_argnums=0) def update(self, rng, gan_state, batch) : rng, rng_gen, rng_disc = jax.random.split(rng, 3) disc_loss, disc_grads = jax.value_and_grad(self.disc_loss)( gan_state.params.disc, rng_disc, gan_state.params.gen, lote) disc_update, disc_opt_state = self.optimizers.disc.update( disc_grads, gan_state.opt_state.disc) disc_params = optax.apply_updates(gan_state.params.disc, disc_update) gen_loss, gen_grads = jax.value_and_grad(self.gen_loss)(gan_state.params.gen, rng_gen , gan_state.params.disc, lote) gen_update, gen_opt_state = self.optimizers.gen.update( gen_grads, gan_state.opt_state.gen) gen_params = optax.apply_updates(gan_state.params.gen, gen_update) params = GANTuple(gen=gen_params , disc=disc_params) opt_state = GANTuple(gen=gen_opt_state, disc=disc_opt_state) gan_state = GANState(params=params, opt_state=opt_state) log = { “gen_loss”: gen_loss, “disc_loss”: disc_loss, } return rng, gan_state, registro
Entrenando al modelo
for step in range(num_steps): rng, gan_state, log = model.update(rng, gan_state, next(dataset)) if step % log_every == 0: log = jax.device_get(log) gen_loss = log[“gen_loss”]
pérdida_disco = registro[“disc_loss”]
print(f”Paso {paso}: ” f”pérdida_gen = {pérdida_gen:.3f}, pérdida_disco = {pérdida_disco:.3f}”) pasos.append(paso) pérdidas_gen.append(pérdida_gen) pérdidas_disco.append(pérdida_disco)
El modelo será entrenado para 5000 pasos debido a limitaciones de tiempo. Depende del usuario para seleccionar el número de pasos. Para 5000 pasos tomó aproximadamente 60 minutos.
Analizando las pérdidas para el generador y el discriminador
fig, axes = plt.subplots(1, 2, figsize=(20, 6)) # Trazar la pérdida del discriminador. hachas[0].plot(pasos, disc_losses, “-“) ejes[0].set_title(“Pérdida del discriminador”, tamaño de fuente=20) # Trazar la pérdida del generador. hachas[1].plot(pasos, gen_losses, ‘-‘) ejes[1].set_title(“Pérdida del generador”, tamaño de fuente=20);
Podemos observar que la pérdida del generador fue bastante alta durante los 2000 pasos iniciales y después de 3000 pasos, las pérdidas del discriminador y del generador se mantuvieron aproximadamente constantes en promedio.
Conclusión
Just After eXceution (JAX) es un cálculo numérico de alto rendimiento, particularmente en la investigación de aprendizaje automático. Su API numérica se basa en NumPy, una biblioteca de funciones utilizadas en computación científica. Con este artículo hemos entendido el ecosistema de JAX y la implementación de Optax y Haiku que son parte de ese ecosistema.