Golang实现一个批量自动化执行树莓派指令的软件(3)下载

简介

话接上篇 Golang实现一个批量自动化执行树莓派指令的软件(2)指令, 这次实现文件的下载。

环境描述

运行环境: Windows, 基于Golang, 暂时没有使用什么不可跨平台接口, 理论上支持Linux/MacOS
目标终端:树莓派DebianOS(主要做用它测试)

实现

接口定义

type IDownloader interface {
	/*
		Download 下载的同步接口, 会堵塞执行
			from : 下载的路径
			to   : 保存的路径
	*/
	Download(from, to string) error
	/*
		Download 下载的同步/异步接口
			from : 下载的路径
			to   : 保存的路径
			processCallback : 进度回调函数,每次下载文件的时候被调用, 返回当前下载进度信息
					from : 当前下载的文件路径
					to   : 当前文件保存路径
					downloadNumber : 下载的文件总数
					downloaded     : 已下载的文件数
			finishedCallback : 完成下载时调用
			background : 表示是同步执行还是异步执行
	*/
	DownloadWithCallback(from, to string,
		processCallback func(from, to string, downloadNumber, downloaded uint),
		finishedCallback func(err error), background bool) error
}

接口实现

package sshutil

import (
	"fmt"
	"github.com/pkg/sftp"
	"os"
	"path"
	"time"
)

type downloader struct {
	client       *sftp.Client
	downloadSize uint
	downNumber   uint
	downloaded   uint

	started  bool
	canceled chan struct{}
}

func newDownloader(client *sftp.Client) (*downloader, error) {
	return &downloader{client: client, canceled: make(chan struct{})}, nil
}

func (d *downloader) Download(from, to string) error {
	return d.download(from, to, nil, nil)
}

func (d *downloader) DownloadWithCallback(from, to string,
	processCallback func(from, to string, downloadNumber, downloaded uint),
	finishedCallback func(err error), background bool) error {
	if !background {
		return d.download(from, to, processCallback, finishedCallback)
	} else {
		go d.download(from, to, processCallback, finishedCallback)
	}
	return nil
}

func (d *downloader) Cancel() error {
	if d.started {
		select {
		case d.canceled <- struct{}{}:
		case <-time.After(time.Second * 2): // 取消时间过长,取消失败
			return fmt.Errorf("time out waiting for cancel")
		}
	}
	return nil
}

func (d *downloader) Destroy() error {
	err := d.Cancel()
	close(d.canceled)
	return err
}

func (d *downloader) downloadFolderCount(remotePath string) (needDownload, size uint, err error) {
	var c, s uint
	infos, _ := d.client.ReadDir(remotePath)
	for _, info := range infos {
		if info.IsDir() {
			c, s, err = d.downloadFolderCount(path.Join(remotePath, info.Name()))
			if nil != err {
				return
			}
			needDownload += c
			size += s
			continue
		}
		size += uint(info.Size())
		needDownload += 1
	}
	err = nil
	return
}

func (d *downloader) downloadFileCount(remotePath string) (needDownload, size uint, err error) {
	info, err := d.client.Stat(remotePath)
	if nil != err {
		return 0, 0, err
	}

	if info.IsDir() {
		return d.downloadFolderCount(remotePath)
	}

	return 1, uint(info.Size()), nil
}

func (d *downloader) download(remotePath, localPath string,
	processCallback func(from, to string, downloadNumber, downloaded uint),
	finishedCallback func(err error)) (err error) {

	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	d.started = true
	defer func() {
		d.started = false
	}()

	d.downNumber, d.downloadSize, err = d.downloadFileCount(remotePath)
	if nil != err {
		return whenErrorCall(err)
	}

	err = os.MkdirAll(localPath, 0777)
	if nil != err {
		if !os.IsExist(err) {
			return whenErrorCall(err)
		}
	}

	info, err := d.client.Stat(remotePath)
	if nil != err {
		return whenErrorCall(err)
	}
	if info.IsDir() {
		return d.downloadFolder(remotePath, localPath, processCallback, finishedCallback)
	}
	return d.downloadFile(remotePath, localPath, processCallback, finishedCallback)
}

