How to get generated tokens in T5 training_step for using user-defined metrics?

0

Issue

I am fine-tuning T5 for question answering generation and want to add additional measures (e.g., BLEU, ROUGE) for the generated answers, in addition to the loss function.

For that, I believe it would be necessary to obtain the generated tokens (answers) at each training_step. However, after reading the source code, I still have no clue how to add that.

Below I leave an excerpt of my code. I can extract the output.loss and output.logits, but I didn’t find a way to get the generated tokens to use additional evaluation metrics.

Thanks in advance.

class MyQAModel(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)

  def forward(self, input_ids, attention_mask, labels=None):
    output = self.model(
        input_ids, 
        attention_mask=attention_mask,
        labels=labels)

    return output.loss, output.logits

  def training_step(self, batch, batch_idx):
    input_ids = batch['input_ids']
    attention_mask=batch['attention_mask']
    labels = batch['labels']
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions":outputs, "labels": labels}
    
    ...
    (code continues...)
    ....

Solution

You can obtain predicted tokens from output.logits [batch, seq_len, vocab_size] using torch.argmax(output.logits, dim=-1) [batch, seq_len]. Then, to decode the generated sentence from a batch of token ids, run

generated_sentences = []
for predicted_token_ids in torch.argmax(output.logits, dim=-1):
    generated_sentences.append(tokenizer.decode(predicted_token_ids))

# For getting original sentences
original_sentences = []
for sent_ids in input_ids:
    original_sentences.append(tokenizer.decode(sent_ids))

Answered By – joe32140

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave A Reply

Your email address will not be published.

This website uses cookies to improve your experience. We'll assume you're ok with this, but you can opt-out if you wish. Accept Read More