# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import io
import math
import os
import warnings
import sys
import inspect
import paddle.distributed as dist
from paddle.io import Dataset, IterableDataset
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from typing import Iterable, Iterator, Optional, List, Any, Callable, Union
import importlib
from functools import partial
__all__ = ['MapDataset', 'DatasetBuilder', 'IterDataset', 'load_dataset']
DATASETS_MODULE_PATH = "paddlenlp.datasets."
def import_main_class(module_path):
"""Import a module at module_path and return its main class.
"""
module_path = DATASETS_MODULE_PATH + module_path
module = importlib.import_module(module_path)
main_cls_type = DatasetBuilder
# Find the main class in our imported module
module_main_cls = None
for name, obj in module.__dict__.items():
if isinstance(obj, type) and issubclass(obj, main_cls_type):
if name == 'DatasetBuilder':
continue
module_main_cls = obj
break
return module_main_cls
def load_dataset(path_or_read_func,
name=None,
data_files=None,
splits=None,
lazy=None,
**kwargs):
if inspect.isfunction(path_or_read_func):
assert lazy is not None, "lazy can not be None in custom mode."
kwargs['name'] = name
kwargs['data_files'] = data_files
kwargs['splits'] = splits
custom_kwargs = {}
for name in inspect.signature(path_or_read_func).parameters.keys():
if name in kwargs.keys():
custom_kwargs[name] = kwargs[name]
reader_instance = SimpleBuilder(lazy=lazy, read_func=path_or_read_func)
return reader_instance.read(**custom_kwargs)
else:
reader_cls = import_main_class(path_or_read_func)
if not name:
reader_instance = reader_cls(lazy=lazy, **kwargs)
else:
reader_instance = reader_cls(lazy=lazy, name=name, **kwargs)
datasets = reader_instance.read_datasets(
data_files=data_files, splits=splits)
return datasets
[docs]class MapDataset(Dataset):
"""
Wraps a dataset-like object as a instance of Dataset, and equips it with
`map` and other utility methods. All non-magic methods of the raw object
also accessible.
Args:
data (list|Dataset): A dataset-like object. It can be a list or a
subclass of Dataset.
"""
def __init__(self, data, **kwargs):
self.data = data
self._transform_pipline = []
self.new_data = self.data
self.label_list = kwargs.pop('label_list', None)
self.vocab_info = kwargs.pop('vocab_info', None)
def _transform(self, data):
for fn in self._transform_pipline:
data = fn(data)
return data
def __getitem__(self, idx):
return self._transform(self.new_data[
idx]) if self._transform_pipline else self.new_data[idx]
def __len__(self):
return len(self.new_data)
[docs] def filter(self, fn):
"""
Filters samples by the filter function and uses the filtered data to
update this dataset.
Args:
fn (callable): A filter function that takes a sample as input and
returns a boolean. Samples that return False are discarded.
"""
self.new_data = [
self.new_data[idx] for idx in range(len(self.new_data))
if fn(self.new_data[idx])
]
return self
[docs] def shard(self, num_shards=None, index=None):
"""
Use samples whose indices mod `index` equals 0 to update this dataset.
Args:
num_shards (int, optional): A integer representing the number of
data shards. If None, `num_shards` would be number of trainers.
Default: None
index (int, optional): A integer representing the index of the
current shard. If None, index` would be the current trainer rank
id. Default: None.
"""
if num_shards is None:
num_shards = dist.get_world_size()
if index is None:
index = dist.get_rank()
num_samples = int(math.ceil(len(self.new_data) * 1.0 / num_shards))
# add extra samples to make it evenly divisible
self.new_data = [
self.new_data[idx] for idx in range(len(self.new_data))
if idx % num_shards == index
]
if len(self.new_data) < num_samples:
self.new_data.append(self.new_data[index + 1 - num_shards])
return self
[docs] def map(self, fn, lazy=True, batched=False):
"""
Performs specific function on the dataset to transform and update every sample.
Args:
fn (callable): Transformations to be performed. It receives single
sample as argument if batched is False. Else it receives all examples.
lazy (bool, optional): If True, transformations would be delayed and
performed on demand. Otherwise, transforms all samples at once. Note that if `fn` is
stochastic, `lazy` should be True or you will get the same
result on all epochs. Defalt: False.
batched(bool, optional): If True, transformations would take all examples as input and
return a collection of transformed examples. Note that if set True, `lazy` option
would be ignored.
"""
if batched:
self.new_data = fn(self.new_data)
elif lazy:
self._transform_pipline.append(fn)
else:
self.new_data = [
fn(self.new_data[idx]) for idx in range(len(self.new_data))
]
return self
def __getattr__(self, name):
return getattr(self.data, name)
[docs]class IterDataset(IterableDataset):
"""
Wraps a dataset-like object as a instance of Dataset, and equips it with
`map` and other utility methods. All non-magic methods of the raw object
also accessible.
Args:
data (Iterable): A dataset-like object. It can be a Iterable or a
subclass of Dataset.
"""
def __init__(self, data, **kwargs):
self.data = data
self._transform_pipline = []
self._filter_pipline = []
self.label_list = kwargs.pop('label_list', None)
self.vocab_info = kwargs.pop('vocab_info', None)
def _transform(self, data):
for fn in self._transform_pipline:
data = fn(data)
return data
def _shard_filter(self, num_samples):
return True
def _filter(self, data):
for fn in self._filter_pipline:
if not fn(data):
return False
return True
def __iter__(self):
num_samples = 0
if inspect.isfunction(self.data):
for example in self.data():
if (not self._filter_pipline or
self._filter(self._filter_pipline)
) and self._shard_filter(num_samples=num_samples):
yield self._transform(
example) if self._transform_pipline else example
num_samples += 1
else:
if inspect.isgenerator(self.data):
warnings.warn(
'Reciving generator as data source, data can only be iterated once'
)
for example in self.data:
if (not self._filter_pipline or
self._filter(self._filter_pipline)
) and self._shard_filter(num_samples=num_samples):
yield self._transform(
example) if self._transform_pipline else example
num_samples += 1
[docs] def filter(self, fn):
"""
Filters samples by the filter function and uses the filtered data to
update this dataset.
Args:
fn (callable): A filter function that takes a sample as input and
returns a boolean. Samples that return False are discarded.
"""
self._filter_pipline.append(fn)
return self
[docs] def shard(self, num_shards=None, index=None):
"""
Use samples whose indices mod `index` equals 0 to update this dataset.
Args:
num_shards (int, optional): A integer representing the number of
data shards. If None, `num_shards` would be number of trainers.
Default: None
index (int, optional): A integer representing the index of the
current shard. If None, index` would be the current trainer rank
id. Default: None.
"""
if num_shards is None:
num_shards = dist.get_world_size()
if index is None:
index = dist.get_rank()
def sharder(num_shards, index, num_samples):
if num_samples % num_shards == index:
return True
else:
return False
fn = partial(sharder, num_shards=num_shards, index=index)
self._shard_filter = fn
return self
[docs] def map(self, fn):
"""
Performs specific function on the dataset to transform and update every sample.
Args:
fn (callable): Transformations to be performed. It receives single
sample as argument.
"""
self._transform_pipline.append(fn)
return self
def __getattr__(self, name):
return getattr(self.data, name)
[docs]class DatasetBuilder:
"""
A base class for all DatasetBuilder. It provides a `read()` function to turn
a data file into a MapDataset or IterDataset.
`_get_data()` function and `_read()` function should be implemented to download
data file and read data file into a `Iterable` of the examples.
"""
lazy = False
def __init__(self, lazy=None, name=None, **config):
if lazy is not None:
self.lazy = lazy
self.name = name
self.config = config
def read_datasets(self, splits=None, data_files=None):
datasets = []
assert splits or data_files, "`data_files` and `splits` can not both be None."
if data_files:
assert isinstance(data_files, str) or isinstance(
data_files, dict
) or isinstance(data_files, tuple) or isinstance(
data_files, list
), "`data_files` should be a string or tuple or list or a dictionary whose key is split name ande value is a path of data file."
if isinstance(data_files, str):
split = 'train'
datasets.append(self.read(filename=data_files, split=split))
elif isinstance(data_files, tuple) or isinstance(data_files, list):
split = 'train'
datasets += [
self.read(
filename=filename, split=split)
for filename in data_files
]
else:
datasets += [
self.read(
filename=filename, split=split)
for split, filename in data_files.items()
]
if splits:
assert isinstance(splits, str) or (
isinstance(splits, list) and isinstance(splits[0], str)
) or (
isinstance(splits, tuple) and isinstance(splits[0], str)
), "`splits` should be a string or list of string or a tuple of string."
if isinstance(splits, str):
filename = self._get_data(splits)
datasets.append(self.read(filename=filename, split=splits))
else:
for split in splits:
filename = self._get_data(split)
datasets.append(self.read(filename=filename, split=split))
return datasets if len(datasets) > 1 else datasets[0]
[docs] def read(self, filename, split='train'):
"""
Returns an dataset containing all the examples that can be read from the file path.
If `self.lazy` is `False`, this eagerly reads all instances from `self._read()`
and returns an `MapDataset`.
If `self.lazy` is `True`, this returns an `IterDataset`, which internally
relies on the generator created from `self._read()` to lazily produce examples.
In this case your implementation of `_read()` must also be lazy
(that is, not load all examples into memory at once).
"""
label_list = self.get_labels()
vocab_info = self.get_vocab()
if self.lazy:
def generate_examples():
generator = self._read(
filename, split
) if self._read.__code__.co_argcount > 2 else self._read(
filename)
for example in generator:
# We need to check if the example contains label column and confirm its name.
# For now we only allow `label` or `labels` to be the name of label column.
if 'labels' in example.keys():
label_col = 'labels'
elif 'label' in example.keys():
label_col = 'label'
else:
label_col = None
# Convert class label to label ids.
if label_list is not None and example.get(label_col, None):
label_dict = {}
for i, label in enumerate(label_list):
label_dict[label] = i
if isinstance(example[label_col], list) or isinstance(
example[label_col], tuple):
for label_idx in range(len(example[label_col])):
example[label_col][label_idx] = label_dict[
example[label_col][label_idx]]
else:
example[label_col] = label_dict[example[label_col]]
yield example
else:
yield example
return IterDataset(
generate_examples(),
label_list=label_list,
vocab_info=vocab_info)
else:
examples = self._read(
filename,
split) if self._read.__code__.co_argcount > 2 else self._read(
filename)
# Then some validation.
if not isinstance(examples, list):
examples = list(examples)
if not examples:
raise ValueError(
"No instances were read from the given filepath {}. "
"Is the path correct?".format(filename))
# We need to check if the example contains label column and confirm its name.
# For now we only allow `label` or `labels` to be the name of label column.
if 'labels' in examples[0].keys():
label_col = 'labels'
elif 'label' in examples[0].keys():
label_col = 'label'
else:
label_col = None
# Convert class label to label ids.
if label_list is not None and examples[0].get(label_col, None):
label_dict = {}
for i, label in enumerate(label_list):
label_dict[label] = i
for idx in range(len(examples)):
if isinstance(examples[idx][label_col], list) or isinstance(
examples[idx][label_col], tuple):
for label_idx in range(len(examples[idx][label_col])):
examples[idx][label_col][label_idx] = label_dict[
examples[idx][label_col][label_idx]]
else:
examples[idx][label_col] = label_dict[examples[idx][
label_col]]
return MapDataset(
examples, label_list=label_list, vocab_info=vocab_info)
def _read(self, filename: str, *args):
"""
Reads examples from the given file_path and returns them as an
`Iterable` (which could be a list or could be a generator).
"""
raise NotImplementedError
def _get_data(self, mode: str):
"""
Download examples from the given URL and customized split informations and returns a filepath.
"""
raise NotImplementedError
[docs] def get_labels(self):
"""
Return list of class labels of the dataset if specified.
"""
return None
[docs] def get_vocab(self):
"""
Return vocab file path of the dataset if specified.
"""
return None
class SimpleBuilder(DatasetBuilder):
def __init__(self, lazy, read_func):
self._read = read_func
self.lazy = lazy
def read(self, **kwargs):
if self.lazy:
def generate_examples():
generator = self._read(**kwargs)
for example in generator:
yield example
return IterDataset(generate_examples)
else:
examples = self._read(**kwargs)
if hasattr(examples, '__len__') and hasattr(examples,
'__getitem__'):
return MapDataset(examples)
else:
return MapDataset(list(examples))