简介
话接上篇 Golang实现一个批量自动化执行树莓派指令的软件(2)指令, 这次实现文件的下载。
环境描述
运行环境: Windows, 基于Golang, 暂时没有使用什么不可跨平台接口, 理论上支持Linux/MacOS
目标终端:树莓派DebianOS(主要做用它测试)
实现
接口定义
type IDownloader interface {
/*
Download 下载的同步接口, 会堵塞执行
from : 下载的路径
to : 保存的路径
*/
Download(from, to string) error
/*
Download 下载的同步/异步接口
from : 下载的路径
to : 保存的路径
processCallback : 进度回调函数,每次下载文件的时候被调用, 返回当前下载进度信息
from : 当前下载的文件路径
to : 当前文件保存路径
downloadNumber : 下载的文件总数
downloaded : 已下载的文件数
finishedCallback : 完成下载时调用
background : 表示是同步执行还是异步执行
*/
DownloadWithCallback(from, to string,
processCallback func(from, to string, downloadNumber, downloaded uint),
finishedCallback func(err error), background bool) error
}
接口实现
package sshutil
import (
"fmt"
"github.com/pkg/sftp"
"os"
"path"
"time"
)
type downloader struct {
client *sftp.Client
downloadSize uint
downNumber uint
downloaded uint
started bool
canceled chan struct{}
}
func newDownloader(client *sftp.Client) (*downloader, error) {
return &downloader{client: client, canceled: make(chan struct{})}, nil
}
func (d *downloader) Download(from, to string) error {
return d.download(from, to, nil, nil)
}
func (d *downloader) DownloadWithCallback(from, to string,
processCallback func(from, to string, downloadNumber, downloaded uint),
finishedCallback func(err error), background bool) error {
if !background {
return d.download(from, to, processCallback, finishedCallback)
} else {
go d.download(from, to, processCallback, finishedCallback)
}
return nil
}
func (d *downloader) Cancel() error {
if d.started {
select {
case d.canceled <- struct{}{}:
case <-time.After(time.Second * 2): // 取消时间过长,取消失败
return fmt.Errorf("time out waiting for cancel")
}
}
return nil
}
func (d *downloader) Destroy() error {
err := d.Cancel()
close(d.canceled)
return err
}
func (d *downloader) downloadFolderCount(remotePath string) (needDownload, size uint, err error) {
var c, s uint
infos, _ := d.client.ReadDir(remotePath)
for _, info := range infos {
if info.IsDir() {
c, s, err = d.downloadFolderCount(path.Join(remotePath, info.Name()))
if nil != err {
return
}
needDownload += c
size += s
continue
}
size += uint(info.Size())
needDownload += 1
}
err = nil
return
}
func (d *downloader) downloadFileCount(remotePath string) (needDownload, size uint, err error) {
info, err := d.client.Stat(remotePath)
if nil != err {
return 0, 0, err
}
if info.IsDir() {
return d.downloadFolderCount(remotePath)
}
return 1, uint(info.Size()), nil
}
func (d *downloader) download(remotePath, localPath string,
processCallback func(from, to string, downloadNumber, downloaded uint),
finishedCallback func(err error)) (err error) {
whenErrorCall := func(e error) error {
if nil != finishedCallback {
go finishedCallback(e)
}
return e
}
d.started = true
defer func() {
d.started = false
}()
d.downNumber, d.downloadSize, err = d.downloadFileCount(remotePath)
if nil != err {
return whenErrorCall(err)
}
err = os.MkdirAll(localPath, 0777)
if nil != err {
if !os.IsExist(err) {
return whenErrorCall(err)
}
}
info, err := d.client.Stat(remotePath)
if nil != err {
return whenErrorCall(err)
}
if info.IsDir() {
return d.downloadFolder(remotePath, localPath, processCallback, finishedCallback)
}
return d.downloadFile(remotePath, localPath, processCallback, finishedCallback)
}
func (d *downloader) downloadFile(remotePath, localPath string,
processCallback func(from, to string, downloadNumber, downloaded uint),
finishedCallback func(err error)) (err error) {
var (
srcFile *sftp.File
dstFile *os.File
info os.FileInfo
localFileName = path.Join(localPath, path.Base(remotePath))
)
whenErrorCall := func(e error) error {
if nil != finishedCallback {
go finishedCallback(e)
}
return e
}
info, err = d.client.Stat(remotePath)
if nil != err {
return whenErrorCall(err)
}
/*
这里是解决在下载0KB的文件时,sftp.Open接口会一直堵塞, 所以判定0KB文件直接创建就好, 有兴趣这里可以进行简化
*/
if 0 >= info.Size() {
dstFile, err = os.OpenFile(localFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
if nil != err {
return whenErrorCall(err)
}
dstFile.Close()
return whenErrorCall(err)
}
srcFile, err = d.client.Open(remotePath)
if err != nil {
return whenErrorCall(err)
}
defer srcFile.Close()
dstFile, err = os.OpenFile(localFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777)
if err != nil {
return whenErrorCall(err)
}
defer dstFile.Close()
if _, err = srcFile.WriteTo(dstFile); err != nil {
return whenErrorCall(err)
}
select {
case <-d.canceled:
return whenErrorCall(fmt.Errorf("user canceled"))
default:
}
d.downloaded += 1
if nil != processCallback {
go processCallback(remotePath, localFileName, d.downNumber, d.downloaded)
}
return whenErrorCall(err)
}
func (d *downloader) downloadFolder(remotePath, localPath string,
processCallback func(from, to string, downloadNumber, downloaded uint),
finishedCallback func(err error)) (err error) {
whenErrorCall := func(e error) error {
if nil != finishedCallback {
go finishedCallback(e)
}
return e
}
err = os.MkdirAll(localPath, 0777)
if nil != err {
return whenErrorCall(err)
}
infos, err := d.client.ReadDir(remotePath)
for _, info := range infos {
remoteFilePath := path.Join(remotePath, info.Name())
if info.IsDir() {
localFilePath := path.Join(localPath, info.Name())
err = d.downloadFolder(remoteFilePath, localFilePath, processCallback, nil)
if nil != err {
return whenErrorCall(err)
}
} else {
err = d.downloadFile(remoteFilePath, localPath, processCallback, nil)
if nil != err {
return err
}
}
}
return whenErrorCall(err)
}
测试用例
package sshutil
import (
"fmt"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"sync"
"testing"
"time"
)
type downloaderTest struct {
sshClient *ssh.Client
sftpClient *sftp.Client
downloader *downloader
}
func newDownloadTest() (*downloaderTest, error) {
var (
err error
dTest = &downloaderTest{}
)
config := ssh.ClientConfig{
User: "pi", // 用户名
Auth: []ssh.AuthMethod{ssh.Password("a123456")}, // 密码
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
}
dTest.sshClient, err = ssh.Dial("tcp", "192.168.3.2:22", &config) //IP + 端口
if err != nil {
fmt.Print(err)
return nil, err
}
if dTest.sftpClient, err = sftp.NewClient(dTest.sshClient); err != nil {
dTest.destroy()
return nil, err
}
dTest.downloader, err = newDownloader(dTest.sftpClient)
return dTest, err
}
func (d *downloaderTest) destroy() {
if nil != d.sftpClient {
d.sftpClient.Close()
d.sftpClient = nil
}
if nil != d.sshClient {
d.sshClient.Close()
d.sshClient = nil
}
}
func TestDownloader_Download(t *testing.T) {
var dTest, err = newDownloadTest()
if nil != err {
fmt.Println("fail to new download test!")
return
}
defer dTest.destroy()
err = dTest.downloader.Download("/home/pi/", "./download")
if nil != err {
fmt.Println(err)
}
}
func TestDownloader_DownloadWithCallback(t *testing.T) {
var dTest, err = newDownloadTest()
if nil != err {
fmt.Println("fail to new download test!")
return
}
defer dTest.destroy()
err = dTest.downloader.DownloadWithCallback("/home/pi/", "./download1", func(from, to string, downloadNumber, downloaded uint) {
fmt.Println(from, to, downloadNumber, downloaded)
}, func(err error) {
fmt.Println("finished!!!")
}, false)
if nil != err {
fmt.Println(err)
}
fmt.Println("sleping...")
time.Sleep(time.Second * 1) // process 在download内部是异步调用, 所以这里延时使process内部的打印有时间执行完
}
func TestDownloader_DownloadWithCallbackAsync(t *testing.T) {
var waiter sync.WaitGroup
var dTest, err = newDownloadTest()
if nil != err {
fmt.Println("fail to new download test!")
return
}
defer dTest.destroy()
waiter.Add(1)
err = dTest.downloader.DownloadWithCallback("/home/pi/", "./download2/", func(from, to string, downloadNumber, downloaded uint) {
fmt.Println(from, to, downloadNumber, downloaded)
}, func(err error) {
fmt.Println("finished!!!")
waiter.Done()
}, true)
if nil != err {
fmt.Println(err)
}
fmt.Println("waiting....")
waiter.Wait()
fmt.Println("done!!!")
time.Sleep(time.Second * 1) // process 在download内部是异步调用, 所以这里延时使process内部的打印有时间执行完
}
代码源
https://gitee.com/grayhsu/ssh_remote_access