Skip to content

Commit

Permalink
Adds HotFip.is_compatible() implementaiton.
Browse files Browse the repository at this point in the history
Isolates HotFlip unit tests.

Separates HotFlip integration tests.

PiperOrigin-RevId: 481735787
  • Loading branch information
RyanMullins authored and LIT team committed Oct 17, 2022
1 parent 2443e4a commit 9b2de92
Show file tree
Hide file tree
Showing 3 changed files with 422 additions and 309 deletions.
81 changes: 48 additions & 33 deletions lit_nlp/components/hotflip.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import copy
import itertools
from typing import Iterator, List, Optional, Text, Tuple, Type
from typing import cast, Iterator, Optional, Type

from absl import logging
from lit_nlp.api import components as lit_components
Expand Down Expand Up @@ -80,9 +80,36 @@ class HotFlip(lit_components.Generator):
prediction flip.
"""

def find_fields(
self, spec: Spec, typ: Type[types.LitType],
align_field: Optional[Text] = None) -> List[Text]:
def is_compatible(self, model: lit_model.Model) -> bool:
"""Returns true if the given model is compatible with HotFlip."""
get_embedding_table = getattr(model, "get_embedding_table", None)
if not callable(get_embedding_table):
return False
try:
table = get_embedding_table()
if not isinstance(table, tuple): return False
vocab, embs_dims = table
if not isinstance(vocab, list): return False
if not isinstance(embs_dims, np.ndarray): return False
# TODO(lit-dev): Further validate the shape of the embeddings table?
except NotImplementedError:
return False

input_spec = model.input_spec()
output_spec = model.output_spec()

for grad_key in utils.find_spec_keys(output_spec, types.TokenGradients):
grad_field = cast(types.TokenGradients, output_spec.get(grad_key))
aligned_field: Optional[types.LitType] = input_spec.get(grad_field.align)
if isinstance(aligned_field, types.Tokens):
return True

return False

def find_fields(self,
spec: Spec,
typ: Type[types.LitType],
align_field: Optional[str] = None) -> list[str]:
# Find fields of provided 'typ'.
fields = utils.find_spec_keys(spec, typ)

Expand All @@ -94,11 +121,9 @@ def find_fields(
return [f for f in fields
if getattr(spec[f], "align", None) == align_field]

def _get_tokens_and_gradients(self,
input_spec: JsonDict,
output_spec: JsonDict,
output: JsonDict,
selected_fields: List[str]):
def _get_tokens_and_gradients(self, input_spec: JsonDict,
output_spec: JsonDict, output: JsonDict,
selected_fields: list[str]):
"""Returns a dictionary mapping token fields to tokens and gradients."""
# Find selected token fields.
input_spec_keys = set(utils.find_spec_keys(input_spec, types.Tokens))
Expand Down Expand Up @@ -152,11 +177,8 @@ def _subset_exists(self, cand_set, sets):
return False

def _gen_token_idxs_to_flip(
self,
tokens: List[str],
token_grads: np.ndarray,
max_flips: int,
tokens_to_ignore: List[str]) -> Iterator[Tuple[int, ...]]:
self, tokens: list[str], token_grads: np.ndarray, max_flips: int,
tokens_to_ignore: list[str]) -> Iterator[tuple[int, ...]]:
"""Generates sets of token positions that are eligible for flipping."""
# Consider all combinations of tokens upto length max_flips.
# We will iterate through this list (sortted by cardinality) and at each
Expand All @@ -180,22 +202,16 @@ def _gen_token_idxs_to_flip(
for s in itertools.combinations(token_idxs_to_flip, i+1):
yield s

def _flip_tokens(self,
tokens: List[str],
token_idxs: Tuple[int, ...],
replacement_tokens: List[str]) -> List[str]:
def _flip_tokens(self, tokens: list[str], token_idxs: tuple[int, ...],
replacement_tokens: list[str]) -> list[str]:
"""Perturbs tokens at the indices specified in 'token_idxs'."""
modified_tokens = [replacement_tokens[j] if j in token_idxs else t
for j, t in enumerate(tokens)]
return modified_tokens

def _create_cf(self,
example: JsonDict,
token_field: str,
text_field: str,
tokens: List[str],
token_idxs: Tuple[int, ...],
replacement_tokens: List[str]) -> JsonDict:
def _create_cf(self, example: JsonDict, token_field: str, text_field: str,
tokens: list[str], token_idxs: tuple[int, ...],
replacement_tokens: list[str]) -> JsonDict:
cf = copy.deepcopy(example)
modified_tokens = self._flip_tokens(
tokens, token_idxs, replacement_tokens)
Expand All @@ -207,13 +223,12 @@ def _create_cf(self,
cf[text_field] = " ".join(modified_tokens)
return cf

def _get_replacement_tokens(
self,
embedding_matrix: np.ndarray,
inv_vocab: List[Text],
token_grads: np.ndarray,
orig_output: JsonDict,
direction: int = -1) -> List[str]:
def _get_replacement_tokens(self,
embedding_matrix: np.ndarray,
inv_vocab: list[str],
token_grads: np.ndarray,
orig_output: JsonDict,
direction: int = -1) -> list[str]:
"""Identifies replacement tokens for each token position."""
token_grads = token_grads * direction
# Compute dot product of each input token gradient with the embedding
Expand All @@ -231,7 +246,7 @@ def generate(self,
example: JsonDict,
model: lit_model.Model,
dataset: lit_dataset.Dataset,
config: Optional[JsonDict] = None) -> List[JsonDict]:
config: Optional[JsonDict] = None) -> list[JsonDict]:
"""Identify minimal sets of token flips that alter the prediction."""
del dataset # Unused.

Expand Down
189 changes: 189 additions & 0 deletions lit_nlp/components/hotflip_int_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for lit_nlp.components.hotflip."""

