Skip to content

Commit

Permalink
add roberta option (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Apr 28, 2024
1 parent ca47c5e commit 70cd174
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 17 deletions.
9 changes: 9 additions & 0 deletions ch06/03_bonus_imdb-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ Test accuracy: 90.81%

---

A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:


```bash
python train-bert-hf.py --bert_model roberta
```

---

A scikit-learn Logistic Regression model as a basline.

```bash
Expand Down
72 changes: 55 additions & 17 deletions ch06/03_bonus_imdb-classification/train-bert-hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,32 +164,71 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
)
)
parser.add_argument(
"--bert_model",
type=str,
default="distilbert",
help=(
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
)
)
args = parser.parse_args()

###############################
# Load model
###############################
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)

torch.manual_seed(123)
model.out_head = torch.nn.Linear(in_features=768, out_features=2)

if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.pre_classifier.parameters():
param.requires_grad = True
for param in model.distilbert.transformer.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
if args.bert_model == "distilbert":

model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)
model.out_head = torch.nn.Linear(in_features=768, out_features=2)

if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.pre_classifier.parameters():
param.requires_grad = True
for param in model.distilbert.transformer.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

elif args.bert_model == "roberta":

model = AutoModelForSequenceClassification.from_pretrained(
"FacebookAI/roberta-large", num_labels=2
)
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)

if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.classifier.parameters():
param.requires_grad = True
for param in model.roberta.encoder.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")

tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")

else:
raise ValueError("Invalid --trainable_layers argument.")
raise ValueError("Selected --bert_model not supported.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

###############################
# Instantiate dataloaders
Expand All @@ -204,7 +243,6 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
file_names = ["train.csv", "val.csv", "test.csv"]
all_exist = all((base_path / file_name).exists() for file_name in file_names)

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
pad_token_id = tokenizer.encode(tokenizer.pad_token)

train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
Expand Down

0 comments on commit 70cd174

Please sign in to comment.