File size: 17,746 Bytes
a2cf83f
 
 
 
 
 
 
bddf009
 
aa965c4
c8e0302
a2cf83f
 
 
 
375bef0
 
 
a2cf83f
 
 
4ab9818
a2cf83f
 
 
 
 
 
 
 
4ab9818
 
 
a2cf83f
 
 
 
 
 
 
 
bddf009
a2cf83f
 
 
 
4ab9818
a2cf83f
 
 
 
 
 
 
 
 
 
 
 
 
4ab9818
a2cf83f
 
 
 
 
4ab9818
a2cf83f
 
 
 
 
 
 
4ab9818
 
 
 
a2cf83f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8e0302
a2cf83f
 
 
 
 
 
 
 
bddf009
a2cf83f
 
 
 
 
 
 
 
 
375bef0
4a31a2e
a2cf83f
 
 
 
 
 
 
bddf009
 
 
c8e0302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2cf83f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375bef0
4a31a2e
a2cf83f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bddf009
 
 
 
a2cf83f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bddf009
 
 
a2cf83f
 
bddf009
 
 
 
 
 
 
 
 
 
 
 
a2cf83f
bddf009
 
 
 
 
 
 
 
aa965c4
 
bddf009
 
a2cf83f
bddf009
 
 
a2cf83f
bddf009
 
 
 
 
 
 
a1b056c
bddf009
 
 
 
 
 
c8e0302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bddf009
c8e0302
bddf009
 
 
 
 
c8e0302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2cf83f
 
 
 
 
 
 
bddf009
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
import streamlit as st
import pandas as pd
import json
import os
import posixpath
from huggingface_hub import hf_hub_download
from huggingface_hub import list_repo_files
import io
import zipfile
import shutil
import tempfile, uuid

# Replace this with your actual Hugging Face repo ID
REPO_ID = "PortPy-Project/PortPy_Dataset"

# Load from private repo using token
token = os.getenv("HF_TOKEN")

@st.cache_data
def get_patient_ids():
    # Extract disease site from patient ID prefix (e.g., Lung_Patient_1)
    file = hf_hub_download(REPO_ID, repo_type="dataset", filename="data_info.jsonl", token=token)
    with open(file) as f:
        # data_info = json.load(f)
        data_info = [json.loads(line) for line in f]
    patient_ids = [pat['patient_id'] for pat in data_info]
    df = pd.DataFrame(patient_ids, columns=["patient_id"])
    df["disease_site"] = df["patient_id"].str.extract(r"^(.*?)_")
    return df

@st.cache_data
def _list_all_repo_files():
    return list_repo_files(repo_id=REPO_ID, repo_type="dataset")

@st.cache_data
def load_all_metadata(disease_site):
    # Get the list of patient IDs for the selected disease site
    patient_df = get_patient_ids()
    filtered_patients = patient_df[patient_df["disease_site"] == disease_site]

    metadata = {}
    for patient_id in filtered_patients["patient_id"]: # TODO: limit for testing
        # Load structure metadata for the patient
        structs = load_structure_metadata(patient_id)
        # Load beam metadata for the patient
        beams = load_beam_metadata(patient_id)
        planner_file = hf_hub_download(REPO_ID, repo_type="dataset", filename=f"data/{patient_id}/PlannerBeams.json", token=token)
        with open(planner_file) as f:
            planner_data = json.load(f)
            planner_beam_ids = planner_data.get("IDs", [])
        metadata[patient_id] = {
            "structures": structs,
            "beams": beams,
            "planner_beam_ids": planner_beam_ids
        }

    return metadata

@st.cache_data
def load_structure_metadata(patient_id):
    file = hf_hub_download(REPO_ID, repo_type="dataset", filename=f"data/{patient_id}/StructureSet_MetaData.json", token=token)
    with open(file) as f:
        return json.load(f)

@st.cache_data
def load_beam_metadata(patient_id):
    files = _list_all_repo_files()
    beam_meta_paths = [
        f for f in files
        if f.startswith(f"data/{patient_id}/Beams/Beam_") and f.endswith("_MetaData.json")
    ]

    beam_meta = []
    for path in beam_meta_paths:
        file = hf_hub_download(REPO_ID,
                               repo_type="dataset",
                               filename=path,
                               token=token)   # no local_dir
        with open(file) as f:
            beam_meta.append(json.load(f))
    return beam_meta