from absl.testing import absltest
from absl.testing import parameterized
from lit_nlp.components import hotflip
# TODO(lit-dev): Move glue_models out of lit_nlp/examples
from lit_nlp.examples.models import glue_models
import numpy as np


BERT_TINY_PATH = 'https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz' # pylint: disable=line-too-long
STSB_PATH = 'https://storage.googleapis.com/what-if-tool-resources/lit-models/stsb_tiny.tar.gz' # pylint: disable=line-too-long
import transformers
BERT_TINY_PATH = transformers.file_utils.cached_path(BERT_TINY_PATH,
extract_compressed_file=True)
STSB_PATH = transformers.file_utils.cached_path(STSB_PATH,
extract_compressed_file=True)


_CONFIG_CLASSIFICATION = {
hotflip.FIELDS_TO_HOTFLIP_KEY: ['tokens_sentence'],
hotflip.PREDICTION_KEY: 'probas',
}
_CONFIG_REGRESSION = {
hotflip.FIELDS_TO_HOTFLIP_KEY: ['tokens_sentence1', 'tokens_sentence2'],
hotflip.PREDICTION_KEY: 'score',
hotflip.REGRESSION_THRESH_KEY: 2,
}

_SST2_EXAMPLE = {'sentence': 'this long movie is terrible.'}
_STSB_EXAMPLE = {
'sentence1': 'this long movie is terrible.',
'sentence2': 'this short movie is great.'
}


class HotflipIntegrationTest(parameterized.TestCase):

def __init__(self, *args, **kwargs):
super(HotflipIntegrationTest, self).__init__(*args, **kwargs)
self.classification_model = glue_models.SST2Model(BERT_TINY_PATH)
self.regression_model = glue_models.STSBModel(STSB_PATH)

def setUp(self):
super(HotflipIntegrationTest, self).setUp()
self.hotflip = hotflip.HotFlip()

@parameterized.named_parameters(
('0_examples', 0),
('1_examples', 1),
('2_examples', 2),
)
def test_hotflip_num_ex(self, num_examples: int):
config = _CONFIG_CLASSIFICATION | {hotflip.NUM_EXAMPLES_KEY: num_examples}
counterfactuals = self.hotflip.generate(
_SST2_EXAMPLE, self.classification_model, None, config)
self.assertLen(counterfactuals, num_examples)

@parameterized.named_parameters(
('0_examples', 0),
('1_examples', 1),
('2_examples', 2),
)
def test_hotflip_num_ex_multi_input(self, num_examples: int):
config = _CONFIG_REGRESSION | {hotflip.NUM_EXAMPLES_KEY: num_examples}
counterfactuals = self.hotflip.generate(
_STSB_EXAMPLE, self.regression_model, None, config)
self.assertLen(counterfactuals, num_examples)

@parameterized.named_parameters(
('terrible', ['terrible'], [4]),
('long_terrible', ['long', 'terrible'], [1, 4]),
)
def test_hotflip_freeze_tokens(
self, ignore: list[str], exp_indexes: list[int]):
config = _CONFIG_CLASSIFICATION | {
hotflip.NUM_EXAMPLES_KEY: 10,
hotflip.TOKENS_TO_IGNORE_KEY: ignore,
}

