Skip to content

Adding fid calculation at the end of epochs (at train.py) #1192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def eval(self):
net = getattr(self, 'net' + name)
net.eval()

def train(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.train()


def test(self):
"""Forward function used in test time.

Expand Down
55 changes: 51 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,35 @@
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
import os
from options.train_options import TrainOptions
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
from util import html
from util.visualizer import Visualizer, save_images
from pytorch_fid.fid_score import calculate_fid_given_paths

if __name__ == '__main__':
opt = TrainOptions().parse() # get training options
opt = TrainOptions().parse() # get training options
val_opts = TestOptions().parse()
val_opts.phase = 'val'
val_opts.num_threads = 0 # test code only supports num_threads = 0
val_opts.batch_size = 1 # test code only supports batch_size = 1
val_opts.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
val_opts.no_flip = True # no flip; comment this line if results on flipped images are needed.
val_opts.display_id = -1

dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset) # get the number of images in the dataset.
val_dataset = create_dataset(val_opts) # create a dataset given opt.dataset_mode and other options
web_dir = os.path.join(val_opts.results_dir, val_opts.name,
'{}_{}'.format(val_opts.phase, val_opts.epoch)) # define the website directory
if opt.load_iter > 0: # load_iter is 0 by default
web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter)
print('creating web directory', web_dir)
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))

dataset_size = len(dataset) # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)

model = create_model(opt) # create a model given opt.model and other options
Expand Down Expand Up @@ -74,4 +94,31 @@
model.save_networks('latest')
model.save_networks(epoch)

print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
if epoch % opt.val_metric_freq == 0:
print('Evaluating FID for validation set at epoch %d, iters %d, at dataset %s' % (
epoch, total_iters, opt.name))
model.eval()
for i, data in enumerate(val_dataset):
model.set_input(data) # unpack data from data loader
model.test() # run inference

visuals = model.get_current_visuals() # get image results
if opt.direction == 'BtoA':
visuals = {'fake_B': visuals['fake_B']}
else:
visuals = {'fake_A': visuals['fake_A']}

img_path = model.get_image_paths() # get image paths
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
save_images(webpage, visuals, img_path, aspect_ratio=val_opts.aspect_ratio,
width=val_opts.display_winsize)
fid_value = calculate_fid_given_paths(
paths=('./results/{d}/val_latest/images/'.format(d=opt.name), '{d}/val'.format(d=opt.dataroot)),
batch_size=64, cuda=True, dims=2048)
visualizer.print_current_fid(epoch, fid_value)
visualizer.plot_current_fid(epoch, fid_value)
model.train()

print('End of epoch %d / %d \t Time Taken: %d sec' % (
epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
41 changes: 41 additions & 0 deletions util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,14 @@ def __init__(self, opt):
util.mkdirs([self.web_dir, self.img_dir])
# create a logging file to store training losses
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
self.fid_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'fid_log.txt')

with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
with open(self.fid_log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Validation FID (%s) ================\n' % now)

def reset(self):
"""Reset the self.saved status"""
Expand Down Expand Up @@ -176,6 +181,29 @@ def display_current_results(self, visuals, epoch, save_result):
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()

def plot_current_fid(self, epoch, fid):
"""display the current fid on visdom display

Parameters:
epoch (int) -- current epoch
fid (float) -- validation fid
"""
if not hasattr(self, 'fid_plot_data'):
self.fid_plot_data = {'X': [], 'Y': []}
self.fid_plot_data['X'].append(epoch)
self.fid_plot_data['Y'].append(fid)
try:
self.vis.line(
X=np.array(self.fid_plot_data['X']),
Y=np.array(self.fid_plot_data['Y']),
opts={
'title': self.name + ' fid over time',
'xlabel': 'epoch',
'ylabel': 'fid'},
win=self.display_id + 4)
except VisdomExceptionBase:
self.create_visdom_connections()

def plot_current_losses(self, epoch, counter_ratio, losses):
"""display the current losses on visdom display: dictionary of error labels and values

Expand Down Expand Up @@ -219,3 +247,16 @@ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
print(message) # print the message
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message

def print_current_fid(self, epoch, fid):
"""print current fid on console; also save the fid to the disk

Parameters:
epoch (int) -- current epoch
fid (float) - fid metric
"""
message = '(epoch: %d, fid: %.3f) ' % (epoch, fid)

print(message) # print the message
with open(self.fid_log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message