diff --git a/http.go b/http.go index e4fe926..e8968e7 100644 --- a/http.go +++ b/http.go @@ -2,13 +2,13 @@ package main import ( "embed" + "fmt" "html/template" - "log" - "math" "net" "net/http" "strings" "sync" + "sync/atomic" "time" ) @@ -19,8 +19,8 @@ var htmlTemplate *template.Template type clientState int const ( - INITIAL = 0 - ACTIVE = iota + INITIAL = iota + ACTIVE ) type client struct { @@ -29,43 +29,25 @@ type client struct { } var ( - clientCounter uint32 = 0 - clientCounterMutex sync.Mutex - - clients = make(map[uint32]*client) - clientMutex sync.RWMutex + clients = make([]*client, 0) + clientMutex sync.Mutex ) -func getClientID() uint32 { - clientCounterMutex.Lock() - defer clientCounterMutex.Unlock() - clientCounter++ - if clientCounter == math.MaxUint32 { - clientCounter = 0 - clearClients() +func deleteClient(client *client) { + for i := 0; i < len(clients); i++ { + if clients[i] == client { + clients[i] = clients[len(clients)-1] + clients = clients[:len(clients)-1] + return + } } - return clientCounter } -func clearClients() { - clientMutex.Lock() - defer clientMutex.Unlock() - for _, c := range clients { - close(c.channel) - } - clients = make(map[uint32]*client) -} - -func httpServer() { - var err error - +func httpServer() error { htmlTemplate = template.Must(template.ParseFS(embedFS, "template.html")) http.HandleFunc("/stream", stream) http.HandleFunc("/", serveRoot) - err = http.ListenAndServe(":9090", nil) - if err != nil { - log.Fatalln(err) - } + return http.ListenAndServe(":9090", nil) } func getInterfaceBaseIP() string { @@ -107,8 +89,7 @@ func getInterfaceBaseIP() string { func serveRoot(w http.ResponseWriter, r *http.Request) { if r.RequestURI != "/" { - w.WriteHeader(404) - _, _ = w.Write([]byte("404 not found")) + http.Error(w, "Not found", http.StatusNotFound) return } type pageData struct { @@ -126,64 +107,65 @@ func serveRoot(w http.ResponseWriter, r *http.Request) { CanvasWidth: 512, }) if err != nil { - log.Println(err) + fmt.Println("Error executing HTML template:", err) } } +var streamServerRunning atomic.Bool + func streamServer() { - for { - clientMutex.RLock() - if len(clients) == 0 { - for { - if len(clients) == 0 { - clientMutex.RUnlock() - time.Sleep(1 * time.Second) - clientMutex.RLock() + if !streamServerRunning.CompareAndSwap(false, true) { + return + } + go func() { + for { + clientMutex.Lock() + if len(clients) == 0 { + streamServerRunning.Store(false) + clientMutex.Unlock() + return + } + + requiresInitial := false + requiresUpdate := false + for _, v := range clients { + if v.state == INITIAL { + requiresInitial = true } else { + requiresUpdate = true + } + if requiresInitial && requiresUpdate { break } } - } - requiresInitial := false - requiresUpdate := false - for _, v := range clients { - if v.state == INITIAL { - requiresInitial = true - } else { - requiresUpdate = true - } - if requiresInitial && requiresUpdate { - break - } - } - - dataInitial, dataUpdate := getPicture(requiresInitial, requiresUpdate) - - for clientID, v := range clients { - if v.state == INITIAL { - v.state = ACTIVE - select { - case v.channel <- dataInitial: - default: - continue - } - } else { - if dataUpdate != "0" { + dataInitial, dataUpdate := getPicture(requiresInitial, requiresUpdate) + tmp := clients[:0] + for _, v := range clients { + if v.state == INITIAL { + v.state = ACTIVE select { - case v.channel <- dataUpdate: + case v.channel <- dataInitial: default: - // Client cannot keep up - close(v.channel) - delete(clients, clientID) - continue + } + } else { + if dataUpdate != "0" { + select { + case v.channel <- dataUpdate: + default: + // Client cannot keep up + close(v.channel) + continue + } } } + tmp = append(tmp, v) } + clients = tmp + clientMutex.Unlock() + time.Sleep(500 * time.Millisecond) } - clientMutex.RUnlock() - time.Sleep(500 * time.Millisecond) - } + }() } func stream(w http.ResponseWriter, r *http.Request) { @@ -192,17 +174,17 @@ func stream(w http.ResponseWriter, r *http.Request) { http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) return } + streamServer() w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") messageChan := make(chan string, 40) - id := getClientID() - newClient := client{ + newClient := &client{ channel: messageChan, state: INITIAL, } clientMutex.Lock() - clients[id] = &newClient + clients = append(clients, newClient) clientMutex.Unlock() // For when clients are removed prior to connection closed, to avoid a call to delete(clients, id) @@ -212,7 +194,7 @@ func stream(w http.ResponseWriter, r *http.Request) { <-r.Context().Done() clientMutex.Lock() if !channelClosedFirst { - delete(clients, id) + deleteClient(newClient) } close(messageChan) clientMutex.Unlock() diff --git a/main.go b/main.go index 0788bc7..84ceb94 100644 --- a/main.go +++ b/main.go @@ -9,24 +9,32 @@ import ( "image" "image/color" "image/png" - "log" + "os" + "runtime" "sync" ) const interfaceName = "canvas" -const handlerCount = 4 func main() { prePopulatePixelArray() packetChan := make(chan *[]byte, 1000) - for i := 0; i < handlerCount; i++ { + for i := 0; i < runtime.NumCPU(); i++ { go packetHandler(packetChan) } - go startInterface(packetChan) - go streamServer() + go func() { + err := startInterface(packetChan) + if err != nil { + fmt.Println("Interface handler error:", err) + os.Exit(0) + } + }() fmt.Println("Kioubit ColorPing started") fmt.Println("Interface name:", interfaceName, "HTTP server port: 9090") - httpServer() + if err := httpServer(); err != nil { + fmt.Println("Error starting HTTP server:", err) + return + } } func prePopulatePixelArray() { @@ -45,21 +53,21 @@ var pktPool = sync.Pool{ New: func() interface{} { return make([]byte, 2000) }, } -func startInterface(packetChan chan *[]byte) { +func startInterface(packetChan chan *[]byte) error { config := water.Config{ DeviceType: water.TUN, } config.Name = interfaceName iFace, err := water.New(config) if err != nil { - log.Fatal(err) + return err } for { packet := pktPool.Get().([]byte) n, err := iFace.Read(packet) if err != nil { - log.Fatal(err) + return err } packet = packet[:n] packetChan <- &packet @@ -167,7 +175,7 @@ func getPicture(fullUpdate bool, incrementalUpdate bool) (string, string) { buff := new(bytes.Buffer) err := encoder.Encode(buff, canvasIncrementalUpdate) if err != nil { - log.Println(err.Error()) + fmt.Println("PNG encoding error:", err) } incrementalUpdateResult = "event: u\ndata:" + base64.StdEncoding.EncodeToString(buff.Bytes()) + "\n\n" } @@ -177,7 +185,7 @@ func getPicture(fullUpdate bool, incrementalUpdate bool) (string, string) { buff := new(bytes.Buffer) err := encoder.Encode(buff, canvasFullUpdate) if err != nil { - log.Println(err.Error()) + fmt.Println("PNG encoding error:", err) } fullUpdateResult = "event: u\ndata:" + base64.StdEncoding.EncodeToString(buff.Bytes()) + "\n\n" } diff --git a/template.html b/template.html index 16f6564..254fd77 100644 --- a/template.html +++ b/template.html @@ -103,13 +103,9 @@