bcueva commited on
Commit
9f13279
Β·
verified Β·
1 Parent(s): 7d64923

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib, shutil, zipfile, tempfile
2
+ import pandas
3
+ import gradio
4
+ import huggingface_hub
5
+ import autogluon.multimodal
6
+ import PIL.Image
7
+
8
+ MODEL_REPO_ID = "george2cool36/hw2_image_automl_autogluon"
9
+ ZIP_FILENAME = "ag_image_predictor_dir.zip"
10
+ CACHE_DIR = pathlib.Path("hf_assets")
11
+ EXTRACT_DIR = CACHE_DIR / "predictor_native"
12
+
13
+ CLASS_LABELS = {0: "πŸ›‘ Has Stop Sign", 1: "βœ… No Stop Sign"}
14
+
15
+ def _prepare_predictor_dir() -> str:
16
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
17
+ local_zip = huggingface_hub.hf_hub_download(
18
+ repo_id=MODEL_REPO_ID,
19
+ filename=ZIP_FILENAME,
20
+ repo_type="model",
21
+ local_dir=str(CACHE_DIR),
22
+ local_dir_use_symlinks=False,
23
+ )
24
+ if EXTRACT_DIR.exists():
25
+ shutil.rmtree(EXTRACT_DIR)
26
+ EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
27
+ with zipfile.ZipFile(local_zip, "r") as zf:
28
+ zf.extractall(str(EXTRACT_DIR))
29
+ contents = list(EXTRACT_DIR.iterdir())
30
+ predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
31
+ return str(predictor_root)
32
+
33
+ PREDICTOR_DIR = _prepare_predictor_dir()
34
+ PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR)
35
+
36
+ def _human_label(c):
37
+ try:
38
+ ci = int(c)
39
+ return CLASS_LABELS.get(ci, str(c))
40
+ except Exception:
41
+ return CLASS_LABELS.get(c, str(c))
42
+
43
+ def do_predict(pil_img: PIL.Image.Image):
44
+ if pil_img is None:
45
+ return {}, "No image provided."
46
+
47
+ tmpdir = pathlib.Path(tempfile.mkdtemp())
48
+ img_path = tmpdir / "input.png"
49
+ pil_img.save(img_path)
50
+
51
+ df = pandas.DataFrame({"image": [str(img_path)]})
52
+
53
+ proba_df = PREDICTOR.predict_proba(df)
54
+ proba_df = proba_df.rename(columns={0: "πŸ›‘ Has Stop Sign (0)", 1: "βœ… No Stop Sign (1)"})
55
+ row = proba_df.iloc[0]
56
+
57
+ pretty_dict = {
58
+ "πŸ›‘ Has Stop Sign": float(row.get("πŸ›‘ Has Stop Sign (0)", 0.0)),
59
+ "βœ… No Stop Sign": float(row.get("βœ… No Stop Sign (1)", 0.0)),
60
+ }
61
+
62
+ predicted_class = PREDICTOR.predict(df).iloc[0]
63
+ pred_label = _human_label(predicted_class)
64
+
65
+ md = f"**Prediction:** {pred_label}"
66
+ if pretty_dict:
67
+ md += f" \n**Confidence:** {round(pretty_dict.get(pred_label, 0.0) * 100, 2)}%"
68
+
69
+
70
+ return pretty_dict, md
71
+
72
+ EXAMPLES = [
73
+ ["https://www.kingsrivercasting.com/images/stories/virtuemart/product/STOP%20SIGN%20(5).jpg"],
74
+ ["https://www.trafficsafetywarehouse.com/Resources/images/traffic-sign-shapes.jpeg"],
75
+ ["https://di-uploads-pod16.dealerinspire.com/toyotaofnorthcharlotte/uploads/2020/08/yield-road-sign.jpg"]
76
+ ]
77
+
78
+
79
+ with gradio.Blocks() as demo:
80
+ gradio.Markdown("# Has Stop Sign or Not?")
81
+ gradio.Markdown(
82
+ "This is a simple app that demonstrates how to use an autogluon multimodal"
83
+ "predictor in a gradio space to predict whether an image contains a stop sign. To use,"
84
+ "just upload a photo. The result should be generated automatically."
85
+ )
86
+
87
+ image_in = gradio.Image(type="pil", label="Input image", sources=["upload", "webcam"])
88
+
89
+ proba_pretty = gradio.Label(num_top_classes=2, label="Class probabilities")
90
+ prediction_output = gradio.Markdown()
91
+
92
+
93
+ inputs = [image_in]
94
+ outputs = [proba_pretty, prediction_output]
95
+ for comp in inputs:
96
+ comp.change(fn=do_predict, inputs=inputs, outputs=outputs)
97
+
98
+ gradio.Examples(
99
+ examples=EXAMPLES,
100
+ inputs=inputs,
101
+ label="Representative examples",
102
+ examples_per_page=8,
103
+ cache_examples=False,
104
+ )
105
+
106
+ if __name__ == "__main__":
107
+ demo.launch(debug=False)