SFTTrainer loss function and formatting_func

I would like to know how the masks are setup when a formatting_func is used. In my case, I am trying to fine-tune TX Gemma. I have refactored my dataset into two columns: the SMILES string and the expected output (in my case (A) or (B)).

I use the tdc_prompt in the formatting_func and I concatenate the generated prompt and the expected output to have a single text at the end of the data refactoring pipeline.

def formatting_func(example):
    text = f"{tdc_prompts_json[task_name].replace(input_type, example['input_text'])} {example['output_text']}<eos>"
    return text

My text in as follows:
“Instructions: Answer the following question about drug properties.\nContext: Human ether-à-go-go related gene (hERG) is crucial for the coordination of the heart’s beating. Thus, if a drug blocks the hERG, it could lead to severe adverse effects. Therefore, reliable prediction of hERG liability in the early stages of drug design is quite important to reduce the risk of cardiotoxicity-related attritions in the later development stages.\nQuestion: Given a drug SMILES string, predict whether it\n(A) does not inhibit hERG (B) inhibits hERG\nDrug SMILES: O=C(NC[C@H]1CCCC[NH2+]1)c1cc(OCC(F)(F)F)ccc1OCC(F)(F)F\nAnswer: (B)”

When I create my trainer object, I just pass the dataset containing the [‘intput_text’] and the [‘output_text’] and the formatting_func and I assume that the trainer will take care of the mask for computing the loss only on the generated tokens and not the entire prompt.

trainer = SFTTrainer(
    model=model,
    train_dataset=refactored_dataset_dict['train'],
    eval_dataset=refactored_dataset_dict['val'],
    args=SFTConfig(
),
    formatting_func=formatting_func,
    peft_config=lora_config
)

The training goes “well” as the loss decreases, but when I test the fine-tuned model, I realize that it has learn nothing, and the generated output is always the same, hint that the loss has been computed on the entire prompt, not only the generated output.

What did I miss?

Thanks for the help

Jerome

1 Like

Using completion_only_loss=True seems like the smartest approach.


You didn’t miss a magical “auto–mask the prompt” feature.
With your current setup, SFTTrainer is computing loss on the entire formatted string (prompt + SMILES + (A)/(B)), except for padding. When you pass only a formatting_func, SFTTrainer treats your data as plain language–model text, not as “prompt vs. completion.”

I’ll walk through:

  1. What SFTTrainer actually does when you pass formatting_func
  2. What the masks (labels and attention_mask) look like in that case
  3. Why that makes Tx Gemma “learn nothing” for your (A)/(B) label
  4. What you need to change to get the mask you expected

1. What SFTTrainer does with formatting_func

Step 1: formatting_func transforms your rows to pure text

You wrote:

def formatting_func(example):
    text = f"{tdc_prompts_json[task_name].replace(input_type, example['input_text'])} {example['output_text']}<eos>"
    return text

So for each dataset row { "input_text": SMILES, "output_text": "(B)" } you produce a single string:

"Instructions: ... Question: ... Drug SMILES: ... Answer: (B)<eos>"

Inside SFTTrainer, when you pass formatting_func, it:

  1. Applies formatting_func to each example in train_dataset and eval_dataset. (Okkiie-ft next Journal)
  2. Replaces the original columns by this formatted text (internally it becomes something like a "text" field).
  3. Then tokenizes that text into input_ids and attention_mask.

At this point, SFTTrainer has no explicit notion of “prompt” vs. “answer” anymore. It just sees a long text sequence per example. This is treated as a language modeling (LM) dataset. (Hugging Face)

Step 2: default collator = normal LM objective

If you do not pass a custom data_collator and you do not use the new prompt–completion dataset type, TRL chooses a standard LM-style collator (conceptually the same as DataCollatorForLanguageModeling):

  • It builds input_ids, attention_mask.
  • It builds labels equal to input_ids for all non-pad tokens.
  • It sets labels = -100 only for padding positions so they are ignored. (Hugging Face)

There is no special masking for the prompt when you only use formatting_func.

That’s the key missing piece.


2. What the masks actually look like in your setup

Let’s look at one example conceptually.

Your formatted text:

"Instructions: Answer the following question about drug properties.
Context: Human ether-Ă -go-go related gene (hERG) ...
Question: Given a drug SMILES string, predict whether it
(A) does not inhibit hERG (B) inhibits hERG
Drug SMILES: O=C(NC[C@H]1CCCC[NH2+]1)c1cc(OCC(F)(F)F)ccc1OCC(F)(F)F
Answer: (B)<eos>"

