using System; using System.Collections; using System.Collections.Generic; using System.IO; using System.Linq; using System.Net; using System.Net.Sockets; using System.Security.Cryptography; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading.Tasks; using lib; using server; namespace Server; public class Program { const int MSG_LEN = 16; // msg len is 128 bits = 16 bytes static readonly Data Data = new(); static readonly Random Rand = new((int)DateTime.Now.Ticks); static async Task Main() { // Generally this key would be static but since its not production we can generate it every time to make sure // the users has the key and could load it from file RSA key = RSA.Create(1024); File.WriteAllText("server_key.pem", key.ExportRSAPublicKeyPem()); int port = 12345; TcpListener server = new(IPAddress.Parse("0.0.0.0"), port); int connectionCounter = 0; Log log = new(-1, null); Console.CancelKeyPress += delegate { Console.WriteLine("\nEXITING!"); server.Stop(); }; try { server.Start(); byte[] buffer = new byte[256]; while (true) { // Currently, every time it gets a block, it will simply send it back but ToUpper TcpClient client = await server.AcceptTcpClientAsync(); _ = Task.Run(async () => { try { await HandleClient(client, connectionCounter, key); } catch (Exception ex) { log.System($"Client crashed: {ex.Message}"); log.System(ex.StackTrace ?? "MISSING STACK TRACE"); } }); connectionCounter += 1; } } catch (Exception ex) { log.System($"Server error: {ex.Message}"); log.System("Trace: " + ex.StackTrace); } finally { server.Stop(); } } static async Task HandleClient(TcpClient client, int id, RSA pubKey) { Log log = new(id, null); log.System("Got a new client"); string clientPhone = ""; NetworkStream stream = client.GetStream(); byte[] buffer = new byte[1024]; byte counter = 0; // Get AES session key int len = await stream.ReadAsync(buffer); log.Encrypted("Key + IV", buffer[..len]); byte[] skBytes = pubKey.Decrypt(buffer[..len], RSAEncryptionPadding.OaepSHA256); log.Decrypted("Key + IV", skBytes); Aes sk = Aes.Create(); sk.Key = skBytes[..32]; // just to make sure no one sends a too big to be true key sk.IV = skBytes[32..]; await stream.WriteAsync(new byte[] { 0 }); // Get first message (should be either login or register) len = await stream.ReadAsync(buffer); log.Encrypted("Auth message", buffer[..len]); byte[] msgDec = sk.DecryptCfb(buffer[..len], sk.IV, PaddingMode.PKCS7); log.Decrypted("Auth message", msgDec); byte[] msg = msgDec[..MSG_LEN]; log.Message(Request.RequestToString(msg)); if (msg[0] != 0) { log.System("Invalid message version!"); client.Dispose(); return; } counter = IncrementCounter(msg[2]); // allow counter to start at a random position if (msg[1] == (byte)RequestType.Register) { // Do register stuff // get phone number string phone = Utils.BytesToNumber(msg[3..11]); log.System($"Client wants to register as {phone}"); clientPhone = phone; log.Client = phone; int keyLen = BitConverter.ToInt32(msg, 11); RSA pub = RSA.Create(); pub.ImportRSAPublicKey(msgDec.AsSpan()[MSG_LEN..], out int bytesRead); log.System($"Imported key is: \n {pub.ExportRSAPublicKeyPem()}\n"); // generate the 6 digit code and send it byte[] code = [ (byte)Rand.Next(10), (byte)Rand.Next(10), (byte)Rand.Next(10), (byte)Rand.Next(10), (byte)Rand.Next(10), (byte)Rand.Next(10), ]; await SendBySecureChannel(stream, code); // wait for the code to be back with a key int tries = 5; // allow 5 tries before closing the connection and forcing a restart while (tries > 0) { tries -= 1; len = await stream.ReadAsync(buffer); log.Encrypted("ConfirmRegister", buffer[..len]); msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.None); log.Decrypted("ConfirmRegister", msg); log.Message(Request.RequestToString(msg)); byte[] sig = buffer[MSG_LEN..len]; log.Decrypted("ConfirmRegister (sig)", sig); if (msg[0] != 0 || msg[1] != (byte)RequestType.ConfirmRegister || msg[2] != counter) { // invalid or unexpected req, someone might be sending dups continue; } counter = IncrementCounter(counter); byte[] gottenCode = msg[3..9]; int expectedSigLen = BitConverter.ToInt32(msg, 9); if (expectedSigLen != len - MSG_LEN) { log.System($"expected sig len doesnt match read len: {expectedSigLen} / {len - MSG_LEN}"); } // check if the codes are equal if (code.Zip(gottenCode).Any(a => a.First != a.Second)) { // codes are not equal, send a nack // perhaps we should only send a SIG INVALID (or just FAILED) and hide the reason in case someone tries to guess the code msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("BAD CODE"), sk.IV, PaddingMode.PKCS7); log.Decrypted("Fail Register", Encoding.UTF8.GetBytes("BAD CODE")); log.Encrypted("Fail Register", msg); await stream.WriteAsync(msg); } else { // codes are equal - verify sig bool sigValid = pub.VerifyData(gottenCode, sig, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); if (sigValid) { msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("OK"), sk.IV, PaddingMode.PKCS7); log.Decrypted("Register Success", Encoding.UTF8.GetBytes("OK")); log.Encrypted("Register Success", msg); await stream.WriteAsync(msg); Data.Keys[phone] = pub; // save the key break; } else { msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("SIG INVALID"), sk.IV, PaddingMode.PKCS7); log.Decrypted("Fail Register", Encoding.UTF8.GetBytes("SIG INVALID")); log.Encrypted("Fail Register", msg); await stream.WriteAsync(msg); } } } } else if (msg[1] == (byte)RequestType.Login) { // verify login clientPhone = Utils.BytesToNumber(msg[3..11]); log.Client = clientPhone; counter = IncrementCounter(msg[2]); if (!Data.Keys.TryGetValue(clientPhone, out RSA? clientKey)) { stream.Close(); client.Close(); log.System($"Client claims to be {clientPhone}, but could not find key in records"); return; } byte[] challenge = new byte[16]; Rand.NextBytes(challenge); log.Decrypted("Challenge", challenge); byte[] response = sk.EncryptCfb(challenge, sk.IV, PaddingMode.None); log.Encrypted("Challenge", response); await stream.WriteAsync(response); len = await stream.ReadAsync(buffer); log.Encrypted("Challenge Response", buffer[..len]); msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.None); log.Decrypted("Challenge Message", msg); log.Message(Request.RequestToString(msg)); if (msg[2] != counter) { client.Close(); log.System($"Invalid counter in login response, quitting"); return; } counter = IncrementCounter(counter); byte[] sig = buffer[MSG_LEN..len]; // sig wasnt encrypted so this print is somewhat redundant, but i think its still useful to see the sig alone log.Decrypted("Challenge Sig", sig); bool valid = clientKey.VerifyData(challenge, sig, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); if (valid) { log.System("Client verification complete"); response = sk.EncryptCfb(Encoding.UTF8.GetBytes("OK"), sk.IV, PaddingMode.PKCS7); log.Decrypted("Challenge Response", Encoding.UTF8.GetBytes("OK")); log.Encrypted("Challenge Response", response); await stream.WriteAsync(response); } else { log.System("Client failed verification, invalid signature"); response = sk.EncryptCfb(Encoding.UTF8.GetBytes("INVALID SIG"), sk.IV, PaddingMode.PKCS7); log.Decrypted("Challenge Response", Encoding.UTF8.GetBytes("INVALID SIG")); log.Encrypted("Challenge Response", response); await stream.WriteAsync(response); client.Close(); return; } } else { // invalid connection, quit log.System("Client didnt register or login as first message"); client.Dispose(); return; } // Client registered/logged in, do main messages loop try { while (client.Connected) { // while the client is connected, simply read messages from the client and handle accordingly, // either by getting new messages for other ppl, or sending back keys/pending messages len = await stream.ReadAsync(buffer); log.Encrypted("Request", buffer[..len]); msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.None); log.Decrypted("Request message", msg); log.Message(Request.RequestToString(msg)); // verify that the counter message is correct if (msg[0] != 0 || msg[2] != counter) { msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("INVALID REQUEST"), sk.IV, PaddingMode.PKCS7); await stream.WriteAsync(msg); continue; } counter = IncrementCounter(counter); switch ((RequestType)msg[1]) { case RequestType.GetMessages: byte[] msgsLens = Enumerable.Repeat(0, 16).ToArray(); // 128 bits // get 15 messages, last byte will indicate if there are more List msgs = Data.GetMessages(clientPhone, 7) ?? []; byte[] msgsBytes = new byte[msgs.Select(m => m.Length).Sum()]; int msgsbytesIndex = 0; for (int i = 0; i < msgs.Count; i += 1) { // messages are encrypted blocks of (currently) 1024 RSA keys, so it would be 256 bytes // meaning we need a short at least (technically we need 9 bytes, but using a full short will allow for // bigger key sizes without much hassle, until a certain length) msgsLens[2 * i] = (byte)(msgs[i].Length >> 8); msgsLens[(2 * i) + 1] = (byte)msgs[i].Length; // copy the message to the msgsBytes array Array.Copy(msgs[i], 0, msgsBytes, msgsbytesIndex, msgs[i].Length); msgsbytesIndex += msgs[i].Length; } msgsLens[15] = Data.PeekMessages(clientPhone) ? (byte)1 : (byte)0; // only need to encrypt the lengths of the messages, as the messages themselves are encrypted log.Decrypted("GetMessages", msgsLens); msgsLens = sk.EncryptCfb(msgsLens, sk.IV, PaddingMode.None); log.Encrypted("GetMessages", msgsLens); byte[] finalPayload = [.. msgsLens, .. msgsBytes]; log.Encrypted("GetMessages Final", finalPayload); await stream.WriteAsync(finalPayload); break; case RequestType.GetUserKey: string phone = Utils.BytesToNumber(msg[3..11]); RSA? key = Data.GetKey(phone); if (key != null) { msg = [0, .. key.ExportRSAPublicKey()]; log.Decrypted("GetUserKey", msg); msg = sk.EncryptCfb(msg, sk.IV, PaddingMode.PKCS7); log.Encrypted("GetUserKey", msg); await stream.WriteAsync(msg); } else { msg = [1, .. Encoding.UTF8.GetBytes("USER DOES NOT EXIST")]; log.Decrypted("GetUserKey", msg); msg = sk.EncryptCfb(msg, sk.IV, PaddingMode.PKCS7); log.Encrypted("GetUserKey", msg); await stream.WriteAsync(msg); } break; case RequestType.SendMessage: string recv = Utils.BytesToNumber(msg[3..11]); int msgLen = BitConverter.ToInt32(msg, 11); if (msgLen != (len - MSG_LEN)) { log.System($"Got message to {recv} of length {len - MSG_LEN} but expected {msgLen}"); } byte[] clientMsg = buffer[MSG_LEN..(msgLen + MSG_LEN)]; log.Encrypted("SendMessage message", clientMsg); // simply add the clientMsg to the "Data" bool added = Data.AddMessage(recv, clientMsg); log.System($"Added message to {recv} of length {msgLen}: {added}"); break; default: msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("INVALID REQUEST"), sk.IV, PaddingMode.PKCS7); await stream.WriteAsync(msg); break; } } } catch (Exception ex) { log.System($"Client failed with error {ex.Message}"); log.System($"Stack: {ex.StackTrace}"); } client.Dispose(); } static byte IncrementCounter(byte counter) { return counter == byte.MaxValue ? (byte)0 : (byte)(counter + 1); } static async Task SendBySecureChannel(NetworkStream stream, byte[] code) { await stream.WriteAsync(code); } }