def get_patient_summary_from_cached_data(patient_id, all_metadata):
    structs = all_metadata[patient_id]["structures"]
    beams = all_metadata[patient_id]["beams"]

    ptv_vol = None
    for s in structs:
        if "PTV" in s["name"].upper():
            ptv_vol = s.get("volume_cc")
            break

    return {
        "ptv_volume": ptv_vol,
        "num_beams": len(beams),
        "beams": beams
    }

def filter_matched_data(filtered_patients, query_ptv_vol, beam_gantry_filter,
                            beam_collimator_filter, beam_energy_filter, beam_couch_filter,
                            only_planner, all_metadata):
    matched = []
    gantry_angles = set(map(int, beam_gantry_filter.split(","))) if beam_gantry_filter else None
    collimator_angles = set(map(int, beam_collimator_filter.split(","))) if beam_collimator_filter else None
    couch_angles = set(map(int, beam_couch_filter.split(","))) if beam_couch_filter else None
    energies = set(beam_energy_filter.replace(" ", "").split(",")) if beam_energy_filter else None

    for pid in filtered_patients["patient_id"]:
        # Retrieve metadata for the patient from the pre-cached all_metadata
        summary = get_patient_summary_from_cached_data(pid, all_metadata)
        if summary["ptv_volume"] is None or summary["ptv_volume"] < query_ptv_vol:
            continue

        # Filter beams by all conditions
        selected_beams = summary["beams"]
        if gantry_angles:
            selected_beams = [b for b in selected_beams if b["gantry_angle"] in gantry_angles]
        if collimator_angles:
            selected_beams = [b for b in selected_beams if b["collimator_angle"] in collimator_angles]
        if couch_angles:
            selected_beams = [b for b in selected_beams if b["couch_angle"] in couch_angles]
        if energies:
            selected_beams = [b for b in selected_beams if b['energy_MV'] in energies]

        selected_beam_ids = [b["ID"] for b in selected_beams]
        if not selected_beam_ids:
            continue

        if only_planner:
            planner_beam_ids = set(all_metadata[pid]["planner_beam_ids"])
            selected_beam_ids = list(planner_beam_ids.intersection(selected_beam_ids))
            if not selected_beam_ids:
                continue

        matched.append({
            "patient_id": pid,
            "num_beams": len(selected_beam_ids),
            "ptv_volume": summary["ptv_volume"],
            "selected_beam_ids": selected_beam_ids
        })

    return pd.DataFrame(matched)

def download_data(repo_id, patient_ids, beam_ids=None, planner_beam_ids=True, max_retries=2, local_dir='./', download_dicom=True):
    from huggingface_hub import hf_hub_download

    downloaded_files = []
    for patient_id in patient_ids:
        static_files = [
            "CT_Data.h5", "CT_MetaData.json",
            "StructureSet_Data.h5", "StructureSet_MetaData.json",
            "OptimizationVoxels_Data.h5", "OptimizationVoxels_MetaData.json",
            "PlannerBeams.json"
        ]
        for filename in static_files:
            hf_path = posixpath.join("data", patient_id, filename)
            for attempt in range(max_retries):
                try:
                    local_path = hf_hub_download(
                        repo_id=repo_id,
                        repo_type="dataset",
                        filename=hf_path,
                        local_dir=local_dir,
                        token=token
                    )
                    downloaded_files.append(local_path)
                    break
                except Exception as e:
                    if attempt == max_retries - 1:
                        st.error(f"Failed to download {hf_path}: {e}")

        # ---------------------------------------------------------------
        # 2. Download all DICOM files under data/<patient_id>/DicomFiles/
        # ---------------------------------------------------------------
        if download_dicom:
            try:
                all_files = list_repo_files(repo_id, repo_type="dataset")
                dicom_prefix = f"data/{patient_id}/DicomFiles/"
                dicom_files = [f for f in all_files if f.startswith(dicom_prefix)]

                for hf_path in dicom_files:
                    for attempt in range(max_retries):
                        try:
                            local_path = hf_hub_download(
                                repo_id=repo_id,
                                repo_type="dataset",
                                filename=hf_path,
                                local_dir=local_dir,
                                token=token
                            )
                            downloaded_files.append(local_path)
                            break
                        except Exception as e:
                            if attempt == max_retries - 1:
                                st.error(f"Failed to download {hf_path}: {e}")

            except Exception as e:
                st.error(f"Error listing DICOM files for {patient_id}: {e}")
        if planner_beam_ids:
            planner_file = os.path.join(local_dir, 'data', patient_id, "PlannerBeams.json")
            try:
                with open(planner_file, "r") as f:
                    planner_data = json.load(f)
                    beam_ids = planner_data.get("IDs", [])
            except Exception as e:
                st.error(f"Error reading PlannerBeams.json: {e}")
                beam_ids = []

        if beam_ids is not None:
            for bid in beam_ids:
                beam_data_file = f"Beams/Beam_{bid}_Data.h5"
                beam_meta_file = f"Beams/Beam_{bid}_MetaData.json"
                for beam_file in [beam_data_file, beam_meta_file]:
                    hf_path = posixpath.join("data", patient_id, beam_file)
                    for attempt in range(max_retries):
                        try:
                            local_path = hf_hub_download(
                                repo_id=repo_id,
                                repo_type="dataset",
                                filename=hf_path,
                                local_dir=local_dir,
                                token=token
                            )
                            downloaded_files.append(local_path)
                            break
                        except Exception as e:
                            if attempt == max_retries - 1:
                                st.error(f"Failed to download {hf_path}: {e}")
    return downloaded_files

