Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozhewei committed May 27, 2021
1 parent 1897bf6 commit 542b045
Show file tree
Hide file tree
Showing 13 changed files with 1,205 additions and 626 deletions.
16 changes: 8 additions & 8 deletions inference/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Our sparse kenel is from [Triton](https://github.com/ptillet/triton).
Our sparse kenel is from [Triton](https://github.com/ptillet/triton). Please note that our inference code is tested on AWS g4dn.xlarge instance.

## Usage

Expand All @@ -16,16 +16,16 @@ After that, you can run the inference similar to the following example,
```
export PYTHONUNBUFFERED=1
OUTPUT_PATH=/home/ubuntu/str_prune/result/qqp_structural_distillation/0.5/acc_and_f1best/
BLOCK_PATH=/home/ubuntu/str_prune/result/qqp_structural_blockwise_32_distillation/0.4
OUTPUT_PATH=result/qqp_partial/0.5/checkpoint-209000/
block_rows=32
block_cols=32
BLOCK_PATH=result/qqp_full/${block_rows}_${block_cols}/0.4/checkpoint-209000/
batch_size=32
max_seq_length=512
pruning_method=topK
block_size=32
export CUDA_VISIBLE_DEVICES=0; python masked_bert_parameter_count.py --model_type masked_bert \
export CUDA_VISIBLE_DEVICES=0; python masked_bert_inference.py --model_type masked_bert \
--model_name_or_path ${OUTPUT_PATH} --per_gpu_train_batch_size ${batch_size} \
--max_seq_length ${max_seq_length} --pruning_method ${pruning_method} \
--block_cols ${block_size} --block_rows ${block_size} \
--max_seq_length ${max_seq_length} --pruning_method topK \
--block_cols ${block_cols} --block_rows ${block_rows} \
--block_path ${BLOCK_PATH} --head_pruning
```
239 changes: 152 additions & 87 deletions inference/emmental/modeling_bert_masked.py

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions inference/emmental/modules/binarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import autograd
import math


class TopKBinarizer(autograd.Function):
"""
Top-k Binarizer.
Expand Down Expand Up @@ -35,7 +36,8 @@ def forward(ctx, inputs: torch.tensor, threshold: float, head_split: int):
Binary matrix of the same size as `inputs` acting as a mask (1 - the associated weight is
retained, 0 - the associated weight is pruned).
"""
# Get the subnetwork by sorting the inputs and using the top threshold %
# Get the subnetwork by sorting the inputs and using the top threshold
# %
threshold = torch.sigmoid(threshold).item()

mask = inputs.clone()
Expand All @@ -48,22 +50,22 @@ def forward(ctx, inputs: torch.tensor, threshold: float, head_split: int):
flat_out[idx[j:]] = 0.
flat_out[idx[:j]] = 1.
else:
inputs = inputs.reshape(head_split, -1) # make it as a 12 x 64 matrix! Then do the sorting!
_, idx = inputs.sort(-1, descending=True) # the default is column-wise
j = math.ceil(threshold * inputs.size(1))
# make it as a 12 x 64 matrix! Then do the sorting!
inputs = inputs.reshape(head_split, -1)
# the default is column-wise
_, idx = inputs.sort(-1, descending=True)
j = math.ceil(threshold * inputs.size(1))

#
#
flat_out = mask.reshape(head_split, -1)
for i in range(head_split):
flat_out[i, idx[i, j:]] = 0.
flat_out[i, idx[i, :j]] = 1.
ctx.save_for_backward(mask) # we should try two things
ctx.save_for_backward(mask) # we should try two things

return mask

@staticmethod
def backward(ctx, gradOutput):
mask, = ctx.saved_tensors
return gradOutput, ((gradOutput * mask).sum()).view(-1), None


Loading

0 comments on commit 542b045

Please sign in to comment.