Need Help with this Trax Error Code #6295
Unanswered
memora0101
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
for sent in train_x:
inputs = np.array(sent_to_tensor(sent, vocab_dict=Vocab))
inputs = inputs[None, :]
predictions=model(inputs)
print(f'example input_str: {sent}')
print(f'Model returned sentiment probabilities: {predictions}')
Error:
LayerError Traceback (most recent call last)
in
2 inputs = np.array(sent_to_tensor(sent, vocab_dict=Vocab))
3 inputs = inputs[None, :]
----> 4 predictions=model(inputs)
5 print(f'example input_str: {sent}')
6 print(f'Model returned sentiment probabilities: {predictions}')
~/opt/anaconda3/lib/python3.7/site-packages/trax/layers/base.py in call(self, x, weights, state, rng)
190 self.state = state # Needed if the model wasn't fully initialized.
191 state = self.state
--> 192 outputs, new_state = self.pure_fn(x, weights, state, rng)
193 self.state = new_state
194 self.weights = weights
~/opt/anaconda3/lib/python3.7/site-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng, use_cache)
547 name, trace = self._name, _short_traceback(skip=3)
548 raise LayerError(name, 'pure_fn',
--> 549 self._caller, signature(x), trace) from None
550
551 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/, line 27
layer input shapes: ShapeDtype{shape:(1, 0), dtype:float32}
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Embedding_437_256 (in pure_fn):
layer created in file [...]/, line 8
layer input shapes: ShapeDtype{shape:(1, 0), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 181, in forward
embedded = jnp.take(self.weights, x, axis=0)
File [...]/_src/numpy/lax_numpy.py, line 4078, in take
slice_sizes=tuple(slice_sizes))
File [...]/_src/lax/lax.py, line 874, in gather
slice_sizes=canonicalize_shape(slice_sizes))
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/site-packages/jax/core.py, line 628, in process_primitive
return primitive.impl(*tracers, **params)
File [...]/jax/interpreters/xla.py, line 238, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File [...]/jax/_src/util.py, line 198, in wrapper
return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)
File [...]/jax/_src/util.py, line 191, in cached
return f(*args, **kwargs)
File [...]/jax/interpreters/xla.py, line 263, in xla_primitive_callable
aval_out = prim.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
File [...]/_src/lax/lax.py, line 4114, in _gather_dtype_rule
raise ValueError("start_indices must have an integer type")
ValueError: start_indices must have an integer type
Beta Was this translation helpful? Give feedback.
All reactions