diff --git a/PARAMETERS.md b/PARAMETERS.md new file mode 100644 index 00000000..94d63798 --- /dev/null +++ b/PARAMETERS.md @@ -0,0 +1,87 @@ +## LoraConfig Parameters + +Adjusting the `LoraConfig` parameters allows you to balance model performance and computational efficiency in Low-Rank Adaptation (LoRA). Here’s a concise breakdown of key parameters: + +**r** +- **Description**: Rank of the low-rank decomposition for factorizing weight matrices. +- **Impact**: + - **Higher**: Retains more information, increases computational load. + - **Lower**: Fewer parameters, more efficient training, potential performance drop if too small. + + +**lora_alpha** +- **Description**: Scaling factor for the low-rank matrices' contribution. +- **Impact**: + - **Higher**: Increases influence, speeds up convergence, risks instability or overfitting. + - **Lower**: Subtler effect, may require more training steps. + +**lora_dropout** +- **Description**: Probability of zeroing out elements in low-rank matrices for regularization. +- **Impact**: + - **Higher**: More regularization, prevents overfitting, may slow training and degrade performance. + - **Lower**: Less regularization, may speed up training, risks overfitting. + +**loftq_config** +- **Description**: Configuration for LoftQ, a quantization method for the backbone weights and initialization of LoRA layers. +- **Impact**: + - **Not None**: If specified, LoftQ will quantize the backbone weights and initialize the LoRA layers. It requires setting `init_lora_weights='loftq'`. + - **None**: LoftQ quantization is not applied. + - **Note**: Do not pass an already quantized model when using LoftQ as LoftQ handles the quantization process itself. + + +**use_rslora** +- **Description**: Enables Rank-Stabilized LoRA (RSLora). +- **Impact**: + - **True**: Uses Rank-Stabilized LoRA, setting the adapter scaling factor to `lora_alpha/math.sqrt(r)`, which has been proven to work better as per the [Rank-Stabilized LoRA paper](https://doi.org/10.48550/arXiv.2312.03732). + - **False**: Uses the original default scaling factor `lora_alpha/r`. + +**gradient_accumulation_steps** +- **Default**: 1 +- **Description**: The number of steps to accumulate gradients before performing a backpropagation update. +- **Impact**: + - **Higher**: Accumulate gradients over multiple steps, effectively increasing the batch size without requiring additional memory. This can improve training stability and convergence, especially with large models and limited hardware. + - **Lower**: Faster updates but may require more memory per step and can be less stable. + +**weight_decay** +- **Default**: 0.01 +- **Description**: Regularization technique that applies a small penalty to the weights during training. +- **Impact**: + - **Non-zero Value (e.g., 0.01)**: Adds a penalty proportional to the magnitude of the weights to the loss function, helping to prevent overfitting by discouraging large weights. + - **Zero**: No weight decay is applied, which can lead to overfitting, especially in large models or with small datasets. + +**learning_rate** +- **Default**: 2e-4 +- **Description**: The rate at which the model updates its parameters during training. +- **Impact**: + - **Higher**: Faster convergence but risks overshooting optimal parameters and causing instability in training. + - **Lower**: More stable and precise updates but may slow down convergence, requiring more training steps to achieve good performance. + +## Target Modules + +**q_proj (query projection)** +- **Description**: Part of the attention mechanism in transformer models, responsible for projecting the input into the query space. +- **Impact**: Transforms the input into query vectors that are used to compute attention scores. + +**k_proj (key projection)** +- **Description**: Projects the input into the key space in the attention mechanism. +- **Impact**: Produces key vectors that are compared with query vectors to determine attention weights. + +**v_proj (value projection)** +- **Description**: Projects the input into the value space in the attention mechanism. +- **Impact**: Produces value vectors that are weighted by the attention scores and combined to form the output. + +**o_proj (output projection)** +- **Description**: Projects the output of the attention mechanism back into the original space. +- **Impact**: Transforms the combined weighted value vectors back to the input dimension, integrating attention results into the model. + +**gate_proj (gate projection)** +- **Description**: Typically used in gated mechanisms within neural networks, such as gating units in gated recurrent units (GRUs) or other gating mechanisms. +- **Impact**: Controls the flow of information through the gate, allowing selective information passage based on learned weights. + +**up_proj (up projection)** +- **Description**: Used for up-projection, typically increasing the dimensionality of the input. +- **Impact**: Expands the input to a higher-dimensional space, often used in feedforward layers or when transitioning between different layers with differing dimensionalities. + +**down_proj (down projection)** +- **Description**: Used for down-projection, typically reducing the dimensionality of the input. +- **Impact**: Compresses the input to a lower-dimensional space, useful for reducing computational complexity and controlling the model size. diff --git a/unsloth/__init__.py b/unsloth/__init__.py index d85eca00..93960e2f 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -14,8 +14,20 @@ import os import warnings import importlib +import sys +from packaging.version import Version -# Currently only supports 1 GPU, or else seg faults will occur. +# Define a list of modules to check +MODULES_TO_CHECK = ["peft", "bitsandbytes"] + +# Check if any of the modules in the list have been imported +for module in MODULES_TO_CHECK: + if module in sys.modules: + raise ImportError(f"Unsloth: Please import Unsloth before {module}.") + pass +pass + +# Currently only supports 1 GPU, or else seg faults will occur. if "CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" devices = os.environ["CUDA_VISIBLE_DEVICES"] @@ -66,8 +78,14 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Try loading bitsandbytes and triton import bitsandbytes as bnb + import triton -from triton.common.build import libcuda_dirs +libcuda_dirs = lambda: None +if Version(triton.__version__) >= Version("3.0.0"): + try: from triton.backends.nvidia.driver import libcuda_dirs + except: pass +else: from triton.common.build import libcuda_dirs + import os import re import numpy as np @@ -103,8 +121,11 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 importlib.reload(bnb) importlib.reload(triton) try: - import bitsandbytes as bnb - from triton.common.build import libcuda_dirs + libcuda_dirs = lambda: None + if Version(triton.__version__) >= Version("3.0.0"): + try: from triton.backends.nvidia.driver import libcuda_dirs + except: pass + else: from triton.common.build import libcuda_dirs cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 libcuda_dirs() except: diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 4c782326..2e3761f5 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1286,7 +1286,7 @@ def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf") pass for prompt in prompts: - command = f"./llama.cpp/main -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\ + command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\ f"--check-tensors -p '{prompt}'" datas = [] diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index b1fdba83..ebea02af 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -24,6 +24,7 @@ ) from .fast_lora import ( get_lora_parameters, + get_lora_parameters_bias, apply_lora_mlp_swiglu, apply_lora_mlp_geglu_exact, apply_lora_mlp_geglu_approx, diff --git a/unsloth/kernels/fast_lora.py b/unsloth/kernels/fast_lora.py index aba44f02..8f7aea58 100644 --- a/unsloth/kernels/fast_lora.py +++ b/unsloth/kernels/fast_lora.py @@ -13,7 +13,13 @@ # limitations under the License. import torch -from .utils import fast_dequantize, QUANT_STATE, get_lora_parameters, matmul_lora +from .utils import ( + fast_dequantize, + QUANT_STATE, + get_lora_parameters, + get_lora_parameters_bias, + matmul_lora, +) class LoRA_MLP(torch.autograd.Function): diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 3bc091b3..de1e2e57 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -33,11 +33,8 @@ def _get_model_name(model_name, load_in_4bit = True): - # First try replacing lowercase 'b' with uppercase 'B' - model_name = model_name.lower() - if not SUPPORTS_FOURBIT and model_name in INT_TO_FLOAT_MAPPER: - model_name = INT_TO_FLOAT_MAPPER[model_name] + model_name = INT_TO_FLOAT_MAPPER[model_name.lower()] logger.warning_once( f"Unsloth: Your transformers version of {transformers_version} does not support native "\ f"4bit loading.\nThe minimum required version is 4.37.\n"\ @@ -47,7 +44,7 @@ def _get_model_name(model_name, load_in_4bit = True): ) elif not load_in_4bit and model_name in INT_TO_FLOAT_MAPPER: - new_model_name = INT_TO_FLOAT_MAPPER[model_name] + new_model_name = INT_TO_FLOAT_MAPPER[model_name.lower()] # logger.warning_once( # f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\ # f"`load_in_4bit = False`. We shall load `{new_model_name}` instead." @@ -55,7 +52,7 @@ def _get_model_name(model_name, load_in_4bit = True): model_name = new_model_name elif load_in_4bit and SUPPORTS_FOURBIT and model_name in FLOAT_TO_INT_MAPPER: - new_model_name = FLOAT_TO_INT_MAPPER[model_name] + new_model_name = FLOAT_TO_INT_MAPPER[model_name.lower()] # logger.warning_once( # f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\ # f"We shall load `{new_model_name}` for 4x faster loading." @@ -70,17 +67,18 @@ def _get_model_name(model_name, load_in_4bit = True): class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( - model_name = "unsloth/llama-3-8b-bnb-4bit", - max_seq_length = None, - dtype = None, - load_in_4bit = True, - token = None, - device_map = "sequential", - rope_scaling = None, - fix_tokenizer = True, - trust_remote_code = False, - use_gradient_checkpointing = True, - resize_model_vocab = None, + model_name = "unsloth/llama-3-8b-bnb-4bit", + max_seq_length = None, + dtype = None, + load_in_4bit = True, + token = None, + device_map = "sequential", + rope_scaling = None, + fix_tokenizer = True, + trust_remote_code = False, + use_gradient_checkpointing = "unsloth", + resize_model_vocab = None, + revision = None, *args, **kwargs, ): if token is None and "HF_TOKEN" in os.environ: @@ -95,12 +93,12 @@ def from_pretrained( # First check if it's a normal model via AutoConfig is_peft = False try: - model_config = AutoConfig.from_pretrained(model_name, token = token) + model_config = AutoConfig.from_pretrained(model_name, token = token, revision = revision) is_peft = False except: try: # Most likely a PEFT model - peft_config = PeftConfig.from_pretrained(model_name, token = token) + peft_config = PeftConfig.from_pretrained(model_name, token = token, revision = revision) except: raise RuntimeError(f"Unsloth: `{model_name}` is not a full model or a PEFT model.") @@ -143,22 +141,24 @@ def from_pretrained( pass model, tokenizer = dispatch_model.from_pretrained( - model_name = model_name, - max_seq_length = max_seq_length, - dtype = dtype, - load_in_4bit = load_in_4bit, - token = token, - device_map = device_map, - rope_scaling = rope_scaling, - fix_tokenizer = fix_tokenizer, - model_patcher = dispatch_model, - tokenizer_name = tokenizer_name, + model_name = model_name, + max_seq_length = max_seq_length, + dtype = dtype, + load_in_4bit = load_in_4bit, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, + fix_tokenizer = fix_tokenizer, + model_patcher = dispatch_model, + tokenizer_name = tokenizer_name, trust_remote_code = trust_remote_code, + revision = revision if not is_peft else None, *args, **kwargs, ) if resize_model_vocab is not None: model.resize_token_embeddings(resize_model_vocab) + pass # In case the model supports tagging, add the unsloth tag. if hasattr(model, "add_model_tags"): @@ -188,8 +188,16 @@ def from_pretrained( pass if is_peft: + # From https://github.com/huggingface/peft/issues/184 # Now add PEFT adapters - model = PeftModel.from_pretrained(model, old_model_name, token = token) + model.enable_input_require_grads() + model = PeftModel.from_pretrained( + model, + old_model_name, + token = token, + revision = revision, + is_trainable = True, + ) # Patch it as well! model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing) pass diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 73aa06ca..5ef75839 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -186,6 +186,9 @@ "unsloth/Qwen2-70B-Instruct-bnb-4bit" : ( "Qwen/Qwen2-70B-Instruct", ), + "mistralai/Codestral-22B-v0.1" : ( + "mistral-community/Codestral-22B-v0.1", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/save.py b/unsloth/save.py index 3ad2f346..cae59cae 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -22,7 +22,7 @@ import pickle import gc from transformers.models.llama.modeling_llama import logger -from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters +from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters_bias import subprocess import psutil import re @@ -132,9 +132,10 @@ def _free_cached_model(model): def _merge_lora(layer, name): + bias = None if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)): # Is LoRA so we need to merge! - W, quant_state, A, B, s = get_lora_parameters(layer) + W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer) if quant_state is not None: dtype = quant_state.dtype if type(quant_state) is not list else quant_state[2] W = fast_dequantize(W, quant_state) @@ -156,7 +157,7 @@ def _merge_lora(layer, name): W = W.t().to(dtype) else: W = layer.weight - return W + return W, bias pass @@ -527,7 +528,12 @@ def unsloth_save_model( for item in LLAMA_WEIGHTS: proj = eval(f"layer.{item}") name = f"model.layers.{j}.{item}.weight" - W = _merge_lora(proj, name) + W, bias = _merge_lora(proj, name) + + # Bias term + if bias is not None: + state_dict[f"model.layers.{j}.{item}.bias"] = bias + pass if (torch.cuda.memory_allocated() + W.nbytes) < max_vram: # Save to GPU memory @@ -643,7 +649,8 @@ def unsloth_save_model( model.config = new_config # Save! - + + save_pretrained_settings["selected_adapters"] = None # Check if pushing to an organization if save_pretrained_settings["push_to_hub"] and (username != actual_username): print(f"Unsloth: Saving to organization with address {new_save_directory}") @@ -785,7 +792,7 @@ def install_llama_cpp_old(version = -10): pass pass # Check if successful - if not os.path.exists("llama.cpp/quantize"): + if not os.path.exists("llama.cpp/quantize") and not os.path.exists("llama.cpp/llama-quantize"): raise RuntimeError( "Unsloth: llama.cpp GGUF seems to be too buggy to install.\n"\ "File a report to llama.cpp's main repo since this is not an Unsloth issue." @@ -794,7 +801,7 @@ def install_llama_cpp_old(version = -10): pass -def install_llama_cpp_blocking(use_cuda = True): +def install_llama_cpp_blocking(use_cuda = False): # https://github.com/ggerganov/llama.cpp/issues/7062 # Weirdly GPU conversion for GGUF breaks?? # use_cuda = "LLAMA_CUDA=1" if use_cuda else "" @@ -822,49 +829,6 @@ def install_llama_cpp_blocking(use_cuda = True): pass -def _fix_gemma_gguf(): - # Fixes Gemma saving to GGUF to float32 instead of float16! - with open("llama.cpp/convert-hf-to-gguf.py", "rb") as file: - text = file.read() - pass - - gemma_start = text.find(b"class GemmaModel(Model):") - if gemma_start == -1: return - - gemma_end = text.find(b"self.gguf_writer.add_tensor(new_name, data)", gemma_start) - if gemma_end == -1: return - - gemma_text = text[gemma_start : gemma_end] - bad_text = \ -b""" data = data.astype(np.float32) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: - data = data.astype(np.float16)""" - good_text = \ -b""" # if f32 desired, convert any float16 to float32 - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) - - # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 - if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: - data = data.astype(np.float32) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: - data = data.astype(np.float16)""" - find_bad = gemma_text.find(bad_text) - if find_bad == -1: return - - gemma_text = gemma_text[:find_bad] + good_text + gemma_text[find_bad + len(bad_text):] - text = text[:gemma_start] + gemma_text + text[gemma_end:] - - with open("llama.cpp/convert-hf-to-gguf.py", "w+b") as file: - file.write(text) - pass -pass - - def save_to_gguf( model_type : str, model_dtype : str, @@ -930,7 +894,7 @@ def save_to_gguf( # Check first_conversion format if first_conversion == "f16" : pass - if first_conversion == "bf16" : pass + elif first_conversion == "bf16" : pass elif first_conversion == "f32" : pass elif first_conversion == "q8_0" : pass else: @@ -946,8 +910,20 @@ def save_to_gguf( error = 0 install_llama_cpp_blocking() pass + # Check if successful. If not install 10th latest release - if error != 0 or not os.path.exists("llama.cpp/quantize"): + + # Careful llama.cpp/quantize changed to llama.cpp/llama-quantize + # and llama.cpp/main changed to llama.cpp/llama-cli + # See https://github.com/ggerganov/llama.cpp/pull/7809 + quantize_location = None + if os.path.exists("llama.cpp/quantize"): + quantize_location = "llama.cpp/quantize" + elif os.path.exists("llama.cpp/llama-quantize"): + quantize_location = "llama.cpp/llama-quantize" + pass + + if error != 0 or quantize_location is None: print(f"Unsloth: llama.cpp error code = {error}.") install_llama_cpp_old(-10) pass @@ -1017,9 +993,6 @@ def save_to_gguf( f"--outfile {final_location} --vocab-type {vocab_type} "\ f"--outtype {first_conversion} --concurrency {n_cpus} --pad-vocab" else: - # Need to fix convert-hf-to-gguf.py for some models! - # _fix_gemma_gguf() - command = f"python llama.cpp/convert-hf-to-gguf.py {model_directory} "\ f"--outfile {final_location} "\ f"--outtype {first_conversion}" @@ -1065,7 +1038,7 @@ def save_to_gguf( print(f"Unsloth: [2] Converting GGUF 16bit into {quantization_method}. This will take 20 minutes...") final_location = f"./{model_directory}-unsloth.{quantization_method.upper()}.gguf" - command = f"./llama.cpp/quantize {old_location} "\ + command = f"./{quantize_location} {old_location} "\ f"{final_location} {quantization_method} {n_cpus}" # quantize uses stderr @@ -1654,6 +1627,140 @@ def unsloth_push_to_hub_gguf( pass pass +# Corrected function to save LoRA to a custom directory +def save_lora_to_custom_dir(model, tokenizer, save_directory): + # Create the custom directory if it doesn't exist + os.makedirs(save_directory, exist_ok=True) + + # Call the unsloth_save_model function with the custom directory + unsloth_save_model( + model, + tokenizer, + save_directory=save_directory, + save_method="lora", + push_to_hub=False, + ) + +# Corrected method within the model class to convert LoRA to GGML and push to Hugging Face Hub +def unsloth_convert_lora_to_ggml_and_push_to_hub( + self, + tokenizer, + repo_id: str, + use_temp_dir: Optional[bool] = None, + commit_message: Optional[str] = "Converted LoRA to GGML with Unsloth", + private: Optional[bool] = None, + token: Union[bool, str, None] = None, + create_pr: bool = False, + revision: str = None, + commit_description: str = "Convert LoRA to GGML format using Unsloth", + temporary_location: str = "_unsloth_temporary_saved_buffers", + maximum_memory_usage: float = 0.85, +): + if not os.path.exists("llama.cpp"): + if IS_KAGGLE_ENVIRONMENT: + python_install = install_python_non_blocking(["protobuf"]) + python_install.wait() + install_llama_cpp_blocking(use_cuda=False) + makefile = None + else: + git_clone = install_llama_cpp_clone_non_blocking() + python_install = install_python_non_blocking(["protobuf"]) + git_clone.wait() + makefile = install_llama_cpp_make_non_blocking() + python_install.wait() + else: + makefile = None + + for _ in range(3): + gc.collect() + + lora_directory_push = "lora-to-ggml-push" + save_lora_to_custom_dir(self, tokenizer, lora_directory_push) + + model_type = self.config.model_type + output_file = os.path.join(lora_directory_push, "ggml-adapter-model.bin") + + print(f"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format.") + print(f"The output file will be {output_file}") + + command = f"python3 llama.cpp/convert-lora-to-ggml.py {lora_directory_push} {output_file} llama" + + try: + with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp: + for line in sp.stdout: + print(line, end="", flush=True) + for line in sp.stderr: + print(line, end="", flush=True) + sp.wait() + if sp.returncode != 0: + raise subprocess.CalledProcessError(sp.returncode, command) + except subprocess.CalledProcessError as e: + print(f"Error: Conversion failed with return code {e.returncode}") + return + + print(f"Unsloth: Conversion completed! Output file: {output_file}") + + print("Unsloth: Uploading GGML file to Hugging Face Hub...") + username = upload_to_huggingface( + self, repo_id, token, + "GGML converted LoRA", "ggml", output_file, None, private, + ) + link = f"{repo_id.lstrip('/')}" + print("Unsloth: Done.") + print(f"Converted LoRA to GGML and uploaded to https://huggingface.co/{link}") + print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!") + +def unsloth_convert_lora_to_ggml_and_save_locally( + self, + save_directory: str, # Added parameter for the folder name + tokenizer, + temporary_location: str = "_unsloth_temporary_saved_buffers", + maximum_memory_usage: float = 0.85, +): + if not os.path.exists("llama.cpp"): + if IS_KAGGLE_ENVIRONMENT: + python_install = install_python_non_blocking(["protobuf"]) + python_install.wait() + install_llama_cpp_blocking(use_cuda=False) + makefile = None + else: + git_clone = install_llama_cpp_clone_non_blocking() + python_install = install_python_non_blocking(["protobuf"]) + git_clone.wait() + makefile = install_llama_cpp_make_non_blocking() + python_install.wait() + else: + makefile = None + + for _ in range(3): + gc.collect() + + # Use the provided save_directory for local saving + save_lora_to_custom_dir(self, tokenizer, save_directory) + + model_type = self.config.model_type + output_file = os.path.join(save_directory, "ggml-adapter-model.bin") + + print(f"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format.") + print(f"The output file will be {output_file}") + + command = f"python3 llama.cpp/convert-lora-to-ggml.py {save_directory} {output_file} llama" + + try: + with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp: + for line in sp.stdout: + print(line, end="", flush=True) + for line in sp.stderr: + print(line, end="", flush=True) + sp.wait() + if sp.returncode != 0: + raise subprocess.CalledProcessError(sp.returncode, command) + except subprocess.CalledProcessError as e: + print(f"Error: Conversion failed with return code {e.returncode}") + return + print("Unsloth: Done.") + print(f"Unsloth: Conversion completed! Output file: {output_file}") + print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!") def patch_saving_functions(model): import inspect @@ -1746,10 +1853,12 @@ def patch_saving_functions(model): # Add saving methods to top level model if hasattr(model, "config"): # Counteract tokenizers - model.push_to_hub_merged = types.MethodType(unsloth_push_to_hub_merged, model) - model.save_pretrained_merged = types.MethodType(unsloth_save_pretrained_merged, model) - model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model) - model.save_pretrained_gguf = types.MethodType(unsloth_save_pretrained_gguf, model) + model.push_to_hub_merged = types.MethodType(unsloth_push_to_hub_merged, model) + model.save_pretrained_merged = types.MethodType(unsloth_save_pretrained_merged, model) + model.push_to_hub_gguf = types.MethodType(unsloth_push_to_hub_gguf, model) + model.save_pretrained_gguf = types.MethodType(unsloth_save_pretrained_gguf, model) + model.push_to_hub_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_push_to_hub, model) + model.save_pretrained_ggml = types.MethodType(unsloth_convert_lora_to_ggml_and_save_locally, model) pass return model pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index f10b2c0a..395c3b73 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -232,62 +232,6 @@ def convert_to_fast_tokenizer( "{% endif %}"\ "{% endfor %}" pass - - -def select_correct_slow_tokenizer( - tokenizer_name, - model_max_length = None, - padding_side = "right", - token = None, - trust_remote_code = False, - cache_dir = "huggingface_tokenizers_cache", -): - """ - Returns 'correct' tokenizer by checking if the chat templates are - actually tokenized correctly. - """ - messages = [ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "It's 4."}, - ] - - settings = ( - (False, False, True,), - (False, True, True,), - (True, False, True,), - (True, False, False,), - ) - - for (use_fast, legacy, from_slow,) in settings: - # Default as mentioned by Arthur from HF: - slow_tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - model_max_length = model_max_length, - padding_side = padding_side, - token = token, - trust_remote_code = trust_remote_code, - # Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373 - use_fast = use_fast, - legacy = legacy, - from_slow = from_slow, - cache_dir = cache_dir, - ) - slow_tokenizer_chat_template = slow_tokenizer.chat_template - - slow_tokenizer.chat_template = llama_template - result1 = slow_tokenizer.decode(slow_tokenizer.apply_chat_template(messages)) - slow_tokenizer.chat_template = mistral_template - result2 = slow_tokenizer.decode(slow_tokenizer.apply_chat_template(messages)) - - # If 2 spaces seen, normally wrong! - if " "*2 not in result1 and " "*2 not in result2: - slow_tokenizer.chat_template = slow_tokenizer_chat_template - return slow_tokenizer - pass - pass - # Return fast version as default - return slow_tokenizer -pass def assert_same_tokenization(slow_tokenizer, fast_tokenizer): @@ -508,13 +452,17 @@ def load_correct_tokenizer( # Mainly to solve Deepseek models with no tokenizer.model file slow_tokenizer = None try: - slow_tokenizer = select_correct_slow_tokenizer( + slow_tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, - model_max_length = model_max_length, - padding_side = padding_side, - token = token, + model_max_length = model_max_length, + padding_side = padding_side, + token = token, trust_remote_code = trust_remote_code, - cache_dir = cache_dir, + # Cannot just use use_fast = False as per https://twitter.com/danielhanchen/status/1789659394302718373 + use_fast = False, + legacy = False, + from_slow = True, + cache_dir = cache_dir, ) except: pass @@ -786,7 +734,7 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps = 1e-16): pass # Count all the possible bad tokens - final_counts = np.zeros(len(tokenizer), dtype = np.int64) + final_counts = np.zeros(max(len(tokenizer), embedding_matrix.shape[0]), dtype = np.int64) def mapping(examples): input_ids = examples["input_ids"] counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype = np.int32) @@ -972,7 +920,7 @@ def patch_sft_trainer_tokenizer(): check_text = \ "\n"\ - "test_text = dataset[0][dataset_text_field] if (formatting_func is None or not use_formatting_func) else formatting_func(dataset[0])\n"\ + "test_text = dataset[0][dataset_text_field] if (formatting_func is None or not use_formatting_func) else formatting_func(dataset[0])[0]\n"\ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ "chat_template = '' if chat_template is None else chat_template\n"\ "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\