From 0b5ddf433b1f36ac90c3635eda237c3087c7db0a Mon Sep 17 00:00:00 2001 From: Gabe Kangas Date: Tue, 2 Aug 2022 13:29:06 -0700 Subject: [PATCH] Limit OTP requests to one per expiry window. Closes #2000 --- auth/fediverse/fediverse.go | 16 +++++++++++++--- auth/fediverse/fediverse_test.go | 19 ++++++++++++++++++- controllers/auth/fediverse/fediverse.go | 7 ++++++- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/auth/fediverse/fediverse.go b/auth/fediverse/fediverse.go index 8f00ee120..404af3916 100644 --- a/auth/fediverse/fediverse.go +++ b/auth/fediverse/fediverse.go @@ -19,9 +19,19 @@ type OTPRegistration struct { // to be active at a time. var pendingAuthRequests = make(map[string]OTPRegistration) +const registrationTimeout = time.Minute * 10 + // RegisterFediverseOTP will start the OTP flow for a user, creating a new // code and returning it to be sent to a destination. -func RegisterFediverseOTP(accessToken, userID, userDisplayName, account string) OTPRegistration { +func RegisterFediverseOTP(accessToken, userID, userDisplayName, account string) (OTPRegistration, bool) { + request, requestExists := pendingAuthRequests[accessToken] + + // If a request is already registered and has not expired then return that + // existing request. + if requestExists && time.Since(request.Timestamp) < registrationTimeout { + return request, false + } + code, _ := createCode() r := OTPRegistration{ Code: code, @@ -32,14 +42,14 @@ func RegisterFediverseOTP(accessToken, userID, userDisplayName, account string) } pendingAuthRequests[accessToken] = r - return r + return r, true } // ValidateFediverseOTP will verify a OTP code for a auth request. func ValidateFediverseOTP(accessToken, code string) (bool, *OTPRegistration) { request, ok := pendingAuthRequests[accessToken] - if !ok || request.Code != code || time.Since(request.Timestamp) > time.Minute*10 { + if !ok || request.Code != code || time.Since(request.Timestamp) > registrationTimeout { return false, nil } diff --git a/auth/fediverse/fediverse_test.go b/auth/fediverse/fediverse_test.go index 8c1d58f66..912736d05 100644 --- a/auth/fediverse/fediverse_test.go +++ b/auth/fediverse/fediverse_test.go @@ -10,7 +10,11 @@ const ( ) func TestOTPFlowValidation(t *testing.T) { - r := RegisterFediverseOTP(accessToken, userID, userDisplayName, account) + r, success := RegisterFediverseOTP(accessToken, userID, userDisplayName, account) + + if !success { + t.Error("Registration should be permitted.") + } if r.Code == "" { t.Error("Code is empty") @@ -41,3 +45,16 @@ func TestOTPFlowValidation(t *testing.T) { t.Error("UserDisplayName is not set correctly") } } + +func TestSingleOTPFlowRequest(t *testing.T) { + r1, _ := RegisterFediverseOTP(accessToken, userID, userDisplayName, account) + r2, s2 := RegisterFediverseOTP(accessToken, userID, userDisplayName, account) + + if r1.Code != r2.Code { + t.Error("Only one registration should be permitted.") + } + + if s2 { + t.Error("Second registration should not be permitted.") + } +} diff --git a/controllers/auth/fediverse/fediverse.go b/controllers/auth/fediverse/fediverse.go index e335a5a81..6192e712e 100644 --- a/controllers/auth/fediverse/fediverse.go +++ b/controllers/auth/fediverse/fediverse.go @@ -29,7 +29,12 @@ func RegisterFediverseOTPRequest(u user.User, w http.ResponseWriter, r *http.Req } accessToken := r.URL.Query().Get("accessToken") - reg := fediverseauth.RegisterFediverseOTP(accessToken, u.ID, u.DisplayName, req.FediverseAccount) + reg, success := fediverseauth.RegisterFediverseOTP(accessToken, u.ID, u.DisplayName, req.FediverseAccount) + if !success { + controllers.WriteSimpleResponse(w, false, "Could not register auth request. One may already be pending. Try again later.") + return + } + msg := fmt.Sprintf("

This is an automated message from %s. If you did not request this message please ignore or block. Your requested one-time code is:

%s

", data.GetServerName(), reg.Code) if err := activitypub.SendDirectFederatedMessage(msg, reg.Account); err != nil { controllers.WriteSimpleResponse(w, false, "Could not send code to fediverse: "+err.Error())