| | import streamlit as st |
| | from tensorflow import keras |
| | import os |
| | import matplotlib.pyplot as plt |
| | from io import BytesIO |
| | from NNVisualiser import NNVisualiser |
| | import glob |
| | import inspect |
| | from tensorflow.keras.models import save_model |
| | import tempfile |
| | import re |
| | import zipfile |
| | import io |
| |
|
| | |
| | def create_zip_of_png_files(): |
| | |
| | cwd = os.getcwd() |
| | png_files = [f for f in os.listdir(cwd) if f.endswith('.png')] |
| |
|
| | |
| | zip_buffer = io.BytesIO() |
| |
|
| | with zipfile.ZipFile(zip_buffer, 'w') as zip_file: |
| | for png_file in png_files: |
| | zip_file.write(os.path.join(cwd, png_file), arcname=png_file) |
| |
|
| | zip_buffer.seek(0) |
| | return zip_buffer |
| |
|
| | def generate_title_from_method_name(method_name): |
| | |
| | if method_name.startswith("plot"): |
| | method_name = method_name[4:] |
| | |
| | |
| | words = re.findall(r'[A-Z][a-z]*', method_name) |
| | |
| | |
| | title = "Plotting " + " ".join(words[:]) + " Plot " |
| | |
| | return title |
| |
|
| | def downloadKerasModel(): |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".keras") as tmp_file: |
| | save_model(model, tmp_file.name) |
| | tmp_file.seek(0) |
| | model_data = tmp_file.read() |
| | return model_data |
| |
|
| | |
| | |
| | def generate_folder_hierarchy(root_folder, max_depth=7): |
| | folder_dict = {} |
| |
|
| | |
| | for dirpath, dirnames, filenames in os.walk(root_folder): |
| | |
| | rel_path = os.path.relpath(dirpath, root_folder) |
| | depth = rel_path.count(os.sep) + 1 |
| |
|
| | |
| | if depth > max_depth: |
| | continue |
| |
|
| | |
| | dirnames[:] = [d for d in dirnames if not d.startswith('.') and d != '1'] |
| |
|
| | sub_dict = folder_dict |
| | |
| | for part in rel_path.split(os.sep): |
| | if part == '.' or part.startswith('.'): |
| | continue |
| | if part not in sub_dict: |
| | sub_dict[part] = {} |
| | sub_dict = sub_dict[part] |
| |
|
| | return folder_dict |
| |
|
| | @st.cache_data |
| | def getPlotMethods(): |
| | return [name for name, func in inspect.getmembers(NNVisualiser, inspect.isfunction) if name.startswith('plot')] |
| |
|
| | |
| | root_folder = os.getcwd(); |
| | folder_hierarchy = generate_folder_hierarchy(root_folder) |
| |
|
| | |
| | st.title("Repository : Simple ANN Models with UAT Architecture") |
| | st.write(f"A Collection of ANN Models with a 1-xReLU-1 Architecture for Basic 1D Functions on Bounded Intervals") |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | repo = st.sidebar.selectbox("Select Model Repository",list(folder_hierarchy.keys())) |
| | initialisation = st.sidebar.selectbox("Select Initialisation", list(folder_hierarchy[repo].keys())) |
| | sampleSize = st.sidebar.selectbox("Select Sample Size", list(folder_hierarchy[repo][initialisation].keys())) |
| | batchSize = st.sidebar.selectbox("Select Batch Size", list(folder_hierarchy[repo][initialisation][sampleSize].keys())) |
| | epochs = st.sidebar.selectbox("Select Epochs Count", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize].keys())) |
| | functions = st.sidebar.selectbox("Select Function", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize][epochs].keys())) |
| | neurons = st.sidebar.selectbox("Select Neurons Count", list(folder_hierarchy[repo][initialisation][sampleSize][batchSize][epochs][functions].keys())) |
| |
|
| | |
| | st.write(f"You selected: {repo} : {initialisation} : {sampleSize} : {batchSize} : {epochs} : {functions} : {neurons}") |
| |
|
| | modelPath = os.path.join(os.getcwd(), repo, initialisation, sampleSize, batchSize, epochs, functions, neurons); |
| | model = keras.models.load_model(modelPath); |
| |
|
| | visualiser = NNVisualiser(model); |
| | visualiser.setSavePlots(True); |
| |
|
| | |
| | def get_layer_info(model): |
| | layer_info = [] |
| | for layer in model.layers: |
| | layer_info.append({ |
| | 'index': len(layer_info), |
| | 'type': layer.__class__.__name__, |
| | 'units': getattr(layer, 'units', None), |
| | }) |
| | return layer_info |
| |
|
| | layer_info = get_layer_info(model) |
| |
|
| | |
| | layer_indices = [layer['index'] for layer in layer_info] |
| | neuron_counts = [layer['units'] for layer in layer_info] |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | plotMethods = getPlotMethods() |
| | selectedPlotMethod = st.sidebar.selectbox("Select Plot", plotMethods) |
| |
|
| | |
| | image_files = glob.glob("*.png") |
| | for file in image_files: |
| | try: |
| | os.remove(file) |
| | except Exception as e: |
| | st.write("Error in removing previous plots") |
| |
|
| | st.session_state.title_text = generate_title_from_method_name(selectedPlotMethod) |
| | st.title(st.session_state.title_text) |
| |
|
| | |
| | visualiser.setSavePlots(True); |
| | method = getattr(visualiser, selectedPlotMethod, None) |
| |
|
| | if method is not None: |
| | if 'Neuron' in selectedPlotMethod: |
| | selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices) |
| | |
| | selected_layer_units = neuron_counts[selected_layer_index] |
| | |
| | neuron_indices = list(range(selected_layer_units)) |
| | selected_neuron_index = st.sidebar.selectbox("Select Neuron Index", neuron_indices) |
| | params = (selected_layer_index, selected_neuron_index) |
| | method(*params) |
| | elif 'Layer' in selectedPlotMethod: |
| | selected_layer_index = st.sidebar.selectbox("Select Layer Index", layer_indices) |
| | params = (selected_layer_index,) |
| | method(*params) |
| | else: |
| | method() |
| |
|
| | st.session_state.kerasModelToDownload = downloadKerasModel() |
| | st.session_state.plotsToDownload = create_zip_of_png_files() |
| |
|
| | @st.fragment() |
| | def downloads(): |
| | st.download_button( |
| | label="Download Model", |
| | data = downloadKerasModel(), |
| | file_name="model.keras", |
| | mime="application/octet-stream" |
| | ); |
| | |
| | st.download_button( |
| | label="Download Plots", |
| | data=create_zip_of_png_files(), |
| | file_name="images.zip", |
| | mime="application/zip" |
| | ); |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | with st.sidebar: |
| | downloads() |
| |
|
| | |
| |
|
| | image_files = glob.glob("*.png") |
| |
|
| | |
| | st.image(image_files) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|