func (d *downloader) downloadFile(remotePath, localPath string,
	processCallback func(from, to string, downloadNumber, downloaded uint),
	finishedCallback func(err error)) (err error) {
	var (
		srcFile       *sftp.File
		dstFile       *os.File
		info          os.FileInfo
		localFileName = path.Join(localPath, path.Base(remotePath))
	)

	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	info, err = d.client.Stat(remotePath)
	if nil != err {
		return whenErrorCall(err)
	}

	/*
		这里是解决在下载0KB的文件时,sftp.Open接口会一直堵塞, 所以判定0KB文件直接创建就好, 有兴趣这里可以进行简化
	*/
	if 0 >= info.Size() {
		dstFile, err = os.OpenFile(localFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
		if nil != err {
			return whenErrorCall(err)
		}
		dstFile.Close()
		return whenErrorCall(err)
	}

	srcFile, err = d.client.Open(remotePath)
	if err != nil {
		return whenErrorCall(err)
	}
	defer srcFile.Close()

	dstFile, err = os.OpenFile(localFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
	if err != nil {
		return whenErrorCall(err)
	}
	defer dstFile.Close()
	if _, err = srcFile.WriteTo(dstFile); err != nil {
		return whenErrorCall(err)
	}
	select {
	case <-d.canceled:
		return whenErrorCall(fmt.Errorf("user canceled"))
	default:
	}
	d.downloaded += 1
	if nil != processCallback {
		go processCallback(remotePath, localFileName, d.downNumber, d.downloaded)
	}

	return whenErrorCall(err)
}

func (d *downloader) downloadFolder(remotePath, localPath string,
	processCallback func(from, to string, downloadNumber, downloaded uint),
	finishedCallback func(err error)) (err error) {

	whenErrorCall := func(e error) error {
		if nil != finishedCallback {
			go finishedCallback(e)
		}
		return e
	}

	err = os.MkdirAll(localPath, 0777)
	if nil != err {
		return whenErrorCall(err)
	}

	infos, err := d.client.ReadDir(remotePath)
	for _, info := range infos {
		remoteFilePath := path.Join(remotePath, info.Name())
		if info.IsDir() {
			localFilePath := path.Join(localPath, info.Name())

			err = d.downloadFolder(remoteFilePath, localFilePath, processCallback, nil)
			if nil != err {
				return whenErrorCall(err)
			}
		} else {
			err = d.downloadFile(remoteFilePath, localPath, processCallback, nil)
			if nil != err {
				return err
			}
		}
	}

	return whenErrorCall(err)
}

测试用例

package sshutil

import (
	"fmt"
	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh"
	"sync"
	"testing"
	"time"
)

type downloaderTest struct {
	sshClient  *ssh.Client
	sftpClient *sftp.Client

	downloader *downloader
}

func newDownloadTest() (*downloaderTest, error) {
	var (
		err   error
		dTest = &downloaderTest{}
	)
	config := ssh.ClientConfig{
		User:            "pi",                                      // 用户名
		Auth:            []ssh.AuthMethod{ssh.Password("a123456")}, // 密码
		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
		Timeout:         10 * time.Second,
	}
	dTest.sshClient, err = ssh.Dial("tcp", "192.168.3.2:22", &config) //IP + 端口
	if err != nil {
		fmt.Print(err)
		return nil, err
	}
	if dTest.sftpClient, err = sftp.NewClient(dTest.sshClient); err != nil {
		dTest.destroy()
		return nil, err
	}

	dTest.downloader, err = newDownloader(dTest.sftpClient)

	return dTest, err
}

func (d *downloaderTest) destroy() {
	if nil != d.sftpClient {
		d.sftpClient.Close()
		d.sftpClient = nil
	}

	if nil != d.sshClient {
		d.sshClient.Close()
		d.sshClient = nil
	}
}

func TestDownloader_Download(t *testing.T) {
	var dTest, err = newDownloadTest()
	if nil != err {
		fmt.Println("fail to new download test!")
		return
	}
	defer dTest.destroy()

	err = dTest.downloader.Download("/home/pi/", "./download")
	if nil != err {
		fmt.Println(err)
	}
}

func TestDownloader_DownloadWithCallback(t *testing.T) {
	var dTest, err = newDownloadTest()
	if nil != err {
		fmt.Println("fail to new download test!")
		return
	}
	defer dTest.destroy()

	err = dTest.downloader.DownloadWithCallback("/home/pi/", "./download1", func(from, to string, downloadNumber, downloaded uint) {
		fmt.Println(from, to, downloadNumber, downloaded)
	}, func(err error) {
		fmt.Println("finished!!!")
	}, false)

	if nil != err {
		fmt.Println(err)
	}
	fmt.Println("sleping...")
	time.Sleep(time.Second * 1) // process 在download内部是异步调用, 所以这里延时使process内部的打印有时间执行完
}

func TestDownloader_DownloadWithCallbackAsync(t *testing.T) {
	var waiter sync.WaitGroup
	var dTest, err = newDownloadTest()
	if nil != err {
		fmt.Println("fail to new download test!")
		return
	}
	defer dTest.destroy()
	waiter.Add(1)
	err = dTest.downloader.DownloadWithCallback("/home/pi/", "./download2/", func(from, to string, downloadNumber, downloaded uint) {
		fmt.Println(from, to, downloadNumber, downloaded)
	}, func(err error) {
		fmt.Println("finished!!!")
		waiter.Done()
	}, true)

	if nil != err {
		fmt.Println(err)
	}
	fmt.Println("waiting....")
	waiter.Wait()
	fmt.Println("done!!!")
	time.Sleep(time.Second * 1) // process 在download内部是异步调用, 所以这里延时使process内部的打印有时间执行完
}

代码源

https://gitee.com/grayhsu/ssh_remote_access

其他

参考

最近更新

  1. TCP协议是安全的吗?

    2024-04-26 11:02:05       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-26 11:02:05       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-26 11:02:05       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-26 11:02:05       20 阅读

热门阅读

  1. npm详解

    2024-04-26 11:02:05       14 阅读
  2. npm/yarm常用命令

    2024-04-26 11:02:05       12 阅读
  3. 企业网络安全的全方位解决方案

    2024-04-26 11:02:05       12 阅读
  4. 大数据任务运维方案

    2024-04-26 11:02:05       12 阅读
  5. 【13】编写shell-备份mysql数据

    2024-04-26 11:02:05       12 阅读