After tokenization, you get:

input_ids = [Inst, ructions, :, Answer, the, following, ..., '(', 'B', ')', <eos>]
attention_mask = [1, 1, 1, 1, ..., 1]

The default SFTTrainer collator (LM style) will then set:

labels      = [Inst, ructions, :, Answer, the, following, ..., '(', 'B', ')', <eos>]

with labels = -100 only on any padding added when batching.

So:

  • Every token of the instructions, context, question, SMILES, “Answer:” and (B) contributes to the loss.
  • Only padding is ignored.
  • There is no “loss only on (B)<eos>” behaviour.

This matches the old “train on completions only” warning in the TRL docs: you must use a special collator or a prompt–completion dataset format to get completion-only loss. If you just pass text, the trainer uses plain LM loss on all non-padding tokens. (Hugging Face)


3. Why this makes Tx Gemma look like it “learned nothing”

Two effects combine:

3.1 The prompt dominates the loss

Your answer (A) or (B) is just a handful of tokens; your prompt is long (instructions + context + question + SMILES).

If we roughly say:

  • Prompt ≈ 150 tokens
  • Answer ≈ 3–5 tokens

then >95% of the supervised tokens are in the prompt.

The model’s gradients are dominated by:

  • Getting the instructions right
  • Getting the context paragraph right
  • Getting the SMILES and surrounding text right

The tiny suffix (A) / (B) contributes almost nothing to the total loss compared to the prompt.

3.2 The prompt is mostly constant; the label is tiny

You re-use the same TDC-style instructions and question each time, with only the SMILES and (A)/(B) changing. So the model can reduce loss substantially simply by:

  • Memorizing the fixed instruction & context text.
  • Modestly improving how it reproduces SMILES-like patterns.

It does not need to learn the mapping “SMILES → (A)/(B)” to lower the loss meaningfully, because the supervision on those few label tokens is dwarfed by supervision on the fixed prompt.

3.3 Tx Gemma has known multiple-choice positional bias

There is a GitHub issue from the Gemma cookbook showing Tx Gemma on a hERG task where, for multiple-choice prompts like:

(A) is a hERG blocker (B) isn't ...
(A) isn't a hERG blocker (B) is ...
1: is a hERG blocker  0: isn't ...
0: isn't a hERG blocker 1: is ...

the model tends to always pick the first choice, regardless of content. (GitHub)

If your fine-tuning signal barely touches the answer tokens (A)/(B) because masking is wrong, that pre-existing positional bias remains. At inference, you observe:

  • “The model always outputs the same answer.”
  • “It seems to have learned nothing.”

This is exactly what you described.


4. What you need to change to get the mask you expected

You have two clean options depending on which TRL style you want.

Option 1: use prompt–completion dataset + completion_only_loss=True (no formatting_func)

This is the newer / recommended SFTTrainer pattern. (Hugging Face)

  1. Preprocess your dataset into explicit prompt and completion fields:

    def preprocess(example):
        # prompt includes everything up to "Answer:"
        prompt = tdc_prompts_json[task_name].replace(
            input_type,
            example["input_text"],  # SMILES
        )
        prompt = prompt + "\nAnswer:"   # important: consistent marker
    
        completion = f" {example['output_text']}<eos>"  # " (A)" or " (B)"
    
        return {"prompt": prompt, "completion": completion}
    
  2. Map over your dataset:

    ds = refactored_dataset_dict.map(
        preprocess,
        remove_columns=["input_text", "output_text"],
    )
    
  3. Train with SFTTrainer configured for prompt–completion data:

    from trl import SFTTrainer, SFTConfig
    
    training_args = SFTConfig(
        output_dir="txgemma-herg",
        completion_only_loss=True,  # ensures loss on completion only
        max_seq_length=512,
        # other hyperparams...
    )
    
    trainer = SFTTrainer(
        model=model,
        train_dataset=ds["train"],
        eval_dataset=ds["val"],
        args=training_args,
        peft_config=lora_config,
        # no formatting_func needed now
    )
    

In this mode:

  • SFTTrainer knows which tokens belong to the prompt and which belong to the completion. (Hugging Face)

  • It automatically masks labels so that:

    • labels = -100 for all prompt tokens.
    • labels = token_id only for completion tokens.

