1#[cfg(feature = "openapi")]
2use crate::openapi::{
3 builder::{OpenapiBuilder, OpenapiInfo},
4 router::OpenapiRouter
5};
6use crate::{response::ResourceError, Endpoint, FromBody, IntoResponse, Resource, Response};
7#[cfg(feature = "cors")]
8use gotham::router::route::matcher::AccessControlRequestMethodMatcher;
9use gotham::{
10 handler::HandlerError,
11 helpers::http::response::{create_empty_response, create_response},
12 hyper::{body::to_bytes, header::CONTENT_TYPE, Body, HeaderMap, Method, StatusCode},
13 mime::{Mime, APPLICATION_JSON},
14 pipeline::PipelineHandleChain,
15 prelude::*,
16 router::{
17 builder::{RouterBuilder, ScopeBuilder},
18 route::matcher::{AcceptHeaderRouteMatcher, ContentTypeHeaderRouteMatcher, RouteMatcher},
19 RouteNonMatch
20 },
21 state::{FromState, State}
22};
23#[cfg(feature = "openapi")]
24use openapi_type::OpenapiType;
25use std::{any::TypeId, panic::RefUnwindSafe};
26
27#[derive(Clone, Copy, Debug, Deserialize, StateData, StaticResponseExtender)]
29#[cfg_attr(feature = "openapi", derive(OpenapiType))]
30pub struct PathExtractor<ID: RefUnwindSafe + Send + 'static> {
31 pub id: ID
32}
33
34#[cfg(feature = "openapi")]
38pub trait WithOpenapi<D> {
39 fn with_openapi<F>(&mut self, info: OpenapiInfo, block: F)
40 where
41 F: FnOnce(OpenapiRouter<'_, D>);
42}
43
44#[_private_openapi_trait(DrawResourcesWithSchema)]
47pub trait DrawResources {
48 #[openapi_bound(R: crate::ResourceWithSchema)]
49 #[non_openapi_bound(R: crate::Resource)]
50 fn resource<R>(&mut self, path: &str);
51}
52
53#[_private_openapi_trait(DrawResourceRoutesWithSchema)]
56pub trait DrawResourceRoutes {
57 #[openapi_bound(E: crate::EndpointWithSchema)]
58 #[non_openapi_bound(E: crate::Endpoint)]
59 fn endpoint<E: 'static>(&mut self);
60}
61
62fn response_from(res: Response, state: &State) -> gotham::hyper::Response<Body> {
63 let mut r = create_empty_response(state, res.status);
64 let headers = r.headers_mut();
65 if let Some(mime) = res.mime {
66 headers.insert(CONTENT_TYPE, mime.as_ref().parse().unwrap());
67 }
68 let mut last_name = None;
69 for (name, value) in res.headers {
70 if name.is_some() {
71 last_name = name;
72 }
73 let name = last_name.clone().unwrap();
75 headers.insert(name, value);
76 }
77
78 let method = Method::borrow_from(state);
79 if method != Method::HEAD {
80 *r.body_mut() = res.body;
81 }
82
83 #[cfg(feature = "cors")]
84 crate::cors::handle_cors(state, &mut r);
85
86 r
87}
88
89async fn endpoint_handler<E>(
90 state: &mut State
91) -> Result<gotham::hyper::Response<Body>, HandlerError>
92where
93 E: Endpoint,
94 <E::Output as IntoResponse>::Err: Into<HandlerError>
95{
96 trace!("entering endpoint_handler");
97 let placeholders = E::Placeholders::take_from(state);
98 if TypeId::of::<E::Placeholders>() == TypeId::of::<E::Params>() {
101 state.put(placeholders.clone());
102 }
103 let params = E::Params::take_from(state);
104
105 let body = match E::needs_body() {
106 true => {
107 let body = to_bytes(Body::take_from(state)).await?;
108
109 let content_type: Mime = match HeaderMap::borrow_from(state).get(CONTENT_TYPE) {
110 Some(content_type) => content_type.to_str().unwrap().parse().unwrap(),
111 None => {
112 debug!("Missing Content-Type: Returning 415 Response");
113 let res = create_empty_response(state, StatusCode::UNSUPPORTED_MEDIA_TYPE);
114 return Ok(res);
115 }
116 };
117
118 match E::Body::from_body(body, content_type) {
119 Ok(body) => Some(body),
120 Err(e) => {
121 debug!("Invalid Body: Returning 400 Response");
122 let error: ResourceError = e.into();
123 let json = serde_json::to_string(&error)?;
124 let res =
125 create_response(state, StatusCode::BAD_REQUEST, APPLICATION_JSON, json);
126 return Ok(res);
127 }
128 }
129 },
130 false => None
131 };
132
133 let out = E::handle(state, placeholders, params, body).await;
134 let res = out.into_response().await.map_err(Into::into)?;
135 debug!("Returning response {res:?}");
136 Ok(response_from(res, state))
137}
138
139#[derive(Clone)]
140struct MaybeMatchAcceptHeader {
141 matcher: Option<AcceptHeaderRouteMatcher>
142}
143
144impl RouteMatcher for MaybeMatchAcceptHeader {
145 fn is_match(&self, state: &State) -> Result<(), RouteNonMatch> {
146 match &self.matcher {
147 Some(matcher) => matcher.is_match(state),
148 None => Ok(())
149 }
150 }
151}
152
153impl MaybeMatchAcceptHeader {
154 fn new(types: Option<Vec<Mime>>) -> Self {
155 let types = match types {
156 Some(types) if types.is_empty() => None,
157 types => types
158 };
159 Self {
160 matcher: types.map(AcceptHeaderRouteMatcher::new)
161 }
162 }
163}
164
165impl From<Option<Vec<Mime>>> for MaybeMatchAcceptHeader {
166 fn from(types: Option<Vec<Mime>>) -> Self {
167 Self::new(types)
168 }
169}
170
171#[derive(Clone)]
172struct MaybeMatchContentTypeHeader {
173 matcher: Option<ContentTypeHeaderRouteMatcher>
174}
175
176impl RouteMatcher for MaybeMatchContentTypeHeader {
177 fn is_match(&self, state: &State) -> Result<(), RouteNonMatch> {
178 match &self.matcher {
179 Some(matcher) => matcher.is_match(state),
180 None => Ok(())
181 }
182 }
183}
184
185impl MaybeMatchContentTypeHeader {
186 fn new(types: Option<Vec<Mime>>) -> Self {
187 Self {
188 matcher: types.map(|types| ContentTypeHeaderRouteMatcher::new(types).allow_no_type())
189 }
190 }
191}
192
193impl From<Option<Vec<Mime>>> for MaybeMatchContentTypeHeader {
194 fn from(types: Option<Vec<Mime>>) -> Self {
195 Self::new(types)
196 }
197}
198
199macro_rules! implDrawResourceRoutes {
200 ($implType:ident) => {
201 #[cfg(feature = "openapi")]
202 impl<'a, C, P> WithOpenapi<Self> for $implType<'a, C, P>
203 where
204 C: PipelineHandleChain<P> + Copy + Send + Sync + 'static,
205 P: RefUnwindSafe + Send + Sync + 'static
206 {
207 fn with_openapi<F>(&mut self, info: OpenapiInfo, block: F)
208 where
209 F: FnOnce(OpenapiRouter<'_, $implType<'a, C, P>>)
210 {
211 let router = OpenapiRouter {
212 router: self,
213 scope: None,
214 openapi_builder: &mut OpenapiBuilder::new(info)
215 };
216 block(router);
217 }
218 }
219
220 impl<'a, C, P> DrawResources for $implType<'a, C, P>
221 where
222 C: PipelineHandleChain<P> + Copy + Send + Sync + 'static,
223 P: RefUnwindSafe + Send + Sync + 'static
224 {
225 fn resource<R: Resource>(&mut self, mut path: &str) {
226 if path.starts_with('/') {
227 path = &path[1..];
228 }
229 R::setup((self, path));
230 }
231 }
232
233 impl<'a, C, P> DrawResourceRoutes for (&mut $implType<'a, C, P>, &str)
234 where
235 C: PipelineHandleChain<P> + Copy + Send + Sync + 'static,
236 P: RefUnwindSafe + Send + Sync + 'static
237 {
238 fn endpoint<E: Endpoint + 'static>(&mut self) {
239 let uri = format!("{}/{}", self.1, E::uri());
240 debug!("Registering endpoint for {uri}");
241 self.0.associate(&uri, |assoc| {
242 assoc
243 .request(vec![E::http_method()])
244 .add_route_matcher(MaybeMatchAcceptHeader::new(E::Output::accepted_types()))
245 .with_path_extractor::<E::Placeholders>()
246 .with_query_string_extractor::<E::Params>()
247 .to_async_borrowing(endpoint_handler::<E>);
248
249 #[cfg(feature = "cors")]
250 if E::http_method() != Method::GET {
251 assoc
252 .options()
253 .add_route_matcher(AccessControlRequestMethodMatcher::new(
254 E::http_method()
255 ))
256 .to(crate::cors::cors_preflight_handler);
257 }
258 });
259 }
260 }
261 };
262}
263
264implDrawResourceRoutes!(RouterBuilder);
265implDrawResourceRoutes!(ScopeBuilder);