SkillAgentSearch skills...

Mlvtk

create 3d loss surface visualizations, with optimizer path.

Install / Use

/learn @tm-schwartz/Mlvtk
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

MLVTK PyPI - Python Version PyPI

A loss surface visualization tool

<img alt="Png" src="https://raw.githubusercontent.com/tm-schwartz/mlvtk/dev/visuals/elu-adam-chess.png" width="80%" />

Simple feed-forward network trained on chess data, using elu activation and Adam optimizer


<img alt="Gif" src="https://raw.githubusercontent.com/tm-schwartz/mlvtk/dev/visuals/gifs/tanh-sgd-chess.gif" width="80%" />

Simple feed-forward network trained on chess data, using tanh activation and SGD optimizer


<img alt="Gif" src="https://raw.githubusercontent.com/tm-schwartz/mlvtk/dev/visuals/gifs/500_500_500_sgd_relu_lettrs_2lr.gif" width="80%" />

3 layer feed-forward network trained on hand written letters data, using relu activation, SGD optimizer and learning rate of 2.0. Example of what happens to path when learning rate is too high


<img alt="Gif" src="https://raw.githubusercontent.com/tm-schwartz/mlvtk/dev/visuals/gifs/hard_sigmoid-rmsprop-chess.gif" width="80%" />

Simple feed-forward network trained on chess data, using hard-sigmoid activation and RMSprop optimizer

Why?

  • :shipit: Simple: A single line addition is all that is needed.
  • :question: Informative: Gain insight into what your model is seeing.
  • :notebook: Educational: See how your hyper parameters and architecture impact your models perception.

Quick Start

Requires | version -------- | ------- python | >= 3.6.1 tensorflow | >= 2.3.1 plotly | >=4.9.0

Install locally (Also works in google Colab!):

pip install mlvtk

Optionally for use with jupyter notebook/lab:

Notebook

pip install "notebook>=5.3" "ipywidgets==7.5"

Lab

pip install jupyterlab "ipywidgets==7.5"

# Basic JupyterLab renderer support
jupyter labextension install jupyterlab-plotly@4.10.0

# OPTIONAL: Jupyter widgets extension for FigureWidget support
jupyter labextension install @jupyter-widgets/jupyterlab-manager plotlywidget@4.10.0

Basic Example

from mlvtk.base import Vmodel
import tensorflow as tf
import numpy as np

# NN with 1 hidden layer
inputs = tf.keras.layers.Input(shape=(None,100))
dense_1 = tf.keras.layers.Dense(50, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(10, activation='softmax')(dense_1)
_model = tf.keras.Model(inputs, outputs)

# Wrap with Vmodel
model = Vmodel(_model)
model.compile(optimizer=tf.keras.optimizers.SGD(),
loss=tf.keras.losses.CategoricalCrossentropy(), metrics=['accuracy'])

# All tf.keras.(Model/Sequential/Functional) methods/properties are accessible
# from Vmodel

model.summary()
model.get_config()
model.get_weights()
model.layers

# Create random example data
x = np.random.rand(3, 10, 100)
y = np.random.randint(9, size=(3, 10, 10))
xval = np.random.rand(1, 10, 100)
yval = np.random.randint(9, size=(1,10,10))

# Only difference, model.fit requires validation_data (tf.data.Dataset, or
# other container
history = model.fit(x, y, validation_data=(xval, yval), epochs=10, verbose=0)

# Calling model.surface_plot() returns a plotly.graph_objs.Figure
# model.surface_plot() will attempt to display the figure inline

fig = model.surface_plot()

# fig can save an interactive plot to an html file,
fig.write_html("surface_plot.html")

# or display the plot in jupyter notebook/lab or other compatible tool.
fig.show()
View on GitHub
GitHub Stars8
CategoryEducation
Updated1y ago
Forks1

Languages

Python

Security Score

75/100

Audited on Aug 13, 2024

No findings