|
|
| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| from torchvision import datasets, transforms
|
| from torch.utils.data import DataLoader
|
| import matplotlib.pyplot as plt
|
| import numpy as np
|
| from tqdm import tqdm
|
| import os
|
|
|
| print("🚀 7Gen - Gelişmiş MNIST Üretici Sistemi 🚀")
|
|
|
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| print(f'Kullanılan cihaz: {device}')
|
|
|
|
|
| batch_size = 64
|
| latent_dim = 100
|
| num_classes = 10
|
| num_epochs = 100
|
| lr = 0.0002
|
|
|
|
|
| transform = transforms.Compose([
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.5], [0.5])
|
| ])
|
|
|
| dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
|
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
|
| class Generator(nn.Module):
|
| def __init__(self):
|
| super(Generator, self).__init__()
|
|
|
| self.label_emb = nn.Embedding(num_classes, num_classes)
|
|
|
| self.model = nn.Sequential(
|
| nn.Linear(latent_dim + num_classes, 256),
|
| nn.LeakyReLU(0.2),
|
| nn.BatchNorm1d(256),
|
|
|
| nn.Linear(256, 512),
|
| nn.LeakyReLU(0.2),
|
| nn.BatchNorm1d(512),
|
|
|
| nn.Linear(512, 1024),
|
| nn.LeakyReLU(0.2),
|
| nn.BatchNorm1d(1024),
|
|
|
| nn.Linear(1024, 784),
|
| nn.Tanh()
|
| )
|
|
|
| def forward(self, noise, labels):
|
| label_embedding = self.label_emb(labels)
|
| gen_input = torch.cat((noise, label_embedding), -1)
|
| img = self.model(gen_input)
|
| img = img.view(img.size(0), 1, 28, 28)
|
| return img
|
|
|
|
|
| class Discriminator(nn.Module):
|
| def __init__(self):
|
| super(Discriminator, self).__init__()
|
|
|
| self.label_emb = nn.Embedding(num_classes, num_classes)
|
|
|
| self.model = nn.Sequential(
|
| nn.Linear(784 + num_classes, 512),
|
| nn.LeakyReLU(0.2),
|
| nn.Dropout(0.3),
|
|
|
| nn.Linear(512, 256),
|
| nn.LeakyReLU(0.2),
|
| nn.Dropout(0.3),
|
|
|
| nn.Linear(256, 1),
|
| nn.Sigmoid()
|
| )
|
|
|
| def forward(self, img, labels):
|
| img_flat = img.view(img.size(0), -1)
|
| label_embedding = self.label_emb(labels)
|
| d_input = torch.cat((img_flat, label_embedding), -1)
|
| validity = self.model(d_input)
|
| return validity
|
|
|
|
|
| generator = Generator().to(device)
|
| discriminator = Discriminator().to(device)
|
|
|
|
|
| adversarial_loss = nn.BCELoss()
|
| optimizer_G = optim.Adam(generator.parameters(), lr=lr)
|
| optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
|
|
|
|
|
| os.makedirs('generated_images', exist_ok=True)
|
|
|
|
|
| print("\n🔥 7Gen Eğitimi Başlıyor...")
|
|
|
| for epoch in range(num_epochs):
|
| for i, (imgs, labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
|
| imgs = imgs.to(device)
|
| labels = labels.to(device)
|
| batch_size = imgs.size(0)
|
|
|
|
|
| valid = torch.ones(batch_size, 1).to(device)
|
| fake = torch.zeros(batch_size, 1).to(device)
|
|
|
|
|
| optimizer_G.zero_grad()
|
| z = torch.randn(batch_size, latent_dim).to(device)
|
| gen_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
|
| gen_imgs = generator(z, gen_labels)
|
|
|
| g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)
|
| g_loss.backward()
|
| optimizer_G.step()
|
|
|
|
|
| optimizer_D.zero_grad()
|
| real_loss = adversarial_loss(discriminator(imgs, labels), valid)
|
| fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
|
| d_loss = (real_loss + fake_loss) / 2
|
|
|
| d_loss.backward()
|
| optimizer_D.step()
|
|
|
| print(f"Epoch {epoch+1}/{num_epochs} - D loss: {d_loss:.4f}, G loss: {g_loss:.4f}")
|
|
|
|
|
| if (epoch + 1) % 10 == 0:
|
| with torch.no_grad():
|
| z = torch.randn(100, latent_dim).to(device)
|
| labels = torch.tensor([i for i in range(10) for _ in range(10)]).to(device)
|
| gen_imgs = generator(z, labels)
|
| gen_imgs = (gen_imgs + 1) / 2
|
|
|
| fig, axes = plt.subplots(10, 10, figsize=(10, 10))
|
| for i in range(10):
|
| for j in range(10):
|
| idx = i * 10 + j
|
| axes[i, j].imshow(gen_imgs[idx][0].cpu().numpy(), cmap='gray')
|
| axes[i, j].axis('off')
|
| plt.savefig(f'generated_images/7gen_epoch_{epoch+1}.png')
|
| plt.close()
|
|
|
|
|
| os.makedirs('models', exist_ok=True)
|
| torch.save(generator.state_dict(), 'models/7gen_generator.pth')
|
| torch.save(discriminator.state_dict(), 'models/7gen_discriminator.pth')
|
|
|
| print("\n✅ 7Gen eğitimi tamamlandı!")
|
|
|
|
|
| def generate_digit(digit, num_samples=5):
|
| generator.eval()
|
| with torch.no_grad():
|
| z = torch.randn(num_samples, latent_dim).to(device)
|
| labels = torch.full((num_samples,), digit).to(device)
|
| gen_imgs = generator(z, labels)
|
| gen_imgs = (gen_imgs + 1) / 2
|
|
|
| plt.figure(figsize=(10, 2))
|
| for i in range(num_samples):
|
| plt.subplot(1, num_samples, i+1)
|
| plt.imshow(gen_imgs[i][0].cpu().numpy(), cmap='gray')
|
| plt.axis('off')
|
| plt.savefig(f'generated_images/digit_{digit}_samples.png')
|
| plt.show()
|
|
|
|
|
| print("\n🎯 Test örnekleri üretiliyor...")
|
| for digit in range(10):
|
| generate_digit(digit, num_samples=5)
|
|
|
| print("\n🎉 7Gen hazır! generated_images klasörüne bak.") |