Skip to content

Commit

Permalink
🧑‍💻 More Flake changes, Isort
Browse files Browse the repository at this point in the history
Fixed a lot of bad code.
  • Loading branch information
ItsNiklas committed Aug 22, 2023
1 parent 6341263 commit 29c7d39
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 84 deletions.
13 changes: 9 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ repos:
rev: 23.7.0
hooks:
- id: black
name: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
Expand All @@ -13,11 +14,15 @@ repos:
rev: 6.1.0 # Use the ref you want to pin
hooks:
- id: flake8
name: flake8
#additional_dependencies: [flake8-import-order]
args: ['--select=F,E9']
- repo: https://github.com/asottile/reorder-python-imports
rev: v3.10.0
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: reorder-python-imports
args: ['--py38-plus']
- id: isort
name: isort
args: ['--profile', 'black']



26 changes: 13 additions & 13 deletions base_bert.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import os
import re
from typing import Optional, Union

from torch import device
from torch import dtype
import torch
from torch import dtype, nn

from config import BertConfig
from config import PretrainedConfig
from utils import *
from config import BertConfig, PretrainedConfig
from utils import (
WEIGHTS_NAME,
cached_path,
get_parameter_dtype,
hf_bucket_url,
is_remote_url,
)


class BertPreTrainedModel(nn.Module):
Expand Down Expand Up @@ -102,8 +109,7 @@ def from_pretrained(
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except EnvironmentError as err:
# logger.error(err)
except EnvironmentError:
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
Expand Down Expand Up @@ -253,10 +259,4 @@ def load(module: nn.Module, prefix=""):
}
return model, loading_info

if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
import torch_xla.core.xla_model as xm

model = xm.send_cpu_data_to_device(model, xm.xla_device())
model.to(xm.xla_device())

return model
8 changes: 1 addition & 7 deletions bert.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import math
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from base_bert import BertPreTrainedModel
from utils import *
from utils import Tensor, get_extended_attention_mask


class BertSelfAttention(nn.Module):
Expand Down
17 changes: 7 additions & 10 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,15 @@
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score, f1_score
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from AttentionLayer import AttentionLayer
from bert import BertModel
from optimizer import AdamW
from optimizer import SophiaG
from optimizer import AdamW, SophiaG
from tokenizer import BertTokenizer

# change it with respect to the original model
Expand Down Expand Up @@ -189,7 +186,7 @@ def model_eval(dataloader, model, device):
y_pred = []
sents = []
sent_ids = []
for step, batch in enumerate(tqdm(dataloader, desc=f"eval", disable=TQDM_DISABLE)):
for step, batch in enumerate(tqdm(dataloader, desc="eval", disable=TQDM_DISABLE)):
b_ids, b_mask, b_labels, b_sents, b_sent_ids = (
batch["token_ids"],
batch["attention_mask"],
Expand Down Expand Up @@ -222,7 +219,7 @@ def model_test_eval(dataloader, model, device):
y_pred = []
sents = []
sent_ids = []
for step, batch in enumerate(tqdm(dataloader, desc=f"eval", disable=TQDM_DISABLE)):
for step, batch in enumerate(tqdm(dataloader, desc="eval", disable=TQDM_DISABLE)):
b_ids, b_mask, b_sents, b_sent_ids = (
batch["token_ids"],
batch["attention_mask"],
Expand Down Expand Up @@ -424,12 +421,12 @@ def test(args):
print("DONE Test")
with open(args.dev_out, "w+") as f:
print(f"dev acc :: {dev_acc :.3f}")
f.write(f"id \t Predicted_Sentiment \n")
f.write("id \t Predicted_Sentiment \n")
for p, s in zip(dev_sent_ids, dev_pred):
f.write(f"{p} , {s} \n")

with open(args.test_out, "w+") as f:
f.write(f"id \t Predicted_Sentiment \n")
f.write("id \t Predicted_Sentiment \n")
for p, s in zip(test_sent_ids, test_pred):
f.write(f"{p} , {s} \n")

Expand Down
12 changes: 3 additions & 9 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import json
import os
from typing import Any
from typing import Dict
from typing import Tuple
from typing import Union

from utils import cached_path
from utils import CONFIG_NAME
from utils import hf_bucket_url
from utils import is_remote_url
from typing import Any, Dict, Tuple, Union

from utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url


class PretrainedConfig(object):
Expand Down
15 changes: 8 additions & 7 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
"""
import numpy as np
import torch
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import load_multitask_data
from datasets import SentenceClassificationDataset
from datasets import SentenceClassificationTestDataset
from datasets import SentencePairDataset
from datasets import SentencePairTestDataset
from datasets import (
SentenceClassificationDataset,
SentenceClassificationTestDataset,
SentencePairDataset,
SentencePairTestDataset,
load_multitask_data,
)

TQDM_DISABLE = False

Expand Down
14 changes: 7 additions & 7 deletions multitask_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

from AttentionLayer import AttentionLayer
from bert import BertModel
from datasets import load_multitask_data
from datasets import SentenceClassificationDataset
from datasets import SentencePairDataset
from evaluation import model_eval_multitask
from evaluation import test_model_multitask
from optimizer import AdamW
from optimizer import SophiaH
from datasets import (
SentenceClassificationDataset,
SentencePairDataset,
load_multitask_data,
)
from evaluation import model_eval_multitask, test_model_multitask
from optimizer import AdamW, SophiaH

TQDM_DISABLE = False

Expand Down
4 changes: 1 addition & 3 deletions optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Callable
from typing import Iterable
from typing import Tuple
from typing import Callable, Iterable, Tuple

import torch
from torch.optim import Optimizer
Expand Down
29 changes: 13 additions & 16 deletions tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,27 @@
import os
import re
import unicodedata
from collections import OrderedDict
from collections import UserDict
from collections import OrderedDict, UserDict
from contextlib import contextmanager
from enum import Enum
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import (
Any,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
overload,
)

import numpy as np
import requests
from tokenizers import AddedToken
from tokenizers import Encoding as EncodingFast

from utils import cached_path
from utils import hf_bucket_url
from utils import is_remote_url
from utils import is_torch_available

from utils import cached_path, hf_bucket_url, is_remote_url, is_torch_available

VERY_LARGE_INTEGER = int(
1e30
Expand Down
10 changes: 2 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,9 @@
from hashlib import sha256
from io import UnsupportedOperation
from pathlib import Path
from typing import BinaryIO
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from typing import BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import is_zipfile
from zipfile import ZipFile
from zipfile import ZipFile, is_zipfile

import importlib_metadata
import requests
Expand Down

0 comments on commit 29c7d39

Please sign in to comment.