Skip to content

Commit cb02dae

Browse files
committed
Started on md5 auth.
Left to figure out: Whats the right format to store user:pw in userlist? hashmap errors? actually do hash:x comparison
1 parent dce72ba commit cb02dae

File tree

10 files changed

+212
-9
lines changed

10 files changed

+212
-9
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
/target
22
*.deb
3+
.idea/*
4+
tests/ruby/.bundle/*
5+
tests/ruby/vendor/*

Cargo.lock

Lines changed: 24 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ sha-1 = "0.10"
1717
toml = "0.5"
1818
serde = "1"
1919
serde_derive = "1"
20+
serde_json = "1"
2021
regex = "1"
2122
num_cpus = "1"
2223
once_cell = "1"

pgcat.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ password = "sharding_user"
4848

4949
# [ host, port, role ]
5050
servers = [
51-
[ "127.0.0.1", 5432, "primary" ],
52-
[ "localhost", 5432, "replica" ],
51+
["127.0.0.1", 5432, "primary"],
52+
["localhost", 5432, "replica"],
5353
# [ "127.0.1.1", 5432, "replica" ],
5454
]
5555
# Database name (e.g. "postgres")
@@ -58,17 +58,17 @@ database = "shard0"
5858
[shards.1]
5959
# [ host, port, role ]
6060
servers = [
61-
[ "127.0.0.1", 5432, "primary" ],
62-
[ "localhost", 5432, "replica" ],
61+
["127.0.0.1", 5432, "primary"],
62+
["localhost", 5432, "replica"],
6363
# [ "127.0.1.1", 5432, "replica" ],
6464
]
6565
database = "shard1"
6666

6767
[shards.2]
6868
# [ host, port, role ]
6969
servers = [
70-
[ "127.0.0.1", 5432, "primary" ],
71-
[ "localhost", 5432, "replica" ],
70+
["127.0.0.1", 5432, "primary"],
71+
["localhost", 5432, "replica"],
7272
# [ "127.0.1.1", 5432, "replica" ],
7373
]
7474
database = "shard2"

src/client.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,18 @@ impl Client {
104104
// Regular startup message.
105105
PROTOCOL_VERSION_NUMBER => {
106106
debug!("Got StartupMessage");
107-
108-
// TODO: perform actual auth.
109107
let parameters = parse_startup(bytes.clone())?;
108+
let mut user_name: String = String::new();
109+
match parameters.get(&"user") {
110+
Some(&user) => user_name = user,
111+
None => return Err(Error::ClientBadStartup),
112+
}
113+
start_auth(&mut stream, &user_name).await?;
110114

111115
// Generate random backend ID and secret key
112116
let process_id: i32 = rand::random();
113117
let secret_key: i32 = rand::random();
114118

115-
auth_ok(&mut stream).await?;
116119
write_all(&mut stream, server_info).await?;
117120
backend_key_data(&mut stream, process_id, secret_key).await?;
118121
ready_for_query(&mut stream).await?;

src/errors.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ pub enum Error {
88
// ServerTimeout,
99
// DirtyServer,
1010
BadConfig,
11+
BadUserList,
1112
AllServersDown,
13+
AuthenticationError
1214
}

src/main.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use std::sync::{Arc, Mutex};
4848

4949
mod client;
5050
mod config;
51+
mod userlist;
5152
mod constants;
5253
mod errors;
5354
mod messages;
@@ -92,6 +93,15 @@ async fn main() {
9293
}
9394
};
9495

96+
// Prepare user list
97+
match userlist::parse("userlist.json").await {
98+
Ok(_) => (),
99+
Err(err) => {
100+
error!("Userlist parse error: {:?}", err);
101+
return;
102+
}
103+
};
104+
95105
let config = get_config();
96106

97107
let addr = format!("{}:{}", config.general.host, config.general.port);

src/messages.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,115 @@
11
/// Helper functions to send one-off protocol messages
22
/// and handle TcpStream (TCP socket).
3+
4+
35
use bytes::{Buf, BufMut, BytesMut};
46
use md5::{Digest, Md5};
57
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
68
use tokio::net::{
79
tcp::{OwnedReadHalf, OwnedWriteHalf},
810
TcpStream,
911
};
12+
use log::{error};
1013

1114
use std::collections::HashMap;
1215

16+
use md5;
17+
use rand::Rng;
18+
1319
use crate::errors::Error;
20+
use crate::userlist::get_user_list;
21+
22+
/**
23+
1. Generate salt (4 bytes of random data)
24+
md5(concat(md5(concat(password, username)), random-salt)))
25+
2. Send md5 auth request
26+
3. recieve PasswordMessage with salt.
27+
4. refactor md5_password function to be reusable
28+
5. check username hash combo against file
29+
6. AuthenticationOk or ErrorResponse
30+
**/
31+
pub async fn start_auth(stream: &mut TcpStream, user_name: &String) -> Result<(), Error> {
32+
let mut rng = rand::thread_rng();
33+
34+
//Generate random 4 byte salt
35+
let salt = rng.gen::<u32>();
36+
37+
// Send AuthenticationMD5Password request
38+
send_md5_request(stream, salt).await?;
39+
40+
let code = match stream.read_u8().await {
41+
Ok(code) => code as char,
42+
Err(_) => return Err(Error::AuthenticationError),
43+
};
44+
45+
match code {
46+
// Password response
47+
'p' => {
48+
fetch_password_and_authenticate(stream, &user_name, &salt).await?;
49+
Ok(auth_ok(stream).await?)
50+
}
51+
_ => {
52+
error!("Unknown code: {}", code);
53+
return Err(Error::AuthenticationError);
54+
}
55+
}
56+
}
57+
58+
pub async fn send_md5_request(stream: &mut TcpStream, salt: u32) -> Result<(), Error> {
59+
let mut authentication_md5password = BytesMut::with_capacity(12);
60+
authentication_md5password.put_u8(b'R');
61+
authentication_md5password.put_i32(12);
62+
authentication_md5password.put_i32(5);
63+
authentication_md5password.put_u32(salt);
64+
65+
// Send AuthenticationMD5Password request
66+
Ok(write_all(stream, authentication_md5password).await?)
67+
}
68+
69+
pub async fn fetch_password_and_authenticate(stream: &mut TcpStream, user_name: &String, salt: &u32) -> Result<(), Error> {
70+
/**
71+
1. How do I store the lists of users and paswords? clear text or hash?? wtf
72+
2. Add auth to tests
73+
**/
74+
75+
let len = match stream.read_i32().await {
76+
Ok(len) => len,
77+
Err(_) => return Err(Error::AuthenticationError),
78+
};
79+
80+
// Read whatever is left.
81+
let mut password_hash = vec![0u8; len as usize - 4];
82+
83+
match stream.read_exact(&mut password_hash).await {
84+
Ok(_) => (),
85+
Err(_) => return Err(Error::AuthenticationError),
86+
};
87+
88+
let user_list = get_user_list();
89+
let mut password: String = String::new();
90+
match user_list.get(&user_name) {
91+
Some(&p) => password = p,
92+
None => return Err(Error::AuthenticationError),
93+
}
94+
95+
let mut md5 = Md5::new();
96+
97+
// concat('md5', md5(concat(md5(concat(password, username)), random-salt)))
98+
// First pass
99+
md5.update(&password.as_bytes());
100+
md5.update(&user_name.as_bytes());
101+
let output = md5.finalize_reset();
102+
// Second pass
103+
md5.update(format!("{:x}", output));
104+
md5.update(salt.to_be_bytes().to_vec());
105+
106+
107+
let password_string: String = String::from_utf8(password_hash).expect("Could not get password hash");
108+
match format!("md5{:x}", md5.finalize()) == password_string {
109+
true => Ok(()),
110+
_ => Err(Error::AuthenticationError)
111+
}
112+
}
14113

