using System.Security.Cryptography; using System.Text; using System.Text.RegularExpressions; using VPNAuth.Server.Database; using VPNAuth.Server.Responses; namespace VPNAuth.Server.Api; public static class OAuth2 { public static async Task AcceptAuthHandler(HttpContext context, int id) { using var db = new Database.Database(); var authRequest = db.AuthRequests.Find(id); if (authRequest == null || authRequest.Accepted) { context.Response.StatusCode = StatusCodes.Status404NotFound; return; } if (authRequest.Username != context.GetUser()?.Username) { context.Response.StatusCode = StatusCodes.Status403Forbidden; return; } authRequest.Accepted = true; db.SaveChanges(); var config = Config.Read(); context.Response.StatusCode = StatusCodes.Status302Found; context.Response.Headers["Location"] = config.FindApp(authRequest.ClientId)!.RedirectUri! + "?code=" + authRequest.Code + "&state=" + authRequest.State; } private static string HashCodeVerifier(string codeVerifier) { using var sha256 = SHA256.Create(); var removeCodeChallengeEnd = new Regex("=$"); var verifierBytes = Encoding.ASCII.GetBytes(codeVerifier); var hashedVerifierBytes = sha256.ComputeHash(verifierBytes); return removeCodeChallengeEnd.Replace(Convert.ToBase64String(hashedVerifierBytes), "") .Replace("+", "-") .Replace("/", "_"); } public static async Task AccessTokenHandler(HttpContext context) { var config = Config.Read(); if (context.Request.Form["grant_type"] != "authorization_code") { context.Response.StatusCode = StatusCodes.Status400BadRequest; return; } var clientSecret = config.FindApp(context.Request.Form["client_id"]!)!.Secret; // FIXME: null pointer if (clientSecret != null && clientSecret != context.Request.Form["client_secret"]) { context.Response.StatusCode = StatusCodes.Status403Forbidden; return; } using var db = new Database.Database(); var authRequest = db.AuthRequests .Where(request => request.Code == context.Request.Form["code"].ToString()) .ToList() .FirstOrDefault(); if (authRequest == null) { context.Response.StatusCode = StatusCodes.Status404NotFound; return; } if (!context.Request.Form.ContainsKey("code_verifier")) { context.Response.StatusCode = StatusCodes.Status400BadRequest; return; } var expectedCodeChallenge = HashCodeVerifier(context.Request.Form["code_verifier"].ToString()); if (expectedCodeChallenge != authRequest.CodeChallenge) { context.Response.StatusCode = StatusCodes.Status403Forbidden; return; } var accessTokenEntry = db.AccessTokens.Add(new AccessToken { ClientId = authRequest.ClientId, Scopes = authRequest.Scopes, CreationTime = DateTime.Now, Token = PkceUtils.GenerateToken(), Username = authRequest.Username }); db.SaveChanges(); await context.Response.WriteAsJsonAsync(new Token { AccessToken = accessTokenEntry.Entity.Token, TokenType = "Bearer", Expires = 0 // TODO: change to actual value }); } }