Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
248 views
in Technique[技术] by (71.8m points)

python - Jax cannot find the static argnums

This is related with this question. I manage to make the most of the code work, except one of the strange thing.

Here is the modified code.

import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp


def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, a_min=0) ** (1 / (alpha - 1))


def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )


def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None

@jax.partial(jax.jit, static_argnums=(2,))
def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()

@jax.partial(jax.jit, static_argnums=(1,3,))
def _entmax(input, axis=-1, alpha=1.5, T=20):
    result = vmap(jax.partial(map_row, alpha=alpha, T=T), axis)(input)
    return result

@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)
    
@jax.partial(jax.jit, static_argnums=(0,2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)


import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, axis=0, alpha=1.5, T=20)).sum()

jax.jit(value_and_grad(toy))(input, weight)

This code will produce an error as follows:

tuple index out of range

which is cause by this line of code

@jax.partial(jax.jit, static_argnums=(2,))
def map_row(z_input, alpha, T):

Even if I replace the function body with nothing but an entity function, the error persists. This is a really strange behavior. However, it is very important for me to get this thing to be static as it will help to unrolled loops.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

This error is due to a wart that I hope will be fixed soon in JAX: static arguments cannot be passed by keyword. In other words, you should change this:

def toy(input, weight):
    return (weight*entmax(input, axis=0, alpha=1.5, T=20)).sum()

to this:

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

The same fix should be applied in calls to max_row.

At this point, you end up with a ValueError because of passing traced variables to functions that require static arguments; the solution will be similar to that in How to handle JAX reshape with JIT.


One additional note: this static_argnums error has recently been improved, and in the next release will be a bit more clear:

ValueError: jitted function has static_argnums=(2,), donate_argnums=() but was called with only 1 positional arguments.

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...