Source code for paddlenlp.datasets.glue

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import json
import os

from paddle.dataset.common import md5file
from import get_path_from_url
from paddlenlp.utils.env import DATA_HOME
from . import DatasetBuilder

[docs]class Glue(DatasetBuilder): BUILDER_CONFIGS = { 'cola': { 'url': "", 'md5': 'b178a7c2f397b0433c39c7caf50a3543', 'splits': { 'train': [ os.path.join('CoLA', 'train.tsv'), 'c79d4693b8681800338aa044bf9e797b', (3, 1), 0 ], 'dev': [ os.path.join('CoLA', 'dev.tsv'), 'c5475ccefc9e7ca0917294b8bbda783c', (3, 1), 0 ], 'test': [ os.path.join('CoLA', 'test.tsv'), 'd8721b7dedda0dcca73cebb2a9f4259f', (1, ), 1 ] }, 'labels': ["0", "1"] }, 'sst-2': { 'url': "", 'md5': '9f81648d4199384278b86e315dac217c', 'splits': { 'train': [ os.path.join('SST-2', 'train.tsv'), 'da409a0a939379ed32a470bc0f7fe99a', (0, 1), 1 ], 'dev': [ os.path.join('SST-2', 'dev.tsv'), '268856b487b2a31a28c0a93daaff7288', (0, 1), 1 ], 'test': [ os.path.join('SST-2', 'test.tsv'), '3230e4efec76488b87877a56ae49675a', (1, ), 1 ] }, 'labels': ["0", "1"] }, 'sts-b': { 'url': '', 'md5': 'd573676be38f1a075a5702b90ceab3de', 'splits': { 'train': [ os.path.join('STS-B', 'train.tsv'), '4f7a86dde15fe4832c18e5b970998672', (7, 8, 9), 1 ], 'dev': [ os.path.join('STS-B', 'dev.tsv'), '5f4d6b0d2a5f268b1b56db773ab2f1fe', (7, 8, 9), 1 ], 'test': [ os.path.join('STS-B', 'test.tsv'), '339b5817e414d19d9bb5f593dd94249c', (7, 8), 1 ] }, 'labels': None }, 'qqp': { 'url': '', 'md5': '884bf26e39c783d757acc510a2a516ef', 'splits': { 'train': [ os.path.join('QQP', 'train.tsv'), 'e003db73d277d38bbd83a2ef15beb442', (3, 4, 5), 1 ], 'dev': [ os.path.join('QQP', 'dev.tsv'), 'cff6a448d1580132367c22fc449ec214', (3, 4, 5), 1 ], 'test': [ os.path.join('QQP', 'test.tsv'), '73de726db186b1b08f071364b2bb96d0', (1, 2), 1 ] }, 'labels': ["0", "1"] }, 'mnli': { 'url': '', 'md5': 'e343b4bdf53f927436d0792203b9b9ff', 'splits': { 'train': [ os.path.join('MNLI', 'train.tsv'), '220192295e23b6705f3545168272c740', (8, 9, 11), 1 ], 'dev_matched': [ os.path.join('MNLI', 'dev_matched.tsv'), 'c3fa2817007f4cdf1a03663611a8ad23', (8, 9, 15), 1 ], 'dev_mismatched': [ os.path.join('MNLI', 'dev_mismatched.tsv'), 'b219e6fe74e4aa779e2f417ffe713053', (8, 9, 15), 1 ], 'test_matched': [ os.path.join('MNLI', 'test_matched.tsv'), '33ea0389aedda8a43dabc9b3579684d9', (8, 9), 1 ], 'test_mismatched': [ os.path.join('MNLI', 'test_mismatched.tsv'), '7d2f60a73d54f30d8a65e474b615aeb6', (8, 9), 1 ] }, 'labels': ["contradiction", "entailment", "neutral"] }, 'qnli': { 'url': '', 'md5': 'b4efd6554440de1712e9b54e14760e82', 'splits': { 'train': [ os.path.join('QNLI', 'train.tsv'), '5e6063f407b08d1f7c7074d049ace94a', (1, 2, 3), 1 ], 'dev': [ os.path.join('QNLI', 'dev.tsv'), '1e81e211959605f144ba6c0ad7dc948b', (1, 2, 3), 1 ], 'test': [ os.path.join('QNLI', 'test.tsv'), 'f2a29f83f3fe1a9c049777822b7fa8b0', (1, 2), 1 ] }, 'labels': ["entailment", "not_entailment"] }, 'rte': { 'url': '', 'md5': 'bef554d0cafd4ab6743488101c638539', 'splits': { 'train': [ os.path.join('RTE', 'train.tsv'), 'd2844f558d111a16503144bb37a8165f', (1, 2, 3), 1 ], 'dev': [ os.path.join('RTE', 'dev.tsv'), '973cb4178d4534cf745a01c309d4a66c', (1, 2, 3), 1 ], 'test': [ os.path.join('RTE', 'test.tsv'), '6041008f3f3e48704f57ce1b88ad2e74', (1, 2), 1 ] }, 'labels': ["entailment", "not_entailment"] }, 'wnli': { 'url': '', 'md5': 'a1b4bd2861017d302d29e42139657a42', 'splits': { 'train': [ os.path.join('WNLI', 'train.tsv'), '5cdc5a87b7be0c87a6363fa6a5481fc1', (1, 2, 3), 1 ], 'dev': [ os.path.join('WNLI', 'dev.tsv'), 'a79a6dd5d71287bcad6824c892e517ee', (1, 2, 3), 1 ], 'test': [ os.path.join('WNLI', 'test.tsv'), 'a18789ba4f60f6fdc8cb4237e4ba24b5', (1, 2), 1 ] }, 'labels': ["0", "1"] }, 'mrpc': { 'url': { 'train_data': '', 'dev_id': '', 'test_data': '' }, 'md5': { 'train_data': '793daf7b6224281e75fe61c1f80afe35', 'dev_id': '7ab59a1b04bd7cb773f98a0717106c9b', 'test_data': 'e437fdddb92535b820fe8852e2df8a49' }, 'splits': { 'train': [ os.path.join('MRPC', 'train.tsv'), 'dc2dac669a113866a6480a0b10cd50bf', (3, 4, 0), 1 ], 'dev': [ os.path.join('MRPC', 'dev.tsv'), '185958e46ba556b38c6a7cc63f3a2135', (3, 4, 0), 1 ], 'test': [ os.path.join('MRPC', 'test.tsv'), '4825dab4b4832f81455719660b608de5', (3, 4), 1 ] }, 'labels': ["0", "1"] } } def _get_data(self, mode, **kwargs): builder_config = self.BUILDER_CONFIGS[] if != 'mrpc': default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, data_hash, _, _ = builder_config['splits'][mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): get_path_from_url(builder_config['url'], default_root, builder_config['md5']) else: default_root = os.path.join(DATA_HOME, self.__class__.__name__) filename, data_hash, _, _ = builder_config['splits'][mode] fullname = os.path.join(default_root, filename) if not os.path.exists(fullname) or ( data_hash and not md5file(fullname) == data_hash): if mode in ('train', 'dev'): dev_id_path = get_path_from_url( builder_config['url']['dev_id'], os.path.join(default_root, 'MRPC'), builder_config['md5']['dev_id']) train_data_path = get_path_from_url( builder_config['url']['train_data'], os.path.join(default_root, 'MRPC'), builder_config['md5']['train_data']) # read dev data ids dev_ids = [] print(dev_id_path) with open(dev_id_path, encoding='utf-8') as ids_fh: for row in ids_fh: dev_ids.append(row.strip().split('\t')) # generate train and dev set train_path = os.path.join(default_root, 'MRPC', 'train.tsv') dev_path = os.path.join(default_root, 'MRPC', 'dev.tsv') with open(train_data_path, encoding='utf-8') as data_fh: with open( train_path, 'w', encoding='utf-8') as train_fh: with open(dev_path, 'w', encoding='utf8') as dev_fh: header = data_fh.readline() train_fh.write(header) dev_fh.write(header) for row in data_fh: label, id1, id2, s1, s2 = row.strip().split( '\t') example = '%s\t%s\t%s\t%s\t%s\n' % ( label, id1, id2, s1, s2) if [id1, id2] in dev_ids: dev_fh.write(example) else: train_fh.write(example) else: test_data_path = get_path_from_url( builder_config['url']['test_data'], os.path.join(default_root, 'MRPC'), builder_config['md5']['test_data']) test_path = os.path.join(default_root, 'MRPC', 'test.tsv') with open(test_data_path, encoding='utf-8') as data_fh: with open(test_path, 'w', encoding='utf-8') as test_fh: header = data_fh.readline() test_fh.write( 'index\t#1 ID\t#2 ID\t#1 String\t#2 String\n') for idx, row in enumerate(data_fh): label, id1, id2, s1, s2 = row.strip().split( '\t') test_fh.write('%d\t%s\t%s\t%s\t%s\n' % (idx, id1, id2, s1, s2)) return fullname def _read(self, filename, split): _, _, field_indices, num_discard_samples = self.BUILDER_CONFIGS[]['splits'][split] with open(filename, 'r', encoding='utf-8') as f: for idx, line in enumerate(f): if idx < num_discard_samples: continue line_stripped = line.strip().split('\t') if not line_stripped: continue example = [line_stripped[indice] for indice in field_indices] if in ['cola', 'sst-2']: yield { 'sentence': example[0] } if 'test' in split else { 'sentence': example[0], 'labels': example[-1] } else: yield { 'sentence1': example[0], 'sentence2': example[1] } if 'test' in split else { 'sentence1': example[0], 'sentence2': example[1], 'labels': example[-1] }
[docs] def get_labels(self): """ Return labels of the Glue task. """ return self.BUILDER_CONFIGS[]['labels']