From d5a8576f15197617243adf6d6c688aff53e40b0f Mon Sep 17 00:00:00 2001 From: richard3983 Date: Fri, 22 May 2020 17:38:14 -0400 Subject: [PATCH] Load model from model-zoo --- models/bert/__main__.py | 9 ++++----- models/bert/args.py | 3 ++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/models/bert/__main__.py b/models/bert/__main__.py index 144010c..9a47abf 100644 --- a/models/bert/__main__.py +++ b/models/bert/__main__.py @@ -71,8 +71,7 @@ def evaluate_split(model, processor, tokenizer, args, split='dev'): args.is_hierarchical = False processor = dataset_map[args.dataset]() - pretrained_vocab_path = PRETRAINED_VOCAB_ARCHIVE_MAP[args.model] - tokenizer = BertTokenizer.from_pretrained(pretrained_vocab_path) + tokenizer = BertTokenizer.from_pretrained(args.model) train_examples = None num_train_optimization_steps = None @@ -81,8 +80,8 @@ def evaluate_split(model, processor, tokenizer, args, split='dev'): num_train_optimization_steps = int( len(train_examples) / args.batch_size / args.gradient_accumulation_steps) * args.epochs - pretrained_model_path = args.model if os.path.isfile(args.model) else PRETRAINED_MODEL_ARCHIVE_MAP[args.model] - model = BertForSequenceClassification.from_pretrained(pretrained_model_path, num_labels=args.num_labels) + pretrained_model = args.model + model = BertForSequenceClassification.from_pretrained(pretrained_model, num_labels=args.num_labels) if args.fp16: model.half() @@ -126,7 +125,7 @@ def evaluate_split(model, processor, tokenizer, args, split='dev'): model = torch.load(trainer.snapshot_path) else: - model = BertForSequenceClassification.from_pretrained(pretrained_model_path, num_labels=args.num_labels) + model = BertForSequenceClassification.from_pretrained(pretrained_model, num_labels=args.num_labels) model_ = torch.load(args.trained_model, map_location=lambda storage, loc: storage) state = {} for key in model_.state_dict().keys(): diff --git a/models/bert/args.py b/models/bert/args.py index 477249e..895160f 100644 --- a/models/bert/args.py +++ b/models/bert/args.py @@ -6,7 +6,8 @@ def get_args(): parser = models.args.get_args() - parser.add_argument('--model', default=None, type=str, required=True) + parser.add_argument('--model', default=None, type=str, required=True, + choices=['bert-base-uncased','bert-large-uncased','bert-base-cased','bert-large-cased']) parser.add_argument('--dataset', type=str, default='SST-2', choices=['SST-2', 'AGNews', 'Reuters', 'AAPD', 'IMDB', 'Yelp2014']) parser.add_argument('--save-path', type=str, default=os.path.join('model_checkpoints', 'bert')) parser.add_argument('--cache-dir', default='cache', type=str)