diff --git a/.gitignore b/.gitignore index 5b6c096..b47ac8c 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,7 @@ config.toml +drmdtool +drmdtool_* +src/DRMDTool +*.exe +uploads/ +src/uploads/ diff --git a/README.md b/README.md index 04bee7e..02ac850 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ BaseDir = "/path/to/save/downloads" Format = "mkv" TempBaseDir = "/tmp/nre" EnableConsole = true +MaxUploadMB = 32 [WatchFolder] Path = "/path/to/watched/folder" @@ -21,6 +22,17 @@ UseInotify = false [N_m3u8DLRE] Path = "/path/to/N_m3u8DL-RE" + +[Server] +Host = "127.0.0.1" +Port = 8080 +ReadTimeoutSec = 30 +WriteTimeoutSec = 30 +IdleTimeoutSec = 60 +ReadHeaderTimeoutS = 10 + +[Security] +AuthToken = "" ``` ### Configuration Options @@ -30,6 +42,7 @@ Path = "/path/to/N_m3u8DL-RE" - `Format`: Output format for the downloaded files (e.g., `mkv`, `mp4`). - `TempBaseDir`: Temporary directory for intermediate files. - `EnableConsole`: Boolean to enable or disable console output. + - `MaxUploadMB`: Maximum allowed upload size for the web UI. - **WatchFolder** - `Path`: Directory to watch for new `.drmd` files. @@ -40,6 +53,14 @@ Path = "/path/to/N_m3u8DL-RE" - **N_m3u8DLRE** - `Path`: Path to the N_m3u8DL-RE executable. +- **Server** + - `Host`: Bind address for the web server (`127.0.0.1` recommended). + - `Port`: Web server port. + - `ReadTimeoutSec`, `WriteTimeoutSec`, `IdleTimeoutSec`, `ReadHeaderTimeoutS`: HTTP timeout settings. + +- **Security** + - `AuthToken`: Optional token for protecting all endpoints. Recommended when binding to a non-loopback host. + ### Environment Variable Overrides You can override the configuration options using environment variables. The following environment variables are supported: @@ -48,10 +69,14 @@ You can override the configuration options using environment variables. The foll - `FORMAT`: Overrides `General.Format` - `TEMP_BASE_DIR`: Overrides `General.TempBaseDir` - `ENABLE_CONSOLE`: Overrides `General.EnableConsole` (set to `true` or `false`) +- `MAX_UPLOAD_MB`: Overrides `General.MaxUploadMB` - `WATCHED_FOLDER`: Overrides `WatchFolder.Path` - `USE_POLLING`: Overrides `WatchFolder.UsePolling` (set to `true` or `false`) - `USE_INOTIFY`: Overrides `WatchFolder.UseInotify` (set to `true` or `false`) - `POLLING_INTERVAL`: Overrides `WatchFolder.PollingInterval` +- `SERVER_HOST`: Overrides `Server.Host` +- `SERVER_PORT`: Overrides `Server.Port` +- `AUTH_TOKEN`: Overrides `Security.AuthToken` ## Web UI Usage @@ -62,6 +87,9 @@ You can override the configuration options using environment variables. The foll 2. Open a web browser and go to `http://localhost:8080` + If `Security.AuthToken` is configured, include it as a query parameter: + `http://localhost:8080/?token=YOUR_TOKEN` + 3. Use the interface to upload .drmd files and monitor download progress ## CLI Usage diff --git a/config.template.toml b/config.template.toml index 54bcb89..ae7d927 100644 --- a/config.template.toml +++ b/config.template.toml @@ -3,6 +3,7 @@ BaseDir = "/mnt/media" Format = "mkv" TempBaseDir = "/tmp/nre" EnableConsole = true +MaxUploadMB = 32 [WatchFolder] Path = "/mnt/watched" @@ -12,3 +13,14 @@ UseInotify = false [N_m3u8DLRE] Path = "nre" + +[Server] +Host = "127.0.0.1" +Port = 8080 +ReadTimeoutSec = 30 +WriteTimeoutSec = 30 +IdleTimeoutSec = 60 +ReadHeaderTimeoutS = 10 + +[Security] +AuthToken = "" diff --git a/src/config.go b/src/config.go index 1081f12..0d2a5a5 100644 --- a/src/config.go +++ b/src/config.go @@ -1,9 +1,11 @@ package main import ( + "errors" "fmt" "io" "os" + "os/exec" "strconv" "strings" @@ -16,6 +18,7 @@ type Config struct { Format string TempBaseDir string EnableConsole bool + MaxUploadMB int } WatchFolder struct { Path string @@ -23,6 +26,17 @@ type Config struct { UseInotify bool PollingInterval int } + Server struct { + Host string + Port int + ReadTimeoutSec int + WriteTimeoutSec int + IdleTimeoutSec int + ReadHeaderTimeoutS int + } + Security struct { + AuthToken string + } N_m3u8DLRE struct { Path string } @@ -30,15 +44,19 @@ type Config struct { var config Config -func loadConfig() { - configFile, err := os.Open("config.toml") +func loadConfig(path string) { + configFile, err := os.Open(path) if err != nil { logger.LogError("Config", fmt.Sprintf("Error opening config file: %v", err)) os.Exit(1) } defer configFile.Close() - byteValue, _ := io.ReadAll(configFile) + byteValue, err := io.ReadAll(configFile) + if err != nil { + logger.LogError("Config", fmt.Sprintf("Error reading config file: %v", err)) + os.Exit(1) + } if _, err := toml.Decode(string(byteValue), &config); err != nil { logger.LogError("Config", fmt.Sprintf("Error decoding config file: %v", err)) @@ -46,6 +64,7 @@ func loadConfig() { } overrideConfigWithEnv() + setDefaultConfigValues() if err := validatePaths(); err != nil { logger.LogError("Config", fmt.Sprintf("Configuration error: %v", err)) @@ -59,6 +78,36 @@ func loadConfig() { logConfig() } +func setDefaultConfigValues() { + if config.General.MaxUploadMB <= 0 { + config.General.MaxUploadMB = 32 + } + + if strings.TrimSpace(config.Server.Host) == "" { + config.Server.Host = "127.0.0.1" + } + + if config.Server.Port <= 0 { + config.Server.Port = 8080 + } + + if config.Server.ReadTimeoutSec <= 0 { + config.Server.ReadTimeoutSec = 30 + } + + if config.Server.WriteTimeoutSec <= 0 { + config.Server.WriteTimeoutSec = 30 + } + + if config.Server.IdleTimeoutSec <= 0 { + config.Server.IdleTimeoutSec = 60 + } + + if config.Server.ReadHeaderTimeoutS <= 0 { + config.Server.ReadHeaderTimeoutS = 10 + } +} + func overrideConfigWithEnv() { if envBaseDir := os.Getenv("BASE_DIR"); envBaseDir != "" { config.General.BaseDir = envBaseDir @@ -86,14 +135,40 @@ func overrideConfigWithEnv() { config.WatchFolder.PollingInterval = interval } } + if envMaxUploadMB := os.Getenv("MAX_UPLOAD_MB"); envMaxUploadMB != "" { + if value, err := strconv.Atoi(envMaxUploadMB); err == nil { + config.General.MaxUploadMB = value + } + } + if envHost := os.Getenv("SERVER_HOST"); envHost != "" { + config.Server.Host = envHost + } + if envPort := os.Getenv("SERVER_PORT"); envPort != "" { + if value, err := strconv.Atoi(envPort); err == nil { + config.Server.Port = value + } + } + if envAuthToken := os.Getenv("AUTH_TOKEN"); envAuthToken != "" { + config.Security.AuthToken = envAuthToken + } } func validatePaths() error { + if strings.TrimSpace(config.General.Format) == "" { + return errors.New("format is not specified") + } + + allowedFormats := map[string]bool{"mkv": true, "mp4": true} + if !allowedFormats[strings.ToLower(config.General.Format)] { + return fmt.Errorf("unsupported format: %s (supported: mkv, mp4)", config.General.Format) + } + paths := []struct { name string path string }{ {"BaseDir", config.General.BaseDir}, + {"TempBaseDir", config.General.TempBaseDir}, } for _, p := range paths { @@ -101,6 +176,12 @@ func validatePaths() error { return fmt.Errorf("%s is not specified", p.name) } if _, err := os.Stat(p.path); os.IsNotExist(err) { + if p.name == "TempBaseDir" { + if mkErr := os.MkdirAll(p.path, 0755); mkErr != nil { + return fmt.Errorf("unable to create %s: %v", p.name, mkErr) + } + continue + } return fmt.Errorf("%s does not exist: %s", p.name, p.path) } else if err != nil { return fmt.Errorf("error accessing %s: %v", p.name, err) @@ -118,6 +199,18 @@ func validatePaths() error { } } + if strings.TrimSpace(config.N_m3u8DLRE.Path) == "" { + return errors.New("N_m3u8DLRE path is not specified") + } + + if _, err := exec.LookPath(config.N_m3u8DLRE.Path); err != nil { + return fmt.Errorf("N_m3u8DLRE executable not found in PATH: %s", config.N_m3u8DLRE.Path) + } + + if config.Server.Port <= 0 || config.Server.Port > 65535 { + return fmt.Errorf("invalid server port: %d", config.Server.Port) + } + return nil } @@ -129,15 +222,28 @@ Configuration Loaded: Format: %s TempBaseDir: %s EnableConsole: %t + MaxUploadMB: %d WatchFolder: Path: %s UsePolling: %t UseInotify: %t PollingInterval: %d + Server: + Host: %s + Port: %d + ReadTimeoutSec: %d + WriteTimeoutSec: %d + IdleTimeoutSec: %d + ReadHeaderTimeoutS: %d + Security: + AuthTokenConfigured: %t N_m3u8DLRE: Path: %s `, config.General.BaseDir, config.General.Format, config.General.TempBaseDir, config.General.EnableConsole, + config.General.MaxUploadMB, config.WatchFolder.Path, config.WatchFolder.UsePolling, config.WatchFolder.UseInotify, config.WatchFolder.PollingInterval, + config.Server.Host, config.Server.Port, config.Server.ReadTimeoutSec, config.Server.WriteTimeoutSec, config.Server.IdleTimeoutSec, config.Server.ReadHeaderTimeoutS, + strings.TrimSpace(config.Security.AuthToken) != "", config.N_m3u8DLRE.Path) logger.LogInfo("Config", configInfo) diff --git a/src/downloaders.go b/src/downloaders.go index bea2500..22387a3 100644 --- a/src/downloaders.go +++ b/src/downloaders.go @@ -1,18 +1,31 @@ package main import ( - "bytes" "encoding/base64" + "errors" "fmt" "io" - "net/http" "os" "os/exec" "path/filepath" + "regexp" "strings" - "time" ) +var decryptionKeyRegex = regexp.MustCompile(`^[0-9a-fA-F]{32}:[0-9a-fA-F]{32}$`) + +type websocketBroadcastWriter struct { + filename string +} + +func (w websocketBroadcastWriter) Write(p []byte) (int, error) { + if config.General.EnableConsole && len(p) > 0 { + message := append([]byte(nil), p...) + broadcast(w.filename, message) + } + return len(p), nil +} + func removeBOM(input []byte) []byte { if len(input) >= 3 && input[0] == 0xEF && input[1] == 0xBB && input[2] == 0xBF { return input[3:] @@ -23,14 +36,17 @@ func removeBOM(input []byte) []byte { func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error { logger.LogInfo("Download File", fmt.Sprintf("Starting download for: %s", item.Filename)) - tempDir := filepath.Join(config.General.TempBaseDir, sanitizeFilename(item.Filename)) - err := os.MkdirAll(tempDir, 0755) + if err := os.MkdirAll(config.General.TempBaseDir, 0755); err != nil { + logger.LogError("Download File", fmt.Sprintf("Error creating temp base dir: %v", err)) + return fmt.Errorf("error creating temp base dir: %v", err) + } + tempDir, err := os.MkdirTemp(config.General.TempBaseDir, sanitizeFilename(item.Filename)+"_") if err != nil { logger.LogError("Download File", fmt.Sprintf("Error creating temporary directory: %v", err)) return fmt.Errorf("error creating temporary directory: %v", err) } - jobInfo.TempDir = tempDir + jobInfo.SetTempDir(tempDir) mpdPath := item.MPD if !isValidURL(item.MPD) { @@ -58,18 +74,11 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error { mpdPath = tempFile.Name() } else if strings.HasPrefix(item.MPD, "https://pubads.g.doubleclick.net") { - resp, err := http.Get(item.MPD) + mpdContent, err := fetchRemoteContent(item.MPD) if err != nil { logger.LogError("Download File", fmt.Sprintf("Error downloading MPD: %v", err)) return fmt.Errorf("error downloading MPD: %v", err) } - defer resp.Body.Close() - - mpdContent, err := io.ReadAll(resp.Body) - if err != nil { - logger.LogError("Download File", fmt.Sprintf("Error reading MPD content: %v", err)) - return fmt.Errorf("error reading MPD content: %v", err) - } fixedMPDContent, err := fixGoPlay(string(mpdContent)) if err != nil { @@ -96,7 +105,10 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error { mpdPath = tempFile.Name() } - command := getDownloadCommand(item, mpdPath, tempDir) + args, err := getDownloadArgs(item, mpdPath, tempDir) + if err != nil { + return err + } if item.Subtitles != "" { subtitlePaths, err := downloadAndConvertSubtitles(item.Subtitles) @@ -105,17 +117,17 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error { } else { for _, path := range subtitlePaths { logger.LogInfo("Download File", fmt.Sprintf("Adding subtitle: %s", path)) - command += fmt.Sprintf(" --mux-import \"path=%s:lang=nl:name=Nederlands\"", path) + args = append(args, "--mux-import", fmt.Sprintf("path=%s:lang=nl:name=Nederlands", path)) } } } - cmd := exec.Command("bash", "-c", command) - jobInfo.Cmd = cmd + cmd := exec.Command(config.N_m3u8DLRE.Path, args...) + jobInfo.SetCmd(cmd) - var outputBuffer bytes.Buffer - cmd.Stdout = io.MultiWriter(&outputBuffer) - cmd.Stderr = os.Stderr + broadcastWriter := websocketBroadcastWriter{filename: drmdFilename} + cmd.Stdout = io.MultiWriter(os.Stdout, broadcastWriter) + cmd.Stderr = io.MultiWriter(os.Stderr, broadcastWriter) err = cmd.Start() if err != nil { @@ -128,33 +140,26 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error { done <- cmd.Wait() }() - go func() { - for { - if outputBuffer.Len() > 0 { - message := outputBuffer.Bytes() - if config.General.EnableConsole { - broadcast(drmdFilename, message) - } - outputBuffer.Reset() - } - time.Sleep(1 * time.Second) - } - }() - select { case <-jobInfo.AbortChan: - if cmd.Process != nil { - cmd.Process.Kill() - } - os.RemoveAll(tempDir) + jobInfo.KillProcess() + _ = os.RemoveAll(tempDir) logger.LogInfo("Download File", "Download aborted") - return fmt.Errorf("download aborted") + return ErrDownloadAborted case err := <-done: - if jobInfo.Paused { + if jobInfo.IsPaused() { logger.LogInfo("Download File", "Download paused") - return fmt.Errorf("download paused") + return ErrDownloadPaused } if err != nil { + if jobInfo.IsAborted() { + return ErrDownloadAborted + } + + if jobInfo.IsPaused() { + return ErrDownloadPaused + } + logger.LogError("Download File", fmt.Sprintf("Error executing download command: %v", err)) return fmt.Errorf("error executing download command: %v", err) } @@ -164,26 +169,24 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error { return nil } -func getDownloadCommand(item Item, mpdPath string, tempDir string) string { +func getDownloadArgs(item Item, mpdPath string, tempDir string) ([]string, error) { metadata := parseMetadata(item.Metadata) keys := getKeys(item.Keys) - command := fmt.Sprintf("%s \"%s\"", config.N_m3u8DLRE.Path, mpdPath) + args := []string{mpdPath} for _, key := range keys { - if key != "" { - command += fmt.Sprintf(" --key %s", key) + if !decryptionKeyRegex.MatchString(key) { + return nil, fmt.Errorf("invalid decryption key format") } + args = append(args, "--key", key) } - command += " --auto-select" + args = append(args, "--auto-select") sanitizedFilename := sanitizeFilename(item.Filename) - - filename := fmt.Sprintf("\"%s\"", sanitizedFilename) - command += fmt.Sprintf(" --save-name %s", filename) - - command += fmt.Sprintf(" --mux-after-done format=%s", config.General.Format) + args = append(args, "--save-name", sanitizedFilename) + args = append(args, "--mux-after-done", fmt.Sprintf("format=%s", config.General.Format)) saveDir := config.General.BaseDir if metadata.Type == "serie" { @@ -191,15 +194,19 @@ func getDownloadCommand(item Item, mpdPath string, tempDir string) string { } else { saveDir = filepath.Join(saveDir, "Movies", metadata.Title) } - command += fmt.Sprintf(" --save-dir \"%s\"", saveDir) + if err := os.MkdirAll(saveDir, 0755); err != nil { + return nil, fmt.Errorf("unable to create save directory: %w", err) + } + args = append(args, "--save-dir", saveDir) + args = append(args, "--tmp-dir", tempDir) - command += fmt.Sprintf(" --tmp-dir \"%s\"", tempDir) - - if globalSpeedLimit != "" { - command += fmt.Sprintf(" -R %s", globalSpeedLimit) + currentSpeedLimit := getGlobalSpeedLimit() + if currentSpeedLimit != "" { + if !speedLimitRegex.MatchString(currentSpeedLimit) { + return nil, errors.New("invalid speed limit format") + } + args = append(args, "-R", currentSpeedLimit) } - fmt.Println(command) - - return command + return args, nil } diff --git a/src/go.mod b/src/go.mod index f8ce615..20fa35e 100644 --- a/src/go.mod +++ b/src/go.mod @@ -1,6 +1,6 @@ module DRMDTool -go 1.23.0 +go 1.25.0 require ( github.com/BurntSushi/toml v1.4.0 @@ -8,13 +8,13 @@ require ( github.com/beevik/etree v1.4.1 ) -require golang.org/x/sys v0.4.0 // indirect +require golang.org/x/sys v0.43.0 // indirect require ( github.com/asticode/go-astikit v0.20.0 // indirect github.com/asticode/go-astits v1.8.0 // indirect github.com/fsnotify/fsnotify v1.7.0 github.com/gorilla/websocket v1.5.3 - golang.org/x/net v0.0.0-20200904194848-62affa334b73 // indirect - golang.org/x/text v0.3.2 // indirect + golang.org/x/net v0.53.0 // indirect + golang.org/x/text v0.36.0 // indirect ) diff --git a/src/go.sum b/src/go.sum index cb34f19..a65b877 100644 --- a/src/go.sum +++ b/src/go.sum @@ -14,8 +14,6 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/pkg/exec v0.0.0-20150614095509-0bd164ad2a5a h1:EN123kAtAAE2pg/+TvBsUBZfHCWNNFyL2ZBPPfNWAc0= -github.com/pkg/exec v0.0.0-20150614095509-0bd164ad2a5a/go.mod h1:b95YoNrAnScjaWG+asr8lxqlrsPUcT2ZEBcjvVGshMo= github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -25,16 +23,18 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA= golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/src/handlers.go b/src/handlers.go deleted file mode 100644 index da28c77..0000000 --- a/src/handlers.go +++ /dev/null @@ -1,478 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "sync" - - "github.com/gorilla/websocket" -) - -type ProgressInfo struct { - Percentage float64 - CurrentFile string - Paused bool -} - -func handleRoot(w http.ResponseWriter, r *http.Request) { - progressMutex.Lock() - defer progressMutex.Unlock() - - jobsInfo := make(map[string]struct { - Percentage float64 - CurrentFile string - Paused bool - }) - - for filename, info := range progress { - jobsInfo[filename] = struct { - Percentage float64 - CurrentFile string - Paused bool - }{ - Percentage: info.Percentage, - CurrentFile: info.CurrentFile, - Paused: info.Paused, - } - } - - err := templates.ExecuteTemplate(w, "index", struct { - Jobs map[string]struct { - Percentage float64 - CurrentFile string - Paused bool - } - GlobalSpeedLimit string - }{ - Jobs: jobsInfo, - GlobalSpeedLimit: globalSpeedLimit, - }) - if err != nil { - logger.LogError("Handle Root", fmt.Sprintf("Error executing template: %v", err)) - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -func handleUpload(w http.ResponseWriter, r *http.Request) { - logger.LogInfo("Handle Upload", "Starting file upload") - err := r.ParseMultipartForm(32 << 20) - if err != nil { - logger.LogError("Handle Upload", fmt.Sprintf("Error parsing multipart form: %v", err)) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - files := r.MultipartForm.File["files"] - if len(files) == 0 { - logger.LogError("Handle Upload", "No files uploaded") - http.Error(w, "No files uploaded", http.StatusBadRequest) - return - } - - uploadedFiles := []string{} - - for _, fileHeader := range files { - file, err := fileHeader.Open() - if err != nil { - logger.LogError("Handle Upload", fmt.Sprintf("Error opening file: %v", err)) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer file.Close() - - tempFile, err := os.CreateTemp(uploadDir, fileHeader.Filename) - if err != nil { - logger.LogError("Handle Upload", fmt.Sprintf("Error creating temporary file: %v", err)) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer tempFile.Close() - - _, err = io.Copy(tempFile, file) - if err != nil { - logger.LogError("Handle Upload", fmt.Sprintf("Error copying file: %v", err)) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - uploadedFiles = append(uploadedFiles, filepath.Base(tempFile.Name())) - - _, err = parseInputFile(tempFile.Name()) - if err != nil { - logger.LogError("Handle Upload", fmt.Sprintf("Error parsing input file: %v", err)) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - - validFiles := []string{} - for _, file := range uploadedFiles { - if file != "" { - validFiles = append(validFiles, file) - } - } - - if len(validFiles) == 0 { - logger.LogError("Handle Upload", "No valid files were uploaded") - http.Error(w, "No valid files were uploaded", http.StatusBadRequest) - return - } - - logger.LogInfo("Handle Upload", fmt.Sprintf("Redirecting to select with files: %v", validFiles)) - http.Redirect(w, r, "/select?files="+url.QueryEscape(strings.Join(validFiles, ",")), http.StatusSeeOther) -} - -func handleSelect(w http.ResponseWriter, r *http.Request) { - filesParam := r.URL.Query().Get("files") - filenames := strings.Split(filesParam, ",") - - allItems := make(map[string]map[string][]Item) - - for _, filename := range filenames { - if filename == "" { - continue - } - fullPath := filepath.Join(uploadDir, filename) - - if _, err := os.Stat(fullPath); os.IsNotExist(err) { - logger.LogError("Handle Select", fmt.Sprintf("File does not exist: %s", fullPath)) - continue - } - - items, err := parseInputFile(fullPath) - if err != nil { - logger.LogError("Handle Select", fmt.Sprintf("Error parsing input file: %v", err)) - continue - } - - sortItems(items) - - groupedItems := groupItemsBySeason(items) - allItems[filename] = groupedItems - } - - if len(allItems) == 0 { - logger.LogError("Handle Select", "No valid files were processed") - http.Error(w, "No valid files were processed", http.StatusBadRequest) - return - } - - err := templates.ExecuteTemplate(w, "select", struct { - Filenames string - AllItems map[string]map[string][]Item - }{ - Filenames: filesParam, - AllItems: allItems, - }) - if err != nil { - logger.LogError("Handle Select", fmt.Sprintf("Error executing template: %v", err)) - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -func handleProcess(w http.ResponseWriter, r *http.Request) { - logger.LogInfo("Handle Process", "Starting process") - if err := r.ParseForm(); err != nil { - logger.LogError("Handle Process", fmt.Sprintf("Error parsing form: %v", err)) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - selectedItems := r.Form["items"] - if len(selectedItems) == 0 { - logger.LogError("Handle Process", "No items selected") - http.Error(w, "No items selected", http.StatusBadRequest) - return - } - - itemsByFile := make(map[string][]string) - for _, item := range selectedItems { - parts := strings.SplitN(item, ":", 2) - if len(parts) != 2 { - logger.LogError("Handle Process", "Invalid item format") - continue - } - filename, itemName := parts[0], parts[1] - itemsByFile[filename] = append(itemsByFile[filename], itemName) - } - - for filename, items := range itemsByFile { - logger.LogInfo("Handle Process", fmt.Sprintf("Processing file: %s", filename)) - fullPath := filepath.Join(uploadDir, filename) - allItems, err := parseInputFile(fullPath) - if err != nil { - logger.LogError("Handle Process", fmt.Sprintf("Error parsing input file: %v", err)) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - selectedItems := filterSelectedItems(allItems, items) - sortItems(selectedItems) - go processItems(filename, selectedItems) - } - - http.Redirect(w, r, "/", http.StatusSeeOther) -} - -func handleProgress(w http.ResponseWriter, r *http.Request) { - filename := r.URL.Query().Get("filename") - - if r.Header.Get("Accept") == "application/json" { - progressInfo := getProgress(filename) - - if progressInfo == nil { - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]string{"error": "No progress information found"}) - return - } - - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(progressInfo) - if err != nil { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - return - } - - err := templates.ExecuteTemplate(w, "progress", struct{ Filename string }{filename}) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -func handlePause(w http.ResponseWriter, r *http.Request) { - filename := r.URL.Query().Get("filename") - if filename == "" { - logger.LogError("Pause Handler", "Filename is required") - http.Error(w, "Filename is required", http.StatusBadRequest) - return - } - - jobsMutex.Lock() - jobInfo, exists := jobs[filename] - jobsMutex.Unlock() - - if !exists { - logger.LogError("Pause Handler", "Job not found") - http.Error(w, "Job not found", http.StatusNotFound) - return - } - - jobInfo.Paused = true - if jobInfo.Cmd != nil && jobInfo.Cmd.Process != nil { - logger.LogJobState(filename, "pausing") - jobInfo.Cmd.Process.Kill() - } - - progressMutex.Lock() - if progressInfo, ok := progress[filename]; ok { - progressInfo.Paused = true - } - progressMutex.Unlock() - - fmt.Fprintf(w, "Pause signal sent for %s", filename) -} - -func handleResume(w http.ResponseWriter, r *http.Request) { - filename := r.URL.Query().Get("filename") - if filename == "" { - http.Error(w, "Filename is required", http.StatusBadRequest) - return - } - - jobsMutex.Lock() - jobInfo, exists := jobs[filename] - jobsMutex.Unlock() - - if !exists { - http.Error(w, "Job not found", http.StatusNotFound) - return - } - - jobInfo.Paused = false - jobInfo.ResumeChan <- struct{}{} - - progressMutex.Lock() - if progressInfo, ok := progress[filename]; ok { - progressInfo.Paused = false - } - progressMutex.Unlock() - - fmt.Fprintf(w, "Resume signal sent for %s", filename) -} - -func handleAbort(w http.ResponseWriter, r *http.Request) { - filename := r.URL.Query().Get("filename") - if filename == "" { - http.Error(w, "Filename is required", http.StatusBadRequest) - return - } - - jobsMutex.Lock() - jobInfo, exists := jobs[filename] - jobsMutex.Unlock() - - if !exists { - http.Error(w, "Job not found", http.StatusNotFound) - return - } - - close(jobInfo.AbortChan) - if jobInfo.Cmd != nil && jobInfo.Cmd.Process != nil { - jobInfo.Cmd.Process.Kill() - } - - if jobInfo.TempDir != "" { - os.RemoveAll(jobInfo.TempDir) - } - - fmt.Fprintf(w, "Abort signal sent for %s", filename) -} - -func handleClearCompleted(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - clearCompletedJobs() - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]bool{"success": true}) -} - -func clearCompletedJobs() { - progressMutex.Lock() - defer progressMutex.Unlock() - - for filename, info := range progress { - if info.Percentage >= 100 { - delete(progress, filename) - } - } -} - -func updateProgress(filename string, value float64, currentFile string) { - progressMutex.Lock() - defer progressMutex.Unlock() - - jobsMutex.Lock() - jobInfo, exists := jobs[filename] - jobsMutex.Unlock() - - paused := false - if exists { - paused = jobInfo.Paused - } - - if existingProgress, ok := progress[filename]; ok { - existingProgress.Percentage = value - existingProgress.CurrentFile = currentFile - existingProgress.Paused = paused - } else { - progress[filename] = &ProgressInfo{ - Percentage: value, - CurrentFile: currentFile, - Paused: paused, - } - } -} - -var upgrader = websocket.Upgrader{} -var clients = make(map[string]map[*websocket.Conn]bool) -var mu sync.Mutex - -func handleWebSocket(w http.ResponseWriter, r *http.Request) { - fmt.Println(config.General.EnableConsole) - if !config.General.EnableConsole { - http.Error(w, "Console output is disabled", http.StatusForbidden) - return - } - - filename := r.URL.Query().Get("filename") - if filename == "" { - http.Error(w, "Filename is required", http.StatusBadRequest) - return - } - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - logger.LogError("WebSocket", fmt.Sprintf("Error while upgrading connection: %v", err)) - return - } - defer conn.Close() - - logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection established for filename: %s", filename)) - - mu.Lock() - if clients[filename] == nil { - clients[filename] = make(map[*websocket.Conn]bool) - } - clients[filename][conn] = true - mu.Unlock() - - for { - if _, _, err := conn.NextReader(); err != nil { - break - } - } - - mu.Lock() - delete(clients[filename], conn) - mu.Unlock() - - logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection closed for filename: %s", filename)) -} - -func broadcast(filename string, message []byte) { - if !config.General.EnableConsole { - return - } - - mu.Lock() - defer mu.Unlock() - - for client := range clients[filename] { - if err := client.WriteMessage(websocket.TextMessage, message); err != nil { - client.Close() - delete(clients[filename], client) - logger.LogError("Broadcast", fmt.Sprintf("Error writing message to client: %v", err)) - } - } -} - -func handleSetSpeedLimit(w http.ResponseWriter, r *http.Request) { - logger.LogInfo("Set Speed Limit", "Received request to set speed limit") - - if r.Method != http.MethodPost { - logger.LogError("Set Speed Limit", "Invalid method") - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - var requestData struct { - SpeedLimit string `json:"speedLimit"` - } - - if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil { - logger.LogError("Set Speed Limit", "Invalid request body") - http.Error(w, "Invalid request", http.StatusBadRequest) - return - } - - if requestData.SpeedLimit == "unlimited" { - globalSpeedLimit = "" - } else { - globalSpeedLimit = requestData.SpeedLimit - } - - logger.LogInfo("Set Speed Limit", fmt.Sprintf("Global speed limit set to: %s", globalSpeedLimit)) - w.WriteHeader(http.StatusOK) -} diff --git a/src/handlers_common.go b/src/handlers_common.go new file mode 100644 index 0000000..cdf3ab3 --- /dev/null +++ b/src/handlers_common.go @@ -0,0 +1,30 @@ +package main + +import ( + "net/url" + "regexp" + "strings" +) + +type ProgressInfo struct { + Percentage float64 + CurrentFile string + Paused bool + Status string +} + +var speedLimitRegex = regexp.MustCompile(`^([1-9]\d*(\.\d+)?)(KBps|MBps|GBps)$`) + +func withToken(path string) string { + token := strings.TrimSpace(config.Security.AuthToken) + if token == "" { + return path + } + + separator := "?" + if strings.Contains(path, "?") { + separator = "&" + } + + return path + separator + "token=" + url.QueryEscape(token) +} diff --git a/src/handlers_jobs.go b/src/handlers_jobs.go new file mode 100644 index 0000000..9b409f1 --- /dev/null +++ b/src/handlers_jobs.go @@ -0,0 +1,174 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "strings" +) + +func handlePause(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) { + return + } + + filename := r.URL.Query().Get("filename") + if filename == "" { + logger.LogError("Pause Handler", "Filename is required") + http.Error(w, "Filename is required", http.StatusBadRequest) + return + } + + jobInfo, exists := getJob(filename) + if !exists { + logger.LogError("Pause Handler", "Job not found") + http.Error(w, "Job not found", http.StatusNotFound) + return + } + + jobInfo.SetPaused(true) + logger.LogJobState(filename, "pausing") + jobInfo.KillProcess() + + paused := true + setProgressStatus(filename, &paused, "paused") + + _, _ = fmt.Fprintf(w, "Pause signal sent for %s", filename) +} + +func handleResume(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) { + return + } + + filename := r.URL.Query().Get("filename") + if filename == "" { + http.Error(w, "Filename is required", http.StatusBadRequest) + return + } + + jobInfo, exists := getJob(filename) + if !exists { + http.Error(w, "Job not found", http.StatusNotFound) + return + } + + jobInfo.SetPaused(false) + jobInfo.SignalResume() + + paused := false + setProgressStatus(filename, &paused, "running") + + _, _ = fmt.Fprintf(w, "Resume signal sent for %s", filename) +} + +func handleAbort(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) { + return + } + + filename := r.URL.Query().Get("filename") + if filename == "" { + http.Error(w, "Filename is required", http.StatusBadRequest) + return + } + + jobInfo, exists := getJob(filename) + if !exists { + http.Error(w, "Job not found", http.StatusNotFound) + return + } + + jobInfo.Abort() + jobInfo.KillProcess() + + if tempDir := jobInfo.GetTempDir(); tempDir != "" { + _ = os.RemoveAll(tempDir) + } + + paused := false + setProgressStatus(filename, &paused, "aborted") + + _, _ = fmt.Fprintf(w, "Abort signal sent for %s", filename) +} + +func handleClearCompleted(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) { + return + } + + clearCompletedJobs() + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]bool{"success": true}) +} + +func clearCompletedJobs() { + progressMutex.Lock() + defer progressMutex.Unlock() + + for filename, info := range progress { + if info.Percentage >= 100 || info.Status == "failed" || info.Status == "aborted" { + delete(progress, filename) + } + } +} + +func updateProgress(filename string, value float64, currentFile, status string) { + paused := false + if jobInfo, exists := getJob(filename); exists { + paused = jobInfo.IsPaused() + } + + progressMutex.Lock() + defer progressMutex.Unlock() + + if existingProgress, ok := progress[filename]; ok { + existingProgress.Percentage = value + existingProgress.CurrentFile = currentFile + existingProgress.Paused = paused + if status != "" { + existingProgress.Status = status + } + } else { + progress[filename] = &ProgressInfo{ + Percentage: value, + CurrentFile: currentFile, + Paused: paused, + Status: status, + } + } +} + +func handleSetSpeedLimit(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) { + return + } + + logger.LogInfo("Set Speed Limit", "Received request to set speed limit") + + var requestData struct { + SpeedLimit string `json:"speedLimit"` + } + + if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil { + logger.LogError("Set Speed Limit", "Invalid request body") + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + requestData.SpeedLimit = strings.TrimSpace(requestData.SpeedLimit) + if requestData.SpeedLimit == "unlimited" { + setGlobalSpeedLimit("") + } else { + if !speedLimitRegex.MatchString(requestData.SpeedLimit) { + http.Error(w, "Invalid speed limit format", http.StatusBadRequest) + return + } + setGlobalSpeedLimit(requestData.SpeedLimit) + } + + logger.LogInfo("Set Speed Limit", fmt.Sprintf("Global speed limit set to: %s", getGlobalSpeedLimit())) + w.WriteHeader(http.StatusOK) +} diff --git a/src/handlers_pages.go b/src/handlers_pages.go new file mode 100644 index 0000000..436c777 --- /dev/null +++ b/src/handlers_pages.go @@ -0,0 +1,128 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "strings" +) + +func handleRoot(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) { + return + } + + err := templates.ExecuteTemplate(w, "index", struct { + Jobs map[string]ProgressInfo + GlobalSpeedLimit string + AuthToken string + Nonce string + }{ + Jobs: snapshotProgress(), + GlobalSpeedLimit: getGlobalSpeedLimit(), + AuthToken: url.QueryEscape(config.Security.AuthToken), + Nonce: cspNonce(r), + }) + if err != nil { + logger.LogError("Handle Root", fmt.Sprintf("Error executing template: %v", err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func handleSelect(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) { + return + } + + filesParam := r.URL.Query().Get("files") + filenames := strings.Split(filesParam, ",") + + allItems := make(map[string]map[string][]Item) + + for _, filename := range filenames { + if filename == "" { + continue + } + + fullPath, pathErr := safeUploadPath(filename) + if pathErr != nil { + logger.LogError("Handle Select", fmt.Sprintf("Invalid filename %s: %v", filename, pathErr)) + continue + } + + if _, err := os.Stat(fullPath); os.IsNotExist(err) { + logger.LogError("Handle Select", fmt.Sprintf("File does not exist: %s", fullPath)) + continue + } + + items, err := parseInputFile(fullPath) + if err != nil { + logger.LogError("Handle Select", fmt.Sprintf("Error parsing input file: %v", err)) + continue + } + + sortItems(items) + + groupedItems := groupItemsBySeason(items) + allItems[filename] = groupedItems + } + + if len(allItems) == 0 { + logger.LogError("Handle Select", "No valid files were processed") + http.Error(w, "No valid files were processed", http.StatusBadRequest) + return + } + + err := templates.ExecuteTemplate(w, "select", struct { + Filenames string + AllItems map[string]map[string][]Item + AuthToken string + Nonce string + }{ + Filenames: filesParam, + AllItems: allItems, + AuthToken: url.QueryEscape(config.Security.AuthToken), + Nonce: cspNonce(r), + }) + if err != nil { + logger.LogError("Handle Select", fmt.Sprintf("Error executing template: %v", err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func handleProgress(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) { + return + } + + filename := r.URL.Query().Get("filename") + + if r.Header.Get("Accept") == "application/json" { + progressInfo := getProgress(filename) + + if progressInfo == nil { + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "No progress information found"}) + return + } + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(progressInfo) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + return + } + + err := templates.ExecuteTemplate(w, "progress", struct { + Filename string + AuthToken string + Nonce string + }{Filename: filename, AuthToken: url.QueryEscape(config.Security.AuthToken), Nonce: cspNonce(r)}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} diff --git a/src/handlers_upload_process.go b/src/handlers_upload_process.go new file mode 100644 index 0000000..71e45e2 --- /dev/null +++ b/src/handlers_upload_process.go @@ -0,0 +1,158 @@ +package main + +import ( + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" +) + +func handleUpload(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) { + return + } + + logger.LogInfo("Handle Upload", "Starting file upload") + r.Body = http.MaxBytesReader(w, r.Body, maxUploadBytes()) + const multipartMemory = 4 << 20 + err := r.ParseMultipartForm(multipartMemory) + if err != nil { + logger.LogError("Handle Upload", fmt.Sprintf("Error parsing multipart form: %v", err)) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + files := r.MultipartForm.File["files"] + if len(files) == 0 { + logger.LogError("Handle Upload", "No files uploaded") + http.Error(w, "No files uploaded", http.StatusBadRequest) + return + } + + uploadedFiles := []string{} + + for _, fileHeader := range files { + if strings.ToLower(filepath.Ext(fileHeader.Filename)) != ".drmd" { + http.Error(w, "Only .drmd files are allowed", http.StatusBadRequest) + return + } + + file, err := fileHeader.Open() + if err != nil { + logger.LogError("Handle Upload", fmt.Sprintf("Error opening file: %v", err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + pattern := sanitizeFilename(fileHeader.Filename) + "_*.drmd" + tempFile, err := os.CreateTemp(uploadDir, pattern) + if err != nil { + _ = file.Close() + logger.LogError("Handle Upload", fmt.Sprintf("Error creating temporary file: %v", err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + _, err = io.Copy(tempFile, file) + _ = file.Close() + if err != nil { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + logger.LogError("Handle Upload", fmt.Sprintf("Error copying file: %v", err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if err := tempFile.Close(); err != nil { + _ = os.Remove(tempFile.Name()) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + uploadedFiles = append(uploadedFiles, filepath.Base(tempFile.Name())) + + _, err = parseInputFile(tempFile.Name()) + if err != nil { + _ = os.Remove(tempFile.Name()) + logger.LogError("Handle Upload", fmt.Sprintf("Error parsing input file: %v", err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + + validFiles := []string{} + for _, file := range uploadedFiles { + if file != "" { + validFiles = append(validFiles, file) + } + } + + if len(validFiles) == 0 { + logger.LogError("Handle Upload", "No valid files were uploaded") + http.Error(w, "No valid files were uploaded", http.StatusBadRequest) + return + } + + logger.LogInfo("Handle Upload", fmt.Sprintf("Redirecting to select with files: %v", validFiles)) + http.Redirect(w, r, withToken("/select?files="+url.QueryEscape(strings.Join(validFiles, ","))), http.StatusSeeOther) +} + +func handleProcess(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) { + return + } + + logger.LogInfo("Handle Process", "Starting process") + if err := r.ParseForm(); err != nil { + logger.LogError("Handle Process", fmt.Sprintf("Error parsing form: %v", err)) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + selectedItems := r.Form["items"] + if len(selectedItems) == 0 { + logger.LogError("Handle Process", "No items selected") + http.Error(w, "No items selected", http.StatusBadRequest) + return + } + + itemsByFile := make(map[string][]string) + for _, item := range selectedItems { + parts := strings.SplitN(item, ":", 2) + if len(parts) != 2 { + logger.LogError("Handle Process", "Invalid item format") + continue + } + filename, itemName := parts[0], parts[1] + itemsByFile[filename] = append(itemsByFile[filename], itemName) + } + + for filename, items := range itemsByFile { + logger.LogInfo("Handle Process", fmt.Sprintf("Processing file: %s", filename)) + fullPath, pathErr := safeUploadPath(filename) + if pathErr != nil { + http.Error(w, pathErr.Error(), http.StatusBadRequest) + return + } + + allItems, err := parseInputFile(fullPath) + if err != nil { + logger.LogError("Handle Process", fmt.Sprintf("Error parsing input file: %v", err)) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + selectedItems := filterSelectedItems(allItems, items) + sortItems(selectedItems) + go func(targetFilename string, targetItems []Item) { + if err := processItems(targetFilename, targetItems); err != nil { + logger.LogError("Handle Process", fmt.Sprintf("Error processing %s: %v", targetFilename, err)) + } + }(filename, selectedItems) + } + + http.Redirect(w, r, withToken("/"), http.StatusSeeOther) +} diff --git a/src/handlers_ws.go b/src/handlers_ws.go new file mode 100644 index 0000000..4b9f574 --- /dev/null +++ b/src/handlers_ws.go @@ -0,0 +1,148 @@ +package main + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + wsSendBuffer = 64 + wsWriteWait = 10 * time.Second +) + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + origin := strings.TrimSpace(r.Header.Get("Origin")) + if origin == "" { + return strings.TrimSpace(config.Security.AuthToken) == "" + } + parsedOrigin, err := url.Parse(origin) + if err != nil { + return false + } + return strings.EqualFold(parsedOrigin.Host, r.Host) + }, +} + +type wsClient struct { + conn *websocket.Conn + send chan []byte + once sync.Once +} + +func (c *wsClient) close() { + c.once.Do(func() { + close(c.send) + }) +} + +var clients = make(map[string]map[*wsClient]struct{}) +var mu sync.Mutex + +func handleWebSocket(w http.ResponseWriter, r *http.Request) { + if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) { + return + } + + if !config.General.EnableConsole { + http.Error(w, "Console output is disabled", http.StatusForbidden) + return + } + + filename := r.URL.Query().Get("filename") + if filename == "" { + http.Error(w, "Filename is required", http.StatusBadRequest) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.LogError("WebSocket", fmt.Sprintf("Error while upgrading connection: %v", err)) + return + } + + logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection established for filename: %s", filename)) + + client := &wsClient{conn: conn, send: make(chan []byte, wsSendBuffer)} + + mu.Lock() + if clients[filename] == nil { + clients[filename] = make(map[*wsClient]struct{}) + } + clients[filename][client] = struct{}{} + mu.Unlock() + + go writePump(filename, client) + + for { + if _, _, err := conn.NextReader(); err != nil { + break + } + } + + mu.Lock() + if set, ok := clients[filename]; ok { + if _, exists := set[client]; exists { + delete(set, client) + client.close() + } + if len(set) == 0 { + delete(clients, filename) + } + } + mu.Unlock() + _ = conn.Close() + + logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection closed for filename: %s", filename)) +} + +func writePump(filename string, client *wsClient) { + defer client.conn.Close() + + for message := range client.send { + if err := client.conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil { + logger.LogError("WebSocket", fmt.Sprintf("SetWriteDeadline error: %v", err)) + return + } + if err := client.conn.WriteMessage(websocket.TextMessage, message); err != nil { + logger.LogError("Broadcast", fmt.Sprintf("Error writing to client for %s: %v", filename, err)) + return + } + } +} + +func broadcast(filename string, message []byte) { + if !config.General.EnableConsole { + return + } + + mu.Lock() + set := clients[filename] + targets := make([]*wsClient, 0, len(set)) + for c := range set { + targets = append(targets, c) + } + mu.Unlock() + + for _, client := range targets { + select { + case client.send <- message: + default: + mu.Lock() + if set, ok := clients[filename]; ok { + if _, exists := set[client]; exists { + delete(set, client) + client.close() + } + } + mu.Unlock() + logger.LogError("Broadcast", fmt.Sprintf("Dropping slow client for %s", filename)) + } + } +} diff --git a/src/integration_test.go b/src/integration_test.go new file mode 100644 index 0000000..1dfdd7b --- /dev/null +++ b/src/integration_test.go @@ -0,0 +1,109 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func resetStateForTest() { + jobsMutex.Lock() + jobs = make(map[string]*JobInfo) + jobsMutex.Unlock() + + progressMutex.Lock() + progress = make(map[string]*ProgressInfo) + progressMutex.Unlock() + + setGlobalSpeedLimit("") + config = Config{} + setDefaultConfigValues() +} + +func TestAuthTokenProtection(t *testing.T) { + resetStateForTest() + config.Security.AuthToken = "secret" + + handler := newRouter() + + req := httptest.NewRequest(http.MethodPost, "/clear-completed", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code) + } + + req = httptest.NewRequest(http.MethodPost, "/clear-completed?token=secret", nil) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code) + } +} + +func TestPauseResumeAbortFlow(t *testing.T) { + resetStateForTest() + handler := newRouter() + + filename := "job.drmd" + setJob(filename, NewJobInfo()) + updateProgress(filename, 10, "episode1", "running") + + req := httptest.NewRequest(http.MethodPost, "/pause?filename="+filename, nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("pause expected %d got %d", http.StatusOK, rr.Code) + } + + job, ok := getJob(filename) + if !ok || !job.IsPaused() { + t.Fatalf("expected paused job state") + } + + progressInfo := getProgress(filename) + if progressInfo == nil || progressInfo.Status != "paused" { + t.Fatalf("expected paused progress state") + } + + req = httptest.NewRequest(http.MethodPost, "/resume?filename="+filename, nil) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("resume expected %d got %d", http.StatusOK, rr.Code) + } + + if job.IsPaused() { + t.Fatalf("expected resumed job state") + } + + progressInfo = getProgress(filename) + if progressInfo == nil || progressInfo.Status != "running" { + t.Fatalf("expected running progress state") + } + + req = httptest.NewRequest(http.MethodPost, "/abort?filename="+filename, nil) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("abort expected %d got %d", http.StatusOK, rr.Code) + } + + if !job.IsAborted() { + t.Fatalf("expected aborted job state") + } + + progressInfo = getProgress(filename) + if progressInfo == nil || progressInfo.Status != "aborted" { + t.Fatalf("expected aborted progress state") + } + + req = httptest.NewRequest(http.MethodPost, "/abort?filename="+filename, nil) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("second abort expected %d got %d", http.StatusOK, rr.Code) + } +} diff --git a/src/logger.go b/src/logger.go index 4375a0f..c1e8f18 100644 --- a/src/logger.go +++ b/src/logger.go @@ -19,7 +19,7 @@ const ( func NewLogger(prefix string) *Logger { return &Logger{ - Logger: log.New(os.Stdout, prefix, log.Ldate|log.Ltime|log.Lshortfile), + Logger: log.New(os.Stdout, prefix, log.Ldate|log.Ltime), } } diff --git a/src/main.go b/src/main.go index cd02a00..58184ec 100644 --- a/src/main.go +++ b/src/main.go @@ -1,13 +1,17 @@ package main import ( + "context" "flag" "fmt" "html/template" "net/http" "os" + "os/signal" "strings" "sync" + "syscall" + "time" "embed" ) @@ -34,9 +38,6 @@ type Metadata struct { Season string } -var progressMutex sync.Mutex -var progress = make(map[string]*ProgressInfo) - const uploadDir = "uploads" var templates *template.Template @@ -45,6 +46,19 @@ var templates *template.Template var templateFS embed.FS var globalSpeedLimit string +var globalSpeedLimitMutex sync.RWMutex + +func getGlobalSpeedLimit() string { + globalSpeedLimitMutex.RLock() + defer globalSpeedLimitMutex.RUnlock() + return globalSpeedLimit +} + +func setGlobalSpeedLimit(value string) { + globalSpeedLimitMutex.Lock() + defer globalSpeedLimitMutex.Unlock() + globalSpeedLimit = value +} func init() { if err := os.MkdirAll(uploadDir, 0755); err != nil { @@ -57,10 +71,12 @@ func init() { } func main() { - loadConfig() + configPath := flag.String("config", "config.toml", "Path to config file") inputFile := flag.String("f", "", "Path to the input JSON file") flag.Parse() + loadConfig(*configPath) + if *inputFile == "" { go watchFolder() startWebServer() @@ -70,35 +86,67 @@ func main() { logger.LogError("Main", fmt.Sprintf("Error parsing input file: %v", err)) return } - processItems(*inputFile, items) + if err := processItems(*inputFile, items); err != nil { + logger.LogError("Main", fmt.Sprintf("Error processing items: %v", err)) + } } } func startWebServer() { - http.HandleFunc("/", handleRoot) - http.HandleFunc("/upload", handleUpload) - http.HandleFunc("/select", handleSelect) - http.HandleFunc("/process", handleProcess) - http.HandleFunc("/progress", handleProgress) - http.HandleFunc("/abort", handleAbort) - http.HandleFunc("/pause", handlePause) - http.HandleFunc("/resume", handleResume) - http.HandleFunc("/clear-completed", handleClearCompleted) - http.HandleFunc("/ws", handleWebSocket) - http.HandleFunc("/set-speed-limit", handleSetSpeedLimit) + server := &http.Server{ + Addr: serverAddr(), + Handler: newRouter(), + ReadTimeout: time.Duration(config.Server.ReadTimeoutSec) * time.Second, + ReadHeaderTimeout: time.Duration(config.Server.ReadHeaderTimeoutS) * time.Second, + WriteTimeout: time.Duration(config.Server.WriteTimeoutSec) * time.Second, + IdleTimeout: time.Duration(config.Server.IdleTimeoutSec) * time.Second, + } - logger.LogInfo("Main", "Starting web server on http://0.0.0.0:8080") - http.ListenAndServe(":8080", nil) + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + go func() { + <-ctx.Done() + logger.LogInfo("Main", "Shutdown signal received, aborting jobs") + abortAllJobs() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + logger.LogError("Main", fmt.Sprintf("Server shutdown error: %v", err)) + } + }() + + logger.LogInfo("Main", fmt.Sprintf("Starting web server on http://%s", server.Addr)) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.LogError("Main", fmt.Sprintf("Web server error: %v", err)) + } } -func getProgress(filename string) *ProgressInfo { - progressMutex.Lock() - defer progressMutex.Unlock() - return progress[filename] +func abortAllJobs() { + jobsMutex.RLock() + jobList := make([]*JobInfo, 0, len(jobs)) + for _, j := range jobs { + jobList = append(jobList, j) + } + jobsMutex.RUnlock() + + for _, j := range jobList { + j.Abort() + j.KillProcess() + } } func getKeys(keys string) []string { - return strings.Split(keys, ",") + parts := strings.Split(keys, ",") + out := parts[:0] + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + return out } func parseMetadata(metadata string) Metadata { diff --git a/src/main_test.go b/src/main_test.go index f613369..498dc1e 100644 --- a/src/main_test.go +++ b/src/main_test.go @@ -3,6 +3,7 @@ package main import ( "encoding/json" "os" + "path/filepath" "reflect" "testing" ) @@ -123,3 +124,53 @@ func TestGroupItemsBySeason(t *testing.T) { } } } + +func TestSafeUploadPath(t *testing.T) { + tests := []struct { + name string + input string + wantError bool + }{ + {name: "valid filename", input: "file.drmd", wantError: false}, + {name: "directory traversal", input: "../file.drmd", wantError: true}, + {name: "empty", input: "", wantError: true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + path, err := safeUploadPath(tc.input) + if tc.wantError && err == nil { + t.Fatalf("expected error for input %q", tc.input) + } + if !tc.wantError { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if filepath.Base(path) != tc.input { + t.Fatalf("expected basename %q, got %q", tc.input, filepath.Base(path)) + } + } + }) + } +} + +func TestValidateRemoteURL(t *testing.T) { + tests := []struct { + name string + input string + wantError bool + }{ + {name: "reject localhost", input: "http://localhost/test.vtt", wantError: true}, + {name: "reject invalid scheme", input: "file:///tmp/test", wantError: true}, + {name: "reject malformed", input: "::://", wantError: true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := validateRemoteURL(tc.input) + if tc.wantError && err == nil { + t.Fatalf("expected error for %q", tc.input) + } + }) + } +} diff --git a/src/router.go b/src/router.go new file mode 100644 index 0000000..20b9524 --- /dev/null +++ b/src/router.go @@ -0,0 +1,19 @@ +package main + +import "net/http" + +func newRouter() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/", handleRoot) + mux.HandleFunc("/upload", handleUpload) + mux.HandleFunc("/select", handleSelect) + mux.HandleFunc("/process", handleProcess) + mux.HandleFunc("/progress", handleProgress) + mux.HandleFunc("/abort", handleAbort) + mux.HandleFunc("/pause", handlePause) + mux.HandleFunc("/resume", handleResume) + mux.HandleFunc("/clear-completed", handleClearCompleted) + mux.HandleFunc("/ws", handleWebSocket) + mux.HandleFunc("/set-speed-limit", handleSetSpeedLimit) + return withSecurityHeaders(mux) +} diff --git a/src/security.go b/src/security.go new file mode 100644 index 0000000..a4afb01 --- /dev/null +++ b/src/security.go @@ -0,0 +1,237 @@ +package main + +import ( + "context" + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +var blockedCIDRs = mustParseCIDRs([]string{ + "127.0.0.0/8", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "169.254.0.0/16", + "100.64.0.0/10", + "0.0.0.0/8", + "::1/128", + "fc00::/7", + "fe80::/10", +}) + +var secureDialer = &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, +} + +var secureTransport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + for _, ip := range ips { + if isBlockedIP(ip.IP) { + return nil, fmt.Errorf("dial blocked: %s resolves to %s", host, ip.IP) + } + } + if len(ips) == 0 { + return nil, fmt.Errorf("no IPs for %s", host) + } + return secureDialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port)) + }, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, +} + +var secureHTTPClient = &http.Client{ + Timeout: 30 * time.Second, + Transport: secureTransport, + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + _, err := validateRemoteURL(req.URL.String()) + return err + }, +} + +const maxRemoteFetchBytes int64 = 20 << 20 + +func mustParseCIDRs(cidrs []string) []*net.IPNet { + networks := make([]*net.IPNet, 0, len(cidrs)) + for _, cidr := range cidrs { + _, network, err := net.ParseCIDR(cidr) + if err != nil { + panic(err) + } + networks = append(networks, network) + } + return networks +} + +func maxUploadBytes() int64 { + return int64(config.General.MaxUploadMB) << 20 +} + +func serverAddr() string { + return fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port) +} + +func ensureMethod(w http.ResponseWriter, r *http.Request, method string) bool { + if r.Method != method { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return false + } + return true +} + +func ensureAuthorized(w http.ResponseWriter, r *http.Request) bool { + token := strings.TrimSpace(config.Security.AuthToken) + if token == "" { + return true + } + + providedToken := strings.TrimSpace(r.Header.Get("X-DRMD-Token")) + if providedToken == "" { + providedToken = strings.TrimSpace(r.URL.Query().Get("token")) + } + + if subtle.ConstantTimeCompare([]byte(providedToken), []byte(token)) != 1 { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return false + } + + return true +} + +func validateRemoteURL(rawURL string) (*url.URL, error) { + parsedURL, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return nil, errors.New("only http(s) URLs are allowed") + } + + host := parsedURL.Hostname() + if host == "" { + return nil, errors.New("URL host is required") + } + + if strings.EqualFold(host, "localhost") { + return nil, errors.New("localhost URLs are not allowed") + } + + ips, err := net.LookupIP(host) + if err != nil { + return nil, fmt.Errorf("unable to resolve URL host: %w", err) + } + + for _, ip := range ips { + if isBlockedIP(ip) { + return nil, fmt.Errorf("URL host resolves to blocked IP range: %s", ip.String()) + } + } + + return parsedURL, nil +} + +func isBlockedIP(ip net.IP) bool { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() || ip.IsMulticast() { + return true + } + + for _, network := range blockedCIDRs { + if network.Contains(ip) { + return true + } + } + + return false +} + +func fetchRemoteContent(rawURL string) ([]byte, error) { + validatedURL, err := validateRemoteURL(rawURL) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(http.MethodGet, validatedURL.String(), nil) + if err != nil { + return nil, err + } + + resp, err := secureHTTPClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + + if resp.ContentLength > maxRemoteFetchBytes { + return nil, fmt.Errorf("remote content too large: %d bytes", resp.ContentLength) + } + + limited := io.LimitReader(resp.Body, maxRemoteFetchBytes+1) + body, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + + if int64(len(body)) > maxRemoteFetchBytes { + return nil, fmt.Errorf("remote content exceeded %d bytes", maxRemoteFetchBytes) + } + + return body, nil +} + +type ctxKey string + +const nonceCtxKey ctxKey = "csp-nonce" + +func generateNonce() string { + var b [16]byte + if _, err := rand.Read(b[:]); err != nil { + return "" + } + return base64.RawStdEncoding.EncodeToString(b[:]) +} + +func cspNonce(r *http.Request) string { + if v, ok := r.Context().Value(nonceCtxKey).(string); ok { + return v + } + return "" +} + +func withSecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nonce := generateNonce() + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Referrer-Policy", "no-referrer") + csp := fmt.Sprintf( + "default-src 'self'; connect-src 'self'; img-src 'self' data:; style-src 'self' 'nonce-%s'; script-src 'self' 'nonce-%s'; base-uri 'self'; form-action 'self'; frame-ancestors 'none'", + nonce, nonce, + ) + w.Header().Set("Content-Security-Policy", csp) + + ctx := context.WithValue(r.Context(), nonceCtxKey, nonce) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/src/state.go b/src/state.go new file mode 100644 index 0000000..3251dc9 --- /dev/null +++ b/src/state.go @@ -0,0 +1,67 @@ +package main + +import "sync" + +var ( + jobsMutex sync.RWMutex + jobs = make(map[string]*JobInfo) + + progressMutex sync.RWMutex + progress = make(map[string]*ProgressInfo) +) + +func setJob(filename string, jobInfo *JobInfo) { + jobsMutex.Lock() + defer jobsMutex.Unlock() + jobs[filename] = jobInfo +} + +func getJob(filename string) (*JobInfo, bool) { + jobsMutex.RLock() + defer jobsMutex.RUnlock() + jobInfo, ok := jobs[filename] + return jobInfo, ok +} + +func removeJob(filename string) { + jobsMutex.Lock() + defer jobsMutex.Unlock() + delete(jobs, filename) +} + +func getProgress(filename string) *ProgressInfo { + progressMutex.RLock() + defer progressMutex.RUnlock() + if p, ok := progress[filename]; ok { + snapshot := *p + return &snapshot + } + return nil +} + +func snapshotProgress() map[string]ProgressInfo { + progressMutex.RLock() + defer progressMutex.RUnlock() + + result := make(map[string]ProgressInfo, len(progress)) + for filename, info := range progress { + result[filename] = *info + } + + return result +} + +func setProgressStatus(filename string, paused *bool, status string) { + progressMutex.Lock() + defer progressMutex.Unlock() + info, ok := progress[filename] + if !ok { + return + } + if paused != nil { + info.Paused = *paused + } + if status != "" { + info.Status = status + } +} diff --git a/src/subtitles.go b/src/subtitles.go index 91b22b8..12b4af5 100644 --- a/src/subtitles.go +++ b/src/subtitles.go @@ -1,9 +1,9 @@ package main import ( + "bytes" "fmt" "io" - "net/http" "os" "strings" @@ -36,12 +36,11 @@ func downloadAndConvertSubtitles(subtitlesURLs string) ([]string, error) { func downloadSubtitle(url string) (string, error) { logger.LogInfo("Download Subtitle", fmt.Sprintf("Starting download from %s", url)) - resp, err := http.Get(url) + body, err := fetchRemoteContent(url) if err != nil { logger.LogError("Download Subtitle", fmt.Sprintf("Error getting subtitle URL: %v", err)) return "", err } - defer resp.Body.Close() tempFile, err := os.CreateTemp("", "subtitle_*.vtt") if err != nil { @@ -50,7 +49,7 @@ func downloadSubtitle(url string) (string, error) { } defer tempFile.Close() - _, err = io.Copy(tempFile, resp.Body) + _, err = io.Copy(tempFile, bytes.NewReader(body)) if err != nil { logger.LogError("Download Subtitle", fmt.Sprintf("Error copying to temp file: %v", err)) return "", err @@ -62,8 +61,13 @@ func downloadSubtitle(url string) (string, error) { func convertVTTtoSRT(vttPath string) (string, error) { srtPath := strings.TrimSuffix(vttPath, ".vtt") + ".srt" - s1, _ := astisub.OpenFile(vttPath) - s1.Write(srtPath) + s1, err := astisub.OpenFile(vttPath) + if err != nil { + return "", err + } + if err := s1.Write(srtPath); err != nil { + return "", err + } logger.LogInfo("Convert VTT to SRT", fmt.Sprintf("Converted %s to %s", vttPath, srtPath)) return srtPath, nil } diff --git a/src/templates/index b/src/templates/index index 8e2135d..b387c23 100644 --- a/src/templates/index +++ b/src/templates/index @@ -4,7 +4,7 @@