Open In Colab

BERT¶

BERT is a transformer model developed for fine tuning on NLP tasks.

Use case:

  • Masked language modeling: The cat [MASK] on the mat. -> BERT -> The cat sat on the mat.
  • Next sentence prediction: [CLS] The cat sat on the mat. [SEP] The dog slept on the rug. [SEP] -> BERT -> True
  • Text classification: Assign a label to a text sequences: An email is spam or not spam. Categorize IT tickets into one of the classes.

Architecture

Training a BERT model¶

When during the training of a BERT model, the train loss is decreasing but the validation loss is increasing, it means that the model is overfitting. This means that the model is learning the training data too well and is not generalizing to unseen data.

Pooler: BERT has a pooling layer that takes the output of the last layer and compresses it into a vector of fixed size.

Fine tuning¶

What is fine tuning? Taking a pre-trained model and training (some of) its parameters on a new task. When to fine tune?

  • Plain prompt engineering is not working.
  • A smaller fine-tuned model can outperform a larger one.

How?

  • Self-supervised
  • Supervised
    • Choose fine-tuning task
    • Prepare dataset
    • Choose a base model
    • Fine-tune via supervised learning
    • Evaluate
    • Deploy
  • Reinforcement learning

Parameter training options:

  • Retrain all parameters
  • Transfer learning: Retrain only the last layer parameters.
  • Parameter-efficient fine-tuning: Retrain only a small portion of the parameters.
  • LoRA: Low-rank adaptation: Add trainable rank-2 matrices to the model.

Fine tuning BERT with LoRA¶

In [ ]:
!pip install datasets evaluate peft torch --upgrade
In [10]:
from datasets import DatasetDict, Dataset, load_dataset
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig

import evaluate
import torch
import numpy as np

model_checkpoint = "distilbert-base-uncased"

id2label = {0: "Negative", 1: "Positive"}
label2id = {"Negative": 0, "Positive": 1}

model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, 
                                                           num_labels=2, 
                                                           id2label=id2label, 
                                                           label2id=label2id,).to("mps")

dataset_dict = load_dataset("shawhin/imdb-truncated")

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)

def tokenize_function(examples):
    text = examples["text"]
    tokenizer.truncation_side = "left"
    tokenized_inputs = tokenizer(text, truncation=True, return_tensors="np", max_length=512)
    return tokenized_inputs

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    model.resize_token_embeddings(len(tokenizer))

tokenized_dataset = dataset_dict.map(tokenize_function, batched=True)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}

text_list = ["I love this movie", "I hate this movie", "Not a fan, don't recommend", "This one is a pass"]

print("untrained model predictions:")
print("-"*20)

for text in text_list:
    inputs = tokenizer.encode(text, return_tensors="pt").to("mps")
    logits = model(inputs).logits
    predictions = torch.argmax(logits)
    print(f"{text} -> {id2label[predictions.tolist()]}")

# Now let's fine-tune the model with LoRA
peft_config = LoraConfig(task_type="SEQ_CLS", r=4, lora_alpha=32, target_modules=["q_lin"], lora_dropout=0.01)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

lr = 1e-3
batch_size = 4
num_epochs = 1

training_args = TrainingArguments(
    output_dir=model_checkpoint + "-lora-text-classification",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

print("trained model predictions:")
print("-"*20)

for text in text_list:
    inputs = tokenizer.encode(text, return_tensors="pt").to("mps")
    logits = model(inputs).logits
    predictions = torch.argmax(logits)
    print(f"{text} -> {id2label[predictions.tolist()]}")
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
untrained model predictions:
--------------------
I love this movie -> Negative
I hate this movie -> Negative
Not a fan, don't recommend -> Negative
This one is a pass -> Negative
trainable params: 628,994 || all params: 67,584,004 || trainable%: 0.9307
/Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages/transformers/training_args.py:1568: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
/var/folders/59/c32_bthx48jd9m2ym5m3tnpw0000j7/T/ipykernel_19744/1863821626.py:78: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
  trainer = Trainer(
                                                 
100%|██████████| 250/250 [01:47<00:00,  3.20it/s]
{'eval_loss': 0.30286046862602234, 'eval_accuracy': {'accuracy': 0.891}, 'eval_runtime': 34.387, 'eval_samples_per_second': 29.081, 'eval_steps_per_second': 7.27, 'epoch': 1.0}
100%|██████████| 250/250 [01:48<00:00,  2.31it/s]
{'train_runtime': 108.427, 'train_samples_per_second': 9.223, 'train_steps_per_second': 2.306, 'train_loss': 0.4968539123535156, 'epoch': 1.0}
trained model predictions:
--------------------
I love this movie -> Positive
I hate this movie -> Negative
Not a fan, don't recommend -> Negative
This one is a pass -> Negative

Fine tuning for phishing link detection¶

Import¶

In [1]:
import pandas as pd
from datasets import DatasetDict, Dataset, load_dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

import evaluate
import numpy as np
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
import torch
/Users/LSoica/work/AI/blog/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
In [2]:
dataset_dict = load_dataset("shawhin/phishing-site-classification")
In [3]:
# Load model directly
model_path = "google-bert/bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_path)

id2label = {0: "Safe", 1: "Not Safe"}
label2id = {"Safe": 0, "Not Safe": 1}
model = AutoModelForSequenceClassification.from_pretrained(model_path, 
                                                           num_labels=2, 
                                                           id2label=id2label, 
                                                           label2id=label2id,).to("mps")
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
In [4]:
# freeze base model parameters
for name, param in model.base_model.named_parameters():
    param.requires_grad = False

# unfreeze base model pooling layers
for name, param in model.base_model.named_parameters():
    if "pooler" in name:
        param.requires_grad = True
In [5]:
# define text preprocessing
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)
In [6]:
tokenized_data = dataset_dict.map(preprocess_function, batched=True)
In [51]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
In [52]:
# load metrics
accuracy = evaluate.load("accuracy")
auc_score = evaluate.load("roc_auc")

def compute_metrics(eval_pred):
    # get predictions
    predictions, labels = eval_pred
    
    # apply softmax to get probabilities
    probabilities = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)
    # use probabilities of the positive class for ROC AUC
    positive_class_probs = probabilities[:, 1]
    # compute auc
    auc = np.round(auc_score.compute(prediction_scores=positive_class_probs, references=labels)['roc_auc'],3)
    
    # predict most probable class
    predicted_classes = np.argmax(predictions, axis=1)
    # compute accuracy
    acc = np.round(accuracy.compute(predictions=predicted_classes, references=labels)['accuracy'],3)
    
    return {"Accuracy": acc, "AUC": auc}

