Pagefy

Pagefy

Back to AI Engineering

Fine Tuning Representation Models for Classification

Hands On Large Language Models by Jay Alammar & Maarten GrootendorstBuy the book
Sign in to save bookmarks, reading progress, and highlights.

Chapter 11: Fine-Tuning Representation Models for Classification

Introduction

Chapter 4 used pretrained classifiers and embedders frozen. This chapter unfreezes them. When you have enough data, fine-tuning the model end-to-end produces stronger task-specific results. We work through four scenarios on the Rotten Tomatoes dataset and CoNLL-2003: full BERT fine-tuning, layer-freezing experiments, SetFit for few-shot classification, continued pretraining with masked language modeling for domain adaptation, and named-entity recognition at the token level.


Section 1: Frozen vs. Fine-Tuned

In Chapter 4 the BERT or embedding model stayed frozen and only a small classifier on top was trained:

Now we let gradients flow through the entire model. Both the BERT encoder and the classification head get updated:

The two pieces co-adapt. BERT learns better representations for this task, and the head learns to consume them.


Section 2: Fine-Tuning a Pretrained BERT Model

Same dataset as Chapter 4: Rotten Tomatoes (5,331 positive plus 5,331 negative reviews):

from datasets import load_dataset
tomatoes = load_dataset("rotten_tomatoes")
train_data, test_data = tomatoes["train"], tomatoes["test"]

Load bert-base-cased with a 2-class classification head:

from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding

model_id = "bert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_id)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_train = train_data.map(preprocess_function, batched=True)
tokenized_test = test_data.map(preprocess_function, batched=True)

Define an F1 metric:

import numpy as np
from datasets import load_metric

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    f1 = load_metric("f1").compute(predictions=predictions, references=labels)["f1"]
    return {"f1": f1}

Train:

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    "model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch",
    report_to="none",
)

