Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cfg as gradient experiment - unsuccessful #742

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from aqt.jax.v2 import aqt_tensor
from aqt.jax.v2 import transpose
from aqt.jax.v2 import utils
from aqt.jax.v2.flax.delayed_scaling_calibration import \
DelayedScalingCalibrationOWG
from aqt.jax.v2.numerics import fp8_numerics
import jax
from jax import lax
Expand Down Expand Up @@ -1206,7 +1208,18 @@ def grad_dot_general(
cfg.drhs,
True,
)

def update_differentiable_config_values(cfg):
for dg in [cfg.fwd, cfg.dlhs, cfg.drhs]:
for calibrator, calibration_config in ((dg.dg_quantizer.lhs._calibrator, dg.dg_quantizer.lhs.calibration_config), (dg.dg_quantizer.rhs._calibrator, dg.dg_quantizer.rhs.calibration_config)):
if isinstance(calibrator, DelayedScalingCalibrationOWG):
calibration_config.amax_history = calibrator.amax_history
calibration_config.bound = calibrator.bound
dg.dg_quantizer.lhs._calibrator = None
dg.dg_quantizer.rhs._calibrator = None
return cfg

# fwd_dimension_numbers are marked as nondiff_argnums instead of returning
# None as grad to it. This is because it is a tuple of Python integers
# that cannot be traced by Jax.
return (dlhs, drhs, None, None, None)
return (dlhs, drhs, None, None, update_differentiable_config_values(cfg))
5 changes: 4 additions & 1 deletion aqt/jax/v2/aqt_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ class Quantizer:
# TODO(yichizh): Factor out auxiliary dataclasses into a separate file.
context: utils.Context

calibration_config: None | utils.CalibrationConfig = None


# we need to speed up this initialization for the backward pass to happen
# outside of bwd pass.
def init_calibration(self):
assert self._calibrator is None, "second call to self.init_calibration()"
if self.calibration is not None:
self._calibrator = self.calibration(dtype=self.scale_dtype)
self._calibrator.init_calibration()
self._calibrator.init_calibration(self.calibration_config)

# TODO(yichizh): Need to add type annotation back to cfg.
def quant(
Expand Down
2 changes: 1 addition & 1 deletion aqt/jax/v2/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_scale_and_bias(
# quantized value aligns with the minimum quantization bucket.
pass

def init_calibration(self):
def init_calibration(self, calibration_config: None | utils.CalibrationConfig):
pass


Expand Down
1 change: 1 addition & 0 deletions aqt/jax/v2/examples/flax_e2e_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def update_model_params_with_grads(state, grads, updated_var):
updates, new_opt_state = state.tx.update(param_grad, state.opt_state, params)
new_params = optax.apply_updates(params, updates)
updated_var.update(params=new_params)
updated_var['_overwrite_with_gradient'] = grads['_overwrite_with_gradient']
return state.replace(
model=updated_var,
opt_state=new_opt_state,
Expand Down
Loading