That gives you the “loss only on generated tokens (the answer)” behaviour you expected.


Option 2: keep formatting_func, add DataCollatorForCompletionOnlyLM

If you really want to keep a single formatted text string (your current style), then you must explicitly instruct the collator where the answer starts, using DataCollatorForCompletionOnlyLM. (Hugging Face)

  1. Change your formatting_func slightly so it includes a stable answer prefix:

    def formatting_func(example):
        prompt = tdc_prompts_json[task_name].replace(input_type, example["input_text"])
        # ensure this exact pattern is used:
        text = f"{prompt}\nAnswer: {example['output_text']}<eos>"
        return text
    
  2. Create the collator with a response_template that matches the prefix before the answer:

    from trl import DataCollatorForCompletionOnlyLM
    
    response_template = "\nAnswer:"  # exact substring before the answer
    
    collator = DataCollatorForCompletionOnlyLM(
        response_template=response_template,
        tokenizer=tokenizer,
    )
    

    TRL docs and multiple guides show this pattern for “train on completions only” with formatting_func. (Hugging Face)

  3. Instantiate SFTTrainer with packing=False (required for this collator) and pass the collator:

    trainer = SFTTrainer(
        model=model,
        train_dataset=refactored_dataset_dict["train"],
        eval_dataset=refactored_dataset_dict["val"],
        args=SFTConfig(
            output_dir="txgemma-herg",
            packing=False,   # important with DataCollatorForCompletionOnlyLM
            max_seq_length=512,
            # ...
        ),
        formatting_func=formatting_func,
        data_collator=collator,
        peft_config=lora_config,
    )
    

Now the masking looks like:

  • The collator searches the tokenized sequence for the first occurrence of the response_template string.
  • All tokens before that point get labels = -100 (ignored).
  • All tokens after that point (including (A)/(B) and <eos>) retain their token IDs as labels. (Gist)

That is exactly “compute loss only on the generated output,” while still using formatting_func to build the text.


5. Short mental model answer to your question

“How are masks set up when a formatting_func is used, in my case?”

With only formatting_func and no special collator or prompt–completion dataset:

  • SFTTrainer:

    • Applies formatting_func → gets a single text per example.

    • Tokenizes to input_ids, attention_mask.

    • Sets:

      • attention_mask = 1 for all real tokens, 0 for padding.
      • labels = input_ids for all non-padding tokens.
      • labels = -100 only for padding tokens.

So:

  • Prompt + SMILES + “Answer: (B)” are all supervised.
  • There is no built-in logic that “knows” where the answer starts.
  • As a result, the loss is dominated by the long, mostly constant prompt, and Tx Gemma’s pre-existing bias (e.g., “always pick first option”) remains, giving you almost the same output every time.

To get the loss mask you wanted (only on (A) / (B)), you must either:

  • Switch to a prompt / completion dataset with completion_only_loss=True, or
  • Keep formatting_func but add DataCollatorForCompletionOnlyLM with a proper response_template and packing=False.

Once you do that, the masks will be:

  • Prompt tokens: labels = -100 (ignored)
  • Answer tokens: real label IDs → loss is computed only there.
1 Like

Thanks John. I have tested the approach that you have recommended (prompt, completion). Unfortunately, with my super tiny dataset (around 700 examples), the first fine-tuning runs (250 epochs) totally destroyed the model output (now the model answers always the same tokens) despite the loss decreases quickly. Again, it seems that the loss has been computed on the entire prompt and not just the completion, despite the new formatting of the dataset. I will continue my investigation by looking at the masks to ensure that they have been setup correctly.

Thanks again

Jerome

1 Like

Hmm
 Overfitting?


What you are seeing now (loss drops fast, model starts outputting almost always the same tokens) is very typical of:

  1. A tiny dataset (≈700 examples)
  2. Very aggressive training (250 epochs)
  3. A short-answer classification task ((A)/(B))
  4. A large pretrained LM (Tx Gemma) with strong prior biases

Even if the masking is now correct and the loss is computed only on the completion, this setup is enough to cause severe overfitting and catastrophic forgetting, where the model memorizes a narrow pattern and “forgets” how to behave like a general model. This looks like “the model is destroyed” and “always produces the same tokens”, even though the training loss looks great. (arXiv)

