Switched to Rocket

This commit is contained in:
Alex Wright 2020-02-29 16:48:07 +01:00
parent e0d1eb8897
commit f32681d95d
1 changed files with 36 additions and 106 deletions

View File

@ -1,21 +1,12 @@
#![deny(warnings)] #![deny(warnings)]
#![feature(proc_macro_hygiene)]
#[macro_use] extern crate rocket;
use log::info; use log::info;
use serde_derive::Serialize; use serde_derive::Serialize;
use std::env; use std::env;
use std::fs; use std::fs;
use std::path::Path; use std::path::Path;
use std::str::{
FromStr,
from_utf8,
};
use std::thread;
use hyper::rt::{Future};
use hyper::{Body, Request, Response, Server, StatusCode};
use hyper::header::{AUTHORIZATION};
use hyper_router::{Route, RouterBuilder, RouterService};
use base64::decode;
use biscuit::{ use biscuit::{
Empty, Empty,
@ -35,6 +26,8 @@ use num::BigUint;
use openssl::rsa::Rsa; use openssl::rsa::Rsa;
use ldap3::{ LdapConn, Scope, SearchEntry }; use ldap3::{ LdapConn, Scope, SearchEntry };
use rocket::request::Form;
use rocket_contrib::json::Json;
#[derive(Debug)] #[derive(Debug)]
struct BasicAuthentication { struct BasicAuthentication {
@ -52,25 +45,6 @@ pub enum AuthError {
LdapSearch, LdapSearch,
} }
impl FromStr for BasicAuthentication {
type Err = AuthError;
fn from_str(s: &str) -> Result<BasicAuthentication, AuthError> {
match decode(s) {
Ok(bytes) => match from_utf8(&bytes) {
Ok(text) => {
let mut pair = text.splitn(2, ":");
Ok(BasicAuthentication {
username: pair.next().unwrap().to_string(),
password: pair.next().unwrap().to_string(),
})
},
Err(_) => Err(AuthError::Parse)
},
Err(_) => Err(AuthError::Decode)
}
}
}
#[derive(Debug)] #[derive(Debug)]
struct LdapUser { struct LdapUser {
pub dn: String, pub dn: String,
@ -114,6 +88,7 @@ fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
None => [].to_vec(), None => [].to_vec(),
}; };
info!("Authentication success for {:?}", base);
Ok(LdapUser { Ok(LdapUser {
dn: base, dn: base,
mail: mail, mail: mail,
@ -121,49 +96,22 @@ fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
}) })
} }
fn auth_handler(req: Request<Body>) -> Response<Body> { #[derive(FromForm)]
let header = match req.headers().get(AUTHORIZATION) { struct LoginData {
Some(auth_value) => auth_value.to_str().unwrap(), username: String,
None => return Response::builder() password: String,
.status(StatusCode::UNAUTHORIZED) }
.body(Body::from("Authentication header missing"))
.unwrap(),
};
let (auth_type, credentials) = {
let mut split = header.split_ascii_whitespace();
let auth_type = split.next().unwrap();
let credentials = split.next().unwrap();
(auth_type, credentials)
};
if auth_type != "Basic" { #[post("/login", data = "<form_data>")]
return Response::builder() fn login(form_data: Form<LoginData>) -> String {
.status(StatusCode::UNAUTHORIZED) let auth = BasicAuthentication {
.body(Body::from("Basic Authentication was expected")) username: form_data.username.to_owned(),
.unwrap(); password: form_data.password.to_owned(),
};
match auth_user(&auth) {
Ok(ldap_user) => format!("OK! {:?}", ldap_user),
_ => format!("Bad :("),
} }
let auth = BasicAuthentication::from_str(credentials).unwrap();
let worker = thread::spawn(move || {
let user = auth_user(&auth);
user
});
let user = match worker.join().unwrap() {
Ok(ldap_user) => ldap_user,
Err(AuthError::LdapBind) => {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::from("LDAP bind failed"))
.unwrap();
},
_ => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Something is broken"))
.unwrap();
}
};
Response::new(Body::from(format!("BasicAuthentication {:?}", user)))
} }
fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Error + 'static>> { fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Error + 'static>> {
@ -184,7 +132,8 @@ fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Erro
}) })
} }
fn get_keys(_req: Request<Body>) -> Response<Body> { #[get("/oauth2/keys")]
fn get_keys() -> Json<JWKSet<Empty>> {
let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap() let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap()
.filter_map(|dir_entry| { .filter_map(|dir_entry| {
let path = dir_entry.unwrap().path(); let path = dir_entry.unwrap().path();
@ -202,12 +151,7 @@ fn get_keys(_req: Request<Body>) -> Response<Body> {
}) })
.collect(); .collect();
let jwks = JWKSet { keys: jwks }; let jwks = JWKSet { keys: jwks };
let jwks_json = serde_json::to_string(&jwks).unwrap(); Json(jwks)
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(jwks_json))
.unwrap()
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -215,44 +159,30 @@ struct OidcConfig {
pub jwks_uri: String, pub jwks_uri: String,
} }
fn oidc_config(_req: Request<Body>) -> Response<Body> { #[get("/.well-known/openid-configuration")]
fn oidc_config() -> Json<OidcConfig> {
let config = OidcConfig { let config = OidcConfig {
jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(), jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(),
}; };
let config_json = serde_json::to_string(&config).unwrap(); Json(config)
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(config_json))
.unwrap()
} }
fn hello(_req: Request<Body>) -> Response<Body> { #[get("/")]
Response::new(Body::from("Hi!")) fn hello() -> &'static str {
"Hello!"
} }
fn router_service() -> Result<RouterService, std::io::Error> { fn routes() -> Vec<rocket::Route> {
let router = RouterBuilder::new() routes![
.add(Route::get("/").using(hello)) hello,
.add(Route::post("/auth").using(auth_handler)) oidc_config,
.add(Route::get("/oauth2/keys").using(get_keys)) get_keys,
.add(Route::get("/.well-known/openid-configuration").using(oidc_config)) login,
.build(); ]
Ok(RouterService::new(router))
} }
fn main() { fn main() {
env_logger::init(); env_logger::init();
let addr_str = match env::var("LISTEN_ADDR") { rocket::ignite().mount("/", routes()).launch();
Ok(addr) => addr,
_ => "0.0.0.0:3000".to_string(),
};
let addr = addr_str.parse().expect("Bad Address");
let server = Server::bind(&addr)
.serve(router_service)
.map_err(|e| eprintln!("server error: {}", e));
info!("Listening on http://{}", addr);
tokio::run(server);
} }