# mxnet-gan **Repository Path**: TimVerion/mxnet-gan ## Basic Information - **Project Name**: mxnet-gan - **Description**: MultiGPU enabled image generative models (GAN and DCGAN) - **Primary Language**: Unknown - **License**: Apache-2.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2020-05-16 - **Last Updated**: 2020-12-19 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # MXNet GAN [MXNet](https://github.com/dmlc/mxnet) module implementation of multi GPU compatible generative models. ## List of Methods - Unsupervised Training - Semisupervised Training - Minibatch discrimation ## Usage ```python import logging import numpy as np import mxnet as mx from mxgan import module, generator, encoder, viz def ferr(label, pred): pred = pred.ravel() label = label.ravel() return np.abs(label - (pred > 0.5)).sum() / label.shape[0] lr = 0.0005 beta1 = 0.5 batch_size = 100 rand_shape = (batch_size, 100) num_epoch = 100 data_shape = (batch_size, 1, 28, 28) context = mx.gpu() logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') sym_gen = generator.dcgan28x28(oshape=data_shape, ngf=32, final_act="sigmoid") gmod = module.GANModule( sym_gen, symbol_encoder=encoder.lenet(), context=context, data_shape=data_shape, code_shape=rand_shape) gmod.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34)) gmod.init_optimizer( optimizer="adam", optimizer_params={ "learning_rate": lr, "wd": 0., "beta1": beta1, }) data_dir = './../../mxnet/example/image-classification/mnist/' train = mx.io.MNISTIter( image = data_dir + "train-images-idx3-ubyte", label = data_dir + "train-labels-idx1-ubyte", input_shape = data_shape[1:], batch_size = batch_size, shuffle = True) metric_acc = mx.metric.CustomMetric(ferr) for epoch in range(num_epoch): train.reset() metric_acc.reset() for t, batch in enumerate(train): gmod.update(batch) gmod.temp_label[:] = 0.0 metric_acc.update([gmod.temp_label], gmod.outputs_fake) gmod.temp_label[:] = 1.0 metric_acc.update([gmod.temp_label], gmod.outputs_real) if t % 100 == 0: logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get()) viz.imshow("gout", gmod.temp_outG[0].asnumpy(), 2) diff = gmod.temp_diffD[0].asnumpy() diff = (diff - diff.mean()) / diff.std() + 0.5 viz.imshow("diff", diff) viz.imshow("data", batch.data[0].asnumpy(), 2) ```