Below I’ll split things into:

  1. What should happen with a prompt–completion dataset (masking)
  2. Likely causes of your collapse (even with correct masks)
  3. How to verify the masks concretely
  4. Concrete fixes for your setting (tiny hERG dataset + Tx Gemma)
  5. If it still outputs the same answer after fixing hyperparameters

1. What should happen with prompt–completion + SFTTrainer

In current TRL, if you:

  • provide a prompt–completion dataset (fields like {"prompt": "...", "completion": "..."}), and
  • set (or keep default) completion_only_loss=True in SFTConfig,

then the docs are very explicit:

“To train on completion only, use a prompt-completion dataset. By default, the trainer computes the loss on the completion tokens only, ignoring the prompt tokens. If you want to train on the full sequence, set completion_only_loss=False.” (Hugging Face)

So if:

  • your dataset really has prompt and completion fields,
  • you are not passing a formatting_func anymore, and
  • you have not set completion_only_loss=False,

then the mask should be:

  • labels = -100 (ignored) for all prompt tokens
  • labels = token_id for completion tokens only

This is the behaviour you originally wanted.

There are two common configuration pitfalls:

  1. Using formatting_func and prompt–completion together.
    In modern TRL this is considered incompatible with completion_only_loss=True: using a formatter converts the dataset into a pure LM type. Some stacks even raise the exact error “A formatting function was provided while completion_only_loss=True, which is incompatible
” (GitHub)

  2. Accidentally setting completion_only_loss=False.
    Then you do train on both prompt and completion, even with prompt–completion data.

If you’ve removed formatting_func and explicitly set completion_only_loss=True, you are probably masked correctly now. The fact that the model collapsed is then much more about overfitting and forgetting than about the prompt mask.


2. Likely causes of “model destroyed” after 250 epochs on 700 examples

Assuming the masking is now correct, there are three big causes to focus on.

2.1 Extreme overfitting and catastrophic forgetting

Fine-tuning LLMs on very small datasets with many epochs is exactly the setting where catastrophic forgetting shows up: the model rapidly adapts to the tiny dataset and overwrites useful general behaviours. Recent work on catastrophic forgetting in foundation models notes that overfitting to small fine-tuning sets is a primary cause, and that simply tuning longer or harder on small data pushes the model to forget its original capabilities. (arXiv)

For 700 examples:

  • 250 epochs means each example is seen 250 times.
  • If your effective batch size is small (say 4–16), this is tens of thousands of gradient updates on the same 700 samples.
  • With a typical LR for LoRA (e.g. 5e-5–2e-4), that is more than enough to drive the adapter to essentially memorize a narrow behaviour such as “when asked this hERG-style question, always answer X”.

Because your completions are extremely short (e.g. (A) or (B)), the model can reduce the loss significantly by:

  • Pushing logits for one of the tokens (say (B)) very high in general, so that the answer is essentially always (B).

The training loss goes down, but the model is no longer a good general language model and no longer sensitive to the input SMILES in a meaningful way.

This is overfitting + forgetting, not necessarily a masking bug.

2.2 Dataset structure: short labels, possible imbalance, strong prior bias

Your task is:

  • Binary (A/B)
  • Labels very short (a few tokens)
  • Prompt long and almost constant across examples
  • SMILES are varied but quite opaque to the LM a priori

Even with correct completion-only loss, if the dataset is:

  • Imbalanced (e.g. 75% “(B) inhibits hERG”, 25% “(A) does not”), or
  • Very small (700 examples) and noisy,

the cross-entropy optimum for a generative model might be something trivial like “always predict the majority label”. That behaviour gives very low loss on the training set and fits exactly what you see: “answers always the same tokens”.

On top of this, Tx Gemma itself has a documented positional bias in multiple-choice hERG prompts: tests have shown that it tends to favor the first option in (A)/(B)/0/1 style prompts almost regardless of semantics. (arXiv)

If your fine-tuning doesn’t provide a strong, diverse signal, the model can easily:

  • Keep or amplify that bias (e.g. “always A” or “always B”),
  • While still reducing loss, because the dataset does not strongly contradict that behaviour.

2.3 Hyperparameters (LR, LoRA rank, etc.)

