diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..29343f7 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.eve.moe/jackyyf/tsb + +go 1.19 diff --git a/main.go b/main.go new file mode 100644 index 0000000..02eb6e5 --- /dev/null +++ b/main.go @@ -0,0 +1,279 @@ +package main + +import ( + "container/list" + "flag" + "fmt" + "net" + "net/http" + "net/http/pprof" + "net/url" + "os" + "os/signal" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" +) + +type UpstreamConfig map[string]*url.URL + +func (config *UpstreamConfig) String() string { + return fmt.Sprintf("%+v", (*map[string]*url.URL)(config)) +} + +func (config *UpstreamConfig) Set(value string) error { + parts := strings.SplitN(value, "=", 2) + if len(parts) < 2 { + return fmt.Errorf("format error: not in name=url format: %s", value) + } + name := parts[0] + if _, ok := (*config)[name]; ok { + return fmt.Errorf("config error: remote %s already specified", name) + } + u, err := url.Parse(parts[1]) + if err != nil { + return fmt.Errorf("format error: unable to parse url %s: %w", parts[1], err) + } + (*config)[name] = u + fmt.Printf("Added remote %s with url %s\n", name, u) + return nil +} + +var listen = flag.String("listen", "unix:/var/run/tsb/tsb.sock", "Listen address of the service") + +var upstreams = make(UpstreamConfig) + +type UpstreamContext struct { + viewerCond *sync.Cond + viewer int64 + upstreamChan chan []byte + clientsChanList *list.List + newClientsChanList *list.List + newChanLock *sync.Mutex +} + +var upstreamContexts = make(map[string]*UpstreamContext) +var upstreamContextLock = &sync.Mutex{} + +func GetUpstreamContext(name string) *UpstreamContext { + upstreamContextLock.Lock() + defer upstreamContextLock.Unlock() + ctx, ok := upstreamContexts[name] + if !ok { + ctx = &UpstreamContext{ + viewerCond: sync.NewCond(&sync.Mutex{}), + viewer: 0, + upstreamChan: make(chan []byte), + clientsChanList: list.New(), + newClientsChanList: list.New(), + newChanLock: &sync.Mutex{}, + } + upstreamContexts[name] = ctx + go upstreamFiber(name) + go broadcastFiber(name) + } + return ctx +} + +type ClientChannel struct { + Notify chan struct{} + Data chan []byte + ctx *UpstreamContext +} + +func NewClientChannel(name string) *ClientChannel { + ctx := GetUpstreamContext(name) + ch := &ClientChannel{ + Notify: make(chan struct{}), + Data: make(chan []byte, 1024), + ctx: ctx, + } + ctx.newChanLock.Lock() + ctx.newClientsChanList.PushBack(ch) + ctx.newChanLock.Unlock() + if atomic.AddInt64(&ctx.viewer, 1) == 1 { + ctx.viewerCond.Broadcast() + } + fmt.Printf("Current viewers for stream %s: %d\n", name, ctx.viewer) + return ch +} + +// Close the notify channel, to the server will stop sending data and remove this entry. +func (ch *ClientChannel) Close() { + close(ch.Notify) + atomic.AddInt64(&ch.ctx.viewer, -1) +} + +func init() { + flag.Var(&upstreams, "upstream", "Upstream in name=URL format, may specify multiple times.") +} + +func upstreamFiber(name string) { + ctx := GetUpstreamContext(name) + client := &http.Client{} + req := &http.Request{ + Method: "GET", + URL: upstreams[name], + Close: true, + } + for { + ctx.viewerCond.L.Lock() + if ctx.viewer == 0 { + fmt.Printf("No active viewers for stream %s, idling around :)\n", name) + ctx.viewerCond.Wait() + } + ctx.viewerCond.L.Unlock() + fmt.Printf("Connecting to upstream for stream %s ...\n", name) + r := (func() *http.Response { + for { + r, err := client.Do(req) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to request %s for stream %s: %v\n", upstreams[name], name, err) + time.Sleep(5 * time.Second) + continue + } + if r.StatusCode != 200 { + fmt.Fprintf(os.Stderr, "Request to %s for stream %s failed with status code %d\n", upstreams[name], name, err) + time.Sleep(5 * time.Second) + continue + } + return r + } + })() + body := r.Body + for { + buff := make([]byte, 65536) + if ctx.viewer == 0 { + fmt.Printf("No active viewers for %s, stopping reading from upstream.\n", name) + break + } + n, err := body.Read(buff) + if err != nil { + fmt.Fprintf(os.Stderr, "Upstream for stream %s request reached EOF\n", name) + break + } + ctx.upstreamChan <- buff[:n] + } + body.Close() + } +} + +func broadcastFiber(name string) { + ctx := GetUpstreamContext(name) + for { + chunk := <-ctx.upstreamChan + // Send to all existing clients + e := ctx.clientsChanList.Front() + for e != nil { + ch := e.Value.(*ClientChannel) + select { + case <-ch.Notify: + // Closed client, remove this entry + close(ch.Data) + next := e.Next() + ctx.clientsChanList.Remove(e) + e = next + case ch.Data <- chunk: + e = e.Next() + default: + e = e.Next() + } + } + // Try to serve all new clients, but don't get locked up here :) + if ctx.newClientsChanList.Len() > 0 && ctx.newChanLock.TryLock() { + ctx.clientsChanList.PushBackList(ctx.newClientsChanList) + ctx.newClientsChanList.Init() + ctx.newChanLock.Unlock() + } + } +} + +func clientHandler(w http.ResponseWriter, r *http.Request) { + remote := r.RemoteAddr + if r.Header.Get("X-Remote-Addr") != "" { + remote = r.Header.Get("X-Remote-Addr") + } + fmt.Printf("Client connection from %s accepted.\n", remote) + w.Header().Add("Content-Type", "video/MP2T") + w.WriteHeader(200) + name := r.URL.Path + if _, ok := upstreams[name]; !ok { + http.NotFound(w, r) + return + } + ch := NewClientChannel(name) + defer ch.Close() + for { + chunk := <-ch.Data + _, err := w.Write(chunk) + if err != nil { + fmt.Printf("Write to remote client %s failed, closing...\n", remote) + break + } + } +} + +func viewersHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "text/plain") + w.WriteHeader(200) + name := r.URL.Path + if name == "" { + totalViewer := int64(0) + upstreamContextLock.Lock() + for name, ctx := range upstreamContexts { + totalViewer += ctx.viewer + fmt.Fprintf(w, "%s: %d\n", name, ctx.viewer) + } + upstreamContextLock.Unlock() + fmt.Fprintf(w, "sum: %d\n", totalViewer) + } else { + if _, ok := upstreams[name]; !ok { + http.NotFound(w, r) + return + } + upstreamContextLock.Lock() + ctx, ok := upstreamContexts[name] + if !ok { + fmt.Fprintf(w, "%s: 0\n", name) + } else { + fmt.Fprintf(w, "%s: %d\n", name, ctx.viewer) + } + upstreamContextLock.Unlock() + } +} + +func main() { + flag.Parse() + var listener net.Listener + var err error + if strings.HasPrefix(*listen, "unix:") { + if listener, err = net.Listen("unix", (*listen)[5:]); err != nil { + panic(err) + } + } else { + if listener, err = net.Listen("tcp", *listen); err != nil { + panic(err) + } + } + handler := http.NewServeMux() + handler.Handle("/stream/", http.StripPrefix("/stream/", http.HandlerFunc(clientHandler))) + handler.Handle("/viewers/", http.StripPrefix("/viewers/", http.HandlerFunc(viewersHandler))) + handler.HandleFunc("/debug/pprof/", pprof.Index) + handler.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + handler.HandleFunc("/debug/pprof/profile", pprof.Profile) + handler.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + handler.HandleFunc("/debug/pprof/trace", pprof.Trace) + server := &http.Server{ + Handler: handler, + } + go server.Serve(listener) + ch := make(chan os.Signal) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + select { + case sig := <-ch: + fmt.Fprintf(os.Stderr, "Caught signal %v, exiting...\n", sig) + os.Exit(0) + } +}