15114
/// Tell the client that authentication handshake completed successfully.
16115
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {

src/userlist.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"sven": "clear_text_password",
3+
"sharding_user": "sharding_user"
4+
}

src/userlist.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use arc_swap::{ArcSwap, Guard};
2+
use log::{error};
3+
use once_cell::sync::Lazy;
4+
use tokio::fs::File;
5+
use tokio::io::AsyncReadExt;
6+
7+
use std::collections::{HashMap};
8+
use std::sync::Arc;
9+
10+
use crate::errors::Error;
11+
12+
pub type UserList = HashMap<String, String>;
13+
static USER_LIST: Lazy<ArcSwap<UserList>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
14+
15+
pub fn get_user_list() -> Guard<Arc<UserList>> {
16+
USER_LIST.load()
17+
}
18+
19+
/// Parse the user list.
20+
pub async fn parse(path: &str) -> Result<(), Error> {
21+
let mut contents = String::new();
22+
let mut file = match File::open(path).await {
23+
Ok(file) => file,
24+
Err(err) => {
25+
error!("Could not open '{}': {}", path, err.to_string());
26+
return Err(Error::BadConfig);
27+
}
28+
};
29+
30+
match file.read_to_string(&mut contents).await {
31+
Ok(_) => (),
32+
Err(err) => {
33+
error!("Could not read config file: {}", err.to_string());
34+
return Err(Error::BadConfig);
35+
}
36+
};
37+
38+
let map: HashMap<String, String> = serde_json::from_str(&contents).expect("JSON was not well-formatted");
39+
40+
41+
42+
USER_LIST.store(Arc::new(map.clone()));
43+
44+
Ok(())
45+
}
46+
47+
#[cfg(test)]
48+
mod test {
49+
use super::*;
50+
51+
#[tokio::test]
52+
async fn test_config() {
53+
parse("userlist.json").await.unwrap();
54+
assert_eq!(get_user_list()["sven"], "clear_text_password");
55+
assert_eq!(get_user_list()["sharding_user"], "sharding_user");
56+
}
57+
}

0 commit comments

Comments
 (0)