I'm trying to train a GAN model, but its results are very bad. The Generator doesn't seem to work. Can someone suggest how this can be improved? What is the best way to remember the code? Can this also be done using a builtin library?

import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

random_seed = 42
torch.manual_seed(random_seed)

BATCH_SIZE = 128
AVAIL_GPUS = min(1, torch.cuda.device_count())
DEVICE = torch.device("cuda" if AVAIL_GPUS else "cpu")
LATENT_DIM = 100


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # scale to [-1, 1] for tanh output
])

dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)



class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=1, feature_maps=64):
        super().__init__()
        self.net = nn.Sequential(
           
            nn.ConvTranspose2d(latent_dim, feature_maps * 4, kernel_size=7, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),

          
            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),

           
            nn.ConvTranspose2d(feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), -1, 1, 1)
        return self.net(z)



class Discriminator(nn.Module):
    def __init__(self, img_channels=1, feature_maps=64):
        super().__init__()
        self.net = nn.Sequential(
      
            nn.Conv2d(img_channels, feature_maps, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

           
            nn.Conv2d(feature_maps, feature_maps * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),

          
            nn.Conv2d(feature_maps * 2, 1, kernel_size=7, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).view(-1, 1).squeeze(1)



def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


generator = Generator(LATENT_DIM).to(DEVICE)
discriminator = Discriminator().to(DEVICE)
generator.apply(weights_init)
discriminator.apply(weights_init)

criterion = nn.BCELoss()

opt_g = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_d = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

real_label = 1.0
fake_label = 0.0


# Training loop

def train_gan(num_epochs=30):
    fixed_noise = torch.randn(64, LATENT_DIM, device=DEVICE)
    g_losses, d_losses = [], []

    for epoch in range(num_epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            real_imgs = real_imgs.to(DEVICE)
            bs = real_imgs.size(0)

          
            opt_d.zero_grad()

            labels_real = torch.full((bs,), real_label, device=DEVICE)
            output_real = discriminator(real_imgs)
            loss_d_real = criterion(output_real, labels_real)

            noise = torch.randn(bs, LATENT_DIM, device=DEVICE)
            fake_imgs = generator(noise)
            labels_fake = torch.full((bs,), fake_label, device=DEVICE)
            output_fake = discriminator(fake_imgs.detach())
            loss_d_fake = criterion(output_fake, labels_fake)

            loss_d = loss_d_real + loss_d_fake
            loss_d.backward()
            opt_d.step()

           
            opt_g.zero_grad()
            labels_gen = torch.full((bs,), real_label, device=DEVICE)  # want D to think these are real
            output_gen = discriminator(fake_imgs)
            loss_g = criterion(output_gen, labels_gen)
            loss_g.backward()
            opt_g.step()

            if i % 200 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Step [{i}/{len(dataloader)}] "
                      f"D_loss: {loss_d.item():.4f} G_loss: {loss_g.item():.4f}")

        g_losses.append(loss_g.item())
        d_losses.append(loss_d.item())

        # Visualize progress
        with torch.no_grad():
            fake = generator(fixed_noise).detach().cpu()
        show_images(fake, epoch + 1)

    return g_losses, d_losses


def show_images(images, epoch):
    images = (images + 1) / 2  # unnormalize from [-1,1] to [0,1]
    grid = torchvision.utils.make_grid(images, nrow=8)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap="gray")
    plt.axis("off")
    plt.title(f"Epoch {epoch}")
    plt.show()


g_losses, d_losses = train_gan(num_epochs=30)

# Plot losses
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label="Generator")
plt.plot(d_losses, label="Discriminator")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("GAN Training Losses")
plt.show()