counterfactuals = self.hotflip.generate(
_SST2_EXAMPLE, self.classification_model, None, config)
self.assertEqual(len(ignore), len(exp_indexes))
for target, index in zip(ignore, exp_indexes):
for counterfactual in counterfactuals:
tokens = counterfactual['tokens_sentence']
self.assertEqual(target, tokens[index])

def test_hotflip_freeze_tokens_multi_input(self):
config = _CONFIG_REGRESSION | {
hotflip.NUM_EXAMPLES_KEY: 10,
hotflip.TOKENS_TO_IGNORE_KEY: ['terrible', 'long'],
}

counterfactuals = self.hotflip.generate(
_STSB_EXAMPLE, self.regression_model, None, config)
for cf in counterfactuals:
tokens1 = cf['tokens_sentence1']
tokens2 = cf['tokens_sentence2']
self.assertEqual('terrible', tokens1[4])
self.assertEqual('long', tokens1[1])
self.assertEqual('long', tokens2[1])

def test_hotflip_max_flips(self):
config = _CONFIG_CLASSIFICATION
ex = _SST2_EXAMPLE

ex_output = list(self.classification_model.predict([ex]))[0]
ex_tokens = ex_output['tokens_sentence']
cfs = self.hotflip.generate(ex, self.classification_model, None, config)
cf_tokens = list(cfs)[0]['tokens_sentence']
self.assertEqual(1, sum([1 for i, t in enumerate(cf_tokens)
if t != ex_tokens[i]]))

ex = {'sentence': 'this long movie is terrible and horrible.'}
cfs = self.hotflip.generate(ex, self.classification_model, None, config)
self.assertEmpty(cfs)

def test_hotflip_max_flips_multi_input(self):
config = _CONFIG_REGRESSION | {
hotflip.MAX_FLIPS_KEY: 1,
hotflip.NUM_EXAMPLES_KEY: 20,
}
ex = _STSB_EXAMPLE
ex_output = list(self.regression_model.predict([ex]))[0]
ex_tokens1 = ex_output['tokens_sentence1']
ex_tokens2 = ex_output['tokens_sentence2']
cfs = self.hotflip.generate(ex, self.regression_model, None, config)
for cf in cfs:
# Number of flips in each field should be no more than MAX_FLIPS.
cf_tokens1 = cf['tokens_sentence1']
cf_tokens2 = cf['tokens_sentence2']
self.assertLessEqual(sum([1 for i, t in enumerate(cf_tokens1)
if t != ex_tokens1[i]]), 1)
self.assertLessEqual(sum([1 for i, t in enumerate(cf_tokens2)
if t != ex_tokens2[i]]), 1)

def test_hotflip_only_flip_one_field(self):
config = _CONFIG_REGRESSION | {hotflip.NUM_EXAMPLES_KEY: 10}
ex = _STSB_EXAMPLE
cfs = self.hotflip.generate(ex, self.regression_model, None, config)
for cf in cfs:
self.assertTrue(
(cf['sentence1'] == ex['sentence1']) or
(cf['sentence2'] == ex['sentence2']))

def test_hotflip_changes_pred_class(self):
config = _CONFIG_CLASSIFICATION
ex = _SST2_EXAMPLE

ex_output = list(self.classification_model.predict([ex]))[0]
pred_class = str(np.argmax(ex_output['probas']))
cfs = self.hotflip.generate(ex, self.classification_model, None, config)
cf_outputs = self.classification_model.predict(cfs)

self.assertEqual('0', pred_class)
for cf_output in cf_outputs:
self.assertNotEqual(np.argmax(ex_output['probas']),
np.argmax(cf_output['probas']))

def test_hotflip_changes_regression_score(self):
config = _CONFIG_REGRESSION | {hotflip.NUM_EXAMPLES_KEY: 2}
ex = _STSB_EXAMPLE

thresh = config[hotflip.REGRESSION_THRESH_KEY]
ex_output = list(self.regression_model.predict([ex]))[0]
cfs = self.hotflip.generate(ex, self.regression_model, None, config)
cf_outputs = self.regression_model.predict(cfs)
for cf_output in cf_outputs:
self.assertNotEqual((ex_output['score'] <= thresh),
(cf_output['score'] <= thresh))


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 9b2de92

Please sign in to comment.