trainer = Trainer(
    model=model, args=training_args,
    train_dataset=tokenized_train, eval_dataset=tokenized_test,
    tokenizer=tokenizer, data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()
trainer.evaluate()
# eval_f1: 0.85

F1 = 0.85 vs. Chapter 4's frozen task-specific RoBERTa at 0.80 — fine-tuning beats out-of-domain pretrained models even with one epoch of training in a few minutes.


Section 3: Freezing Layers

You can selectively freeze parts of the network. This is useful when you're short on compute or data.

3.1 The BERT-Base Architecture

BERT-base has 12 stacked encoder blocks (each with attention + dense + layer norm), token, position, and segment embeddings, a pooler, and on top of that we add a classifier:

You can list all parameters by name:

for name, param in model.named_parameters():
    print(name)
# bert.embeddings.word_embeddings.weight
# bert.embeddings.position_embeddings.weight
# ...
# bert.encoder.layer.0.attention.self.query.weight
# ...
# bert.encoder.layer.11.output.LayerNorm.bias
# bert.pooler.dense.weight
# classifier.weight
# classifier.bias

3.2 Freeze Everything Except the Classifier

for name, param in model.named_parameters():
    param.requires_grad = name.startswith("classifier")

Training is much faster but quality plummets to F1 = 0.63. The model can't adapt its representations to your task.

3.3 Freeze All But the Last Encoder Block

for index, (name, param) in enumerate(model.named_parameters()):
    if index < 165:    # encoder block 11 starts here
        param.requires_grad = False

Result: F1 = 0.80, close to fully unfrozen (0.85) at a fraction of compute.

The chapter's experiment shows freezing/training boundaries vs F1:

Training the last 5 of 12 encoder blocks gets you most of the way there.

Bottom line: freezing buys speed, full fine-tuning buys quality. Pick a balance based on your compute budget — and remember that the gap widens as you train for more epochs.


Section 4: Few-Shot Classification with SetFit

What if you only have a handful of labels per class?

SetFit ("Efficient few-shot learning without prompts") fine-tunes a sentence-transformer using just 16 examples per class, competitive with full BERT fine-tuning on much larger datasets.

4.1 The SetFit Recipe

The recipe has three steps.

Step 1 — Sample Training Pairs

Treat sentences in the same class as positives and different classes as negatives. With 16 sentences per class, that's 16·15/2 = 120 positive pairs per class for free.

Step 2 — Fine-Tune the Embedding Model with Contrastive Learning

This is the same pattern from Chapter 10. Embeddings get nudged so that same-class sentences are close and different-class sentences are far apart.

Step 3 — Train a Lightweight Classifier

Embed all sentences with the freshly-fine-tuned encoder, then train a logistic regression (or a differentiable classification head) on top.

Full pipeline:

4.2 Hands-On

from setfit import sample_dataset, SetFitModel
from setfit import TrainingArguments as SetFitTrainingArguments
from setfit import Trainer as SetFitTrainer

# Sample 16 examples per class — only 32 docs vs 8500
sampled_train_data = sample_dataset(tomatoes["train"], num_samples=16)

model = SetFitModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

args = SetFitTrainingArguments(
    num_epochs=3,        # contrastive learning epochs
    num_iterations=20,   # text pairs to generate per sample
)
args.eval_strategy = args.evaluation_strategy

trainer = SetFitTrainer(
    model=model, args=args,
    train_dataset=sampled_train_data,
    eval_dataset=test_data, metric="f1",
)
trainer.train()
trainer.evaluate()
# {'f1': 0.84}

The training output reports 1,280 sentence pairs generated from 32 input examples (20 pairs × 32 samples × 2 for positive/negative). F1 = 0.84 with 32 labels, within striking distance of the 0.85 we got with 8,500 labels.

SetFit also supports zero-shot: it synthesizes examples like "This example is happy" / "This example is sad" from label names alone.

Custom head: pass use_differentiable_head=True and head_params={"out_features": num_classes} to use a neural classification head instead of logistic regression.


Section 5: Continued Pretraining with Masked Language Modeling

The default flow is pretrain → fine-tune for the task:

If your data is domain-specific (medical, legal, movie reviews, internal jargon), insert a third step: continue pretraining with MLM on your domain corpus before fine-tuning. The model learns domain vocabulary and patterns first, and the classifier on top benefits.

5.1 Setting Up MLM

from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling

model = AutoModelForMaskedLM.from_pretrained("bert-base-cased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_train = train_data.map(preprocess_function, batched=True).remove_columns("label")
tokenized_test = test_data.map(preprocess_function, batched=True).remove_columns("label")

# Mask 15% of tokens at random
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15,
)

5.2 Token Masking vs. Whole-Word Masking

Token masking randomly masks 15% of tokens (which might split a word) and converges faster. Whole-word masking via DataCollatorForWholeWordMask is a harder task that produces better representations but converges more slowly.

5.3 Training and Verifying

training_args = TrainingArguments(
    "model", learning_rate=2e-5,
    per_device_train_batch_size=16, per_device_eval_batch_size=16,
    num_train_epochs=10, weight_decay=0.01,
    save_strategy="epoch", report_to="none",
)

trainer = Trainer(
    model=model, args=training_args,
    train_dataset=tokenized_train, eval_dataset=tokenized_test,
    tokenizer=tokenizer, data_collator=data_collator,
)

tokenizer.save_pretrained("mlm")
trainer.train()
model.save_pretrained("mlm")

Verify domain adaptation by filling masks before vs after:

from transformers import pipeline

# Original BERT
mask_filler = pipeline("fill-mask", model="bert-base-cased")
mask_filler("What a horrible [MASK]!")
# → idea, dream, thing, day, thought   (generic words)

# After continued pretraining on movie reviews
mask_filler = pipeline("fill-mask", model="mlm")
mask_filler("What a horrible [MASK]!")
# → movie, film, mess, comedy, story   (movie-specific words!)

The model has clearly absorbed the domain. From here, you'd fine-tune for classification:

model = AutoModelForSequenceClassification.from_pretrained("mlm", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("mlm")

Section 6: Named-Entity Recognition (NER)

NER classifies individual tokens, used for de-identification, anonymization, and information extraction.

Compared to document classification, the model now emits a label per token rather than per document:

6.1 The CoNLL-2003 Dataset

dataset = load_dataset("conll2003", trust_remote_code=True)

example = dataset["train"][848]
# {'tokens':  ['Dean', 'Palmer', 'hit', 'his', '30th', 'homer', 'for', 'the', 'Rangers', '.'],
#  'ner_tags': [   1,        2,    0,    0,     0,      0,      0,    0,        3,       0]}

The label scheme uses BIO tagging. B-X is the beginning of an entity span, I-X is inside the span, and O is non-entity:

label2id = {"O": 0, "B-PER": 1, "I-PER": 2, "B-ORG": 3, "I-ORG": 4,
            "B-LOC": 5, "I-LOC": 6, "B-MISC": 7, "I-MISC": 8}

So "Dean Palmer" is tagged [B-PER, I-PER]: one person, not two.

6.2 Aligning Word Labels with Subword Tokens

Tokenizing 'homer' produces 'home' plus '##r'. Words tagged B-PER like 'Maarten' become ['Ma', '##arte', '##n']. Only the first subword should keep B-PER, while the rest become I-PER (still inside the same entity).

def align_labels(examples):
    token_ids = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
    labels = examples["ner_tags"]
    updated_labels = []

    for index, label in enumerate(labels):
        word_ids = token_ids.word_ids(batch_index=index)
        previous_word_idx = None
        label_ids = []

        for word_idx in word_ids:
            if word_idx != previous_word_idx:
                # Start of a new word — keep original label, or -100 for special tokens
                previous_word_idx = word_idx
                updated_label = -100 if word_idx is None else label[word_idx]
                label_ids.append(updated_label)
            elif word_idx is None:
                label_ids.append(-100)
            else:
                # Continuation: turn B-X (odd) into I-X (even)
                updated_label = label[word_idx]
                if updated_label % 2 == 1:
                    updated_label += 1
                label_ids.append(updated_label)

        updated_labels.append(label_ids)

    token_ids["labels"] = updated_labels
    return token_ids

tokenized = dataset.map(align_labels, batched=True)

-100 is a special label that PyTorch and HuggingFace ignore during loss computation, applied to [CLS], [SEP], and any padding.

6.3 Training NER

Use DataCollatorForTokenClassification (handles per-token padding) and AutoModelForTokenClassification:

from transformers import AutoModelForTokenClassification, DataCollatorForTokenClassification
import evaluate

model = AutoModelForTokenClassification.from_pretrained(
    "bert-base-cased",
    num_labels=len(id2label), id2label=id2label, label2id=label2id,
)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

seqeval = evaluate.load("seqeval")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=2)
    true_predictions, true_labels = [], []
    for prediction, label in zip(predictions, labels):
        for tp, tl in zip(prediction, label):
            if tl != -100:
                true_predictions.append([id2label[tp]])
                true_labels.append([id2label[tl]])
    return {"f1": seqeval.compute(predictions=true_predictions, references=true_labels)["overall_f1"]}

trainer = Trainer(
    model=model, args=training_args,
    train_dataset=tokenized["train"], eval_dataset=tokenized["test"],
    tokenizer=tokenizer, data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()
trainer.evaluate()

6.4 Inference

from transformers import pipeline

trainer.save_model("ner_model")
token_classifier = pipeline("token-classification", model="ner_model")
token_classifier("My name is Maarten.")
# [{'entity': 'B-PER', 'score': 0.995, 'word': 'Ma',    'start': 11, 'end': 13},
#  {'entity': 'I-PER', 'score': 0.993, 'word': '##arte','start': 13, 'end': 17},
#  {'entity': 'I-PER', 'score': 0.995, 'word': '##n',   'start': 17, 'end': 18}]

The three subword tokens correctly stitch back into one person entity ("Maarten"). Use aggregation_strategy="simple" in the pipeline to merge them automatically.


Summary

  • Fine-tuning a pretrained BERT (with its classification head) on Rotten Tomatoes hits F1 = 0.85 in one epoch, beating Chapter 4's frozen-RoBERTa at 0.80.
  • Freezing layers is a compute-vs-quality trade. Freezing everything except the classifier collapses quality (0.63). Training only the last encoder block recovers most of it (0.80). Beyond about 5 trainable blocks, returns are diminishing.
  • SetFit does few-shot classification in three steps. First, build positive and negative pairs from in-class and out-class sentences. Second, contrastive-fine-tune a sentence-transformer. Third, train a small classifier on the resulting embeddings. F1 ≈ 0.84 with 32 labels vs 0.85 with 8,500.
  • Continued pretraining with masked language modeling on your domain corpus, before task fine-tuning, adapts BERT's vocabulary and representations to your domain. Useful for medical, legal, or jargon-heavy internal text.
  • Token vs whole-word masking trades convergence speed for representation quality.
  • NER is per-token classification with BIO tagging (B-PER, I-PER, O, and so on). The trickiest part is aligning word-level labels with subword tokenization. The first subword keeps the B- tag, and subsequent subwords inherit I-. Use -100 to ignore special tokens ([CLS], [SEP], padding) in the loss.
  • Tools: AutoModelForSequenceClassification for document-level classification, AutoModelForMaskedLM for MLM, AutoModelForTokenClassification for NER, plus matching DataCollator* classes.