diff --git a/main.go b/main.go index 0f038c0..86fd737 100644 --- a/main.go +++ b/main.go @@ -813,6 +813,7 @@ func ripAlbum(albumId string, token string, storefront string, mediaUserToken st err := album.GetResp(token, Config.Language) if err != nil { fmt.Println("Failed to get album response.") + return err } meta := album.Resp //debug mode diff --git a/utils/runv3/runv3.go b/utils/runv3/runv3.go index 6a8010b..771d019 100644 --- a/utils/runv3/runv3.go +++ b/utils/runv3/runv3.go @@ -26,7 +26,7 @@ import ( "os/exec" "strings" "sync" - "time" + //"time" "github.com/grafov/m3u8" "github.com/schollz/progressbar/v3" @@ -347,6 +347,95 @@ func Run(adamId string, trackpath string, authtoken string, mutoken string, mvmo return "", nil } +// Segment 结构体用于在 Channel 中传递分段数据 +type Segment struct { + Index int + Data []byte +} + +func downloadSegment(url string, index int, wg *sync.WaitGroup, segmentsChan chan<- Segment, client *http.Client, limiter chan struct{}) { + // 函数退出时,从 limiter 中接收一个值,释放一个并发槽位 + defer func() { + <-limiter + wg.Done() + }() + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + fmt.Printf("错误(分段 %d): 创建请求失败: %v\n", index, err) + return + } + + resp, err := client.Do(req) + if err != nil { + fmt.Printf("错误(分段 %d): 下载失败: %v\n", index, err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + fmt.Printf("错误(分段 %d): 服务器返回状态码 %d\n", index, resp.StatusCode) + return + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Printf("错误(分段 %d): 读取数据失败: %v\n", index, err) + return + } + + // 将下载好的分段(包含序号和数据)发送到 Channel + segmentsChan <- Segment{Index: index, Data: data} +} + +// fileWriter 从 Channel 接收分段并按顺序写入文件 +func fileWriter(wg *sync.WaitGroup, segmentsChan <-chan Segment, outputFile io.Writer, totalSegments int) { + defer wg.Done() + + // 缓冲区,用于存放乱序到达的分段 + // key 是分段序号,value 是分段数据 + segmentBuffer := make(map[int][]byte) + nextIndex := 0 // 期望写入的下一个分段的序号 + + for segment := range segmentsChan { + // 检查收到的分段是否是当前期望的 + if segment.Index == nextIndex { + //fmt.Printf("写入分段 %d\n", segment.Index) + _, err := outputFile.Write(segment.Data) + if err != nil { + fmt.Printf("错误(分段 %d): 写入文件失败: %v\n", segment.Index, err) + } + nextIndex++ + + // 检查缓冲区中是否有下一个连续的分段 + for { + data, ok := segmentBuffer[nextIndex] + if !ok { + break // 缓冲区里没有下一个,跳出循环,等待下一个分段到达 + } + + //fmt.Printf("从缓冲区写入分段 %d\n", nextIndex) + _, err := outputFile.Write(data) + if err != nil { + fmt.Printf("错误(分段 %d): 从缓冲区写入文件失败: %v\n", nextIndex, err) + } + // 从缓冲区删除已写入的分段,释放内存 + delete(segmentBuffer, nextIndex) + nextIndex++ + } + } else { + // 如果不是期望的分段,先存入缓冲区 + //fmt.Printf("缓冲分段 %d (等待 %d)\n", segment.Index, nextIndex) + segmentBuffer[segment.Index] = segment.Data + } + } + + // 确保所有分段都已写入 + if nextIndex != totalSegments { + fmt.Printf("警告: 写入完成,但似乎有分段丢失。期望 %d 个, 实际写入 %d 个。\n", totalSegments, nextIndex) + } +} + func ExtMvData(keyAndUrls string, savePath string) error { segments := strings.Split(keyAndUrls, ";") key := segments[0] @@ -360,95 +449,42 @@ func ExtMvData(keyAndUrls string, savePath string) error { defer os.Remove(tempFile.Name()) defer tempFile.Close() - // 创建上下文用于取消所有下载任务 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + var downloadWg, writerWg sync.WaitGroup + segmentsChan := make(chan Segment, len(urls)) + // --- 新增代码: 定义最大并发数 --- + const maxConcurrency = 10 + // --- 新增代码: 创建带缓冲的 Channel 作为信号量 --- + limiter := make(chan struct{}, maxConcurrency) + client := &http.Client{} // 初始化进度条 bar := progressbar.DefaultBytes(-1, "Downloading...") barWriter := io.MultiWriter(tempFile, bar) - // 预先创建所有管道 - pipeReaders := make([]*io.PipeReader, len(urls)) - pipeWriters := make([]*io.PipeWriter, len(urls)) - for i := range urls { - pr, pw := io.Pipe() - pipeReaders[i] = pr - pipeWriters[i] = pw - } + // 启动写入 Goroutine + writerWg.Add(1) + go fileWriter(&writerWg, segmentsChan, barWriter, len(urls)) - // 控制并发数(使用空结构体节省内存) - sem := make(chan struct{}, 5) - var wg sync.WaitGroup + // 启动下载 Goroutines + for i, url := range urls { + // 在启动 Goroutine 前,向 limiter 发送一个值来“获取”一个槽位 + // 如果 limiter 已满 (达到10个),这里会阻塞,直到有其他任务完成并释放槽位 + //fmt.Printf("请求启动任务 %d...\n", i) + limiter <- struct{}{} + //fmt.Printf("...任务 %d 已启动\n", i) - // 创建带超时的HTTP Client - client := &http.Client{ - Timeout: 30 * time.Second, - } - - // 启动下载任务 - go func() { - for i, url := range urls { - select { - case <-ctx.Done(): - return // 上下文已取消,直接返回 - default: - sem <- struct{}{} // 获取信号量 - wg.Add(1) - - go func(i int, url string, pw *io.PipeWriter) { - defer func() { - <-sem // 释放信号量 - wg.Done() - }() - - // 创建带上下文的请求 - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - pw.CloseWithError(err) - fmt.Printf("创建请求失败: %v\n", err) - return - } - - resp, err := client.Do(req) - if err != nil { - pw.CloseWithError(err) - fmt.Printf("下载失败: %v\n", err) - return - } - defer resp.Body.Close() - - // 检查HTTP状态码 - if resp.StatusCode != http.StatusOK { - err := fmt.Errorf("非200状态码: %d", resp.StatusCode) - pw.CloseWithError(err) - fmt.Printf("下载失败: %v\n", err) - return - } - - // 将响应体复制到管道 - if _, err := io.Copy(pw, resp.Body); err != nil { - pw.CloseWithError(err) - } else { - pw.Close() // 正常关闭 - } - }(i, url, pipeWriters[i]) - } - } - }() - - // 按顺序写入文件 - for i := 0; i < len(urls); i++ { - if _, err := io.Copy(barWriter, pipeReaders[i]); err != nil { - cancel() // 取消所有下载任务 - fmt.Printf("写入第 %d 部分失败: %v\n", i+1, err) - return err - } - pipeReaders[i].Close() // 关闭当前读取端 + downloadWg.Add(1) + // 将 limiter 传递给下载函数 + go downloadSegment(url, i, &downloadWg, segmentsChan, client, limiter) } // 等待所有下载任务完成 - wg.Wait() + downloadWg.Wait() + // 下载完成后,关闭 Channel。写入 Goroutine 会在处理完 Channel 中所有数据后退出。 + close(segmentsChan) + + // 等待写入 Goroutine 完成所有写入和缓冲处理 + writerWg.Wait() // 显式关闭文件(defer会再次调用,但重复关闭是安全的) if err := tempFile.Close(); err != nil {