diff --git a/LarpixServer/Account/Requests.cs b/LarpixServer/Account/Requests.cs index 0cc3882..de9be00 100644 --- a/LarpixServer/Account/Requests.cs +++ b/LarpixServer/Account/Requests.cs @@ -65,119 +65,120 @@ public class Requests return; } - await createLock.WaitAsync(); - try + + switch (step) { - switch (step) + case "init": { - case "init": + foreach (var kvp in createHolder) // czyszczenie nieaktywnych od 2 minut requestow { - foreach (var kvp in createHolder) // czyszczenie nieaktywnych od 2 minut requestow + if (kvp.Value.date < DateTimeOffset.UtcNow.AddMinutes(-2)) { - if (kvp.Value.date < DateTimeOffset.UtcNow.AddMinutes(-2)) - { - createHolder.TryRemove(kvp.Key, out _); - } + createHolder.TryRemove(kvp.Key, out _); } + } - context.Response.ContentType = mimeTypes["json"]; + context.Response.ContentType = mimeTypes["json"]; - var serverInfo = Encryption.Encryption.InitHybridKEM(); - KeyExchangePayload payload = new KeyExchangePayload(); - payload.pubX25519 = Convert.ToBase64String(serverInfo.pubX25519); - payload.pubMlKem = Convert.ToBase64String(serverInfo.pubMlKem); - payload.idKey = DateTimeOffset.UtcNow.ToUnixTimeSeconds() + "\n"; + var serverInfo = Encryption.Encryption.InitHybridKEM(); + KeyExchangePayload payload = new KeyExchangePayload(); + payload.pubX25519 = Convert.ToBase64String(serverInfo.pubX25519); + payload.pubMlKem = Convert.ToBase64String(serverInfo.pubMlKem); + payload.idKey = DateTimeOffset.UtcNow.ToUnixTimeSeconds() + "\n"; - while (createHolder.ContainsKey(payload.idKey)) - { - payload.idKey += Random.Shared.Next(0, 10).ToString(); - } + while (createHolder.ContainsKey(payload.idKey)) + { + payload.idKey += Random.Shared.Next(0, 10).ToString(); + } - CreateHolder dataHolder = new(); - dataHolder.date = DateTimeOffset.UtcNow; - dataHolder.privX25519Server = serverInfo.privX25519; - dataHolder.privMlKemServer = serverInfo.privMlKem; + CreateHolder dataHolder = new(); + dataHolder.date = DateTimeOffset.UtcNow; + dataHolder.privX25519Server = serverInfo.privX25519; + dataHolder.privMlKemServer = serverInfo.privMlKem; - createHolder.TryAdd(payload.idKey, dataHolder); + createHolder.TryAdd(payload.idKey, dataHolder); - var serializedPayload = JsonSerializer.Serialize( - payload, - AppJsonSerializerContext.Default.KeyExchangePayload - ); + var serializedPayload = JsonSerializer.Serialize( + payload, + AppJsonSerializerContext.Default.KeyExchangePayload + ); - await context.Response.WriteAsync(serializedPayload); + await context.Response.WriteAsync(serializedPayload); + return; + } + case "register": + { + string body = await LoadBody(bodyReader); + + KeyExchangePayloadClient serializedBody = JsonSerializer.Deserialize( + body, + AppJsonSerializerContext.Default.KeyExchangePayloadClient + ); + if (!createHolder.TryGetValue(serializedBody.idKey, out CreateHolder entry)) + { + await context.Response.WriteAsync("error:account.creation.request.expired"); return; } - case "register": + + entry.date = DateTimeOffset.UtcNow; + + byte[] pubX25519Client = Convert.FromBase64String(serializedBody.pubX25519); + byte[] ciphertextMlKem = Convert.FromBase64String(serializedBody.ciphertextMlKem); + + byte[] sharedKey = Encryption.Encryption.CalcHybridSharedKey( + pubX25519Client, entry.privX25519Server, ciphertextMlKem, entry.privMlKemServer); + + + entry.name = Encryption.Encryption.Decrypt(serializedBody.username, sharedKey); + entry.pass = Encryption.Encryption.Decrypt(serializedBody.password, sharedKey); + + if (!Utils.IsValidUsername(entry.name, out string message)) { - string body = await LoadBody(bodyReader); - - KeyExchangePayloadClient serializedBody = JsonSerializer.Deserialize( - body, - AppJsonSerializerContext.Default.KeyExchangePayloadClient - ); - if (!createHolder.TryGetValue(serializedBody.idKey, out CreateHolder entry)) - { - await context.Response.WriteAsync("error:account.creation.request.expired"); - return; - } - - entry.date = DateTimeOffset.UtcNow; - - byte[] pubX25519Client = Convert.FromBase64String(serializedBody.pubX25519); - byte[] ciphertextMlKem = Convert.FromBase64String(serializedBody.ciphertextMlKem); - - byte[] sharedKey = Encryption.Encryption.CalcHybridSharedKey( - pubX25519Client, entry.privX25519Server, ciphertextMlKem, entry.privMlKemServer); - - - entry.name = Encryption.Encryption.Decrypt(serializedBody.username, sharedKey); - entry.pass = Encryption.Encryption.Decrypt(serializedBody.password, sharedKey); - - if (!Utils.IsValidUsername(entry.name, out string message)) - { - await context.Response.WriteAsync(message); - return; - } - - if (!Utils.IsValidPassword(entry.pass, out message)) - { - await context.Response.WriteAsync(message); - return; - } - - (byte[] ImageBytes, string CaptchaText) captchaResult = Captcha.GenerateCaptcha(); - entry.captcha = captchaResult.CaptchaText; - context.Response.ContentType = mimeTypes["webp"]; - context.Response.ContentLength = captchaResult.ImageBytes.Length; - - await context.Response.Body.WriteAsync(captchaResult.ImageBytes, 0, - captchaResult.ImageBytes.Length); + await context.Response.WriteAsync(message); return; } - case "finish": + + if (!Utils.IsValidPassword(entry.pass, out message)) { - string body = await LoadBody(bodyReader); - - CaptchaPayloadClient serialized = JsonSerializer.Deserialize( - body, - AppJsonSerializerContext.Default.CaptchaPayloadClient - ); - if (!createHolder.TryGetValue(serialized.idKey, out var entry)) - { - await context.Response.WriteAsync("error:account.creation.request.expired"); - return; - } + await context.Response.WriteAsync(message); + return; + } - if (entry.captcha.ToLower() != serialized.captcha.ToLower()) - { - createHolder.TryRemove(serialized.idKey, out _); - await context.Response.WriteAsync("error:incorrect.captcha"); - return; - } + (byte[] ImageBytes, string CaptchaText) captchaResult = Captcha.GenerateCaptcha(); + entry.captcha = captchaResult.CaptchaText; + context.Response.ContentType = mimeTypes["webp"]; + context.Response.ContentLength = captchaResult.ImageBytes.Length; - string lowerName = entry.name.ToLowerInvariant(); + await context.Response.Body.WriteAsync(captchaResult.ImageBytes, 0, + captchaResult.ImageBytes.Length); + return; + } + case "finish": + { + string body = await LoadBody(bodyReader); + CaptchaPayloadClient serialized = JsonSerializer.Deserialize( + body, + AppJsonSerializerContext.Default.CaptchaPayloadClient + ); + if (!createHolder.TryGetValue(serialized.idKey, out var entry)) + { + await context.Response.WriteAsync("error:account.creation.request.expired"); + return; + } + + if (entry.captcha.ToLower() != serialized.captcha.ToLower()) + { + createHolder.TryRemove(serialized.idKey, out _); + await context.Response.WriteAsync("error:incorrect.captcha"); + return; + } + + string lowerName = entry.name.ToLowerInvariant(); + + await createLock.WaitAsync(); + try + { if (Fs.Exists($"{ACCOUNTS_NAME_DIR}/{lowerName}")) { await context.Response.WriteAsync("error:username.taken"); @@ -229,7 +230,6 @@ public class Requests await Fs.WriteFile($"{ACCOUNTS_DIR}/registration", Encoding.UTF8.GetBytes("0;")); } - ulong id = ulong.Parse(await Fs.ReadFile($"{ACCOUNTS_DIR}/last")); id++; var freeid = Path.GetFileName(Directory.EnumerateFiles(ACCOUNTS_FREEID_DIR).FirstOrDefault()); @@ -270,17 +270,19 @@ public class Requests await context.Response.WriteAsync("success:account.created"); return; } + finally + { + createLock.Release(); + } } + } - await next(); - } - finally - { - createLock.Release(); - } + await next(); } - public static async Task Auth(HttpContext context, Func next, IQueryCollection query, StreamReader bodyReader) + + +public static async Task Auth(HttpContext context, Func next, IQueryCollection query, StreamReader bodyReader) { if (!query.TryGetValue("id", out var idQuery)) { diff --git a/LarpixServer/Account/Utils.cs b/LarpixServer/Account/Utils.cs index e78d180..568c514 100644 --- a/LarpixServer/Account/Utils.cs +++ b/LarpixServer/Account/Utils.cs @@ -3,6 +3,7 @@ using System.Numerics; using System.Text; using System.Text.Json; using System.Text.RegularExpressions; +using System.Threading; using LarpixServer.Filesystem; using LarpixServer.Utils.Jsons; using static LarpixServer.Utils.Utils; @@ -15,46 +16,20 @@ public class Utils public static string LOGIN_SUCCESS = "success:login.successful"; - public static ConcurrentDictionary userLocks = new(); - public static ConcurrentQueue keyQueue = new(); + private static SemaphoreSlim[]? _userLocksArray = null; public static SemaphoreSlim GetUserLock(string id) { - while (userLocks.Count >= LOCK_SIZE) + if (_userLocksArray == null) { - if (!keyQueue.TryDequeue(out var firstKey)) break; - - if (userLocks.TryGetValue(firstKey, out var sem)) - { - if (sem.Wait(0)) - { - try - { - userLocks.TryRemove(firstKey, out _); - } - finally { sem.Release(); } - } - else - { - keyQueue.Enqueue(firstKey); - break; - } - } + int size = LOCK_SIZE > 0 ? LOCK_SIZE : 65536; + var newArray = Enumerable.Range(0, size).Select(_ => new SemaphoreSlim(1, 1)).ToArray(); + Interlocked.CompareExchange(ref _userLocksArray, newArray, null); } - - if (!userLocks.TryGetValue(id, out var semLock)) - { - semLock = new SemaphoreSlim(1, 1); - if (userLocks.TryAdd(id, semLock)) - { - keyQueue.Enqueue(id); - } - else - { - semLock = userLocks[id]; - } - } - return semLock; + + int hash = id.GetHashCode(); + if (hash < 0) hash = -hash; // Or use Math.Abs, but hash < 0 logic avoids OverflowException on int.MinValue + return _userLocksArray[hash % _userLocksArray.Length]; } public static string GetIdFromUsernameWD(string usernameWD) @@ -174,16 +149,23 @@ public class Utils public static async Task NonceDecryptBody(string id, string password, string body, bool delEntry = true) { - if (!Requests.nonceHolder.TryGetValue(id, out (string, DateTimeOffset) nonce)) - { - return "error:invalid.nonce"; - } - string decBody = Encryption.Encryption.PacketDecPass(body, password, nonce.Item1); + (string, DateTimeOffset) nonce; if (delEntry) { - Requests.nonceHolder.TryRemove(id, out _); + if (!Requests.nonceHolder.TryRemove(id, out nonce)) + { + return "error:invalid.nonce"; + } + } + else + { + if (!Requests.nonceHolder.TryGetValue(id, out nonce)) + { + return "error:invalid.nonce"; + } } + string decBody = Encryption.Encryption.PacketDecPass(body, password, nonce.Item1); return decBody; } diff --git a/LarpixServer/Filesystem/Fs.cs b/LarpixServer/Filesystem/Fs.cs index d6b7ad8..d1d3260 100644 --- a/LarpixServer/Filesystem/Fs.cs +++ b/LarpixServer/Filesystem/Fs.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using System.Threading; using static LarpixServer.Utils.Utils; namespace LarpixServer.Filesystem; @@ -9,8 +10,7 @@ public class Fs public static ConcurrentDictionary existCache = new ConcurrentDictionary(); public static ConcurrentDictionary dirCache = new ConcurrentDictionary(); - private static ConcurrentDictionary fileLocks = new(); - + private static SemaphoreSlim[]? _fileLocksArray = null; private static void InvalidateDirCacheFor(string path) { @@ -23,7 +23,16 @@ public class Fs private static SemaphoreSlim GetFileLock(string path) { - return fileLocks.GetOrAdd(path, _ => new SemaphoreSlim(1, 1)); + if (_fileLocksArray == null) + { + int size = LOCK_SIZE > 0 ? LOCK_SIZE : 65536; + var newArray = Enumerable.Range(0, size).Select(_ => new SemaphoreSlim(1, 1)).ToArray(); + Interlocked.CompareExchange(ref _fileLocksArray, newArray, null); + } + + int hash = path.GetHashCode(); + if (hash < 0) hash = -hash; // Or use Math.Abs, but hash < 0 logic avoids OverflowException on int.MinValue + return _fileLocksArray[hash % _fileLocksArray.Length]; } public static void ProcessCacheSpace() @@ -46,12 +55,7 @@ public class Fs if (firstKey != null) dirCache.TryRemove(firstKey, out _); } - while (fileLocks.Count >= LOCK_SIZE) - { - var firstKey = fileLocks.Keys.FirstOrDefault(); - if (firstKey != null) fileLocks.TryRemove(firstKey, out _); } - } public static ulong ClearCache(string pattern) {