gotham/router/tree/
node.rs

1//! Defines `Node` for `Tree`.
2
3use hyper::{Body, StatusCode};
4use log::trace;
5
6use crate::helpers::http::PercentDecoded;
7use crate::router::non_match::RouteNonMatch;
8use crate::router::route::{Delegation, Route};
9use crate::router::tree::segment::{SegmentMapping, SegmentType};
10use crate::state::{request_id, State};
11
12use std::cmp::Ordering;
13use std::collections::HashMap;
14
15/// A recursive member of `Tree`, representative of segment(s) in a request path.
16///
17/// Each node includes `0..n` `Route` instances, which can be further evaluated by the `Router`
18/// based on a match. Every node may also have `0..n` children to provide the recursive tree
19/// representation.
20pub struct Node {
21    segment: String,
22    segment_type: SegmentType,
23    routes: Vec<Box<dyn Route<ResBody = Body> + Send + Sync>>,
24    children: Vec<Node>,
25}
26
27impl Node {
28    /// Creates new `Node` for the given segment and type.
29    pub fn new(segment: &str, segment_type: SegmentType) -> Self {
30        Node {
31            segment_type,
32            segment: segment.to_string(),
33            routes: vec![],
34            children: vec![],
35        }
36    }
37
38    /// Adds a new child `Node` instance to this `Node`.
39    pub fn add_child(&mut self, node: Node) -> &mut Self {
40        self.children.push(node);
41        self.children.sort();
42        self
43    }
44
45    /// Adds a `Route` to this `Node`, to be potentially evaluated by the `Router`.
46    pub fn add_route(&mut self, route: Box<dyn Route<ResBody = Body> + Send + Sync>) -> &mut Self {
47        self.routes.push(route);
48        self
49    }
50
51    /// Borrows a child `Node` based on the defined segment bounds.
52    pub fn borrow_child(&self, segment: &str, segment_type: SegmentType) -> Option<&Node> {
53        self.children
54            .iter()
55            .find(|n| n.segment_type == segment_type && n.segment == segment)
56    }
57
58    /// Borrows a mutable child `Node` based on the defined segment bounds.
59    pub fn borrow_child_mut(
60        &mut self,
61        segment: &str,
62        segment_type: SegmentType,
63    ) -> Option<&mut Node> {
64        self.children
65            .iter_mut()
66            .find(|n| n.segment_type == segment_type && n.segment == segment)
67    }
68
69    /// Determines if a child exists based on the defined segment bounds.
70    pub fn has_child(&self, segment: &str, segment_type: SegmentType) -> bool {
71        self.borrow_child(segment, segment_type).is_some()
72    }
73
74    /// Determines if this `Node` has any valid `Route` values attached.
75    pub fn is_routable(&self) -> bool {
76        !self.routes.is_empty()
77    }
78
79    /// Traverses this `Node` and its children, attempting to a locate a path of `Node` instances
80    /// which match all segments of the provided `Request` path. The final `Node` must have at
81    /// least a single `Route` attached in order to be returned.
82    ///
83    /// Only the first matching path is returned from this method, and the value is wrapped in
84    /// an `Option` as there may be no matching node.
85    ///
86    /// Children are searched in a most to least specific order of segments, based on the node
87    /// `SegmentType` value:
88    ///
89    /// 1. Static
90    /// 2. Constrained
91    /// 3. Dynamic
92    /// 4. Glob
93    ///
94    /// This method is a wrapping of an internal recursive implementation to mask the required
95    /// types needed for the recursion.
96    pub fn match_node<'a>(
97        &'a self,
98        segments: &'a [PercentDecoded],
99    ) -> Option<(&'a Node, SegmentMapping<'a>, usize)> {
100        // accumulators for recursion
101        let mut params = HashMap::new();
102        let mut processed = 0;
103
104        // process and map the results through to the required form
105        self.inner_match_node(segments, &mut params, &mut processed)
106            .map(|node| (node, params, processed))
107    }
108
109    /// Retrieves a reference to the contained segment value.
110    ///
111    /// This is required for lifetime related annotations.
112    pub fn segment<'a>(&'a self) -> &'a str {
113        &self.segment
114    }
115
116    /// Determines if a `Route` instance associated with this `Node` is willing to `Handle` the
117    /// request.
118    ///
119    /// Where multiple `Route` instances could possibly handle the `Request` only the first, ordered
120    /// per creation, is invoked.
121    ///
122    /// Where no `Route` instances will accept the `Request` the resulting Error will be the
123    /// union of the `RouteNonMatch` values returned from each `Route`.
124    ///
125    /// In the situation where all these avenues are exhausted an InternalServerError will be
126    /// provided.
127    pub fn select_route(
128        &self,
129        state: &State,
130    ) -> Result<&Box<dyn Route<ResBody = Body> + Send + Sync>, RouteNonMatch> {
131        let mut err = Ok(());
132
133        // check for matching routes
134        for r in self.routes.iter() {
135            match r.is_match(state) {
136                Ok(()) => {
137                    trace!("[{}] found matching route", request_id(state));
138                    return Ok(r);
139                }
140                Err(e) => {
141                    // concat errors
142                    err = match err {
143                        Err(e0) => Err(e.union(e0)),
144                        Ok(()) => Err(e),
145                    }
146                }
147            }
148        }
149
150        // unpack required for types
151        if let Err(e) = err {
152            trace!(
153                "[{}] no matching route, using error status code from route",
154                request_id(state)
155            );
156            return Err(e);
157        }
158
159        trace!(
160            "[{}] invalid state, no routes. sending internal server error",
161            request_id(state)
162        );
163
164        // error because we shouldn't arrive here due to match_node/1
165        Err(RouteNonMatch::new(StatusCode::INTERNAL_SERVER_ERROR))
166    }
167
168    /// Recursive implementation of `match_route` to populate parameters and keep
169    /// track of the number of visited nodes.
170    ///
171    /// There's space for optimizations in here (perhaps), but it seems to perform
172    /// faster than the previous implementation of the router, so all is well for now.
173    fn inner_match_node<'a>(
174        &'a self,
175        segments: &'a [PercentDecoded],
176        params: &mut SegmentMapping<'a>,
177        processed: &mut usize,
178    ) -> Option<&'a Node> {
179        let next_segment = segments.split_first();
180
181        // stop if we're done
182        if next_segment.is_none() {
183            if !self.is_routable() {
184                return None;
185            }
186            return Some(self);
187        }
188
189        // check for external delegates, and stop
190        if let Some(route) = self.routes.first() {
191            if route.delegation() == Delegation::External {
192                return Some(self);
193            }
194        }
195
196        let (segment, remaining) = next_segment.unwrap();
197
198        *processed += 1;
199
200        // check all children first
201        for child in &self.children {
202            match child.segment_type {
203                // Globbing matches everything, so we append the segment value
204                // to the parameters against the child segment name.
205                SegmentType::Glob => {
206                    params.entry(&child.segment).or_default().push(segment);
207                }
208
209                // Static matches based on a raw string match, so we simply
210                // compare the value of the current segment with that of the
211                // child node we're currently iterating.
212                SegmentType::Static => {
213                    // check for raw string match
214                    if child.segment != segment.as_ref() {
215                        continue;
216                    }
217                }
218
219                // Constrained matches are based on a contained pattern the
220                // segment value must match. If the segment matches, we need
221                // to make sure to store the value inside the parameters map.
222                SegmentType::Constrained { ref regex } => {
223                    // check for regex matching
224                    if !regex.is_match(segment.as_ref()) {
225                        continue;
226                    }
227                    // if there's a match, store the value
228                    params.insert(&child.segment, vec![segment]);
229                }
230
231                // Dynamic matches match every value, so we just attach the
232                // segment value to the parameters list (just like with the
233                // constrained type).
234                SegmentType::Dynamic => {
235                    // if there's a match, store the value
236                    params.insert(&child.segment, vec![segment]);
237                }
238            };
239
240            // If we hit this point, we've determined that the child node is
241            // the correct node to delegate to, so we continue the recursion
242            // on the child node, passing in the same parameters.
243            return child.inner_match_node(remaining, params, processed);
244        }
245
246        // If there are no children, but this is a globbing node, then we can
247        // continue the nesting by just shifting the path segments and calling
248        // `inner_match_node` on ourself again (to simulate wildcards).
249        if let SegmentType::Glob = self.segment_type {
250            // push the segment to the parameters of the glob
251            if let Some(path) = params.get_mut(self.segment()) {
252                path.push(segment);
253            }
254            // call again, but after shifting the segments to the next
255            return self.inner_match_node(remaining, params, processed);
256        }
257
258        None
259    }
260}
261
262impl Eq for Node {}
263impl PartialEq for Node {
264    /// Compares two `Node` values for equality based on the segments they represent.
265    fn eq(&self, other: &Node) -> bool {
266        self.segment == other.segment && self.segment_type == other.segment_type
267    }
268}
269
270impl Ord for Node {
271    /// Compares two `Node` values to determine an appropriate `Ordering`.
272    fn cmp(&self, other: &Node) -> Ordering {
273        (&self.segment_type, &self.segment).cmp(&(&other.segment_type, &other.segment))
274    }
275}
276
277impl PartialOrd for Node {
278    /// Compares two `Node` values to determine an appropriate `Ordering`.
279    fn partial_cmp(&self, other: &Node) -> Option<Ordering> {
280        Some(self.cmp(other))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    use std::panic::RefUnwindSafe;
289
290    use hyper::{HeaderMap, Method, Response};
291
292    use crate::extractor::{NoopPathExtractor, NoopQueryStringExtractor};
293    use crate::helpers::http::request::path::RequestPathSegments;
294    use crate::helpers::http::PercentDecoded;
295    use crate::pipeline::{finalize_pipeline_set, new_pipeline_set, PipelineSet};
296    use crate::router::route::dispatch::DispatcherImpl;
297    use crate::router::route::matcher::MethodOnlyRouteMatcher;
298    use crate::router::route::{Delegation, Extractors, Route, RouteImpl};
299    use crate::router::tree::regex::ConstrainedSegmentRegex;
300    use crate::state::{set_request_id, State};
301
302    fn handler(state: State) -> (State, Response<Body>) {
303        (state, Response::new(Body::empty()))
304    }
305
306    fn get_route<P>(pipeline_set: PipelineSet<P>) -> Box<dyn Route<ResBody = Body> + Send + Sync>
307    where
308        P: Send + Sync + RefUnwindSafe + 'static,
309    {
310        let methods = vec![Method::GET];
311        let matcher = MethodOnlyRouteMatcher::new(methods);
312        let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set);
313        let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
314        let route = RouteImpl::new(
315            matcher,
316            Box::new(dispatcher),
317            extractors,
318            Delegation::Internal,
319        );
320        Box::new(route)
321    }
322
323    fn test_structure() -> Node {
324        let mut root = Node::new("/", SegmentType::Static);
325        let pipeline_set = finalize_pipeline_set(new_pipeline_set());
326
327        // Two methods, same path, same handler
328        // [Get|Head]: /seg1
329        let mut seg1 = Node::new("seg1", SegmentType::Static);
330        let methods = vec![Method::GET, Method::HEAD];
331        let matcher = MethodOnlyRouteMatcher::new(methods);
332        let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
333        let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
334        let route = RouteImpl::new(
335            matcher,
336            Box::new(dispatcher),
337            extractors,
338            Delegation::Internal,
339        );
340        seg1.add_route(Box::new(route));
341        root.add_child(seg1);
342
343        // Two methods, same path, different handlers
344        // Post: /seg2
345        let mut seg2 = Node::new("seg2", SegmentType::Static);
346        let methods = vec![Method::POST];
347        let matcher = MethodOnlyRouteMatcher::new(methods);
348        let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
349        let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
350        let route = RouteImpl::new(
351            matcher,
352            Box::new(dispatcher),
353            extractors,
354            Delegation::Internal,
355        );
356        seg2.add_route(Box::new(route));
357
358        // Patch: /seg2
359        let methods = vec![Method::PATCH];
360        let matcher = MethodOnlyRouteMatcher::new(methods);
361        let dispatcher = DispatcherImpl::new(|| Ok(handler), (), pipeline_set.clone());
362        let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> = Extractors::new();
363        let route = RouteImpl::new(
364            matcher,
365            Box::new(dispatcher),
366            extractors,
367            Delegation::Internal,
368        );
369        seg2.add_route(Box::new(route));
370        root.add_child(seg2);
371
372        // Ensure basic traversal
373        // Get: /seg3/seg4
374        let mut seg3 = Node::new("seg3", SegmentType::Static);
375        let mut seg4 = Node::new("seg4", SegmentType::Static);
376        seg4.add_route(get_route(pipeline_set.clone()));
377        seg3.add_child(seg4);
378        root.add_child(seg3);
379
380        // Ensure regex matching works and that it's anchored to the segment and does not allow for
381        // overzealous matching
382        // GET: /resource/<id> where id: [0-9]+
383        let mut seg_resource = Node::new("resource", SegmentType::Static);
384        let mut seg_id = Node::new(
385            "id",
386            SegmentType::Constrained {
387                regex: Box::new(ConstrainedSegmentRegex::new("[0-9]+")),
388            },
389        );
390        seg_id.add_route(get_route(pipeline_set.clone()));
391        seg_resource.add_child(seg_id);
392        root.add_child(seg_resource);
393
394        // Ensure traversal will backtrack and find the correct path if it goes down an ultimately
395        // invalid branch, in this case seg6 initially being matched by the dynamic handler segdyn1
396        // which matches every segment it sees.
397        //
398        // Get /seg5/:segdyn1/seg7
399        // Get /seg5/seg6
400        let mut seg5 = Node::new("seg5", SegmentType::Static);
401        let mut seg6 = Node::new("seg6", SegmentType::Static);
402        seg6.add_route(get_route(pipeline_set.clone()));
403
404        let mut segdyn1 = Node::new(":segdyn1", SegmentType::Dynamic);
405        let mut seg7 = Node::new("seg7", SegmentType::Static);
406        seg7.add_route(get_route(pipeline_set.clone()));
407
408        // Ensure traversal will respect Globs
409        let mut seg8 = Node::new("seg8", SegmentType::Glob);
410        let mut seg9 = Node::new("seg9", SegmentType::Static);
411
412        let mut seg10 = Node::new("seg10", SegmentType::Glob);
413        seg10.add_route(get_route(pipeline_set));
414
415        segdyn1.add_child(seg7);
416        seg5.add_child(seg6);
417        seg5.add_child(segdyn1);
418        root.add_child(seg5);
419
420        seg9.add_child(seg10);
421        seg8.add_child(seg9);
422        root.add_child(seg8);
423
424        root
425    }
426
427    #[test]
428    fn manages_children() {
429        let root = test_structure();
430
431        assert!(root.borrow_child("seg1", SegmentType::Static).is_some());
432        assert!(root.borrow_child("seg2", SegmentType::Static).is_some());
433        assert!(root.borrow_child("seg1", SegmentType::Dynamic).is_none());
434        assert!(root.borrow_child("seg0", SegmentType::Static).is_none());
435    }
436
437    #[test]
438    fn traverses_children() {
439        let root = test_structure();
440
441        // GET /seg3/seg4
442        let rs = RequestPathSegments::new("/seg3/seg4");
443        match root.match_node(rs.segments()) {
444            Some((node, _params, processed)) => {
445                assert_eq!(node.segment, "seg4");
446                assert_eq!(processed, 2);
447            }
448            None => panic!("traversal should have succeeded here"),
449        }
450
451        // GET /seg3/seg4/seg5
452        let rs = RequestPathSegments::new("/seg3/seg4/seg5");
453        assert!(root.match_node(rs.segments()).is_none());
454
455        // GET /seg5/seg6
456        let rs = RequestPathSegments::new("/seg5/seg6");
457        match root.match_node(rs.segments()) {
458            Some((node, _params, processed)) => {
459                assert_eq!(node.segment, "seg6");
460                assert_eq!(processed, 2);
461            }
462            None => panic!("traversal should have succeeded here"),
463        }
464
465        // GET /seg5/someval/seg7
466        let rs = RequestPathSegments::new("/seg5/someval/seg7");
467        match root.match_node(rs.segments()) {
468            Some((node, _params, processed)) => {
469                assert_eq!(node.segment, "seg7");
470                assert_eq!(processed, 3);
471            }
472            None => panic!("traversal should have succeeded here"),
473        }
474
475        // GET /some/path/seg9/another/path
476        let rs = RequestPathSegments::new("/some/path/seg9/another/branch");
477        match root.match_node(rs.segments()) {
478            Some((node, _params, processed)) => {
479                assert_eq!(node.segment, "seg10");
480                assert_eq!(processed, 5);
481            }
482            None => panic!("traversal should have succeeded here"),
483        }
484
485        let rs = RequestPathSegments::new("/resource/5001");
486        let expected_segment = "id";
487        match root.match_node(rs.segments()) {
488            Some((node, _params, processed)) => {
489                assert_eq!(node.segment, expected_segment);
490                assert_eq!(processed, 2);
491            }
492            None => panic!("traversal should have succeeded here"),
493        }
494    }
495
496    #[test]
497    fn non_matching_routes_allow_list_tests() {
498        let root = test_structure();
499
500        let mut state = State::new();
501        state.put(Method::OPTIONS);
502        state.put(HeaderMap::new());
503        set_request_id(&mut state);
504
505        let rs = RequestPathSegments::new("/seg2");
506        match root.match_node(rs.segments()) {
507            Some((node, _params, _processed)) => match node.select_route(&state) {
508                Err(e) => {
509                    let (status, mut allow_list) = e.deconstruct();
510                    assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
511                    allow_list.sort_by(|a, b| a.as_ref().cmp(b.as_ref()));
512                    assert_eq!(allow_list, vec![Method::PATCH, Method::POST]);
513                }
514                Ok(_) => panic!("expected mismatched route to test allow header"),
515            },
516            None => panic!("traversal should have succeeded here"),
517        }
518
519        let rs = RequestPathSegments::new("/resource/100");
520        match root.match_node(rs.segments()) {
521            Some((node, _params, _processed)) => match node.select_route(&state) {
522                Err(e) => {
523                    let (status, mut allow_list) = e.deconstruct();
524                    assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
525                    allow_list.sort_by(|a, b| a.as_ref().cmp(b.as_ref()));
526                    assert_eq!(allow_list, vec![Method::GET]);
527                }
528                Ok(_) => panic!("expected mismatched route to test allow header"),
529            },
530            None => panic!("traversal should have succeeded here"),
531        }
532    }
533
534    #[test]
535    fn node_traversal_tests() {
536        let pipeline_set = finalize_pipeline_set(new_pipeline_set());
537        let mut root_node_builder = Node::new("/", SegmentType::Static);
538        let mut activate_node_builder = Node::new("activate", SegmentType::Static);
539
540        let mut workflow_node = Node::new("workflow", SegmentType::Static);
541        let route = {
542            let methods = vec![Method::GET];
543            let matcher = MethodOnlyRouteMatcher::new(methods);
544            let dispatcher = Box::new(DispatcherImpl::new(|| Ok(handler), (), pipeline_set));
545            let extractors: Extractors<NoopPathExtractor, NoopQueryStringExtractor> =
546                Extractors::new();
547            let route = RouteImpl::new(matcher, dispatcher, extractors, Delegation::Internal);
548            Box::new(route)
549        };
550        workflow_node.add_route(route);
551
552        activate_node_builder.add_child(workflow_node);
553        root_node_builder.add_child(activate_node_builder);
554
555        let root_node = root_node_builder;
556        match root_node.match_node(&[
557            PercentDecoded::new("activate").unwrap(),
558            PercentDecoded::new("workflow").unwrap(),
559        ]) {
560            Some((node, _params, processed)) => {
561                assert!(node.is_routable());
562                assert_eq!(processed, 2)
563            }
564            None => panic!(),
565        }
566    }
567}