Source code for loss_landscape.landscape_utils

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import copy
from typing import Dict, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import animation, cm
from mpl_toolkits.mplot3d import Axes3D

from utils.common_utils import unwrap_model_fn

# https://github.com/xxxnell/how-do-vits-work


[docs]def rand_basis(ws: Dict, device: Optional[str] = torch.device("cpu")): return {k: torch.randn(size=v.shape, device=device) for k, v in ws.items()}
[docs]def normalize_filter(bs: Dict, ws: Dict): bs = {k: v.float() for k, v in bs.items()} ws = {k: v.float() for k, v in ws.items()} norm_bs = {} for k in bs: ws_norm = torch.norm(ws[k], dim=0, keepdim=True) bs_norm = torch.norm(bs[k], dim=0, keepdim=True) norm_bs[k] = ws_norm / (bs_norm + 1e-7) * bs[k] return norm_bs
[docs]def ignore_bn(ws: Dict): ignored_ws = {} for k in ws: if len(ws[k].size()) < 2: ignored_ws[k] = torch.zeros(size=ws[k].size(), device=ws[k].device) else: ignored_ws[k] = ws[k] return ignored_ws
[docs]def create_bases( model: torch.nn.Module, device: Optional[str] = torch.device("cpu"), has_module: Optional[bool] = False, ): unwrapped_model = unwrap_model_fn(model) weight_state_0 = unwrapped_model.state_dict() bases = [rand_basis(weight_state_0, device) for _ in range(2)] # Use two bases bases = [normalize_filter(bs, weight_state_0) for bs in bases] bases = [ignore_bn(bs) for bs in bases] return bases
[docs]def generate_plots(xx, yy, zz, model_name, results_loc): zz = np.log(zz) plt.figure(figsize=(10, 10)) plt.contour(xx, yy, zz) plt.savefig(f"{results_loc}/{model_name}_log_contour.png", dpi=100) plt.close() ## 3D plot fig, ax = plt.subplots(subplot_kw={"projection": "3d"}) ax.set_axis_off() surf = ax.plot_surface(xx, yy, zz, cmap=cm.coolwarm, linewidth=0, antialiased=False) ax.set_xlim(-1, 1) ax.set_ylim(-1, 1) plt.savefig( f"{results_loc}/{model_name}_log_surface.png", dpi=100, format="png", bbox_inches="tight", ) plt.close() fig = plt.figure(figsize=(10, 10)) ax = Axes3D(fig) ax.set_axis_off() def init(): ax.plot_surface(xx, yy, zz, cmap=cm.coolwarm, linewidth=0, antialiased=False) ax.set_xlim(-1, 1) ax.set_ylim(-1, 1) return (fig,) def animate(i): ax.view_init(elev=(15 * (i // 15) + i % 15) + 0.0, azim=i) ax.set_xlim(-1, 1) ax.set_ylim(-1, 1) return (fig,) anim = animation.FuncAnimation( fig, animate, init_func=init, frames=100, interval=20, blit=True ) anim.save( f"{results_loc}/{model_name}_log_surface.gif", fps=15, writer="imagemagick" )
[docs]def plot_save_graphs( save_dir: str, model_name: str, grid_a: np.ndarray, grid_b: np.ndarray, loss_surface: np.ndarray, resolution: int, ): np.save(f"{save_dir}/{model_name}_xx.npy", grid_a) np.save(f"{save_dir}/{model_name}_yy.npy", grid_b) np.save(f"{save_dir}/{model_name}_zz.npy", loss_surface) plt.figure(figsize=(10, 10)) plt.contour(grid_a, grid_b, loss_surface) plt.savefig(f"{save_dir}/{model_name}_contour_res_{resolution}.png", dpi=100) plt.close() generate_plots( xx=grid_a, yy=grid_b, zz=loss_surface, model_name=model_name, results_loc=save_dir, )