-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtag_accidents.py
77 lines (53 loc) · 2.91 KB
/
tag_accidents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from concurrent.futures import ProcessPoolExecutor, as_completed
from transformers import pipeline
import pandas as pd
import logging
import os
# ----------------------------------------------------------------------------------- #
# Define constants #
# ----------------------------------------------------------------------------------- #
# Filepaths
ACCIDENT_FILEPATH = os.path.join('data', 'output', 'accident_reports.csv')
TAGS_FILEPATH = os.path.join('data', 'input', 'tags.csv')
TAGGED_ACC_FILEPATH = os.path.join('data', 'output', 'tagged_accidents.csv')
# Other
LABEL_THRESHOLD = 0.75
ORIG_LOG_CONFIG = logging.getLogger().getEffectiveLevel()
# ----------------------------------------------------------------------------------- #
# Configurations #
# ----------------------------------------------------------------------------------- #
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
# ----------------------------------------------------------------------------------- #
# Get Tags #
# ----------------------------------------------------------------------------------- #
# Load accident descriptions
acc_df = pd.read_csv(ACCIDENT_FILEPATH)
# Load tags
tag_df = pd.read_csv(TAGS_FILEPATH)
candidate_tags = tag_df.tag.values
# Instantiate classification pipeline
logging.getLogger("transformers").setLevel(logging.ERROR) # Temporarily set warnings to ERROR only
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
def classify_sequence(row):
"""Function for classifying an accident description as described in row['accidents']."""
result = classifier(row['accidents'], candidate_tags, multi_label=True)
filtered_tags = [(tag, score) for tag, score in zip(result['labels'], result['scores'])
if score >= LABEL_THRESHOLD]
return row['acc_id'], filtered_tags
if __name__ == "__main__":
for peakid in acc_df.peakid.unique():
df = acc_df.query(f"peakid == '{peakid}'")
print(f"Working on peakid = '{peakid}'")
num_workers = 6
with ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(classify_sequence, this_row) for _, this_row in df.iterrows()]
results = [future.result() for future in as_completed(futures)]
tagged_df = pd.DataFrame(results, columns=['acc_id', 'tags'])
tagged_df.sort_values(by='acc_id', ascending=True, inplace=True)
# Check if the file exists
if os.path.isfile(TAGGED_ACC_FILEPATH):
tagged_df.to_csv(TAGGED_ACC_FILEPATH, mode='a', header=False, index=False)
else:
tagged_df.to_csv(TAGGED_ACC_FILEPATH, index=False)
logging.basicConfig(level=ORIG_LOG_CONFIG)