Source code for paddlenlp.transformers.model_utils

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import io
import json
import os
import six
import logging
import inspect

import paddle
from paddle.nn import Layer
# TODO(fangzeyang) Temporary fix and replace by paddle framework downloader later
from paddlenlp.utils.downloader import get_path_from_url
from paddlenlp.utils.env import MODEL_HOME
from paddlenlp.utils.log import logger

from .generation_utils import GenerationMixin
from .utils import InitTrackerMeta, fn_args_to_dict

__all__ = [
    'PretrainedModel',
    'register_base_model',
]


[docs]def register_base_model(cls): """ Add a `base_model_class` attribute for the base class of decorated class, representing the base model class in derived classes of the same architecture. Args: cls (class): the name of the model """ base_cls = cls.__bases__[0] assert issubclass( base_cls, PretrainedModel ), "`register_base_model` should be used on subclasses of PretrainedModel." base_cls.base_model_class = cls return cls
[docs]@six.add_metaclass(InitTrackerMeta) class PretrainedModel(Layer, GenerationMixin): """ The base class for all pretrained models. It provides some attributes and common methods for all pretrained models, including attributes `init_config`, `config` for initialized arguments and methods for saving, loading. It also includes some class attributes (should be set by derived classes): - `model_config_file` (str): represents the file name for saving and loading model configuration, it's value is `model_config.json`. - `resource_files_names` (dict): use this to map resources to specific file names for saving and loading. - `pretrained_resource_files_map` (dict): The dict has the same keys as `resource_files_names`, the values are also dict mapping specific pretrained model name to URL linking to pretrained model. - `pretrained_init_configuration` (dict): The dict has pretrained model names as keys, and the values are also dict preserving corresponding configuration for model initialization. - `base_model_prefix` (str): represents the the attribute associated to the base model in derived classes of the same architecture adding layers on top of the base model. """ model_config_file = "model_config.json" pretrained_init_configuration = {} # TODO: more flexible resource handle, namedtuple with fileds as: # resource_name, saved_file, handle_name_for_load(None for used as __init__ # arguments), handle_name_for_save resource_files_names = {"model_state": "model_state.pdparams"} pretrained_resource_files_map = {} base_model_prefix = "" def _wrap_init(self, original_init, *args, **kwargs): """ It would be hooked after `__init__` to add a dict including arguments of `__init__` as a attribute named `config` of the prtrained model instance. """ init_dict = fn_args_to_dict(original_init, *((self, ) + args), **kwargs) self.config = init_dict @property def base_model(self): return getattr(self, self.base_model_prefix, self) @property def model_name_list(self): return list(self.pretrained_init_configuration.keys()) def get_input_embeddings(self): base_model = getattr(self, self.base_model_prefix, self) if base_model is not self: return base_model.get_input_embeddings() else: raise NotImplementedError def get_output_embeddings(self): return None # Overwrite for models with output embeddings
[docs] @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """ Instantiate an instance of `PretrainedModel` from a predefined model specified by name or path. Args: pretrained_model_name_or_path (str): A name of or a file path to a pretrained model. *args (tuple): position arguments for `__init__`. If provide, use this as position argument values for model initialization. **kwargs (dict): keyword arguments for `__init__`. If provide, use this to update pre-defined keyword argument values for model initialization. If the key is in base model `__init__`, update keyword argument of base model; else update keyword argument of derived model. Returns: PretrainedModel: An instance of PretrainedModel. """ pretrained_models = list(cls.pretrained_init_configuration.keys()) resource_files = {} init_configuration = {} if pretrained_model_name_or_path in pretrained_models: for file_id, map_list in cls.pretrained_resource_files_map.items(): resource_files[file_id] = map_list[ pretrained_model_name_or_path] init_configuration = copy.deepcopy( cls.pretrained_init_configuration[ pretrained_model_name_or_path]) else: if os.path.isdir(pretrained_model_name_or_path): for file_id, file_name in cls.resource_files_names.items(): full_file_name = os.path.join(pretrained_model_name_or_path, file_name) resource_files[file_id] = full_file_name resource_files["model_config_file"] = os.path.join( pretrained_model_name_or_path, cls.model_config_file) else: raise ValueError( "Calling {}.from_pretrained() with a model identifier or the " "path to a directory instead. The supported model " "identifiers are as follows: {}".format( cls.__name__, cls.pretrained_init_configuration.keys())) default_root = os.path.join(MODEL_HOME, pretrained_model_name_or_path) resolved_resource_files = {} for file_id, file_path in resource_files.items(): path = os.path.join(default_root, file_path.split('/')[-1]) if file_path is None or os.path.isfile(file_path): resolved_resource_files[file_id] = file_path elif os.path.exists(path): logger.info("Already cached %s" % path) resolved_resource_files[file_id] = path else: logger.info("Downloading %s and saved to %s" % (file_path, default_root)) resolved_resource_files[file_id] = get_path_from_url( file_path, default_root) # Prepare model initialization kwargs # Did we saved some inputs and kwargs to reload ? model_config_file = resolved_resource_files.pop("model_config_file", None) if model_config_file is not None: with io.open(model_config_file, encoding="utf-8") as f: init_kwargs = json.load(f) else: init_kwargs = init_configuration # position args are stored in kwargs, maybe better not include init_args = init_kwargs.pop("init_args", ()) # class name corresponds to this configuration init_class = init_kwargs.pop("init_class", cls.base_model_class.__name__) # Check if the loaded config matches the current model class's __init__ # arguments. If not match, the loaded config is for the base model class. if init_class == cls.base_model_class.__name__: base_args = init_args base_kwargs = init_kwargs derived_args = () derived_kwargs = {} base_arg_index = None else: # extract config for base model derived_args = list(init_args) derived_kwargs = init_kwargs base_arg = None for i, arg in enumerate(init_args): if isinstance(arg, dict) and "init_class" in arg: assert arg.pop( "init_class") == cls.base_model_class.__name__, ( "pretrained base model should be {}" ).format(cls.base_model_class.__name__) base_arg_index = i base_arg = arg break for arg_name, arg in init_kwargs.items(): if isinstance(arg, dict) and "init_class" in arg: assert arg.pop( "init_class") == cls.base_model_class.__name__, ( "pretrained base model should be {}" ).format(cls.base_model_class.__name__) base_arg_index = arg_name base_arg = arg break base_args = base_arg.pop("init_args", ()) base_kwargs = base_arg if cls == cls.base_model_class: # Update with newly provided args and kwargs for base model base_args = base_args if not args else args base_kwargs.update(kwargs) model = cls(*base_args, **base_kwargs) else: # Update with newly provided args and kwargs for derived model base_parameters_dict = inspect.signature( cls.base_model_class.__init__).parameters for k, v in kwargs.items(): if k in base_parameters_dict: base_kwargs[k] = v base_model = cls.base_model_class(*base_args, **base_kwargs) if base_arg_index is not None: derived_args[base_arg_index] = base_model else: derived_args = (base_model, ) # assume at the first position derived_args = derived_args if not args else args derived_parameters_dict = inspect.signature(cls.__init__).parameters for k, v in kwargs.items(): if k in derived_parameters_dict: derived_kwargs[k] = v model = cls(*derived_args, **derived_kwargs) # Maybe need more ways to load resources. weight_path = list(resolved_resource_files.values())[0] assert weight_path.endswith( ".pdparams"), "suffix of weight must be .pdparams" state_dict = paddle.load(weight_path) # Make sure we are able to load base models as well as derived models # (with heads) start_prefix = "" model_to_load = model state_to_load = state_dict unexpected_keys = [] missing_keys = [] if not hasattr(model, cls.base_model_prefix) and any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): # base model state_to_load = {} start_prefix = cls.base_model_prefix + "." for k, v in state_dict.items(): if k.startswith(cls.base_model_prefix): state_to_load[k[len(start_prefix):]] = v else: unexpected_keys.append(k) if hasattr(model, cls.base_model_prefix) and not any( s.startswith(cls.base_model_prefix) for s in state_dict.keys()): # derived model (base model with heads) model_to_load = getattr(model, cls.base_model_prefix) for k in model.state_dict().keys(): if not k.startswith(cls.base_model_prefix): missing_keys.append(k) if len(missing_keys) > 0: logger.info( "Weights of {} not initialized from pretrained model: {}". format(model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}". format(model.__class__.__name__, unexpected_keys)) model_to_load.set_state_dict(state_to_load) if paddle.in_dynamic_mode(): return model return model, state_to_load
[docs] def save_pretrained(self, save_directory): """ Save model configuration and related resources (model state) to files under `save_directory`. Args: save_directory (str): Directory to save files into. """ assert os.path.isdir( save_directory ), "Saving directory ({}) should be a directory".format(save_directory) # save model config model_config_file = os.path.join(save_directory, self.model_config_file) model_config = self.init_config # If init_config contains a Layer, use the layer's init_config to save for key, value in model_config.items(): if key == "init_args": args = [] for arg in value: args.append( arg.init_config if isinstance(arg, PretrainedModel) else arg) model_config[key] = tuple(args) elif isinstance(value, PretrainedModel): model_config[key] = value.init_config with io.open(model_config_file, "w", encoding="utf-8") as f: f.write(json.dumps(model_config, ensure_ascii=False)) # save model file_name = os.path.join(save_directory, list(self.resource_files_names.values())[0]) paddle.save(self.state_dict(), file_name)