package downloader import ( "crypto/md5" "crypto/sha1" "crypto/sha256" "errors" "fmt" "hash" "io" "os" "strings" "time" "github.com/dustin/go-humanize" req "github.com/levigross/grequests" ) var ( hashes = map[string]hash.Hash{ "md5": md5.New(), //nolint: gosec "sha1": sha1.New(), //nolint: gosec "sha256": sha256.New(), } ) type Context struct { Filename string Filesize int64 Reader io.Reader Chan chan *Response Closer []func() error hashWriter hash.Hash } func NewContext(filename string, filesize int64, hw hash.Hash, closer ...func() error) *Context { return &Context{ Filename: filename, Filesize: filesize, Chan: make(chan *Response), Closer: closer, hashWriter: hw, } } func (ctx *Context) Checksum() string { return fmt.Sprintf("%x", ctx.hashWriter.Sum(nil)) } func (ctx *Context) SetReader(rdr io.Reader) { ctx.Reader = io.TeeReader( rdr, io.MultiWriter( ctx.hashWriter, &ProgressWriter{ Filename: ctx.Filename, Filesize: ctx.Filesize, Channel: ctx.Chan, Start: time.Now(), }, ), ) } func (ctx *Context) Close() { for i := len(ctx.Closer) - 1; i >= 0; i-- { ctx.Closer[i]() } } type ProgressWriter struct { Channel chan *Response Start time.Time Filename string Filesize int64 Bytes int64 } func (pw *ProgressWriter) Write(data []byte) (int, error) { pw.Bytes = pw.Bytes + int64(len(data)) var progress float64 if pw.Filesize > 0 { progress = float64(pw.Bytes) / float64(pw.Filesize) * 100 } pw.Channel <- &Response{ Progress: progress, BytesPerSecond: float64(pw.Bytes) / time.Since(pw.Start).Seconds(), BytesTransfered: pw.Bytes, BytesTotal: pw.Filesize, } return len(data), nil } type File struct { URL string Filename string Size int64 } type Response struct { Error error Progress float64 BytesPerSecond float64 BytesTransfered int64 BytesTotal int64 } func handleError(err error, adds ...string) { msg := err.Error() if len(adds) != 0 { msg = fmt.Sprintf("%s %s", msg, strings.Join(adds, " ")) } fmt.Println(msg) os.Exit(0) } func printProgress(ctx *Context) { for { resp, ok := <-ctx.Chan if !ok { fmt.Println("\nSaved file to", ctx.Filename) for key, value := range hashes { fmt.Printf("%s: %x\n", key, value.Sum(nil)) } break } if resp.Error != nil { handleError(resp.Error) } if resp.BytesTotal > 0 { fmt.Printf("\rDownloading %7s/%7s %.02f%%", humanize.Bytes(uint64(resp.BytesTransfered)), humanize.Bytes(uint64(resp.BytesTotal)), resp.Progress) } else { fmt.Printf("\rDownloading %7s", humanize.Bytes(uint64(resp.BytesTransfered))) } } } func DownloadFile(path, url string) { file := File{ URL: url, Filename: path, } response, err := req.Head(file.URL, nil) if err != nil { if strings.Contains(err.Error(), "no such host") { handleError(errors.New("invalid url")) } else if strings.Contains(err.Error(), "certificate signed by unknown authority") { handleError(errors.New("certificate signed by unknown authority")) } handleError(err) } defer response.Close() if response.StatusCode == 404 { handleError(errors.New("invalid url")) } if response.StatusCode == 401 { handleError(errors.New("restriced access, credentials required")) } if response.StatusCode != 200 { handleError(errors.New(response.RawResponse.Status)) } response.Close() resp, err := req.Get(file.URL, nil) if err != nil { handleError(err) } var filesize int64 if resp.RawResponse.ContentLength != -1 { filesize = resp.RawResponse.ContentLength } ctx := NewContext(file.Filename, filesize, md5.New()) ctx.SetReader(resp.RawResponse.Body) go func(ctx *Context) { defer ctx.Close() dstFile, err := os.Create(ctx.Filename) if err != nil { ctx.Chan <- &Response{Error: err} close(ctx.Chan) return } defer dstFile.Close() writer := io.MultiWriter(dstFile, hashes["md5"], hashes["sha1"], hashes["sha256"]) _, err = io.Copy(writer, ctx.Reader) if err != nil { ctx.Chan <- &Response{Error: err} } close(ctx.Chan) }(ctx) printProgress(ctx) } func FileExists(path string) bool { _, err := os.Stat(path) return !errors.Is(err, os.ErrNotExist) }