《昇思25天学习打卡营第9天 | mindspore 使用静态图加速》

1. 背景:

使用 mindspore 学习神经网络,打卡第9天;

2. 训练的内容:

mindspore 框架分为两种运行模式,分别是动态图模式(PYNATIVE_MODE)以及静态图模式(GRAPH_MODE); 默认使用动态图模式; 但是,也提供静态图模式;

使用静态图模式时,编译器可以针对图进行全局的优化,获得较好的性能,因此比较适合网络固定且需要高性能的场景

3. 常见的用法小节:

静态图模式与动态图模式切换:

  • 使用 set_context 进行运行全局静态图模式与动态图模式
  • 使用 jit 装饰器将函数使用局部的静态图模式

3.1 全局静态图模式-动态图模式切换:

  • 设置全局动态图模式 ms.set_context(mode=ms.PYNATIVE_MODE)
# AI编译框架分为两种运行模式,分别是动态图模式以及静态图模式。
# MindSpore默认情况下是以动态图模式运行,但也支持手工切换为静态图模式
import time
import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

ms.set_context(mode=ms.PYNATIVE_MODE)  # 使用set_context进行动态图模式的配置

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

begin_time = time.time()
model = Network()
input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
output = model(input)
print(output)
print(f'PYNATIVE_MODE cost time(ms): {(time.time() - begin_time) * 1000} ')

运行时间大约:

PYNATIVE_MODE cost time(ms): 56.162357330322266
  • 设置全局静态图模式:
    ms.set_context(mode=ms.GRAPH_MODE)
# 使用set_context进行运行静态图模式的配置
import numpy as np
import mindspore as ms
from mindspore import nn, Tensor

ms.set_context(mode=ms.GRAPH_MODE)

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits


begin_time = time.time()
model = Network()
input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))
output = model(input)
print(f' ---- GRAPH_MODE cost time(ms): {(time.time() - begin_time) * 1000} ')

begin_time = time.time()
output = model(input)
print(f' ---- GRAPH_MODE cost time(ms): {(time.time() - begin_time) * 1000} ')

运行时间大约:

 ---- GRAPH_MODE cost time(ms): 8907.071828842163 
 ---- GRAPH_MODE cost time(ms): 4.278421401977539 

第一次包含编译优化,运行时间长;
第二次是已经编译优化过,运行时间短

3.2 jit 装饰器将函数使用局部的静态图模式

@jit 装饰器修饰函数或者类的成员方法,所修饰的函数或方法将会被编译成静态计算图

@ms.jit  # 使用ms.jit装饰器,使被装饰的函数以静态图模式运行
def run(x):
    model = Network()
    return model(x)

begin_time = time.time()
output = run(input)
print(output)
print(f' ---- GRAPH_MODE jit cost time(ms): {(time.time() - begin_time) * 1000} ')

3.3 注意事项:

  • 静态图模式下,只支持 Python 部分语法,需要特别注意;
  • 静态图模式下,首次使用对象时需要编译图优化,因此,首次使用时间较长,后续使用时性能好。

4. 相关链接:

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-15 16:44:02       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-15 16:44:02       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-15 16:44:02       58 阅读
  4. Python语言-面向对象

    2024-07-15 16:44:02       69 阅读

热门阅读

  1. Eureka是什么?

    2024-07-15 16:44:02       22 阅读
  2. 享元模式(大话设计模式)C/C++版本

    2024-07-15 16:44:02       19 阅读
  3. html 关闭信息窗口

    2024-07-15 16:44:02       22 阅读
  4. vue3+springboot+minio,实现文件上传功能

    2024-07-15 16:44:02       20 阅读
  5. 使用Python进行桌面应用程序开发

    2024-07-15 16:44:02       16 阅读
  6. 启动 zabbix 相关服务

    2024-07-15 16:44:02       19 阅读
  7. 【AI应用探讨】—KAN应用场景

    2024-07-15 16:44:02       23 阅读
  8. 【无标题】

    2024-07-15 16:44:02       19 阅读
  9. 租用海外服务器需要考虑哪些因素

    2024-07-15 16:44:02       18 阅读
  10. 1448. 统计二叉树中好节点的数目

    2024-07-15 16:44:02       21 阅读