from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode

def show_aggrid_table(df):
    gb = GridOptionsBuilder.from_dataframe(df)
    gb.configure_default_column(groupable=True, value=True, enableRowGroup=True, aggFunc='sum', editable=False)
    gb.configure_grid_options(domLayout='normal')

    # Enable multiple row selection with checkboxes
    gb.configure_selection('multiple', use_checkbox=True)
    gb.configure_column("patient_id", checkboxSelection=True)

    grid_options = gb.build()

    grid_response = AgGrid(
        df,
        gridOptions=grid_options,
        enable_enterprise_modules=False,
        allow_unsafe_jscode=True,
        fit_columns_on_grid_load=True,
        theme='balham',
        update_mode=GridUpdateMode.SELECTION_CHANGED
    )

    return grid_response

def main():
    st.set_page_config(page_title="PortPy Metadata Explorer", layout="wide")
    st.title("📊 PortPy Metadata Explorer & Downloader")

    patient_df = get_patient_ids()
    disease_site = st.sidebar.selectbox("Select Disease Site", patient_df["disease_site"].unique())
    all_metadata = load_all_metadata(disease_site)  # Load and cache all metadata for selected disease site

    filtered_patients = pd.DataFrame(all_metadata.keys(), columns=["patient_id"])


    beam_gantry_filter = st.sidebar.text_input("Gantry Angles (comma-separated)", "")
    beam_collimator_filter = st.sidebar.text_input("Collimator Angles (comma-separated)", "")
    beam_energy_filter = st.sidebar.text_input("Beam Energies (comma-separated)", "")
    beam_couch_filter = st.sidebar.text_input("Couch Angles (comma-separated)", "")
    query_ptv_vol = st.sidebar.number_input("Minimum PTV volume (cc):", value=0)

    # Checkbox: Only planner beams
    only_planner = st.sidebar.checkbox(
        "Show only planner beams (if selected it will download only planner beams)",
        value=True,
    )

    results_df = filter_matched_data(
        filtered_patients, query_ptv_vol, beam_gantry_filter,
        beam_collimator_filter, beam_energy_filter, beam_couch_filter,
        only_planner, all_metadata
    )
    # Summary Table
    # st.dataframe(results_df)
    grid_response = show_aggrid_table(results_df)

    selected_rows = grid_response.get("selected_rows", pd.DataFrame())

    if isinstance(selected_rows, pd.DataFrame):
        print(selected_rows)
        if not selected_rows.empty:
            for _, row in selected_rows.iterrows():
                pid = row["patient_id"]
                st.markdown(f"### Patient: {pid}")
                st.markdown("#### Structures")
                st.dataframe(pd.DataFrame(all_metadata[pid]["structures"]))
                st.markdown("#### Beams")
                st.dataframe(pd.DataFrame(all_metadata[pid]["beams"]))

    # selected_patient = st.selectbox("Select patient for detailed view", results_df["patient_id"] if not results_df.empty else [])
    # if selected_patient:
    #     structs = all_metadata[selected_patient]["structures"]
    #     beams = all_metadata[selected_patient]["beams"]
    #     st.subheader(f"🏗️ Structures for {selected_patient}")
    #     st.dataframe(pd.DataFrame(structs), use_container_width=True)
    #     st.subheader(f"📡 Beams for {selected_patient}")
    #     st.dataframe(pd.DataFrame(beams), use_container_width=True)
    if "open_download_expander" not in st.session_state:
        st.session_state["open_download_expander"] = False
    with st.expander("Download matched patients", expanded=st.session_state["open_download_expander"]):
        # Multi-select and download
        to_download = st.sidebar.multiselect("Select Patients to Download", results_df["patient_id"].tolist())
        # local_dir = st.sidebar.text_input("Enter local directory to download data:", value="./downloaded")
        # if st.sidebar.button("Download Selected Patients"):
        #     if to_download:
        #         patient_to_beams = {
        #             row["patient_id"]: row["beam_ids"] for ind, row in results_df.iterrows() if ind in to_download
        #         }
        #         for pid, beam_ids in patient_to_beams.items():
        #             download_data(REPO_ID, [pid], beam_ids=beam_ids, planner_beam_ids=False, local_dir=local_dir)
        #         st.success("Download complete!")
        #     else:
        #         st.warning("No patients selected.")

        if st.sidebar.button("Download Selected Patients"):
            st.session_state["open_download_expander"] = True  # Force open expander
            if not to_download:
                st.warning("No patients selected.")
            else:
                progress = st.progress(0)
                status = st.empty()

                local_dir = "./downloaded"
                if os.path.exists(local_dir):
                    shutil.rmtree(local_dir)
                os.makedirs(local_dir, exist_ok=True)

                patient_to_beams = {
                    row["patient_id"]: row["selected_beam_ids"]
                    for _, row in results_df.iterrows()
                    if row["patient_id"] in to_download
                }

                total = len(patient_to_beams)
                for i, (pid, beam_ids) in enumerate(patient_to_beams.items(), start=1):
                    status.write(f"Downloading {pid} ({i}/{total})…")

                    download_data(REPO_ID, [pid], beam_ids=beam_ids,
                                  planner_beam_ids=only_planner,
                                  local_dir=local_dir, download_dicom=True)

                    progress.progress(i / total)


                status.success("All downloads complete. Preparing zip…")

                zip_path = os.path.join(tempfile.gettempdir(), f"portpy_patients_{uuid.uuid4().hex}.zip")

                # optional: guard size to avoid crashes
                total_bytes = 0
                for root, _, files in os.walk(local_dir):
                    for f in files:
                        total_bytes += os.path.getsize(os.path.join(root, f))
                total_gb = total_bytes / (1024 ** 3)
                status.write(f"Preparing zip (~{total_gb:.2f} GB)…")
                if total_gb > 40.0:
                    st.error("Selection too large for a single zip. Please download fewer patients.")
                    st.stop()

                if os.path.exists(zip_path):
                    os.remove(zip_path)

                with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_STORED, allowZip64=True) as zf:
                    for root, _, files in os.walk(local_dir):
                        for f in files:
                            full_path = os.path.join(root, f)
                            rel_path = os.path.relpath(full_path, local_dir)
                            zf.write(full_path, rel_path)

                with open(zip_path, "rb") as fp:
                    st.download_button(
                        label="Your download is ready! Click to save.",
                        data=fp,
                        file_name="portpy_patients.zip",
                        mime="application/zip",
                    )
                # # Create zip in memory
                # buf = io.BytesIO()
                # with zipfile.ZipFile(buf, "w", zipfile.ZIP_STORED) as zf:
                #     for root, _, files in os.walk(local_dir):
                #         for f in files:
                #             full_path = os.path.join(root, f)
                #             rel_path = os.path.relpath(full_path, local_dir)
                #             zf.write(full_path, rel_path)
                # buf.seek(0)
                #
                # # Trigger file download automatically from the SAME BUTTON CLICK
                # st.download_button(
                #     label="Your download is ready! Click to save.",
                #     data=buf,
                #     file_name="portpy_patients.zip",
                #     mime="application/zip",
                # )

        # if st.button("Download Data"):
        #     patients_to_download = results_df["patient_id"].tolist()
        #     download_data(REPO_ID, patients_to_download, planner_beam_ids=True, local_dir=local_dir)
        #     st.success("Download complete!")

if __name__ == "__main__":
    main()
    #to run: streamlit run app.py