generation_utils

class GenerationMixin[source]

Bases: object

This class implements the interface for generation task.

It’s used as the base class of paddlenlp.transformers.PretrainedModel.

generate(input_ids=None, max_length=20, min_length=0, decode_strategy='greedy_search', temperature=1.0, top_k=0, top_p=1.0, num_beams=1, length_penalty=1.0, early_stopping=False, bos_token_id=None, eos_token_id=None, pad_token_id=None, num_return_sequences=1, use_cache=True, **model_kwargs)[source]

The interface for generation task. This method can generate sequences by using decoding strategy. Currently, there are three decoding strategies supported: “greedy_search”, “sampling” and “beam_search”.

Parameters
  • input_ids (Tensor, optional) – The input sequence ids for the generation. It is a Tensor with shape [batch_size, sequence_length]. The data type should be int32 or int64. Default to None, which we will initialize it as a Tensor with shape [1, 1], filled with the value bos_token_id.

  • max_length (int, optional) – The maximum length of the sequence to be generated. Default to 20.

  • min_length (int, optional) – The minimum length of the sequence to be generated. Default to 0.

  • decode_strategy (str, optional) – The decoding strategy in generation. Currently, there are three decoding strategies supported: “greedy_search”, “sampling” and “beam_search”. Default to “greedy_search”.

  • temperature (float, optional) – The value used to module the next token probabilities in the “sampling” strategy. Default to 1.0, which means no effect.

  • top_k (int, optional) – The number of highest probability tokens to keep for top-k-filtering in the “sampling” strategy. Default to 0, which means no effect.

  • top_p (float, optional) – The cumulative probability for top-p-filtering in the “sampling” strategy. The value should satisfy \(0 <= top\_p < 1\). Default to 1.0, which means no effect.

  • num_beams (int, optional) – The number of beams in the “beam_search” strategy. Default to 1.

  • length_penalty (float, optional) – The exponential penalty to the sequence length in the “beam_search” strategy. If \(length\_penalty < 1.0\), the model will generate shorter sequences. If \(length\_penalty > 1.0\), the model will generate longer sequences. Default to 1.0, which means no penalty.

  • early_stopping (bool, optional) – Whether to stop searching in the “beam_search” strategy when at least num_beams sentences are finished per batch or not. Default to False.

  • bos_token_id (int, optional) – The id of the bos_token. Default to None.

  • eos_token_id (int, optional) – The id of the eos_token. Default to None.

  • pad_token_id (int, optional) – The id of the pad_token. Default to None.

  • num_return_sequences (int, optional) – The number of returned sequences for each sequence in the batch. Default to 1.

  • use_cache – (bool, optional): Whether or not use the model cache to speed up decoding. Default to True.

  • model_kwargs (dict) – It can be used to specify additional kwargs passed to the model.

Returns

It is a tuple includes generated sequence ids and

scores. The generated sequence ids is a Tensor with shape [batch_size * num_return_sequences, sequence_length]. The data type is same as the input input_ids. The scores is a Tensor with shape [batch_size * num_return_sequences, 1]. The data type is float32 or float64, the same as the parameters in the model.

tuple: It is a tuple contains two elements: ids and scores. Each element is a Tensor.

With the fields:

  • ids (Tensor): The ids of the generated sequences. It is a Tensor

    with shape [batch_size * num_return_sequences, sequence_length]. The data type is same as the input input_ids.

  • scores (Tensor):The scores of the generated sequences. It is a

    Tensor with shape [batch_size * num_return_sequences, 1]. The data type is float32 or float64, which is the same as the parameters in the model.

Return type

tuple (Tensor)

Example

import paddle
from paddlenlp.transformers import (
    UnifiedTransformerLMHeadModel,
    UnifiedTransformerTokenizer
)

paddle.seed(2)

model_name_or_path = 'unified_transformer-12L-cn-luge'
model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)

history = "早上好,今天空气质量不错。"
inputs = tokenizer.dialogue_encode(history, task_type='chitchat',
    add_start_token_as_response=True, return_tensors=True)

# Generate the sequence by using "greedy_search" strategy
ids, scores = model.generate(
    input_ids=inputs['input_ids'],
    token_type_ids=inputs['token_type_ids'],
    position_ids=inputs['position_ids'],
    attention_mask=inputs['attention_mask'],
    decode_strategy="greedy_search")
print(ids.shape, scores.shape)
# [1, 3] [1, 1]
sequence_ids = ids.numpy().tolist()[0]
sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
response = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
print(response)
# 是的

# Generate 2 sequences by using "sampling" strategy (top_k=5)
ids, scores = model.generate(
    input_ids=inputs['input_ids'],
    token_type_ids=inputs['token_type_ids'],
    position_ids=inputs['position_ids'],
    attention_mask=inputs['attention_mask'],
    decode_strategy="sampling",
    top_k=5,
    num_return_sequences=2)
print(ids.shape, scores.shape)
# [2, 7] [2, 1]
response = []
for sequence_ids in ids.numpy().tolist():
    sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
    text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
    response.append(text)
print(response)
# ['天气好,心情也好', '你也是']

# Generate 2 sequences by using "beam_search" strategy (num_beams=5)
ids, scores = model.generate(
    input_ids=inputs['input_ids'],
    token_type_ids=inputs['token_type_ids'],
    position_ids=inputs['position_ids'],
    attention_mask=inputs['attention_mask'],
    decode_strategy="beam_search",
    num_beams=5,
    num_return_sequences=2)
print(ids.shape, scores.shape)
# [2, 3] [2, 1]
response = []
for sequence_ids in ids.numpy().tolist():
    sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)]
    text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False)
    response.append(text)
print(response)
# ['是的', '嗯嗯']