Skip to content

Commit

Permalink
Silence some pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 446056025
  • Loading branch information
rchen152 authored and LIT team committed May 2, 2022
1 parent 00371b4 commit 085bb67
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lit_nlp/components/ablation_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def generate(self,
continue

# Create counterfactual and obtain model prediction.
cf = self._create_cf(example, input_spec, ablation_idxs)
cf = self._create_cf(example, input_spec, ablation_idxs) # pytype: disable=wrong-arg-types # enable-nested-classes
cf_output = list(model.predict([cf]))[0]

# Check if counterfactual results in a prediction flip.
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/minimal_targeted_counterfactuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def generate(self,
original_pred = list(model.predict([example]))[0]

# Find dataset examples that are flips.
filtered_examples = self._filter_ds_examples(
filtered_examples = self._filter_ds_examples( # pytype: disable=wrong-arg-types # enable-nested-classes
dataset=dataset,
dataset_name=dataset_name,
model=model,
Expand Down
2 changes: 1 addition & 1 deletion lit_nlp/components/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def prediction_fn(examples):
preds.append(max_value)
elif isinstance(pred_info, types.SparseMultilabelPreds):
pred_tuples = pred[pred_key_to_explain]
pred_list = list(map(lambda pred: pred[1], pred_tuples))
pred_list = list(map(lambda pred: pred[1], pred_tuples)) # pytype: disable=annotation-type-mismatch # enable-nested-classes
max_value: float = max(pred_list)
preds.append(max_value)
else:
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/lib/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _obj_to_json(o: object):
# .tolist() on a NumPy array.
return cast(np.number, o).tolist() # to regular Python scalar
elif isinstance(o, types.LitType):
return o.to_json()
return o.to_json() # pytype: disable=attribute-error # enable-nested-classes
elif isinstance(o, dtypes.DataTuple):
return o.to_json()
elif isinstance(o, tuple):
Expand All @@ -59,7 +59,7 @@ def _obj_to_json_simple(o: object):
# .tolist() on a NumPy array.
return cast(np.number, o).tolist() # to regular Python scalar
elif isinstance(o, types.LitType):
return o.to_json()
return o.to_json() # pytype: disable=attribute-error # enable-nested-classes
elif isinstance(o, dtypes.DataTuple):
return o.to_json()
elif isinstance(o, dtypes.EnumSerializableAsValues):
Expand Down

0 comments on commit 085bb67

Please sign in to comment.