Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to quantize ViT model with quantization aware training #374

Open
Linsop2 opened this issue Oct 30, 2024 · 3 comments
Open

How to quantize ViT model with quantization aware training #374

Linsop2 opened this issue Oct 30, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@Linsop2
Copy link

Linsop2 commented Oct 30, 2024

It can train the ViT model from the Hugging Face transformer,
but when converting to tflite model it appear an error message that I can't solve it.
The following are the tinynn setting and the error message

Transformers version is 4.26.0

from tinynn.graph.tracer import import_patcher
from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import QATQuantizer

with import_patcher():
    from transformers import ViTForImageClassification

device = 'cuda'

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class ViTWrapper(nn.Module):
    def __init__(self, vit):
        super().__init__()
        self.vit = vit

    def forward(self, x):
        return self.vit(x).logits
    
model_url = 'google/vit-base-patch16-224'

with import_patcher():
        model = ViTForImageClassification.from_pretrained(model_url)
        Vit = ViTWrapper(model)

################ QAT SET #######################
dummy = torch.rand([1,3,224,224])
mean = (0.485 + 0.456 + 0.406) / 3 * 255 
std = (0.229 + 0.224 + 0.225) / 3 * 255
config={'asymmetric': True, 'per_tensor': False, 'backend':'fbgemm',"rewrite_graph": False,'quantized_input_stats':[(float(mean), float(std))]}
quantizer = QATQuantizer(Vit, dummy, work_dir='out',config=config)
Vit = quantizer.quantize()
Vit.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
Vit = torch.quantization.prepare_qat(Vit)
Vit.cuda()

#####Skip  training  Step#####

##### Convert tflite Step#####
    if((epoch) % 1 == 0):
        with torch.no_grad():
            device = 'cpu'
            Vit.eval()
            Vit.to(device)
            Vit = quantizer.convert(Vit)
            torch.backends.quantized.engine = quantizer.backend
            converter = TFLiteConverter(Vit, dummy, tflite_path='./qat_model.tflite',fuse_quant_dequant=True)
            converter.convert()

The error message:

Traceback (most recent call last):
  File "c:\Users\user\Desktop\ViT\tinynn_PTQ\QAT.py", line 100, in <module>
    converter.convert()
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 512, in convert
    self.init_jit_graph()
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 237, in init_jit_graph
    script = torch.jit.trace(self.model, self.dummy_input)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 1000, in trace
    traced_func = _trace_impl(
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 695, in _trace_impl
    return trace_module(
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 1275, in trace_module
    module._c._create_method_from_trace(
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "c:\Users\user\Desktop\ViT\tinynn_PTQ\QAT.py", line 39, in forward
    return self.vit(x).logits
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\transformers\models\vit\modeling_vit.py", line 789, in forward
    outputs = self.vit(
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\transformers\models\vit\modeling_vit.py", line 571, in forward
    expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
AttributeError: 'function' object has no attribute 'dtype
@peterjc123 peterjc123 added bug Something isn't working question Further information is requested and removed bug Something isn't working labels Oct 30, 2024
@peterjc123
Copy link
Collaborator

Just noticed that you are not using the Quantized graph rewrite of TinyNN as I can see the following option in your code. "rewrite_graph": False and torch.quantization.prepare_qat. Just FYI, the TFLiteConverter won't work for most of the cases if you don't use our quantization tools.

@Linsop2
Copy link
Author

Linsop2 commented Oct 31, 2024

I modify the QAT setting part, but it appears another error

Traceback (most recent call last):
  File "c:\Users\user\Desktop\ViT\QAT.py", line 91, in <module>
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 513, in convert
    self.init_lowered_module()
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 299, in init_lowered_module
    torch._C._jit_pass_lower_all_tuples(graph)
RuntimeError: prim::TupleUnpack not matched to tuple construct

This is my whole code

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from tqdm import tqdm
from tinynn.graph.tracer import import_patcher
from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import QATQuantizer
from torch.optim.lr_scheduler import StepLR

with import_patcher():
    from transformers import ViTForImageClassification

BATCH_SIZE = 16
LEARNING_RATE = 0.0003
EPOCH = 1
N_CLASSES = 1000

device = 'cuda'
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

trainData = dsets.ImageFolder(r"C:\Users\user\Desktop\dataset\train", transform)
testData =  dsets.ImageFolder(r"C:\Users\user\Desktop\dataset\test", transform)
trainLoader = torch.utils.data.DataLoader(dataset=trainData, batch_size=BATCH_SIZE, shuffle=True)
testLoader  = torch.utils.data.DataLoader(dataset=testData, batch_size=BATCH_SIZE, shuffle=False)

class ViTWrapper(nn.Module):
    def __init__(self, vit):
        super().__init__()
        self.vit = vit

    def forward(self, x):
        return self.vit(x).logits
    
model_url = 'google/vit-base-patch16-224'

with import_patcher():
        model = ViTForImageClassification.from_pretrained(model_url)
        Vit = ViTWrapper(model)
        
print(f'Learning rate is : {LEARNING_RATE}')
print(f'Batch size is : {BATCH_SIZE}')

################ QAT SET #######################
dummy = torch.rand([1,3,224,224])
config={'asymmetric': True, 'per_tensor': False, 'backend':'fbgemm','rewrite_graph':True}
quantizer = QATQuantizer(Vit, dummy, work_dir='out',config=config)
Vit = quantizer.quantize()
Vit.cuda()
################ QAT SET #######################

################ Training #######################
# Loss, Optimizer & Scheduler
cost = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(Vit.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

# Train the model
for epoch in range(1,EPOCH+1):
    Vit.train()
    avg_loss = 0
    cnt = 0
    progress_bar = tqdm(trainLoader, desc=f"Epoch {epoch}/{EPOCH}", leave=False)
    for images, labels in progress_bar:
        images = images.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()
        outputs = Vit(images)
        loss = cost(outputs, labels)
        avg_loss += loss.item()
        cnt += 1
        loss.backward()
        optimizer.step()
        progress_bar.set_postfix({'loss': f"{loss.item():.4f}", 'avg_loss': f"{avg_loss/cnt:.4f}"})
    
    print(f"[Epoch: {epoch}] avg_loss: {avg_loss/cnt:.4f}")
    scheduler.step()
    torch.save(Vit.state_dict(), f'./epoch_qat{epoch}.pth')

################### Convert model #########################
    if((epoch) % EPOCH == 0):
        with torch.no_grad():
            Vit.eval()
            Vit.cpu()
            Vit = quantizer.convert(Vit)
            torch.backends.quantized.engine = quantizer.backend
            converter = TFLiteConverter(Vit, dummy, tflite_path='./qat_model.tflite')
            converter.convert()

@peterjc123 peterjc123 added bug Something isn't working and removed question Further information is requested labels Oct 31, 2024
@peterjc123
Copy link
Collaborator

We will take a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants