-
Notifications
You must be signed in to change notification settings - Fork 372
/
Copy pathaccess.rs
183 lines (156 loc) · 6.96 KB
/
access.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
use crate::{
rest,
rest::{RequestFactory, RequestServiceHandle},
API,
};
use futures::{
channel::{mpsc, oneshot},
StreamExt,
};
use hyper::StatusCode;
use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
use std::collections::HashMap;
use tokio::select;
pub const AUTH_URL_PREFIX: &str = "auth/v1";
#[derive(Debug, Clone)]
pub struct AccessTokenStore {
tx: mpsc::UnboundedSender<StoreAction>,
}
enum StoreAction {
/// Request an access token for `AccountToken`, or return a saved one if it's not expired.
GetAccessToken(
AccountToken,
oneshot::Sender<Result<AccessToken, rest::Error>>,
),
/// Forget cached access token for `AccountToken`, and drop any in-flight requests
InvalidateToken(AccountToken),
}
#[derive(Default)]
struct AccountState {
current_access_token: Option<AccessTokenData>,
inflight_request: Option<tokio::task::JoinHandle<()>>,
response_channels: Vec<oneshot::Sender<Result<AccessToken, rest::Error>>>,
}
impl AccessTokenStore {
pub(crate) fn new(service: RequestServiceHandle) -> Self {
let factory = rest::RequestFactory::new(API.host(), None);
let (tx, rx) = mpsc::unbounded();
tokio::spawn(Self::service_requests(rx, service, factory));
Self { tx }
}
async fn service_requests(
mut rx: mpsc::UnboundedReceiver<StoreAction>,
service: RequestServiceHandle,
factory: RequestFactory,
) {
let mut account_states: HashMap<AccountToken, AccountState> = HashMap::new();
let (completed_tx, mut completed_rx) = mpsc::unbounded();
loop {
select! {
action = rx.next() => {
let Some(action) = action else {
// We're done
break;
};
match action {
StoreAction::GetAccessToken(account, response_tx) => {
let account_state = account_states
.entry(account.clone())
.or_default();
// If there is an unexpired access token, just return it.
// Otherwise, generate a new token
if let Some(ref access_token) = account_state.current_access_token {
if !access_token.is_expired() {
log::trace!("Using stored access token");
let _ = response_tx.send(Ok(access_token.access_token.clone()));
continue;
}
log::debug!("Replacing expired access token");
account_state.current_access_token = None;
}
// Begin requesting an access token if it's not already underway.
// If there's already an inflight request, just save `response_tx`
account_state
.inflight_request
.get_or_insert_with(|| {
let completed_tx = completed_tx.clone();
let account = account.clone();
let service = service.clone();
let factory = factory.clone();
log::debug!("Fetching access token for an account");
tokio::spawn(async move {
let result = fetch_access_token(service, factory, account.clone()).await;
let _ = completed_tx.unbounded_send((account, result));
})
});
// Save the channel to respond to later
account_state.response_channels.push(response_tx);
}
StoreAction::InvalidateToken(account) => {
let account_state = account_states
.entry(account)
.or_default();
// Drop in-flight requests for the account
// & forget any existing access token
log::debug!("Invalidating access token for an account");
if let Some(task) = account_state.inflight_request.take() {
task.abort();
let _ = task.await;
}
account_state.response_channels.clear();
account_state.current_access_token = None;
}
}
}
Some((account, result)) = completed_rx.next() => {
let account_state = account_states
.entry(account)
.or_default();
account_state.inflight_request = None;
// Send response to all channels
for tx in account_state.response_channels.drain(..) {
let _ = tx.send(result.clone().map(|data| data.access_token));
}
if let Ok(access_token) = result {
account_state.current_access_token = Some(access_token);
}
}
}
}
}
/// Obtain access token for an account, requesting a new one from the API if necessary.
pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
let (tx, rx) = oneshot::channel();
let _ = self
.tx
.unbounded_send(StoreAction::GetAccessToken(account.to_owned(), tx));
rx.await.map_err(|_| rest::Error::Aborted)?
}
/// Remove an access token if the API response calls for it.
pub fn check_response<T>(&self, account: &AccountToken, response: &Result<T, rest::Error>) {
if let Err(rest::Error::ApiError(_status, code)) = response {
if code == crate::INVALID_ACCESS_TOKEN {
let _ = self
.tx
.unbounded_send(StoreAction::InvalidateToken(account.to_owned()));
}
}
}
}
async fn fetch_access_token(
service: RequestServiceHandle,
factory: RequestFactory,
account_token: AccountToken,
) -> Result<AccessTokenData, rest::Error> {
#[derive(serde::Serialize)]
struct AccessTokenRequest {
account_number: String,
}
let request = AccessTokenRequest {
account_number: account_token,
};
let rest_request = factory
.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?
.expected_status(&[StatusCode::OK]);
service.request(rest_request).await?.deserialize().await
}