| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Convert pytorch checkpoints to TensorFlow""" |
|
|
| import argparse |
| import os |
|
|
| from . import ( |
| AlbertConfig, |
| BartConfig, |
| BertConfig, |
| CamembertConfig, |
| CTRLConfig, |
| DistilBertConfig, |
| DPRConfig, |
| ElectraConfig, |
| FlaubertConfig, |
| GPT2Config, |
| LayoutLMConfig, |
| LxmertConfig, |
| OpenAIGPTConfig, |
| RobertaConfig, |
| T5Config, |
| TFAlbertForPreTraining, |
| TFBartForConditionalGeneration, |
| TFBartForSequenceClassification, |
| TFBertForPreTraining, |
| TFBertForQuestionAnswering, |
| TFBertForSequenceClassification, |
| TFCamembertForMaskedLM, |
| TFCTRLLMHeadModel, |
| TFDistilBertForMaskedLM, |
| TFDistilBertForQuestionAnswering, |
| TFDPRContextEncoder, |
| TFDPRQuestionEncoder, |
| TFDPRReader, |
| TFElectraForPreTraining, |
| TFFlaubertWithLMHeadModel, |
| TFGPT2LMHeadModel, |
| TFLayoutLMForMaskedLM, |
| TFLxmertForPreTraining, |
| TFLxmertVisualFeatureEncoder, |
| TFOpenAIGPTLMHeadModel, |
| TFRobertaForCausalLM, |
| TFRobertaForMaskedLM, |
| TFRobertaForSequenceClassification, |
| TFT5ForConditionalGeneration, |
| TFTransfoXLLMHeadModel, |
| TFWav2Vec2Model, |
| TFXLMRobertaForMaskedLM, |
| TFXLMWithLMHeadModel, |
| TFXLNetLMHeadModel, |
| TransfoXLConfig, |
| Wav2Vec2Config, |
| Wav2Vec2Model, |
| XLMConfig, |
| XLMRobertaConfig, |
| XLNetConfig, |
| is_torch_available, |
| load_pytorch_checkpoint_in_tf2_model, |
| ) |
| from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging |
|
|
|
|
| if is_torch_available(): |
| import numpy as np |
| import torch |
|
|
| from . import ( |
| AlbertForPreTraining, |
| BartForConditionalGeneration, |
| BertForPreTraining, |
| BertForQuestionAnswering, |
| BertForSequenceClassification, |
| CamembertForMaskedLM, |
| CTRLLMHeadModel, |
| DistilBertForMaskedLM, |
| DistilBertForQuestionAnswering, |
| DPRContextEncoder, |
| DPRQuestionEncoder, |
| DPRReader, |
| ElectraForPreTraining, |
| FlaubertWithLMHeadModel, |
| GPT2LMHeadModel, |
| LayoutLMForMaskedLM, |
| LxmertForPreTraining, |
| LxmertVisualFeatureEncoder, |
| OpenAIGPTLMHeadModel, |
| RobertaForMaskedLM, |
| RobertaForSequenceClassification, |
| T5ForConditionalGeneration, |
| TransfoXLLMHeadModel, |
| XLMRobertaForMaskedLM, |
| XLMWithLMHeadModel, |
| XLNetLMHeadModel, |
| ) |
|
|
|
|
| logging.set_verbosity_info() |
|
|
| MODEL_CLASSES = { |
| "bart": ( |
| BartConfig, |
| TFBartForConditionalGeneration, |
| TFBartForSequenceClassification, |
| BartForConditionalGeneration, |
| ), |
| "bert": ( |
| BertConfig, |
| TFBertForPreTraining, |
| BertForPreTraining, |
| ), |
| "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": ( |
| BertConfig, |
| TFBertForQuestionAnswering, |
| BertForQuestionAnswering, |
| ), |
| "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": ( |
| BertConfig, |
| TFBertForQuestionAnswering, |
| BertForQuestionAnswering, |
| ), |
| "google-bert/bert-base-cased-finetuned-mrpc": ( |
| BertConfig, |
| TFBertForSequenceClassification, |
| BertForSequenceClassification, |
| ), |
| "dpr": ( |
| DPRConfig, |
| TFDPRQuestionEncoder, |
| TFDPRContextEncoder, |
| TFDPRReader, |
| DPRQuestionEncoder, |
| DPRContextEncoder, |
| DPRReader, |
| ), |
| "openai-community/gpt2": ( |
| GPT2Config, |
| TFGPT2LMHeadModel, |
| GPT2LMHeadModel, |
| ), |
| "xlnet": ( |
| XLNetConfig, |
| TFXLNetLMHeadModel, |
| XLNetLMHeadModel, |
| ), |
| "xlm": ( |
| XLMConfig, |
| TFXLMWithLMHeadModel, |
| XLMWithLMHeadModel, |
| ), |
| "xlm-roberta": ( |
| XLMRobertaConfig, |
| TFXLMRobertaForMaskedLM, |
| XLMRobertaForMaskedLM, |
| ), |
| "transfo-xl": ( |
| TransfoXLConfig, |
| TFTransfoXLLMHeadModel, |
| TransfoXLLMHeadModel, |
| ), |
| "openai-community/openai-gpt": ( |
| OpenAIGPTConfig, |
| TFOpenAIGPTLMHeadModel, |
| OpenAIGPTLMHeadModel, |
| ), |
| "roberta": ( |
| RobertaConfig, |
| TFRobertaForCausalLM, |
| TFRobertaForMaskedLM, |
| RobertaForMaskedLM, |
| ), |
| "layoutlm": ( |
| LayoutLMConfig, |
| TFLayoutLMForMaskedLM, |
| LayoutLMForMaskedLM, |
| ), |
| "FacebookAI/roberta-large-mnli": ( |
| RobertaConfig, |
| TFRobertaForSequenceClassification, |
| RobertaForSequenceClassification, |
| ), |
| "camembert": ( |
| CamembertConfig, |
| TFCamembertForMaskedLM, |
| CamembertForMaskedLM, |
| ), |
| "flaubert": ( |
| FlaubertConfig, |
| TFFlaubertWithLMHeadModel, |
| FlaubertWithLMHeadModel, |
| ), |
| "distilbert": ( |
| DistilBertConfig, |
| TFDistilBertForMaskedLM, |
| DistilBertForMaskedLM, |
| ), |
| "distilbert-base-distilled-squad": ( |
| DistilBertConfig, |
| TFDistilBertForQuestionAnswering, |
| DistilBertForQuestionAnswering, |
| ), |
| "lxmert": ( |
| LxmertConfig, |
| TFLxmertForPreTraining, |
| LxmertForPreTraining, |
| ), |
| "lxmert-visual-feature-encoder": ( |
| LxmertConfig, |
| TFLxmertVisualFeatureEncoder, |
| LxmertVisualFeatureEncoder, |
| ), |
| "Salesforce/ctrl": ( |
| CTRLConfig, |
| TFCTRLLMHeadModel, |
| CTRLLMHeadModel, |
| ), |
| "albert": ( |
| AlbertConfig, |
| TFAlbertForPreTraining, |
| AlbertForPreTraining, |
| ), |
| "t5": ( |
| T5Config, |
| TFT5ForConditionalGeneration, |
| T5ForConditionalGeneration, |
| ), |
| "electra": ( |
| ElectraConfig, |
| TFElectraForPreTraining, |
| ElectraForPreTraining, |
| ), |
| "wav2vec2": ( |
| Wav2Vec2Config, |
| TFWav2Vec2Model, |
| Wav2Vec2Model, |
| ), |
| } |
|
|
|
|
| def convert_pt_checkpoint_to_tf( |
| model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True |
| ): |
| if model_type not in MODEL_CLASSES: |
| raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.") |
|
|
| config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type] |
|
|
| |
| if config_file in aws_config_map: |
| config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models) |
| config = config_class.from_json_file(config_file) |
| config.output_hidden_states = True |
| config.output_attentions = True |
| print(f"Building TensorFlow model from configuration: {config}") |
| tf_model = model_class(config) |
|
|
| |
| if pytorch_checkpoint_path in aws_config_map.keys(): |
| pytorch_checkpoint_path = cached_file( |
| pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models |
| ) |
| |
| tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path) |
|
|
| if compare_with_pt_model: |
| tfo = tf_model(tf_model.dummy_inputs, training=False) |
|
|
| weights_only_kwarg = {"weights_only": True} |
| state_dict = torch.load( |
| pytorch_checkpoint_path, |
| map_location="cpu", |
| **weights_only_kwarg, |
| ) |
| pt_model = pt_model_class.from_pretrained( |
| pretrained_model_name_or_path=None, config=config, state_dict=state_dict |
| ) |
|
|
| with torch.no_grad(): |
| pto = pt_model(**pt_model.dummy_inputs) |
|
|
| np_pt = pto[0].numpy() |
| np_tf = tfo[0].numpy() |
| diff = np.amax(np.abs(np_pt - np_tf)) |
| print(f"Max absolute difference between models outputs {diff}") |
| assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}" |
|
|
| |
| print(f"Save TensorFlow model to {tf_dump_path}") |
| tf_model.save_weights(tf_dump_path, save_format="h5") |
|
|
|
|
| def convert_all_pt_checkpoints_to_tf( |
| args_model_type, |
| tf_dump_path, |
| model_shortcut_names_or_path=None, |
| config_shortcut_names_or_path=None, |
| compare_with_pt_model=False, |
| use_cached_models=False, |
| remove_cached_files=False, |
| only_convert_finetuned_models=False, |
| ): |
| if args_model_type is None: |
| model_types = list(MODEL_CLASSES.keys()) |
| else: |
| model_types = [args_model_type] |
|
|
| for j, model_type in enumerate(model_types, start=1): |
| print("=" * 100) |
| print(f" Converting model type {j}/{len(model_types)}: {model_type}") |
| print("=" * 100) |
| if model_type not in MODEL_CLASSES: |
| raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.") |
|
|
| config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type] |
|
|
| if model_shortcut_names_or_path is None: |
| model_shortcut_names_or_path = list(aws_model_maps.keys()) |
| if config_shortcut_names_or_path is None: |
| config_shortcut_names_or_path = model_shortcut_names_or_path |
|
|
| for i, (model_shortcut_name, config_shortcut_name) in enumerate( |
| zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1 |
| ): |
| print("-" * 100) |
| if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name: |
| if not only_convert_finetuned_models: |
| print(f" Skipping finetuned checkpoint {model_shortcut_name}") |
| continue |
| model_type = model_shortcut_name |
| elif only_convert_finetuned_models: |
| print(f" Skipping not finetuned checkpoint {model_shortcut_name}") |
| continue |
| print( |
| f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}" |
| ) |
| print("-" * 100) |
|
|
| if config_shortcut_name in aws_config_map: |
| config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models) |
| else: |
| config_file = config_shortcut_name |
|
|
| if model_shortcut_name in aws_model_maps: |
| model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models) |
| else: |
| model_file = model_shortcut_name |
|
|
| if os.path.isfile(model_shortcut_name): |
| model_shortcut_name = "converted_model" |
|
|
| convert_pt_checkpoint_to_tf( |
| model_type=model_type, |
| pytorch_checkpoint_path=model_file, |
| config_file=config_file, |
| tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"), |
| compare_with_pt_model=compare_with_pt_model, |
| ) |
| if remove_cached_files: |
| os.remove(config_file) |
| os.remove(model_file) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument( |
| "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file." |
| ) |
| parser.add_argument( |
| "--model_type", |
| default=None, |
| type=str, |
| help=( |
| f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and " |
| "convert all the models from AWS." |
| ), |
| ) |
| parser.add_argument( |
| "--pytorch_checkpoint_path", |
| default=None, |
| type=str, |
| help=( |
| "Path to the PyTorch checkpoint path or shortcut name to download from AWS. " |
| "If not given, will download and convert all the checkpoints from AWS." |
| ), |
| ) |
| parser.add_argument( |
| "--config_file", |
| default=None, |
| type=str, |
| help=( |
| "The config json file corresponding to the pre-trained model. \n" |
| "This specifies the model architecture. If not given and " |
| "--pytorch_checkpoint_path is not given or is a shortcut name " |
| "use the configuration associated to the shortcut name on the AWS" |
| ), |
| ) |
| parser.add_argument( |
| "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions." |
| ) |
| parser.add_argument( |
| "--use_cached_models", |
| action="store_true", |
| help="Use cached models if possible instead of updating to latest checkpoint versions.", |
| ) |
| parser.add_argument( |
| "--remove_cached_files", |
| action="store_true", |
| help="Remove pytorch models after conversion (save memory when converting in batches).", |
| ) |
| parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.") |
| args = parser.parse_args() |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| convert_all_pt_checkpoints_to_tf( |
| args.model_type.lower() if args.model_type is not None else None, |
| args.tf_dump_path, |
| model_shortcut_names_or_path=[args.pytorch_checkpoint_path] |
| if args.pytorch_checkpoint_path is not None |
| else None, |
| config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, |
| compare_with_pt_model=args.compare_with_pt_model, |
| use_cached_models=args.use_cached_models, |
| remove_cached_files=args.remove_cached_files, |
| only_convert_finetuned_models=args.only_convert_finetuned_models, |
| ) |
|
|