| | |
| | |
| | |
| | import panel as pn |
| | import os, yaml |
| | from panel.viewable import Viewer |
| |
|
| | from app_components import canvas, plots |
| | from app_utils import styles |
| |
|
| | pn.extension('plotly') |
| | FILE_PATH = os.path.dirname(__file__) |
| |
|
| |
|
| | |
| | |
| | |
| | class DigitClassifier(Viewer): |
| | ''' |
| | Builds and displays the UI for the classifier application. |
| | |
| | Args: |
| | mod_path (str): The absolute path to the saved TinyVGG model |
| | mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model. |
| | This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes |
| | ''' |
| |
|
| | def __init__(self, mod_path: str, mod_kwargs: dict, **params): |
| | self.canvas = canvas.Canvas(sizing_mode = 'stretch_both', |
| | styles = {'border':'black solid 0.15rem'}) |
| | |
| | self.clear_btn = pn.widgets.Button(name = 'Clear', |
| | sizing_mode = 'stretch_width', |
| | stylesheets = [styles.BTN_STYLESHEET]) |
| |
|
| | self.plot_panels = plots.PlotPanels(canvas_info = self.canvas, mod_path = mod_path, mod_kwargs = mod_kwargs) |
| | |
| | super().__init__(**params) |
| | self.github_logo = pn.pane.PNG( |
| | object = FILE_PATH + '/assets/github-mark-white.png', |
| | alt_text = 'GitHub Repo', |
| | link_url = 'https://github.com/Jechen00/digit-classifier-app', |
| | height = 70, |
| | styles = {'margin':'0'} |
| | ) |
| | self.controls_col = pn.FlexBox( |
| | self.github_logo, |
| | self.clear_btn, |
| | self.plot_panels.pred_txt, |
| | gap = '60px', |
| | flex_direction = 'column', |
| | justify_content = 'center', |
| | align_items = 'center', |
| | flex_wrap = 'nowrap', |
| | styles = {'width':'40%', 'height':'100%'} |
| | ) |
| |
|
| | self.mod_input_txt = pn.pane.HTML( |
| | object = ''' |
| | <div> |
| | <b>MODEL INPUT</b> |
| | </div> |
| | ''', |
| | styles = {'margin':'0rem', 'padding-left':'0.15rem', 'color':'white', |
| | 'font-size':styles.FONTSIZES['mod_input_txt'], |
| | 'font-family':styles.FONTFAMILY, |
| | 'position':'absolute', 'z-index':'100'} |
| | ) |
| |
|
| | self.img_row = pn.FlexBox( |
| | self.canvas, |
| | self.controls_col, |
| | pn.FlexBox(self.mod_input_txt, |
| | self.plot_panels.img_pane, |
| | sizing_mode = 'stretch_both', |
| | styles = {'border':'solid 0.15rem white'}), |
| | gap = '1%', |
| | flex_wrap = 'nowrap', |
| | flex_direction = 'row', |
| | justify_content = 'center', |
| | sizing_mode = 'stretch_width', |
| | styles = {'height':'60%'} |
| | ) |
| |
|
| | self.prob_row = pn.FlexBox(self.plot_panels.prob_pane, |
| | sizing_mode = 'stretch_width', |
| | styles = {'height':'40%', |
| | 'border':'solid 0.15rem black'}) |
| |
|
| | self.page_info = pn.pane.HTML( |
| | object = f''' |
| | <style> |
| | .link {{ |
| | color: rgb(29, 161, 242); |
| | text-decoration: none; |
| | transition: text-decoration 0.2s ease; |
| | }} |
| | |
| | .link:hover {{ |
| | text-decoration: underline; |
| | }} |
| | </style> |
| | |
| | <div style="text-align:center; font-size:{styles.FONTSIZES['sidebar_title']};margin-top:0.2rem"> |
| | <b>Digit Classifier</b> |
| | </div> |
| | |
| | <div style="padding:0 2.5% 0 2.5%; text-align:left; font-size:{styles.FONTSIZES['sidebar_txt']}; width: 100%;"> |
| | <hr style="height:2px; background-color:rgb(200, 200, 200); border:none; margin-top:0"> |
| | <p style="margin:0"> |
| | This is a handwritten digit classifier that uses a <i>convolutional neural network (CNN)</i> |
| | to make predictions. The architecture of the model is a scaled-down version of |
| | the <i>Visual Geometry Group (VGG)</i> architecture from the paper: |
| | <a href="https://arxiv.org/pdf/1409.1556" |
| | class="link" |
| | target="_blank" |
| | rel="noopener noreferrer"> |
| | Very Deep Convolutional Networks for Large-Scale Image Recognition</a>. |
| | </p> |
| | </br> |
| | <p style="margin:0"> |
| | <b>How To Use:</b> Draw a digit (0-9) on the canvas |
| | and the model will produce a prediction for it in real time. |
| | Prediction probabilities (or confidences) for each digit are displayed in the bar chart, |
| | reflecting the model's softmax output distribution. |
| | To the right of the canvas, you'll also find the transformed input image, i.e. the canvas drawing after undergoing |
| | <a href="https://paperswithcode.com/dataset/mnist" |
| | class="link" |
| | target="_blank" |
| | rel="noopener noreferrer"> |
| | MNIST preprocessing</a>. |
| | This input image represents what the model receives prior to feature extraction and classification. |
| | </p> |
| | </br> |
| | <p style="margin:0"> |
| | <b>Note:</b> Due to resource limitations on HF Spaces (CPU basic), performance may vary. |
| | For optimal experience, it's recommended to run the app locally. |
| | </p> |
| | </div> |
| | ''', |
| | styles = {'margin':' 0rem', 'color': styles.CLRS['sidebar_txt'], |
| | 'width': '19.7%', 'height': '100%', |
| | 'font-family': styles.FONTFAMILY, |
| | 'background-color': styles.CLRS['sidebar'], |
| | 'overflow-y':'scroll', |
| | 'border': 'solid 0.15rem black'} |
| | ) |
| |
|
| | self.classifier_content = pn.FlexBox( |
| | self.img_row, |
| | self.prob_row, |
| | gap = '0.5%', |
| | flex_direction = 'column', |
| | flex_wrap = 'nowrap', |
| | sizing_mode = 'stretch_height', |
| | styles = {'width': '80%'} |
| | ) |
| |
|
| | self.page_content = pn.FlexBox( |
| | self.page_info, |
| | self.classifier_content, |
| | gap = '0.3%', |
| | flex_direction = 'row', |
| | justify_content = 'space-around', |
| | align_items = 'center', |
| | flex_wrap = 'nowrap', |
| | styles = { |
| | 'height':'100%', |
| | 'width':'100vw', |
| | 'padding': '1%', |
| | 'min-width': '1200px', |
| | 'min-height': '600px', |
| | 'max-width': '3600px', |
| | 'max-height': '1800px', |
| | 'background-color': styles.CLRS['page_bg'] |
| | }, |
| | ) |
| |
|
| | |
| | self.page_layout = pn.FlexBox( |
| | self.page_content, |
| | justify_content = 'center', |
| | flex_wrap = 'nowrap', |
| | sizing_mode = 'stretch_both', |
| | styles = { |
| | 'min-width': 'max-content', |
| | 'background-color': styles.CLRS['page_bg'], |
| | } |
| | ) |
| | |
| | self.clear_btn.on_click(self.canvas.toggle_clear) |
| | |
| | def __panel__(self): |
| | ''' |
| | Returns the main layout of the application to be rendered by Panel. |
| | ''' |
| | return self.page_layout |
| |
|
| |
|
| | def create_app(): |
| | ''' |
| | Creates the application, ensuring that each user gets a different instance of digit_classifier. |
| | Mostly used to keep things away from a global scope. |
| | ''' |
| | |
| | save_dir = FILE_PATH + '/saved_models' |
| | base_name = 'tiny_vgg_less_compute' |
| |
|
| | mod_path = f'{save_dir}/{base_name}_model.pth' |
| | settings_path = f'{save_dir}/{base_name}_settings.yaml' |
| |
|
| | |
| | with open( settings_path, 'r') as f: |
| | loaded_settings = yaml.load(f, Loader = yaml.FullLoader) |
| |
|
| | mod_kwargs = loaded_settings['mod_kwargs'] |
| |
|
| | digit_classifier = DigitClassifier(mod_path = mod_path, mod_kwargs = mod_kwargs) |
| | return digit_classifier |
| |
|
| | |
| | |
| | |
| | |
| | create_app().servable(title = 'CNN Digit Classifier') |
| |
|