diff --git a/README.md b/README.md index e7bc968..4134b4e 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,12 @@ # Project - TODO: -[ ] Create a skeleton protocol -[ ] implement most of the skeleton - [ ] Create basic TCP server - [ ] Create basic client that connects to the server - [ ] Send ping message from client to server - [ ] Add more items based on skeleton protocol -[ ] Refine protocol using the implementation (and update stuff that got changed in impl) -[ ] Finish implementing the protocol -[ ] Update the protocol file with the latest structs and stuff +[ ] implement SendMessage at the server +[ ] implement Login + [ ] client + [ ] server +[ ] Figure out how to do the messages themselves +[ ] implement sending messages properly +[ ] implement message acks ## Protocol todo: diff --git a/client/Program.cs b/client/Program.cs index 7dbed7d..6c67243 100644 --- a/client/Program.cs +++ b/client/Program.cs @@ -1,16 +1,20 @@ using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Security.Cryptography; using System.Text; using System.Threading.Tasks; -using System.Security.Cryptography; -using System.Linq; -using System.IO; using lib; namespace Client; public class Program { + static byte counter = 0; + static async Task Main(string[] args) { string user = args @@ -48,33 +52,38 @@ public class Program } else { - + // attempt to login here } - var inputTask = Task.Run(async () => await HandleUserInput(client, stream)); - var serverInput = Task.Run(async () => await HandleServerInput(client, stream)); - - _ = Task.WaitAny(inputTask, serverInput); + await HandleUserInput(client, stream, sk, privKey); } async static Task RegisterClient(string user, RSA pub, RSA priv, RSA server, Aes sk, NetworkStream stream) { - byte counter = 0; + Console.WriteLine("Attempting to register with public key:"); + Console.WriteLine(pub.ExportRSAPublicKeyPem()); // Generate aes key and send it forward + Console.WriteLine($"Session key: {string.Join(' ', sk.Key)}"); + Console.WriteLine($"Session IV: {string.Join(' ', sk.IV)}"); byte[] skEnc = server.Encrypt([.. sk.Key, .. sk.IV], RSAEncryptionPadding.OaepSHA256); await stream.WriteAsync(skEnc); + // wait for the server to confirm it recieved the keys + await stream.ReadExactlyAsync(new byte[1]); + // Generate the Register msg + Console.WriteLine("Sending rsa public key thing"); byte[] pubBytes = pub.ExportRSAPublicKey(); byte[] data = new byte[12]; Array.Copy(Utils.NumberToBytes(user), data, 8); Array.Copy(BitConverter.GetBytes(pubBytes.Length), 0, data, 8, 4); byte[] msg = Request.CreateRequest(RequestType.Register, ref counter, data); // Encrypt msg and send it - byte[] enc = sk.EncryptCfb(msg, sk.IV, PaddingMode.PKCS7); - byte[] payload = [.. enc, .. pubBytes]; - await stream.WriteAsync(payload); + byte[] payload = [.. msg, .. pubBytes]; + byte[] enc = sk.EncryptCfb(payload, sk.IV, PaddingMode.PKCS7); + Console.WriteLine($"payload length: {enc.Length}"); + await stream.WriteAsync(enc); // get the 6 digit code (from "secure channel", actually an OK message is expected here but the 6 digit code kinda replaces it) byte[] digits = new byte[6]; @@ -88,6 +97,7 @@ public class Program // get the 6 digit code from the user while (true) { + Console.Write("> "); string? code = Console.ReadLine()?.Trim(); if (code == null || code.Take(6).Any(c => !char.IsDigit(c))) { @@ -104,7 +114,7 @@ public class Program // Sign the 6 digit code & Generate ConfirmRegister message byte[] signed = priv.SignData(codeBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); msg = Request.CreateRequest(RequestType.ConfirmRegister, ref counter, [.. codeBytes, .. BitConverter.GetBytes(signed.Length)]); - enc = sk.EncryptCfb(msg, sk.IV, PaddingMode.PKCS7); // no reason to encrpy the signature + enc = sk.EncryptCfb(msg, sk.IV, PaddingMode.None); // no reason to encrpy the signature payload = [.. enc, .. signed]; // should be 128 (enc) + 256 (signed) await stream.WriteAsync(payload); // wait for OK/NACK response (anything other than OK is a NACK) @@ -137,10 +147,13 @@ public class Program File.WriteAllText($"privkey_{user}.pem", priv.ExportRSAPrivateKeyPem()); } - static async Task HandleUserInput(TcpClient client, NetworkStream stream) + static async Task HandleUserInput(TcpClient client, NetworkStream stream, Aes sk, RSA privKey) { + string? currentChat = null; + Dictionary publicKeys = []; while (client.Connected) { + Console.Write("> "); string? input = Console.ReadLine(); if (input == null) { @@ -149,35 +162,139 @@ public class Program } else if (input.StartsWith('/')) { + string[] words = input.Split(' '); // Commands :D, i like commands - switch (input.ToLower()) + switch (words[0].ToLower()) { case "/quit": case "/exit": case "/q!": return; + case "/chat": + case "/msg": + string? old = currentChat; + currentChat = words.Length > 1 ? words[1] : null; + if (currentChat != null && currentChat.Length > 16) + { + Console.WriteLine("Invalid number: too long"); + currentChat = old; + continue; + } + if (currentChat != null && !publicKeys.ContainsKey(currentChat)) + { + // attempt to get the currnet chat's key, this is also possible to do just before sending a message + // but i decided to do it here since if better fits the current structure + RSA? key = await GetPublicKey(stream, sk, currentChat); + if (key != null) + { + publicKeys[currentChat] = key; + } + else + { + currentChat = old; + Console.WriteLine($"Reverting to previous chat: {currentChat ?? "none"}"); + } + } + break; + case "/read": + case "/get": + case "/fetch": + case "/pull": + await GetMessages(stream, sk, privKey, publicKeys); + break; } } else { - stream.Write(Encoding.ASCII.GetBytes(input)); - Console.WriteLine($"[{DateTime.Now}]Sent to server: {input}"); + if (currentChat == null) + { + Console.WriteLine("No chat is active, please select chat using '/msg [number]' or '/chat [number]'"); + } + else if (publicKeys.TryGetValue(currentChat!, out RSA? key)) + { + // TODO: add signature and origin please yes thank you + byte[] userMsg = key.Encrypt(Encoding.UTF8.GetBytes(input), RSAEncryptionPadding.OaepSHA256); + byte[] req = Request.CreateRequest( + RequestType.SendMessage, + ref counter, + [.. Utils.NumberToBytes(currentChat!), .. BitConverter.GetBytes(userMsg.Length)]); + req = sk.EncryptCfb(req, sk.IV); + stream.Write([.. req, .. userMsg]); + Console.WriteLine($"[{DateTime.Now}] Sent to server: {input}"); + } + else + { + Console.WriteLine($"active chat exists, but no key was found..."); + } } } } - static async Task HandleServerInput(TcpClient client, NetworkStream stream) + static async Task GetPublicKey(NetworkStream stream, Aes sk, string chat) { - byte[] buffer = new byte[1024]; - while (client.Connected) + byte[] req = Request.CreateRequest(RequestType.GetUserKey, ref counter, Utils.NumberToBytes(chat)); + req = sk.EncryptCfb(req, sk.IV, PaddingMode.None); // no need for padding this is exactly 128 bytes + await stream.WriteAsync(req); + byte[] response = new byte[1024]; + int len = await stream.ReadAsync(response); + byte[] key = sk.DecryptCfb(response[..len], sk.IV, PaddingMode.PKCS7); + if (key[0] == 1) { - int readLen = await stream.ReadAsync(buffer); - if (readLen != 0) - { - string fromServer = Encoding.ASCII.GetString(buffer[..readLen]); - Console.WriteLine($"[{DateTime.Now}] From server: {fromServer}"); + Console.WriteLine($"failed getting key for {chat}: {Encoding.UTF8.GetString(key[1..])}"); + return null; + } + else + { + RSA bobsKey = RSA.Create(); + bobsKey.ImportRSAPublicKey(key.AsSpan()[1..], out int _); + Console.WriteLine($"Got key:\n {bobsKey.ExportRSAPublicKeyPem()}\n"); + return bobsKey; + } + } + static async Task GetMessages(NetworkStream stream, Aes sk, RSA privKey, Dictionary publicKeys) + { + byte[] req = Request.CreateRequest(RequestType.GetMessages, ref counter, []); + req = sk.EncryptCfb(req, sk.IV, PaddingMode.None); // no need for padding this is exactly 128 bytes + await stream.WriteAsync(req); + byte[] buffer = new byte[1024]; + int len = await stream.ReadAsync(buffer); + byte[] lengths = buffer[..16]; + lengths = sk.DecryptCfb(lengths, sk.IV, PaddingMode.None); + byte[] msg = new byte[1024]; // msg buffer + + int start = 16; // skip the first 16 bytes since its the lengths message + foreach (byte l in lengths.Take(15)) + { + if (l == 0) { break; } // a 0 means we are done actually, as empty messages shouldn't be allowed + // get the msg + int end = start + l; + if (end > len) + { + // we need to read more as there was use of more than 1024 overall + // todo for now + // TODO: read more incoming bytes when messages exceed the 1024 buffer + } + Console.WriteLine($"got ecnryped message: {Convert.ToBase64String(buffer[start..end])}"); + // decrypt the message + if (privKey.TryDecrypt(buffer.AsSpan()[start..end], msg, RSAEncryptionPadding.OaepSHA256, out int written)) + { + byte[] dec = msg[..written]; + Console.WriteLine($"decrypted message: {Convert.ToBase64String(dec)}"); + Console.WriteLine($"Message: {Encoding.UTF8.GetString(dec)}"); + } + else + { + // what if we cant decrypt the message? well, i doubt there is anything we can do about it + // even if we know who sent the message we dont know which message it is unless the server also knows + // and i dont want the server to be aware of that, i think it makes more sense for the server to act as a relay + // and a buffer than an actual participant, so if the message is failing to decrypt that will go unnoticed. + // supposedly the sender will notice the lack of ACK and send it again + Console.WriteLine("Incoming message failed to decrypt, unknown sender"); } } + + + bool hasMoreMessages = lengths[15] != 0; } } \ No newline at end of file diff --git a/lib/Request.cs b/lib/Request.cs index d22ed3f..4df156b 100644 --- a/lib/Request.cs +++ b/lib/Request.cs @@ -46,7 +46,7 @@ public static class Request string res = $"V: {Request[0]}, "; res += $"Request type: {(RequestType)Request[1]}, "; res += $"Counter: {Request[2]}\n"; - res += $"Extra: {string.Join(' ', Request[3..].Select(b => b.ToString("{b8}")))}"; + res += $"Extra: {string.Join(' ', Request[3..].Select(b => b.ToString("b8")))}\n"; // also display extra data based on the request itself switch ((RequestType)Request[1]) { @@ -68,8 +68,13 @@ public static class Request case RequestType.GetMessages: break; case RequestType.GetUserKey: + phone = Utils.BytesToNumber(Request[3..11]); + res += $"Phone: {phone}"; break; case RequestType.SendMessage: + phone = Utils.BytesToNumber(Request[3..11]); + int msgLen = BitConverter.ToInt32(Request, 11); + res += $"Phone: {phone}, Message length: {msgLen}"; break; default: res += "INVALID REQUEST TYPE"; diff --git a/server/Data.cs b/server/Data.cs index d4bb099..e31512f 100644 --- a/server/Data.cs +++ b/server/Data.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.Security.Cryptography; @@ -8,20 +9,36 @@ public class Data public Dictionary Keys { set; get; } = []; public Dictionary> Messages { set; get; } = []; - public RSA? GetKey(string Phone) { + public RSA? GetKey(string Phone) + { return Keys.TryGetValue(Phone, out RSA? value) ? value : null; } - public Queue? GetMessages(string Phone) { + public List? GetMessages(string Phone, int limit = -1) + { // Check we have a RSA key for the phone and get the messages - if(!Keys.ContainsKey(Phone)) { return null; } - if(Messages.TryGetValue(Phone, out Queue? value)) { - return value; + if (!Keys.ContainsKey(Phone)) { return null; } + if (Messages.TryGetValue(Phone, out Queue? value)) + { + List msgs = new(limit == -1 ? value.Count : Math.Min(value.Count, limit)); + int count = 0; + while (count != limit && value.TryDequeue(out byte[]? m)) + { + count += 1; + msgs.Add(m); + } + return msgs; } - else { + else + { // generate a new queue because one doesnt already exists Messages[Phone] = new Queue(); - return Messages[Phone]; + return []; // no messages were in the list so no reason to attempt to send any message } } + + public bool PeekMessages(string Phone) + { + return Messages.TryGetValue(Phone, out Queue? value) && value.TryPeek(out var _); + } } \ No newline at end of file diff --git a/server/Program.cs b/server/Program.cs index ca820d3..f6fda95 100644 --- a/server/Program.cs +++ b/server/Program.cs @@ -1,4 +1,6 @@ using System; +using System.Collections; +using System.Collections.Generic; using System.IO; using System.Linq; using System.Net; @@ -37,7 +39,18 @@ public class Program { // Currently, every time it gets a block, it will simply send it back but ToUpper TcpClient client = await server.AcceptTcpClientAsync(); - _ = Task.Run(async () => await HandleClient(client, connectionCounter)); + _ = Task.Run(async () => + { + try + { + await HandleClient(client, connectionCounter, key); + } + catch (Exception ex) + { + Console.WriteLine($"Client crashed: {ex.Message}"); + Console.WriteLine(ex.StackTrace); + } + }); connectionCounter += 1; } } @@ -52,21 +65,30 @@ public class Program } } - static async Task HandleClient(TcpClient client, int id) + static async Task HandleClient(TcpClient client, int id, RSA pubKey) { + Write(id, "Got a new client"); + string clientPhone = ""; NetworkStream stream = client.GetStream(); byte[] buffer = new byte[1024]; byte counter = 0; // Get AES session key int len = await stream.ReadAsync(buffer); + Write(id, $"Got {len} bytes"); + byte[] skBytes = pubKey.Decrypt(buffer[..len], RSAEncryptionPadding.OaepSHA256); Aes sk = Aes.Create(); - sk.Key = buffer[..32]; // just to make sure no one sends a too big to be true key - sk.IV = buffer[32..len]; - Write(id, "key + iv: " + len.ToString()); + sk.Key = skBytes[..32]; // just to make sure no one sends a too big to be true key + sk.IV = skBytes[32..]; + Write(id, $"key: {string.Join(' ', sk.Key)}"); + Write(id, $"IV: {string.Join(' ', sk.IV)}"); + await stream.WriteAsync(new byte[] { 0 }); - // Get first message (should be either login or ) + // Get first message (should be either login or register) len = await stream.ReadAsync(buffer); - byte[] msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.PKCS7); + Write(id, $"Got {len} bytes"); + byte[] msgDec = sk.DecryptCfb(buffer[..len], sk.IV, PaddingMode.PKCS7); + byte[] msg = msgDec[..MSG_LEN]; + Write(id, Request.RequestToString(msg)); if (msg[0] != 0) { Write(id, "Invalid session id!"); @@ -80,10 +102,12 @@ public class Program // get phone number string phone = Utils.BytesToNumber(msg[3..11]); Write(id, $"Client wants to register as {phone}"); + clientPhone = phone; int keyLen = BitConverter.ToInt32(msg, 11); RSA pub = RSA.Create(); - pub.ImportRSAPublicKey(buffer.AsSpan()[MSG_LEN..], out int bytesRead); + pub.ImportRSAPublicKey(msgDec.AsSpan()[MSG_LEN..], out int bytesRead); Write(id, $"Imported key len: {bytesRead} while client claims it is {keyLen}"); + Write(id, $"Imported key is: \n {pub.ExportRSAPublicKeyPem()}\n"); // generate the 6 digit code and send it byte[] code = [ (byte)Rnd.Next(10), @@ -101,7 +125,8 @@ public class Program tries -= 1; len = await stream.ReadAsync(buffer); Write(id, $"Got 6 digit code with sig, len: {len}"); - msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.PKCS7); + msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.None); + Write(id, Request.RequestToString(msg)); byte[] sig = buffer[MSG_LEN..len]; if (msg[0] != 0 || msg[1] != (byte)RequestType.ConfirmRegister || msg[2] != counter) { @@ -145,6 +170,7 @@ public class Program else if (msg[1] == (byte)RequestType.Login) { // verify login + // TODO: Login } else { @@ -162,6 +188,7 @@ public class Program // either by getting new messages for other ppl, or sending back keys/pending messages len = await stream.ReadAsync(buffer); msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.None); + Write(id, Request.RequestToString(msg)); // verify that the counter message is correct if (msg[0] != 0 || msg[2] != counter) { @@ -173,22 +200,47 @@ public class Program switch ((RequestType)msg[1]) { case RequestType.GetMessages: + byte[] msgsLens = new byte[16]; // 128 bits + // get 15 messages, last byte will indicate if there are more + List msgs = Data.GetMessages(clientPhone, 15) ?? []; + byte[] msgsBytes = new byte[msgs.Select(m => m.Length).Sum()]; + int msgsbytesIndex = 0; + for (int i = 0; i < msgsLens.Length - 1; i += 1) + { + // it is expected that all messages will be less than 255 bytes, hence a single byte to + // denote length is sufficient, but a simple update to the protocol can allow up to 7 messages + // per request (instead of 15), and use an ushort (u16) instead + msgsLens[i] = (byte)(msgs.Count > i ? msgs[i].Length : 0); + if (i < msgs.Count) + { + // copy the message to the msgsBytes array + Array.Copy(msgs[i], 0, msgsBytes, msgsbytesIndex, msgs[i].Length); + } + } + msgsLens[15] = Data.PeekMessages(clientPhone) ? (byte)1 : (byte)0; + // only need to encrypt the lengths of the messages, as the messages themselves are encrypted + msgsLens = sk.EncryptCfb(msgsLens, sk.IV, PaddingMode.None); + byte[] finalPayload = [.. msgsLens, .. msgsBytes]; + await stream.WriteAsync(finalPayload); break; case RequestType.GetUserKey: string phone = Utils.BytesToNumber(msg[3..11]); RSA? key = Data.GetKey(phone); if (key != null) { - msg = sk.EncryptCfb(key.ExportRSAPublicKey(), sk.IV, PaddingMode.PKCS7); + msg = [0, .. key.ExportRSAPublicKey()]; + msg = sk.EncryptCfb(msg, sk.IV, PaddingMode.PKCS7); await stream.WriteAsync(msg); } else { - msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("USER DOES NOT EXIST"), sk.IV, PaddingMode.PKCS7); + msg = [1, .. Encoding.UTF8.GetBytes("USER DOES NOT EXIST")]; + msg = sk.EncryptCfb(msg, sk.IV, PaddingMode.PKCS7); await stream.WriteAsync(msg); } break; case RequestType.SendMessage: + break; default: msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("INVALID REQUEST"), sk.IV, PaddingMode.PKCS7);