# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import os
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from . import DatasetBuilder
[docs]class Glue(DatasetBuilder):
BUILDER_CONFIGS = {
'cola': {
'url': "https://dataset.bj.bcebos.com/glue/CoLA.zip",
'md5': 'b178a7c2f397b0433c39c7caf50a3543',
'splits': {
'train': [
os.path.join('CoLA', 'train.tsv'),
'c79d4693b8681800338aa044bf9e797b', (3, 1), 0
],
'dev': [
os.path.join('CoLA', 'dev.tsv'),
'c5475ccefc9e7ca0917294b8bbda783c', (3, 1), 0
],
'test': [
os.path.join('CoLA', 'test.tsv'),
'd8721b7dedda0dcca73cebb2a9f4259f', (1, ), 1
]
},
'labels': ["0", "1"]
},
'sst-2': {
'url': "https://dataset.bj.bcebos.com/glue/SST.zip",
'md5': '9f81648d4199384278b86e315dac217c',
'splits': {
'train': [
os.path.join('SST-2', 'train.tsv'),
'da409a0a939379ed32a470bc0f7fe99a', (0, 1), 1
],
'dev': [
os.path.join('SST-2', 'dev.tsv'),
'268856b487b2a31a28c0a93daaff7288', (0, 1), 1
],
'test': [
os.path.join('SST-2', 'test.tsv'),
'3230e4efec76488b87877a56ae49675a', (1, ), 1
]
},
'labels': ["0", "1"]
},
'sts-b': {
'url': 'https://dataset.bj.bcebos.com/glue/STS.zip',
'md5': 'd573676be38f1a075a5702b90ceab3de',
'splits': {
'train': [
os.path.join('STS-B', 'train.tsv'),
'4f7a86dde15fe4832c18e5b970998672', (7, 8, 9), 1
],
'dev': [
os.path.join('STS-B', 'dev.tsv'),
'5f4d6b0d2a5f268b1b56db773ab2f1fe', (7, 8, 9), 1
],
'test': [
os.path.join('STS-B', 'test.tsv'),
'339b5817e414d19d9bb5f593dd94249c', (7, 8), 1
]
},
'labels': None
},
'qqp': {
'url': 'https://dataset.bj.bcebos.com/glue/QQP.zip',
'md5': '884bf26e39c783d757acc510a2a516ef',
'splits': {
'train': [
os.path.join('QQP', 'train.tsv'),
'e003db73d277d38bbd83a2ef15beb442', (3, 4, 5), 1
],
'dev': [
os.path.join('QQP', 'dev.tsv'),
'cff6a448d1580132367c22fc449ec214', (3, 4, 5), 1
],
'test': [
os.path.join('QQP', 'test.tsv'),
'73de726db186b1b08f071364b2bb96d0', (1, 2), 1
]
},
'labels': ["0", "1"]
},
'mnli': {
'url': 'https://dataset.bj.bcebos.com/glue/MNLI.zip',
'md5': 'e343b4bdf53f927436d0792203b9b9ff',
'splits': {
'train': [
os.path.join('MNLI', 'train.tsv'),
'220192295e23b6705f3545168272c740', (8, 9, 11), 1
],
'dev_matched': [
os.path.join('MNLI', 'dev_matched.tsv'),
'c3fa2817007f4cdf1a03663611a8ad23', (8, 9, 15), 1
],
'dev_mismatched': [
os.path.join('MNLI', 'dev_mismatched.tsv'),
'b219e6fe74e4aa779e2f417ffe713053', (8, 9, 15), 1
],
'test_matched': [
os.path.join('MNLI', 'test_matched.tsv'),
'33ea0389aedda8a43dabc9b3579684d9', (8, 9), 1
],
'test_mismatched': [
os.path.join('MNLI', 'test_mismatched.tsv'),
'7d2f60a73d54f30d8a65e474b615aeb6', (8, 9), 1
]
},
'labels': ["contradiction", "entailment", "neutral"]
},
'qnli': {
'url': 'https://dataset.bj.bcebos.com/glue/QNLI.zip',
'md5': 'b4efd6554440de1712e9b54e14760e82',
'splits': {
'train': [
os.path.join('QNLI', 'train.tsv'),
'5e6063f407b08d1f7c7074d049ace94a', (1, 2, 3), 1
],
'dev': [
os.path.join('QNLI', 'dev.tsv'),
'1e81e211959605f144ba6c0ad7dc948b', (1, 2, 3), 1
],
'test': [
os.path.join('QNLI', 'test.tsv'),
'f2a29f83f3fe1a9c049777822b7fa8b0', (1, 2), 1
]
},
'labels': ["entailment", "not_entailment"]
},
'rte': {
'url': 'https://dataset.bj.bcebos.com/glue/RTE.zip',
'md5': 'bef554d0cafd4ab6743488101c638539',
'splits': {
'train': [
os.path.join('RTE', 'train.tsv'),
'd2844f558d111a16503144bb37a8165f', (1, 2, 3), 1
],
'dev': [
os.path.join('RTE', 'dev.tsv'),
'973cb4178d4534cf745a01c309d4a66c', (1, 2, 3), 1
],
'test': [
os.path.join('RTE', 'test.tsv'),
'6041008f3f3e48704f57ce1b88ad2e74', (1, 2), 1
]
},
'labels': ["entailment", "not_entailment"]
},
'wnli': {
'url': 'https://dataset.bj.bcebos.com/glue/WNLI.zip',
'md5': 'a1b4bd2861017d302d29e42139657a42',
'splits': {
'train': [
os.path.join('WNLI', 'train.tsv'),
'5cdc5a87b7be0c87a6363fa6a5481fc1', (1, 2, 3), 1
],
'dev': [
os.path.join('WNLI', 'dev.tsv'),
'a79a6dd5d71287bcad6824c892e517ee', (1, 2, 3), 1
],
'test': [
os.path.join('WNLI', 'test.tsv'),
'a18789ba4f60f6fdc8cb4237e4ba24b5', (1, 2), 1
]
},
'labels': ["0", "1"]
},
'mrpc': {
'url': {
'train_data':
'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_train.txt',
'dev_id': 'https://dataset.bj.bcebos.com/glue/mrpc/dev_ids.tsv',
'test_data':
'https://dataset.bj.bcebos.com/glue/mrpc/msr_paraphrase_test.txt'
},
'md5': {
'train_data': '793daf7b6224281e75fe61c1f80afe35',
'dev_id': '7ab59a1b04bd7cb773f98a0717106c9b',
'test_data': 'e437fdddb92535b820fe8852e2df8a49'
},
'splits': {
'train': [
os.path.join('MRPC', 'train.tsv'),
'dc2dac669a113866a6480a0b10cd50bf', (3, 4, 0), 1
],
'dev': [
os.path.join('MRPC', 'dev.tsv'),
'185958e46ba556b38c6a7cc63f3a2135', (3, 4, 0), 1
],
'test': [
os.path.join('MRPC', 'test.tsv'),
'4825dab4b4832f81455719660b608de5', (3, 4), 1
]
},
'labels': ["0", "1"]
}
}
def _get_data(self, mode, **kwargs):
builder_config = self.BUILDER_CONFIGS[self.name]
if self.name != 'mrpc':
default_root = os.path.join(DATA_HOME, self.__class__.__name__)
filename, data_hash, _, _ = builder_config['splits'][mode]
fullname = os.path.join(default_root, filename)
if not os.path.exists(fullname) or (
data_hash and not md5file(fullname) == data_hash):
get_path_from_url(builder_config['url'], default_root,
builder_config['md5'])
else:
default_root = os.path.join(DATA_HOME, self.__class__.__name__)
filename, data_hash, _, _ = builder_config['splits'][mode]
fullname = os.path.join(default_root, filename)
if not os.path.exists(fullname) or (
data_hash and not md5file(fullname) == data_hash):
if mode in ('train', 'dev'):
dev_id_path = get_path_from_url(
builder_config['url']['dev_id'],
os.path.join(default_root, 'MRPC'),
builder_config['md5']['dev_id'])
train_data_path = get_path_from_url(
builder_config['url']['train_data'],
os.path.join(default_root, 'MRPC'),
builder_config['md5']['train_data'])
# read dev data ids
dev_ids = []
print(dev_id_path)
with open(dev_id_path, encoding='utf-8') as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split('\t'))
# generate train and dev set
train_path = os.path.join(default_root, 'MRPC', 'train.tsv')
dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv')
with open(train_data_path, encoding='utf-8') as data_fh:
with open(
train_path, 'w', encoding='utf-8') as train_fh:
with open(dev_path, 'w', encoding='utf8') as dev_fh:
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split(
'\t')
example = '%s\t%s\t%s\t%s\t%s\n' % (
label, id1, id2, s1, s2)
if [id1, id2] in dev_ids:
dev_fh.write(example)
else:
train_fh.write(example)
else:
test_data_path = get_path_from_url(
builder_config['url']['test_data'],
os.path.join(default_root, 'MRPC'),
builder_config['md5']['test_data'])
test_path = os.path.join(default_root, 'MRPC', 'test.tsv')
with open(test_data_path, encoding='utf-8') as data_fh:
with open(test_path, 'w', encoding='utf-8') as test_fh:
header = data_fh.readline()
test_fh.write(
'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n')
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split(
'\t')
test_fh.write('%d\t%s\t%s\t%s\t%s\n' %
(idx, id1, id2, s1, s2))
return fullname
def _read(self, filename, split):
_, _, field_indices, num_discard_samples = self.BUILDER_CONFIGS[
self.name]['splits'][split]
with open(filename, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if idx < num_discard_samples:
continue
line_stripped = line.strip().split('\t')
if not line_stripped:
continue
example = [line_stripped[indice] for indice in field_indices]
if self.name in ['cola', 'sst-2']:
yield {
'sentence': example[0]
} if 'test' in split else {
'sentence': example[0],
'labels': example[-1]
}
else:
yield {
'sentence1': example[0],
'sentence2': example[1]
} if 'test' in split else {
'sentence1': example[0],
'sentence2': example[1],
'labels': example[-1]
}
[docs] def get_labels(self):
"""
Return labels of the Glue task.
"""
return self.BUILDER_CONFIGS[self.name]['labels']