| from flax import linen as nn |
| import jax |
| import jax.numpy as jnp |
| from local_response_norm import LocalResponseNorm |
|
|
| EPSILON = 1e-8 |
| MAX_DISC_FEATURES = 128 |
| MAX_GEN_FEATURES = 512 |
| LATENT_DIM = 512 |
| MAX_LAYERS = 7 |
|
|
| def get_gen_layers(layer): |
| resolution = int(4 * 2 ** layer) |
| features = min(int(32 * 2 ** (MAX_LAYERS - 1 - layer)), MAX_GEN_FEATURES) |
| layers = [] |
| layers.append(lambda x: jax.image.resize(x, shape=(x.shape[0], resolution, resolution, x.shape[3]), method="linear")) |
| layers.append(lambda x: nn.ConvTranspose(features=features, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_{features}")(x)) |
| layers.append(lambda x: nn.relu(x)) |
| return layers |
|
|
| def get_initial_gen_layers(num_layers): |
| layers = [] |
| layers.append(lambda x: x.reshape(x.shape[0], 1, 1, -1)) |
| return layers |
|
|
| def get_final_gen_layers(num_layers): |
| resolution = int(4 * 2 ** (num_layers - 1)) |
| layers = [] |
| layers.append(lambda x: nn.ConvTranspose(features=3, kernel_size=(3, 3), name=f"ConvTranspose_{resolution}_3")(x)) |
| return layers |
|
|
| class Generator(nn.Module): |
| num_layers: int = None |
|
|
| def setup(self): |
| |
| layers = [] |
| layers.extend(get_initial_gen_layers(self.num_layers)) |
| for layer in range(self.num_layers): |
| layers.extend(get_gen_layers(layer)) |
| layers.extend(get_final_gen_layers(self.num_layers)) |
| self.layers = layers |
|
|
| @nn.compact |
| def __call__(self, x): |
| result = x |
| for layer in self.layers: |
| result = layer(result) |
| return result |