|
3 | 3 | pub mod page;
|
4 | 4 |
|
5 | 5 | use crate::utils::get_correct_docsrs_style_file;
|
6 |
| -use crate::utils::report_error; |
7 |
| -use anyhow::anyhow; |
| 6 | +use crate::utils::{report_error, spawn_blocking}; |
| 7 | +use anyhow::{anyhow, bail, Context as _}; |
8 | 8 | use serde_json::Value;
|
9 | 9 | use tracing::{info, instrument};
|
10 | 10 |
|
@@ -92,7 +92,7 @@ mod source;
|
92 | 92 | mod statics;
|
93 | 93 | mod strangler;
|
94 | 94 |
|
95 |
| -use crate::{impl_axum_webpage, impl_webpage, Context}; |
| 95 | +use crate::{db::Pool, impl_axum_webpage, impl_webpage, Context}; |
96 | 96 | use anyhow::Error;
|
97 | 97 | use axum::{
|
98 | 98 | extract::Extension,
|
@@ -123,6 +123,7 @@ use std::{borrow::Cow, net::SocketAddr, sync::Arc};
|
123 | 123 | use strangler::StranglerService;
|
124 | 124 | use tower::ServiceBuilder;
|
125 | 125 | use tower_http::trace::TraceLayer;
|
| 126 | +use url::form_urlencoded; |
126 | 127 |
|
127 | 128 | /// Duration of static files for staticfile and DatabaseFileHandler (in seconds)
|
128 | 129 | const STATIC_FILE_CACHE_DURATION: u64 = 60 * 60 * 24 * 30 * 12; // 12 months
|
@@ -428,6 +429,26 @@ fn match_version(
|
428 | 429 | Err(Nope::VersionNotFound)
|
429 | 430 | }
|
430 | 431 |
|
| 432 | +// temporary wrapper around `match_version` for axum handlers. |
| 433 | +// |
| 434 | +// FIXME: this can go when we fully migrated to axum / async in web |
| 435 | +async fn match_version_axum( |
| 436 | + pool: &Pool, |
| 437 | + name: &str, |
| 438 | + input_version: Option<&str>, |
| 439 | +) -> Result<MatchVersion, Error> { |
| 440 | + spawn_blocking({ |
| 441 | + let name = name.to_owned(); |
| 442 | + let input_version = input_version.map(str::to_owned); |
| 443 | + let pool = pool.clone(); |
| 444 | + move || { |
| 445 | + let mut conn = pool.get()?; |
| 446 | + Ok(match_version(&mut conn, &name, input_version.as_deref())?) |
| 447 | + } |
| 448 | + }) |
| 449 | + .await |
| 450 | +} |
| 451 | + |
431 | 452 | #[instrument(skip_all)]
|
432 | 453 | pub(crate) fn build_axum_app(
|
433 | 454 | context: &dyn Context,
|
@@ -539,15 +560,29 @@ fn redirect(url: Url) -> Response {
|
539 | 560 | resp
|
540 | 561 | }
|
541 | 562 |
|
542 |
| -fn axum_redirect(url: &str) -> Result<impl IntoResponse, Error> { |
543 |
| - if !url.starts_with('/') || url.starts_with("//") { |
544 |
| - return Err(anyhow!("invalid redirect URL: {}", url)); |
| 563 | +fn axum_redirect<U>(uri: U) -> Result<impl IntoResponse, Error> |
| 564 | +where |
| 565 | + U: TryInto<http::Uri>, |
| 566 | + <U as TryInto<http::Uri>>::Error: std::fmt::Debug, |
| 567 | +{ |
| 568 | + let uri: http::Uri = uri |
| 569 | + .try_into() |
| 570 | + .map_err(|err| anyhow!("invalid URI: {:?}", err))?; |
| 571 | + |
| 572 | + if let Some(path_and_query) = uri.path_and_query() { |
| 573 | + if path_and_query.as_str().starts_with("//") { |
| 574 | + bail!("protocol relative redirects are forbidden"); |
| 575 | + } |
| 576 | + } else { |
| 577 | + // we always want a path to redirect to, even when it's just `/` |
| 578 | + bail!("missing path in URI"); |
545 | 579 | }
|
| 580 | + |
546 | 581 | Ok((
|
547 | 582 | StatusCode::FOUND,
|
548 | 583 | [(
|
549 | 584 | http::header::LOCATION,
|
550 |
| - http::HeaderValue::try_from(url).expect("invalid url for redirect"), |
| 585 | + http::HeaderValue::try_from(uri.to_string()).context("invalid uri for redirect")?, |
551 | 586 | )],
|
552 | 587 | ))
|
553 | 588 | }
|
@@ -605,6 +640,29 @@ where
|
605 | 640 | }
|
606 | 641 | }
|
607 | 642 |
|
| 643 | +/// Parse an URI into a http::Uri struct. |
| 644 | +/// When `queries` are given these are added to the URL, |
| 645 | +/// with empty `queries` the `?` will be omitted. |
| 646 | +pub(crate) fn axum_parse_uri_with_params<I, K, V>(uri: &str, queries: I) -> Result<http::Uri, Error> |
| 647 | +where |
| 648 | + I: IntoIterator, |
| 649 | + I::Item: Borrow<(K, V)>, |
| 650 | + K: AsRef<str>, |
| 651 | + V: AsRef<str>, |
| 652 | +{ |
| 653 | + let mut queries = queries.into_iter().peekable(); |
| 654 | + if queries.peek().is_some() { |
| 655 | + let query_params: String = form_urlencoded::Serializer::new(String::new()) |
| 656 | + .extend_pairs(queries) |
| 657 | + .finish(); |
| 658 | + format!("{uri}?{}", query_params) |
| 659 | + .parse::<http::Uri>() |
| 660 | + .context("error parsing URL") |
| 661 | + } else { |
| 662 | + uri.parse::<http::Uri>().context("error parsing URL") |
| 663 | + } |
| 664 | +} |
| 665 | + |
608 | 666 | /// MetaData used in header
|
609 | 667 | #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
610 | 668 | pub(crate) struct MetaData {
|
|
0 commit comments