2022-04-22 17:23:14 -07:00
|
|
|
package fediverse
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/rand"
|
2022-12-23 20:20:59 -08:00
|
|
|
"errors"
|
2022-04-22 17:23:14 -07:00
|
|
|
"io"
|
2022-10-02 14:16:46 -04:00
|
|
|
"strings"
|
2022-12-23 20:20:59 -08:00
|
|
|
"sync"
|
2022-04-22 17:23:14 -07:00
|
|
|
"time"
|
2022-12-23 20:20:59 -08:00
|
|
|
|
|
|
|
log "github.com/sirupsen/logrus"
|
2022-04-22 17:23:14 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
// OTPRegistration represents a single OTP request.
|
|
|
|
type OTPRegistration struct {
|
2023-05-30 10:31:43 -07:00
|
|
|
Timestamp time.Time
|
2022-04-22 17:23:14 -07:00
|
|
|
UserID string
|
|
|
|
UserDisplayName string
|
|
|
|
Code string
|
|
|
|
Account string
|
|
|
|
}
|
|
|
|
|
|
|
|
// Key by access token to limit one OTP request for a person
|
|
|
|
// to be active at a time.
|
2022-12-23 20:20:59 -08:00
|
|
|
var (
|
|
|
|
pendingAuthRequests = make(map[string]OTPRegistration)
|
|
|
|
lock = sync.Mutex{}
|
|
|
|
)
|
2022-04-22 17:23:14 -07:00
|
|
|
|
2022-12-23 20:20:59 -08:00
|
|
|
const (
|
|
|
|
registrationTimeout = time.Minute * 10
|
|
|
|
maxPendingRequests = 1000
|
|
|
|
)
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
go setupExpiredRequestPruner()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Clear out any pending requests that have been pending for greater than
|
|
|
|
// the specified timeout value.
|
|
|
|
func setupExpiredRequestPruner() {
|
|
|
|
pruneExpiredRequestsTimer := time.NewTicker(registrationTimeout)
|
|
|
|
|
|
|
|
for range pruneExpiredRequestsTimer.C {
|
|
|
|
lock.Lock()
|
|
|
|
log.Debugln("Pruning expired OTP requests.")
|
|
|
|
for k, v := range pendingAuthRequests {
|
|
|
|
if time.Since(v.Timestamp) > registrationTimeout {
|
|
|
|
delete(pendingAuthRequests, k)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
lock.Unlock()
|
|
|
|
}
|
|
|
|
}
|
2022-08-02 13:29:06 -07:00
|
|
|
|
2022-04-22 17:23:14 -07:00
|
|
|
// RegisterFediverseOTP will start the OTP flow for a user, creating a new
|
|
|
|
// code and returning it to be sent to a destination.
|
2022-12-23 20:20:59 -08:00
|
|
|
func RegisterFediverseOTP(accessToken, userID, userDisplayName, account string) (OTPRegistration, bool, error) {
|
2022-08-02 13:29:06 -07:00
|
|
|
request, requestExists := pendingAuthRequests[accessToken]
|
|
|
|
|
|
|
|
// If a request is already registered and has not expired then return that
|
|
|
|
// existing request.
|
|
|
|
if requestExists && time.Since(request.Timestamp) < registrationTimeout {
|
2022-12-23 20:20:59 -08:00
|
|
|
return request, false, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
lock.Lock()
|
|
|
|
defer lock.Unlock()
|
|
|
|
|
|
|
|
if len(pendingAuthRequests)+1 > maxPendingRequests {
|
|
|
|
return request, false, errors.New("Please try again later. Too many pending requests.")
|
2022-08-02 13:29:06 -07:00
|
|
|
}
|
|
|
|
|
2022-04-22 17:23:14 -07:00
|
|
|
code, _ := createCode()
|
|
|
|
r := OTPRegistration{
|
|
|
|
Code: code,
|
|
|
|
UserID: userID,
|
|
|
|
UserDisplayName: userDisplayName,
|
2022-10-02 14:16:46 -04:00
|
|
|
Account: strings.ToLower(account),
|
2022-04-22 17:23:14 -07:00
|
|
|
Timestamp: time.Now(),
|
|
|
|
}
|
|
|
|
pendingAuthRequests[accessToken] = r
|
|
|
|
|
2022-12-23 20:20:59 -08:00
|
|
|
return r, true, nil
|
2022-04-22 17:23:14 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
// ValidateFediverseOTP will verify a OTP code for a auth request.
|
|
|
|
func ValidateFediverseOTP(accessToken, code string) (bool, *OTPRegistration) {
|
|
|
|
request, ok := pendingAuthRequests[accessToken]
|
|
|
|
|
2022-08-02 13:29:06 -07:00
|
|
|
if !ok || request.Code != code || time.Since(request.Timestamp) > registrationTimeout {
|
2022-04-22 17:23:14 -07:00
|
|
|
return false, nil
|
|
|
|
}
|
|
|
|
|
2022-12-23 20:20:59 -08:00
|
|
|
lock.Lock()
|
|
|
|
defer lock.Unlock()
|
|
|
|
|
2022-04-22 17:23:14 -07:00
|
|
|
delete(pendingAuthRequests, accessToken)
|
|
|
|
return true, &request
|
|
|
|
}
|
|
|
|
|
|
|
|
func createCode() (string, error) {
|
|
|
|
table := [...]byte{'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}
|
|
|
|
|
|
|
|
digits := 6
|
|
|
|
b := make([]byte, digits)
|
|
|
|
n, err := io.ReadAtLeast(rand.Reader, b, digits)
|
|
|
|
if n != digits {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
for i := 0; i < len(b); i++ {
|
|
|
|
b[i] = table[int(b[i])%len(table)]
|
|
|
|
}
|
|
|
|
return string(b), nil
|
|
|
|
}
|