diff --git a/downloaders.go b/downloaders.go index 0aa2284..d92b1c6 100644 --- a/downloaders.go +++ b/downloaders.go @@ -5,9 +5,11 @@ import ( "encoding/json" "fmt" "io" + "net/http" "os" "os/exec" "path/filepath" + "strings" ) func processInputFile(inputFile string) error { @@ -45,7 +47,6 @@ func downloadFile(item Item) error { mpdPath := item.MPD if !isValidURL(item.MPD) { - decodedMPD, err := base64.StdEncoding.DecodeString(item.MPD) if err != nil { return fmt.Errorf("error decoding base64 MPD: %v", err) @@ -64,6 +65,37 @@ func downloadFile(item Item) error { return fmt.Errorf("error closing temporary MPD file: %v", err) } + mpdPath = tempFile.Name() + } else if strings.HasPrefix(item.MPD, "https://pubads.g.doubleclick.net") { + resp, err := http.Get(item.MPD) + if err != nil { + return fmt.Errorf("error downloading MPD: %v", err) + } + defer resp.Body.Close() + + mpdContent, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error reading MPD content: %v", err) + } + + fixedMPDContent, err := fixGoPlay(string(mpdContent)) + if err != nil { + return fmt.Errorf("error fixing MPD content: %v", err) + } + + tempFile, err := os.CreateTemp("", "fixed_mpd_*.mpd") + if err != nil { + return fmt.Errorf("error creating temporary MPD file: %v", err) + } + defer os.Remove(tempFile.Name()) + + if _, err := tempFile.WriteString(fixedMPDContent); err != nil { + return fmt.Errorf("error writing to temporary MPD file: %v", err) + } + if err := tempFile.Close(); err != nil { + return fmt.Errorf("error closing temporary MPD file: %v", err) + } + mpdPath = tempFile.Name() } diff --git a/go.mod b/go.mod index 28053d4..bc9873f 100644 --- a/go.mod +++ b/go.mod @@ -2,11 +2,14 @@ module DRMDTool go 1.23.0 -require github.com/BurntSushi/toml v1.4.0 +require ( + github.com/BurntSushi/toml v1.4.0 + github.com/asticode/go-astisub v0.26.2 + github.com/beevik/etree v1.4.1 +) require ( github.com/asticode/go-astikit v0.20.0 // indirect - github.com/asticode/go-astisub v0.26.2 // indirect github.com/asticode/go-astits v1.8.0 // indirect golang.org/x/net v0.0.0-20200904194848-62affa334b73 // indirect golang.org/x/text v0.3.2 // indirect diff --git a/go.sum b/go.sum index b0cd2c0..92cc1b6 100644 --- a/go.sum +++ b/go.sum @@ -6,10 +6,15 @@ github.com/asticode/go-astisub v0.26.2 h1:cdEXcm+SUSmYCEPTQYbbfCECnmQoIFfH6pF8wD github.com/asticode/go-astisub v0.26.2/go.mod h1:WTkuSzFB+Bp7wezuSf2Oxulj5A8zu2zLRVFf6bIFQK8= github.com/asticode/go-astits v1.8.0 h1:rf6aiiGn/QhlFjNON1n5plqF3Fs025XLUwiQ0NB6oZg= github.com/asticode/go-astits v1.8.0/go.mod h1:DkOWmBNQpnr9mv24KfZjq4JawCFX1FCqjLVGvO0DygQ= +github.com/beevik/etree v1.4.1 h1:PmQJDDYahBGNKDcpdX8uPy1xRCwoCGVUiW669MEirVI= +github.com/beevik/etree v1.4.1/go.mod h1:gPNJNaBGVZ9AwsidazFZyygnd+0pAU38N4D+WemwKNs= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pkg/profile v1.4.0/go.mod h1:NWz/XGvpEW1FyYQ7fCx4dqYBLlfTcE+A9FLAkNKqjFE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -24,4 +29,5 @@ golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/utils.go b/utils.go index f6c0c80..321a237 100644 --- a/utils.go +++ b/utils.go @@ -1,9 +1,13 @@ package main import ( + "fmt" "net/url" "regexp" + "strconv" "strings" + + "github.com/beevik/etree" ) func sanitizeFilename(filename string) string { @@ -18,3 +22,99 @@ func isValidURL(toTest string) bool { _, err := url.ParseRequestURI(toTest) return err == nil } + +func fixGoPlay(mpdContent string) (string, error) { + doc := etree.NewDocument() + if err := doc.ReadFromString(mpdContent); err != nil { + return "", fmt.Errorf("error parsing MPD content: %v", err) + } + + root := doc.Root() + + // Remove ad periods + for _, period := range root.SelectElements("Period") { + if strings.Contains(period.SelectAttrValue("id", ""), "-ad-") { + root.RemoveChild(period) + } + } + + // Find highest bandwidth for video + highestBandwidth := 0 + for _, adaptationSet := range root.FindElements("//AdaptationSet") { + if strings.Contains(adaptationSet.SelectAttrValue("mimeType", ""), "video") { + for _, representation := range adaptationSet.SelectElements("Representation") { + bandwidth, _ := strconv.Atoi(representation.SelectAttrValue("bandwidth", "0")) + if bandwidth > highestBandwidth { + highestBandwidth = bandwidth + } + } + } + } + + // Remove lower bitrate representations + for _, adaptationSet := range root.FindElements("//AdaptationSet") { + if strings.Contains(adaptationSet.SelectAttrValue("mimeType", ""), "video") { + for _, representation := range adaptationSet.SelectElements("Representation") { + bandwidth, _ := strconv.Atoi(representation.SelectAttrValue("bandwidth", "0")) + if bandwidth != highestBandwidth { + adaptationSet.RemoveChild(representation) + } + } + } + } + + // Combine periods + periods := root.SelectElements("Period") + if len(periods) > 1 { + firstPeriod := periods[0] + var newVideoTimeline, newAudioTimeline *etree.Element + + // Find or create SegmentTimeline elements + for _, adaptationSet := range firstPeriod.SelectElements("AdaptationSet") { + mimeType := adaptationSet.SelectAttrValue("mimeType", "") + if strings.Contains(mimeType, "video") && newVideoTimeline == nil { + newVideoTimeline = findOrCreateSegmentTimeline(adaptationSet) + } else if strings.Contains(mimeType, "audio") && newAudioTimeline == nil { + newAudioTimeline = findOrCreateSegmentTimeline(adaptationSet) + } + } + + for _, period := range periods[1:] { + for _, adaptationSet := range period.SelectElements("AdaptationSet") { + mimeType := adaptationSet.SelectAttrValue("mimeType", "") + var timeline *etree.Element + if strings.Contains(mimeType, "video") { + timeline = newVideoTimeline + } else if strings.Contains(mimeType, "audio") { + timeline = newAudioTimeline + } + + if timeline != nil { + segmentTimeline := findOrCreateSegmentTimeline(adaptationSet) + for _, s := range segmentTimeline.SelectElements("S") { + timeline.AddChild(s.Copy()) + } + } + } + root.RemoveChild(period) + } + } + + return doc.WriteToString() +} + +func findOrCreateSegmentTimeline(adaptationSet *etree.Element) *etree.Element { + for _, representation := range adaptationSet.SelectElements("Representation") { + for _, segmentTemplate := range representation.SelectElements("SegmentTemplate") { + timeline := segmentTemplate.SelectElement("SegmentTimeline") + if timeline != nil { + return timeline + } + } + } + + // If no SegmentTimeline found, create one + representation := adaptationSet.CreateElement("Representation") + segmentTemplate := representation.CreateElement("SegmentTemplate") + return segmentTemplate.CreateElement("SegmentTimeline") +}