在 /home/your_name/TensorFlow/DCGAN/ 下新建文件 train.py,同时新建文件夹 logs 和文件夹 samples,前者用来保存训练过程中的日志和模型,后者用来保存训练过程中采样器的采样图片,在 train.py 中输入如下代码:

移动开发培训,Android培训,安卓培训,手机开发培训,手机维修培训,手机软件培训

# -*- coding: utf-8 -*-import tensorflow as tfimport osfrom read_data import *from utils import *from ops import *from model import *from model import BATCH_SIZEdef train():    # 设置 global_step ,用来记录训练过程中的 step        
    global_step = tf.Variable(0, name = 'global_step', trainable = False)    # 训练过程中的日志保存文件
    train_dir = '/home/your_name/TensorFlow/DCGAN/logs'

    # 放置三个 placeholder,y 表示约束条件,images 表示送入判别器的图片,
    # z 表示随机噪声
    y= tf.placeholder(tf.float32, [BATCH_SIZE, 10], name='y')
    images = tf.placeholder(tf.float32, [64, 28, 28, 1], name='real_images')
    z = tf.placeholder(tf.float32, [None, 100], name='z')    # 由生成器生成图像 G
    G = generator(z, y)    # 真实图像送入判别器
    D, D_logits  = discriminator(images, y)    # 采样器采样图像
    samples = sampler(z, y)    # 生成图像送入判别器
    D_, D_logits_ = discriminator(G, y, reuse = True)    

        
		

网友评论