Harden web/download pipeline and split handler modules

Replace shell-based downloader execution with validated arguments, enforce request hardening and safer defaults, and refactor handlers/router/state so job control is safer and easier to maintain.
This commit is contained in:
2026-04-14 10:21:11 +02:00
parent 6e016b802b
commit 1c82b619c4
25 changed files with 1722 additions and 667 deletions

6
.gitignore vendored
View File

@@ -1 +1,7 @@
config.toml
drmdtool
drmdtool_*
src/DRMDTool
*.exe
uploads/
src/uploads/

View File

@@ -12,6 +12,7 @@ BaseDir = "/path/to/save/downloads"
Format = "mkv"
TempBaseDir = "/tmp/nre"
EnableConsole = true
MaxUploadMB = 32
[WatchFolder]
Path = "/path/to/watched/folder"
@@ -21,6 +22,17 @@ UseInotify = false
[N_m3u8DLRE]
Path = "/path/to/N_m3u8DL-RE"
[Server]
Host = "127.0.0.1"
Port = 8080
ReadTimeoutSec = 30
WriteTimeoutSec = 30
IdleTimeoutSec = 60
ReadHeaderTimeoutS = 10
[Security]
AuthToken = ""
```
### Configuration Options
@@ -30,6 +42,7 @@ Path = "/path/to/N_m3u8DL-RE"
- `Format`: Output format for the downloaded files (e.g., `mkv`, `mp4`).
- `TempBaseDir`: Temporary directory for intermediate files.
- `EnableConsole`: Boolean to enable or disable console output.
- `MaxUploadMB`: Maximum allowed upload size for the web UI.
- **WatchFolder**
- `Path`: Directory to watch for new `.drmd` files.
@@ -40,6 +53,14 @@ Path = "/path/to/N_m3u8DL-RE"
- **N_m3u8DLRE**
- `Path`: Path to the N_m3u8DL-RE executable.
- **Server**
- `Host`: Bind address for the web server (`127.0.0.1` recommended).
- `Port`: Web server port.
- `ReadTimeoutSec`, `WriteTimeoutSec`, `IdleTimeoutSec`, `ReadHeaderTimeoutS`: HTTP timeout settings.
- **Security**
- `AuthToken`: Optional token for protecting all endpoints. Recommended when binding to a non-loopback host.
### Environment Variable Overrides
You can override the configuration options using environment variables. The following environment variables are supported:
@@ -48,10 +69,14 @@ You can override the configuration options using environment variables. The foll
- `FORMAT`: Overrides `General.Format`
- `TEMP_BASE_DIR`: Overrides `General.TempBaseDir`
- `ENABLE_CONSOLE`: Overrides `General.EnableConsole` (set to `true` or `false`)
- `MAX_UPLOAD_MB`: Overrides `General.MaxUploadMB`
- `WATCHED_FOLDER`: Overrides `WatchFolder.Path`
- `USE_POLLING`: Overrides `WatchFolder.UsePolling` (set to `true` or `false`)
- `USE_INOTIFY`: Overrides `WatchFolder.UseInotify` (set to `true` or `false`)
- `POLLING_INTERVAL`: Overrides `WatchFolder.PollingInterval`
- `SERVER_HOST`: Overrides `Server.Host`
- `SERVER_PORT`: Overrides `Server.Port`
- `AUTH_TOKEN`: Overrides `Security.AuthToken`
## Web UI Usage
@@ -62,6 +87,9 @@ You can override the configuration options using environment variables. The foll
2. Open a web browser and go to `http://localhost:8080`
If `Security.AuthToken` is configured, include it as a query parameter:
`http://localhost:8080/?token=YOUR_TOKEN`
3. Use the interface to upload .drmd files and monitor download progress
## CLI Usage

View File

@@ -3,6 +3,7 @@ BaseDir = "/mnt/media"
Format = "mkv"
TempBaseDir = "/tmp/nre"
EnableConsole = true
MaxUploadMB = 32
[WatchFolder]
Path = "/mnt/watched"
@@ -12,3 +13,14 @@ UseInotify = false
[N_m3u8DLRE]
Path = "nre"
[Server]
Host = "127.0.0.1"
Port = 8080
ReadTimeoutSec = 30
WriteTimeoutSec = 30
IdleTimeoutSec = 60
ReadHeaderTimeoutS = 10
[Security]
AuthToken = ""

View File

