tensorflow入门 自定义层

前面讲了自定义损失函数,自定义正则化,自定义评价函数。现在来讲自定义层,其实都差不多,继承重要的组件就可以了。自定义层就是基层keras.layers.Layer

class MyLayer(keras.layers.Layer):
    def __init__(self, units, activation = None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = keras.activations.get(activation)
        
    def build(self, batch_input_shape):
        self.kernel = self.add_weight(name = 'kernel', shape = [batch_input_shape[-1], self.units], initializer='glorot_normal')
        self.bais = self.add_weight(name = 'bias', shape = self.units, initializer='zeros')
        super().build(batch_input_shape)
        
    def call(self, x):
        return self.activation(X @ self.kernel + self.bias)
        
    def compute_output_shape(self, batch_input_shape):
        return tf.TensorShape(batch_input_shape.as_list()[:-1] + [self.units])
    
    def get_config(self):
        base_config = super().get_config()
        return {**base_config, 'units':self.units, 'activation': keras.activations.serialize(self.activation)}

还是老样子,继承keras已有的组件layers,然后实现几个必要的函数。

(1).构造函数将所有超参数用作参数,kwargs负责把所有的默认参数传递给父类,比如input_shape,trainable,name.

(2).build方法的作用是通过为每个权重调用add_weight()方法来创建层的变量,keras会自动推测输入的维度,也即是batch_input_shape

(3).call方法是每次计算矩阵相乘的时候,被自动调用的方法。

(4).compute_output_shape返回输出的维度,这个函数可有可无,keras会自动推断出输出的维度

(5).get_config是必须的,初始化父类的权重,以及自己的某些参数,当然不仅仅是unit和激活函数,值得注意的是,这里使用了keras.activation.serialize方法保存激活函数的完整配置。

上面创建的层可以直接拿来使用。比如dense = MyLayer(100,'relu')(input)

创建自定义的层也很灵活,可以多输入多输出,只不过需要再call返回的时候,分开返回,比如三个输入,两个输出的自定义层。

class MyLayer(keras.layers,Layer):

        def call(self, X):

                x1, x2, x3 = X

                return [x1+x2, x2+x3]

                

        

这里只是举一个简单的例子

如果再自定义层中需要加入一些操作,比如正则化,也需要再call函数中实现。

相关推荐

  1. tensorflow入门 定义

    2023-12-19 09:28:02       40 阅读
  2. GraphQL入门定义标量类型

    2023-12-19 09:28:02       18 阅读
  3. Kotlin语法入门-定义注解(7)

    2023-12-19 09:28:02       13 阅读
  4. vue3 | 定义遮罩组件

    2023-12-19 09:28:02       31 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-19 09:28:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-19 09:28:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-19 09:28:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-19 09:28:02       20 阅读

热门阅读

  1. 传统服务器和云服务器的区别?

    2023-12-19 09:28:02       38 阅读
  2. Python装饰器

    2023-12-19 09:28:02       41 阅读
  3. Rabbitmq 死信取消超时订单

    2023-12-19 09:28:02       43 阅读
  4. 装饰器设计模式

    2023-12-19 09:28:02       31 阅读
  5. 【uniapp小程序-wesocket的使用】

    2023-12-19 09:28:02       39 阅读
  6. error: C2039: “qt_metacast“: 不是 “***“ 的成员

    2023-12-19 09:28:02       44 阅读
  7. 动态规划 - 1137.第N个泰波那契数(C#和C实现)

    2023-12-19 09:28:02       39 阅读
  8. python学习4

    2023-12-19 09:28:02       32 阅读
  9. 手机天线市场分析:预计2029年将达到576亿美元

    2023-12-19 09:28:02       36 阅读