To reproduce, first download the model checkpoint from https://www.kaggle.com/models/google/gemma/flax/2b-it
Clone the repository and install the dependencies:
git clone https://github.com/google-deepmind/gemma.git
cd gemma
python3 -m venv .
./bin/pip install jax-metal absl-py sentencepiece orbax chex flax
Patch it to use float32 params:
sed -i.bu 's/param_state = jax.tree_util.tree_map(jnp.array, params)/param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)/' gemma/params.py
Run sampling and observe the segfault (paths here must reference the checkpoint downloaded in the first step):
PYTHONPATH=$(pwd) ./bin/python3 examples/sampling.py --path_checkpoint ~/models/gemma_2b_it/2b-it --path_tokenizer ~/models/gemma_2b_it/tokenizer.model
Topic:
Machine Learning & AI
SubTopic:
General
Tags: