cycle GAN

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'#设置tensorflow的日志级别
from tensorflow.python.platform import build_info

import tensorflow as tf

# 列出所有物理GPU设备  
gpus = tf.config.list_physical_devices('GPU')  
if gpus:  
    # 如果有GPU,设置GPU资源使用率  
    try:  
        # 允许GPU内存按需增长  
        for gpu in gpus:  
            tf.config.experimental.set_memory_growth(gpu, True)  
        # 设置可见的GPU设备(这里实际上不需要,因为已经通过内存增长设置了每个GPU)  
        # tf.config.set_visible_devices(gpus, 'GPU')  
        print("GPU可用并已设置内存增长模式。")  
    except RuntimeError as e:  
        # 虚拟设备未就绪时可能无法设置GPU  
        print(f"设置GPU时发生错误: {e}")  
else:  
    # 如果没有GPU  
    print("没有检测到GPU设备。")

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
AUTOTUNE = tf.data.AUTOTUNE
# tf.data.AUTOTUNE 是一个特殊的值,它告诉TensorFlow的tf.data API自动选择适当的并行度。
# 当使用tf.data API来构建输入管道时,经常需要决定并行
# 处理数据的方式,以最大化数据加载和预处理的速度,同时不浪费计算资源。

# 加载训练数据  
def load_and_preprocess_image(image_path):  
    image = tf.io.read_file(image_path)  
    image = tf.image.decode_jpeg(image, channels=3)  
    image = tf.image.resize(image, IMAGE_SIZE)  
    image /= 255.0  # 归一化到[0, 1]  
    return image 

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

#改变图片大小
def resize(image, height, width):
  image = tf.image.resize(image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return image

#定义随机裁剪方法
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
  return cropped_image

# 标准化 to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

def random_jitter(image):
  # 改变尺寸到 286x286
  image = resize(image, 286, 286)
  # 随机裁剪to 256 x 256 x 3
  image = random_crop(image)
  # 随机的水平翻转
  image = tf.image.random_flip_left_right(image)
  return image

def load(image_file):
    # 读取图片文件,并且解码转换成uint8
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32)
    return image

def preprocess_image_train(image_file):#定义预处理训练图片的方法
    # print(image_file)
    image = load(image_file)
    image = random_jitter(image)
    image = normalize(image)
    return image

import matpl

相关推荐

  1. 【深度学习】CycleGAN

    2024-04-07 06:02:06       25 阅读
  2. CycleGAN(Cycle-Consistent Generative Adversarial Network)

    2024-04-07 06:02:06       44 阅读
  3. 基于CycleGAN的图像风格转换

    2024-04-07 06:02:06       22 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-07 06:02:06       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-07 06:02:06       101 阅读
  3. 在Django里面运行非项目文件

    2024-04-07 06:02:06       82 阅读
  4. Python语言-面向对象

    2024-04-07 06:02:06       91 阅读

热门阅读

  1. Pytorch中的nn.Embedding()

    2024-04-07 06:02:06       38 阅读
  2. Redis过期删除策略和内存淘汰机制

    2024-04-07 06:02:06       45 阅读
  3. 前端node使用WebSocket实现实时通信例子

    2024-04-07 06:02:06       33 阅读
  4. Android ContentProvider基础知识学习笔记

    2024-04-07 06:02:06       39 阅读
  5. vue 生命周期

    2024-04-07 06:02:06       38 阅读
  6. [蓝桥杯 2023 国 B] 双子数

    2024-04-07 06:02:06       39 阅读
  7. ARXML处理 - C#的解析代码(一)

    2024-04-07 06:02:06       32 阅读
  8. Python常用算法--排序算法【附源码】

    2024-04-07 06:02:06       42 阅读