online_security_project/server/Program.cs

359 lines
No EOL
16 KiB
C#

using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using lib;
using server;
namespace Server;
public class Program
{
const int MSG_LEN = 16; // msg len is 128 bits = 16 bytes
static readonly Data Data = new();
static readonly Random Rand = new((int)DateTime.Now.Ticks);
static async Task Main()
{
// Generally this key would be static but since its not production we can generate it every time to make sure
// the users has the key and could load it from file
RSA key = RSA.Create(1024);
File.WriteAllText("server_key.pem", key.ExportRSAPublicKeyPem());
int port = 12345;
TcpListener server = new(IPAddress.Parse("0.0.0.0"), port);
int connectionCounter = 0;
Log log = new(-1, null);
Console.CancelKeyPress += delegate
{
Console.WriteLine("\nEXITING!");
server.Stop();
};
try
{
server.Start();
byte[] buffer = new byte[256];
while (true)
{
// Currently, every time it gets a block, it will simply send it back but ToUpper
TcpClient client = await server.AcceptTcpClientAsync();
_ = Task.Run(async () =>
{
try
{
await HandleClient(client, connectionCounter, key);
}
catch (Exception ex)
{
log.System($"Client crashed: {ex.Message}");
log.System(ex.StackTrace ?? "MISSING STACK TRACE");
}
});
connectionCounter += 1;
}
}
catch (Exception ex)
{
log.System($"Server error: {ex.Message}");
log.System("Trace: " + ex.StackTrace);
}
finally
{
server.Stop();
}
}
static async Task HandleClient(TcpClient client, int id, RSA pubKey)
{
Log log = new(id, null);
log.System("Got a new client");
string clientPhone = "";
NetworkStream stream = client.GetStream();
byte[] buffer = new byte[1024];
byte counter = 0;
// Get AES session key
int len = await stream.ReadAsync(buffer);
log.Encrypted("Key + IV", buffer[..len]);
byte[] skBytes = pubKey.Decrypt(buffer[..len], RSAEncryptionPadding.OaepSHA256);
log.Decrypted("Key + IV", skBytes);
Aes sk = Aes.Create();
sk.Key = skBytes[..32]; // just to make sure no one sends a too big to be true key
sk.IV = skBytes[32..];
await stream.WriteAsync(new byte[] { 0 });
// Get first message (should be either login or register)
len = await stream.ReadAsync(buffer);
log.Encrypted("Auth message", buffer[..len]);
byte[] msgDec = sk.DecryptCfb(buffer[..len], sk.IV, PaddingMode.PKCS7);
log.Decrypted("Auth message", msgDec);
byte[] msg = msgDec[..MSG_LEN];
log.Message(Request.RequestToString(msg));
if (msg[0] != 0)
{
log.System("Invalid message version!");
client.Dispose();
return;
}
counter = IncrementCounter(msg[2]); // allow counter to start at a random position
if (msg[1] == (byte)RequestType.Register)
{
// Do register stuff
// get phone number
string phone = Utils.BytesToNumber(msg[3..11]);
log.System($"Client wants to register as {phone}");
clientPhone = phone;
log.Client = phone;
int keyLen = BitConverter.ToInt32(msg, 11);
RSA pub = RSA.Create();
pub.ImportRSAPublicKey(msgDec.AsSpan()[MSG_LEN..], out int bytesRead);
log.System($"Imported key is: \n {pub.ExportRSAPublicKeyPem()}\n");
// generate the 6 digit code and send it
byte[] code = [
(byte)Rand.Next(10),
(byte)Rand.Next(10),
(byte)Rand.Next(10),
(byte)Rand.Next(10),
(byte)Rand.Next(10),
(byte)Rand.Next(10),
];
await SendBySecureChannel(stream, code);
// wait for the code to be back with a key
int tries = 5; // allow 5 tries before closing the connection and forcing a restart
while (tries > 0)
{
tries -= 1;
len = await stream.ReadAsync(buffer);
log.Encrypted("ConfirmRegister", buffer[..len]);
msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.None);
log.Decrypted("ConfirmRegister", msg);
log.Message(Request.RequestToString(msg));
byte[] sig = buffer[MSG_LEN..len];
log.Decrypted("ConfirmRegister (sig)", sig);
if (msg[0] != 0 || msg[1] != (byte)RequestType.ConfirmRegister || msg[2] != counter)
{
// invalid or unexpected req, someone might be sending dups
continue;
}
counter = IncrementCounter(counter);
byte[] gottenCode = msg[3..9];
int expectedSigLen = BitConverter.ToInt32(msg, 9);
if (expectedSigLen != len - MSG_LEN)
{
log.System($"expected sig len doesnt match read len: {expectedSigLen} / {len - MSG_LEN}");
}
// check if the codes are equal
if (code.Zip(gottenCode).Any(a => a.First != a.Second))
{
// codes are not equal, send a nack
// perhaps we should only send a SIG INVALID (or just FAILED) and hide the reason in case someone tries to guess the code
msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("BAD CODE"), sk.IV, PaddingMode.PKCS7);
log.Decrypted("Fail Register", Encoding.UTF8.GetBytes("BAD CODE"));
log.Encrypted("Fail Register", msg);
await stream.WriteAsync(msg);
}
else
{
// codes are equal - verify sig
bool sigValid = pub.VerifyData(gottenCode, sig, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
if (sigValid)
{
msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("OK"), sk.IV, PaddingMode.PKCS7);
log.Decrypted("Register Success", Encoding.UTF8.GetBytes("OK"));
log.Encrypted("Register Success", msg);
await stream.WriteAsync(msg);
Data.Keys[phone] = pub; // save the key
break;
}
else
{
msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("SIG INVALID"), sk.IV, PaddingMode.PKCS7);
log.Decrypted("Fail Register", Encoding.UTF8.GetBytes("SIG INVALID"));
log.Encrypted("Fail Register", msg);
await stream.WriteAsync(msg);
}
}
}
}
else if (msg[1] == (byte)RequestType.Login)
{
// verify login
clientPhone = Utils.BytesToNumber(msg[3..11]);
log.Client = clientPhone;
counter = IncrementCounter(msg[2]);
if (!Data.Keys.TryGetValue(clientPhone, out RSA? clientKey))
{
stream.Close();
client.Close();
log.System($"Client claims to be {clientPhone}, but could not find key in records");
return;
}
byte[] challenge = new byte[16];
Rand.NextBytes(challenge);
log.Decrypted("Challenge", challenge);
byte[] response = sk.EncryptCfb(challenge, sk.IV, PaddingMode.None);
log.Encrypted("Challenge", response);
await stream.WriteAsync(response);
len = await stream.ReadAsync(buffer);
log.Encrypted("Challenge Response", buffer[..len]);
msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.None);
log.Decrypted("Challenge Message", msg);
log.Message(Request.RequestToString(msg));
if (msg[2] != counter)
{
client.Close();
log.System($"Invalid counter in login response, quitting");
return;
}
counter = IncrementCounter(counter);
byte[] sig = buffer[MSG_LEN..len];
// sig wasnt encrypted so this print is somewhat redundant, but i think its still useful to see the sig alone
log.Decrypted("Challenge Sig", sig);
bool valid = clientKey.VerifyData(challenge, sig, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
if (valid)
{
log.System("Client verification complete");
response = sk.EncryptCfb(Encoding.UTF8.GetBytes("OK"), sk.IV, PaddingMode.PKCS7);
log.Decrypted("Challenge Response", Encoding.UTF8.GetBytes("OK"));
log.Encrypted("Challenge Response", response);
await stream.WriteAsync(response);
}
else
{
log.System("Client failed verification, invalid signature");
response = sk.EncryptCfb(Encoding.UTF8.GetBytes("INVALID SIG"), sk.IV, PaddingMode.PKCS7);
log.Decrypted("Challenge Response", Encoding.UTF8.GetBytes("INVALID SIG"));
log.Encrypted("Challenge Response", response);
await stream.WriteAsync(response);
client.Close();
return;
}
}
else
{
// invalid connection, quit
log.System("Client didnt register or login as first message");
client.Dispose();
return;
}
// Client registered/logged in, do main messages loop
try
{
while (client.Connected)
{
// while the client is connected, simply read messages from the client and handle accordingly,
// either by getting new messages for other ppl, or sending back keys/pending messages
len = await stream.ReadAsync(buffer);
log.Encrypted("Request", buffer[..len]);
msg = sk.DecryptCfb(buffer[..MSG_LEN], sk.IV, PaddingMode.None);
log.Decrypted("Request message", msg);
log.Message(Request.RequestToString(msg));
// verify that the counter message is correct
if (msg[0] != 0 || msg[2] != counter)
{
msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("INVALID REQUEST"), sk.IV, PaddingMode.PKCS7);
await stream.WriteAsync(msg);
continue;
}
counter = IncrementCounter(counter);
switch ((RequestType)msg[1])
{
case RequestType.GetMessages:
byte[] msgsLens = Enumerable.Repeat<byte>(0, 16).ToArray(); // 128 bits
// get 15 messages, last byte will indicate if there are more
List<byte[]> msgs = Data.GetMessages(clientPhone, 7) ?? [];
byte[] msgsBytes = new byte[msgs.Select(m => m.Length).Sum()];
int msgsbytesIndex = 0;
for (int i = 0; i < msgs.Count; i += 1)
{
// messages are encrypted blocks of (currently) 1024 RSA keys, so it would be 256 bytes
// meaning we need a short at least (technically we need 9 bytes, but using a full short will allow for
// bigger key sizes without much hassle, until a certain length)
msgsLens[2 * i] = (byte)(msgs[i].Length >> 8);
msgsLens[(2 * i) + 1] = (byte)msgs[i].Length;
// copy the message to the msgsBytes array
Array.Copy(msgs[i], 0, msgsBytes, msgsbytesIndex, msgs[i].Length);
msgsbytesIndex += msgs[i].Length;
}
msgsLens[15] = Data.PeekMessages(clientPhone) ? (byte)1 : (byte)0;
// only need to encrypt the lengths of the messages, as the messages themselves are encrypted
log.Decrypted("GetMessages", msgsLens);
msgsLens = sk.EncryptCfb(msgsLens, sk.IV, PaddingMode.None);
log.Encrypted("GetMessages", msgsLens);
byte[] finalPayload = [.. msgsLens, .. msgsBytes];
log.Encrypted("GetMessages Final", finalPayload);
await stream.WriteAsync(finalPayload);
break;
case RequestType.GetUserKey:
string phone = Utils.BytesToNumber(msg[3..11]);
RSA? key = Data.GetKey(phone);
if (key != null)
{
msg = [0, .. key.ExportRSAPublicKey()];
log.Decrypted("GetUserKey", msg);
msg = sk.EncryptCfb(msg, sk.IV, PaddingMode.PKCS7);
log.Encrypted("GetUserKey", msg);
await stream.WriteAsync(msg);
}
else
{
msg = [1, .. Encoding.UTF8.GetBytes("USER DOES NOT EXIST")];
log.Decrypted("GetUserKey", msg);
msg = sk.EncryptCfb(msg, sk.IV, PaddingMode.PKCS7);
log.Encrypted("GetUserKey", msg);
await stream.WriteAsync(msg);
}
break;
case RequestType.SendMessage:
string recv = Utils.BytesToNumber(msg[3..11]);
int msgLen = BitConverter.ToInt32(msg, 11);
if (msgLen != (len - MSG_LEN))
{
log.System($"Got message to {recv} of length {len - MSG_LEN} but expected {msgLen}");
}
byte[] clientMsg = buffer[MSG_LEN..(msgLen + MSG_LEN)];
log.Encrypted("SendMessage message", clientMsg);
// simply add the clientMsg to the "Data"
bool added = Data.AddMessage(recv, clientMsg);
log.System($"Added message to {recv} of length {msgLen}: {added}");
break;
default:
msg = sk.EncryptCfb(Encoding.UTF8.GetBytes("INVALID REQUEST"), sk.IV, PaddingMode.PKCS7);
await stream.WriteAsync(msg);
break;
}
}
}
catch (Exception ex)
{
log.System($"Client failed with error {ex.Message}");
log.System($"Stack: {ex.StackTrace}");
}
client.Dispose();
}
static byte IncrementCounter(byte counter)
{
return counter == byte.MaxValue ? (byte)0 : (byte)(counter + 1);
}
static async Task SendBySecureChannel(NetworkStream stream, byte[] code)
{
await stream.WriteAsync(code);
}
}