-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbus_collection_builder.py
239 lines (195 loc) Β· 9.37 KB
/
bus_collection_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from collections import defaultdict
from collections.abc import Sequence
from itertools import combinations
from math import radians
from operator import itemgetter
import networkx as nx
import numpy as np
from rapidfuzz.fuzz import token_ratio
from rapidfuzz.process import extract
from scipy.optimize import linear_sum_assignment
from sentry_sdk import trace
from sklearn.neighbors import BallTree
from config import BUS_COLLECTION_SEARCH_AREA
from cython_lib.geoutils import haversine_distance, radians_tuple
from models.fetch_relation import FetchRelationBusStop, FetchRelationBusStopCollection, PublicTransport
from utils import extract_numbers
@trace
def build_bus_stop_collections(bus_stops: Sequence[FetchRelationBusStop]) -> list[FetchRelationBusStopCollection]:
# 1. group by area
# 2. group by name in area
# 3. discard unnamed if in area with named
# 4. for each named group, pick best platform and best stop
if not bus_stops:
return []
search_latLng = BUS_COLLECTION_SEARCH_AREA / 111_111
search_latLng_rad = radians(search_latLng)
bus_stops_coordinates = tuple(radians_tuple(bus_stop.latLng) for bus_stop in bus_stops)
bus_stops_tree = BallTree(bus_stops_coordinates, metric='haversine')
G = nx.Graph() # noqa: N806
query_indices, _ = bus_stops_tree.query_radius(
bus_stops_coordinates,
r=search_latLng_rad,
return_distance=True,
sort_results=True,
)
# group by area
for i in range(len(bus_stops)):
G.add_edge(i, i)
for query_group in query_indices:
for i, j in combinations(query_group, 2):
G.add_edge(i, j)
collections: list[FetchRelationBusStopCollection] = []
for component in nx.connected_components(G):
# make area group from member indices
area_group = tuple(bus_stops[member_index] for member_index in component)
# group by name in area
name_groups: dict[str, list[FetchRelationBusStop]] = defaultdict(list)
for bus_stop in area_group:
name_groups[bus_stop.groupName].append(bus_stop)
# discard unnamed if in area with named
if len(name_groups) > 1 and (unnamed := name_groups.get('')):
unnamed = [s for s in unnamed if s.public_transport == PublicTransport.PLATFORM] # never discard platforms
if unnamed:
name_groups[''] = unnamed
else:
name_groups.pop('')
# expand short-name groups to long-name groups if possible
if len(name_groups) > 1:
expand_data = {
expand_key: extract(expand_key, name_groups.keys(), scorer=token_ratio, score_cutoff=89)
for expand_key in name_groups
}
expand_data = sorted(
expand_data.items(),
key=lambda t: (sum(map(itemgetter(1), t[1])), -len(t[0])),
reverse=True,
)
# pprint(expand_data)
for expand_key, target_data in expand_data:
expand_key_n = extract_numbers(expand_key)
expand_group = name_groups[expand_key]
expand_group_public_transports = {bus_stop.public_transport for bus_stop in expand_group}
expanded = False
for target_key, name_score, _ in target_data:
if target_key == expand_key:
continue
# expand non-numeric to numeric
# or
# expand numeric to numeric when equal
if not expand_key_n.issubset(extract_numbers(target_key)):
continue
target_group = name_groups.get(target_key)
# skip if target_group was expanded/popped
if not target_group:
continue
target_group_public_transports = (bus_stop.public_transport for bus_stop in target_group)
# expand only if target doesn't share any public_transport types
if expand_group_public_transports.intersection(target_group_public_transports):
continue
print(
f'[COLL] [{name_score:5.1f}] Expanded {expand_key!r} to {target_key!r}, '
f'ID={expand_group[0].nice_id!r}'
)
target_group.extend(expand_group)
expanded = True
if expanded:
name_groups.pop(expand_key)
# for each named group, pick best platform and best stop
for name_key, name_group in name_groups.items():
platforms: list[FetchRelationBusStop] = []
stops: list[FetchRelationBusStop] = []
for bus_stop in name_group:
if bus_stop.public_transport == PublicTransport.PLATFORM:
platforms.append(bus_stop)
elif bus_stop.public_transport == PublicTransport.STOP_POSITION:
stops.append(bus_stop)
else:
raise NotImplementedError(f'Unknown public transport type: {bus_stop.public_transport}')
# for deterministic results
platforms.sort(key=lambda p: p.id)
stops.sort(key=lambda s: s.id)
platforms_explicit, platforms_implicit = _pick_best(platforms)
stops_explicit, stops_implicit = _pick_best(stops)
if platforms_explicit and stops_explicit:
collection_name = next(s.name for s in name_group if s.groupName == name_key)
print(
f'π§ Warning: Invalid explicit platforms and stops for {collection_name!r}, '
f'ID={stops_explicit[0].nice_id!r}'
)
if platforms_explicit:
for platform, stop in zip(
platforms_explicit, _assign(platforms_explicit, stops, allow_element_reuse=True)
):
collections.append(FetchRelationBusStopCollection(platform=platform, stop=stop))
continue
if stops_explicit:
for stop, platform in zip(
stops_explicit, _assign(stops_explicit, platforms, allow_element_reuse=False)
):
collections.append(FetchRelationBusStopCollection(platform=platform, stop=stop))
continue
if platforms_implicit and stops_implicit:
for platform, stop in zip(
platforms_implicit, _assign(platforms_implicit, stops, allow_element_reuse=True)
):
collections.append(FetchRelationBusStopCollection(platform=platform, stop=stop))
continue
if platforms_implicit: # and not stops_implicit
collections.extend(
FetchRelationBusStopCollection(platform=platform, stop=None) for platform in platforms_implicit
)
continue
if stops_implicit: # and not platforms_implicit
collections.extend(FetchRelationBusStopCollection(platform=None, stop=stop) for stop in stops_implicit)
continue
return collections
def _pick_best(
elements: list[FetchRelationBusStop],
) -> tuple[tuple[FetchRelationBusStop, ...], tuple[FetchRelationBusStop, ...]]:
if not elements:
return (), ()
elements_explicit = tuple(e for e in elements if e.highway == 'bus_stop')
elements_implicit = tuple(e for e in elements if e.highway != 'bus_stop')
return elements_explicit, elements_implicit
@trace
def _assign(
primary: Sequence[FetchRelationBusStop],
elements: Sequence[FetchRelationBusStop],
*,
allow_element_reuse: bool,
) -> list[FetchRelationBusStop] | list[None]:
if len(elements) >= 2:
# find the closest stop to each platform
if len(elements) < len(primary):
# disallow reuse of elements
if not allow_element_reuse:
return [None] * len(primary)
tree = BallTree(tuple(radians_tuple(e.latLng) for e in elements), metric='haversine')
query_indices = tree.query(
tuple(radians_tuple(p.latLng) for p in primary),
k=1,
return_distance=False,
sort_results=False,
)
return [elements[i] for i in query_indices[:, 0]]
# minimize the total distance between each platform and stop
else:
distance_matrix = np.zeros((len(primary), len(elements)))
# compute the haversine distance between each platform and stop
for i, p in enumerate(primary):
for j, e in enumerate(elements):
distance_matrix[i, j] = haversine_distance(p.latLng, e.latLng)
# use the Hungarian algorithm to find the optimal assignment
row_ind, col_ind = linear_sum_assignment(distance_matrix)
# ensure the assignments are sorted by platform indices
assignments = sorted(zip(row_ind, col_ind, strict=False))
# get the assigned stop for each platform
return [elements[j] for _, j in assignments]
elif len(elements) == 1:
# disallow reuse of elements
if not allow_element_reuse and len(primary) > 1:
return [None] * len(primary)
return [elements[0]] * len(primary)
else:
return [None] * len(primary)