diff --git a/main.go b/main.go index 193a1d2..f1f9df5 100644 --- a/main.go +++ b/main.go @@ -24,8 +24,13 @@ func main() { } handleGracefulShutdown() - if !downloader.FileExists(appSettings.DataFile) { - downloader.DownloadFile(appSettings.DataFile, appSettings.DataURL) + + ctx := downloader.NewContext(appSettings.DataFile, appSettings.DataURL) + if !ctx.FileExists() { + if err := ctx.Download(); err != nil { + panic(err) + } + fmt.Printf("Saved file to %s\n", ctx.Filename) } fmt.Printf("Import data from file...\r") diff --git a/pkg/downloader/downloader.go b/pkg/downloader/downloader.go index 62f2127..62dd3f3 100644 --- a/pkg/downloader/downloader.go +++ b/pkg/downloader/downloader.go @@ -11,34 +11,44 @@ import ( "github.com/schollz/progressbar/v3" ) -func handleError(err error) { - fmt.Println(err.Error()) - os.Exit(0) +type Context struct { + Filename string + Link string } -func DownloadFile(path, url string) { - resp, err := req.Head(url, nil) +func NewContext(filename, link string) *Context { + return &Context{ + Filename: filename, + Link: link, + } +} + +func (c *Context) Download() error { + resp, err := req.Head(c.Link, nil) if err != nil { if strings.Contains(err.Error(), "no such host") { - handleError(errors.New("invalid url")) - } else if strings.Contains(err.Error(), "certificate signed by unknown authority") { - handleError(errors.New("certificate from unknown authority")) + return fmt.Errorf("invalid url") } - handleError(err) + if strings.Contains(err.Error(), "certificate signed by unknown authority") { + return fmt.Errorf("certificate from unknown authority") + } + return err } defer resp.Close() if resp.StatusCode == 404 { - handleError(errors.New("invalid url")) - } else if resp.StatusCode == 401 { - handleError(errors.New("restricted access (credentials required)")) - } else if resp.StatusCode != 200 { - handleError(errors.New(resp.RawResponse.Status)) + return fmt.Errorf("invalid url") + } + if resp.StatusCode == 401 { + return fmt.Errorf("restricted access (credentials required)") + } + if resp.StatusCode != 200 { + return fmt.Errorf(resp.RawResponse.Status) } resp.Close() - resp, err = req.Get(url, nil) + resp, err = req.Get(c.Link, nil) if err != nil { - handleError(err) + return err } var filesize int64 @@ -46,9 +56,9 @@ func DownloadFile(path, url string) { filesize = resp.RawResponse.ContentLength } - destFile, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644) + destFile, err := os.OpenFile(c.Filename, os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - handleError(err) + return err } defer destFile.Close() @@ -59,13 +69,13 @@ func DownloadFile(path, url string) { _, err = io.Copy(io.MultiWriter(destFile, bar), resp.RawResponse.Body) if err != nil { - handleError(err) + return err } - fmt.Println("Saved file to", path) + return nil } -func FileExists(path string) bool { - _, err := os.Stat(path) +func (c *Context) FileExists() bool { + _, err := os.Stat(c.Filename) return !errors.Is(err, os.ErrNotExist) }