1use hyper::{Body, StatusCode};
4use log::trace;
5
6use crate::helpers::http::PercentDecoded;
7use crate::router::non_match::RouteNonMatch;
8use crate::router::route::{Delegation, Route};
9use crate::router::tree::segment::{SegmentMapping, SegmentType};
10use crate::state::{request_id, State};
11
12use std::cmp::Ordering;
13use std::collections::HashMap;
14
15pub struct Node {
21 segment: String,
22 segment_type: SegmentType,
23 routes: Vec<Box<dyn Route<ResBody = Body> + Send + Sync>>,
24 children: Vec<Node>,
25}
26
27impl Node {
28 pub fn new(segment: &str, segment_type: SegmentType) -> Self {
30 Node {
31 segment_type,
32 segment: segment.to_string(),
33 routes: vec![],
34 children: vec![],
35 }
36 }
37
38 pub fn add_child(&mut self, node: Node) -> &mut Self {
40 self.children.push(node);
41 self.children.sort();
42 self
43 }
44
45 pub fn add_route(&mut self, route: Box<dyn Route<ResBody = Body> + Send + Sync>) -> &mut Self {
47 self.routes.push(route);
48 self
49 }
50
51 pub fn borrow_child(&self, segment: &str, segment_type: SegmentType) -> Option<&Node> {
53 self.children
54 .iter()
55 .find(|n| n.segment_type == segment_type && n.segment == segment)
56 }
57
58 pub fn borrow_child_mut(
60 &mut self,
61 segment: &str,
62 segment_type: SegmentType,
63 ) -> Option<&mut Node> {
64 self.children
65 .iter_mut()
66 .find(|n| n.segment_type == segment_type && n.segment == segment)
67 }
68
69 pub fn has_child(&self, segment: &str, segment_type: SegmentType) -> bool {
71 self.borrow_child(segment, segment_type).is_some()
72 }
73
74 pub fn is_routable(&self) -> bool {
76 !self.routes.is_empty()
77 }
78
79 pub fn match_node<'a>(
97 &'a self,
98 segments: &'a [PercentDecoded],
99 ) -> Option<(&'a Node, SegmentMapping<'a>, usize)> {
100 let mut params = HashMap::new();
102 let mut processed = 0;
103
104 self.inner_match_node(segments, &mut params, &mut processed)
106 .map(|node| (node, params, processed))
107 }
108
109 pub fn segment<'a>(&'a self) -> &'a str {
113 &self.segment
114 }
115
116 pub fn select_route(
128 &self,
129 state: &State,
130 ) -> Result<&Box<dyn Route<ResBody = Body> + Send + Sync>, RouteNonMatch> {
131 let mut err = Ok(());
132
133 for r in self.routes.iter() {
135 match r.is_match(state) {
136 Ok(()) => {
137 trace!("[{}] found matching route", request_id(state));
138 return Ok(r);
139 }
140 Err(e) => {
141 err = match err {
143 Err(e0) => Err(e.union(e0)),
144 Ok(()) => Err(e),
145 }
146 }
147 }
148 }
149
150 if let Err(e) = err {
152 trace!(
153 "[{}] no matching route, using error status code from route",
154 request_id(state)
155 );
156 return Err(e);
157 }
158
159 trace!(
160 "[{}] invalid state, no routes. sending internal server error",
161 request_id(state)
162 );
163
164 Err(RouteNonMatch::new(StatusCode::INTERNAL_SERVER_ERROR))
166 }
167
168 fn inner_match_node<'a>(
174 &'a self,
175 segments: &'a [PercentDecoded],
176 params: &mut SegmentMapping<'a>,
177 processed: &mut usize,
178 ) -> Option<&'a Node> {
179 let next_segment = segments.split_first();
180
181 if next_segment.is_none() {
183 if !self.is_routable() {
184 return None;
185 }
186 return Some(self);
187 }
188
189 if let Some(route) = self.routes.first() {
191 if route.delegation() == Delegation::External {
192 return Some(self);
193 }
194 }
195
196 let (segment, remaining) = next_segment.unwrap();
197
198 *processed += 1;
199
200 for child in &self.children {
202 match child.segment_type {
203 SegmentType::Glob => {
206 params.entry(&child.segment).or_default().push(segment);
207 }
208
209 SegmentType::Static => {
213 if child.segment != segment.as_ref() {
215 continue;
216 }
217 }
218
219 SegmentType::Constrained { ref regex } => {
223 if !regex.is_match(segment.as_ref()) {
225 continue;
226 }
227 params.insert(&child.segment, vec![segment]);
229 }
230
231 SegmentType::Dynamic => {
235 params.insert(&child.segment, vec![segment]);
237 }
238 };
239
240 return child.inner_match_node(remaining, params, processed);
244 }
245
246 if let SegmentType::Glob = self.segment_type {
250 if let Some(path) = params.get_mut(self.segment()) {
252 path.push(segment);
253 }
254 return self.inner_match_node(remaining, params, processed);
256 }
257
258 None
259 }
260}
261
262impl Eq for Node {}
263impl PartialEq for Node {
264 fn eq(&self, other: &Node) -> bool {
266 self.segment == other.segment && self.segment_type == other.segment_type
267 }
268}
269
270impl Ord for Node {
271 fn cmp(&self, other: &Node) -> Ordering {
273 (&self.segment_type, &self.segment).cmp(&(&other.segment_type, &other.segment))
274 }
275}
276
277impl PartialOrd for Node {
278 fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
280 Some(self.cmp(other))
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 use std::panic::RefUnwindSafe;
289
290 use hyper::{HeaderMap, Method, Response};
291
292 use crate::extractor::{NoopPathExtractor, NoopQueryStringExtractor};
293 use crate::helpers::http::request::path::RequestPathSegments;
294 use crate::helpers::http::PercentDecoded;
295 use crate::pipeline::{finalize_pipeline_set, new_pipeline_set, PipelineSet};
296 use crate::router::route::dispatch::DispatcherImpl;
297 use crate::router::route::matcher::MethodOnlyRouteMatcher;
298 use crate::router::route::{Delegation, Extractors, Route, RouteImpl};
299 use crate::router::tree::regex::ConstrainedSegmentRegex;
300 use crate::state::{set_request_id, State};
301
302 fn handler(state: State) -> (State, Response<Body>) {
303 (state, Response::new(Body::empty()))
304 }
305
306 fn get_route<P>(pipeline_set: PipelineSet<P>) -> Box<dyn Route<ResBody = Body> + Send + Sync>
307 where
308 P: Send + Sync + RefUnwindSafe + 'static,
309 {
310 let methods = vec![Method::GET];
311 let matcher = MethodOnlyRouteMatcher::new(methods);
312 let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set);
313 let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
314 let route = RouteImpl::new(
315 matcher,
316 Box::new(dispatcher),
317 extractors,
318 Delegation::Internal,
319 );
320 Box::new(route)
321 }
322
323 fn test_structure() -> Node {
324 let mut root = Node::new("/", SegmentType::Static);
325 let pipeline_set = finalize_pipeline_set(new_pipeline_set());
326
327 let mut seg1 = Node::new("seg1", SegmentType::Static);
330 let methods = vec![Method::GET, Method::HEAD];
331 let matcher = MethodOnlyRouteMatcher::new(methods);
332 let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
333 let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
334 let route = RouteImpl::new(
335 matcher,
336 Box::new(dispatcher),
337 extractors,
338 Delegation::Internal,
339 );
340 seg1.add_route(Box::new(route));
341 root.add_child(seg1);
342
343 let mut seg2 = Node::new("seg2", SegmentType::Static);
346 let methods = vec![Method::POST];
347 let matcher = MethodOnlyRouteMatcher::new(methods);
348 let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
349 let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
350 let route = RouteImpl::new(
351 matcher,
352 Box::new(dispatcher),
353 extractors,
354 Delegation::Internal,
355 );
356 seg2.add_route(Box::new(route));
357
358 let methods = vec![Method::PATCH];
360 let matcher = MethodOnlyRouteMatcher::new(methods);
361 let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
362 let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
363 let route = RouteImpl::new(
364 matcher,
365 Box::new(dispatcher),
366 extractors,
367 Delegation::Internal,
368 );
369 seg2.add_route(Box::new(route));
370 root.add_child(seg2);
371
372 let mut seg3 = Node::new("seg3", SegmentType::Static);
375 let mut seg4 = Node::new("seg4", SegmentType::Static);
376 seg4.add_route(get_route(pipeline_set.clone()));
377 seg3.add_child(seg4);
378 root.add_child(seg3);
379
380 let mut seg_resource = Node::new("resource", SegmentType::Static);
384 let mut seg_id = Node::new(
385 "id",
386 SegmentType::Constrained {
387 regex: Box::new(ConstrainedSegmentRegex::new("[0-9]+")),
388 },
389 );
390 seg_id.add_route(get_route(pipeline_set.clone()));
391 seg_resource.add_child(seg_id);
392 root.add_child(seg_resource);
393
394 let mut seg5 = Node::new("seg5", SegmentType::Static);
401 let mut seg6 = Node::new("seg6", SegmentType::Static);
402 seg6.add_route(get_route(pipeline_set.clone()));
403
404 let mut segdyn1 = Node::new(":segdyn1", SegmentType::Dynamic);
405 let mut seg7 = Node::new("seg7", SegmentType::Static);
406 seg7.add_route(get_route(pipeline_set.clone()));
407
408 let mut seg8 = Node::new("seg8", SegmentType::Glob);
410 let mut seg9 = Node::new("seg9", SegmentType::Static);
411
412 let mut seg10 = Node::new("seg10", SegmentType::Glob);
413 seg10.add_route(get_route(pipeline_set));
414
415 segdyn1.add_child(seg7);
416 seg5.add_child(seg6);
417 seg5.add_child(segdyn1);
418 root.add_child(seg5);
419
420 seg9.add_child(seg10);
421 seg8.add_child(seg9);
422 root.add_child(seg8);
423
424 root
425 }
426
427 #[test]
428 fn manages_children() {
429 let root = test_structure();
430
431 assert!(root.borrow_child("seg1", SegmentType::Static).is_some());
432 assert!(root.borrow_child("seg2", SegmentType::Static).is_some());
433 assert!(root.borrow_child("seg1", SegmentType::Dynamic).is_none());
434 assert!(root.borrow_child("seg0", SegmentType::Static).is_none());
435 }
436
437 #[test]
438 fn traverses_children() {
439 let root = test_structure();
440
441 let rs = RequestPathSegments::new("/seg3/seg4");
443 match root.match_node(rs.segments()) {
444 Some((node, _params, processed)) => {
445 assert_eq!(node.segment, "seg4");
446 assert_eq!(processed, 2);
447 }
448 None => panic!("traversal should have succeeded here"),
449 }
450
451 let rs = RequestPathSegments::new("/seg3/seg4/seg5");
453 assert!(root.match_node(rs.segments()).is_none());
454
455 let rs = RequestPathSegments::new("/seg5/seg6");
457 match root.match_node(rs.segments()) {
458 Some((node, _params, processed)) => {
459 assert_eq!(node.segment, "seg6");
460 assert_eq!(processed, 2);
461 }
462 None => panic!("traversal should have succeeded here"),
463 }
464
465 let rs = RequestPathSegments::new("/seg5/someval/seg7");
467 match root.match_node(rs.segments()) {
468 Some((node, _params, processed)) => {
469 assert_eq!(node.segment, "seg7");
470 assert_eq!(processed, 3);
471 }
472 None => panic!("traversal should have succeeded here"),
473 }
474
475 let rs = RequestPathSegments::new("/some/path/seg9/another/branch");
477 match root.match_node(rs.segments()) {
478 Some((node, _params, processed)) => {
479 assert_eq!(node.segment, "seg10");
480 assert_eq!(processed, 5);
481 }
482 None => panic!("traversal should have succeeded here"),
483 }
484
485 let rs = RequestPathSegments::new("/resource/5001");
486 let expected_segment = "id";
487 match root.match_node(rs.segments()) {
488 Some((node, _params, processed)) => {
489 assert_eq!(node.segment, expected_segment);
490 assert_eq!(processed, 2);
491 }
492 None => panic!("traversal should have succeeded here"),
493 }
494 }
495
496 #[test]
497 fn non_matching_routes_allow_list_tests() {
498 let root = test_structure();
499
500 let mut state = State::new();
501 state.put(Method::OPTIONS);
502 state.put(HeaderMap::new());
503 set_request_id(&mut state);
504
505 let rs = RequestPathSegments::new("/seg2");
506 match root.match_node(rs.segments()) {
507 Some((node, _params, _processed)) => match node.select_route(&state) {
508 Err(e) => {
509 let (status, mut allow_list) = e.deconstruct();
510 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
511 allow_list.sort_by(|a, b| a.as_ref().cmp(b.as_ref()));
512 assert_eq!(allow_list, vec![Method::PATCH, Method::POST]);
513 }
514 Ok(_) => panic!("expected mismatched route to test allow header"),
515 },
516 None => panic!("traversal should have succeeded here"),
517 }
518
519 let rs = RequestPathSegments::new("/resource/100");
520 match root.match_node(rs.segments()) {
521 Some((node, _params, _processed)) => match node.select_route(&state) {
522 Err(e) => {
523 let (status, mut allow_list) = e.deconstruct();
524 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
525 allow_list.sort_by(|a, b| a.as_ref().cmp(b.as_ref()));
526 assert_eq!(allow_list, vec![Method::GET]);
527 }
528 Ok(_) => panic!("expected mismatched route to test allow header"),
529 },
530 None => panic!("traversal should have succeeded here"),
531 }
532 }
533
534 #[test]
535 fn node_traversal_tests() {
536 let pipeline_set = finalize_pipeline_set(new_pipeline_set());
537 let mut root_node_builder = Node::new("/", SegmentType::Static);
538 let mut activate_node_builder = Node::new("activate", SegmentType::Static);
539
540 let mut workflow_node = Node::new("workflow", SegmentType::Static);
541 let route = {
542 let methods = vec![Method::GET];
543 let matcher = MethodOnlyRouteMatcher::new(methods);
544 let dispatcher = Box::new(DispatcherImpl::new(|| Ok(handler), (), pipeline_set));
545 let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> =
546 Extractors::new();
547 let route = RouteImpl::new(matcher, dispatcher, extractors, Delegation::Internal);
548 Box::new(route)
549 };
550 workflow_node.add_route(route);
551
552 activate_node_builder.add_child(workflow_node);
553 root_node_builder.add_child(activate_node_builder);
554
555 let root_node = root_node_builder;
556 match root_node.match_node(&[
557 PercentDecoded::new("activate").unwrap(),
558 PercentDecoded::new("workflow").unwrap(),
559 ]) {
560 Some((node, _params, processed)) => {
561 assert!(node.is_routable());
562 assert_eq!(processed, 2)
563 }
564 None => panic!(),
565 }
566 }
567}