Train¶

In [53]:
# hyperparameters
lr = 2e-4
batch_size = 8
num_epochs = 10

training_args = TrainingArguments(
    output_dir="bert-phishing-classifier_teacher",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)
In [54]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()
/var/folders/59/c32_bthx48jd9m2ym5m3tnpw0000j7/T/ipykernel_69699/2732273287.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
  trainer = Trainer(
 10%|█         | 263/2630 [00:21<03:28, 11.33it/s]
{'loss': 0.4922, 'grad_norm': 1.1272960901260376, 'learning_rate': 0.00018, 'epoch': 1.0}
 10%|█         | 263/2630 [00:26<03:28, 11.33it/s]
{'eval_loss': 0.4166680574417114, 'eval_Accuracy': 0.787, 'eval_AUC': 0.912, 'eval_runtime': 4.9842, 'eval_samples_per_second': 90.285, 'eval_steps_per_second': 11.436, 'epoch': 1.0}
 20%|██        | 526/2630 [00:47<02:45, 12.70it/s]
{'loss': 0.3905, 'grad_norm': 2.191509485244751, 'learning_rate': 0.00016, 'epoch': 2.0}
 20%|██        | 526/2630 [00:52<02:45, 12.70it/s]
{'eval_loss': 0.3610435724258423, 'eval_Accuracy': 0.813, 'eval_AUC': 0.931, 'eval_runtime': 5.0556, 'eval_samples_per_second': 89.01, 'eval_steps_per_second': 11.275, 'epoch': 2.0}
 30%|███       | 789/2630 [01:14<02:28, 12.42it/s]
{'loss': 0.3854, 'grad_norm': 0.820168137550354, 'learning_rate': 0.00014, 'epoch': 3.0}
 30%|███       | 789/2630 [01:19<02:28, 12.42it/s]
{'eval_loss': 0.31684166193008423, 'eval_Accuracy': 0.858, 'eval_AUC': 0.939, 'eval_runtime': 4.8727, 'eval_samples_per_second': 92.351, 'eval_steps_per_second': 11.698, 'epoch': 3.0}
 40%|████      | 1052/2630 [01:40<01:44, 15.12it/s]
{'loss': 0.3592, 'grad_norm': 1.5052204132080078, 'learning_rate': 0.00012, 'epoch': 4.0}
 40%|████      | 1052/2630 [01:44<01:44, 15.12it/s]
{'eval_loss': 0.4752054810523987, 'eval_Accuracy': 0.793, 'eval_AUC': 0.942, 'eval_runtime': 4.7917, 'eval_samples_per_second': 93.913, 'eval_steps_per_second': 11.896, 'epoch': 4.0}
 50%|█████     | 1315/2630 [02:05<01:20, 16.37it/s]
{'loss': 0.3511, 'grad_norm': 3.0983104705810547, 'learning_rate': 0.0001, 'epoch': 5.0}
 50%|█████     | 1315/2630 [02:10<01:20, 16.37it/s]
{'eval_loss': 0.33138200640678406, 'eval_Accuracy': 0.86, 'eval_AUC': 0.946, 'eval_runtime': 4.6846, 'eval_samples_per_second': 96.059, 'eval_steps_per_second': 12.167, 'epoch': 5.0}
 60%|██████    | 1578/2630 [02:31<01:13, 14.24it/s]
{'loss': 0.3536, 'grad_norm': 2.456183671951294, 'learning_rate': 8e-05, 'epoch': 6.0}
 60%|██████    | 1578/2630 [02:36<01:13, 14.24it/s]
{'eval_loss': 0.30289924144744873, 'eval_Accuracy': 0.871, 'eval_AUC': 0.948, 'eval_runtime': 4.9382, 'eval_samples_per_second': 91.126, 'eval_steps_per_second': 11.543, 'epoch': 6.0}
 70%|███████   | 1841/2630 [02:57<00:47, 16.67it/s]
{'loss': 0.3196, 'grad_norm': 2.2076616287231445, 'learning_rate': 6e-05, 'epoch': 7.0}
 70%|███████   | 1841/2630 [03:02<00:47, 16.67it/s]
{'eval_loss': 0.2912053167819977, 'eval_Accuracy': 0.862, 'eval_AUC': 0.949, 'eval_runtime': 5.0116, 'eval_samples_per_second': 89.791, 'eval_steps_per_second': 11.374, 'epoch': 7.0}
 80%|████████  | 2104/2630 [03:24<00:41, 12.64it/s]
{'loss': 0.3285, 'grad_norm': 4.401841163635254, 'learning_rate': 4e-05, 'epoch': 8.0}
 80%|████████  | 2104/2630 [03:28<00:41, 12.64it/s]
{'eval_loss': 0.29782968759536743, 'eval_Accuracy': 0.876, 'eval_AUC': 0.949, 'eval_runtime': 4.7566, 'eval_samples_per_second': 94.605, 'eval_steps_per_second': 11.983, 'epoch': 8.0}
 90%|█████████ | 2367/2630 [03:50<00:29,  9.06it/s]
{'loss': 0.3152, 'grad_norm': 0.3482799828052521, 'learning_rate': 2e-05, 'epoch': 9.0}
 90%|█████████ | 2367/2630 [03:54<00:29,  9.06it/s]
{'eval_loss': 0.28831204771995544, 'eval_Accuracy': 0.864, 'eval_AUC': 0.951, 'eval_runtime': 4.5704, 'eval_samples_per_second': 98.461, 'eval_steps_per_second': 12.472, 'epoch': 9.0}
100%|██████████| 2630/2630 [04:16<00:00, 16.56it/s]
{'loss': 0.3053, 'grad_norm': 4.398026943206787, 'learning_rate': 0.0, 'epoch': 10.0}
100%|██████████| 2630/2630 [04:21<00:00, 16.56it/s]
{'eval_loss': 0.2977932393550873, 'eval_Accuracy': 0.871, 'eval_AUC': 0.951, 'eval_runtime': 4.7019, 'eval_samples_per_second': 95.706, 'eval_steps_per_second': 12.123, 'epoch': 10.0}
100%|██████████| 2630/2630 [04:22<00:00, 10.03it/s]
{'train_runtime': 262.2333, 'train_samples_per_second': 80.081, 'train_steps_per_second': 10.029, 'train_loss': 0.36006521942950925, 'epoch': 10.0}

Out[54]:
TrainOutput(global_step=2630, training_loss=0.36006521942950925, metrics={'train_runtime': 262.2333, 'train_samples_per_second': 80.081, 'train_steps_per_second': 10.029, 'total_flos': 706603239165360.0, 'train_loss': 0.36006521942950925, 'epoch': 10.0})

Evaluate¶

In [55]:
# apply model to validation dataset
predictions = trainer.predict(tokenized_data["test"])

# Extract the logits and labels from the predictions object
logits = predictions.predictions
labels = predictions.label_ids

# Use your compute_metrics function
metrics = compute_metrics((logits, labels))
print(metrics)
100%|██████████| 57/57 [00:04<00:00, 12.49it/s]
{'Accuracy': np.float64(0.864), 'AUC': np.float64(0.951)}

In [ ]:
## Infer on new data
In [119]:
urls = [
  "google.com",
  "yahoo.com",
  "www.yahoo.com",
  "https://www.yahoo.com",
  "https://microsoft.user-account.online/14e84edd29dc7302?l=861",
  "users11.jabry.com/reaseo/Aolupdate.htm",
  "www.allandmedia.com/opencart/system/Cielo/index.html",
  "mrterabit.com/remax/index.php",
  "phishing.org"
]
for url in urls:
  inputs = tokenizer(url, return_tensors="pt", truncation=True, padding=True)

  with torch.no_grad():
    outputs = trainer.model.forward(**inputs)
    probabilities = np.exp(outputs.logits) / np.exp(outputs.logits).sum(-1, keepdims=True)
    print("Safe" if probabilities[0][0].item() > 0.9 else "Not Safe", url)
/var/folders/59/c32_bthx48jd9m2ym5m3tnpw0000j7/T/ipykernel_69699/3933651874.py:17: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
  probabilities = np.exp(outputs.logits) / np.exp(outputs.logits).sum(-1, keepdims=True)
Safe google.com
Safe yahoo.com
Safe www.yahoo.com
Safe https://www.yahoo.com
Not Safe https://microsoft.user-account.online/14e84edd29dc7302?l=861
Not Safe users11.jabry.com/reaseo/Aolupdate.htm
Not Safe www.allandmedia.com/opencart/system/Cielo/index.html
Not Safe mrterabit.com/remax/index.php
Not Safe phishing.org

References¶

BERT - Fine tuning for phishing link detection