#![deny(warnings)] #![feature(proc_macro_hygiene)] #[macro_use] extern crate diesel; #[macro_use] extern crate rocket; #[macro_use] extern crate rocket_contrib; use log::info; use serde_derive::Serialize; use std::env; use std::fs; use std::path::Path; use biscuit::{ Empty, }; use biscuit::jwa::{ SignatureAlgorithm, Algorithm, }; use biscuit::jwk::{ RSAKeyParameters, CommonParameters, AlgorithmParameters, JWK, JWKSet, }; use num::BigUint; use openssl::rsa::Rsa; use ldap3::{ LdapConn, Scope, SearchEntry }; use rocket::request::{ FlashMessage, Form, FromRequest, Outcome, Request, }; use rocket::response::{ Flash, Redirect, }; use rocket_contrib::json::Json; use rocket_contrib::templates::Template; mod schema; mod models; #[derive(Debug)] struct BasicAuthentication { pub username: String, pub password: String, } #[derive(Debug)] pub enum AuthError { Parse, Decode, LdapBind, LdapConfig, LdapConnection, LdapSearch, } #[derive(Debug)] struct LdapUser { pub dn: String, pub groups: Vec, pub mail: Vec, pub services: Vec, pub username: String, } fn auth_user(auth: &BasicAuthentication) -> Result { let ldap_server_addr = match env::var("LDAP_SERVER_ADDR") { Ok(addr) => addr, _ => return Err(AuthError::LdapConfig), }; let ldap = match LdapConn::new(&ldap_server_addr) { Ok(conn) => conn, Err(_err) => return Err(AuthError::LdapConnection), }; let base = format!("uid={},ou=people,dc=xeentech,dc=com", auth.username); match ldap.simple_bind(&base, &auth.password).unwrap().success() { Ok(_ldap) => println!("Connected and authenticated"), Err(_err) => return Err(AuthError::LdapBind), }; let filter = format!("(uid={})", auth.username); let s = match ldap.search(&base, Scope::Subtree, &filter, vec!["uid", "mail", "enabledService", "memberOf"]) { Ok(result) => { let (rs, _) = result.success().unwrap(); rs }, Err(_err) => return Err(AuthError::LdapSearch), }; // Grab the first, if any, result and discard the rest let se = SearchEntry::construct(s.first().unwrap().to_owned()); let services = match se.attrs.get("enabledService") { Some(services) => services.to_vec(), None => [].to_vec(), }; let mail = match se.attrs.get("mail") { Some(mail) => mail.to_vec(), None => [].to_vec(), }; let groups = match se.attrs.get("memberOf") { Some(groups) => groups.to_vec(), None => [].to_vec(), }; let username = match se.attrs.get("uid") { Some(username) => username[0].to_owned(), None => "".to_string(), }; info!("Authentication success for {:?}", base); Ok(LdapUser { dn: base, groups: groups, mail: mail, services: services, username: username, }) } use models::{ User }; #[derive(FromForm)] struct LoginData { username: String, password: String, } #[derive(Serialize)] struct LoginFormContext { message: Option, } #[get("/login")] fn login_form(flash: Option>) -> Template { let context = LoginFormContext { message: match flash { Some(ref msg) => Some(msg.msg().to_string()), _ => None, }, }; Template::render("login_form", &context) } #[post("/login", data = "")] fn login(form_data: Form, conn: AuthDb) -> Result> { let auth = BasicAuthentication { username: form_data.username.to_owned(), password: form_data.password.to_owned(), }; let ldap_user = match auth_user(&auth) { Ok(ldap_user) => ldap_user, _ => return Err(Flash::error(Redirect::to(uri!(login_form)), "Not able to authenticate with given credentials.")), }; let user = match User::find_or_create(&conn, ldap_user.username) { Ok(user) => user, _ => return Err(Flash::error(Redirect::to(uri!(login_form)), "Failed to fetch user")), }; if ! user.is_active { return Err(Flash::error(Redirect::to(uri!(login_form)), "Account is suspended")); } println!("User: {:?}", user); Ok(Redirect::to("/")) } fn jwk_from_pem(file_path: &Path) -> Result, Box> { let key_bytes = fs::read(file_path)?; let rsa = Rsa::private_key_from_pem(key_bytes.as_slice())?; Ok(JWK { common: CommonParameters { algorithm: Some(Algorithm::Signature(SignatureAlgorithm::RS256)), key_id: Some(file_path.file_name().unwrap().to_str().unwrap().to_string()), ..Default::default() }, algorithm: AlgorithmParameters::RSA(RSAKeyParameters { n: BigUint::from_bytes_be(&rsa.n().to_vec()), e: BigUint::from_bytes_be(&rsa.e().to_vec()), ..Default::default() }), additional: Default::default(), }) } #[get("/oauth2/keys")] fn get_keys() -> Json> { let jwks: Vec> = fs::read_dir("./").unwrap() .filter_map(|dir_entry| { let path = dir_entry.unwrap().path(); let ext = match path.extension() { Some(ext) => ext.to_str().unwrap().to_owned(), None => return None, }; match ext.as_ref() { "pem" => match jwk_from_pem(path.as_path()) { Ok(jwk) => Some(jwk), _ => None, }, _ => None, } }) .collect(); let jwks = JWKSet { keys: jwks }; Json(jwks) } #[derive(Debug, Serialize)] struct OidcConfig { pub jwks_uri: String, } #[get("/.well-known/openid-configuration")] fn oidc_config() -> Json { let config = OidcConfig { jwks_uri: "https://auth.xeen.dev/oauth2/keys".to_string(), }; Json(config) } impl<'a, 'r> FromRequest<'a, 'r> for User { type Error = (); fn from_request(request: &'a Request<'r>) -> Outcome { let mut user_id = match request.cookies().get_private("user_id") { Some(cookie) => cookie.value().to_string(), None => return Outcome::Forward(()), }; let conn = request.guard::().unwrap(); match User::get_with_id(&conn, user_id) { Ok(user) => Outcome::Success(user), _ => Outcome::Forward(()), } } } #[derive(Serialize)] struct HelloContext { username: String, } #[get("/")] fn hello(user: User) -> Template { println!("User: {:?}", &user); let context = HelloContext { username: user.username, }; Template::render("hello", &context) } fn routes() -> Vec { routes![ hello, oidc_config, get_keys, login, login_form, ] } #[database("xeenauth")] struct AuthDb(diesel::PgConnection); fn main() { env_logger::init(); rocket::ignite() .attach(AuthDb::fairing()) .attach(Template::fairing()) .mount("/", routes()) .launch(); }