248 lines
7.0 KiB
Rust
248 lines
7.0 KiB
Rust
#![deny(warnings)]
|
|
extern crate biscuit;
|
|
extern crate base64;
|
|
extern crate hyper;
|
|
extern crate ldap3;
|
|
extern crate tokio;
|
|
|
|
#[macro_use] extern crate log;
|
|
|
|
extern crate serde_derive;
|
|
|
|
use std::env;
|
|
use std::fs;
|
|
use std::path::Path;
|
|
use std::str::{
|
|
FromStr,
|
|
from_utf8,
|
|
};
|
|
use std::thread;
|
|
|
|
use hyper::rt::{Future};
|
|
use hyper::{Body, Request, Response, Server, StatusCode};
|
|
use hyper::header::{AUTHORIZATION};
|
|
use hyper_router::{Route, RouterBuilder, RouterService};
|
|
use base64::decode;
|
|
|
|
use biscuit::{
|
|
Empty,
|
|
};
|
|
use biscuit::jwa::{
|
|
SignatureAlgorithm,
|
|
Algorithm,
|
|
};
|
|
use biscuit::jwk::{
|
|
RSAKeyParameters,
|
|
CommonParameters,
|
|
AlgorithmParameters,
|
|
JWK,
|
|
JWKSet,
|
|
};
|
|
use num::BigUint;
|
|
use openssl::rsa::Rsa;
|
|
|
|
use ldap3::{ LdapConn, Scope, SearchEntry };
|
|
|
|
#[derive(Debug)]
|
|
struct BasicAuthentication {
|
|
pub username: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum AuthError {
|
|
Parse,
|
|
Decode,
|
|
LdapBind,
|
|
LdapConfig,
|
|
LdapConnection,
|
|
LdapSearch,
|
|
}
|
|
|
|
impl FromStr for BasicAuthentication {
|
|
type Err = AuthError;
|
|
fn from_str(s: &str) -> Result<BasicAuthentication, AuthError> {
|
|
match decode(s) {
|
|
Ok(bytes) => match from_utf8(&bytes) {
|
|
Ok(text) => {
|
|
let mut pair = text.splitn(2, ":");
|
|
Ok(BasicAuthentication {
|
|
username: pair.next().unwrap().to_string(),
|
|
password: pair.next().unwrap().to_string(),
|
|
})
|
|
},
|
|
Err(_) => Err(AuthError::Parse)
|
|
},
|
|
Err(_) => Err(AuthError::Decode)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct LdapUser {
|
|
pub dn: String,
|
|
pub mail: Vec<String>,
|
|
pub services: Vec<String>,
|
|
}
|
|
|
|
fn auth_user(auth: &BasicAuthentication) -> Result<LdapUser, AuthError> {
|
|
let ldap_server_addr = match env::var("LDAP_SERVER_ADDR") {
|
|
Ok(addr) => addr,
|
|
_ => return Err(AuthError::LdapConfig),
|
|
};
|
|
let ldap = match LdapConn::new(&ldap_server_addr) {
|
|
Ok(conn) => conn,
|
|
Err(_err) => return Err(AuthError::LdapConnection),
|
|
};
|
|
|
|
let base = format!("uid={},ou=people,dc=xeentech,dc=com", auth.username);
|
|
match ldap.simple_bind(&base, &auth.password).unwrap().success() {
|
|
Ok(_ldap) => println!("Connected and authenticated"),
|
|
Err(_err) => return Err(AuthError::LdapBind),
|
|
};
|
|
|
|
let filter = format!("(uid={})", auth.username);
|
|
let s = match ldap.search(&base, Scope::Subtree, &filter, vec!["mail", "enabledService"]) {
|
|
Ok(result) => {
|
|
let (rs, _) = result.success().unwrap();
|
|
rs
|
|
},
|
|
Err(_err) => return Err(AuthError::LdapSearch),
|
|
};
|
|
|
|
// Grab the first, if any, result and discard the rest
|
|
let se = SearchEntry::construct(s.first().unwrap().to_owned());
|
|
let services = match se.attrs.get("enabledService") {
|
|
Some(services) => services.to_vec(),
|
|
None => [].to_vec(),
|
|
};
|
|
let mail = match se.attrs.get("mail") {
|
|
Some(mail) => mail.to_vec(),
|
|
None => [].to_vec(),
|
|
};
|
|
|
|
Ok(LdapUser {
|
|
dn: base,
|
|
mail: mail,
|
|
services: services,
|
|
})
|
|
}
|
|
|
|
fn auth_handler(req: Request<Body>) -> Response<Body> {
|
|
let header = match req.headers().get(AUTHORIZATION) {
|
|
Some(auth_value) => auth_value.to_str().unwrap(),
|
|
None => return Response::builder()
|
|
.status(StatusCode::UNAUTHORIZED)
|
|
.body(Body::from("Authentication header missing"))
|
|
.unwrap(),
|
|
};
|
|
let (auth_type, credentials) = {
|
|
let mut split = header.split_ascii_whitespace();
|
|
let auth_type = split.next().unwrap();
|
|
let credentials = split.next().unwrap();
|
|
(auth_type, credentials)
|
|
};
|
|
|
|
if auth_type != "Basic" {
|
|
return Response::builder()
|
|
.status(StatusCode::UNAUTHORIZED)
|
|
.body(Body::from("Basic Authentication was expected"))
|
|
.unwrap();
|
|
}
|
|
let auth = BasicAuthentication::from_str(credentials).unwrap();
|
|
let worker = thread::spawn(move || {
|
|
let user = auth_user(&auth);
|
|
user
|
|
});
|
|
let user = match worker.join().unwrap() {
|
|
Ok(ldap_user) => ldap_user,
|
|
Err(AuthError::LdapBind) => {
|
|
return Response::builder()
|
|
.status(StatusCode::UNAUTHORIZED)
|
|
.body(Body::from("LDAP bind failed"))
|
|
.unwrap();
|
|
},
|
|
_ => {
|
|
return Response::builder()
|
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
|
.body(Body::from("Something is broken"))
|
|
.unwrap();
|
|
}
|
|
};
|
|
|
|
Response::new(Body::from(format!("BasicAuthentication {:?}", user)))
|
|
}
|
|
|
|
fn jwk_from_pem(file_path: &Path) -> Result<JWK<Empty>, Box<dyn std::error::Error + 'static>> {
|
|
let key_bytes = fs::read(file_path)?;
|
|
let rsa = Rsa::private_key_from_pem(key_bytes.as_slice())?;
|
|
Ok(JWK {
|
|
common: CommonParameters {
|
|
algorithm: Some(Algorithm::Signature(SignatureAlgorithm::RS256)),
|
|
key_id: Some(file_path.file_name().unwrap().to_str().unwrap().to_string()),
|
|
..Default::default()
|
|
},
|
|
algorithm: AlgorithmParameters::RSA(RSAKeyParameters {
|
|
n: BigUint::from_bytes_be(&rsa.n().to_vec()),
|
|
e: BigUint::from_bytes_be(&rsa.e().to_vec()),
|
|
..Default::default()
|
|
}),
|
|
additional: Default::default(),
|
|
})
|
|
}
|
|
|
|
fn get_keys(_req: Request<Body>) -> Response<Body> {
|
|
let jwks: Vec<JWK<Empty>> = fs::read_dir("./").unwrap()
|
|
.filter_map(|dir_entry| {
|
|
let path = dir_entry.unwrap().path();
|
|
let ext = match path.extension() {
|
|
Some(ext) => ext.to_str().unwrap().to_owned(),
|
|
None => return None,
|
|
};
|
|
match ext.as_ref() {
|
|
"pem" => match jwk_from_pem(path.as_path()) {
|
|
Ok(jwk) => Some(jwk),
|
|
_ => None,
|
|
},
|
|
_ => None,
|
|
}
|
|
})
|
|
.collect();
|
|
let jwks = JWKSet { keys: jwks };
|
|
let jwks_json = serde_json::to_string(&jwks).unwrap();
|
|
Response::builder()
|
|
.status(StatusCode::OK)
|
|
.header("Content-Type", "application/json")
|
|
.body(Body::from(jwks_json))
|
|
.unwrap()
|
|
}
|
|
|
|
fn hello(_req: Request<Body>) -> Response<Body> {
|
|
Response::new(Body::from("Hi!"))
|
|
}
|
|
|
|
fn router_service() -> Result<RouterService, std::io::Error> {
|
|
let router = RouterBuilder::new()
|
|
.add(Route::get("/").using(hello))
|
|
.add(Route::post("/auth").using(auth_handler))
|
|
.add(Route::get("/oauth2/keys").using(get_keys))
|
|
.build();
|
|
Ok(RouterService::new(router))
|
|
}
|
|
|
|
fn main() {
|
|
env_logger::init();
|
|
|
|
let addr_str = match env::var("LISTEN_ADDR") {
|
|
Ok(addr) => addr,
|
|
_ => "0.0.0.0:3000".to_string(),
|
|
};
|
|
let addr = addr_str.parse().expect("Bad Address");
|
|
let server = Server::bind(&addr)
|
|
.serve(router_service)
|
|
.map_err(|e| eprintln!("server error: {}", e));
|
|
|
|
info!("Listening on http://{}", addr);
|
|
tokio::run(server);
|
|
}
|