diff --git a/client.go b/client.go index e82548f70..6f74077a0 100644 --- a/client.go +++ b/client.go @@ -10,11 +10,9 @@ import ( const channelBufSize = 100 -var maxId int = 0 - // Chat client. type Client struct { - id int + id string ws *websocket.Conn server *Server ch chan *Message @@ -32,11 +30,10 @@ func NewClient(ws *websocket.Conn, server *Server) *Client { panic("server cannot be nil") } - maxId++ ch := make(chan *Message, channelBufSize) doneCh := make(chan bool) - - return &Client{maxId, ws, server, ch, doneCh} + clientID := getClientIDFromRequest(ws.Request()) + return &Client{clientID, ws, server, ch, doneCh} } func (c *Client) Conn() *websocket.Conn { diff --git a/main.go b/main.go index c11d7c5d0..c2c0e5b6f 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "encoding/json" "net/http" + "path" "strconv" log "github.com/sirupsen/logrus" @@ -49,7 +50,15 @@ func startChatServer() { go server.Listen() // static files - http.Handle("/", http.FileServer(http.Dir("webroot"))) + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, path.Join("webroot", r.URL.Path)) + + if path.Ext(r.URL.Path) == ".m3u8" { + clientID := getClientIDFromRequest(r) + stats.SetClientActive(clientID) + } + }) + http.HandleFunc("/status", getStatus) log.Printf("Starting public web server on port %d", configuration.WebServerPort) @@ -81,10 +90,10 @@ func streamDisconnected() { stats.StreamDisconnected() } -func viewerAdded() { - stats.SetViewerCount(server.ClientCount()) +func viewerAdded(clientID string) { + stats.SetClientActive(clientID) } -func viewerRemoved() { - stats.SetViewerCount(server.ClientCount()) +func viewerRemoved(clientID string) { + stats.ViewerDisconnected(clientID) } diff --git a/server.go b/server.go index bb6e6ccf9..3f2c3fd1b 100644 --- a/server.go +++ b/server.go @@ -11,7 +11,7 @@ import ( type Server struct { pattern string messages []*Message - clients map[int]*Client + clients map[string]*Client addCh chan *Client delCh chan *Client sendAllCh chan *Message @@ -22,7 +22,7 @@ type Server struct { // Create new chat server. func NewServer(pattern string) *Server { messages := []*Message{} - clients := make(map[int]*Client) + clients := make(map[string]*Client) addCh := make(chan *Client) delCh := make(chan *Client) sendAllCh := make(chan *Message) @@ -100,17 +100,14 @@ func (s *Server) Listen() { // Add new a client case c := <-s.addCh: - log.Println("Added new client") s.clients[c.id] = c - log.Println("Now", len(s.clients), "clients connected.") - viewerAdded() + viewerAdded(c.id) s.sendPastMessages(c) // del a client case c := <-s.delCh: - log.Println("Delete client") delete(s.clients, c.id) - viewerRemoved() + viewerRemoved(c.id) // broadcast message for all clients case msg := <-s.sendAllCh: diff --git a/stats.go b/stats.go index a811e6228..d7b1bd36e 100644 --- a/stats.go +++ b/stats.go @@ -1,8 +1,18 @@ +/* +Viewer counting doesn't just count the number of websocket clients that are currently connected, +because people may be watching the stream outside of the web browser via any HLS video client. +Instead we keep track of requests and consider each unique IP as a "viewer". +As a signal, however, we do use the websocket disconnect from a client as a signal that a viewer +dropped and we call ViewerDisconnected(). +*/ + package main import ( "encoding/json" + "fmt" "io/ioutil" + "log" "math" "os" "time" @@ -10,40 +20,53 @@ import ( type Stats struct { streamConnected bool `json:"-"` - ViewerCount int `json:"viewerCount"` SessionMaxViewerCount int `json:"sessionMaxViewerCount"` OverallMaxViewerCount int `json:"overallMaxViewerCount"` LastDisconnectTime time.Time `json:"lastDisconnectTime"` + + clients map[string]time.Time } func (s *Stats) Setup() { - ticker := time.NewTicker(2 * time.Minute) - quit := make(chan struct{}) + s.clients = make(map[string]time.Time) + + statsSaveTimer := time.NewTicker(2 * time.Minute) go func() { for { select { - case <-ticker.C: + case <-statsSaveTimer.C: s.save() - case <-quit: - ticker.Stop() - return } } }() + + staleViewerPurgeTimer := time.NewTicker(5 * time.Second) + go func() { + for { + select { + case <-staleViewerPurgeTimer.C: + s.purgeStaleViewers() + } + } + }() +} + +func (s *Stats) purgeStaleViewers() { + for clientID, lastConnectedtime := range s.clients { + timeSinceLastActive := time.Since(lastConnectedtime).Minutes() + if timeSinceLastActive > 2 { + s.ViewerDisconnected(clientID) + } + + } } func (s *Stats) IsStreamConnected() bool { return s.streamConnected } -func (s *Stats) SetViewerCount(count int) { - s.ViewerCount = count - s.SessionMaxViewerCount = int(math.Max(float64(s.ViewerCount), float64(s.SessionMaxViewerCount))) - s.OverallMaxViewerCount = int(math.Max(float64(s.SessionMaxViewerCount), float64(s.OverallMaxViewerCount))) -} - func (s *Stats) GetViewerCount() int { - return s.ViewerCount + return len(s.clients) } func (s *Stats) GetSessionMaxViewerCount() int { @@ -54,10 +77,20 @@ func (s *Stats) GetOverallMaxViewerCount() int { return s.OverallMaxViewerCount } -func (s *Stats) ViewerConnected() { +func (s *Stats) SetClientActive(clientID string) { + fmt.Println("Marking client active:", clientID) + + s.clients[clientID] = time.Now() + s.SessionMaxViewerCount = int(math.Max(float64(s.GetViewerCount()), float64(s.SessionMaxViewerCount))) + s.OverallMaxViewerCount = int(math.Max(float64(s.SessionMaxViewerCount), float64(s.OverallMaxViewerCount))) + + fmt.Println("Now", s.GetViewerCount(), "clients connected.") } -func (s *Stats) ViewerDisconnected() { +func (s *Stats) ViewerDisconnected(clientID string) { + log.Println("Removed client", clientID) + + delete(s.clients, clientID) } func (s *Stats) StreamConnected() { diff --git a/utils.go b/utils.go index 210ec8f9f..4d0f161a0 100644 --- a/utils.go +++ b/utils.go @@ -3,6 +3,7 @@ package main import ( "fmt" "io/ioutil" + "net/http" "os" "path" "path/filepath" @@ -65,3 +66,20 @@ func resetDirectories(configuration Config) { os.MkdirAll(path.Join(configuration.PublicHLSPath, strconv.Itoa(index)), 0777) } } + +func getClientIDFromRequest(req *http.Request) string { + var ipAddress string + xForwardedFor := req.Header.Get("X-FORWARDED-FOR") + if xForwardedFor != "" { + ipAddress = xForwardedFor + } else { + ipAddressString := req.RemoteAddr + ipAddressComponents := strings.Split(ipAddressString, ":") + ipAddressComponents[len(ipAddressComponents)-1] = "" + ipAddress = strings.Join(ipAddressComponents, ":") + } + + // fmt.Println("IP address determined to be", ipAddress) + + return ipAddress +}