@@ -1,9 +1,11 @@
package main
import (
"errors"
"fmt"
"io"
"os"
"os/exec"
"strconv"
"strings"
@@ -16,6 +18,7 @@ type Config struct {
Format string
TempBaseDir string
EnableConsole bool
MaxUploadMB int
}
WatchFolder struct {
Path string
@@ -23,6 +26,17 @@ type Config struct {
UseInotify bool
PollingInterval int
}
Server struct {
Host string
Port int
ReadTimeoutSec int
WriteTimeoutSec int
IdleTimeoutSec int
ReadHeaderTimeoutS int
}
Security struct {
AuthToken string
}
N_m3u8DLRE struct {
Path string
}
@@ -30,15 +44,19 @@ type Config struct {
var config Config
func loadConfig() {
configFile, err := os.Open("config.toml")
func loadConfig(path string) {
configFile, err := os.Open(path)
if err != nil {
logger.LogError("Config", fmt.Sprintf("Error opening config file: %v", err))
os.Exit(1)
}
defer configFile.Close()
byteValue, _ := io.ReadAll(configFile)
byteValue, err := io.ReadAll(configFile)
if err != nil {
logger.LogError("Config", fmt.Sprintf("Error reading config file: %v", err))
os.Exit(1)
}
if _, err := toml.Decode(string(byteValue), &config); err != nil {
logger.LogError("Config", fmt.Sprintf("Error decoding config file: %v", err))
@@ -46,6 +64,7 @@ func loadConfig() {
}
overrideConfigWithEnv()
setDefaultConfigValues()
if err := validatePaths(); err != nil {
logger.LogError("Config", fmt.Sprintf("Configuration error: %v", err))
@@ -59,6 +78,36 @@ func loadConfig() {
logConfig()
}
func setDefaultConfigValues() {
if config.General.MaxUploadMB <= 0 {
config.General.MaxUploadMB = 32
}
if strings.TrimSpace(config.Server.Host) == "" {
config.Server.Host = "127.0.0.1"
}
if config.Server.Port <= 0 {
config.Server.Port = 8080
}
if config.Server.ReadTimeoutSec <= 0 {
config.Server.ReadTimeoutSec = 30
}
if config.Server.WriteTimeoutSec <= 0 {
config.Server.WriteTimeoutSec = 30
}
if config.Server.IdleTimeoutSec <= 0 {
config.Server.IdleTimeoutSec = 60
}
if config.Server.ReadHeaderTimeoutS <= 0 {
config.Server.ReadHeaderTimeoutS = 10
}
}
func overrideConfigWithEnv() {
if envBaseDir := os.Getenv("BASE_DIR"); envBaseDir != "" {
config.General.BaseDir = envBaseDir
@@ -86,14 +135,40 @@ func overrideConfigWithEnv() {
config.WatchFolder.PollingInterval = interval
}
}
if envMaxUploadMB := os.Getenv("MAX_UPLOAD_MB"); envMaxUploadMB != "" {
if value, err := strconv.Atoi(envMaxUploadMB); err == nil {
config.General.MaxUploadMB = value
}
}
if envHost := os.Getenv("SERVER_HOST"); envHost != "" {
config.Server.Host = envHost
}
if envPort := os.Getenv("SERVER_PORT"); envPort != "" {
if value, err := strconv.Atoi(envPort); err == nil {
config.Server.Port = value
}
}
if envAuthToken := os.Getenv("AUTH_TOKEN"); envAuthToken != "" {
config.Security.AuthToken = envAuthToken
}
}
func validatePaths() error {
if strings.TrimSpace(config.General.Format) == "" {
return errors.New("format is not specified")
}
allowedFormats := map[string]bool{"mkv": true, "mp4": true}
if !allowedFormats[strings.ToLower(config.General.Format)] {
return fmt.Errorf("unsupported format: %s (supported: mkv, mp4)", config.General.Format)
}
paths := []struct {
name string
path string
}{
{"BaseDir", config.General.BaseDir},
{"TempBaseDir", config.General.TempBaseDir},
}
for _, p := range paths {
@@ -101,6 +176,12 @@ func validatePaths() error {
return fmt.Errorf("%s is not specified", p.name)
}
if _, err := os.Stat(p.path); os.IsNotExist(err) {
if p.name == "TempBaseDir" {
if mkErr := os.MkdirAll(p.path, 0755); mkErr != nil {
return fmt.Errorf("unable to create %s: %v", p.name, mkErr)
}
continue
}
return fmt.Errorf("%s does not exist: %s", p.name, p.path)
} else if err != nil {
return fmt.Errorf("error accessing %s: %v", p.name, err)
@@ -118,6 +199,18 @@ func validatePaths() error {
}
}
if strings.TrimSpace(config.N_m3u8DLRE.Path) == "" {
return errors.New("N_m3u8DLRE path is not specified")
}
if _, err := exec.LookPath(config.N_m3u8DLRE.Path); err != nil {
return fmt.Errorf("N_m3u8DLRE executable not found in PATH: %s", config.N_m3u8DLRE.Path)
}
if config.Server.Port <= 0 || config.Server.Port > 65535 {
return fmt.Errorf("invalid server port: %d", config.Server.Port)
}
return nil
}
@@ -129,15 +222,28 @@ Configuration Loaded:
Format: %s
TempBaseDir: %s
EnableConsole: %t
MaxUploadMB: %d
WatchFolder:
Path: %s
UsePolling: %t
UseInotify: %t
PollingInterval: %d
Server:
Host: %s
Port: %d
ReadTimeoutSec: %d
WriteTimeoutSec: %d
IdleTimeoutSec: %d
ReadHeaderTimeoutS: %d
Security:
AuthTokenConfigured: %t
N_m3u8DLRE:
Path: %s
`, config.General.BaseDir, config.General.Format, config.General.TempBaseDir, config.General.EnableConsole,
config.General.MaxUploadMB,
config.WatchFolder.Path, config.WatchFolder.UsePolling, config.WatchFolder.UseInotify, config.WatchFolder.PollingInterval,
config.Server.Host, config.Server.Port, config.Server.ReadTimeoutSec, config.Server.WriteTimeoutSec, config.Server.IdleTimeoutSec, config.Server.ReadHeaderTimeoutS,
strings.TrimSpace(config.Security.AuthToken) != "",
config.N_m3u8DLRE.Path)
logger.LogInfo("Config", configInfo)

View File

@@ -1,18 +1,31 @@
package main
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"time"
)
var decryptionKeyRegex = regexp.MustCompile(`^[0-9a-fA-F]{32}:[0-9a-fA-F]{32}$`)
type websocketBroadcastWriter struct {
filename string
}
func (w websocketBroadcastWriter) Write(p []byte) (int, error) {
if config.General.EnableConsole && len(p) > 0 {
message := append([]byte(nil), p...)
broadcast(w.filename, message)
}
return len(p), nil
}
func removeBOM(input []byte) []byte {
if len(input) >= 3 && input[0] == 0xEF && input[1] == 0xBB && input[2] == 0xBF {
return input[3:]
@@ -23,14 +36,17 @@ func removeBOM(input []byte) []byte {
func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
logger.LogInfo("Download File", fmt.Sprintf("Starting download for: %s", item.Filename))
tempDir := filepath.Join(config.General.TempBaseDir, sanitizeFilename(item.Filename))
err := os.MkdirAll(tempDir, 0755)
if err := os.MkdirAll(config.General.TempBaseDir, 0755); err != nil {
logger.LogError("Download File", fmt.Sprintf("Error creating temp base dir: %v", err))
return fmt.Errorf("error creating temp base dir: %v", err)
}
tempDir, err := os.MkdirTemp(config.General.TempBaseDir, sanitizeFilename(item.Filename)+"_")
if err != nil {
logger.LogError("Download File", fmt.Sprintf("Error creating temporary directory: %v", err))
return fmt.Errorf("error creating temporary directory: %v", err)
}
jobInfo.TempDir = tempDir
jobInfo.SetTempDir(tempDir)
mpdPath := item.MPD
if !isValidURL(item.MPD) {
@@ -58,18 +74,11 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
mpdPath = tempFile.Name()
} else if strings.HasPrefix(item.MPD, "https://pubads.g.doubleclick.net") {
resp, err := http.Get(item.MPD)
mpdContent, err := fetchRemoteContent(item.MPD)
if err != nil {
logger.LogError("Download File", fmt.Sprintf("Error downloading MPD: %v", err))
return fmt.Errorf("error downloading MPD: %v", err)
}
defer resp.Body.Close()
mpdContent, err := io.ReadAll(resp.Body)
if err != nil {
logger.LogError("Download File", fmt.Sprintf("Error reading MPD content: %v", err))
return fmt.Errorf("error reading MPD content: %v", err)
}
fixedMPDContent, err := fixGoPlay(string(mpdContent))
if err != nil {
@@ -96,7 +105,10 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
mpdPath = tempFile.Name()
}
command := getDownloadCommand(item, mpdPath, tempDir)
args, err := getDownloadArgs(item, mpdPath, tempDir)
if err != nil {
return err
}
if item.Subtitles != "" {
subtitlePaths, err := downloadAndConvertSubtitles(item.Subtitles)
@@ -105,17 +117,17 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
} else {
for _, path := range subtitlePaths {
logger.LogInfo("Download File", fmt.Sprintf("Adding subtitle: %s", path))
command += fmt.Sprintf(" --mux-import \"path=%s:lang=nl:name=Nederlands\"", path)
args = append(args, "--mux-import", fmt.Sprintf("path=%s:lang=nl:name=Nederlands", path))
}
}
}
cmd := exec.Command("bash", "-c", command)
jobInfo.Cmd = cmd
cmd := exec.Command(config.N_m3u8DLRE.Path, args...)
jobInfo.SetCmd(cmd)
var outputBuffer bytes.Buffer
cmd.Stdout = io.MultiWriter(&outputBuffer)
cmd.Stderr = os.Stderr
broadcastWriter := websocketBroadcastWriter{filename: drmdFilename}
cmd.Stdout = io.MultiWriter(os.Stdout, broadcastWriter)
cmd.Stderr = io.MultiWriter(os.Stderr, broadcastWriter)
err = cmd.Start()
if err != nil {
@@ -128,33 +140,26 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
done <- cmd.Wait()
}()
go func() {
for {
if outputBuffer.Len() > 0 {
message := outputBuffer.Bytes()
if config.General.EnableConsole {
broadcast(drmdFilename, message)
}
outputBuffer.Reset()
}
time.Sleep(1 * time.Second)
}
}()
select {
case <-jobInfo.AbortChan:
if cmd.Process != nil {
cmd.Process.Kill()
}
os.RemoveAll(tempDir)
jobInfo.KillProcess()
_ = os.RemoveAll(tempDir)
logger.LogInfo("Download File", "Download aborted")
return fmt.Errorf("download aborted")
return ErrDownloadAborted
case err := <-done:
if jobInfo.Paused {
if jobInfo.IsPaused() {
logger.LogInfo("Download File", "Download paused")
return fmt.Errorf("download paused")
return ErrDownloadPaused
}
if err != nil {
if jobInfo.IsAborted() {
return ErrDownloadAborted
}
if jobInfo.IsPaused() {
return ErrDownloadPaused
}
logger.LogError("Download File", fmt.Sprintf("Error executing download command: %v", err))
return fmt.Errorf("error executing download command: %v", err)
}
@@ -164,26 +169,24 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
return nil
}
func getDownloadCommand(item Item, mpdPath string, tempDir string) string {
func getDownloadArgs(item Item, mpdPath string, tempDir string) ([]string, error) {
metadata := parseMetadata(item.Metadata)
keys := getKeys(item.Keys)
command := fmt.Sprintf("%s '%s'", config.N_m3u8DLRE.Path, mpdPath)
args := []string{mpdPath}
for _, key := range keys {
if key != "" {
command += fmt.Sprintf(" --key %s", key)
if !decryptionKeyRegex.MatchString(key) {
return nil, fmt.Errorf("invalid decryption key format")
}
args = append(args, "--key", key)
}
command += " --auto-select"
args = append(args, "--auto-select")
sanitizedFilename := sanitizeFilename(item.Filename)
filename := fmt.Sprintf("\"%s\"", sanitizedFilename)
command += fmt.Sprintf(" --save-name %s", filename)
command += fmt.Sprintf(" --mux-after-done format=%s", config.General.Format)
args = append(args, "--save-name", sanitizedFilename)
args = append(args, "--mux-after-done", fmt.Sprintf("format=%s", config.General.Format))
saveDir := config.General.BaseDir
if metadata.Type == "serie" {
@@ -191,15 +194,19 @@ func getDownloadCommand(item Item, mpdPath string, tempDir string) string {
} else {
saveDir = filepath.Join(saveDir, "Movies", metadata.Title)
}
command += fmt.Sprintf(" --save-dir \"%s\"", saveDir)
if err := os.MkdirAll(saveDir, 0755); err != nil {
return nil, fmt.Errorf("unable to create save directory: %w", err)
}
args = append(args, "--save-dir", saveDir)
args = append(args, "--tmp-dir", tempDir)
command += fmt.Sprintf(" --tmp-dir \"%s\"", tempDir)
if globalSpeedLimit != "" {
command += fmt.Sprintf(" -R %s", globalSpeedLimit)
currentSpeedLimit := getGlobalSpeedLimit()
if currentSpeedLimit != "" {
if !speedLimitRegex.MatchString(currentSpeedLimit) {
return nil, errors.New("invalid speed limit format")
}
args = append(args, "-R", currentSpeedLimit)
}
fmt.Println(command)
return command
return args, nil
}

View File

@@ -1,6 +1,6 @@
module DRMDTool
go 1.23.0
go 1.25.0
require (
github.com/BurntSushi/toml v1.4.0
@@ -8,13 +8,13 @@ require (
github.com/beevik/etree v1.4.1
)
require golang.org/x/sys v0.4.0 // indirect
require golang.org/x/sys v0.43.0 // indirect
require (
github.com/asticode/go-astikit v0.20.0 // indirect
github.com/asticode/go-astits v1.8.0 // indirect
github.com/fsnotify/fsnotify v1.7.0
github.com/gorilla/websocket v1.5.3
golang.org/x/net v0.0.0-20200904194848-62affa334b73 // indirect
golang.org/x/text v0.3.2 // indirect
golang.org/x/net v0.53.0 // indirect
golang.org/x/text v0.36.0 // indirect
)

View File

@@ -14,8 +14,6 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/pkg/exec v0.0.0-20150614095509-0bd164ad2a5a h1:EN123kAtAAE2pg/+TvBsUBZfHCWNNFyL2ZBPPfNWAc0=
github.com/pkg/exec v0.0.0-20150614095509-0bd164ad2a5a/go.mod h1:b95YoNrAnScjaWG+asr8lxqlrsPUcT2ZEBcjvVGshMo=
github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -25,16 +23,18 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA=
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=

View File

@@ -1,478 +0,0 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"github.com/gorilla/websocket"
)
type ProgressInfo struct {
Percentage float64
CurrentFile string
Paused bool
}
func handleRoot(w http.ResponseWriter, r *http.Request) {
progressMutex.Lock()
defer progressMutex.Unlock()
jobsInfo := make(map[string]struct {
Percentage float64
CurrentFile string
Paused bool
})
for filename, info := range progress {
jobsInfo[filename] = struct {
Percentage float64
CurrentFile string
Paused bool
}{
Percentage: info.Percentage,
CurrentFile: info.CurrentFile,
Paused: info.Paused,
}
}
err := templates.ExecuteTemplate(w, "index", struct {
Jobs map[string]struct {
Percentage float64
CurrentFile string
Paused bool
}
GlobalSpeedLimit string
}{
Jobs: jobsInfo,
GlobalSpeedLimit: globalSpeedLimit,
})
if err != nil {
logger.LogError("Handle Root", fmt.Sprintf("Error executing template: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleUpload(w http.ResponseWriter, r *http.Request) {
logger.LogInfo("Handle Upload", "Starting file upload")
err := r.ParseMultipartForm(32 << 20)
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error parsing multipart form: %v", err))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
files := r.MultipartForm.File["files"]
if len(files) == 0 {
logger.LogError("Handle Upload", "No files uploaded")
http.Error(w, "No files uploaded", http.StatusBadRequest)
return
}
uploadedFiles := []string{}
for _, fileHeader := range files {
file, err := fileHeader.Open()
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error opening file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer file.Close()
tempFile, err := os.CreateTemp(uploadDir, fileHeader.Filename)
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error creating temporary file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer tempFile.Close()
_, err = io.Copy(tempFile, file)
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error copying file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
uploadedFiles = append(uploadedFiles, filepath.Base(tempFile.Name()))
_, err = parseInputFile(tempFile.Name())
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error parsing input file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
validFiles := []string{}
for _, file := range uploadedFiles {
if file != "" {
validFiles = append(validFiles, file)
}
}
if len(validFiles) == 0 {
logger.LogError("Handle Upload", "No valid files were uploaded")
http.Error(w, "No valid files were uploaded", http.StatusBadRequest)
return
}
logger.LogInfo("Handle Upload", fmt.Sprintf("Redirecting to select with files: %v", validFiles))
http.Redirect(w, r, "/select?files="+url.QueryEscape(strings.Join(validFiles, ",")), http.StatusSeeOther)
}
func handleSelect(w http.ResponseWriter, r *http.Request) {
filesParam := r.URL.Query().Get("files")
filenames := strings.Split(filesParam, ",")
allItems := make(map[string]map[string][]Item)
for _, filename := range filenames {
if filename == "" {
continue
}
fullPath := filepath.Join(uploadDir, filename)
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
logger.LogError("Handle Select", fmt.Sprintf("File does not exist: %s", fullPath))
continue
}
items, err := parseInputFile(fullPath)
if err != nil {
logger.LogError("Handle Select", fmt.Sprintf("Error parsing input file: %v", err))
continue
}
sortItems(items)
groupedItems := groupItemsBySeason(items)
allItems[filename] = groupedItems
}
if len(allItems) == 0 {
logger.LogError("Handle Select", "No valid files were processed")
http.Error(w, "No valid files were processed", http.StatusBadRequest)
return
}
err := templates.ExecuteTemplate(w, "select", struct {
Filenames string
AllItems map[string]map[string][]Item
}{
Filenames: filesParam,
AllItems: allItems,
})
if err != nil {
logger.LogError("Handle Select", fmt.Sprintf("Error executing template: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleProcess(w http.ResponseWriter, r *http.Request) {
logger.LogInfo("Handle Process", "Starting process")
if err := r.ParseForm(); err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error parsing form: %v", err))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
selectedItems := r.Form["items"]
if len(selectedItems) == 0 {
logger.LogError("Handle Process", "No items selected")
http.Error(w, "No items selected", http.StatusBadRequest)
return
}
itemsByFile := make(map[string][]string)
for _, item := range selectedItems {
parts := strings.SplitN(item, ":", 2)
if len(parts) != 2 {
logger.LogError("Handle Process", "Invalid item format")
continue
}
filename, itemName := parts[0], parts[1]
itemsByFile[filename] = append(itemsByFile[filename], itemName)
}
for filename, items := range itemsByFile {
logger.LogInfo("Handle Process", fmt.Sprintf("Processing file: %s", filename))
fullPath := filepath.Join(uploadDir, filename)
allItems, err := parseInputFile(fullPath)
if err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error parsing input file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
selectedItems := filterSelectedItems(allItems, items)
sortItems(selectedItems)
go processItems(filename, selectedItems)
}
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func handleProgress(w http.ResponseWriter, r *http.Request) {
filename := r.URL.Query().Get("filename")
if r.Header.Get("Accept") == "application/json" {
progressInfo := getProgress(filename)
if progressInfo == nil {
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{"error": "No progress information found"})
return
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(progressInfo)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
return
}
err := templates.ExecuteTemplate(w, "progress", struct{ Filename string }{filename})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handlePause(w http.ResponseWriter, r *http.Request) {
filename := r.URL.Query().Get("filename")
if filename == "" {
logger.LogError("Pause Handler", "Filename is required")
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobsMutex.Lock()
jobInfo, exists := jobs[filename]
jobsMutex.Unlock()
if !exists {
logger.LogError("Pause Handler", "Job not found")
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.Paused = true
if jobInfo.Cmd != nil && jobInfo.Cmd.Process != nil {
logger.LogJobState(filename, "pausing")
jobInfo.Cmd.Process.Kill()
}
progressMutex.Lock()
if progressInfo, ok := progress[filename]; ok {
progressInfo.Paused = true
}
progressMutex.Unlock()
fmt.Fprintf(w, "Pause signal sent for %s", filename)
}
func handleResume(w http.ResponseWriter, r *http.Request) {
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobsMutex.Lock()
jobInfo, exists := jobs[filename]
jobsMutex.Unlock()
if !exists {
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.Paused = false
jobInfo.ResumeChan <- struct{}{}
progressMutex.Lock()
if progressInfo, ok := progress[filename]; ok {
progressInfo.Paused = false
}
progressMutex.Unlock()
fmt.Fprintf(w, "Resume signal sent for %s", filename)
}
func handleAbort(w http.ResponseWriter, r *http.Request) {
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobsMutex.Lock()
jobInfo, exists := jobs[filename]
jobsMutex.Unlock()
if !exists {
http.Error(w, "Job not found", http.StatusNotFound)
return
}
close(jobInfo.AbortChan)
if jobInfo.Cmd != nil && jobInfo.Cmd.Process != nil {
jobInfo.Cmd.Process.Kill()
}
if jobInfo.TempDir != "" {
os.RemoveAll(jobInfo.TempDir)
}
fmt.Fprintf(w, "Abort signal sent for %s", filename)
}
func handleClearCompleted(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
clearCompletedJobs()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func clearCompletedJobs() {
progressMutex.Lock()
defer progressMutex.Unlock()
for filename, info := range progress {
if info.Percentage >= 100 {
delete(progress, filename)
}
}
}
func updateProgress(filename string, value float64, currentFile string) {
progressMutex.Lock()
defer progressMutex.Unlock()
jobsMutex.Lock()
jobInfo, exists := jobs[filename]
jobsMutex.Unlock()
paused := false
if exists {
paused = jobInfo.Paused
}
if existingProgress, ok := progress[filename]; ok {
existingProgress.Percentage = value
existingProgress.CurrentFile = currentFile
existingProgress.Paused = paused
} else {
progress[filename] = &ProgressInfo{
Percentage: value,
CurrentFile: currentFile,
Paused: paused,
}
}
}
var upgrader = websocket.Upgrader{}
var clients = make(map[string]map[*websocket.Conn]bool)
var mu sync.Mutex
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
fmt.Println(config.General.EnableConsole)
if !config.General.EnableConsole {
http.Error(w, "Console output is disabled", http.StatusForbidden)
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.LogError("WebSocket", fmt.Sprintf("Error while upgrading connection: %v", err))
return
}
defer conn.Close()
logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection established for filename: %s", filename))
mu.Lock()
if clients[filename] == nil {
clients[filename] = make(map[*websocket.Conn]bool)
}
clients[filename][conn] = true
mu.Unlock()
for {
if _, _, err := conn.NextReader(); err != nil {
break
}
}
mu.Lock()
delete(clients[filename], conn)
mu.Unlock()
logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection closed for filename: %s", filename))
}
func broadcast(filename string, message []byte) {
if !config.General.EnableConsole {
return
}
mu.Lock()
defer mu.Unlock()
for client := range clients[filename] {
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
client.Close()
delete(clients[filename], client)
logger.LogError("Broadcast", fmt.Sprintf("Error writing message to client: %v", err))
}
}
}
func handleSetSpeedLimit(w http.ResponseWriter, r *http.Request) {
logger.LogInfo("Set Speed Limit", "Received request to set speed limit")
if r.Method != http.MethodPost {
logger.LogError("Set Speed Limit", "Invalid method")
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var requestData struct {
SpeedLimit string `json:"speedLimit"`
}
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
logger.LogError("Set Speed Limit", "Invalid request body")
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
if requestData.SpeedLimit == "unlimited" {
globalSpeedLimit = ""
} else {
globalSpeedLimit = requestData.SpeedLimit
}
logger.LogInfo("Set Speed Limit", fmt.Sprintf("Global speed limit set to: %s", globalSpeedLimit))
w.WriteHeader(http.StatusOK)
}

30
src/handlers_common.go Normal file
View File

@@ -0,0 +1,30 @@
package main
import (
"net/url"
"regexp"
"strings"
)
type ProgressInfo struct {
Percentage float64
CurrentFile string
Paused bool
Status string
}
var speedLimitRegex = regexp.MustCompile(`^([1-9]\d*(\.\d+)?)(KBps|MBps|GBps)$`)
func withToken(path string) string {
token := strings.TrimSpace(config.Security.AuthToken)
if token == "" {
return path
}
separator := "?"
if strings.Contains(path, "?") {
separator = "&"
}
return path + separator + "token=" + url.QueryEscape(token)
}

174
src/handlers_jobs.go Normal file
View File

@@ -0,0 +1,174 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
)
func handlePause(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
logger.LogError("Pause Handler", "Filename is required")
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobInfo, exists := getJob(filename)
if !exists {
logger.LogError("Pause Handler", "Job not found")
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.SetPaused(true)
logger.LogJobState(filename, "pausing")
jobInfo.KillProcess()
paused := true
setProgressStatus(filename, &paused, "paused")
_, _ = fmt.Fprintf(w, "Pause signal sent for %s", filename)
}
func handleResume(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobInfo, exists := getJob(filename)
if !exists {
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.SetPaused(false)
jobInfo.SignalResume()
paused := false
setProgressStatus(filename, &paused, "running")
_, _ = fmt.Fprintf(w, "Resume signal sent for %s", filename)
}
func handleAbort(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobInfo, exists := getJob(filename)
if !exists {
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.Abort()
jobInfo.KillProcess()
if tempDir := jobInfo.GetTempDir(); tempDir != "" {
_ = os.RemoveAll(tempDir)
}
paused := false
setProgressStatus(filename, &paused, "aborted")
_, _ = fmt.Fprintf(w, "Abort signal sent for %s", filename)
}
func handleClearCompleted(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
clearCompletedJobs()
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func clearCompletedJobs() {
progressMutex.Lock()
defer progressMutex.Unlock()
for filename, info := range progress {
if info.Percentage >= 100 || info.Status == "failed" || info.Status == "aborted" {
delete(progress, filename)
}
}
}
func updateProgress(filename string, value float64, currentFile, status string) {
paused := false
if jobInfo, exists := getJob(filename); exists {
paused = jobInfo.IsPaused()
}
progressMutex.Lock()
defer progressMutex.Unlock()
if existingProgress, ok := progress[filename]; ok {
existingProgress.Percentage = value
existingProgress.CurrentFile = currentFile
existingProgress.Paused = paused
if status != "" {
existingProgress.Status = status
}
} else {
progress[filename] = &ProgressInfo{
Percentage: value,
CurrentFile: currentFile,
Paused: paused,
Status: status,
}
}
}
func handleSetSpeedLimit(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
logger.LogInfo("Set Speed Limit", "Received request to set speed limit")
var requestData struct {
SpeedLimit string `json:"speedLimit"`
}
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
logger.LogError("Set Speed Limit", "Invalid request body")
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
requestData.SpeedLimit = strings.TrimSpace(requestData.SpeedLimit)
if requestData.SpeedLimit == "unlimited" {
setGlobalSpeedLimit("")
} else {
if !speedLimitRegex.MatchString(requestData.SpeedLimit) {
http.Error(w, "Invalid speed limit format", http.StatusBadRequest)
return
}
setGlobalSpeedLimit(requestData.SpeedLimit)
}
logger.LogInfo("Set Speed Limit", fmt.Sprintf("Global speed limit set to: %s", getGlobalSpeedLimit()))
w.WriteHeader(http.StatusOK)
}

128
src/handlers_pages.go Normal file
View File

@@ -0,0 +1,128 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strings"
)
func handleRoot(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) {
return
}
err := templates.ExecuteTemplate(w, "index", struct {
Jobs map[string]ProgressInfo
GlobalSpeedLimit string
AuthToken string
Nonce string
}{
Jobs: snapshotProgress(),
GlobalSpeedLimit: getGlobalSpeedLimit(),
AuthToken: url.QueryEscape(config.Security.AuthToken),
Nonce: cspNonce(r),
})
if err != nil {
logger.LogError("Handle Root", fmt.Sprintf("Error executing template: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleSelect(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) {
return
}
filesParam := r.URL.Query().Get("files")
filenames := strings.Split(filesParam, ",")
allItems := make(map[string]map[string][]Item)
for _, filename := range filenames {
if filename == "" {
continue
}
fullPath, pathErr := safeUploadPath(filename)
if pathErr != nil {
logger.LogError("Handle Select", fmt.Sprintf("Invalid filename %s: %v", filename, pathErr))
continue
}
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
logger.LogError("Handle Select", fmt.Sprintf("File does not exist: %s", fullPath))
continue
}
items, err := parseInputFile(fullPath)
if err != nil {
logger.LogError("Handle Select", fmt.Sprintf("Error parsing input file: %v", err))
continue
}
sortItems(items)
groupedItems := groupItemsBySeason(items)
allItems[filename] = groupedItems
}
if len(allItems) == 0 {
logger.LogError("Handle Select", "No valid files were processed")
http.Error(w, "No valid files were processed", http.StatusBadRequest)
return
}
err := templates.ExecuteTemplate(w, "select", struct {
Filenames string
AllItems map[string]map[string][]Item
AuthToken string
Nonce string
}{
Filenames: filesParam,
AllItems: allItems,
AuthToken: url.QueryEscape(config.Security.AuthToken),
Nonce: cspNonce(r),
})
if err != nil {
logger.LogError("Handle Select", fmt.Sprintf("Error executing template: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleProgress(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) {
return
}
filename := r.URL.Query().Get("filename")
if r.Header.Get("Accept") == "application/json" {
progressInfo := getProgress(filename)
if progressInfo == nil {
w.WriteHeader(http.StatusNotFound)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "No progress information found"})
return
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(progressInfo)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
return
}
err := templates.ExecuteTemplate(w, "progress", struct {
Filename string
AuthToken string
Nonce string
}{Filename: filename, AuthToken: url.QueryEscape(config.Security.AuthToken), Nonce: cspNonce(r)})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,158 @@
package main
import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
)
func handleUpload(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
logger.LogInfo("Handle Upload", "Starting file upload")
r.Body = http.MaxBytesReader(w, r.Body, maxUploadBytes())
const multipartMemory = 4 << 20
err := r.ParseMultipartForm(multipartMemory)
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error parsing multipart form: %v", err))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
files := r.MultipartForm.File["files"]
if len(files) == 0 {
logger.LogError("Handle Upload", "No files uploaded")
http.Error(w, "No files uploaded", http.StatusBadRequest)
return
}
uploadedFiles := []string{}
for _, fileHeader := range files {
if strings.ToLower(filepath.Ext(fileHeader.Filename)) != ".drmd" {
http.Error(w, "Only .drmd files are allowed", http.StatusBadRequest)
return
}
file, err := fileHeader.Open()
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error opening file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
pattern := sanitizeFilename(fileHeader.Filename) + "_*.drmd"
tempFile, err := os.CreateTemp(uploadDir, pattern)
if err != nil {
_ = file.Close()
logger.LogError("Handle Upload", fmt.Sprintf("Error creating temporary file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, err = io.Copy(tempFile, file)
_ = file.Close()
if err != nil {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
logger.LogError("Handle Upload", fmt.Sprintf("Error copying file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if err := tempFile.Close(); err != nil {
_ = os.Remove(tempFile.Name())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
uploadedFiles = append(uploadedFiles, filepath.Base(tempFile.Name()))
_, err = parseInputFile(tempFile.Name())
if err != nil {
_ = os.Remove(tempFile.Name())
logger.LogError("Handle Upload", fmt.Sprintf("Error parsing input file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
validFiles := []string{}
for _, file := range uploadedFiles {
if file != "" {
validFiles = append(validFiles, file)
}
}
if len(validFiles) == 0 {
logger.LogError("Handle Upload", "No valid files were uploaded")
http.Error(w, "No valid files were uploaded", http.StatusBadRequest)
return
}
logger.LogInfo("Handle Upload", fmt.Sprintf("Redirecting to select with files: %v", validFiles))
http.Redirect(w, r, withToken("/select?files="+url.QueryEscape(strings.Join(validFiles, ","))), http.StatusSeeOther)
}
func handleProcess(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
logger.LogInfo("Handle Process", "Starting process")
if err := r.ParseForm(); err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error parsing form: %v", err))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
selectedItems := r.Form["items"]
if len(selectedItems) == 0 {
logger.LogError("Handle Process", "No items selected")
http.Error(w, "No items selected", http.StatusBadRequest)
return
}
itemsByFile := make(map[string][]string)
for _, item := range selectedItems {
parts := strings.SplitN(item, ":", 2)
if len(parts) != 2 {
logger.LogError("Handle Process", "Invalid item format")
continue
}
filename, itemName := parts[0], parts[1]
itemsByFile[filename] = append(itemsByFile[filename], itemName)
}
for filename, items := range itemsByFile {
logger.LogInfo("Handle Process", fmt.Sprintf("Processing file: %s", filename))
fullPath, pathErr := safeUploadPath(filename)
if pathErr != nil {
http.Error(w, pathErr.Error(), http.StatusBadRequest)
return
}
allItems, err := parseInputFile(fullPath)
if err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error parsing input file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
selectedItems := filterSelectedItems(allItems, items)
sortItems(selectedItems)
go func(targetFilename string, targetItems []Item) {
if err := processItems(targetFilename, targetItems); err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error processing %s: %v", targetFilename, err))
}
}(filename, selectedItems)
}
http.Redirect(w, r, withToken("/"), http.StatusSeeOther)
}

148
src/handlers_ws.go Normal file
View File

@@ -0,0 +1,148 @@
package main
import (
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
const (
wsSendBuffer = 64
wsWriteWait = 10 * time.Second
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
return strings.TrimSpace(config.Security.AuthToken) == ""
}
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false
}
return strings.EqualFold(parsedOrigin.Host, r.Host)
},
}
type wsClient struct {
conn *websocket.Conn
send chan []byte
once sync.Once
}
func (c *wsClient) close() {
c.once.Do(func() {
close(c.send)
})
}
var clients = make(map[string]map[*wsClient]struct{})
var mu sync.Mutex
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) {
return
}
if !config.General.EnableConsole {
http.Error(w, "Console output is disabled", http.StatusForbidden)
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.LogError("WebSocket", fmt.Sprintf("Error while upgrading connection: %v", err))
return
}
logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection established for filename: %s", filename))
client := &wsClient{conn: conn, send: make(chan []byte, wsSendBuffer)}
mu.Lock()
if clients[filename] == nil {
clients[filename] = make(map[*wsClient]struct{})
}
clients[filename][client] = struct{}{}
mu.Unlock()
go writePump(filename, client)
for {
if _, _, err := conn.NextReader(); err != nil {
break
}
}
mu.Lock()
if set, ok := clients[filename]; ok {
if _, exists := set[client]; exists {
delete(set, client)
client.close()
}
if len(set) == 0 {
delete(clients, filename)
}
}
mu.Unlock()
_ = conn.Close()
logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection closed for filename: %s", filename))
}
func writePump(filename string, client *wsClient) {
defer client.conn.Close()
for message := range client.send {
if err := client.conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
logger.LogError("WebSocket", fmt.Sprintf("SetWriteDeadline error: %v", err))
return
}
if err := client.conn.WriteMessage(websocket.TextMessage, message); err != nil {
logger.LogError("Broadcast", fmt.Sprintf("Error writing to client for %s: %v", filename, err))
return
}
}
}
func broadcast(filename string, message []byte) {
if !config.General.EnableConsole {
return
}
mu.Lock()
set := clients[filename]
targets := make([]*wsClient, 0, len(set))
for c := range set {
targets = append(targets, c)
}
mu.Unlock()
for _, client := range targets {
select {
case client.send <- message:
default:
mu.Lock()
if set, ok := clients[filename]; ok {
if _, exists := set[client]; exists {
delete(set, client)
client.close()
}
}
mu.Unlock()
logger.LogError("Broadcast", fmt.Sprintf("Dropping slow client for %s", filename))
}
}
}

109
src/integration_test.go Normal file
View File

@@ -0,0 +1,109 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func resetStateForTest() {
jobsMutex.Lock()
jobs = make(map[string]*JobInfo)
jobsMutex.Unlock()
progressMutex.Lock()
progress = make(map[string]*ProgressInfo)
progressMutex.Unlock()
setGlobalSpeedLimit("")
config = Config{}
setDefaultConfigValues()
}
func TestAuthTokenProtection(t *testing.T) {
resetStateForTest()
config.Security.AuthToken = "secret"
handler := newRouter()
req := httptest.NewRequest(http.MethodPost, "/clear-completed", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
}
req = httptest.NewRequest(http.MethodPost, "/clear-completed?token=secret", nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
}
}
func TestPauseResumeAbortFlow(t *testing.T) {
resetStateForTest()
handler := newRouter()
filename := "job.drmd"
setJob(filename, NewJobInfo())
updateProgress(filename, 10, "episode1", "running")
req := httptest.NewRequest(http.MethodPost, "/pause?filename="+filename, nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("pause expected %d got %d", http.StatusOK, rr.Code)
}
job, ok := getJob(filename)
if !ok || !job.IsPaused() {
t.Fatalf("expected paused job state")
}
progressInfo := getProgress(filename)
if progressInfo == nil || progressInfo.Status != "paused" {
t.Fatalf("expected paused progress state")
}
req = httptest.NewRequest(http.MethodPost, "/resume?filename="+filename, nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("resume expected %d got %d", http.StatusOK, rr.Code)
}
if job.IsPaused() {
t.Fatalf("expected resumed job state")
}
progressInfo = getProgress(filename)
if progressInfo == nil || progressInfo.Status != "running" {
t.Fatalf("expected running progress state")
}
req = httptest.NewRequest(http.MethodPost, "/abort?filename="+filename, nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("abort expected %d got %d", http.StatusOK, rr.Code)
}
if !job.IsAborted() {
t.Fatalf("expected aborted job state")
}
progressInfo = getProgress(filename)
if progressInfo == nil || progressInfo.Status != "aborted" {
t.Fatalf("expected aborted progress state")
}
req = httptest.NewRequest(http.MethodPost, "/abort?filename="+filename, nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("second abort expected %d got %d", http.StatusOK, rr.Code)
}
}

View File

@@ -1,13 +1,17 @@
package main
import (
"context"
"flag"
"fmt"
"html/template"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"embed"
)
@@ -34,9 +38,6 @@ type Metadata struct {
Season string
}
var progressMutex sync.Mutex
var progress = make(map[string]*ProgressInfo)
const uploadDir = "uploads"
var templates *template.Template
@@ -45,6 +46,19 @@ var templates *template.Template
var templateFS embed.FS
var globalSpeedLimit string
var globalSpeedLimitMutex sync.RWMutex
func getGlobalSpeedLimit() string {
globalSpeedLimitMutex.RLock()
defer globalSpeedLimitMutex.RUnlock()
return globalSpeedLimit
}
func setGlobalSpeedLimit(value string) {
globalSpeedLimitMutex.Lock()
defer globalSpeedLimitMutex.Unlock()
globalSpeedLimit = value
}
func init() {
if err := os.MkdirAll(uploadDir, 0755); err != nil {
@@ -57,10 +71,12 @@ func init() {
}
func main() {
loadConfig()
configPath := flag.String("config", "config.toml", "Path to config file")
inputFile := flag.String("f", "", "Path to the input JSON file")
flag.Parse()
loadConfig(*configPath)
if *inputFile == "" {
go watchFolder()
startWebServer()
@@ -70,35 +86,67 @@ func main() {
logger.LogError("Main", fmt.Sprintf("Error parsing input file: %v", err))
return
}
processItems(*inputFile, items)
if err := processItems(*inputFile, items); err != nil {
logger.LogError("Main", fmt.Sprintf("Error processing items: %v", err))
}
}
}
func startWebServer() {
http.HandleFunc("/", handleRoot)
http.HandleFunc("/upload", handleUpload)
http.HandleFunc("/select", handleSelect)
http.HandleFunc("/process", handleProcess)
http.HandleFunc("/progress", handleProgress)
http.HandleFunc("/abort", handleAbort)
http.HandleFunc("/pause", handlePause)
http.HandleFunc("/resume", handleResume)
http.HandleFunc("/clear-completed", handleClearCompleted)
http.HandleFunc("/ws", handleWebSocket)
http.HandleFunc("/set-speed-limit", handleSetSpeedLimit)
server := &http.Server{
Addr: serverAddr(),
Handler: newRouter(),
ReadTimeout: time.Duration(config.Server.ReadTimeoutSec) * time.Second,
ReadHeaderTimeout: time.Duration(config.Server.ReadHeaderTimeoutS) * time.Second,
WriteTimeout: time.Duration(config.Server.WriteTimeoutSec) * time.Second,
IdleTimeout: time.Duration(config.Server.IdleTimeoutSec) * time.Second,
}
logger.LogInfo("Main", "Starting web server on http://0.0.0.0:8080")
http.ListenAndServe(":8080", nil)
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
go func() {
<-ctx.Done()
logger.LogInfo("Main", "Shutdown signal received, aborting jobs")
abortAllJobs()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
logger.LogError("Main", fmt.Sprintf("Server shutdown error: %v", err))
}
}()
logger.LogInfo("Main", fmt.Sprintf("Starting web server on http://%s", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.LogError("Main", fmt.Sprintf("Web server error: %v", err))
}
}
func getProgress(filename string) *ProgressInfo {
progressMutex.Lock()
defer progressMutex.Unlock()
return progress[filename]
func abortAllJobs() {
jobsMutex.RLock()
jobList := make([]*JobInfo, 0, len(jobs))
for _, j := range jobs {
jobList = append(jobList, j)
}
jobsMutex.RUnlock()
for _, j := range jobList {
j.Abort()
j.KillProcess()
}
}
func getKeys(keys string) []string {
return strings.Split(keys, ",")
parts := strings.Split(keys, ",")
out := parts[:0]
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func parseMetadata(metadata string) Metadata {

View File

@@ -3,6 +3,7 @@ package main
import (
"encoding/json"
"os"
"path/filepath"
"reflect"
"testing"
)
@@ -123,3 +124,53 @@ func TestGroupItemsBySeason(t *testing.T) {
}
}
}
func TestSafeUploadPath(t *testing.T) {
tests := []struct {
name string
input string
wantError bool
}{
{name: "valid filename", input: "file.drmd", wantError: false},
{name: "directory traversal", input: "../file.drmd", wantError: true},
{name: "empty", input: "", wantError: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
path, err := safeUploadPath(tc.input)
if tc.wantError && err == nil {
t.Fatalf("expected error for input %q", tc.input)
}
if !tc.wantError {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filepath.Base(path) != tc.input {
t.Fatalf("expected basename %q, got %q", tc.input, filepath.Base(path))
}
}
})
}
}
func TestValidateRemoteURL(t *testing.T) {
tests := []struct {
name string
input string
wantError bool
}{
{name: "reject localhost", input: "http://localhost/test.vtt", wantError: true},
{name: "reject invalid scheme", input: "file:///tmp/test", wantError: true},
{name: "reject malformed", input: "::://", wantError: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := validateRemoteURL(tc.input)
if tc.wantError && err == nil {
t.Fatalf("expected error for %q", tc.input)
}
})
}
}

19
src/router.go Normal file
View File

@@ -0,0 +1,19 @@
package main
import "net/http"
func newRouter() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/", handleRoot)
mux.HandleFunc("/upload", handleUpload)
mux.HandleFunc("/select", handleSelect)
mux.HandleFunc("/process", handleProcess)
mux.HandleFunc("/progress", handleProgress)
mux.HandleFunc("/abort", handleAbort)
mux.HandleFunc("/pause", handlePause)
mux.HandleFunc("/resume", handleResume)
mux.HandleFunc("/clear-completed", handleClearCompleted)
mux.HandleFunc("/ws", handleWebSocket)
mux.HandleFunc("/set-speed-limit", handleSetSpeedLimit)
return withSecurityHeaders(mux)
}

237
src/security.go Normal file
View File

@@ -0,0 +1,237 @@
package main
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
)
var blockedCIDRs = mustParseCIDRs([]string{
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"169.254.0.0/16",
"100.64.0.0/10",
"0.0.0.0/8",
"::1/128",
"fc00::/7",
"fe80::/10",
})
var secureDialer = &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
var secureTransport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
for _, ip := range ips {
if isBlockedIP(ip.IP) {
return nil, fmt.Errorf("dial blocked: %s resolves to %s", host, ip.IP)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IPs for %s", host)
}
return secureDialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port))
},
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
}
var secureHTTPClient = &http.Client{
Timeout: 30 * time.Second,
Transport: secureTransport,
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
_, err := validateRemoteURL(req.URL.String())
return err
},
}
const maxRemoteFetchBytes int64 = 20 << 20
func mustParseCIDRs(cidrs []string) []*net.IPNet {
networks := make([]*net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
networks = append(networks, network)
}
return networks
}
func maxUploadBytes() int64 {
return int64(config.General.MaxUploadMB) << 20
}
func serverAddr() string {
return fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port)
}
func ensureMethod(w http.ResponseWriter, r *http.Request, method string) bool {
if r.Method != method {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return false
}
return true
}
func ensureAuthorized(w http.ResponseWriter, r *http.Request) bool {
token := strings.TrimSpace(config.Security.AuthToken)
if token == "" {
return true
}
providedToken := strings.TrimSpace(r.Header.Get("X-DRMD-Token"))
if providedToken == "" {
providedToken = strings.TrimSpace(r.URL.Query().Get("token"))
}
if subtle.ConstantTimeCompare([]byte(providedToken), []byte(token)) != 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return false
}
return true
}
func validateRemoteURL(rawURL string) (*url.URL, error) {
parsedURL, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return nil, errors.New("only http(s) URLs are allowed")
}
host := parsedURL.Hostname()
if host == "" {
return nil, errors.New("URL host is required")
}
if strings.EqualFold(host, "localhost") {
return nil, errors.New("localhost URLs are not allowed")
}
ips, err := net.LookupIP(host)
if err != nil {
return nil, fmt.Errorf("unable to resolve URL host: %w", err)
}
for _, ip := range ips {
if isBlockedIP(ip) {
return nil, fmt.Errorf("URL host resolves to blocked IP range: %s", ip.String())
}
}
return parsedURL, nil
}
func isBlockedIP(ip net.IP) bool {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() || ip.IsMulticast() {
return true
}
for _, network := range blockedCIDRs {
if network.Contains(ip) {
return true
}
}
return false
}
func fetchRemoteContent(rawURL string) ([]byte, error) {
validatedURL, err := validateRemoteURL(rawURL)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodGet, validatedURL.String(), nil)
if err != nil {
return nil, err
}
resp, err := secureHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
if resp.ContentLength > maxRemoteFetchBytes {
return nil, fmt.Errorf("remote content too large: %d bytes", resp.ContentLength)
}
limited := io.LimitReader(resp.Body, maxRemoteFetchBytes+1)
body, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(body)) > maxRemoteFetchBytes {
return nil, fmt.Errorf("remote content exceeded %d bytes", maxRemoteFetchBytes)
}
return body, nil
}
type ctxKey string
const nonceCtxKey ctxKey = "csp-nonce"
func generateNonce() string {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return ""
}
return base64.RawStdEncoding.EncodeToString(b[:])
}
func cspNonce(r *http.Request) string {
if v, ok := r.Context().Value(nonceCtxKey).(string); ok {
return v
}
return ""
}
func withSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nonce := generateNonce()
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Referrer-Policy", "no-referrer")
csp := fmt.Sprintf(
"default-src 'self'; connect-src 'self'; img-src 'self' data:; style-src 'self' 'nonce-%s'; script-src 'self' 'nonce-%s'; base-uri 'self'; form-action 'self'; frame-ancestors 'none'",
nonce, nonce,
)
w.Header().Set("Content-Security-Policy", csp)
ctx := context.WithValue(r.Context(), nonceCtxKey, nonce)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

67
src/state.go Normal file
View File

@@ -0,0 +1,67 @@
package main
import "sync"
var (
jobsMutex sync.RWMutex
jobs = make(map[string]*JobInfo)
progressMutex sync.RWMutex
progress = make(map[string]*ProgressInfo)
)
func setJob(filename string, jobInfo *JobInfo) {
jobsMutex.Lock()
defer jobsMutex.Unlock()
jobs[filename] = jobInfo
}
func getJob(filename string) (*JobInfo, bool) {
jobsMutex.RLock()
defer jobsMutex.RUnlock()
jobInfo, ok := jobs[filename]
return jobInfo, ok
}
func removeJob(filename string) {
jobsMutex.Lock()
defer jobsMutex.Unlock()
delete(jobs, filename)
}
func getProgress(filename string) *ProgressInfo {
progressMutex.RLock()
defer progressMutex.RUnlock()
if p, ok := progress[filename]; ok {
snapshot := *p
return &snapshot
}
return nil
}
func snapshotProgress() map[string]ProgressInfo {
progressMutex.RLock()
defer progressMutex.RUnlock()
result := make(map[string]ProgressInfo, len(progress))
for filename, info := range progress {
result[filename] = *info
}
return result
}
func setProgressStatus(filename string, paused *bool, status string) {
progressMutex.Lock()
defer progressMutex.Unlock()
info, ok := progress[filename]
if !ok {
return
}
if paused != nil {
info.Paused = *paused
}
if status != "" {
info.Status = status
}
}

View File

@@ -1,9 +1,9 @@
package main
import (
"bytes"
"fmt"
"io"
"net/http"
"os"
"strings"
@@ -36,12 +36,11 @@ func downloadAndConvertSubtitles(subtitlesURLs string) ([]string, error) {
func downloadSubtitle(url string) (string, error) {
logger.LogInfo("Download Subtitle", fmt.Sprintf("Starting download from %s", url))
resp, err := http.Get(url)
body, err := fetchRemoteContent(url)
if err != nil {
logger.LogError("Download Subtitle", fmt.Sprintf("Error getting subtitle URL: %v", err))
return "", err
}
defer resp.Body.Close()
tempFile, err := os.CreateTemp("", "subtitle_*.vtt")
if err != nil {
@@ -50,7 +49,7 @@ func downloadSubtitle(url string) (string, error) {
}
defer tempFile.Close()
_, err = io.Copy(tempFile, resp.Body)
_, err = io.Copy(tempFile, bytes.NewReader(body))
if err != nil {
logger.LogError("Download Subtitle", fmt.Sprintf("Error copying to temp file: %v", err))
return "", err
@@ -62,8 +61,13 @@ func downloadSubtitle(url string) (string, error) {
func convertVTTtoSRT(vttPath string) (string, error) {
srtPath := strings.TrimSuffix(vttPath, ".vtt") + ".srt"
s1, _ := astisub.OpenFile(vttPath)
s1.Write(srtPath)
s1, err := astisub.OpenFile(vttPath)
if err != nil {
return "", err
}
if err := s1.Write(srtPath); err != nil {
return "", err
}
logger.LogInfo("Convert VTT to SRT", fmt.Sprintf("Converted %s to %s", vttPath, srtPath))
return srtPath, nil
}

View File

@@ -4,7 +4,7 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Simple Downloader</title>
<style>
<style nonce="{{.Nonce}}">
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
background-color: #1e1e1e;
@@ -203,7 +203,7 @@
</head>
<body>
<h1>Simple Downloader</h1>
<form action="/upload" method="post" enctype="multipart/form-data">
<form action="{{if .AuthToken}}/upload?token={{.AuthToken}}{{else}}/upload{{end}}" method="post" enctype="multipart/form-data">
<input type="file" name="files" accept=".drmd" multiple>
<input type="submit" value="Upload and Process">
</form>
@@ -212,11 +212,12 @@
{{range $filename, $info := .Jobs}}
<li>
<div class="job-title">
<a href="/progress?filename={{$filename}}">{{$filename}}</a>
<a href="{{if $.AuthToken}}/progress?filename={{$filename}}&token={{$.AuthToken}}{{else}}/progress?filename={{$filename}}{{end}}">{{$filename}}</a>
</div>
<div class="job-info">
Progress: <span class="progress-text">{{printf "%5.1f%%" $info.Percentage}}</span>
Current file: {{$info.CurrentFile}}
Status: {{$info.Status}}
{{if $info.Paused}}
<span class="paused">(Paused)</span>
{{end}}
@@ -226,7 +227,7 @@
<li>No active jobs</li>
{{end}}
</ul>
<button id="clear-completed" onclick="clearCompleted()">Clear Completed Jobs</button>
<button id="clear-completed">Clear Completed Jobs</button>
<div class="settings-section">
<h2>Settings</h2>
@@ -239,14 +240,26 @@
<option value="MBps" selected>MBps</option>
<option value="KBps">KBps</option>
</select>
<button type="button" onclick="updateSpeedLimit(event)">Set Limit</button>
<button id="set-speed-limit" type="button">Set Limit</button>
</div>
</div>
</div>
<script>
<script nonce="{{.Nonce}}">
const authToken = "{{.AuthToken}}";
const currentSpeedLimitRaw = "{{if .GlobalSpeedLimit}}{{.GlobalSpeedLimit}}{{else}}0{{end}}";
function withToken(path) {
if (!authToken) {
return path;
}
const separator = path.includes('?') ? '&' : '?';
return `${path}${separator}token=${authToken}`;
}
function clearCompleted() {
fetch('/clear-completed', { method: 'POST' })
fetch(withToken('/clear-completed'), { method: 'POST' })
.then(response => response.json())
.then(data => {
if (data.success) {
@@ -269,7 +282,7 @@
return;
}
fetch('/set-speed-limit', {
fetch(withToken('/set-speed-limit'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -278,7 +291,6 @@
}).then(response => {
if (response.ok) {
alert('Speed limit updated successfully');
document.getElementById('currentSpeedLimit').textContent = speedLimit;
} else {
alert('Failed to update speed limit');
}
@@ -291,18 +303,20 @@
}
document.addEventListener('DOMContentLoaded', function() {
const currentSpeedLimit = "{{if .GlobalSpeedLimit}}{{.GlobalSpeedLimit}}{{else}}0{{end}}";
const speedLimitValueInput = document.getElementById('speedLimitValue');
const speedLimitUnitSelect = document.getElementById('speedLimitUnit');
const match = currentSpeedLimit.match(/(\d+(\.\d+)?)([A-Za-z]+)/);
const match = currentSpeedLimitRaw.match(/(\d+(\.\d+)?)([A-Za-z]+)/);
if (match) {
speedLimitValueInput.value = match[1];
speedLimitUnitSelect.value = match[3];
} else {
speedLimitValueInput.value = "0";
speedLimitValueInput.value = "0";
speedLimitUnitSelect.value = "MBps";
}
document.getElementById('clear-completed').addEventListener('click', clearCompleted);
document.getElementById('set-speed-limit').addEventListener('click', updateSpeedLimit);
});
</script>
</body>

View File

@@ -4,7 +4,7 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Processing {{.Filename}}</title>
<style>
<style nonce="{{.Nonce}}">
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
background-color: #1e1e1e;
@@ -136,19 +136,29 @@
<div id="currentFile"></div>
</div>
<div>
<button id="abort-button" onclick="abortDownload()">Abort Download</button>
<button id="pause-button" onclick="pauseDownload()">Pause Download</button>
<button id="resume-button" onclick="resumeDownload()" style="display: none;">Resume Download</button>
<button id="abort-button">Abort Download</button>
<button id="pause-button">Pause Download</button>
<button id="resume-button">Resume Download</button>
<button id="toggle-console">Toggle Console View</button>
<button id="back-button" onclick="window.location.href='/'">Back to Index</button>
<button id="back-button">Back to Index</button>
</div>
<div style="display: none;" id="console"></div>
<script>
<div id="console"></div>
<script nonce="{{.Nonce}}">
let isPaused = false;
const filename = "{{.Filename}}";
const authToken = "{{.AuthToken}}";
function withToken(path) {
if (!authToken) {
return path;
}
const separator = path.includes('?') ? '&' : '?';
return `${path}${separator}token=${authToken}`;
}
function updateProgress() {
fetch(`/progress?filename=${filename}`, {
fetch(withToken(`/progress?filename=${encodeURIComponent(filename)}`), {
headers: {
'Accept': 'application/json'
}
@@ -180,7 +190,7 @@
}
function abortDownload() {
fetch(`/abort?filename=${filename}`, { method: 'POST' })
fetch(withToken(`/abort?filename=${encodeURIComponent(filename)}`), { method: 'POST' })
.then(response => {
if (response.ok) {
console.log('Abort signal sent. The download will stop soon.');
@@ -191,7 +201,7 @@
}
function pauseDownload() {
fetch(`/pause?filename=${filename}`, { method: 'POST' })
fetch(withToken(`/pause?filename=${encodeURIComponent(filename)}`), { method: 'POST' })
.then(response => {
if (response.ok) {
console.log('Pause signal sent. The download will pause soon.');
@@ -204,7 +214,7 @@
}
function resumeDownload() {
fetch(`/resume?filename=${filename}`, { method: 'POST' })
fetch(withToken(`/resume?filename=${encodeURIComponent(filename)}`), { method: 'POST' })
.then(response => {
if (response.ok) {
console.log('Resume signal sent. The download will resume soon.');
@@ -219,7 +229,7 @@
const consoleDiv = document.getElementById('console');
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const ws = new WebSocket(`${protocol}//${window.location.host}/ws?filename=${filename}`);
const ws = new WebSocket(`${protocol}//${window.location.host}${withToken(`/ws?filename=${encodeURIComponent(filename)}`)}`);
ws.onmessage = function(event) {
consoleDiv.textContent += event.data;
@@ -234,13 +244,17 @@
console.error('WebSocket error:', error);
};
document.getElementById('toggle-console').onclick = function() {
if (consoleDiv.style.display === "none") {
consoleDiv.style.display = "block";
} else {
consoleDiv.style.display = "none";
}
};
document.addEventListener('DOMContentLoaded', function() {
document.getElementById('abort-button').addEventListener('click', abortDownload);
document.getElementById('pause-button').addEventListener('click', pauseDownload);
document.getElementById('resume-button').addEventListener('click', resumeDownload);
document.getElementById('back-button').addEventListener('click', function() {
window.location.href = withToken('/');
});
document.getElementById('toggle-console').addEventListener('click', function() {
consoleDiv.style.display = consoleDiv.style.display === "none" ? "block" : "none";
});
});
updateProgress();
</script>

View File

@@ -4,7 +4,7 @@
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Select Items to Download</title>
<style>
<style nonce="{{.Nonce}}">
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
background-color: #1e1e1e;
@@ -66,14 +66,14 @@
</head>
<body>
<h1>Select Items to Download</h1>
<form action="/process" method="post">
<form action="{{if .AuthToken}}/process?token={{.AuthToken}}{{else}}/process{{end}}" method="post">
<input type="hidden" name="filenames" value="{{.Filenames}}">
{{range $filename, $fileItems := .AllItems}}
<h2>{{$filename}}</h2>
{{range $season, $items := $fileItems}}
<div class="season" id="season-{{$filename}}-{{$season}}">
<div class="season-title">
<input type="checkbox" class="season-checkbox" id="season-checkbox-{{$filename}}-{{$season}}" checked onchange="toggleSeason('{{$filename}}-{{$season}}')">
<input type="checkbox" class="season-checkbox" id="season-checkbox-{{$filename}}-{{$season}}" data-season-key="{{$filename}}-{{$season}}" checked>
<label for="season-checkbox-{{$filename}}-{{$season}}">{{$season}}</label>
</div>
<div class="season-items">
@@ -90,12 +90,12 @@
{{end}}
{{end}}
<div>
<button type="button" onclick="selectAll(true)">Select All</button>
<button type="button" onclick="selectAll(false)">Select None</button>
<button type="button" id="select-all">Select All</button>
<button type="button" id="select-none">Select None</button>
<input type="submit" value="Start Download">
</div>
</form>
<script>
<script nonce="{{.Nonce}}">
function selectAll(checked) {
var checkboxes = document.getElementsByName('items');
for (var i = 0; i < checkboxes.length; i++) {
@@ -114,6 +114,17 @@
episodeCheckboxes[i].checked = seasonCheckbox.checked;
}
}
document.addEventListener('DOMContentLoaded', function() {
document.getElementById('select-all').addEventListener('click', function() { selectAll(true); });
document.getElementById('select-none').addEventListener('click', function() { selectAll(false); });
var boxes = document.getElementsByClassName('season-checkbox');
for (var i = 0; i < boxes.length; i++) {
boxes[i].addEventListener('change', function() {
toggleSeason(this.dataset.seasonKey);
});
}
});
</script>
</body>
</html>

View File

@@ -2,6 +2,7 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
@@ -21,23 +22,131 @@ type JobInfo struct {
AbortChan chan struct{}
ResumeChan chan struct{}
Cmd *exec.Cmd
Paused bool
TempDir string
mu sync.RWMutex
paused bool
abortOnce sync.Once
}
func NewJobInfo() *JobInfo {
return &JobInfo{
AbortChan: make(chan struct{}),
ResumeChan: make(chan struct{}, 1),
}
}
func (j *JobInfo) SetPaused(value bool) {
j.mu.Lock()
defer j.mu.Unlock()
j.paused = value
}
func (j *JobInfo) IsPaused() bool {
j.mu.RLock()
defer j.mu.RUnlock()
return j.paused
}
func (j *JobInfo) SetCmd(cmd *exec.Cmd) {
j.mu.Lock()
defer j.mu.Unlock()
j.Cmd = cmd
}
func (j *JobInfo) SetTempDir(tempDir string) {
j.mu.Lock()
defer j.mu.Unlock()
j.TempDir = tempDir
}
func (j *JobInfo) GetTempDir() string {
j.mu.RLock()
defer j.mu.RUnlock()
return j.TempDir
}
func (j *JobInfo) KillProcess() {
j.mu.RLock()
cmd := j.Cmd
j.mu.RUnlock()
if cmd != nil && cmd.Process != nil {
_ = cmd.Process.Kill()
}
}
func (j *JobInfo) SignalResume() {
select {
case j.ResumeChan <- struct{}{}:
default:
}
}
func (j *JobInfo) Abort() {
j.abortOnce.Do(func() {
close(j.AbortChan)
})
}
func (j *JobInfo) IsAborted() bool {
select {
case <-j.AbortChan:
return true
default:
return false
}
}
var (
jobsMutex sync.Mutex
jobs = make(map[string]*JobInfo)
ErrDownloadPaused = errors.New("download paused")
ErrDownloadAborted = errors.New("download aborted")
)
var sanitizeFilenameRegex = regexp.MustCompile(`[<>:"/\\|?*\x00-\x1f]`)
var windowsReservedNames = map[string]struct{}{
"CON": {}, "PRN": {}, "AUX": {}, "NUL": {},
"COM1": {}, "COM2": {}, "COM3": {}, "COM4": {}, "COM5": {}, "COM6": {}, "COM7": {}, "COM8": {}, "COM9": {},
"LPT1": {}, "LPT2": {}, "LPT3": {}, "LPT4": {}, "LPT5": {}, "LPT6": {}, "LPT7": {}, "LPT8": {}, "LPT9": {},
}
func sanitizeFilename(filename string) string {
filename = regexp.MustCompile(`[<>:"/\\|?*]`).ReplaceAllString(filename, "_")
filename = sanitizeFilenameRegex.ReplaceAllString(filename, "_")
filename = strings.Trim(filename, ". ")
filename = strings.Trim(filename, ".")
base := filename
if idx := strings.LastIndex(filename, "."); idx > 0 {
base = filename[:idx]
}
if _, reserved := windowsReservedNames[strings.ToUpper(base)]; reserved {
filename = "_" + filename
}
if filename == "" {
filename = "_"
}
return filename
}
func safeUploadPath(filename string) (string, error) {
cleanName := strings.TrimSpace(filename)
if cleanName == "" {
return "", fmt.Errorf("filename is required")
}
baseName := filepath.Base(cleanName)
if baseName != cleanName {
return "", fmt.Errorf("invalid filename")
}
if strings.Contains(baseName, "..") {
return "", fmt.Errorf("invalid filename")
}
return filepath.Join(uploadDir, baseName), nil
}
func isValidURL(toTest string) bool {
_, err := url.ParseRequestURI(toTest)
return err == nil
@@ -185,13 +294,14 @@ func groupItemsBySeason(items []Item) map[string][]Item {
}
func filterSelectedItems(items []Item, selectedItems []string) []Item {
var filtered []Item
set := make(map[string]struct{}, len(selectedItems))
for _, s := range selectedItems {
set[s] = struct{}{}
}
filtered := make([]Item, 0, len(selectedItems))
for _, item := range items {
for _, selected := range selectedItems {
if item.Filename == selected {
filtered = append(filtered, item)
break
}
if _, ok := set[item.Filename]; ok {
filtered = append(filtered, item)
}
}
return filtered
@@ -225,31 +335,27 @@ func extractNumber(s string) int {
return num
}
var episodeNumberRegex = regexp.MustCompile(`(?i)S\d+E(\d+)`)
func extractEpisodeNumber(filename string) int {
parts := strings.Split(filename, "E")
if len(parts) > 1 {
num, _ := strconv.Atoi(parts[1])
return num
match := episodeNumberRegex.FindStringSubmatch(filename)
if len(match) < 2 {
return 0
}
return 0
num, _ := strconv.Atoi(match[1])
return num
}
func processItems(filename string, items []Item) error {
jobsMutex.Lock()
jobInfo := &JobInfo{
AbortChan: make(chan struct{}),
ResumeChan: make(chan struct{}),
}
jobs[filename] = jobInfo
jobsMutex.Unlock()
jobInfo := NewJobInfo()
setJob(filename, jobInfo)
defer func() {
jobsMutex.Lock()
delete(jobs, filename)
jobsMutex.Unlock()
removeJob(filename)
if jobInfo.TempDir != "" {
os.RemoveAll(jobInfo.TempDir)
tempDir := jobInfo.GetTempDir()
if tempDir != "" {
_ = os.RemoveAll(tempDir)
}
}()
@@ -258,34 +364,45 @@ func processItems(filename string, items []Item) error {
for i := 0; i < len(items); i++ {
select {
case <-jobInfo.AbortChan:
updateProgress(filename, 100, "Aborted")
updateProgress(filename, 100, "Aborted", "aborted")
logger.LogJobState(filename, "aborted")
return fmt.Errorf("download aborted")
return ErrDownloadAborted
default:
if jobInfo.Paused {
if jobInfo.IsPaused() {
select {
case <-jobInfo.ResumeChan:
jobInfo.Paused = false
jobInfo.SetPaused(false)
logger.LogJobState(filename, "resumed")
case <-jobInfo.AbortChan:
updateProgress(filename, 100, "Aborted")
updateProgress(filename, 100, "Aborted", "aborted")
logger.LogJobState(filename, "aborted")
return fmt.Errorf("download aborted")
return ErrDownloadAborted
}
}
updateProgress(filename, float64(i)/float64(len(items))*100, items[i].Filename)
updateProgress(filename, float64(i)/float64(len(items))*100, items[i].Filename, "running")
err := downloadFile(filename, items[i], jobInfo)
if err != nil {
if err.Error() == "download paused" {
if errors.Is(err, ErrDownloadPaused) {
logger.LogJobState(filename, "paused")
removeCompletedEpisodes(filename, items[:i])
if remErr := removeCompletedEpisodes(filename, items[:i]); remErr != nil {
logger.LogError("Process Items", fmt.Sprintf("Error updating partial progress file: %v", remErr))
}
i--
continue
}
if errors.Is(err, ErrDownloadAborted) {
updateProgress(filename, 100, "Aborted", "aborted")
logger.LogJobState(filename, "aborted")
return ErrDownloadAborted
}
updateProgress(filename, float64(i)/float64(len(items))*100, items[i].Filename, "failed")
logger.LogError("Process Items", fmt.Sprintf("Error downloading item %s: %v", items[i].Filename, err))
return fmt.Errorf("error downloading %s: %w", items[i].Filename, err)
}
}
}
updateProgress(filename, 100, "")
updateProgress(filename, 100, "", "completed")
logger.LogJobState(filename, "completed successfully")
return nil
}

View File

@@ -3,10 +3,10 @@ package main
import (
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
@@ -22,10 +22,32 @@ func watchFolder() {
}
}
var watchedProcessing = struct {
mu sync.Mutex
files map[string]bool
}{files: make(map[string]bool)}
func beginWatching(filePath string) bool {
watchedProcessing.mu.Lock()
defer watchedProcessing.mu.Unlock()
if watchedProcessing.files[filePath] {
return false
}
watchedProcessing.files[filePath] = true
return true
}
func doneWatching(filePath string) {
watchedProcessing.mu.Lock()
defer watchedProcessing.mu.Unlock()
delete(watchedProcessing.files, filePath)
}
func inotifyWatch() {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal(err)
logger.LogError("Watcher", fmt.Sprintf("Failed to create inotify watcher: %v", err))
return
}
defer watcher.Close()
@@ -40,22 +62,23 @@ func inotifyWatch() {
}
if event.Op&fsnotify.Create == fsnotify.Create {
if strings.HasSuffix(event.Name, ".drmd") {
fmt.Println("New .drmd detected:", event.Name)
processWatchedFile(event.Name)
logger.LogInfo("Watcher", fmt.Sprintf("New .drmd detected: %s", event.Name))
go processWatchedFile(event.Name)
}
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
log.Println("Error:", err)
logger.LogError("Watcher", fmt.Sprintf("Inotify error: %v", err))
}
}
}()
err = watcher.Add(config.WatchFolder.Path)
if err != nil {
log.Fatal(err)
logger.LogError("Watcher", fmt.Sprintf("Failed to add watch folder %s: %v", config.WatchFolder.Path, err))
return
}
<-done
}
@@ -67,18 +90,28 @@ func pollFolder() {
for range ticker.C {
files, err := filepath.Glob(filepath.Join(config.WatchFolder.Path, "*.drmd"))
if err != nil {
log.Println("Error polling folder:", err)
logger.LogError("Watcher", fmt.Sprintf("Error polling folder: %v", err))
continue
}
for _, file := range files {
fmt.Println("New .drmd detected via polling:", file)
logger.LogInfo("Watcher", fmt.Sprintf("New .drmd detected via polling: %s", file))
go processWatchedFile(file)
}
}
}
func processWatchedFile(filePath string) {
if !beginWatching(filePath) {
return
}
releaseDedupe := true
defer func() {
if releaseDedupe {
doneWatching(filePath)
}
}()
for {
initialSize, err := getFileSize(filePath)
if err != nil {
@@ -112,16 +145,24 @@ func processWatchedFile(filePath string) {
logger.LogError("Watcher", fmt.Sprintf("Error creating temporary file: %v", err))
return
}
defer tempFile.Close()
_, err = io.Copy(tempFile, file)
if err != nil {
if _, err := io.Copy(tempFile, file); err != nil {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
logger.LogError("Watcher", fmt.Sprintf("Error copying file: %v", err))
return
}
if err := tempFile.Close(); err != nil {
_ = os.Remove(tempFile.Name())
logger.LogError("Watcher", fmt.Sprintf("Error closing temp file: %v", err))
return
}
if err := os.Remove(filePath); err != nil {
logger.LogError("Watcher", fmt.Sprintf("Error deleting original file: %v", err))
logger.LogError("Watcher", fmt.Sprintf("Error deleting original file; keeping dedupe entry to avoid reprocessing: %v", err))
_ = os.Remove(tempFile.Name())
releaseDedupe = false
return
}
items, err := parseInputFile(tempFile.Name())
@@ -130,7 +171,11 @@ func processWatchedFile(filePath string) {
return
}
go processItems(filepath.Base(tempFile.Name()), items)
go func(targetFilename string, targetItems []Item) {
if err := processItems(targetFilename, targetItems); err != nil {
logger.LogError("Watcher", fmt.Sprintf("Error processing watched file %s: %v", targetFilename, err))
}
}(filepath.Base(tempFile.Name()), items)
}
func getFileSize(filePath string) (int64, error) {