1use std::collections::HashSet;
4
5use hyper::{Method, StatusCode};
6
7#[derive(Clone)]
42pub struct RouteNonMatch {
43 status: StatusCode,
44 allow: MethodSet,
45}
46
47impl RouteNonMatch {
48 pub fn new(status: StatusCode) -> RouteNonMatch {
50 RouteNonMatch {
51 status,
52 allow: MethodSet::default(),
53 }
54 }
55
56 pub fn with_allow_list(self, allow: &[Method]) -> RouteNonMatch {
60 RouteNonMatch {
61 allow: allow.into(),
62 ..self
63 }
64 }
65
66 pub fn intersection(self, other: RouteNonMatch) -> RouteNonMatch {
74 let status = match (self.status, other.status) {
75 (StatusCode::METHOD_NOT_ALLOWED, _) | (_, StatusCode::METHOD_NOT_ALLOWED) => {
77 StatusCode::METHOD_NOT_ALLOWED
78 }
79 (StatusCode::NOT_FOUND, rhs) if rhs.is_client_error() => rhs,
81 (lhs, StatusCode::NOT_FOUND) if lhs.is_client_error() => lhs,
82 (StatusCode::NOT_ACCEPTABLE, rhs) if rhs.is_client_error() => rhs,
84 (lhs, StatusCode::NOT_ACCEPTABLE) if lhs.is_client_error() => lhs,
85 (lhs, _) if lhs.is_client_error() => lhs,
88 (_, rhs) if rhs.is_client_error() => rhs,
89 (lhs, _) => lhs,
90 };
91 let allow = self.allow.intersection(other.allow);
92 RouteNonMatch { status, allow }
93 }
94
95 pub fn union(self, other: RouteNonMatch) -> RouteNonMatch {
103 let status = match (self.status, other.status) {
104 (StatusCode::METHOD_NOT_ALLOWED, rhs) if rhs.is_client_error() => rhs,
107 (lhs, StatusCode::METHOD_NOT_ALLOWED) if lhs.is_client_error() => lhs,
108 (StatusCode::NOT_FOUND, rhs) if rhs.is_client_error() => rhs,
110 (lhs, StatusCode::NOT_FOUND) if lhs.is_client_error() => lhs,
111 (StatusCode::NOT_ACCEPTABLE, rhs) if rhs.is_client_error() => rhs,
113 (lhs, StatusCode::NOT_ACCEPTABLE) if lhs.is_client_error() => lhs,
114 (lhs, _) if lhs.is_client_error() => lhs,
117 (_, rhs) if rhs.is_client_error() => rhs,
118 (lhs, _) => lhs,
119 };
120 let allow = self.allow.union(other.allow);
121 RouteNonMatch { status, allow }
122 }
123
124 pub(super) fn deconstruct(self) -> (StatusCode, Vec<Method>) {
125 (self.status, self.allow.into())
126 }
127}
128
129impl From<RouteNonMatch> for StatusCode {
130 fn from(val: RouteNonMatch) -> StatusCode {
131 val.status
132 }
133}
134
135#[derive(Clone, Default)]
138struct MethodSet {
139 connect: bool,
140 delete: bool,
141 get: bool,
142 head: bool,
143 options: bool,
144 patch: bool,
145 post: bool,
146 put: bool,
147 trace: bool,
148 other: HashSet<Method>,
149}
150
151impl MethodSet {
152 fn is_empty(&self) -> bool {
153 !self.connect
154 && !self.delete
155 && !self.get
156 && !self.head
157 && !self.options
158 && !self.patch
159 && !self.post
160 && !self.put
161 && !self.trace
162 && self.other.is_empty()
163 }
164
165 fn intersection(self, other: MethodSet) -> MethodSet {
166 if self.is_empty() {
167 other
168 } else if other.is_empty() {
169 self
170 } else {
171 MethodSet {
172 connect: self.connect && other.connect,
173 delete: self.delete && other.delete,
174 get: self.get && other.get,
175 head: self.head && other.head,
176 options: self.options && other.options,
177 patch: self.patch && other.patch,
178 post: self.post && other.post,
179 put: self.put && other.put,
180 trace: self.trace && other.trace,
181 other: self.other.intersection(&other.other).cloned().collect(),
182 }
183 }
184 }
185
186 fn union(self, other: MethodSet) -> MethodSet {
187 MethodSet {
188 connect: self.connect || other.connect,
189 delete: self.delete || other.delete,
190 get: self.get || other.get,
191 head: self.head || other.head,
192 options: self.options || other.options,
193 patch: self.patch || other.patch,
194 post: self.post || other.post,
195 put: self.put || other.put,
196 trace: self.trace || other.trace,
197 other: self.other.union(&other.other).cloned().collect(),
198 }
199 }
200}
201
202impl<'a> From<&'a [Method]> for MethodSet {
203 fn from(methods: &[Method]) -> MethodSet {
204 let (
205 mut connect,
206 mut delete,
207 mut get,
208 mut head,
209 mut options,
210 mut patch,
211 mut post,
212 mut put,
213 mut trace,
214 ) = (
215 false, false, false, false, false, false, false, false, false,
216 );
217
218 let mut other = HashSet::new();
219
220 for method in methods {
221 match *method {
222 Method::CONNECT => {
223 connect = true;
224 }
225 Method::DELETE => {
226 delete = true;
227 }
228 Method::GET => {
229 get = true;
230 }
231 Method::HEAD => {
232 head = true;
233 }
234 Method::OPTIONS => {
235 options = true;
236 }
237 Method::PATCH => {
238 patch = true;
239 }
240 Method::POST => {
241 post = true;
242 }
243 Method::PUT => {
244 put = true;
245 }
246 Method::TRACE => {
247 trace = true;
248 }
249 _ => {
250 other.insert(method.clone());
251 }
252 }
253 }
254
255 MethodSet {
256 connect,
257 delete,
258 get,
259 head,
260 options,
261 patch,
262 post,
263 put,
264 trace,
265 other,
266 }
267 }
268}
269
270impl From<MethodSet> for Vec<Method> {
271 fn from(method_set: MethodSet) -> Vec<Method> {
272 let methods_with_flags: [(Method, bool); 9] = [
273 (Method::CONNECT, method_set.connect),
274 (Method::DELETE, method_set.delete),
275 (Method::GET, method_set.get),
276 (Method::HEAD, method_set.head),
277 (Method::OPTIONS, method_set.options),
278 (Method::PATCH, method_set.patch),
279 (Method::POST, method_set.post),
280 (Method::PUT, method_set.put),
281 (Method::TRACE, method_set.trace),
282 ];
283
284 let mut result = methods_with_flags
285 .iter()
286 .filter_map(|&(ref method, flag)| if flag { Some(method.clone()) } else { None })
287 .chain(method_set.other)
288 .collect::<Vec<Method>>();
289
290 result.sort_unstable_by(|a, b| a.as_ref().cmp(b.as_ref()));
291 result
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 use hyper::{Method, StatusCode};
300
301 trait AllowList {
302 fn apply_allow_list(self, list: Option<&[Method]>) -> Self;
303 }
304
305 impl AllowList for RouteNonMatch {
306 fn apply_allow_list(self, list: Option<&[Method]>) -> Self {
307 match list {
308 Some(list) => self.with_allow_list(list),
309 None => self,
310 }
311 }
312 }
313
314 const ALL: [Method; 7] = [
315 Method::DELETE,
316 Method::GET,
317 Method::HEAD,
318 Method::OPTIONS,
319 Method::PATCH,
320 Method::POST,
321 Method::PUT,
322 ];
323
324 fn intersection_assert_status_code(code1: StatusCode, code2: StatusCode, expected: StatusCode) {
325 let (status, _) = RouteNonMatch::new(code1)
326 .intersection(RouteNonMatch::new(code2))
327 .deconstruct();
328 assert_eq!(status, expected);
329 let (status, _) = RouteNonMatch::new(code2)
330 .intersection(RouteNonMatch::new(code1))
331 .deconstruct();
332 assert_eq!(status, expected);
333 }
334
335 #[test]
336 fn intersection_test_status_code() {
337 intersection_assert_status_code(
338 StatusCode::METHOD_NOT_ALLOWED,
339 StatusCode::NOT_FOUND,
340 StatusCode::METHOD_NOT_ALLOWED,
341 );
342 intersection_assert_status_code(
343 StatusCode::NOT_ACCEPTABLE,
344 StatusCode::NOT_FOUND,
345 StatusCode::NOT_ACCEPTABLE,
346 );
347 intersection_assert_status_code(
348 StatusCode::NOT_ACCEPTABLE,
349 StatusCode::FORBIDDEN,
350 StatusCode::FORBIDDEN,
351 );
352 intersection_assert_status_code(
353 StatusCode::OK,
354 StatusCode::NOT_FOUND,
355 StatusCode::NOT_FOUND,
356 );
357 intersection_assert_status_code(
358 StatusCode::OK,
359 StatusCode::NOT_ACCEPTABLE,
360 StatusCode::NOT_ACCEPTABLE,
361 );
362
363 let (status, _) = RouteNonMatch::new(StatusCode::OK)
364 .intersection(RouteNonMatch::new(StatusCode::NO_CONTENT))
365 .deconstruct();
366 assert_eq!(status, StatusCode::OK);
367 let (status, _) = RouteNonMatch::new(StatusCode::NO_CONTENT)
368 .intersection(RouteNonMatch::new(StatusCode::OK))
369 .deconstruct();
370 assert_eq!(status, StatusCode::NO_CONTENT);
371 }
372
373 fn intersection_assert_allow_list(
374 list1: Option<&[Method]>,
375 list2: Option<&[Method]>,
376 expected: &[Method],
377 ) {
378 let status = StatusCode::BAD_REQUEST;
379 let (_, allow_list) = RouteNonMatch::new(status)
380 .apply_allow_list(list1)
381 .intersection(RouteNonMatch::new(status).apply_allow_list(list2))
382 .deconstruct();
383 assert_eq!(&allow_list, &expected);
384 let (_, allow_list) = RouteNonMatch::new(status)
385 .apply_allow_list(list2)
386 .intersection(RouteNonMatch::new(status).apply_allow_list(list1))
387 .deconstruct();
388 assert_eq!(&allow_list, &expected);
389 }
390
391 #[test]
392 fn intersection_test_allow_list() {
393 intersection_assert_allow_list(None, None, &[]);
394 intersection_assert_allow_list(Some(&ALL), None, &ALL);
395 intersection_assert_allow_list(Some(&ALL), Some(&[Method::GET]), &[Method::GET]);
396 intersection_assert_allow_list(None, Some(&[Method::GET]), &[Method::GET]);
397 intersection_assert_allow_list(
398 Some(&[Method::GET, Method::POST]),
399 Some(&[Method::POST, Method::PUT]),
400 &[Method::POST],
401 );
402 }
403
404 fn union_assert_status_code(code1: StatusCode, code2: StatusCode, expected: StatusCode) {
405 let (status, _) = RouteNonMatch::new(code1)
406 .union(RouteNonMatch::new(code2))
407 .deconstruct();
408 assert_eq!(status, expected);
409 let (status, _) = RouteNonMatch::new(code2)
410 .union(RouteNonMatch::new(code1))
411 .deconstruct();
412 assert_eq!(status, expected);
413 }
414
415 #[test]
416 fn union_test_status_code() {
417 union_assert_status_code(
418 StatusCode::METHOD_NOT_ALLOWED,
419 StatusCode::NOT_FOUND,
420 StatusCode::NOT_FOUND,
421 );
422 union_assert_status_code(
423 StatusCode::NOT_ACCEPTABLE,
424 StatusCode::NOT_FOUND,
425 StatusCode::NOT_ACCEPTABLE,
426 );
427 union_assert_status_code(
428 StatusCode::NOT_ACCEPTABLE,
429 StatusCode::FORBIDDEN,
430 StatusCode::FORBIDDEN,
431 );
432 union_assert_status_code(StatusCode::OK, StatusCode::NOT_FOUND, StatusCode::NOT_FOUND);
433 union_assert_status_code(
434 StatusCode::OK,
435 StatusCode::NOT_ACCEPTABLE,
436 StatusCode::NOT_ACCEPTABLE,
437 );
438
439 let (status, _) = RouteNonMatch::new(StatusCode::OK)
440 .union(RouteNonMatch::new(StatusCode::NO_CONTENT))
441 .deconstruct();
442 assert_eq!(status, StatusCode::OK);
443 let (status, _) = RouteNonMatch::new(StatusCode::NO_CONTENT)
444 .union(RouteNonMatch::new(StatusCode::OK))
445 .deconstruct();
446 assert_eq!(status, StatusCode::NO_CONTENT);
447 }
448
449 fn union_assert_allow_list(
450 list1: Option<&[Method]>,
451 list2: Option<&[Method]>,
452 expected: &[Method],
453 ) {
454 let status = StatusCode::BAD_REQUEST;
455 let (_, allow_list) = RouteNonMatch::new(status)
456 .apply_allow_list(list1)
457 .union(RouteNonMatch::new(status).apply_allow_list(list2))
458 .deconstruct();
459 assert_eq!(&allow_list, &expected);
460 let (_, allow_list) = RouteNonMatch::new(status)
461 .apply_allow_list(list2)
462 .union(RouteNonMatch::new(status).apply_allow_list(list1))
463 .deconstruct();
464 assert_eq!(&allow_list, &expected);
465 }
466
467 #[test]
468 fn union_test_allow_list() {
469 union_assert_allow_list(None, None, &[]);
470 union_assert_allow_list(Some(&ALL), None, &ALL);
471 union_assert_allow_list(Some(&ALL), Some(&[Method::GET]), &ALL);
472 union_assert_allow_list(None, Some(&[Method::GET]), &[Method::GET]);
473 union_assert_allow_list(
474 Some(&[Method::GET, Method::POST]),
475 Some(&[Method::POST, Method::PUT]),
476 &[Method::GET, Method::POST, Method::PUT],
477 );
478 }
479
480 #[test]
481 fn deconstruct_tests() {
482 let (_, allow_list) = RouteNonMatch::new(StatusCode::NOT_FOUND)
483 .with_allow_list(&[
484 Method::DELETE,
485 Method::GET,
486 Method::HEAD,
487 Method::OPTIONS,
488 Method::PATCH,
489 Method::POST,
490 Method::PUT,
491 Method::CONNECT,
492 Method::TRACE,
493 Method::from_bytes(b"PROPFIND").unwrap(),
494 Method::from_bytes(b"PROPSET").unwrap(),
495 ])
496 .deconstruct();
497
498 assert_eq!(
499 &allow_list[..],
500 &[
501 Method::CONNECT,
502 Method::DELETE,
503 Method::GET,
504 Method::HEAD,
505 Method::OPTIONS,
506 Method::PATCH,
507 Method::POST,
508 Method::from_bytes(b"PROPFIND").unwrap(),
509 Method::from_bytes(b"PROPSET").unwrap(),
510 Method::PUT,
511 Method::TRACE,
512 ]
513 );
514 }
515}