#![deny(warnings)] extern crate biscuit; extern crate base64; extern crate hyper; extern crate ldap3; extern crate tokio; #[macro_use] extern crate log; extern crate serde_derive; use std::env; use std::fs; use std::path::Path; use std::str::{ FromStr, from_utf8, }; use std::thread; use hyper::rt::{Future}; use hyper::{Body, Request, Response, Server, StatusCode}; use hyper::header::{AUTHORIZATION}; use hyper_router::{Route, RouterBuilder, RouterService}; use base64::decode; use biscuit::{ Empty, }; use biscuit::jwa::{ SignatureAlgorithm, Algorithm, }; use biscuit::jwk::{ RSAKeyParameters, CommonParameters, AlgorithmParameters, JWK, JWKSet, }; use num::BigUint; use openssl::rsa::Rsa; use ldap3::{ LdapConn, Scope, SearchEntry }; #[derive(Debug)] struct BasicAuthentication { pub username: String, pub password: String, } #[derive(Debug)] pub enum AuthError { Parse, Decode, LdapBind, LdapConfig, LdapConnection, LdapSearch, } impl FromStr for BasicAuthentication { type Err = AuthError; fn from_str(s: &str) -> Result { match decode(s) { Ok(bytes) => match from_utf8(&bytes) { Ok(text) => { let mut pair = text.splitn(2, ":"); Ok(BasicAuthentication { username: pair.next().unwrap().to_string(), password: pair.next().unwrap().to_string(), }) }, Err(_) => Err(AuthError::Parse) }, Err(_) => Err(AuthError::Decode) } } } #[derive(Debug)] struct LdapUser { pub dn: String, pub mail: Vec, pub services: Vec, } fn auth_user(auth: &BasicAuthentication) -> Result { let ldap_server_addr = match env::var("LDAP_SERVER_ADDR") { Ok(addr) => addr, _ => return Err(AuthError::LdapConfig), }; let ldap = match LdapConn::new(&ldap_server_addr) { Ok(conn) => conn, Err(_err) => return Err(AuthError::LdapConnection), }; let base = format!("uid={},ou=people,dc=xeentech,dc=com", auth.username); match ldap.simple_bind(&base, &auth.password).unwrap().success() { Ok(_ldap) => println!("Connected and authenticated"), Err(_err) => return Err(AuthError::LdapBind), }; let filter = format!("(uid={})", auth.username); let s = match ldap.search(&base, Scope::Subtree, &filter, vec!["mail", "enabledService"]) { Ok(result) => { let (rs, _) = result.success().unwrap(); rs }, Err(_err) => return Err(AuthError::LdapSearch), }; // Grab the first, if any, result and discard the rest let se = SearchEntry::construct(s.first().unwrap().to_owned()); let services = match se.attrs.get("enabledService") { Some(services) => services.to_vec(), None => [].to_vec(), }; let mail = match se.attrs.get("mail") { Some(mail) => mail.to_vec(), None => [].to_vec(), }; Ok(LdapUser { dn: base, mail: mail, services: services, }) } fn auth_handler(req: Request) -> Response { let header = match req.headers().get(AUTHORIZATION) { Some(auth_value) => auth_value.to_str().unwrap(), None => return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Body::from("Authentication header missing")) .unwrap(), }; let (auth_type, credentials) = { let mut split = header.split_ascii_whitespace(); let auth_type = split.next().unwrap(); let credentials = split.next().unwrap(); (auth_type, credentials) }; if auth_type != "Basic" { return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Body::from("Basic Authentication was expected")) .unwrap(); } let auth = BasicAuthentication::from_str(credentials).unwrap(); let worker = thread::spawn(move || { let user = auth_user(&auth); user }); let user = match worker.join().unwrap() { Ok(ldap_user) => ldap_user, Err(AuthError::LdapBind) => { return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(Body::from("LDAP bind failed")) .unwrap(); }, _ => { return Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from("Something is broken")) .unwrap(); } }; Response::new(Body::from(format!("BasicAuthentication {:?}", user))) } fn jwk_from_pem(file_path: &Path) -> Result, Box> { let key_bytes = fs::read(file_path)?; let rsa = Rsa::private_key_from_pem(key_bytes.as_slice())?; Ok(JWK { common: CommonParameters { algorithm: Some(Algorithm::Signature(SignatureAlgorithm::RS256)), key_id: Some(file_path.file_name().unwrap().to_str().unwrap().to_string()), ..Default::default() }, algorithm: AlgorithmParameters::RSA(RSAKeyParameters { n: BigUint::from_bytes_be(&rsa.n().to_vec()), e: BigUint::from_bytes_be(&rsa.e().to_vec()), ..Default::default() }), additional: Default::default(), }) } fn get_keys(_req: Request) -> Response { let jwks: Vec> = fs::read_dir("./").unwrap() .filter_map(|dir_entry| { let path = dir_entry.unwrap().path(); let ext = match path.extension() { Some(ext) => ext.to_str().unwrap().to_owned(), None => return None, }; match ext.as_ref() { "pem" => match jwk_from_pem(path.as_path()) { Ok(jwk) => Some(jwk), _ => None, }, _ => None, } }) .collect(); let jwks = JWKSet { keys: jwks }; let jwks_json = serde_json::to_string(&jwks).unwrap(); Response::builder() .status(StatusCode::OK) .header("Content-Type", "application/json") .body(Body::from(jwks_json)) .unwrap() } fn hello(_req: Request) -> Response { Response::new(Body::from("Hi!")) } fn router_service() -> Result { let router = RouterBuilder::new() .add(Route::get("/").using(hello)) .add(Route::post("/auth").using(auth_handler)) .add(Route::get("/oauth2/keys").using(get_keys)) .build(); Ok(RouterService::new(router)) } fn main() { env_logger::init(); let addr_str = match env::var("LISTEN_ADDR") { Ok(addr) => addr, _ => "0.0.0.0:3000".to_string(), }; let addr = addr_str.parse().expect("Bad Address"); let server = Server::bind(&addr) .serve(router_service) .map_err(|e| eprintln!("server error: {}", e)); info!("Listening on http://{}", addr); tokio::run(server); }