EnYa32 commited on
Commit
b138cbf
·
verified ·
1 Parent(s): 6e2976c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +143 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,146 @@
1
- import altair as alt
 
 
2
  import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import json
2
+ from pathlib import Path
3
+
4
  import numpy as np
 
5
  import streamlit as st
6
+ from PIL import Image
7
+ import tensorflow as tf
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ # -------------------------
12
+ # Page config
13
+ # -------------------------
14
+ st.set_page_config(
15
+ page_title='Facial Keypoints Predictor (CNN)',
16
+ page_icon='😊',
17
+ layout='centered'
18
+ )
19
+
20
+ st.title('😊 Facial Keypoints Predictor (CNN)')
21
+ st.write('Upload a face image and the CNN predicts 30 facial keypoint coordinates (x/y).')
22
+
23
+
24
+ # -------------------------
25
+ # Paths (HF-friendly: repo root)
26
+ # -------------------------
27
+ BASE_DIR = Path(__file__).resolve().parent
28
+
29
+ MODEL_PATH = BASE_DIR / 'final_keypoints_cnn.keras'
30
+ TARGET_COLS_PATH = BASE_DIR / 'target_cols.json'
31
+ PREPROCESS_PATH = BASE_DIR / 'preprocess_config.json'
32
+
33
+
34
+ # -------------------------
35
+ # Loaders
36
+ # -------------------------
37
+ @st.cache_resource
38
+ def load_model():
39
+ if not MODEL_PATH.exists():
40
+ raise FileNotFoundError(
41
+ f'Model not found: {MODEL_PATH.name}. Put it in the repo root (same folder as app.py).'
42
+ )
43
+ return tf.keras.models.load_model(MODEL_PATH, compile=False)
44
+
45
+
46
+ @st.cache_data
47
+ def load_json(path: Path):
48
+ if not path.exists():
49
+ raise FileNotFoundError(
50
+ f'File not found: {path.name}. Put it in the repo root (same folder as app.py).'
51
+ )
52
+ with open(path, 'r') as f:
53
+ return json.load(f)
54
+
55
+
56
+ model = load_model()
57
+ target_cols = load_json(TARGET_COLS_PATH)
58
+ pre_cfg = load_json(PREPROCESS_PATH)
59
+
60
+ IMG_H, IMG_W = pre_cfg.get('img_size', [96, 96])
61
+
62
+
63
+ # -------------------------
64
+ # Preprocess + Postprocess
65
+ # -------------------------
66
+ def preprocess_image(pil_img: Image.Image) -> np.ndarray:
67
+ # Convert to grayscale like training data
68
+ img = pil_img.convert('L')
69
+ img = img.resize((IMG_W, IMG_H))
70
+ arr = np.array(img, dtype=np.float32) # 0..255
71
+ arr = arr / 255.0
72
+ arr = arr.reshape(1, IMG_H, IMG_W, 1)
73
+ return arr
74
+
75
+
76
+ def denormalize_keypoints(pred_norm: np.ndarray) -> np.ndarray:
77
+ # Training normalization: (y - 48) / 48 -> invert: y = pred*48 + 48
78
+ pred = pred_norm * 48.0 + 48.0
79
+ return pred
80
+
81
+
82
+ def plot_keypoints(pil_img: Image.Image, keypoints_xy: np.ndarray):
83
+ img = pil_img.convert('L').resize((IMG_W, IMG_H))
84
+ xs = keypoints_xy[0::2]
85
+ ys = keypoints_xy[1::2]
86
+
87
+ fig, ax = plt.subplots(figsize=(5, 5))
88
+ ax.imshow(img, cmap='gray')
89
+ ax.scatter(xs, ys, s=25)
90
+ ax.axis('off')
91
+ return fig
92
+
93
+
94
+ # -------------------------
95
+ # UI
96
+ # -------------------------
97
+ uploaded = st.file_uploader('Upload an image (jpg/png)', type=['jpg', 'jpeg', 'png'])
98
+
99
+ with st.expander('Settings', expanded=False):
100
+ show_table = st.checkbox('Show predicted coordinates table', value=True)
101
+ show_overlay = st.checkbox('Show keypoints overlay', value=True)
102
+
103
+ if uploaded is None:
104
+ st.info('Upload an image to get predictions.')
105
+ st.stop()
106
+
107
+ pil_img = Image.open(uploaded)
108
+
109
+ st.subheader('Input Image')
110
+ st.image(pil_img, use_container_width=True)
111
+
112
+ x = preprocess_image(pil_img)
113
+
114
+ with st.spinner('Predicting keypoints...'):
115
+ pred_norm = model.predict(x, verbose=0)[0] # shape (30,)
116
+ pred_px = denormalize_keypoints(pred_norm)
117
+
118
+ # Build table
119
+ rows = []
120
+ for i, name in enumerate(target_cols):
121
+ rows.append({'feature': name, 'value_px': float(pred_px[i])})
122
+
123
+ if show_table:
124
+ st.subheader('Predicted Keypoints (Pixels)')
125
+ st.dataframe(rows, use_container_width=True)
126
+
127
+ if show_overlay:
128
+ st.subheader('Overlay')
129
+ fig = plot_keypoints(pil_img, pred_px)
130
+ st.pyplot(fig, clear_figure=True)
131
+
132
+ # Optional: downloadable single-image "submission-like" csv (ImageId=1)
133
+ st.subheader('Download (optional)')
134
+ st.caption('This creates a single-image submission-like file: ImageId=1, RowId=1..30.')
135
+
136
+ csv_lines = ['RowId,ImageId,FeatureName,Location']
137
+ for idx, (name, val) in enumerate(zip(target_cols, pred_px), start=1):
138
+ csv_lines.append(f'{idx},1,{name},{val}')
139
+ csv_text = '\n'.join(csv_lines)
140
 
141
+ st.download_button(
142
+ label='Download predicted_keypoints.csv',
143
+ data=csv_text.encode('utf-8'),
144
+ file_name='predicted_keypoints.csv',
145
+ mime='text/csv'
146
+ )