29 Commits

Author SHA1 Message Date
toor
070d4a862c Fix stream selection flags for bytohgn fork 2026-04-14 11:07:58 +02:00
6543d3a170 Merge pull request 'bugfixes' (#12) from bugfixes into main
Reviewed-on: #12
2026-04-14 10:26:04 +02:00
0e08b1669c Merge branch 'main' into bugfixes 2026-04-14 10:25:53 +02:00
f4310ed688 Simplify logger output format
Drop source file/line prefixes from logs so console output is cleaner and easier to scan during long-running jobs.
2026-04-14 10:21:21 +02:00
1c82b619c4 Harden web/download pipeline and split handler modules
Replace shell-based downloader execution with validated arguments, enforce request hardening and safer defaults, and refactor handlers/router/state so job control is safer and easier to maintain.
2026-04-14 10:21:11 +02:00
6e016b802b ye 2025-06-01 19:44:24 +02:00
78ed401392 Merge pull request 'bugfixes' (#11) from bugfixes into main
Reviewed-on: #11
2025-05-21 09:45:11 +02:00
72b85ec281 Add template 2025-05-21 09:44:20 +02:00
b2e3268ad1 Stop tracking config.toml 2025-05-21 09:42:35 +02:00
1af43b111c enclose MPD url 2025-05-21 09:40:45 +02:00
03312a0079 Fix accidentally registering same handler twice 2024-12-31 01:50:58 +01:00
a91163f845 config 2024-12-30 16:57:05 +01:00
7d28d1cea8 Merge pull request 'speedLimiter' (#10) from speedLimiter into main
Reviewed-on: #10
2024-12-30 16:47:45 +01:00
3fda737af2 Merge branch 'main' into speedLimiter 2024-12-30 16:47:21 +01:00
8cf3d4dda8 Show limit 2024-12-30 16:46:30 +01:00
2e18921a27 Show limit 2024-12-30 16:45:56 +01:00
f1efb1d67c impl speed limit 2024-12-30 16:32:12 +01:00
457ede5b62 Speed 2024-12-30 16:20:48 +01:00
7eb724d01f Speed 2024-12-30 16:16:51 +01:00
189bbb0874 Speed 2024-12-30 16:16:37 +01:00
68da5f9658 Speed 2024-12-30 16:16:21 +01:00
83cd0b722b style 2024-12-30 16:04:37 +01:00
ca176e1a76 Update README.md 2024-10-07 13:02:40 +02:00
54656f2630 Update README.md 2024-10-07 13:02:17 +02:00
f38b0c69d9 Merge pull request 'Poller' (#9) from Poller into main
Reviewed-on: #9
2024-10-07 12:59:18 +02:00
b1ba08933a Console should also beable to be controlled by env var 2024-10-07 12:48:37 +02:00
a049610291 Implement polling, update readme 2024-10-07 12:46:38 +02:00
c46538a55f Change the config paths according to new layout 2024-10-07 12:46:26 +02:00
fe6b7c78f6 Add options for polling, path validation, env variables 2024-10-07 12:45:49 +02:00
27 changed files with 2073 additions and 655 deletions

6
.gitignore vendored
View File

@@ -1 +1,7 @@
config.toml config.toml
drmdtool
drmdtool_*
src/DRMDTool
*.exe
uploads/
src/uploads/

View File

@@ -7,17 +7,76 @@ drmdtool is a utility for processing .drmd files using N_m3u8DL-RE.
Create a `config.toml` file in the same directory as the drmdtool executable: Create a `config.toml` file in the same directory as the drmdtool executable:
```toml ```toml
[General]
BaseDir = "/path/to/save/downloads" BaseDir = "/path/to/save/downloads"
Format = "mkv" Format = "mkv"
TempBaseDir = "/tmp/nre" TempBaseDir = "/tmp/nre"
EnableConsole = true EnableConsole = true
WatchedFolder = "/path/to/watched/folder" MaxUploadMB = 32
[N_m3u8DL-RE] [WatchFolder]
Path = "/path/to/watched/folder"
PollingInterval = 10
UsePolling = true
UseInotify = false
[N_m3u8DLRE]
Path = "/path/to/N_m3u8DL-RE" Path = "/path/to/N_m3u8DL-RE"
[Server]
Host = "127.0.0.1"
Port = 8080
ReadTimeoutSec = 30
WriteTimeoutSec = 30
IdleTimeoutSec = 60
ReadHeaderTimeoutS = 10
[Security]
AuthToken = ""
``` ```
Adjust the paths and format as needed. (mkv, mp4) ### Configuration Options
- **General**
- `BaseDir`: Directory where downloaded files will be saved.
- `Format`: Output format for the downloaded files (e.g., `mkv`, `mp4`).
- `TempBaseDir`: Temporary directory for intermediate files.
- `EnableConsole`: Boolean to enable or disable console output.
- `MaxUploadMB`: Maximum allowed upload size for the web UI.
- **WatchFolder**
- `Path`: Directory to watch for new `.drmd` files.
- `PollingInterval`: Interval in seconds for polling the watch folder.
- `UsePolling`: Boolean to enable or disable folder polling.
- `UseInotify`: Boolean to enable or disable inotify for file watching.
- **N_m3u8DLRE**
- `Path`: Path to the N_m3u8DL-RE executable.
- **Server**
- `Host`: Bind address for the web server (`127.0.0.1` recommended).
- `Port`: Web server port.
- `ReadTimeoutSec`, `WriteTimeoutSec`, `IdleTimeoutSec`, `ReadHeaderTimeoutS`: HTTP timeout settings.
- **Security**
- `AuthToken`: Optional token for protecting all endpoints. Recommended when binding to a non-loopback host.
### Environment Variable Overrides
You can override the configuration options using environment variables. The following environment variables are supported:
- `BASE_DIR`: Overrides `General.BaseDir`
- `FORMAT`: Overrides `General.Format`
- `TEMP_BASE_DIR`: Overrides `General.TempBaseDir`
- `ENABLE_CONSOLE`: Overrides `General.EnableConsole` (set to `true` or `false`)
- `MAX_UPLOAD_MB`: Overrides `General.MaxUploadMB`
- `WATCHED_FOLDER`: Overrides `WatchFolder.Path`
- `USE_POLLING`: Overrides `WatchFolder.UsePolling` (set to `true` or `false`)
- `USE_INOTIFY`: Overrides `WatchFolder.UseInotify` (set to `true` or `false`)
- `POLLING_INTERVAL`: Overrides `WatchFolder.PollingInterval`
- `SERVER_HOST`: Overrides `Server.Host`
- `SERVER_PORT`: Overrides `Server.Port`
- `AUTH_TOKEN`: Overrides `Security.AuthToken`
## Web UI Usage ## Web UI Usage
@@ -28,8 +87,10 @@ Adjust the paths and format as needed. (mkv, mp4)
2. Open a web browser and go to `http://localhost:8080` 2. Open a web browser and go to `http://localhost:8080`
3. Use the interface to upload .drmd files and monitor download progress If `Security.AuthToken` is configured, include it as a query parameter:
`http://localhost:8080/?token=YOUR_TOKEN`
3. Use the interface to upload .drmd files and monitor download progress
## CLI Usage ## CLI Usage
@@ -41,7 +102,6 @@ To process a file directly from the command line:
This will download the file and save it in the base directory specified in the config. This will download the file and save it in the base directory specified in the config.
# Previews # Previews
## Index Page ## Index Page

26
config.template.toml Normal file
View File

@@ -0,0 +1,26 @@
[General]
BaseDir = "/mnt/media"
Format = "mkv"
TempBaseDir = "/tmp/nre"
EnableConsole = true
MaxUploadMB = 32
[WatchFolder]
Path = "/mnt/watched"
PollingInterval = 10
UsePolling = false
UseInotify = false
[N_m3u8DLRE]
Path = "nre"
[Server]
Host = "127.0.0.1"
Port = 8080
ReadTimeoutSec = 30
WriteTimeoutSec = 30
IdleTimeoutSec = 60
ReadHeaderTimeoutS = 10
[Security]
AuthToken = ""

View File

@@ -1,8 +0,0 @@
BaseDir = "/mnt/media"
Format = "mkv"
TempBaseDir = "/tmp/nre"
EnableConsole = true
WatchedFolder = "/mnt/watched"
[N_m3u8DLRE]
Path = "nre"

View File

@@ -1,48 +1,250 @@
package main package main
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
"os/exec"
"strconv"
"strings"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
) )
type Config struct { type Config struct {
General struct {
BaseDir string BaseDir string
Format string Format string
TempBaseDir string TempBaseDir string
EnableConsole bool
MaxUploadMB int
}
WatchFolder struct {
Path string
UsePolling bool
UseInotify bool
PollingInterval int
}
Server struct {
Host string
Port int
ReadTimeoutSec int
WriteTimeoutSec int
IdleTimeoutSec int
ReadHeaderTimeoutS int
}
Security struct {
AuthToken string
}
N_m3u8DLRE struct { N_m3u8DLRE struct {
Path string Path string
} }
EnableConsole bool
WatchedFolder string
} }
var config Config var config Config
func loadConfig() { func loadConfig(path string) {
configFile, err := os.Open("config.toml") configFile, err := os.Open(path)
if err != nil { if err != nil {
fmt.Println("Error opening config file:", err) logger.LogError("Config", fmt.Sprintf("Error opening config file: %v", err))
return os.Exit(1)
} }
defer configFile.Close() defer configFile.Close()
byteValue, _ := io.ReadAll(configFile) byteValue, err := io.ReadAll(configFile)
if err != nil {
logger.LogError("Config", fmt.Sprintf("Error reading config file: %v", err))
os.Exit(1)
}
if _, err := toml.Decode(string(byteValue), &config); err != nil { if _, err := toml.Decode(string(byteValue), &config); err != nil {
fmt.Println("Error decoding config file:", err) logger.LogError("Config", fmt.Sprintf("Error decoding config file: %v", err))
return os.Exit(1)
} }
if config.N_m3u8DLRE.Path == "" { overrideConfigWithEnv()
fmt.Println("Error: N_m3u8DL-RE path is not specified in the config file") setDefaultConfigValues()
return
if err := validatePaths(); err != nil {
logger.LogError("Config", fmt.Sprintf("Configuration error: %v", err))
os.Exit(1)
} }
if config.WatchedFolder == "" { if config.WatchFolder.PollingInterval <= 0 {
fmt.Println("Error: Watched folder is not specified in the config file") config.WatchFolder.PollingInterval = 10
return }
logConfig()
}
func setDefaultConfigValues() {
if config.General.MaxUploadMB <= 0 {
config.General.MaxUploadMB = 32
}
if strings.TrimSpace(config.Server.Host) == "" {
config.Server.Host = "127.0.0.1"
}
if config.Server.Port <= 0 {
config.Server.Port = 8080
}
if config.Server.ReadTimeoutSec <= 0 {
config.Server.ReadTimeoutSec = 30
}
if config.Server.WriteTimeoutSec <= 0 {
config.Server.WriteTimeoutSec = 30
}
if config.Server.IdleTimeoutSec <= 0 {
config.Server.IdleTimeoutSec = 60
}
if config.Server.ReadHeaderTimeoutS <= 0 {
config.Server.ReadHeaderTimeoutS = 10
} }
} }
func overrideConfigWithEnv() {
if envBaseDir := os.Getenv("BASE_DIR"); envBaseDir != "" {
config.General.BaseDir = envBaseDir
}
if envFormat := os.Getenv("FORMAT"); envFormat != "" {
config.General.Format = envFormat
}
if envTempBaseDir := os.Getenv("TEMP_BASE_DIR"); envTempBaseDir != "" {
config.General.TempBaseDir = envTempBaseDir
}
if envEnableConsole := os.Getenv("ENABLE_CONSOLE"); envEnableConsole != "" {
config.General.EnableConsole = strings.ToLower(envEnableConsole) == "true"
}
if envWatchedFolder := os.Getenv("WATCHED_FOLDER"); envWatchedFolder != "" {
config.WatchFolder.Path = envWatchedFolder
}
if envUsePolling := os.Getenv("USE_POLLING"); envUsePolling != "" {
config.WatchFolder.UsePolling = strings.ToLower(envUsePolling) == "true"
}
if envUseInotify := os.Getenv("USE_INOTIFY"); envUseInotify != "" {
config.WatchFolder.UseInotify = strings.ToLower(envUseInotify) == "true"
}
if envPollingInterval := os.Getenv("POLLING_INTERVAL"); envPollingInterval != "" {
if interval, err := strconv.Atoi(envPollingInterval); err == nil {
config.WatchFolder.PollingInterval = interval
}
}
if envMaxUploadMB := os.Getenv("MAX_UPLOAD_MB"); envMaxUploadMB != "" {
if value, err := strconv.Atoi(envMaxUploadMB); err == nil {
config.General.MaxUploadMB = value
}
}
if envHost := os.Getenv("SERVER_HOST"); envHost != "" {
config.Server.Host = envHost
}
if envPort := os.Getenv("SERVER_PORT"); envPort != "" {
if value, err := strconv.Atoi(envPort); err == nil {
config.Server.Port = value
}
}
if envAuthToken := os.Getenv("AUTH_TOKEN"); envAuthToken != "" {
config.Security.AuthToken = envAuthToken
}
}
func validatePaths() error {
if strings.TrimSpace(config.General.Format) == "" {
return errors.New("format is not specified")
}
allowedFormats := map[string]bool{"mkv": true, "mp4": true}
if !allowedFormats[strings.ToLower(config.General.Format)] {
return fmt.Errorf("unsupported format: %s (supported: mkv, mp4)", config.General.Format)
}
paths := []struct {
name string
path string
}{
{"BaseDir", config.General.BaseDir},
{"TempBaseDir", config.General.TempBaseDir},
}
for _, p := range paths {
if p.path == "" {
return fmt.Errorf("%s is not specified", p.name)
}
if _, err := os.Stat(p.path); os.IsNotExist(err) {
if p.name == "TempBaseDir" {
if mkErr := os.MkdirAll(p.path, 0755); mkErr != nil {
return fmt.Errorf("unable to create %s: %v", p.name, mkErr)
}
continue
}
return fmt.Errorf("%s does not exist: %s", p.name, p.path)
} else if err != nil {
return fmt.Errorf("error accessing %s: %v", p.name, err)
}
}
if config.WatchFolder.UsePolling || config.WatchFolder.UseInotify {
if config.WatchFolder.Path == "" {
return fmt.Errorf("WatchedFolder is not specified")
}
if _, err := os.Stat(config.WatchFolder.Path); os.IsNotExist(err) {
return fmt.Errorf("WatchedFolder does not exist: %s", config.WatchFolder.Path)
} else if err != nil {
return fmt.Errorf("error accessing WatchedFolder: %v", err)
}
}
if strings.TrimSpace(config.N_m3u8DLRE.Path) == "" {
return errors.New("N_m3u8DLRE path is not specified")
}
if _, err := exec.LookPath(config.N_m3u8DLRE.Path); err != nil {
return fmt.Errorf("N_m3u8DLRE executable not found in PATH: %s", config.N_m3u8DLRE.Path)
}
if config.Server.Port <= 0 || config.Server.Port > 65535 {
return fmt.Errorf("invalid server port: %d", config.Server.Port)
}
return nil
}
func logConfig() {
configInfo := fmt.Sprintf(`
Configuration Loaded:
General:
BaseDir: %s
Format: %s
TempBaseDir: %s
EnableConsole: %t
MaxUploadMB: %d
WatchFolder:
Path: %s
UsePolling: %t
UseInotify: %t
PollingInterval: %d
Server:
Host: %s
Port: %d
ReadTimeoutSec: %d
WriteTimeoutSec: %d
IdleTimeoutSec: %d
ReadHeaderTimeoutS: %d
Security:
AuthTokenConfigured: %t
N_m3u8DLRE:
Path: %s
`, config.General.BaseDir, config.General.Format, config.General.TempBaseDir, config.General.EnableConsole,
config.General.MaxUploadMB,
config.WatchFolder.Path, config.WatchFolder.UsePolling, config.WatchFolder.UseInotify, config.WatchFolder.PollingInterval,
config.Server.Host, config.Server.Port, config.Server.ReadTimeoutSec, config.Server.WriteTimeoutSec, config.Server.IdleTimeoutSec, config.Server.ReadHeaderTimeoutS,
strings.TrimSpace(config.Security.AuthToken) != "",
config.N_m3u8DLRE.Path)
logger.LogInfo("Config", configInfo)
}

View File

@@ -1,18 +1,31 @@
package main package main
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"regexp"
"strings" "strings"
"time"
) )
var decryptionKeyRegex = regexp.MustCompile(`^[0-9a-fA-F]{32}:[0-9a-fA-F]{32}$`)
type websocketBroadcastWriter struct {
filename string
}
func (w websocketBroadcastWriter) Write(p []byte) (int, error) {
if config.General.EnableConsole && len(p) > 0 {
message := append([]byte(nil), p...)
broadcast(w.filename, message)
}
return len(p), nil
}
func removeBOM(input []byte) []byte { func removeBOM(input []byte) []byte {
if len(input) >= 3 && input[0] == 0xEF && input[1] == 0xBB && input[2] == 0xBF { if len(input) >= 3 && input[0] == 0xEF && input[1] == 0xBB && input[2] == 0xBF {
return input[3:] return input[3:]
@@ -23,14 +36,17 @@ func removeBOM(input []byte) []byte {
func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error { func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
logger.LogInfo("Download File", fmt.Sprintf("Starting download for: %s", item.Filename)) logger.LogInfo("Download File", fmt.Sprintf("Starting download for: %s", item.Filename))
tempDir := filepath.Join(config.TempBaseDir, sanitizeFilename(item.Filename)) if err := os.MkdirAll(config.General.TempBaseDir, 0755); err != nil {
err := os.MkdirAll(tempDir, 0755) logger.LogError("Download File", fmt.Sprintf("Error creating temp base dir: %v", err))
return fmt.Errorf("error creating temp base dir: %v", err)
}
tempDir, err := os.MkdirTemp(config.General.TempBaseDir, sanitizeFilename(item.Filename)+"_")
if err != nil { if err != nil {
logger.LogError("Download File", fmt.Sprintf("Error creating temporary directory: %v", err)) logger.LogError("Download File", fmt.Sprintf("Error creating temporary directory: %v", err))
return fmt.Errorf("error creating temporary directory: %v", err) return fmt.Errorf("error creating temporary directory: %v", err)
} }
jobInfo.TempDir = tempDir jobInfo.SetTempDir(tempDir)
mpdPath := item.MPD mpdPath := item.MPD
if !isValidURL(item.MPD) { if !isValidURL(item.MPD) {
@@ -58,18 +74,11 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
mpdPath = tempFile.Name() mpdPath = tempFile.Name()
} else if strings.HasPrefix(item.MPD, "https://pubads.g.doubleclick.net") { } else if strings.HasPrefix(item.MPD, "https://pubads.g.doubleclick.net") {
resp, err := http.Get(item.MPD) mpdContent, err := fetchRemoteContent(item.MPD)
if err != nil { if err != nil {
logger.LogError("Download File", fmt.Sprintf("Error downloading MPD: %v", err)) logger.LogError("Download File", fmt.Sprintf("Error downloading MPD: %v", err))
return fmt.Errorf("error downloading MPD: %v", err) return fmt.Errorf("error downloading MPD: %v", err)
} }
defer resp.Body.Close()
mpdContent, err := io.ReadAll(resp.Body)
if err != nil {
logger.LogError("Download File", fmt.Sprintf("Error reading MPD content: %v", err))
return fmt.Errorf("error reading MPD content: %v", err)
}
fixedMPDContent, err := fixGoPlay(string(mpdContent)) fixedMPDContent, err := fixGoPlay(string(mpdContent))
if err != nil { if err != nil {
@@ -96,7 +105,10 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
mpdPath = tempFile.Name() mpdPath = tempFile.Name()
} }
command := getDownloadCommand(item, mpdPath, tempDir) args, err := getDownloadArgs(item, mpdPath, tempDir)
if err != nil {
return err
}
if item.Subtitles != "" { if item.Subtitles != "" {
subtitlePaths, err := downloadAndConvertSubtitles(item.Subtitles) subtitlePaths, err := downloadAndConvertSubtitles(item.Subtitles)
@@ -105,17 +117,17 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
} else { } else {
for _, path := range subtitlePaths { for _, path := range subtitlePaths {
logger.LogInfo("Download File", fmt.Sprintf("Adding subtitle: %s", path)) logger.LogInfo("Download File", fmt.Sprintf("Adding subtitle: %s", path))
command += fmt.Sprintf(" --mux-import \"path=%s:lang=nl:name=Nederlands\"", path) args = append(args, "--mux-import", fmt.Sprintf("path=%s:lang=nl:name=Nederlands", path))
} }
} }
} }
cmd := exec.Command("bash", "-c", command) cmd := exec.Command(config.N_m3u8DLRE.Path, args...)
jobInfo.Cmd = cmd jobInfo.SetCmd(cmd)
var outputBuffer bytes.Buffer broadcastWriter := websocketBroadcastWriter{filename: drmdFilename}
cmd.Stdout = io.MultiWriter(&outputBuffer) cmd.Stdout = io.MultiWriter(os.Stdout, broadcastWriter)
cmd.Stderr = os.Stderr cmd.Stderr = io.MultiWriter(os.Stderr, broadcastWriter)
err = cmd.Start() err = cmd.Start()
if err != nil { if err != nil {
@@ -128,33 +140,26 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
done <- cmd.Wait() done <- cmd.Wait()
}() }()
go func() {
for {
if outputBuffer.Len() > 0 {
message := outputBuffer.Bytes()
if config.EnableConsole {
broadcast(drmdFilename, message)
}
outputBuffer.Reset()
}
time.Sleep(1 * time.Second)
}
}()
select { select {
case <-jobInfo.AbortChan: case <-jobInfo.AbortChan:
if cmd.Process != nil { jobInfo.KillProcess()
cmd.Process.Kill() _ = os.RemoveAll(tempDir)
}
os.RemoveAll(tempDir)
logger.LogInfo("Download File", "Download aborted") logger.LogInfo("Download File", "Download aborted")
return fmt.Errorf("download aborted") return ErrDownloadAborted
case err := <-done: case err := <-done:
if jobInfo.Paused { if jobInfo.IsPaused() {
logger.LogInfo("Download File", "Download paused") logger.LogInfo("Download File", "Download paused")
return fmt.Errorf("download paused") return ErrDownloadPaused
} }
if err != nil { if err != nil {
if jobInfo.IsAborted() {
return ErrDownloadAborted
}
if jobInfo.IsPaused() {
return ErrDownloadPaused
}
logger.LogError("Download File", fmt.Sprintf("Error executing download command: %v", err)) logger.LogError("Download File", fmt.Sprintf("Error executing download command: %v", err))
return fmt.Errorf("error executing download command: %v", err) return fmt.Errorf("error executing download command: %v", err)
} }
@@ -164,38 +169,48 @@ func downloadFile(drmdFilename string, item Item, jobInfo *JobInfo) error {
return nil return nil
} }
func getDownloadCommand(item Item, mpdPath string, tempDir string) string { func getDownloadArgs(item Item, mpdPath string, tempDir string) ([]string, error) {
metadata := parseMetadata(item.Metadata) metadata := parseMetadata(item.Metadata)
keys := getKeys(item.Keys) keys := getKeys(item.Keys)
command := fmt.Sprintf("%s %s", config.N_m3u8DLRE.Path, mpdPath) args := []string{mpdPath}
for _, key := range keys { for _, key := range keys {
if key != "" { if !decryptionKeyRegex.MatchString(key) {
command += fmt.Sprintf(" --key %s", key) return nil, fmt.Errorf("invalid decryption key format")
} }
args = append(args, "--key", key)
} }
command += " --auto-select" args = append(args,
"--select-video", "best",
"--select-audio", "all",
"--select-subtitle", "all",
)
sanitizedFilename := sanitizeFilename(item.Filename) sanitizedFilename := sanitizeFilename(item.Filename)
args = append(args, "--save-name", sanitizedFilename)
args = append(args, "--mux-after-done", fmt.Sprintf("format=%s", config.General.Format))
filename := fmt.Sprintf("\"%s\"", sanitizedFilename) saveDir := config.General.BaseDir
command += fmt.Sprintf(" --save-name %s", filename)
command += fmt.Sprintf(" --mux-after-done format=%s", config.Format)
saveDir := config.BaseDir
if metadata.Type == "serie" { if metadata.Type == "serie" {
saveDir = filepath.Join(saveDir, "Series", metadata.Title, metadata.Season) saveDir = filepath.Join(saveDir, "Series", metadata.Title, metadata.Season)
} else { } else {
saveDir = filepath.Join(saveDir, "Movies", metadata.Title) saveDir = filepath.Join(saveDir, "Movies", metadata.Title)
} }
command += fmt.Sprintf(" --save-dir \"%s\"", saveDir) if err := os.MkdirAll(saveDir, 0755); err != nil {
return nil, fmt.Errorf("unable to create save directory: %w", err)
}
args = append(args, "--save-dir", saveDir)
args = append(args, "--tmp-dir", tempDir)
command += fmt.Sprintf(" --tmp-dir \"%s\"", tempDir) currentSpeedLimit := getGlobalSpeedLimit()
if currentSpeedLimit != "" {
if !speedLimitRegex.MatchString(currentSpeedLimit) {
return nil, errors.New("invalid speed limit format")
}
args = append(args, "-R", currentSpeedLimit)
}
fmt.Println(command) return args, nil
return command
} }

View File

@@ -1,6 +1,6 @@
module DRMDTool module DRMDTool
go 1.23.0 go 1.25.0
require ( require (
github.com/BurntSushi/toml v1.4.0 github.com/BurntSushi/toml v1.4.0
@@ -8,13 +8,13 @@ require (
github.com/beevik/etree v1.4.1 github.com/beevik/etree v1.4.1
) )
require golang.org/x/sys v0.4.0 // indirect require golang.org/x/sys v0.43.0 // indirect
require ( require (
github.com/asticode/go-astikit v0.20.0 // indirect github.com/asticode/go-astikit v0.20.0 // indirect
github.com/asticode/go-astits v1.8.0 // indirect github.com/asticode/go-astits v1.8.0 // indirect
github.com/fsnotify/fsnotify v1.7.0 github.com/fsnotify/fsnotify v1.7.0
github.com/gorilla/websocket v1.5.3 github.com/gorilla/websocket v1.5.3
golang.org/x/net v0.0.0-20200904194848-62affa334b73 // indirect golang.org/x/net v0.53.0 // indirect
golang.org/x/text v0.3.2 // indirect golang.org/x/text v0.36.0 // indirect
) )

View File

@@ -14,8 +14,6 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/pkg/exec v0.0.0-20150614095509-0bd164ad2a5a h1:EN123kAtAAE2pg/+TvBsUBZfHCWNNFyL2ZBPPfNWAc0=
github.com/pkg/exec v0.0.0-20150614095509-0bd164ad2a5a/go.mod h1:b95YoNrAnScjaWG+asr8lxqlrsPUcT2ZEBcjvVGshMo=
github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE= github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -25,16 +23,18 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA=
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=

View File

@@ -1,445 +0,0 @@
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"github.com/gorilla/websocket"
)
type ProgressInfo struct {
Percentage float64
CurrentFile string
Paused bool
}
func handleRoot(w http.ResponseWriter, r *http.Request) {
progressMutex.Lock()
defer progressMutex.Unlock()
jobsInfo := make(map[string]struct {
Percentage float64
CurrentFile string
Paused bool
})
for filename, info := range progress {
jobsInfo[filename] = struct {
Percentage float64
CurrentFile string
Paused bool
}{
Percentage: info.Percentage,
CurrentFile: info.CurrentFile,
Paused: info.Paused,
}
}
err := templates.ExecuteTemplate(w, "index", struct {
Jobs map[string]struct {
Percentage float64
CurrentFile string
Paused bool
}
}{jobsInfo})
if err != nil {
logger.LogError("Handle Root", fmt.Sprintf("Error executing template: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleUpload(w http.ResponseWriter, r *http.Request) {
logger.LogInfo("Handle Upload", "Starting file upload")
err := r.ParseMultipartForm(32 << 20)
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error parsing multipart form: %v", err))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
files := r.MultipartForm.File["files"]
if len(files) == 0 {
logger.LogError("Handle Upload", "No files uploaded")
http.Error(w, "No files uploaded", http.StatusBadRequest)
return
}
uploadedFiles := []string{}
for _, fileHeader := range files {
file, err := fileHeader.Open()
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error opening file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer file.Close()
tempFile, err := os.CreateTemp(uploadDir, fileHeader.Filename)
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error creating temporary file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer tempFile.Close()
_, err = io.Copy(tempFile, file)
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error copying file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
uploadedFiles = append(uploadedFiles, filepath.Base(tempFile.Name()))
_, err = parseInputFile(tempFile.Name())
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error parsing input file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
validFiles := []string{}
for _, file := range uploadedFiles {
if file != "" {
validFiles = append(validFiles, file)
}
}
if len(validFiles) == 0 {
logger.LogError("Handle Upload", "No valid files were uploaded")
http.Error(w, "No valid files were uploaded", http.StatusBadRequest)
return
}
logger.LogInfo("Handle Upload", fmt.Sprintf("Redirecting to select with files: %v", validFiles))
http.Redirect(w, r, "/select?files="+url.QueryEscape(strings.Join(validFiles, ",")), http.StatusSeeOther)
}
func handleSelect(w http.ResponseWriter, r *http.Request) {
filesParam := r.URL.Query().Get("files")
filenames := strings.Split(filesParam, ",")
allItems := make(map[string]map[string][]Item)
for _, filename := range filenames {
if filename == "" {
continue
}
fullPath := filepath.Join(uploadDir, filename)
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
logger.LogError("Handle Select", fmt.Sprintf("File does not exist: %s", fullPath))
continue
}
items, err := parseInputFile(fullPath)
if err != nil {
logger.LogError("Handle Select", fmt.Sprintf("Error parsing input file: %v", err))
continue
}
sortItems(items)
groupedItems := groupItemsBySeason(items)
allItems[filename] = groupedItems
}
if len(allItems) == 0 {
logger.LogError("Handle Select", "No valid files were processed")
http.Error(w, "No valid files were processed", http.StatusBadRequest)
return
}
err := templates.ExecuteTemplate(w, "select", struct {
Filenames string
AllItems map[string]map[string][]Item
}{
Filenames: filesParam,
AllItems: allItems,
})
if err != nil {
logger.LogError("Handle Select", fmt.Sprintf("Error executing template: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleProcess(w http.ResponseWriter, r *http.Request) {
logger.LogInfo("Handle Process", "Starting process")
if err := r.ParseForm(); err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error parsing form: %v", err))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
selectedItems := r.Form["items"]
if len(selectedItems) == 0 {
logger.LogError("Handle Process", "No items selected")
http.Error(w, "No items selected", http.StatusBadRequest)
return
}
itemsByFile := make(map[string][]string)
for _, item := range selectedItems {
parts := strings.SplitN(item, ":", 2)
if len(parts) != 2 {
logger.LogError("Handle Process", "Invalid item format")
continue
}
filename, itemName := parts[0], parts[1]
itemsByFile[filename] = append(itemsByFile[filename], itemName)
}
for filename, items := range itemsByFile {
logger.LogInfo("Handle Process", fmt.Sprintf("Processing file: %s", filename))
fullPath := filepath.Join(uploadDir, filename)
allItems, err := parseInputFile(fullPath)
if err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error parsing input file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
selectedItems := filterSelectedItems(allItems, items)
sortItems(selectedItems)
go processItems(filename, selectedItems)
}
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func handleProgress(w http.ResponseWriter, r *http.Request) {
filename := r.URL.Query().Get("filename")
if r.Header.Get("Accept") == "application/json" {
progressInfo := getProgress(filename)
if progressInfo == nil {
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{"error": "No progress information found"})
return
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(progressInfo)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
return
}
err := templates.ExecuteTemplate(w, "progress", struct{ Filename string }{filename})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handlePause(w http.ResponseWriter, r *http.Request) {
filename := r.URL.Query().Get("filename")
if filename == "" {
logger.LogError("Pause Handler", "Filename is required")
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobsMutex.Lock()
jobInfo, exists := jobs[filename]
jobsMutex.Unlock()
if !exists {
logger.LogError("Pause Handler", "Job not found")
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.Paused = true
if jobInfo.Cmd != nil && jobInfo.Cmd.Process != nil {
logger.LogJobState(filename, "pausing")
jobInfo.Cmd.Process.Kill()
}
progressMutex.Lock()
if progressInfo, ok := progress[filename]; ok {
progressInfo.Paused = true
}
progressMutex.Unlock()
fmt.Fprintf(w, "Pause signal sent for %s", filename)
}
func handleResume(w http.ResponseWriter, r *http.Request) {
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobsMutex.Lock()
jobInfo, exists := jobs[filename]
jobsMutex.Unlock()
if !exists {
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.Paused = false
jobInfo.ResumeChan <- struct{}{}
progressMutex.Lock()
if progressInfo, ok := progress[filename]; ok {
progressInfo.Paused = false
}
progressMutex.Unlock()
fmt.Fprintf(w, "Resume signal sent for %s", filename)
}
func handleAbort(w http.ResponseWriter, r *http.Request) {
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobsMutex.Lock()
jobInfo, exists := jobs[filename]
jobsMutex.Unlock()
if !exists {
http.Error(w, "Job not found", http.StatusNotFound)
return
}
close(jobInfo.AbortChan)
if jobInfo.Cmd != nil && jobInfo.Cmd.Process != nil {
jobInfo.Cmd.Process.Kill()
}
if jobInfo.TempDir != "" {
os.RemoveAll(jobInfo.TempDir)
}
fmt.Fprintf(w, "Abort signal sent for %s", filename)
}
func handleClearCompleted(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
clearCompletedJobs()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func clearCompletedJobs() {
progressMutex.Lock()
defer progressMutex.Unlock()
for filename, info := range progress {
if info.Percentage >= 100 {
delete(progress, filename)
}
}
}
func updateProgress(filename string, value float64, currentFile string) {
progressMutex.Lock()
defer progressMutex.Unlock()
jobsMutex.Lock()
jobInfo, exists := jobs[filename]
jobsMutex.Unlock()
paused := false
if exists {
paused = jobInfo.Paused
}
if existingProgress, ok := progress[filename]; ok {
existingProgress.Percentage = value
existingProgress.CurrentFile = currentFile
existingProgress.Paused = paused
} else {
progress[filename] = &ProgressInfo{
Percentage: value,
CurrentFile: currentFile,
Paused: paused,
}
}
}
var upgrader = websocket.Upgrader{}
var clients = make(map[string]map[*websocket.Conn]bool)
var mu sync.Mutex
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
fmt.Println(config.EnableConsole)
if !config.EnableConsole {
http.Error(w, "Console output is disabled", http.StatusForbidden)
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.LogError("WebSocket", fmt.Sprintf("Error while upgrading connection: %v", err))
return
}
defer conn.Close()
logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection established for filename: %s", filename))
mu.Lock()
if clients[filename] == nil {
clients[filename] = make(map[*websocket.Conn]bool)
}
clients[filename][conn] = true
mu.Unlock()
for {
if _, _, err := conn.NextReader(); err != nil {
break
}
}
mu.Lock()
delete(clients[filename], conn)
mu.Unlock()
logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection closed for filename: %s", filename))
}
func broadcast(filename string, message []byte) {
if !config.EnableConsole {
return
}
mu.Lock()
defer mu.Unlock()
for client := range clients[filename] {
if err := client.WriteMessage(websocket.TextMessage, message); err != nil {
client.Close()
delete(clients[filename], client)
logger.LogError("Broadcast", fmt.Sprintf("Error writing message to client: %v", err))
}
}
}

30
src/handlers_common.go Normal file
View File

@@ -0,0 +1,30 @@
package main
import (
"net/url"
"regexp"
"strings"
)
type ProgressInfo struct {
Percentage float64
CurrentFile string
Paused bool
Status string
}
var speedLimitRegex = regexp.MustCompile(`^([1-9]\d*(\.\d+)?)(KBps|MBps|GBps)$`)
func withToken(path string) string {
token := strings.TrimSpace(config.Security.AuthToken)
if token == "" {
return path
}
separator := "?"
if strings.Contains(path, "?") {
separator = "&"
}
return path + separator + "token=" + url.QueryEscape(token)
}

174
src/handlers_jobs.go Normal file
View File

@@ -0,0 +1,174 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
)
func handlePause(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
logger.LogError("Pause Handler", "Filename is required")
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobInfo, exists := getJob(filename)
if !exists {
logger.LogError("Pause Handler", "Job not found")
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.SetPaused(true)
logger.LogJobState(filename, "pausing")
jobInfo.KillProcess()
paused := true
setProgressStatus(filename, &paused, "paused")
_, _ = fmt.Fprintf(w, "Pause signal sent for %s", filename)
}
func handleResume(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobInfo, exists := getJob(filename)
if !exists {
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.SetPaused(false)
jobInfo.SignalResume()
paused := false
setProgressStatus(filename, &paused, "running")
_, _ = fmt.Fprintf(w, "Resume signal sent for %s", filename)
}
func handleAbort(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
jobInfo, exists := getJob(filename)
if !exists {
http.Error(w, "Job not found", http.StatusNotFound)
return
}
jobInfo.Abort()
jobInfo.KillProcess()
if tempDir := jobInfo.GetTempDir(); tempDir != "" {
_ = os.RemoveAll(tempDir)
}
paused := false
setProgressStatus(filename, &paused, "aborted")
_, _ = fmt.Fprintf(w, "Abort signal sent for %s", filename)
}
func handleClearCompleted(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
clearCompletedJobs()
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]bool{"success": true})
}
func clearCompletedJobs() {
progressMutex.Lock()
defer progressMutex.Unlock()
for filename, info := range progress {
if info.Percentage >= 100 || info.Status == "failed" || info.Status == "aborted" {
delete(progress, filename)
}
}
}
func updateProgress(filename string, value float64, currentFile, status string) {
paused := false
if jobInfo, exists := getJob(filename); exists {
paused = jobInfo.IsPaused()
}
progressMutex.Lock()
defer progressMutex.Unlock()
if existingProgress, ok := progress[filename]; ok {
existingProgress.Percentage = value
existingProgress.CurrentFile = currentFile
existingProgress.Paused = paused
if status != "" {
existingProgress.Status = status
}
} else {
progress[filename] = &ProgressInfo{
Percentage: value,
CurrentFile: currentFile,
Paused: paused,
Status: status,
}
}
}
func handleSetSpeedLimit(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
logger.LogInfo("Set Speed Limit", "Received request to set speed limit")
var requestData struct {
SpeedLimit string `json:"speedLimit"`
}
if err := json.NewDecoder(r.Body).Decode(&requestData); err != nil {
logger.LogError("Set Speed Limit", "Invalid request body")
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
requestData.SpeedLimit = strings.TrimSpace(requestData.SpeedLimit)
if requestData.SpeedLimit == "unlimited" {
setGlobalSpeedLimit("")
} else {
if !speedLimitRegex.MatchString(requestData.SpeedLimit) {
http.Error(w, "Invalid speed limit format", http.StatusBadRequest)
return
}
setGlobalSpeedLimit(requestData.SpeedLimit)
}
logger.LogInfo("Set Speed Limit", fmt.Sprintf("Global speed limit set to: %s", getGlobalSpeedLimit()))
w.WriteHeader(http.StatusOK)
}

128
src/handlers_pages.go Normal file
View File

@@ -0,0 +1,128 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strings"
)
func handleRoot(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) {
return
}
err := templates.ExecuteTemplate(w, "index", struct {
Jobs map[string]ProgressInfo
GlobalSpeedLimit string
AuthToken string
Nonce string
}{
Jobs: snapshotProgress(),
GlobalSpeedLimit: getGlobalSpeedLimit(),
AuthToken: url.QueryEscape(config.Security.AuthToken),
Nonce: cspNonce(r),
})
if err != nil {
logger.LogError("Handle Root", fmt.Sprintf("Error executing template: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleSelect(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) {
return
}
filesParam := r.URL.Query().Get("files")
filenames := strings.Split(filesParam, ",")
allItems := make(map[string]map[string][]Item)
for _, filename := range filenames {
if filename == "" {
continue
}
fullPath, pathErr := safeUploadPath(filename)
if pathErr != nil {
logger.LogError("Handle Select", fmt.Sprintf("Invalid filename %s: %v", filename, pathErr))
continue
}
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
logger.LogError("Handle Select", fmt.Sprintf("File does not exist: %s", fullPath))
continue
}
items, err := parseInputFile(fullPath)
if err != nil {
logger.LogError("Handle Select", fmt.Sprintf("Error parsing input file: %v", err))
continue
}
sortItems(items)
groupedItems := groupItemsBySeason(items)
allItems[filename] = groupedItems
}
if len(allItems) == 0 {
logger.LogError("Handle Select", "No valid files were processed")
http.Error(w, "No valid files were processed", http.StatusBadRequest)
return
}
err := templates.ExecuteTemplate(w, "select", struct {
Filenames string
AllItems map[string]map[string][]Item
AuthToken string
Nonce string
}{
Filenames: filesParam,
AllItems: allItems,
AuthToken: url.QueryEscape(config.Security.AuthToken),
Nonce: cspNonce(r),
})
if err != nil {
logger.LogError("Handle Select", fmt.Sprintf("Error executing template: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func handleProgress(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) {
return
}
filename := r.URL.Query().Get("filename")
if r.Header.Get("Accept") == "application/json" {
progressInfo := getProgress(filename)
if progressInfo == nil {
w.WriteHeader(http.StatusNotFound)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "No progress information found"})
return
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(progressInfo)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
return
}
err := templates.ExecuteTemplate(w, "progress", struct {
Filename string
AuthToken string
Nonce string
}{Filename: filename, AuthToken: url.QueryEscape(config.Security.AuthToken), Nonce: cspNonce(r)})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View File

@@ -0,0 +1,158 @@
package main
import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
)
func handleUpload(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
logger.LogInfo("Handle Upload", "Starting file upload")
r.Body = http.MaxBytesReader(w, r.Body, maxUploadBytes())
const multipartMemory = 4 << 20
err := r.ParseMultipartForm(multipartMemory)
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error parsing multipart form: %v", err))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
files := r.MultipartForm.File["files"]
if len(files) == 0 {
logger.LogError("Handle Upload", "No files uploaded")
http.Error(w, "No files uploaded", http.StatusBadRequest)
return
}
uploadedFiles := []string{}
for _, fileHeader := range files {
if strings.ToLower(filepath.Ext(fileHeader.Filename)) != ".drmd" {
http.Error(w, "Only .drmd files are allowed", http.StatusBadRequest)
return
}
file, err := fileHeader.Open()
if err != nil {
logger.LogError("Handle Upload", fmt.Sprintf("Error opening file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
pattern := sanitizeFilename(fileHeader.Filename) + "_*.drmd"
tempFile, err := os.CreateTemp(uploadDir, pattern)
if err != nil {
_ = file.Close()
logger.LogError("Handle Upload", fmt.Sprintf("Error creating temporary file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_, err = io.Copy(tempFile, file)
_ = file.Close()
if err != nil {
_ = tempFile.Close()
_ = os.Remove(tempFile.Name())
logger.LogError("Handle Upload", fmt.Sprintf("Error copying file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if err := tempFile.Close(); err != nil {
_ = os.Remove(tempFile.Name())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
uploadedFiles = append(uploadedFiles, filepath.Base(tempFile.Name()))
_, err = parseInputFile(tempFile.Name())
if err != nil {
_ = os.Remove(tempFile.Name())
logger.LogError("Handle Upload", fmt.Sprintf("Error parsing input file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
validFiles := []string{}
for _, file := range uploadedFiles {
if file != "" {
validFiles = append(validFiles, file)
}
}
if len(validFiles) == 0 {
logger.LogError("Handle Upload", "No valid files were uploaded")
http.Error(w, "No valid files were uploaded", http.StatusBadRequest)
return
}
logger.LogInfo("Handle Upload", fmt.Sprintf("Redirecting to select with files: %v", validFiles))
http.Redirect(w, r, withToken("/select?files="+url.QueryEscape(strings.Join(validFiles, ","))), http.StatusSeeOther)
}
func handleProcess(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodPost) || !ensureAuthorized(w, r) {
return
}
logger.LogInfo("Handle Process", "Starting process")
if err := r.ParseForm(); err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error parsing form: %v", err))
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
selectedItems := r.Form["items"]
if len(selectedItems) == 0 {
logger.LogError("Handle Process", "No items selected")
http.Error(w, "No items selected", http.StatusBadRequest)
return
}
itemsByFile := make(map[string][]string)
for _, item := range selectedItems {
parts := strings.SplitN(item, ":", 2)
if len(parts) != 2 {
logger.LogError("Handle Process", "Invalid item format")
continue
}
filename, itemName := parts[0], parts[1]
itemsByFile[filename] = append(itemsByFile[filename], itemName)
}
for filename, items := range itemsByFile {
logger.LogInfo("Handle Process", fmt.Sprintf("Processing file: %s", filename))
fullPath, pathErr := safeUploadPath(filename)
if pathErr != nil {
http.Error(w, pathErr.Error(), http.StatusBadRequest)
return
}
allItems, err := parseInputFile(fullPath)
if err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error parsing input file: %v", err))
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
selectedItems := filterSelectedItems(allItems, items)
sortItems(selectedItems)
go func(targetFilename string, targetItems []Item) {
if err := processItems(targetFilename, targetItems); err != nil {
logger.LogError("Handle Process", fmt.Sprintf("Error processing %s: %v", targetFilename, err))
}
}(filename, selectedItems)
}
http.Redirect(w, r, withToken("/"), http.StatusSeeOther)
}

148
src/handlers_ws.go Normal file
View File

@@ -0,0 +1,148 @@
package main
import (
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
const (
wsSendBuffer = 64
wsWriteWait = 10 * time.Second
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
return strings.TrimSpace(config.Security.AuthToken) == ""
}
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false
}
return strings.EqualFold(parsedOrigin.Host, r.Host)
},
}
type wsClient struct {
conn *websocket.Conn
send chan []byte
once sync.Once
}
func (c *wsClient) close() {
c.once.Do(func() {
close(c.send)
})
}
var clients = make(map[string]map[*wsClient]struct{})
var mu sync.Mutex
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
if !ensureMethod(w, r, http.MethodGet) || !ensureAuthorized(w, r) {
return
}
if !config.General.EnableConsole {
http.Error(w, "Console output is disabled", http.StatusForbidden)
return
}
filename := r.URL.Query().Get("filename")
if filename == "" {
http.Error(w, "Filename is required", http.StatusBadRequest)
return
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.LogError("WebSocket", fmt.Sprintf("Error while upgrading connection: %v", err))
return
}
logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection established for filename: %s", filename))
client := &wsClient{conn: conn, send: make(chan []byte, wsSendBuffer)}
mu.Lock()
if clients[filename] == nil {
clients[filename] = make(map[*wsClient]struct{})
}
clients[filename][client] = struct{}{}
mu.Unlock()
go writePump(filename, client)
for {
if _, _, err := conn.NextReader(); err != nil {
break
}
}
mu.Lock()
if set, ok := clients[filename]; ok {
if _, exists := set[client]; exists {
delete(set, client)
client.close()
}
if len(set) == 0 {
delete(clients, filename)
}
}
mu.Unlock()
_ = conn.Close()
logger.LogInfo("WebSocket", fmt.Sprintf("WebSocket connection closed for filename: %s", filename))
}
func writePump(filename string, client *wsClient) {
defer client.conn.Close()
for message := range client.send {
if err := client.conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
logger.LogError("WebSocket", fmt.Sprintf("SetWriteDeadline error: %v", err))
return
}
if err := client.conn.WriteMessage(websocket.TextMessage, message); err != nil {
logger.LogError("Broadcast", fmt.Sprintf("Error writing to client for %s: %v", filename, err))
return
}
}
}
func broadcast(filename string, message []byte) {
if !config.General.EnableConsole {
return
}
mu.Lock()
set := clients[filename]
targets := make([]*wsClient, 0, len(set))
for c := range set {
targets = append(targets, c)
}
mu.Unlock()
for _, client := range targets {
select {
case client.send <- message:
default:
mu.Lock()
if set, ok := clients[filename]; ok {
if _, exists := set[client]; exists {
delete(set, client)
client.close()
}
}
mu.Unlock()
logger.LogError("Broadcast", fmt.Sprintf("Dropping slow client for %s", filename))
}
}
}

109
src/integration_test.go Normal file
View File

@@ -0,0 +1,109 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func resetStateForTest() {
jobsMutex.Lock()
jobs = make(map[string]*JobInfo)
jobsMutex.Unlock()
progressMutex.Lock()
progress = make(map[string]*ProgressInfo)
progressMutex.Unlock()
setGlobalSpeedLimit("")
config = Config{}
setDefaultConfigValues()
}
func TestAuthTokenProtection(t *testing.T) {
resetStateForTest()
config.Security.AuthToken = "secret"
handler := newRouter()
req := httptest.NewRequest(http.MethodPost, "/clear-completed", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rr.Code)
}
req = httptest.NewRequest(http.MethodPost, "/clear-completed?token=secret", nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rr.Code)
}
}
func TestPauseResumeAbortFlow(t *testing.T) {
resetStateForTest()
handler := newRouter()
filename := "job.drmd"
setJob(filename, NewJobInfo())
updateProgress(filename, 10, "episode1", "running")
req := httptest.NewRequest(http.MethodPost, "/pause?filename="+filename, nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("pause expected %d got %d", http.StatusOK, rr.Code)
}
job, ok := getJob(filename)
if !ok || !job.IsPaused() {
t.Fatalf("expected paused job state")
}
progressInfo := getProgress(filename)
if progressInfo == nil || progressInfo.Status != "paused" {
t.Fatalf("expected paused progress state")
}
req = httptest.NewRequest(http.MethodPost, "/resume?filename="+filename, nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("resume expected %d got %d", http.StatusOK, rr.Code)
}
if job.IsPaused() {
t.Fatalf("expected resumed job state")
}
progressInfo = getProgress(filename)
if progressInfo == nil || progressInfo.Status != "running" {
t.Fatalf("expected running progress state")
}
req = httptest.NewRequest(http.MethodPost, "/abort?filename="+filename, nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("abort expected %d got %d", http.StatusOK, rr.Code)
}
if !job.IsAborted() {
t.Fatalf("expected aborted job state")
}
progressInfo = getProgress(filename)
if progressInfo == nil || progressInfo.Status != "aborted" {
t.Fatalf("expected aborted progress state")
}
req = httptest.NewRequest(http.MethodPost, "/abort?filename="+filename, nil)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("second abort expected %d got %d", http.StatusOK, rr.Code)
}
}

View File

@@ -19,7 +19,7 @@ const (
func NewLogger(prefix string) *Logger { func NewLogger(prefix string) *Logger {
return &Logger{ return &Logger{
Logger: log.New(os.Stdout, prefix, log.Ldate|log.Ltime|log.Lshortfile), Logger: log.New(os.Stdout, prefix, log.Ldate|log.Ltime),
} }
} }

View File

@@ -1,13 +1,17 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
"os" "os"
"os/signal"
"strings" "strings"
"sync" "sync"
"syscall"
"time"
"embed" "embed"
) )
@@ -34,9 +38,6 @@ type Metadata struct {
Season string Season string
} }
var progressMutex sync.Mutex
var progress = make(map[string]*ProgressInfo)
const uploadDir = "uploads" const uploadDir = "uploads"
var templates *template.Template var templates *template.Template
@@ -44,6 +45,21 @@ var templates *template.Template
//go:embed templates //go:embed templates
var templateFS embed.FS var templateFS embed.FS
var globalSpeedLimit string
var globalSpeedLimitMutex sync.RWMutex
func getGlobalSpeedLimit() string {
globalSpeedLimitMutex.RLock()
defer globalSpeedLimitMutex.RUnlock()
return globalSpeedLimit
}
func setGlobalSpeedLimit(value string) {
globalSpeedLimitMutex.Lock()
defer globalSpeedLimitMutex.Unlock()
globalSpeedLimit = value
}
func init() { func init() {
if err := os.MkdirAll(uploadDir, 0755); err != nil { if err := os.MkdirAll(uploadDir, 0755); err != nil {
fmt.Printf("Error creating upload directory: %v\n", err) fmt.Printf("Error creating upload directory: %v\n", err)
@@ -55,10 +71,12 @@ func init() {
} }
func main() { func main() {
loadConfig() configPath := flag.String("config", "config.toml", "Path to config file")
inputFile := flag.String("f", "", "Path to the input JSON file") inputFile := flag.String("f", "", "Path to the input JSON file")
flag.Parse() flag.Parse()
loadConfig(*configPath)
if *inputFile == "" { if *inputFile == "" {
go watchFolder() go watchFolder()
startWebServer() startWebServer()
@@ -68,34 +86,67 @@ func main() {
logger.LogError("Main", fmt.Sprintf("Error parsing input file: %v", err)) logger.LogError("Main", fmt.Sprintf("Error parsing input file: %v", err))
return return
} }
processItems(*inputFile, items) if err := processItems(*inputFile, items); err != nil {
logger.LogError("Main", fmt.Sprintf("Error processing items: %v", err))
}
} }
} }
func startWebServer() { func startWebServer() {
http.HandleFunc("/", handleRoot) server := &http.Server{
http.HandleFunc("/upload", handleUpload) Addr: serverAddr(),
http.HandleFunc("/select", handleSelect) Handler: newRouter(),
http.HandleFunc("/process", handleProcess) ReadTimeout: time.Duration(config.Server.ReadTimeoutSec) * time.Second,
http.HandleFunc("/progress", handleProgress) ReadHeaderTimeout: time.Duration(config.Server.ReadHeaderTimeoutS) * time.Second,
http.HandleFunc("/abort", handleAbort) WriteTimeout: time.Duration(config.Server.WriteTimeoutSec) * time.Second,
http.HandleFunc("/pause", handlePause) IdleTimeout: time.Duration(config.Server.IdleTimeoutSec) * time.Second,
http.HandleFunc("/resume", handleResume) }
http.HandleFunc("/clear-completed", handleClearCompleted)
http.HandleFunc("/ws", handleWebSocket)
fmt.Println("Starting web server on http://0.0.0.0:8080") ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
http.ListenAndServe(":8080", nil) defer stop()
go func() {
<-ctx.Done()
logger.LogInfo("Main", "Shutdown signal received, aborting jobs")
abortAllJobs()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
logger.LogError("Main", fmt.Sprintf("Server shutdown error: %v", err))
}
}()
logger.LogInfo("Main", fmt.Sprintf("Starting web server on http://%s", server.Addr))
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.LogError("Main", fmt.Sprintf("Web server error: %v", err))
}
} }
func getProgress(filename string) *ProgressInfo { func abortAllJobs() {
progressMutex.Lock() jobsMutex.RLock()
defer progressMutex.Unlock() jobList := make([]*JobInfo, 0, len(jobs))
return progress[filename] for _, j := range jobs {
jobList = append(jobList, j)
}
jobsMutex.RUnlock()
for _, j := range jobList {
j.Abort()
j.KillProcess()
}
} }
func getKeys(keys string) []string { func getKeys(keys string) []string {
return strings.Split(keys, ",") parts := strings.Split(keys, ",")
out := parts[:0]
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
} }
func parseMetadata(metadata string) Metadata { func parseMetadata(metadata string) Metadata {

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"encoding/json" "encoding/json"
"os" "os"
"path/filepath"
"reflect" "reflect"
"testing" "testing"
) )
@@ -123,3 +124,53 @@ func TestGroupItemsBySeason(t *testing.T) {
} }
} }
} }
func TestSafeUploadPath(t *testing.T) {
tests := []struct {
name string
input string
wantError bool
}{
{name: "valid filename", input: "file.drmd", wantError: false},
{name: "directory traversal", input: "../file.drmd", wantError: true},
{name: "empty", input: "", wantError: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
path, err := safeUploadPath(tc.input)
if tc.wantError && err == nil {
t.Fatalf("expected error for input %q", tc.input)
}
if !tc.wantError {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if filepath.Base(path) != tc.input {
t.Fatalf("expected basename %q, got %q", tc.input, filepath.Base(path))
}
}
})
}
}
func TestValidateRemoteURL(t *testing.T) {
tests := []struct {
name string
input string
wantError bool
}{
{name: "reject localhost", input: "http://localhost/test.vtt", wantError: true},
{name: "reject invalid scheme", input: "file:///tmp/test", wantError: true},
{name: "reject malformed", input: "::://", wantError: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := validateRemoteURL(tc.input)
if tc.wantError && err == nil {
t.Fatalf("expected error for %q", tc.input)
}
})
}
}

19
src/router.go Normal file
View File

@@ -0,0 +1,19 @@
package main
import "net/http"
func newRouter() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/", handleRoot)
mux.HandleFunc("/upload", handleUpload)
mux.HandleFunc("/select", handleSelect)
mux.HandleFunc("/process", handleProcess)
mux.HandleFunc("/progress", handleProgress)
mux.HandleFunc("/abort", handleAbort)
mux.HandleFunc("/pause", handlePause)
mux.HandleFunc("/resume", handleResume)
mux.HandleFunc("/clear-completed", handleClearCompleted)
mux.HandleFunc("/ws", handleWebSocket)
mux.HandleFunc("/set-speed-limit", handleSetSpeedLimit)
return withSecurityHeaders(mux)
}

237
src/security.go Normal file
View File

@@ -0,0 +1,237 @@
package main
import (
"context"
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
)
var blockedCIDRs = mustParseCIDRs([]string{
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"169.254.0.0/16",
"100.64.0.0/10",
"0.0.0.0/8",
"::1/128",
"fc00::/7",
"fe80::/10",
})
var secureDialer = &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
var secureTransport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
for _, ip := range ips {
if isBlockedIP(ip.IP) {
return nil, fmt.Errorf("dial blocked: %s resolves to %s", host, ip.IP)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IPs for %s", host)
}
return secureDialer.DialContext(ctx, network, net.JoinHostPort(ips[0].IP.String(), port))
},
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
}
var secureHTTPClient = &http.Client{
Timeout: 30 * time.Second,
Transport: secureTransport,
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
_, err := validateRemoteURL(req.URL.String())
return err
},
}
const maxRemoteFetchBytes int64 = 20 << 20
func mustParseCIDRs(cidrs []string) []*net.IPNet {
networks := make([]*net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
panic(err)
}
networks = append(networks, network)
}
return networks
}
func maxUploadBytes() int64 {
return int64(config.General.MaxUploadMB) << 20
}
func serverAddr() string {
return fmt.Sprintf("%s:%d", config.Server.Host, config.Server.Port)
}
func ensureMethod(w http.ResponseWriter, r *http.Request, method string) bool {
if r.Method != method {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return false
}
return true
}
func ensureAuthorized(w http.ResponseWriter, r *http.Request) bool {
token := strings.TrimSpace(config.Security.AuthToken)
if token == "" {
return true
}
providedToken := strings.TrimSpace(r.Header.Get("X-DRMD-Token"))
if providedToken == "" {
providedToken = strings.TrimSpace(r.URL.Query().Get("token"))
}
if subtle.ConstantTimeCompare([]byte(providedToken), []byte(token)) != 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return false
}
return true
}
func validateRemoteURL(rawURL string) (*url.URL, error) {
parsedURL, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return nil, errors.New("only http(s) URLs are allowed")
}
host := parsedURL.Hostname()
if host == "" {
return nil, errors.New("URL host is required")
}
if strings.EqualFold(host, "localhost") {
return nil, errors.New("localhost URLs are not allowed")
}
ips, err := net.LookupIP(host)
if err != nil {
return nil, fmt.Errorf("unable to resolve URL host: %w", err)
}
for _, ip := range ips {
if isBlockedIP(ip) {
return nil, fmt.Errorf("URL host resolves to blocked IP range: %s", ip.String())
}
}
return parsedURL, nil
}
func isBlockedIP(ip net.IP) bool {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() || ip.IsMulticast() {
return true
}
for _, network := range blockedCIDRs {
if network.Contains(ip) {
return true
}
}
return false
}
func fetchRemoteContent(rawURL string) ([]byte, error) {
validatedURL, err := validateRemoteURL(rawURL)
if err != nil {
return nil, err
}
req, err := http.NewRequest(http.MethodGet, validatedURL.String(), nil)
if err != nil {
return nil, err
}
resp, err := secureHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
if resp.ContentLength > maxRemoteFetchBytes {
return nil, fmt.Errorf("remote content too large: %d bytes", resp.ContentLength)
}
limited := io.LimitReader(resp.Body, maxRemoteFetchBytes+1)
body, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(body)) > maxRemoteFetchBytes {
return nil, fmt.Errorf("remote content exceeded %d bytes", maxRemoteFetchBytes)
}
return body, nil
}
type ctxKey string
const nonceCtxKey ctxKey = "csp-nonce"
func generateNonce() string {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return ""
}
return base64.RawStdEncoding.EncodeToString(b[:])
}
func cspNonce(r *http.Request) string {
if v, ok := r.Context().Value(nonceCtxKey).(string); ok {
return v
}
return ""
}
func withSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nonce := generateNonce()
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Referrer-Policy", "no-referrer")
csp := fmt.Sprintf(
"default-src 'self'; connect-src 'self'; img-src 'self' data:; style-src 'self' 'nonce-%s'; script-src 'self' 'nonce-%s'; base-uri 'self'; form-action 'self'; frame-ancestors 'none'",
nonce, nonce,
)
w.Header().Set("Content-Security-Policy", csp)
ctx := context.WithValue(r.Context(), nonceCtxKey, nonce)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

67
src/state.go Normal file
View File

@@ -0,0 +1,67 @@
package main
import "sync"
var (
jobsMutex sync.RWMutex
jobs = make(map[string]*JobInfo)
progressMutex sync.RWMutex
progress = make(map[string]*ProgressInfo)
)
func setJob(filename string, jobInfo *JobInfo) {
jobsMutex.Lock()
defer jobsMutex.Unlock()
jobs[filename] = jobInfo
}
func getJob(filename string) (*JobInfo, bool) {
jobsMutex.RLock()
defer jobsMutex.RUnlock()
jobInfo, ok := jobs[filename]
return jobInfo, ok
}
func removeJob(filename string) {
jobsMutex.Lock()
defer jobsMutex.Unlock()
delete(jobs, filename)
}
func getProgress(filename string) *ProgressInfo {
progressMutex.RLock()
defer progressMutex.RUnlock()
if p, ok := progress[filename]; ok {
snapshot := *p
return &snapshot
}
return nil
}
func snapshotProgress() map[string]ProgressInfo {
progressMutex.RLock()
defer progressMutex.RUnlock()
result := make(map[string]ProgressInfo, len(progress))
for filename, info := range progress {
result[filename] = *info
}
return result
}
func setProgressStatus(filename string, paused *bool, status string) {
progressMutex.Lock()
defer progressMutex.Unlock()
info, ok := progress[filename]
if !ok {
return
}
if paused != nil {
info.Paused = *paused
}
if status != "" {
info.Status = status
}
}

View File

@@ -1,9 +1,9 @@
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"strings" "strings"
@@ -36,12 +36,11 @@ func downloadAndConvertSubtitles(subtitlesURLs string) ([]string, error) {
func downloadSubtitle(url string) (string, error) { func downloadSubtitle(url string) (string, error) {
logger.LogInfo("Download Subtitle", fmt.Sprintf("Starting download from %s", url)) logger.LogInfo("Download Subtitle", fmt.Sprintf("Starting download from %s", url))
resp, err := http.Get(url) body, err := fetchRemoteContent(url)
if err != nil { if err != nil {
logger.LogError("Download Subtitle", fmt.Sprintf("Error getting subtitle URL: %v", err)) logger.LogError("Download Subtitle", fmt.Sprintf("Error getting subtitle URL: %v", err))
return "", err return "", err
} }
defer resp.Body.Close()
tempFile, err := os.CreateTemp("", "subtitle_*.vtt") tempFile, err := os.CreateTemp("", "subtitle_*.vtt")
if err != nil { if err != nil {
@@ -50,7 +49,7 @@ func downloadSubtitle(url string) (string, error) {
} }
defer tempFile.Close() defer tempFile.Close()
_, err = io.Copy(tempFile, resp.Body) _, err = io.Copy(tempFile, bytes.NewReader(body))
if err != nil { if err != nil {
logger.LogError("Download Subtitle", fmt.Sprintf("Error copying to temp file: %v", err)) logger.LogError("Download Subtitle", fmt.Sprintf("Error copying to temp file: %v", err))
return "", err return "", err
@@ -62,8 +61,13 @@ func downloadSubtitle(url string) (string, error) {
func convertVTTtoSRT(vttPath string) (string, error) { func convertVTTtoSRT(vttPath string) (string, error) {
srtPath := strings.TrimSuffix(vttPath, ".vtt") + ".srt" srtPath := strings.TrimSuffix(vttPath, ".vtt") + ".srt"
s1, _ := astisub.OpenFile(vttPath) s1, err := astisub.OpenFile(vttPath)
s1.Write(srtPath) if err != nil {
return "", err
}
if err := s1.Write(srtPath); err != nil {
return "", err
}
logger.LogInfo("Convert VTT to SRT", fmt.Sprintf("Converted %s to %s", vttPath, srtPath)) logger.LogInfo("Convert VTT to SRT", fmt.Sprintf("Converted %s to %s", vttPath, srtPath))
return srtPath, nil return srtPath, nil
} }

View File

@@ -4,7 +4,7 @@
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Simple Downloader</title> <title>Simple Downloader</title>
<style> <style nonce="{{.Nonce}}">
body { body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
background-color: #1e1e1e; background-color: #1e1e1e;
@@ -28,7 +28,6 @@
input[type="file"], input[type="submit"] { input[type="file"], input[type="submit"] {
background-color: #2d2d2d; background-color: #2d2d2d;
color: #d4d4d4; color: #d4d4d4;
border: 1px solid #444;
padding: 8px 12px; padding: 8px 12px;
border-radius: 4px; border-radius: 4px;
margin-bottom: 10px; margin-bottom: 10px;
@@ -77,6 +76,11 @@
.paused { .paused {
color: #ffa500; color: #ffa500;
} }
.speed-limit {
font-size: 1em;
color: #a0a0a0;
margin-top: 10px;
}
@media (max-width: 600px) { @media (max-width: 600px) {
body { body {
padding: 10px; padding: 10px;
@@ -107,11 +111,99 @@
#clear-completed:hover { #clear-completed:hover {
background-color: #d32f2f; background-color: #d32f2f;
} }
/* New CSS for speed limit form */
.settings-section {
margin-top: 30px;
}
.speed-limit-form {
display: flex;
align-items: center;
justify-content: space-between;
gap: 10px;
margin-bottom: 20px;
}
.speed-limit-form .form-group {
display: flex;
align-items: center;
gap: 10px;
}
.speed-limit-form input[type="number"],
.speed-limit-form select,
.speed-limit-form button {
background-color: #2d2d2d;
color: #d4d4d4;
border: 1px solid #444;
padding: 8px 12px;
border-radius: 4px;
}
.speed-limit-form button {
cursor: pointer;
background-color: #4CAF50;
color: white;
}
.speed-limit-form button:hover {
background-color: #45a049;
}
.speed-limit-container {
display: flex;
align-items: center;
margin-bottom: 20px;
background-color: #2d2d2d;
padding: 8px 12px;
border-radius: 4px;
}
.speed-limit-container .form-group {
display: flex;
align-items: center;
gap: 10px;
width: 100%;
}
.speed-limit-container input[type="number"] {
background-color: #2d2d2d;
color: #d4d4d4;
border: 1px solid #444;
padding: 8px 12px;
border-radius: 4px;
height: 40px;
box-sizing: border-box;
flex-grow: 1;
}
.speed-limit-container select,
.speed-limit-container button {
background-color: #2d2d2d;
color: #d4d4d4;
border: 1px solid #444;
padding: 8px 12px;
border-radius: 4px;
height: 40px;
box-sizing: border-box;
}
.speed-limit-container button {
cursor: pointer;
background-color: #4CAF50;
color: white;
}
.speed-limit-container button:hover {
background-color: #45a049;
}
.speed-limit-container .speed-limit {
color: #d4d4d4;
margin-left: auto;
display: flex;
align-items: center;
}
.speed-limit-container .speed-limit span {
margin-left: 5px;
}
.current-speed-limit {
color: #d4d4d4;
margin-top: 10px;
}
</style> </style>
</head> </head>
<body> <body>
<h1>Simple Downloader</h1> <h1>Simple Downloader</h1>
<form action="/upload" method="post" enctype="multipart/form-data"> <form action="{{if .AuthToken}}/upload?token={{.AuthToken}}{{else}}/upload{{end}}" method="post" enctype="multipart/form-data">
<input type="file" name="files" accept=".drmd" multiple> <input type="file" name="files" accept=".drmd" multiple>
<input type="submit" value="Upload and Process"> <input type="submit" value="Upload and Process">
</form> </form>
@@ -120,11 +212,12 @@
{{range $filename, $info := .Jobs}} {{range $filename, $info := .Jobs}}
<li> <li>
<div class="job-title"> <div class="job-title">
<a href="/progress?filename={{$filename}}">{{$filename}}</a> <a href="{{if $.AuthToken}}/progress?filename={{$filename}}&token={{$.AuthToken}}{{else}}/progress?filename={{$filename}}{{end}}">{{$filename}}</a>
</div> </div>
<div class="job-info"> <div class="job-info">
Progress: <span class="progress-text">{{printf "%5.1f%%" $info.Percentage}}</span> Progress: <span class="progress-text">{{printf "%5.1f%%" $info.Percentage}}</span>
Current file: {{$info.CurrentFile}} Current file: {{$info.CurrentFile}}
Status: {{$info.Status}}
{{if $info.Paused}} {{if $info.Paused}}
<span class="paused">(Paused)</span> <span class="paused">(Paused)</span>
{{end}} {{end}}
@@ -134,10 +227,39 @@
<li>No active jobs</li> <li>No active jobs</li>
{{end}} {{end}}
</ul> </ul>
<button id="clear-completed" onclick="clearCompleted()">Clear Completed Jobs</button> <button id="clear-completed">Clear Completed Jobs</button>
<script>
<div class="settings-section">
<h2>Settings</h2>
<div class="speed-limit-container">
<div class="form-group">
<label for="speedLimitValue">Speed Limit:</label>
<input type="number" id="speedLimitValue" name="speedLimitValue" min="0" step="0.01" required>
<select id="speedLimitUnit" name="speedLimitUnit">
<option value="GBps">GBps</option>
<option value="MBps" selected>MBps</option>
<option value="KBps">KBps</option>
</select>
<button id="set-speed-limit" type="button">Set Limit</button>
</div>
</div>
</div>
<script nonce="{{.Nonce}}">
const authToken = "{{.AuthToken}}";
const currentSpeedLimitRaw = "{{if .GlobalSpeedLimit}}{{.GlobalSpeedLimit}}{{else}}0{{end}}";
function withToken(path) {
if (!authToken) {
return path;
}
const separator = path.includes('?') ? '&' : '?';
return `${path}${separator}token=${authToken}`;
}
function clearCompleted() { function clearCompleted() {
fetch('/clear-completed', { method: 'POST' }) fetch(withToken('/clear-completed'), { method: 'POST' })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
if (data.success) { if (data.success) {
@@ -147,6 +269,55 @@
} }
}); });
} }
function updateSpeedLimit(event) {
event.preventDefault();
const speedLimitValue = document.getElementById('speedLimitValue').value;
const speedLimitUnit = document.getElementById('speedLimitUnit').value;
const speedLimit = speedLimitValue === "0" ? "unlimited" : speedLimitValue + speedLimitUnit;
if (!validateSpeedLimit(speedLimitValue)) {
alert('Please enter a valid speed limit.');
return;
}
fetch(withToken('/set-speed-limit'), {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ speedLimit }),
}).then(response => {
if (response.ok) {
alert('Speed limit updated successfully');
} else {
alert('Failed to update speed limit');
}
});
}
function validateSpeedLimit(value) {
const number = parseFloat(value);
return !isNaN(number) && number >= 0;
}
document.addEventListener('DOMContentLoaded', function() {
const speedLimitValueInput = document.getElementById('speedLimitValue');
const speedLimitUnitSelect = document.getElementById('speedLimitUnit');
const match = currentSpeedLimitRaw.match(/(\d+(\.\d+)?)([A-Za-z]+)/);
if (match) {
speedLimitValueInput.value = match[1];
speedLimitUnitSelect.value = match[3];
} else {
speedLimitValueInput.value = "0";
speedLimitUnitSelect.value = "MBps";
}
document.getElementById('clear-completed').addEventListener('click', clearCompleted);
document.getElementById('set-speed-limit').addEventListener('click', updateSpeedLimit);
});
</script> </script>
</body> </body>
</html> </html>

View File

@@ -4,7 +4,7 @@
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Processing {{.Filename}}</title> <title>Processing {{.Filename}}</title>
<style> <style nonce="{{.Nonce}}">
body { body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
background-color: #1e1e1e; background-color: #1e1e1e;
@@ -136,19 +136,29 @@
<div id="currentFile"></div> <div id="currentFile"></div>
</div> </div>
<div> <div>
<button id="abort-button" onclick="abortDownload()">Abort Download</button> <button id="abort-button">Abort Download</button>
<button id="pause-button" onclick="pauseDownload()">Pause Download</button> <button id="pause-button">Pause Download</button>
<button id="resume-button" onclick="resumeDownload()" style="display: none;">Resume Download</button> <button id="resume-button">Resume Download</button>
<button id="toggle-console">Toggle Console View</button> <button id="toggle-console">Toggle Console View</button>
<button id="back-button" onclick="window.location.href='/'">Back to Index</button> <button id="back-button">Back to Index</button>
</div> </div>
<div style="display: none;" id="console"></div> <div id="console"></div>
<script> <script nonce="{{.Nonce}}">
let isPaused = false; let isPaused = false;
const filename = "{{.Filename}}"; const filename = "{{.Filename}}";
const authToken = "{{.AuthToken}}";
function withToken(path) {
if (!authToken) {
return path;
}
const separator = path.includes('?') ? '&' : '?';
return `${path}${separator}token=${authToken}`;
}
function updateProgress() { function updateProgress() {
fetch(`/progress?filename=${filename}`, { fetch(withToken(`/progress?filename=${encodeURIComponent(filename)}`), {
headers: { headers: {
'Accept': 'application/json' 'Accept': 'application/json'
} }
@@ -180,7 +190,7 @@
} }
function abortDownload() { function abortDownload() {
fetch(`/abort?filename=${filename}`, { method: 'POST' }) fetch(withToken(`/abort?filename=${encodeURIComponent(filename)}`), { method: 'POST' })
.then(response => { .then(response => {
if (response.ok) { if (response.ok) {
console.log('Abort signal sent. The download will stop soon.'); console.log('Abort signal sent. The download will stop soon.');
@@ -191,7 +201,7 @@
} }
function pauseDownload() { function pauseDownload() {
fetch(`/pause?filename=${filename}`, { method: 'POST' }) fetch(withToken(`/pause?filename=${encodeURIComponent(filename)}`), { method: 'POST' })
.then(response => { .then(response => {
if (response.ok) { if (response.ok) {
console.log('Pause signal sent. The download will pause soon.'); console.log('Pause signal sent. The download will pause soon.');
@@ -204,7 +214,7 @@
} }
function resumeDownload() { function resumeDownload() {
fetch(`/resume?filename=${filename}`, { method: 'POST' }) fetch(withToken(`/resume?filename=${encodeURIComponent(filename)}`), { method: 'POST' })
.then(response => { .then(response => {
if (response.ok) { if (response.ok) {
console.log('Resume signal sent. The download will resume soon.'); console.log('Resume signal sent. The download will resume soon.');
@@ -219,7 +229,7 @@
const consoleDiv = document.getElementById('console'); const consoleDiv = document.getElementById('console');
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const ws = new WebSocket(`${protocol}//${window.location.host}/ws?filename=${filename}`); const ws = new WebSocket(`${protocol}//${window.location.host}${withToken(`/ws?filename=${encodeURIComponent(filename)}`)}`);
ws.onmessage = function(event) { ws.onmessage = function(event) {
consoleDiv.textContent += event.data; consoleDiv.textContent += event.data;
@@ -234,13 +244,17 @@
console.error('WebSocket error:', error); console.error('WebSocket error:', error);
}; };
document.getElementById('toggle-console').onclick = function() { document.addEventListener('DOMContentLoaded', function() {
if (consoleDiv.style.display === "none") { document.getElementById('abort-button').addEventListener('click', abortDownload);
consoleDiv.style.display = "block"; document.getElementById('pause-button').addEventListener('click', pauseDownload);
} else { document.getElementById('resume-button').addEventListener('click', resumeDownload);
consoleDiv.style.display = "none"; document.getElementById('back-button').addEventListener('click', function() {
} window.location.href = withToken('/');
}; });
document.getElementById('toggle-console').addEventListener('click', function() {
consoleDiv.style.display = consoleDiv.style.display === "none" ? "block" : "none";
});
});
updateProgress(); updateProgress();
</script> </script>

View File

@@ -4,7 +4,7 @@
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Select Items to Download</title> <title>Select Items to Download</title>
<style> <style nonce="{{.Nonce}}">
body { body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
background-color: #1e1e1e; background-color: #1e1e1e;
@@ -66,14 +66,14 @@
</head> </head>
<body> <body>
<h1>Select Items to Download</h1> <h1>Select Items to Download</h1>
<form action="/process" method="post"> <form action="{{if .AuthToken}}/process?token={{.AuthToken}}{{else}}/process{{end}}" method="post">
<input type="hidden" name="filenames" value="{{.Filenames}}"> <input type="hidden" name="filenames" value="{{.Filenames}}">
{{range $filename, $fileItems := .AllItems}} {{range $filename, $fileItems := .AllItems}}
<h2>{{$filename}}</h2> <h2>{{$filename}}</h2>
{{range $season, $items := $fileItems}} {{range $season, $items := $fileItems}}
<div class="season" id="season-{{$filename}}-{{$season}}"> <div class="season" id="season-{{$filename}}-{{$season}}">
<div class="season-title"> <div class="season-title">
<input type="checkbox" class="season-checkbox" id="season-checkbox-{{$filename}}-{{$season}}" checked onchange="toggleSeason('{{$filename}}-{{$season}}')"> <input type="checkbox" class="season-checkbox" id="season-checkbox-{{$filename}}-{{$season}}" data-season-key="{{$filename}}-{{$season}}" checked>
<label for="season-checkbox-{{$filename}}-{{$season}}">{{$season}}</label> <label for="season-checkbox-{{$filename}}-{{$season}}">{{$season}}</label>
</div> </div>
<div class="season-items"> <div class="season-items">
@@ -90,12 +90,12 @@
{{end}} {{end}}
{{end}} {{end}}
<div> <div>
<button type="button" onclick="selectAll(true)">Select All</button> <button type="button" id="select-all">Select All</button>
<button type="button" onclick="selectAll(false)">Select None</button> <button type="button" id="select-none">Select None</button>
<input type="submit" value="Start Download"> <input type="submit" value="Start Download">
</div> </div>
</form> </form>
<script> <script nonce="{{.Nonce}}">
function selectAll(checked) { function selectAll(checked) {
var checkboxes = document.getElementsByName('items'); var checkboxes = document.getElementsByName('items');
for (var i = 0; i < checkboxes.length; i++) { for (var i = 0; i < checkboxes.length; i++) {
@@ -114,6 +114,17 @@
episodeCheckboxes[i].checked = seasonCheckbox.checked; episodeCheckboxes[i].checked = seasonCheckbox.checked;
} }
} }
document.addEventListener('DOMContentLoaded', function() {
document.getElementById('select-all').addEventListener('click', function() { selectAll(true); });
document.getElementById('select-none').addEventListener('click', function() { selectAll(false); });
var boxes = document.getElementsByClassName('season-checkbox');
for (var i = 0; i < boxes.length; i++) {
boxes[i].addEventListener('change', function() {
toggleSeason(this.dataset.seasonKey);
});
}
});
</script> </script>
</body> </body>
</html> </html>

View File

@@ -2,6 +2,7 @@ package main
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/url" "net/url"
@@ -21,23 +22,131 @@ type JobInfo struct {
AbortChan chan struct{} AbortChan chan struct{}
ResumeChan chan struct{} ResumeChan chan struct{}
Cmd *exec.Cmd Cmd *exec.Cmd
Paused bool
TempDir string TempDir string
mu sync.RWMutex
paused bool
abortOnce sync.Once
}
func NewJobInfo() *JobInfo {
return &JobInfo{
AbortChan: make(chan struct{}),
ResumeChan: make(chan struct{}, 1),
}
}
func (j *JobInfo) SetPaused(value bool) {
j.mu.Lock()
defer j.mu.Unlock()
j.paused = value
}
func (j *JobInfo) IsPaused() bool {
j.mu.RLock()
defer j.mu.RUnlock()
return j.paused
}
func (j *JobInfo) SetCmd(cmd *exec.Cmd) {
j.mu.Lock()
defer j.mu.Unlock()
j.Cmd = cmd
}
func (j *JobInfo) SetTempDir(tempDir string) {
j.mu.Lock()
defer j.mu.Unlock()
j.TempDir = tempDir
}
func (j *JobInfo) GetTempDir() string {
j.mu.RLock()
defer j.mu.RUnlock()
return j.TempDir
}
func (j *JobInfo) KillProcess() {
j.mu.RLock()
cmd := j.Cmd
j.mu.RUnlock()
if cmd != nil && cmd.Process != nil {
_ = cmd.Process.Kill()
}
}
func (j *JobInfo) SignalResume() {
select {
case j.ResumeChan <- struct{}{}:
default:
}
}
func (j *JobInfo) Abort() {
j.abortOnce.Do(func() {
close(j.AbortChan)
})
}
func (j *JobInfo) IsAborted() bool {
select {
case <-j.AbortChan:
return true
default:
return false
}
} }
var ( var (
jobsMutex sync.Mutex ErrDownloadPaused = errors.New("download paused")
jobs = make(map[string]*JobInfo) ErrDownloadAborted = errors.New("download aborted")
) )
var sanitizeFilenameRegex = regexp.MustCompile(`[<>:"/\\|?*\x00-\x1f]`)
var windowsReservedNames = map[string]struct{}{
"CON": {}, "PRN": {}, "AUX": {}, "NUL": {},
"COM1": {}, "COM2": {}, "COM3": {}, "COM4": {}, "COM5": {}, "COM6": {}, "COM7": {}, "COM8": {}, "COM9": {},
"LPT1": {}, "LPT2": {}, "LPT3": {}, "LPT4": {}, "LPT5": {}, "LPT6": {}, "LPT7": {}, "LPT8": {}, "LPT9": {},
}
func sanitizeFilename(filename string) string { func sanitizeFilename(filename string) string {
filename = regexp.MustCompile(`[<>:"/\\|?*]`).ReplaceAllString(filename, "_") filename = sanitizeFilenameRegex.ReplaceAllString(filename, "_")
filename = strings.Trim(filename, ". ")
filename = strings.Trim(filename, ".") base := filename
if idx := strings.LastIndex(filename, "."); idx > 0 {
base = filename[:idx]
}
if _, reserved := windowsReservedNames[strings.ToUpper(base)]; reserved {
filename = "_" + filename
}
if filename == "" {
filename = "_"
}
return filename return filename
} }
func safeUploadPath(filename string) (string, error) {
cleanName := strings.TrimSpace(filename)
if cleanName == "" {
return "", fmt.Errorf("filename is required")
}
baseName := filepath.Base(cleanName)
if baseName != cleanName {
return "", fmt.Errorf("invalid filename")
}
if strings.Contains(baseName, "..") {
return "", fmt.Errorf("invalid filename")
}
return filepath.Join(uploadDir, baseName), nil
}
func isValidURL(toTest string) bool { func isValidURL(toTest string) bool {
_, err := url.ParseRequestURI(toTest) _, err := url.ParseRequestURI(toTest)
return err == nil return err == nil
@@ -185,13 +294,14 @@ func groupItemsBySeason(items []Item) map[string][]Item {
} }
func filterSelectedItems(items []Item, selectedItems []string) []Item { func filterSelectedItems(items []Item, selectedItems []string) []Item {
var filtered []Item set := make(map[string]struct{}, len(selectedItems))
for _, item := range items { for _, s := range selectedItems {
for _, selected := range selectedItems { set[s] = struct{}{}
if item.Filename == selected {
filtered = append(filtered, item)
break
} }
filtered := make([]Item, 0, len(selectedItems))
for _, item := range items {
if _, ok := set[item.Filename]; ok {
filtered = append(filtered, item)
} }
} }
return filtered return filtered
@@ -225,31 +335,27 @@ func extractNumber(s string) int {
return num return num
} }
var episodeNumberRegex = regexp.MustCompile(`(?i)S\d+E(\d+)`)
func extractEpisodeNumber(filename string) int { func extractEpisodeNumber(filename string) int {
parts := strings.Split(filename, "E") match := episodeNumberRegex.FindStringSubmatch(filename)
if len(parts) > 1 { if len(match) < 2 {
num, _ := strconv.Atoi(parts[1])
return num
}
return 0 return 0
}
num, _ := strconv.Atoi(match[1])
return num
} }
func processItems(filename string, items []Item) error { func processItems(filename string, items []Item) error {
jobsMutex.Lock() jobInfo := NewJobInfo()
jobInfo := &JobInfo{ setJob(filename, jobInfo)
AbortChan: make(chan struct{}),
ResumeChan: make(chan struct{}),
}
jobs[filename] = jobInfo
jobsMutex.Unlock()
defer func() { defer func() {
jobsMutex.Lock() removeJob(filename)
delete(jobs, filename)
jobsMutex.Unlock()
if jobInfo.TempDir != "" { tempDir := jobInfo.GetTempDir()
os.RemoveAll(jobInfo.TempDir) if tempDir != "" {
_ = os.RemoveAll(tempDir)
} }
}() }()
@@ -258,34 +364,45 @@ func processItems(filename string, items []Item) error {
for i := 0; i < len(items); i++ { for i := 0; i < len(items); i++ {
select { select {
case <-jobInfo.AbortChan: case <-jobInfo.AbortChan:
updateProgress(filename, 100, "Aborted") updateProgress(filename, 100, "Aborted", "aborted")
logger.LogJobState(filename, "aborted") logger.LogJobState(filename, "aborted")
return fmt.Errorf("download aborted") return ErrDownloadAborted
default: default:
if jobInfo.Paused { if jobInfo.IsPaused() {
select { select {
case <-jobInfo.ResumeChan: case <-jobInfo.ResumeChan:
jobInfo.Paused = false jobInfo.SetPaused(false)
logger.LogJobState(filename, "resumed") logger.LogJobState(filename, "resumed")
case <-jobInfo.AbortChan: case <-jobInfo.AbortChan:
updateProgress(filename, 100, "Aborted") updateProgress(filename, 100, "Aborted", "aborted")
logger.LogJobState(filename, "aborted") logger.LogJobState(filename, "aborted")
return fmt.Errorf("download aborted") return ErrDownloadAborted
} }
} }
updateProgress(filename, float64(i)/float64(len(items))*100, items[i].Filename) updateProgress(filename, float64(i)/float64(len(items))*100, items[i].Filename, "running")
err := downloadFile(filename, items[i], jobInfo) err := downloadFile(filename, items[i], jobInfo)
if err != nil { if err != nil {
if err.Error() == "download paused" { if errors.Is(err, ErrDownloadPaused) {
logger.LogJobState(filename, "paused") logger.LogJobState(filename, "paused")
removeCompletedEpisodes(filename, items[:i]) if remErr := removeCompletedEpisodes(filename, items[:i]); remErr != nil {
logger.LogError("Process Items", fmt.Sprintf("Error updating partial progress file: %v", remErr))
}
i-- i--
continue continue
} }
if errors.Is(err, ErrDownloadAborted) {
updateProgress(filename, 100, "Aborted", "aborted")
logger.LogJobState(filename, "aborted")
return ErrDownloadAborted
}
updateProgress(filename, float64(i)/float64(len(items))*100, items[i].Filename, "failed")
logger.LogError("Process Items", fmt.Sprintf("Error downloading item %s: %v", items[i].Filename, err))
return fmt.Errorf("error downloading %s: %w", items[i].Filename, err)
} }
} }
} }
updateProgress(filename, 100, "") updateProgress(filename, 100, "", "completed")
logger.LogJobState(filename, "completed successfully") logger.LogJobState(filename, "completed successfully")
return nil return nil
} }

View File

@@ -3,19 +3,51 @@ package main
import ( import (
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
) )
func watchFolder() { func watchFolder() {
if config.WatchFolder.UsePolling {
go pollFolder()
}
if config.WatchFolder.UseInotify {
go inotifyWatch()
}
}
var watchedProcessing = struct {
mu sync.Mutex
files map[string]bool
}{files: make(map[string]bool)}
func beginWatching(filePath string) bool {
watchedProcessing.mu.Lock()
defer watchedProcessing.mu.Unlock()
if watchedProcessing.files[filePath] {
return false
}
watchedProcessing.files[filePath] = true
return true
}
func doneWatching(filePath string) {
watchedProcessing.mu.Lock()
defer watchedProcessing.mu.Unlock()
delete(watchedProcessing.files, filePath)
}
func inotifyWatch() {
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
log.Fatal(err) logger.LogError("Watcher", fmt.Sprintf("Failed to create inotify watcher: %v", err))
return
} }
defer watcher.Close() defer watcher.Close()
@@ -30,27 +62,56 @@ func watchFolder() {
} }
if event.Op&fsnotify.Create == fsnotify.Create { if event.Op&fsnotify.Create == fsnotify.Create {
if strings.HasSuffix(event.Name, ".drmd") { if strings.HasSuffix(event.Name, ".drmd") {
fmt.Println("New .drmd detected:", event.Name) logger.LogInfo("Watcher", fmt.Sprintf("New .drmd detected: %s", event.Name))
processWatchedFile(event.Name) go processWatchedFile(event.Name)
} }
} }
case err, ok := <-watcher.Errors: case err, ok := <-watcher.Errors:
if !ok { if !ok {
return return
} }
log.Println("Error:", err) logger.LogError("Watcher", fmt.Sprintf("Inotify error: %v", err))
} }
} }
}() }()
err = watcher.Add(config.WatchedFolder) err = watcher.Add(config.WatchFolder.Path)
if err != nil { if err != nil {
log.Fatal(err) logger.LogError("Watcher", fmt.Sprintf("Failed to add watch folder %s: %v", config.WatchFolder.Path, err))
return
} }
<-done <-done
} }
func pollFolder() {
ticker := time.NewTicker(time.Duration(config.WatchFolder.PollingInterval) * time.Second)
defer ticker.Stop()
for range ticker.C {
files, err := filepath.Glob(filepath.Join(config.WatchFolder.Path, "*.drmd"))
if err != nil {
logger.LogError("Watcher", fmt.Sprintf("Error polling folder: %v", err))
continue
}
for _, file := range files {
logger.LogInfo("Watcher", fmt.Sprintf("New .drmd detected via polling: %s", file))
go processWatchedFile(file)
}
}
}
func processWatchedFile(filePath string) { func processWatchedFile(filePath string) {
if !beginWatching(filePath) {
return
}
releaseDedupe := true
defer func() {
if releaseDedupe {
doneWatching(filePath)
}
}()
for { for {
initialSize, err := getFileSize(filePath) initialSize, err := getFileSize(filePath)
if err != nil { if err != nil {
@@ -84,16 +145,24 @@ func processWatchedFile(filePath string) {
logger.LogError("Watcher", fmt.Sprintf("Error creating temporary file: %v", err)) logger.LogError("Watcher", fmt.Sprintf("Error creating temporary file: %v", err))
return return
} }
defer tempFile.Close()
_, err = io.Copy(tempFile, file) if _, err := io.Copy(tempFile, file); err != nil {
if err != nil { _ = tempFile.Close()
_ = os.Remove(tempFile.Name())
logger.LogError("Watcher", fmt.Sprintf("Error copying file: %v", err)) logger.LogError("Watcher", fmt.Sprintf("Error copying file: %v", err))
return return
} }
if err := tempFile.Close(); err != nil {
_ = os.Remove(tempFile.Name())
logger.LogError("Watcher", fmt.Sprintf("Error closing temp file: %v", err))
return
}
if err := os.Remove(filePath); err != nil { if err := os.Remove(filePath); err != nil {
logger.LogError("Watcher", fmt.Sprintf("Error deleting original file: %v", err)) logger.LogError("Watcher", fmt.Sprintf("Error deleting original file; keeping dedupe entry to avoid reprocessing: %v", err))
_ = os.Remove(tempFile.Name())
releaseDedupe = false
return
} }
items, err := parseInputFile(tempFile.Name()) items, err := parseInputFile(tempFile.Name())
@@ -102,7 +171,11 @@ func processWatchedFile(filePath string) {
return return
} }
go processItems(filepath.Base(tempFile.Name()), items) go func(targetFilename string, targetItems []Item) {
if err := processItems(targetFilename, targetItems); err != nil {
logger.LogError("Watcher", fmt.Sprintf("Error processing watched file %s: %v", targetFilename, err))
}
}(filepath.Base(tempFile.Name()), items)
} }
func getFileSize(filePath string) (int64, error) { func getFileSize(filePath string) (int64, error) {