Source code for paddlenlp.datasets.iwslt15

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import os
import warnings

from paddle.io import Dataset
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from . import DatasetBuilder

__all__ = ['IWSLT15']


[docs]class IWSLT15(DatasetBuilder): URL = "https://paddlenlp.bj.bcebos.com/datasets/iwslt15.en-vi.tar.gz" META_INFO = collections.namedtuple('META_INFO', ('src_file', 'tgt_file', 'src_md5', 'tgt_md5')) MD5 = 'aca22dc3f90962e42916dbb36d8f3e8e' SPLITS = { 'train': META_INFO( os.path.join("iwslt15.en-vi", "train.en"), os.path.join("iwslt15.en-vi", "train.vi"), "5b6300f46160ab5a7a995546d2eeb9e6", "858e884484885af5775068140ae85dab"), 'dev': META_INFO( os.path.join("iwslt15.en-vi", "tst2012.en"), os.path.join("iwslt15.en-vi", "tst2012.vi"), "c14a0955ed8b8d6929fdabf4606e3875", "dddf990faa149e980b11a36fca4a8898"), 'test': META_INFO( os.path.join("iwslt15.en-vi", "tst2013.en"), os.path.join("iwslt15.en-vi", "tst2013.vi"), "c41c43cb6d3b122c093ee89608ba62bd", "a3185b00264620297901b647a4cacf38") } VOCAB_INFO = (os.path.join("iwslt15.en-vi", "vocab.en"), os.path.join( "iwslt15.en-vi", "vocab.vi"), "98b5011e1f579936277a273fd7f4e9b4", "e8b05f8c26008a798073c619236712b4") UNK_TOKEN = '<unk>' BOS_TOKEN = '<s>' EOS_TOKEN = '</s>' def _get_data(self, mode, **kwargs): default_root = os.path.join(DATA_HOME, self.__class__.__name__) src_filename, tgt_filename, src_data_hash, tgt_data_hash = self.SPLITS[ mode] src_fullname = os.path.join(default_root, src_filename) tgt_fullname = os.path.join(default_root, tgt_filename) src_vocab_filename, src_vocab_hash, tgt_vocab_filename, tgt_vocab_hash = self.VOCAB_INFO src_vocab_fullname = os.path.join(default_root, src_vocab_filename) tgt_vocab_fullname = os.path.join(default_root, tgt_vocab_filename) if (not os.path.exists(src_fullname) or (src_data_hash and not md5file(src_fullname) == src_data_hash)) or ( not os.path.exists(tgt_fullname) or (tgt_data_hash and not md5file(tgt_fullname) == tgt_data_hash)) or ( not os.path.exists(src_vocab_fullname) or (src_vocab_hash and not md5file(src_vocab_fullname) == src_vocab_hash)) or ( not os.path.exists(tgt_vocab_fullname) or (tgt_vocab_hash and not md5file(tgt_vocab_fullname) == tgt_vocab_hash)): get_path_from_url(self.URL, default_root, self.MD5) return src_fullname, tgt_fullname def _read(self, filename, *args): src_filename, tgt_filename = filename with open(src_filename, 'r', encoding='utf-8') as src_f: with open(tgt_filename, 'r', encoding='utf-8') as tgt_f: for src_line, tgt_line in zip(src_f, tgt_f): src_line = src_line.strip() tgt_line = tgt_line.strip() if not src_line and not tgt_line: continue yield {"en": src_line, "vi": tgt_line}
[docs] def get_vocab(self): en_vocab_fullname = os.path.join(DATA_HOME, self.__class__.__name__, self.VOCAB_INFO[0]) vi_vocab_fullname = os.path.join(DATA_HOME, self.__class__.__name__, self.VOCAB_INFO[1]) # Construct vocab_info to match the form of the input of `Vocab.load_vocabulary()` function vocab_info = { 'en': { 'filepath': en_vocab_fullname, 'unk_token': self.UNK_TOKEN, 'bos_token': self.BOS_TOKEN, 'eos_token': self.EOS_TOKEN }, 'vi': { 'filepath': vi_vocab_fullname, 'unk_token': self.UNK_TOKEN, 'bos_token': self.BOS_TOKEN, 'eos_token': self.EOS_TOKEN } } return vocab_info