online_security_project/client/Program.cs

300 lines
No EOL
13 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
using lib;
namespace Client;
public class Program
{
static byte counter = 0;
static async Task Main(string[] args)
{
string user = args
.SkipWhile(a => !new[] { "-p", "--phone" }.Contains(a)) // search for the option
.Skip(1) // skip the `-u/--user` itself to get the value
.FirstOrDefault() ?? "0000"; // get the value or deafult if it doesnt exist
// On boot, check if a key is available
RSA? serverKey = LoadRSAFromFile("server_key.pem");
if (serverKey == null) { Console.WriteLine("Could not find server key, please run server before clients!"); return; }
RSA pubKey = LoadRSAFromFile($"pubkey_{user}.pem") ?? RSA.Create(2048);
RSA privKey = LoadRSAFromFile($"privkey_{user}.pem") ?? pubKey;
SaveRSAKeys(user, pubKey, privKey);
using TcpClient client = new("127.0.0.1", 12345);
var stream = client.GetStream();
// First contact init
bool needsRegister = pubKey == privKey || args.Any(a => new[] { "-fr", "--force-register" }.Contains(a));
Aes sk = Aes.Create(); // creates an AES-256 key
if (needsRegister)
{
try
{
await RegisterClient(user, pubKey, privKey, serverKey, sk, stream);
}
catch (Exception ex)
{
Console.WriteLine("Failed registration process");
Console.WriteLine("Exception: " + ex.Message);
Console.WriteLine("Stack: " + ex.StackTrace);
return;
}
}
else
{
// attempt to login here
}
await HandleUserInput(client, stream, sk, privKey);
}
async static Task RegisterClient(string user, RSA pub, RSA priv, RSA server, Aes sk, NetworkStream stream)
{
Console.WriteLine("Attempting to register with public key:");
Console.WriteLine(pub.ExportRSAPublicKeyPem());
// Generate aes key and send it forward
Console.WriteLine($"Session key: {string.Join(' ', sk.Key)}");
Console.WriteLine($"Session IV: {string.Join(' ', sk.IV)}");
byte[] skEnc = server.Encrypt([.. sk.Key, .. sk.IV], RSAEncryptionPadding.OaepSHA256);
await stream.WriteAsync(skEnc);
// wait for the server to confirm it recieved the keys
await stream.ReadExactlyAsync(new byte[1]);
// Generate the Register msg
Console.WriteLine("Sending rsa public key thing");
byte[] pubBytes = pub.ExportRSAPublicKey();
byte[] data = new byte[12];
Array.Copy(Utils.NumberToBytes(user), data, 8);
Array.Copy(BitConverter.GetBytes(pubBytes.Length), 0, data, 8, 4);
byte[] msg = Request.CreateRequest(RequestType.Register, ref counter, data);
// Encrypt msg and send it
byte[] payload = [.. msg, .. pubBytes];
byte[] enc = sk.EncryptCfb(payload, sk.IV, PaddingMode.PKCS7);
Console.WriteLine($"payload length: {enc.Length}");
await stream.WriteAsync(enc);
// get the 6 digit code (from "secure channel", actually an OK message is expected here but the 6 digit code kinda replaces it)
byte[] digits = new byte[6];
int len = 0;
while (len != 6)
{
len = await stream.ReadAsync(digits);
}
// print the 6 digit code
Console.WriteLine($"[{DateTime.Now}] 6 digit code: {string.Join(' ', digits.Select(d => d.ToString()))}");
// get the 6 digit code from the user
while (true)
{
Console.Write("> ");
string? code = Console.ReadLine()?.Trim();
if (code == null || code.Take(6).Any(c => !char.IsDigit(c)))
{
Console.WriteLine("Invalid code!");
continue;
}
byte[] codeBytes = code
.Take(6) // take the first 6 characters
.Select(d => byte.Parse(d.ToString())) // parse into bytes
.ToArray();
// Debug print the inserted value to see it works :)
Console.WriteLine(string.Join(' ', codeBytes.Select(b => b.ToString())));
// Sign the 6 digit code & Generate ConfirmRegister message
byte[] signed = priv.SignData(codeBytes, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
msg = Request.CreateRequest(RequestType.ConfirmRegister, ref counter, [.. codeBytes, .. BitConverter.GetBytes(signed.Length)]);
enc = sk.EncryptCfb(msg, sk.IV, PaddingMode.None); // no reason to encrpy the signature
payload = [.. enc, .. signed]; // should be 128 (enc) + 256 (signed)
await stream.WriteAsync(payload);
// wait for OK/NACK response (anything other than OK is a NACK)
int incoming = await stream.ReadAsync(enc);
msg = sk.DecryptCfb(enc[..incoming], sk.IV, PaddingMode.PKCS7);
string r = Encoding.UTF8.GetString(msg);
if (r == "OK")
{
Console.WriteLine("Registration process complete");
break;
}
}
}
static RSA? LoadRSAFromFile(string file)
{
if (File.Exists(file))
{
string content = File.ReadAllText(file);
RSA k = RSA.Create();
k.ImportFromPem(content);
return k;
}
return null;
}
static void SaveRSAKeys(string user, RSA pub, RSA priv)
{
File.WriteAllText($"pubkey_{user}.pem", pub.ExportRSAPublicKeyPem());
File.WriteAllText($"privkey_{user}.pem", priv.ExportRSAPrivateKeyPem());
}
static async Task HandleUserInput(TcpClient client, NetworkStream stream, Aes sk, RSA privKey)
{
string? currentChat = null;
Dictionary<string, RSA> publicKeys = [];
while (client.Connected)
{
Console.Write("> ");
string? input = Console.ReadLine();
if (input == null)
{
await Task.Delay(100);
continue;
}
else if (input.StartsWith('/'))
{
string[] words = input.Split(' ');
// Commands :D, i like commands
switch (words[0].ToLower())
{
case "/quit":
case "/exit":
case "/q!":
return;
case "/chat":
case "/msg":
string? old = currentChat;
currentChat = words.Length > 1 ? words[1] : null;
if (currentChat != null && currentChat.Length > 16)
{
Console.WriteLine("Invalid number: too long");
currentChat = old;
continue;
}
if (currentChat != null && !publicKeys.ContainsKey(currentChat))
{
// attempt to get the currnet chat's key, this is also possible to do just before sending a message
// but i decided to do it here since if better fits the current structure
RSA? key = await GetPublicKey(stream, sk, currentChat);
if (key != null)
{
publicKeys[currentChat] = key;
}
else
{
currentChat = old;
Console.WriteLine($"Reverting to previous chat: {currentChat ?? "none"}");
}
}
break;
case "/read":
case "/get":
case "/fetch":
case "/pull":
await GetMessages(stream, sk, privKey, publicKeys);
break;
}
}
else
{
if (currentChat == null)
{
Console.WriteLine("No chat is active, please select chat using '/msg [number]' or '/chat [number]'");
}
else if (publicKeys.TryGetValue(currentChat!, out RSA? key))
{
// TODO: add signature and origin please yes thank you
byte[] userMsg = key.Encrypt(Encoding.UTF8.GetBytes(input), RSAEncryptionPadding.OaepSHA256);
byte[] req = Request.CreateRequest(
RequestType.SendMessage,
ref counter,
[.. Utils.NumberToBytes(currentChat!), .. BitConverter.GetBytes(userMsg.Length)]);
req = sk.EncryptCfb(req, sk.IV);
stream.Write([.. req, .. userMsg]);
Console.WriteLine($"[{DateTime.Now}] Sent to server: {input}");
}
else
{
Console.WriteLine($"active chat exists, but no key was found...");
}
}
}
}
static async Task<RSA?> GetPublicKey(NetworkStream stream, Aes sk, string chat)
{
byte[] req = Request.CreateRequest(RequestType.GetUserKey, ref counter, Utils.NumberToBytes(chat));
req = sk.EncryptCfb(req, sk.IV, PaddingMode.None); // no need for padding this is exactly 128 bytes
await stream.WriteAsync(req);
byte[] response = new byte[1024];
int len = await stream.ReadAsync(response);
byte[] key = sk.DecryptCfb(response[..len], sk.IV, PaddingMode.PKCS7);
if (key[0] == 1)
{
Console.WriteLine($"failed getting key for {chat}: {Encoding.UTF8.GetString(key[1..])}");
return null;
}
else
{
RSA bobsKey = RSA.Create();
bobsKey.ImportRSAPublicKey(key.AsSpan()[1..], out int _);
Console.WriteLine($"Got key:\n {bobsKey.ExportRSAPublicKeyPem()}\n");
return bobsKey;
}
}
static async Task GetMessages(NetworkStream stream, Aes sk, RSA privKey, Dictionary<string, RSA> publicKeys)
{
byte[] req = Request.CreateRequest(RequestType.GetMessages, ref counter, []);
req = sk.EncryptCfb(req, sk.IV, PaddingMode.None); // no need for padding this is exactly 128 bytes
await stream.WriteAsync(req);
byte[] buffer = new byte[1024];
int len = await stream.ReadAsync(buffer);
byte[] lengths = buffer[..16];
lengths = sk.DecryptCfb(lengths, sk.IV, PaddingMode.None);
byte[] msg = new byte[1024]; // msg buffer
int start = 16; // skip the first 16 bytes since its the lengths message
foreach (byte l in lengths.Take(15))
{
if (l == 0) { break; } // a 0 means we are done actually, as empty messages shouldn't be allowed
// get the msg
int end = start + l;
if (end > len)
{
// we need to read more as there was use of more than 1024 overall
// todo for now
// TODO: read more incoming bytes when messages exceed the 1024 buffer
}
Console.WriteLine($"got ecnryped message: {Convert.ToBase64String(buffer[start..end])}");
// decrypt the message
if (privKey.TryDecrypt(buffer.AsSpan()[start..end], msg, RSAEncryptionPadding.OaepSHA256, out int written))
{
byte[] dec = msg[..written];
Console.WriteLine($"decrypted message: {Convert.ToBase64String(dec)}");
Console.WriteLine($"Message: {Encoding.UTF8.GetString(dec)}");
}
else
{
// what if we cant decrypt the message? well, i doubt there is anything we can do about it
// even if we know who sent the message we dont know which message it is unless the server also knows
// and i dont want the server to be aware of that, i think it makes more sense for the server to act as a relay
// and a buffer than an actual participant, so if the message is failing to decrypt that will go unnoticed.
// supposedly the sender will notice the lack of response and send it again
Console.WriteLine("Incoming message failed to decrypt, unknown sender");
}
}
bool hasMoreMessages = lengths[15] != 0;
}
}