diff --git a/core/data/data.go b/core/data/data.go index be2410491..896aa35e5 100644 --- a/core/data/data.go +++ b/core/data/data.go @@ -75,7 +75,7 @@ func SetupPersistence(file string) error { _, _ = db.Exec("pragma temp_store = memory") _, _ = db.Exec("pragma wal_checkpoint(full)") - createWebhooksTable() + tables.CreateWebhooksTable(db) tables.CreateUsersTable(db) tables.CreateAccessTokenTable(db) diff --git a/core/webhooks/webhooks.go b/core/webhooks/webhooks.go index c3bcb39e2..74478c4d1 100644 --- a/core/webhooks/webhooks.go +++ b/core/webhooks/webhooks.go @@ -4,8 +4,8 @@ import ( "sync" "time" - "github.com/owncast/owncast/core/data" "github.com/owncast/owncast/models" + "github.com/owncast/owncast/persistence/webhookrepository" ) // WebhookEvent represents an event sent as a webhook. @@ -31,7 +31,8 @@ func SendEventToWebhooks(payload WebhookEvent) { } func sendEventToWebhooks(payload WebhookEvent, wg *sync.WaitGroup) { - webhooks := data.GetWebhooksForEvent(payload.Type) + webhooksRepo := webhookrepository.Get() + webhooks := webhooksRepo.GetWebhooksForEvent(payload.Type) for _, webhook := range webhooks { // Use wg to track the number of notifications to be sent. diff --git a/core/webhooks/webhooks_test.go b/core/webhooks/webhooks_test.go index e10b02ee8..1b035471c 100644 --- a/core/webhooks/webhooks_test.go +++ b/core/webhooks/webhooks_test.go @@ -15,6 +15,7 @@ import ( "github.com/owncast/owncast/core/chat/events" "github.com/owncast/owncast/core/data" "github.com/owncast/owncast/models" + "github.com/owncast/owncast/persistence/webhookrepository" jsonpatch "gopkg.in/evanphx/json-patch.v5" ) @@ -62,12 +63,14 @@ func TestPublicSend(t *testing.T) { })) defer svr.Close() - hook, err := data.InsertWebhook(svr.URL, []models.EventType{models.MessageSent}) + webhooksRepo := webhookrepository.Get() + + hook, err := webhooksRepo.InsertWebhook(svr.URL, []models.EventType{models.MessageSent}) if err != nil { t.Fatal(err) } defer func() { - if err := data.DeleteWebhook(hook); err != nil { + if err := webhooksRepo.DeleteWebhook(hook); err != nil { t.Error(err) } }() @@ -107,13 +110,15 @@ func TestRouting(t *testing.T) { })) defer svr.Close() + webhooksRepo := webhookrepository.Get() + for _, eventType := range eventTypes { - hook, err := data.InsertWebhook(svr.URL+"/"+eventType, []models.EventType{eventType}) + hook, err := webhooksRepo.InsertWebhook(svr.URL+"/"+eventType, []models.EventType{eventType}) if err != nil { t.Fatal(err) } defer func() { - if err := data.DeleteWebhook(hook); err != nil { + if err := webhooksRepo.DeleteWebhook(hook); err != nil { t.Error(err) } }() @@ -148,13 +153,15 @@ func TestMultiple(t *testing.T) { })) defer svr.Close() + webhooksRepo := webhookrepository.Get() + for i := 0; i < times; i++ { - hook, err := data.InsertWebhook(fmt.Sprintf("%v/%v", svr.URL, i), []models.EventType{models.MessageSent}) + hook, err := webhooksRepo.InsertWebhook(fmt.Sprintf("%v/%v", svr.URL, i), []models.EventType{models.MessageSent}) if err != nil { t.Fatal(err) } defer func() { - if err := data.DeleteWebhook(hook); err != nil { + if err := webhooksRepo.DeleteWebhook(hook); err != nil { t.Error(err) } }() @@ -186,14 +193,16 @@ func TestTimestamps(t *testing.T) { })) defer svr.Close() + webhooksRepo := webhookrepository.Get() + for i, eventType := range eventTypes { - hook, err := data.InsertWebhook(svr.URL+"/"+eventType, []models.EventType{eventType}) + hook, err := webhooksRepo.InsertWebhook(svr.URL+"/"+eventType, []models.EventType{eventType}) if err != nil { t.Fatal(err) } handlerIds[i] = hook defer func() { - if err := data.DeleteWebhook(hook); err != nil { + if err := webhooksRepo.DeleteWebhook(hook); err != nil { t.Error(err) } }() @@ -209,7 +218,7 @@ func TestTimestamps(t *testing.T) { wg.Wait() - hooks, err := data.GetWebhooks() + hooks, err := webhooksRepo.GetWebhooks() if err != nil { t.Fatal(err) } @@ -285,12 +294,14 @@ func TestParallel(t *testing.T) { })) defer svr.Close() - hook, err := data.InsertWebhook(svr.URL, []models.EventType{models.MessageSent}) + webhooksRepo := webhookrepository.Get() + + hook, err := webhooksRepo.InsertWebhook(svr.URL, []models.EventType{models.MessageSent}) if err != nil { t.Fatal(err) } defer func() { - if err := data.DeleteWebhook(hook); err != nil { + if err := webhooksRepo.DeleteWebhook(hook); err != nil { t.Error(err) } }() @@ -320,13 +331,15 @@ func checkPayload(t *testing.T, eventType models.EventType, send func(), expecte })) defer svr.Close() + webhooksRepo := webhookrepository.Get() + // Subscribe to the webhook. - hook, err := data.InsertWebhook(svr.URL, []models.EventType{eventType}) + hook, err := webhooksRepo.InsertWebhook(svr.URL, []models.EventType{eventType}) if err != nil { t.Fatal(err) } defer func() { - if err := data.DeleteWebhook(hook); err != nil { + if err := webhooksRepo.DeleteWebhook(hook); err != nil { t.Error(err) } }() diff --git a/core/webhooks/workerpool.go b/core/webhooks/workerpool.go index 2134ef1d2..69f2b4fa5 100644 --- a/core/webhooks/workerpool.go +++ b/core/webhooks/workerpool.go @@ -9,8 +9,8 @@ import ( log "github.com/sirupsen/logrus" - "github.com/owncast/owncast/core/data" "github.com/owncast/owncast/models" + "github.com/owncast/owncast/persistence/webhookrepository" ) // webhookWorkerPoolSize defines the number of concurrent HTTP webhook requests. @@ -87,7 +87,8 @@ func sendWebhook(job Job) error { defer resp.Body.Close() - if err := data.SetWebhookAsUsed(job.webhook); err != nil { + webhooksRepo := webhookrepository.Get() + if err := webhooksRepo.SetWebhookAsUsed(job.webhook); err != nil { log.Warnln(err) } diff --git a/persistence/tables/webhooks.go b/persistence/tables/webhooks.go new file mode 100644 index 000000000..2bfb78a41 --- /dev/null +++ b/persistence/tables/webhooks.go @@ -0,0 +1,28 @@ +package tables + +import ( + "database/sql" + + log "github.com/sirupsen/logrus" +) + +func CreateWebhooksTable(db *sql.DB) { + log.Traceln("Creating webhooks table...") + + createTableSQL := `CREATE TABLE IF NOT EXISTS webhooks ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "url" string NOT NULL, + "events" TEXT NOT NULL, + "timestamp" DATETIME DEFAULT CURRENT_TIMESTAMP, + "last_used" DATETIME + );` + + stmt, err := db.Prepare(createTableSQL) + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + if _, err = stmt.Exec(); err != nil { + log.Warnln(err) + } +} diff --git a/core/data/webhooks.go b/persistence/webhookrepository/webhookrepository.go similarity index 69% rename from core/data/webhooks.go rename to persistence/webhookrepository/webhookrepository.go index 2ae09885b..4a36b8299 100644 --- a/core/data/webhooks.go +++ b/persistence/webhookrepository/webhookrepository.go @@ -1,4 +1,4 @@ -package data +package webhookrepository import ( "errors" @@ -6,38 +6,51 @@ import ( "strings" "time" + "github.com/owncast/owncast/core/data" "github.com/owncast/owncast/models" log "github.com/sirupsen/logrus" ) -func createWebhooksTable() { - log.Traceln("Creating webhooks table...") +type WebhookRepository interface { + InsertWebhook(url string, events []models.EventType) (int, error) + DeleteWebhook(id int) error + GetWebhooksForEvent(event models.EventType) []models.Webhook + GetWebhooks() ([]models.Webhook, error) + SetWebhookAsUsed(webhook models.Webhook) error +} - createTableSQL := `CREATE TABLE IF NOT EXISTS webhooks ( - "id" INTEGER PRIMARY KEY AUTOINCREMENT, - "url" string NOT NULL, - "events" TEXT NOT NULL, - "timestamp" DATETIME DEFAULT CURRENT_TIMESTAMP, - "last_used" DATETIME - );` +type SqlWebhookRepository struct { + datastore *data.Datastore +} - stmt, err := _db.Prepare(createTableSQL) - if err != nil { - log.Fatal(err) +// NOTE: This is temporary during the transition period. +var temporaryGlobalInstance WebhookRepository + +// Get will return the user repository. +func Get() WebhookRepository { + if temporaryGlobalInstance == nil { + i := New(data.GetDatastore()) + temporaryGlobalInstance = i } - defer stmt.Close() - if _, err = stmt.Exec(); err != nil { - log.Warnln(err) + return temporaryGlobalInstance +} + +// New will create a new instance of the UserRepository. +func New(datastore *data.Datastore) WebhookRepository { + r := SqlWebhookRepository{ + datastore: datastore, } + + return &r } // InsertWebhook will add a new webhook to the database. -func InsertWebhook(url string, events []models.EventType) (int, error) { +func (r *SqlWebhookRepository) InsertWebhook(url string, events []models.EventType) (int, error) { log.Traceln("Adding new webhook") eventsString := strings.Join(events, ",") - tx, err := _db.Begin() + tx, err := r.datastore.DB.Begin() if err != nil { return 0, err } @@ -65,10 +78,10 @@ func InsertWebhook(url string, events []models.EventType) (int, error) { } // DeleteWebhook will delete a webhook from the database. -func DeleteWebhook(id int) error { +func (r *SqlWebhookRepository) DeleteWebhook(id int) error { log.Traceln("Deleting webhook") - tx, err := _db.Begin() + tx, err := r.datastore.DB.Begin() if err != nil { return err } @@ -96,7 +109,7 @@ func DeleteWebhook(id int) error { } // GetWebhooksForEvent will return all of the webhooks that want to be notified about an event type. -func GetWebhooksForEvent(event models.EventType) []models.Webhook { +func (r *SqlWebhookRepository) GetWebhooksForEvent(event models.EventType) []models.Webhook { webhooks := make([]models.Webhook, 0) query := `SELECT * FROM ( @@ -111,9 +124,9 @@ func GetWebhooksForEvent(event models.EventType) []models.Webhook { SELECT id, url, event FROM split WHERE event <> '' - ) AS webhook WHERE event IS "` + event + `"` + ) AS webhook WHERE event IS ?` - rows, err := _db.Query(query) + rows, err := r.datastore.DB.Query(query, event) if err != nil || rows.Err() != nil { log.Fatal(err) } @@ -140,12 +153,12 @@ func GetWebhooksForEvent(event models.EventType) []models.Webhook { } // GetWebhooks will return all the webhooks. -func GetWebhooks() ([]models.Webhook, error) { //nolint +func (r *SqlWebhookRepository) GetWebhooks() ([]models.Webhook, error) { //nolint webhooks := make([]models.Webhook, 0) query := "SELECT * FROM webhooks" - rows, err := _db.Query(query) + rows, err := r.datastore.DB.Query(query) if err != nil { return webhooks, err } @@ -193,8 +206,8 @@ func GetWebhooks() ([]models.Webhook, error) { //nolint } // SetWebhookAsUsed will update the last used time for a webhook. -func SetWebhookAsUsed(webhook models.Webhook) error { - tx, err := _db.Begin() +func (r *SqlWebhookRepository) SetWebhookAsUsed(webhook models.Webhook) error { + tx, err := r.datastore.DB.Begin() if err != nil { return err } diff --git a/webserver/handlers/admin/webhooks.go b/webserver/handlers/admin/webhooks.go index 719edc6fb..39b381f73 100644 --- a/webserver/handlers/admin/webhooks.go +++ b/webserver/handlers/admin/webhooks.go @@ -6,8 +6,8 @@ import ( "net/http" "time" - "github.com/owncast/owncast/core/data" "github.com/owncast/owncast/models" + "github.com/owncast/owncast/persistence/webhookrepository" "github.com/owncast/owncast/webserver/handlers/generated" webutils "github.com/owncast/owncast/webserver/utils" ) @@ -32,7 +32,8 @@ func CreateWebhook(w http.ResponseWriter, r *http.Request) { return } - newWebhookID, err := data.InsertWebhook(request.URL, request.Events) + webhooksrepo := webhookrepository.Get() + newWebhookID, err := webhooksrepo.InsertWebhook(request.URL, request.Events) if err != nil { webutils.InternalErrorHandler(w, err) return @@ -49,7 +50,8 @@ func CreateWebhook(w http.ResponseWriter, r *http.Request) { // GetWebhooks will return all webhooks. func GetWebhooks(w http.ResponseWriter, r *http.Request) { - webhooks, err := data.GetWebhooks() + webhooksrepo := webhookrepository.Get() + webhooks, err := webhooksrepo.GetWebhooks() if err != nil { webutils.InternalErrorHandler(w, err) return @@ -72,7 +74,8 @@ func DeleteWebhook(w http.ResponseWriter, r *http.Request) { return } - if err := data.DeleteWebhook(*request.Id); err != nil { + webhooksrepo := webhookrepository.Get() + if err := webhooksrepo.DeleteWebhook(*request.Id); err != nil { webutils.InternalErrorHandler(w, err) return }