diff --git a/core/rtmp/rtmp.go b/core/rtmp/rtmp.go index a94705c2c..20b303792 100644 --- a/core/rtmp/rtmp.go +++ b/core/rtmp/rtmp.go @@ -5,7 +5,6 @@ import ( "io" "net" "os" - "strings" "syscall" "time" @@ -78,9 +77,7 @@ func HandleConn(c *rtmp.Conn, nc net.Conn) { return } - streamingKeyComponents := strings.Split(c.URL.Path, "/") - streamingKey := streamingKeyComponents[len(streamingKeyComponents)-1] - if streamingKey != data.GetStreamKey() { + if !secretMatch(data.GetStreamKey(), c.URL.Path) { log.Errorln("invalid streaming key; rejecting incoming stream") nc.Close() return diff --git a/core/rtmp/utils.go b/core/rtmp/utils.go index 8897a74ba..0fa59894e 100644 --- a/core/rtmp/utils.go +++ b/core/rtmp/utils.go @@ -9,6 +9,7 @@ import ( "github.com/nareix/joy5/format/flv/flvio" "github.com/owncast/owncast/models" + log "github.com/sirupsen/logrus" ) const unknownString = "Unknown" @@ -76,3 +77,15 @@ func getVideoCodec(codec interface{}) string { return unknownString } + +func secretMatch(configStreamKey string, path string) bool { + prefix := "/live/" + + if !strings.HasPrefix(path, prefix) { + log.Debug("RTMP path does not start with " + prefix) + return false // We need the path to begin with $prefix + } + + streamingKey := path[len(prefix):] // Remove $prefix + return streamingKey == configStreamKey +} diff --git a/core/rtmp/utils_test.go b/core/rtmp/utils_test.go new file mode 100644 index 000000000..233c09b48 --- /dev/null +++ b/core/rtmp/utils_test.go @@ -0,0 +1,35 @@ +package rtmp + +import "testing" + +func Test_secretMatch(t *testing.T) { + tests := []struct { + name string + streamKey string + path string + want bool + }{ + {"positive", "abc", "/live/abc", true}, + {"negative", "abc", "/live/def", false}, + {"positive with numbers", "abc123", "/live/abc123", true}, + {"negative with numbers", "abc123", "/live/def456", false}, + {"positive with url chars", "one/two/three", "/live/one/two/three", true}, + {"negative with url chars", "one/two/three", "/live/four/five/six", false}, + {"check the entire secret", "three", "/live/one/two/three", false}, + {"with /live/ in secret", "one/live/three", "/live/one/live/three", true}, + {"bad path", "anything", "nonsense", false}, + {"missing secret", "abc", "/live/", false}, + {"missing secret and missing last slash", "abc", "/live", false}, + {"streamkey before /live/", "streamkey", "/streamkey/live", false}, + {"missing /live/", "anything", "/something/else", false}, + {"stuff before and after /live/", "after", "/before/live/after", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := secretMatch(tt.streamKey, tt.path); got != tt.want { + t.Errorf("secretMatch() = %v, want %v", got, tt.want) + } + }) + } +}