| import os |
| from collections import defaultdict, namedtuple |
| from datetime import datetime, timedelta |
| from json import dumps |
| from typing import Any, AnyStr, Dict, List, NamedTuple, Union, Optional |
|
|
| import numpy as np |
| import requests |
| import tensorflow as tf |
| from fastapi import FastAPI |
| from kafka import KafkaProducer |
| from pydantic import BaseModel |
| from scipy.interpolate import interp1d |
|
|
| from model import ModelConfig, UNet |
| from postprocess import extract_picks |
|
|
| tf.compat.v1.disable_eager_execution() |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
| PROJECT_ROOT = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) |
| JSONObject = Dict[AnyStr, Any] |
| JSONArray = List[Any] |
| JSONStructure = Union[JSONArray, JSONObject] |
|
|
| app = FastAPI() |
| X_SHAPE = [3000, 1, 3] |
| SAMPLING_RATE = 100 |
|
|
| |
| model = UNet(mode="pred") |
| sess_config = tf.compat.v1.ConfigProto() |
| sess_config.gpu_options.allow_growth = True |
|
|
| sess = tf.compat.v1.Session(config=sess_config) |
| saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables()) |
| init = tf.compat.v1.global_variables_initializer() |
| sess.run(init) |
| latest_check_point = tf.train.latest_checkpoint(f"{PROJECT_ROOT}/model/190703-214543") |
| print(f"restoring model {latest_check_point}") |
| saver.restore(sess, latest_check_point) |
|
|
| |
| GAMMA_API_URL = "http://gamma-api:8001" |
| |
| |
| |
|
|
| |
| use_kafka = False |
|
|
| try: |
| print("Connecting to k8s kafka") |
| BROKER_URL = "quakeflow-kafka-headless:9092" |
| |
| producer = KafkaProducer( |
| bootstrap_servers=[BROKER_URL], |
| key_serializer=lambda x: dumps(x).encode("utf-8"), |
| value_serializer=lambda x: dumps(x).encode("utf-8"), |
| ) |
| use_kafka = True |
| print("k8s kafka connection success!") |
| except BaseException: |
| print("k8s Kafka connection error") |
| try: |
| print("Connecting to local kafka") |
| producer = KafkaProducer( |
| bootstrap_servers=["localhost:9092"], |
| key_serializer=lambda x: dumps(x).encode("utf-8"), |
| value_serializer=lambda x: dumps(x).encode("utf-8"), |
| ) |
| use_kafka = True |
| print("local kafka connection success!") |
| except BaseException: |
| print("local Kafka connection error") |
| print(f"Kafka status: {use_kafka}") |
|
|
|
|
| def normalize_batch(data, window=3000): |
| """ |
| data: nsta, nt, nch |
| """ |
| shift = window // 2 |
| nsta, nt, nch = data.shape |
|
|
| |
| data_pad = np.pad(data, ((0, 0), (window // 2, window // 2), (0, 0)), mode="reflect") |
| t = np.arange(0, nt, shift, dtype="int") |
| std = np.zeros([nsta, len(t) + 1, nch]) |
| mean = np.zeros([nsta, len(t) + 1, nch]) |
| for i in range(1, len(t)): |
| std[:, i, :] = np.std(data_pad[:, i * shift : i * shift + window, :], axis=1) |
| mean[:, i, :] = np.mean(data_pad[:, i * shift : i * shift + window, :], axis=1) |
|
|
| t = np.append(t, nt) |
| |
| |
| std[:, -1, :], mean[:, -1, :] = std[:, -2, :], mean[:, -2, :] |
| std[:, 0, :], mean[:, 0, :] = std[:, 1, :], mean[:, 1, :] |
| std[std == 0] = 1 |
|
|
| |
| t_interp = np.arange(nt, dtype="int") |
| std_interp = interp1d(t, std, axis=1, kind="slinear")(t_interp) |
| mean_interp = interp1d(t, mean, axis=1, kind="slinear")(t_interp) |
| data = (data - mean_interp) / std_interp |
|
|
| return data |
|
|
|
|
| def preprocess(data): |
| raw = data.copy() |
| data = normalize_batch(data) |
| if len(data.shape) == 3: |
| data = data[:, :, np.newaxis, :] |
| raw = raw[:, :, np.newaxis, :] |
| return data, raw |
|
|
|
|
| def calc_timestamp(timestamp, sec): |
| timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f") + timedelta(seconds=sec) |
| return timestamp.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] |
|
|
|
|
| def format_picks(picks, dt, amplitudes): |
| picks_ = [] |
| for pick, amplitude in zip(picks, amplitudes): |
| for idxs, probs, amps in zip(pick.p_idx, pick.p_prob, amplitude.p_amp): |
| for idx, prob, amp in zip(idxs, probs, amps): |
| picks_.append( |
| { |
| "id": pick.fname, |
| "timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
| "prob": prob, |
| "amp": amp, |
| "type": "p", |
| } |
| ) |
| for idxs, probs, amps in zip(pick.s_idx, pick.s_prob, amplitude.s_amp): |
| for idx, prob, amp in zip(idxs, probs, amps): |
| picks_.append( |
| { |
| "id": pick.fname, |
| "timestamp": calc_timestamp(pick.t0, float(idx) * dt), |
| "prob": prob, |
| "amp": amp, |
| "type": "s", |
| } |
| ) |
| return picks_ |
|
|
|
|
| def format_data(data): |
|
|
| |
| |
| |
| chn2idx = {"E": 0, "N": 1, "Z": 2, "3": 0, "2": 1, "1": 2} |
| Data = NamedTuple("data", [("id", list), ("timestamp", list), ("vec", list), ("dt", float)]) |
|
|
| |
| chn_ = defaultdict(list) |
| t0_ = defaultdict(list) |
| vv_ = defaultdict(list) |
| for i in range(len(data.id)): |
| key = data.id[i][:-1] |
| chn_[key].append(data.id[i][-1]) |
| t0_[key].append(datetime.strptime(data.timestamp[i], "%Y-%m-%dT%H:%M:%S.%f").timestamp() * SAMPLING_RATE) |
| vv_[key].append(np.array(data.vec[i])) |
|
|
| |
| id_ = [] |
| timestamp_ = [] |
| vec_ = [] |
| for k in chn_: |
| id_.append(k) |
| min_t0 = min(t0_[k]) |
| timestamp_.append(datetime.fromtimestamp(min_t0 / SAMPLING_RATE).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3]) |
| vec = np.zeros([X_SHAPE[0], X_SHAPE[-1]]) |
| for i in range(len(chn_[k])): |
| |
| shift = int(t0_[k][i] - min_t0) |
| vec[shift : len(vv_[k][i]) + shift, chn2idx[chn_[k][i]]] = vv_[k][i][: X_SHAPE[0] - shift] - np.mean( |
| vv_[k][i][: X_SHAPE[0] - shift] |
| ) |
| vec_.append(vec.tolist()) |
|
|
| return Data(id=id_, timestamp=timestamp_, vec=vec_, dt=1 / SAMPLING_RATE) |
| |
|
|
|
|
| def get_prediction(data, return_preds=False): |
|
|
| vec = np.array(data.vec) |
| vec, vec_raw = preprocess(vec) |
|
|
| feed = {model.X: vec, model.drop_rate: 0, model.is_training: False} |
| preds = sess.run(model.preds, feed_dict=feed) |
|
|
| picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw) |
|
|
| picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks] |
|
|
| if return_preds: |
| return picks, preds |
|
|
| return picks |
|
|
|
|
| class Data(BaseModel): |
| |
| |
| |
| id: List[str] |
| timestamp: List[str] |
| vec: Union[List[List[List[float]]], List[List[float]]] |
| dt: Optional[float] = 0.01 |
| |
| stations: Optional[List[Dict[str, Union[float, str]]]] = None |
| config: Optional[Dict[str, Union[List[float], List[int], List[str], float, int, str]]] = None |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @app.post("/predict") |
| def predict(data: Data): |
|
|
| picks = get_prediction(data) |
|
|
| return picks |
|
|
|
|
| @app.post("/predict_prob") |
| def predict(data: Data): |
|
|
| picks, preds = get_prediction(data, True) |
|
|
| return picks, preds.tolist() |
|
|
|
|
| @app.post("/predict_phasenet2gamma") |
| def predict(data: Data): |
|
|
| picks = get_prediction(data) |
|
|
| |
| |
| |
| |
| try: |
| catalog = requests.post(f"{GAMMA_API_URL}/predict", json={"picks": picks, |
| "stations": data.stations, |
| "config": data.config}) |
| print(catalog.json()["catalog"]) |
| return catalog.json() |
| except Exception as error: |
| print(error) |
|
|
| return {} |
|
|
| @app.post("/predict_phasenet2gamma2ui") |
| def predict(data: Data): |
|
|
| picks = get_prediction(data) |
|
|
| try: |
| catalog = requests.post(f"{GAMMA_API_URL}/predict", json={"picks": picks, |
| "stations": data.stations, |
| "config": data.config}) |
| print(catalog.json()["catalog"]) |
| return catalog.json() |
| except Exception as error: |
| print(error) |
|
|
| if use_kafka: |
| print("Push picks to kafka...") |
| for pick in picks: |
| producer.send("phasenet_picks", key=pick["id"], value=pick) |
| print("Push waveform to kafka...") |
| for id, timestamp, vec in zip(data.id, data.timestamp, data.vec): |
| producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt}) |
|
|
| return {} |
|
|
|
|
| @app.post("/predict_stream_phasenet2gamma") |
| def predict(data: Data): |
|
|
| data = format_data(data) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| picks = get_prediction(data) |
|
|
| return_value = {} |
| try: |
| catalog = requests.post(f"{GAMMA_API_URL}/predict_stream", json={"picks": picks}) |
| print("GMMA:", catalog.json()["catalog"]) |
| return_value = catalog.json() |
| except Exception as error: |
| print(error) |
|
|
| if use_kafka: |
| print("Push picks to kafka...") |
| for pick in picks: |
| producer.send("phasenet_picks", key=pick["id"], value=pick) |
| print("Push waveform to kafka...") |
| for id, timestamp, vec in zip(data.id, data.timestamp, data.vec): |
| producer.send("waveform_phasenet", key=id, value={"timestamp": timestamp, "vec": vec, "dt": data.dt}) |
|
|
| return return_value |
|
|
|
|
| @app.get("/healthz") |
| def healthz(): |
| return {"status": "ok"} |
|
|