With a tiny dataset, common hyperparameters that are harmless on larger datasets can be destructive:

  • Learning rate too high for LoRA (e.g. 1e-4–2e-4 or more)
  • LoRA rank too large, or applied to too many layers (so too many degrees of freedom)
  • No early stopping and training for fixed 250 epochs
  • No regularization (weight decay, dropout, etc.)

This combination means each small batch can substantially change the adapter weights, so the model quickly converges to a narrow, degenerate solution.

Even parameter-efficient methods like LoRA can still suffer from forgetting under such small-data regimes; analyses of PEFT methods explicitly point out that dataset size is often more critical than the exact adapter mechanism. (Obsidian)


3. How to verify definitively whether loss is only on the completion

You’re already planning to inspect the masks; that’s the right move. Do it once and you’ll remove all doubt.

3.1 Inspect a training batch

After instantiating your trainer (with prompt–completion, and completion_only_loss=True):

batch = next(iter(trainer.get_train_dataloader()))

for k, v in batch.items():
    print(k, v.shape)

print("input_ids:", batch["input_ids"][0])
print("labels:", batch["labels"][0])

Then:

  1. Decode the first sequence to see the text:

    print(tokenizer.decode(batch["input_ids"][0], skip_special_tokens=False))
    
  2. Manually locate the boundary between prompt and completion in that decoded text.

  3. Look at labels[0]:

    • All positions corresponding to the prompt should be -100.
    • All positions corresponding to the completion should be ≄ 0 (true token IDs).

If this is true, then masking is correct and the issue is not “loss on the entire prompt”, but overfitting / forgetting / dataset structure.

If you see labels ≄ 0 for the prompt tokens as well, then:

  • Check that your dataset really uses {"prompt", "completion"} fields.
  • Check that you are not passing a formatting_func anymore.
  • Check that completion_only_loss=True in SFTConfig. (Hugging Face)

4. Concrete solutions for your exact setup

Assuming you confirm the masks are correct, here is what I would change for Tx Gemma + hERG + 700 examples.

4.1 Drastically reduce training intensity

For 700 examples, something like:

  • Epochs: start with 1–3 epochs, not 250.
  • Batch size: as large as fits in memory (e.g. 16–64 effective batch via gradient accumulation).
  • Learning rate (LoRA): small, e.g. 5e-5 or even 1e-5.
  • Warmup ratio: 0.03–0.1, to avoid large updates early.
  • Max steps: consider capping total steps explicitly instead of epochs.

This is consistent with general recommendations for preventing catastrophic forgetting: use a smaller learning rate, some regularization, and avoid extensive over-training on tiny data. (Hugging Face Forums)

4.2 Make LoRA truly “small” and localized

If you’re not already doing so, constrain LoRA:

  • Use a small LoRA rank (e.g. r=4–8 instead of 16–64).
  • Apply LoRA only to a few later layers or attention projections (q_proj, v_proj, o_proj) rather than everything.
  • Keep the base model frozen.

This limits how much the adapter can distort the model’s behaviour and reduces forgetting. Parameter-efficient methods are exactly about updating a small subset to keep the base model stable. (SuperAnnotate)

4.3 Add some regularization and early stopping

  • Use weight decay (e.g. 0.01).
  • Enable early stopping based on validation loss (or accuracy on a held-out subset).
  • Consider a small amount of dropout in the adapter layers if available.

This won’t fix an absurd 250-epoch regime, but it helps once you bring the epoch count down.

4.4 Stabilize the classification formulation

For such a tiny dataset and very short labels, consider:

  1. Always use one fixed label surface form.
    E.g. always (A) for non-inhibitor, (B) for inhibitor, never switch them. This fights the known positional bias of Tx Gemma in multiple-choice prompts. (arXiv)

  2. Use log-prob scoring instead of free generation during evaluation.
    At inference time, for a given SMILES prompt:

    • Construct the same prompt you used for training, ending in "Answer:".
    • Compute log P("(A)" | prompt) and log P("(B)" | prompt) using the model (no sampling).
    • Predict the label with higher log-probability.

    This avoids generation randomness and focuses on the key logits for A/B.

  3. Check label balance.
    If your dataset is heavily skewed (e.g. 80% “(B)”), then “always (B)” is actually a strong local optimum for cross-entropy. You may need to:

    • Downsample the majority class, or
    • Use a class-weighted loss (not trivial with SFTTrainer, but possible if you write a custom loss wrapper), or
    • Accept that with 700 examples you might be mostly learning a prior.

