package main import ( "container/list" "flag" "fmt" "net" "net/http" "net/http/pprof" "net/url" "os" "os/signal" "strings" "sync" "sync/atomic" "syscall" "time" ) type UpstreamConfig map[string]*url.URL func (config *UpstreamConfig) String() string { return fmt.Sprintf("%+v", (*map[string]*url.URL)(config)) } func (config *UpstreamConfig) Set(value string) error { parts := strings.SplitN(value, "=", 2) if len(parts) < 2 { return fmt.Errorf("format error: not in name=url format: %s", value) } name := parts[0] if _, ok := (*config)[name]; ok { return fmt.Errorf("config error: remote %s already specified", name) } u, err := url.Parse(parts[1]) if err != nil { return fmt.Errorf("format error: unable to parse url %s: %w", parts[1], err) } (*config)[name] = u fmt.Printf("Added remote %s with url %s\n", name, u) return nil } var listen = flag.String("listen", "unix:/var/run/tsb/tsb.sock", "Listen address of the service") var upstreams = make(UpstreamConfig) type UpstreamContext struct { viewerCond *sync.Cond viewer int64 upstreamChan chan []byte clientsChanList *list.List newClientsChanList *list.List newChanLock *sync.Mutex } var upstreamContexts = make(map[string]*UpstreamContext) var upstreamContextLock = &sync.Mutex{} func GetUpstreamContext(name string) *UpstreamContext { upstreamContextLock.Lock() defer upstreamContextLock.Unlock() ctx, ok := upstreamContexts[name] if !ok { ctx = &UpstreamContext{ viewerCond: sync.NewCond(&sync.Mutex{}), viewer: 0, upstreamChan: make(chan []byte), clientsChanList: list.New(), newClientsChanList: list.New(), newChanLock: &sync.Mutex{}, } upstreamContexts[name] = ctx go upstreamFiber(name) go broadcastFiber(name) } return ctx } type ClientChannel struct { Notify chan struct{} Data chan []byte ctx *UpstreamContext } func NewClientChannel(name string) *ClientChannel { ctx := GetUpstreamContext(name) ch := &ClientChannel{ Notify: make(chan struct{}), Data: make(chan []byte, 1024), ctx: ctx, } ctx.newChanLock.Lock() ctx.newClientsChanList.PushBack(ch) ctx.newChanLock.Unlock() if atomic.AddInt64(&ctx.viewer, 1) == 1 { ctx.viewerCond.Broadcast() } fmt.Printf("Current viewers for stream %s: %d\n", name, ctx.viewer) return ch } // Close the notify channel, to the server will stop sending data and remove this entry. func (ch *ClientChannel) Close() { close(ch.Notify) atomic.AddInt64(&ch.ctx.viewer, -1) } func init() { flag.Var(&upstreams, "upstream", "Upstream in name=URL format, may specify multiple times.") } func upstreamFiber(name string) { ctx := GetUpstreamContext(name) client := &http.Client{} req := &http.Request{ Method: "GET", URL: upstreams[name], Close: true, } for { ctx.viewerCond.L.Lock() if ctx.viewer == 0 { fmt.Printf("No active viewers for stream %s, idling around :)\n", name) ctx.viewerCond.Wait() } ctx.viewerCond.L.Unlock() fmt.Printf("Connecting to upstream for stream %s ...\n", name) r := (func() *http.Response { for { r, err := client.Do(req) if err != nil { fmt.Fprintf(os.Stderr, "Unable to request %s for stream %s: %v\n", upstreams[name], name, err) time.Sleep(5 * time.Second) continue } if r.StatusCode != 200 { fmt.Fprintf(os.Stderr, "Request to %s for stream %s failed with status code %d\n", upstreams[name], name, err) time.Sleep(5 * time.Second) continue } return r } })() body := r.Body for { buff := make([]byte, 65536) if ctx.viewer == 0 { fmt.Printf("No active viewers for %s, stopping reading from upstream.\n", name) break } n, err := body.Read(buff) if err != nil { fmt.Fprintf(os.Stderr, "Upstream for stream %s request reached EOF\n", name) break } ctx.upstreamChan <- buff[:n] } body.Close() } } func broadcastFiber(name string) { ctx := GetUpstreamContext(name) for { chunk := <-ctx.upstreamChan // Send to all existing clients e := ctx.clientsChanList.Front() for e != nil { ch := e.Value.(*ClientChannel) select { case <-ch.Notify: // Closed client, remove this entry close(ch.Data) next := e.Next() ctx.clientsChanList.Remove(e) e = next case ch.Data <- chunk: e = e.Next() default: e = e.Next() } } // Try to serve all new clients, but don't get locked up here :) if ctx.newClientsChanList.Len() > 0 && ctx.newChanLock.TryLock() { ctx.clientsChanList.PushBackList(ctx.newClientsChanList) ctx.newClientsChanList.Init() ctx.newChanLock.Unlock() } } } func clientHandler(w http.ResponseWriter, r *http.Request) { remote := r.RemoteAddr if r.Header.Get("X-Remote-Addr") != "" { remote = r.Header.Get("X-Remote-Addr") } fmt.Printf("Client connection from %s accepted.\n", remote) w.Header().Add("Content-Type", "video/MP2T") w.WriteHeader(200) name := r.URL.Path if _, ok := upstreams[name]; !ok { http.NotFound(w, r) return } ch := NewClientChannel(name) defer ch.Close() for { chunk := <-ch.Data _, err := w.Write(chunk) if err != nil { fmt.Printf("Write to remote client %s failed, closing...\n", remote) break } } } func viewersHandler(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "text/plain") w.WriteHeader(200) name := r.URL.Path if name == "" { totalViewer := int64(0) upstreamContextLock.Lock() for name, ctx := range upstreamContexts { totalViewer += ctx.viewer fmt.Fprintf(w, "%s: %d\n", name, ctx.viewer) } upstreamContextLock.Unlock() fmt.Fprintf(w, "sum: %d\n", totalViewer) } else { if _, ok := upstreams[name]; !ok { http.NotFound(w, r) return } upstreamContextLock.Lock() ctx, ok := upstreamContexts[name] if !ok { fmt.Fprintf(w, "%s: 0\n", name) } else { fmt.Fprintf(w, "%s: %d\n", name, ctx.viewer) } upstreamContextLock.Unlock() } } func main() { flag.Parse() var listener net.Listener var err error if strings.HasPrefix(*listen, "unix:") { if listener, err = net.Listen("unix", (*listen)[5:]); err != nil { panic(err) } } else { if listener, err = net.Listen("tcp", *listen); err != nil { panic(err) } } handler := http.NewServeMux() handler.Handle("/stream/", http.StripPrefix("/stream/", http.HandlerFunc(clientHandler))) handler.Handle("/viewers/", http.StripPrefix("/viewers/", http.HandlerFunc(viewersHandler))) handler.HandleFunc("/debug/pprof/", pprof.Index) handler.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) handler.HandleFunc("/debug/pprof/profile", pprof.Profile) handler.HandleFunc("/debug/pprof/symbol", pprof.Symbol) handler.HandleFunc("/debug/pprof/trace", pprof.Trace) server := &http.Server{ Handler: handler, } go server.Serve(listener) ch := make(chan os.Signal) signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) select { case sig := <-ch: fmt.Fprintf(os.Stderr, "Caught signal %v, exiting...\n", sig) os.Exit(0) } }