Skip to content

Commit

Permalink
Fix fp8 implementation which had bit-rotten a bit
Browse files Browse the repository at this point in the history
I only tested with "on-the-fly" bf16 -> fp8 conversion, not the "load
from fp8" codepath.

YAML I tested with:

```
providers:
  - provider_id: quantized
    provider_type: meta-reference-quantized
    config:
      model: Llama3.1-8B-Instruct
      quantization:
        type: fp8
```
  • Loading branch information
ashwinb committed Oct 15, 2024
1 parent 80ada04 commit 09b793c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def build(
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config)
model = convert_to_quantized_model(model, config, ckpt_dir)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
Expand Down Expand Up @@ -228,8 +228,7 @@ def generate(
ignore_index=pad_id,
)

stop_tokens = torch.tensor(self.tokenizer.stop_tokens)

stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import torch

from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region

from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock

from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import Tensor

Expand All @@ -39,6 +40,7 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
Expand All @@ -49,12 +51,14 @@ def convert_to_quantized_model(

from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8

checkpoint = config.checkpoint_config.checkpoint
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"

# Move weights to GPU with quantization
if checkpoint.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
cprint("Loading fp8 scales...", "yellow")
fp8_scales_path = os.path.join(
checkpoint.checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt"
)
assert os.path.isfile(
fp8_scales_path
Expand Down

0 comments on commit 09b793c

Please sign in to comment.