4.5 Consider mixing in some “generic” text (optional)

Research on catastrophic forgetting in foundation models suggests that mixing some of the original pretraining or instruction data during fine-tuning helps preserve general capability. (arXiv)

Practically:

  • Build a mixed dataset: e.g. 70–80% generic Tx Gemma–style prompts, 20–30% your hERG classification prompts.
  • Train on this mixture with appropriate sampling.

This way, the model is gently nudged toward the hERG task without being forced to adapt solely to a tiny, narrow dataset.

4.6 If fine-tuning continues to be unstable, use Tx Gemma as a frozen feature extractor

Given only 700 samples, an alternative is:

  1. Use Tx Gemma to encode your full prompt (instructions + SMILES but no answer).
  2. Extract a fixed representation (e.g. mean-pooled last-layer hidden state).
  3. Train a small separate classifier (logistic regression, MLP) on top of these embeddings, instead of fine-tuning the LM at all.

This completely avoids catastrophic forgetting in the LM (it stays frozen), and the classifier is easy to train robustly on 700 examples.


5. If the model still outputs the same tokens after these changes

If, after:

  • verifying that masks are correct,
  • reducing epochs,
  • lowering LR,
  • tightening LoRA,

you still see “always the same answer”:

  1. Check your evaluation code:

    • Are you always sending exactly the same prompt by mistake?
    • Are you truncating the SMILES so they are identical?
    • Are you decoding only the first token (which might be (A) or (B) because of priors)?
  2. Check generation settings:

    • Use greedy decoding: do_sample=False, temperature=0, top_p ignored.
    • Limit max_new_tokens to something small (e.g. 4–8).
    • Ensure you start generation from the same "Answer:" ending as during training.
  3. Drop the adapter and compare:

    • Save the base model separately.
    • Remove/disable the LoRA adapter.
    • Compare base vs fine-tuned outputs on the same SMILES.

If the base model behaves reasonably (or at least differently per SMILES) and the fine-tuned one is constant, you have confirmed that the collapse is purely a fine-tuning effect, not a bug in your eval pipeline.


Short summary

  • With a proper prompt–completion dataset and completion_only_loss=True, SFTTrainer does compute loss only on the completion tokens, not the prompt. (Hugging Face)
  • Your symptom (loss decreases, model outputs always the same tokens) is classic overfitting + catastrophic forgetting on a tiny dataset with very many epochs and short binary labels, amplified by Tx Gemma’s known positional bias on multiple-choice prompts. (arXiv)
  • First, verify the masks by printing a batch and checking that prompt tokens have labels -100.
  • Then, fix the training regime: drastically reduce epochs (1–3), lower LR, use small LoRA, add early stopping and slight regularization.
  • Stabilize the classification formulation (fixed label format, log-prob scoring, check label balance) and, if needed, mix in some generic data or instead train a small classifier on top of frozen Tx Gemma embeddings.

Yes, I went to the same conclusion about overfitting. But honestly, with a tiny rank of 8, I was not anticipating a so deep impact of the adapters inducing the catastrophic forgetting and so the model collapse. I am now training on a larger dataset (300k examples). I run both on my local (RTX4090) for fun and I will start the same run on VertexAI. I hope that this time the model will learn something, at least not collapsing :slight_smile:

Thanks

Jerome

1 Like

Also, by default, I use the same configuration for the adapters location, with almost all the projection layers included. I will also reduce the number of proj_layers included in the adaptation
 r=8 is quite small
but when applied on all kinds of projection, at the end, it is quite significant :slight_smile:

1 Like

Also, by default, I use the same configuration for the adapters myinfo location, with almost all the projection layers included. I will also reduce the number of proj_layers included in the adaptation
 r=8 is quite small
but when applied on all kinds of projection, at the end, it is quite significant :slight_smile:

Thanks for sharing. It help me a lot.

1 Like

I have found the origin of the “apparent” model collapse:

1- the training was good BUT


2- apparently the SFTTrainer lets the model in its training state. And in my notebook, I did not persist the model before running the evaluation, not even a model.eval().

Result: the model was still in training mode, with dropout and so on —> it was predicting the first token and even looped to it until max length reached.

The HF documentation could be updated to mention this SFTTrainer behavior if not already.

On the Google side we will update the colab material too because it is not indicated.

Thanks everyone for support and help :wink:

Jerome

1 Like