auth-server/src/main.rs
Alex Wright e6f9148865 Use biscuit's JWK structures to format /get_keys.
New method to load a .pem and a loop over current dir for any .pem
2020-02-23 19:42:32 +01:00

259 lines
7.3 KiB
Rust

#![deny(warnings)]
extern crate biscuit;
extern crate base64;
extern crate hyper;
extern crate ldap3;
extern crate tokio;
#[macro_use]
extern crate serde_derive;
use std::env;
use std::io;
use std::fs;
use std::fs::File;
use std::path::Path;
use std::str::{
FromStr,
from_utf8,
};
use std::thread;
use hyper::rt::{Future};
use hyper::{Body, Request, Response, Server, StatusCode};
use hyper::header::{AUTHORIZATION};
use hyper_router::{Route, RouterBuilder, RouterService};
use base64::decode;
use ring::signature::RsaKeyPair;
use biscuit::{
ClaimsSet,
Empty,
JWT,
RegisteredClaims,
SingleOrMultiple,
};
use biscuit::jwa::{
SignatureAlgorithm,
Algorithm,
};
use biscuit::jwk::{
RSAKeyParameters,
CommonParameters,
AlgorithmParameters,
JWK,
JWKSet,
};
use biscuit::jws::{
Secret,
RegisteredHeader,
};
use num::BigUint;
use openssl::bn::BigNum;
use openssl::rsa::Rsa;
use openssl::rsa::RsaPrivateKeyBuilder;
use ldap3::{ LdapConn, Scope, SearchEntry };
#[derive(Debug)]
struct BasicAuthentication {
pub username: String,
pub password: String,
}
#[derive(Debug)]
pub enum AuthError {
Parse,
Decode,
LdapBind,
LdapConfig,
LdapConnection,
LdapSearch,
}
impl FromStr for BasicAuthentication {
type Err = AuthError;
fn from_str(s: &str) -> Result<BasicAuthentication, AuthError> {
match decode(s) {
Ok(bytes) => match from_utf8(&bytes) {
Ok(text) => {
let mut pair = text.splitn(2, ":");
Ok(BasicAuthentication {
username: pair.next().unwrap().to_string(),
password: pair.next().unwrap().to_string(),
})
},
Err(_) => Err(AuthError::Parse)
},
Err(_) => Err(AuthError::Decode)
}
}
}
#[derive(Debug)]
struct LdapUser {
pub dn: String,
pub mail: Vec<String>,
pub services: Vec<String>,
}
fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
let ldap_server_addr = match env::var("LDAP_SERVER_ADDR") {
Ok(addr) => addr,
_ => return Err(AuthError::LdapConfig),
};
let ldap = match LdapConn::new(&ldap_server_addr) {
Ok(conn) => conn,
Err(_err) => return Err(AuthError::LdapConnection),
};
let base = format!("uid={},ou=people,dc=xeentech,dc=com", auth.username);
match ldap.simple_bind(&base, &auth.password).unwrap().success() {
Ok(_ldap) => println!("Connected and authenticated"),
Err(_err) => return Err(AuthError::LdapBind),
};
let filter = format!("(uid={})", auth.username);
let s = match ldap.search(&base, Scope::Subtree, &filter, vec!["mail", "enabledService"]) {
Ok(result) => {
let (rs, _) = result.success().unwrap();
rs
},
Err(_err) => return Err(AuthError::LdapSearch),
};
// Grab the first, if any, result and discard the rest
let se = SearchEntry::construct(s.first().unwrap().to_owned());
let services = match se.attrs.get("enabledService") {
Some(services) => services.to_vec(),
None => [].to_vec(),
};
let mail = match se.attrs.get("mail") {
Some(mail) => mail.to_vec(),
None => [].to_vec(),
};
Ok(LdapUser {
dn: base,
mail: mail,
services: services,
})
}
fn auth_handler(req: Request<Body>) -> Response<Body> {
let header = match req.headers().get(AUTHORIZATION) {
Some(auth_value) => auth_value.to_str().unwrap(),
None => return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::from("Authentication header missing"))
.unwrap(),
};
let (auth_type, credentials) = {
let mut split = header.split_ascii_whitespace();
let auth_type = split.next().unwrap();
let credentials = split.next().unwrap();
(auth_type, credentials)
};
if auth_type != "Basic" {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::from("Basic Authentication was expected"))
.unwrap();
}
let auth = BasicAuthentication::from_str(credentials).unwrap();
let worker = thread::spawn(move || {
let user = auth_user(&auth);
user
});
let user = match worker.join().unwrap() {
Ok(ldap_user) => ldap_user,
Err(AuthError::LdapBind) => {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::from("LDAP bind failed"))
.unwrap();
},
_ => {
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from("Something is broken"))
.unwrap();
}
};
Response::new(Body::from(format!("BasicAuthentication {:?}", user)))
}
fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, io::Error> {
let key_bytes = fs::read(file_path)?;
let rsa = Rsa::private_key_from_pem(key_bytes.as_slice()).unwrap();
Ok(JWK {
common: CommonParameters {
algorithm: Some(Algorithm::Signature(SignatureAlgorithm::RS256)),
key_id: Some(file_path.file_name().unwrap().to_str().unwrap().to_string()),
..Default::default()
},
algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
n: BigUint::from_bytes_be(&rsa.n().to_vec()),
e: BigUint::from_bytes_be(&rsa.e().to_vec()),
..Default::default()
}),
additional: Default::default(),
})
}
fn get_keys(_req: Request<Body>) -> Response<Body> {
let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap()
.filter_map(|dir_entry| {
let path = dir_entry.unwrap().path();
let filename = match path.file_name() {
Some(filename) => filename.to_str().unwrap().to_owned(),
None => return None,
};
let ext = match path.extension() {
Some(ext) => ext.to_str().unwrap().to_owned(),
None => return None,
};
match ext.as_ref() {
"pem" => Some(jwk_from_pem(path.as_path()).unwrap()),
_ => None,
}
})
.collect();
let jwks = JWKSet { keys: jwks };
let jwks_json = serde_json::to_string(&jwks).unwrap();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(Body::from(jwks_json))
.unwrap()
}
fn hello(_req: Request<Body>) -> Response<Body> {
Response::new(Body::from("Hi!"))
}
fn router_service() -> Result<RouterService, std::io::Error> {
let router = RouterBuilder::new()
.add(Route::get("/hello").using(hello))
.add(Route::post("/auth").using(auth_handler))
.add(Route::get("/oauth2/certs").using(get_keys))
.build();
Ok(RouterService::new(router))
}
fn main() {
let addr_str = match env::var("LISTEN_ADDR") {
Ok(addr) => addr,
_ => "0.0.0.0:3000".to_string(),
};
let addr = addr_str.parse().expect("Bad Address");
let server = Server::bind(&addr)
.serve(router_service)
.map_err(|e| eprintln!("server error: {}", e));
println!("Listening on http://{}", addr);
tokio::run(server);
}