-
-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial jax+dask example. #158
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on Review Jupyter notebook visual diffs & provide feedback on notebooks. Powered by ReviewNB |
machine-learning/jax-haiku-dask-dataframe-distributed-example.ipynb
Outdated
Show resolved
Hide resolved
…ipynb Co-authored-by: Matthew Rocklin <[email protected]>
" df_one_partition = ddf_one_partition.compute()\n", | ||
" scaled_x = jnp.array(df_one_partition[[\"scaled_x\"]].values)\n", | ||
" y = jnp.array(df_one_partition[[\"y\"]].values)\n", | ||
" params, opt_state = update(params, opt_state, scaled_x, y)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be worth taking a look at some of the functionality in dask-ml, which might do some of these things for you already if you're interested.
cc'ing @stsievert and @TomAugspurger
" futures = []\n", | ||
" for ddf_one_partition in ddf_train.partitions:\n", | ||
" # Compute the gradients in parallel\n", | ||
" futures.append(client.submit(dask_compute_grads_one_partition_wrapper, ddf_one_partition, params))\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend instead ...
from dask.distributed import futures_of
futures = futures_of(df.map_partitions(func, **params).persist())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've tried this but .map_partitions()
requires you to return either a Dask.DataFrame or Dask.Series (I think?). My function returns a set of gradients, grads
, which is a Python dictionary (with more python dicts inside, i.e. a tree-like structure), so I don't think this will work in this case (please correct me if I am mistaken).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can probably work around that with to_delayed()
instead of map_partitions. I can take a closer look later.
machine-learning/jax-haiku-dask-dataframe-distributed-example.ipynb
Outdated
Show resolved
Hide resolved
" # Bring the gradients back to the client, and update the model with the optimizer on the client\n", | ||
" grads = future.result()\n", | ||
" updates, opt_state = optimizer.update(grads, opt_state)\n", | ||
" params = optix.apply_updates(params, updates)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also the kind of thing for which Actors is probably a decent fit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I've been trying to think how to perform training with shared parameters (and optimizer state) among workers via Actors. Haven't quite got my head around how this might work yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a start: https://docs.dask.org/en/latest/futures.html#example-parameter-server
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That example doesn't run, maybe a bad merge. I've put in a PR to correct that: dask/dask#6449
This notebook example is a learning exercise during the Scipy2020 Dask sprint to establish how dask might be used to parallelize jax/dm-haiku deep learning model training and prediction.
I've committed my notebook that is working end-to-end, and demonstrates a neural network for learning the sine function.