use std::{num::ParseIntError, str::FromStr}; use http::{header::AsHeaderName, HeaderMap, StatusCode}; use serde::Serialize; use crate::{ error::{ApiError, DeserializeError, FromHttpError, HeaderError, IntoHttpError}, AuthRequirement, Context, Metadata, }; /// `Pagination` struct is used to specify the page number and the maximum number of items to be shown per page. /// /// Default values are `page = 1` and `limit = 10`. #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize)] pub struct Pagination { pub page: usize, pub limit: usize, } impl Default for Pagination { fn default() -> Self { Self::DEFAULT } } impl Pagination { const DEFAULT: Self = Self::new(1, 10); #[inline] #[must_use] pub const fn new(page: usize, limit: usize) -> Self { Self { page, limit } } } /// `PaginationResponse` struct is used to store the paginated response from the API. #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct PaginationResponse { pub items: Vec, pub current_page: usize, pub items_per_page: usize, pub total_pages: usize, pub total_items: usize, } impl PaginationResponse { /// Create a new `PaginationResponse` instance from items and Trakt.tv API response headers. /// /// # Errors /// /// Returns a `DeserializeError` if the headers are missing or if the header values are not valid. pub fn from_headers(items: Vec, map: &HeaderMap) -> Result { let current_page = parse_from_header(map, "X-Pagination-Page")?; let items_per_page = parse_from_header(map, "X-Pagination-Limit")?; let total_pages = parse_from_header(map, "X-Pagination-Page-Count")?; let total_items = parse_from_header(map, "X-Pagination-Item-Count")?; Ok(Self { items, current_page, items_per_page, total_pages, total_items, }) } #[inline] #[must_use] pub const fn next_page(&self) -> Option { if self.current_page < self.total_pages { Some(Pagination::new(self.current_page + 1, self.items_per_page)) } else { None } } } /// Helper function to parse a header value to an integer. /// /// # Errors /// /// Returns a `DeserializeError` if the header is missing, if the header value is not a valid /// string, or if the string value cannot be parsed to an integer. pub fn parse_from_header(map: &HeaderMap, key: K) -> Result where T: FromStr, K: AsHeaderName, { map.get(key) .ok_or(HeaderError::MissingHeader)? .to_str() .map_err(HeaderError::ToStrError)? .parse() .map_err(DeserializeError::ParseInt) } /// Helper function to handle the response body from the API. /// /// Will check if the response has the expected status code and will try to deserialize the /// response body. /// /// # Errors /// /// Returns a `FromHttpError` if the response status code is not the expected one or if the body /// failed to be deserialized. pub fn handle_response_body( response: &http::Response, expected: StatusCode, ) -> Result where B: AsRef<[u8]>, T: serde::de::DeserializeOwned, { if response.status() == expected { Ok(serde_json::from_slice(response.body().as_ref()).map_err(DeserializeError::Json)?) } else { Err(FromHttpError::Api(ApiError::from(response.status()))) } } /// Helper function to construct an HTTP request using the given context, metadata, and /// path/query/body values. /// /// # Errors /// /// Returns an `IntoHttpError` if the http request cannot be constructed. pub fn construct_req( ctx: &Context, md: &Metadata, path: &impl Serialize, query: &impl Serialize, body: B, ) -> Result, IntoHttpError> { let url = crate::construct_url(ctx.base_url, md.endpoint, path, query)?; let request = http::Request::builder() .method(&md.method) .uri(url) .header("Content-Type", "application/json") .header("trakt-api-version", "2") .header("trakt-api-key", ctx.client_id); let request = match (md.auth, ctx.oauth_token) { (AuthRequirement::None, _) | (AuthRequirement::Optional, None) => request, (AuthRequirement::Optional | AuthRequirement::Required, Some(token)) => { request.header("Authorization", format!("Bearer {token}")) } (AuthRequirement::Required, None) => { return Err(IntoHttpError::MissingToken); } }; Ok(request.body(body)?) } #[cfg(test)] mod tests { use http::HeaderValue; use super::*; #[test] fn test_parse_from_header() { let mut map = HeaderMap::new(); map.insert("B", HeaderValue::from_bytes(b"hello\xfa").unwrap()); map.insert("C", HeaderValue::from_static("hello")); map.insert("D", HeaderValue::from_static("10")); assert!(matches!( parse_from_header::(&map, "A"), Err(DeserializeError::Header(HeaderError::MissingHeader)) )); assert!(matches!( parse_from_header::(&map, "B"), Err(DeserializeError::Header(HeaderError::ToStrError(_))) )); assert!(matches!( parse_from_header::(&map, "C"), Err(DeserializeError::ParseInt(_)) )); assert_eq!(parse_from_header::(&map, "D").unwrap(), 10); } #[test] fn test_handle_response_body_ok() { let response = http::Response::builder() .status(StatusCode::OK) .body(b"\"hello\"") .unwrap(); assert_eq!( handle_response_body::<_, String>(&response, StatusCode::OK).unwrap(), "hello" ); } #[test] fn test_handle_response_body_bad_request() { let response = http::Response::builder() .status(StatusCode::BAD_REQUEST) .body(b"\"hello\"") .unwrap(); assert!(matches!( handle_response_body::<_, String>(&response, StatusCode::OK), Err(FromHttpError::Api(ApiError::BadRequest)) )); } #[test] fn test_handle_response_body_deserialize_error() { let response = http::Response::builder() .status(StatusCode::OK) .body(b"\"hello\xfa\"") .unwrap(); assert!(matches!( handle_response_body::<_, String>(&response, StatusCode::OK), Err(FromHttpError::Deserialize(DeserializeError::Json(_))) )); } #[allow(clippy::cognitive_complexity)] #[test] fn test_construct_req() { let mut ctx = Context { base_url: "https://api.trakt.tv", client_id: "client id", oauth_token: None, }; let mut md = Metadata { endpoint: "/test", method: http::Method::GET, auth: AuthRequirement::None, }; let req = construct_req(&ctx, &md, &(), &(), "body").unwrap(); assert_eq!(req.method(), &http::Method::GET); assert_eq!(req.uri(), "https://api.trakt.tv/test"); assert_eq!( req.headers().get("Content-Type").unwrap(), "application/json" ); assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2"); assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id"); assert!(req.headers().get("Authorization").is_none()); assert_eq!(req.into_body(), "body"); md.auth = AuthRequirement::Required; ctx.oauth_token = Some("token"); let req = construct_req(&ctx, &md, &(), &(), "body").unwrap(); assert_eq!(req.method(), &http::Method::GET); assert_eq!(req.uri(), "https://api.trakt.tv/test"); assert_eq!( req.headers().get("Content-Type").unwrap(), "application/json" ); assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2"); assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id"); assert_eq!(req.headers().get("Authorization").unwrap(), "Bearer token"); assert_eq!(req.into_body(), "body"); md.auth = AuthRequirement::Required; ctx.oauth_token = None; let result = construct_req(&ctx, &md, &(), &(), "body").unwrap_err(); assert!(matches!(result, IntoHttpError::MissingToken)); md.auth = AuthRequirement::Optional; ctx.oauth_token = None; let req = construct_req(&ctx, &md, &(), &(), "body").unwrap(); assert_eq!(req.method(), &http::Method::GET); assert_eq!(req.uri(), "https://api.trakt.tv/test"); assert_eq!( req.headers().get("Content-Type").unwrap(), "application/json" ); assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2"); assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id"); assert!(req.headers().get("Authorization").is_none()); assert_eq!(req.into_body(), "body"); md.auth = AuthRequirement::Optional; ctx.oauth_token = Some("token"); let req = construct_req(&ctx, &md, &(), &(), "body").unwrap(); assert_eq!(req.method(), &http::Method::GET); assert_eq!(req.uri(), "https://api.trakt.tv/test"); assert_eq!( req.headers().get("Content-Type").unwrap(), "application/json" ); assert_eq!(req.headers().get("trakt-api-version").unwrap(), "2"); assert_eq!(req.headers().get("trakt-api-key").unwrap(), "client id"); assert_eq!(req.headers().get("Authorization").unwrap(), "Bearer token"); assert_eq!(req.into_body(), "body"); } }