Source code for paddlenlp.transformers.utils
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import paddle
from paddle.nn import Layer
[docs]def fn_args_to_dict(func, *args, **kwargs):
"""
Inspect function `func` and its arguments for running, and extract a
dict mapping between argument names and keys.
"""
if hasattr(inspect, 'getfullargspec'):
(spec_args, spec_varargs, spec_varkw, spec_defaults, _, _,
_) = inspect.getfullargspec(func)
else:
(spec_args, spec_varargs, spec_varkw,
spec_defaults) = inspect.getargspec(func)
# add positional argument values
init_dict = dict(zip(spec_args, args))
# add default argument values
kwargs_dict = dict(zip(spec_args[-len(spec_defaults):],
spec_defaults)) if spec_defaults else {}
kwargs_dict.update(kwargs)
init_dict.update(kwargs_dict)
return init_dict
[docs]class InitTrackerMeta(type(Layer)):
"""
This metaclass wraps the `__init__` method of a class to add `init_config`
attribute for instances of that class, and `init_config` use a dict to track
the initial configuration. If the class has `_wrap_init` method, it would be
hooked after `__init__` and called as `_wrap_init(self, init_fn, init_args)`.
Since InitTrackerMeta would be used as metaclass for pretrained model classes,
which always are Layer and `type(Layer)` is not `type`, thus use `type(Layer)`
rather than `type` as base class for it to avoid inheritance metaclass
conflicts.
"""
def __init__(cls, name, bases, attrs):
init_func = cls.__init__
# If attrs has `__init__`, wrap it using accessable `_wrap_init`.
# Otherwise, no need to wrap again since the super cls has been wraped.
# TODO: remove reduplicated tracker if using super cls `__init__`
help_func = getattr(cls, '_wrap_init',
None) if '__init__' in attrs else None
cls.__init__ = InitTrackerMeta.init_and_track_conf(init_func, help_func)
super(InitTrackerMeta, cls).__init__(name, bases, attrs)
[docs] @staticmethod
def init_and_track_conf(init_func, help_func=None):
"""
wraps `init_func` which is `__init__` method of a class to add `init_config`
attribute for instances of that class.
Args:
init_func (callable): It should be the `__init__` method of a class.
help_func (callable, optional): If provided, it would be hooked after
`init_func` and called as `_wrap_init(self, init_func, *init_args, **init_args)`.
Default None.
Returns:
function: the wrapped function
"""
@functools.wraps(init_func)
def __impl__(self, *args, **kwargs):
# keep full configuration
init_func(self, *args, **kwargs)
# registed helper by `_wrap_init`
if help_func:
help_func(self, init_func, *args, **kwargs)
self.init_config = kwargs
if args:
kwargs['init_args'] = args
kwargs['init_class'] = self.__class__.__name__
return __impl__