From 800541a1410f5ba2e119da42f7f557f47f279e0b Mon Sep 17 00:00:00 2001 From: gael Date: Thu, 25 Sep 2025 20:53:22 -0400 Subject: [PATCH] Update gain_viz/app.py --- gain_viz/app.py | 351 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 267 insertions(+), 84 deletions(-) diff --git a/gain_viz/app.py b/gain_viz/app.py index 807b919..68bf12a 100644 --- a/gain_viz/app.py +++ b/gain_viz/app.py @@ -7,26 +7,36 @@ import os import threading import time import serial +import json app = Flask(__name__) - -# Path to save the plot image PLOT_PATH = os.path.join(os.getcwd(), "plot.png") +# ----------------- Shared Config ----------------- +config = { + "usrp_tx_gain": 60, + "usrp_rx_gain": 30, + "scm_tx_gain": 30, + "scm_rx_gain": 30, + "sample_rate": 23.04e6, + "window_ms": 20, + "center_freq": 3.415e9, + "NFFT": 1024, + "tcp_port": 5556, + "streaming": False, # Added streaming state +} +config_lock = threading.Lock() +# Global variables +usrp_tx_gain = config["usrp_tx_gain"] +usrp_rx_gain = config["usrp_rx_gain"] +scm_tx_gain = config["scm_tx_gain"] +scm_rx_gain = config["scm_rx_gain"] -# Global variables for gain values -usrp_tx_gain = 60 -usrp_rx_gain = 30 -scm_tx_gain = 30 -scm_rx_gain = 30 - -# Global variables for plot settings -sample_rate = 23.04e6 # Hz -window_ms = 20 -center_freq = 3.415e9 -NFFT = 1024 -tcp_port = 5556 +# Plotting thread control +plot_thread = None +stop_event = threading.Event() +pause_event = threading.Event() # ----------------- Serial / SCM ----------------- def connect_serial(port, baudrate=115200, timeout=1): @@ -46,11 +56,11 @@ def connect_serial(port, baudrate=115200, timeout=1): return None def send_command(ser, command): - if ser.is_open: + if ser and ser.is_open: ser.write(command.encode('utf-8')) def receive_feedback(ser): - if ser.is_open: + if ser and ser.is_open: try: ser.flush() raw_response = ser.readlines() @@ -75,6 +85,7 @@ def scm_conf(port, baudrate, rx_cmd, tx_cmd): send_command(ser, cmd + "\r") feedback = receive_feedback(ser) attempt += 1 + ser.close() return True return False @@ -105,48 +116,81 @@ def gain_update(usrp_tx, usrp_rx, scm_tx, scm_rx): if scm_change: scm_conf("/dev/ttyUSB0", 115200, r_cmd, t_cmd) scm_conf("/dev/ttyUSB1", 115200, r_cmd, t_cmd) + + with config_lock: + config["scm_tx_gain"] = scm_tx_gain + config["scm_rx_gain"] = scm_rx_gain + + with config_lock: + config["usrp_tx_gain"] = usrp_tx_gain + config["usrp_rx_gain"] = usrp_rx_gain return True -# ----------------- ZMQ Subscriber ----------------- -def zmq_subscriber(host, port): - context = zmq.Context() - socket = context.socket(zmq.SUB) - socket.setsockopt(zmq.CONFLATE, 1) - socket.setsockopt_string(zmq.SUBSCRIBE, "") - socket.setsockopt(zmq.RCVTIMEO, 1000) - socket.connect(f"tcp://{host}:{port}") - return socket - # ----------------- Plot Generation ----------------- def generate_spectrum_plot(): - socket = zmq_subscriber("localhost", tcp_port) - global sample_rate, window_ms, center_freq, NFFT - window_samples = int(sample_rate * window_ms / 1000) - noverlap = 512 - cmap = plt.get_cmap('twilight') + socket = None + iq_sample = np.zeros(1, dtype=np.complex64) + last_port = None - # Initial placeholder for first plot (zeros) - iq_sample = np.zeros(window_samples, dtype=np.complex64) + while not stop_event.is_set(): + # Check if we're paused + if pause_event.is_set(): + time.sleep(0.1) + continue + + with config_lock: + sample_rate = config["sample_rate"] + window_ms = config["window_ms"] + center_freq = config["center_freq"] + NFFT = config["NFFT"] + tcp_port = config["tcp_port"] + streaming = config["streaming"] + + # Only process if streaming is active + if not streaming: + time.sleep(0.1) + continue + + # Reconnect if port changed or socket is None + if socket is None or tcp_port != last_port: + if socket: + socket.close() + try: + context = zmq.Context() + socket = context.socket(zmq.SUB) + socket.setsockopt(zmq.CONFLATE, 1) + socket.setsockopt_string(zmq.SUBSCRIBE, "") + socket.setsockopt(zmq.RCVTIMEO, 1000) + socket.connect(f"tcp://localhost:{tcp_port}") + last_port = tcp_port + print(f"Connected to ZMQ on port {tcp_port}") + except Exception as e: + print(f"ZMQ connection error: {e}") + socket = None + time.sleep(1) + continue + + window_samples = int(sample_rate * window_ms / 1000) + if iq_sample.size != window_samples: + iq_sample = np.zeros(window_samples, dtype=np.complex64) - while True: try: - # Try to read ZMQ message msg = socket.recv(zmq.NOBLOCK) float_data = np.frombuffer(msg, dtype=np.float32) if float_data.size >= 2: complex_data = float_data.reshape(-1, 2) iq_all = complex_data[:, 0] + 1j * complex_data[:, 1] - iq_sample = ( - iq_all[-window_samples:] - if len(iq_all) >= window_samples - else np.pad(iq_all, (window_samples - len(iq_all), 0)) - ) - # --- Create plot --- + if len(iq_all) >= window_samples: + iq_sample = iq_all[-window_samples:] + else: + iq_sample = np.pad(iq_all, (window_samples - len(iq_all), 0)) + + # Create plot fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6)) fig.subplots_adjust(hspace=0.4) - # Time-domain plot (ms) + # Time-domain plot times_ms = np.arange(len(iq_sample)) * 1000 / sample_rate ax1.plot(times_ms, np.real(iq_sample), label="Real", color='b') ax1.plot(times_ms, np.imag(iq_sample), label="Imag", color='r') @@ -156,40 +200,105 @@ def generate_spectrum_plot(): ax1.grid(True, which='both', linestyle='--', linewidth=0.5) ax1.legend() - # Spectrogram without grid + # Spectrogram + cmap = plt.get_cmap('twilight') ax2.specgram( iq_sample, Fs=sample_rate, Fc=center_freq, NFFT=NFFT, - noverlap=noverlap, + noverlap=512, cmap=cmap ) ax2.set_xlabel("Time (ms)") ax2.set_ylabel("Frequency (Hz)") ax2.grid(False) ax2.set_ylim(center_freq - sample_rate / 2, - center_freq + sample_rate / 2) - ax2.xaxis.set_major_formatter(ticker.FuncFormatter(lambda t, pos: '{0:g}'.format(t*1e3))) # kHz format - ax2.set_xlabel("Time (ms)") - ax2.set_ylabel("Frequency (Hz)") + center_freq + sample_rate / 2) + ax2.xaxis.set_major_formatter( + ticker.FuncFormatter(lambda t, pos: '{0:g}'.format(t*1e3)) + ) ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator()) - - - # Save the plot - os.makedirs(os.path.dirname(PLOT_PATH), exist_ok=True) plt.savefig(PLOT_PATH, bbox_inches='tight') plt.close(fig) - except zmq.Again: - pass # No new data, keep last iq_sample + # No new data + fig, ax = plt.subplots(figsize=(12, 6)) + ax.text(0.5, 0.5, "Waiting for data...", + ha='center', va='center', transform=ax.transAxes, fontsize=16) + ax.set_title("Spectrum Analyzer - No Data (Streaming Active)") + plt.savefig(PLOT_PATH, bbox_inches='tight') + plt.close(fig) + except Exception as e: + print(f"Plot generation error: {e}") + fig, ax = plt.subplots(figsize=(12, 6)) + ax.text(0.5, 0.5, f"Error: {str(e)}", + ha='center', va='center', transform=ax.transAxes, fontsize=12) + ax.set_title("Spectrum Analyzer - Error") + plt.savefig(PLOT_PATH, bbox_inches='tight') + plt.close(fig) - - # Fast refresh (20ms = 50 fps) - time.sleep(0.5) + time.sleep(0.1) + # Cleanup when stopping + if socket: + socket.close() + print("Plotting thread stopped") + +def start_plotting(): + """Start the plotting thread""" + global plot_thread, stop_event, pause_event + + stop_event.clear() + pause_event.clear() + + with config_lock: + config["streaming"] = True + + if plot_thread is None or not plot_thread.is_alive(): + plot_thread = threading.Thread(target=generate_spectrum_plot, daemon=True) + plot_thread.start() + print("Plotting thread started") + + return True + +def stop_plotting(): + """Stop the plotting thread""" + global plot_thread, stop_event + + with config_lock: + config["streaming"] = False + + stop_event.set() + + if plot_thread and plot_thread.is_alive(): + plot_thread.join(timeout=2.0) + + # Create stopped message plot + fig, ax = plt.subplots(figsize=(12, 6)) + ax.text(0.5, 0.5, "Streaming Stopped\nClick Start to begin", + ha='center', va='center', transform=ax.transAxes, fontsize=16) + ax.set_title("Spectrum Analyzer - Stopped") + plt.savefig(PLOT_PATH, bbox_inches='tight') + plt.close(fig) + + print("Plotting thread stopped") + return True + +def pause_plotting(): + """Pause the plotting updates""" + global pause_event + + if pause_event.is_set(): + pause_event.clear() + print("Plotting resumed") + return "resumed" + else: + pause_event.set() + print("Plotting paused") + return "paused" # ----------------- Flask Routes ----------------- @app.route('/') @@ -199,17 +308,37 @@ def index(): @app.route('/update_gains', methods=['POST']) def update_gains(): global usrp_tx_gain, usrp_rx_gain, scm_tx_gain, scm_rx_gain - usrp_tx = request.form.get('usrp_tx_gain', usrp_tx_gain, type=float) - usrp_rx = request.form.get('usrp_rx_gain', usrp_rx_gain, type=float) - scm_tx = request.form.get('scm_tx_gain', scm_tx_gain, type=float) - scm_rx = request.form.get('scm_rx_gain', scm_rx_gain, type=float) + + try: + usrp_tx = request.form.get('usrp_tx_gain', type=float) + usrp_rx = request.form.get('usrp_rx_gain', type=float) + scm_tx = request.form.get('scm_tx_gain', type=float) + scm_rx = request.form.get('scm_rx_gain', type=float) + + if usrp_tx is None: + usrp_tx = usrp_tx_gain + if usrp_rx is None: + usrp_rx = usrp_rx_gain + if scm_tx is None: + scm_tx = scm_tx_gain + if scm_rx is None: + scm_rx = scm_rx_gain - gain_update(usrp_tx, usrp_rx, scm_tx, scm_rx) - return jsonify({"status": "success", "message": "Gains updated successfully"}) + success = gain_update(usrp_tx, usrp_rx, scm_tx, scm_rx) + if success: + return jsonify({"status": "success", "message": "Gains updated successfully"}) + else: + return jsonify({"status": "error", "message": "Failed to update gains"}), 500 + + except Exception as e: + return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500 @app.route('/plot') def plot(): - return send_file(PLOT_PATH, mimetype='image/png') + try: + return send_file(PLOT_PATH, mimetype='image/png') + except Exception as e: + return send_file(PLOT_PATH, mimetype='image/png') @app.route('/get_gains') def get_gains(): @@ -222,19 +351,28 @@ def get_gains(): @app.route('/update_params', methods=['POST']) def update_params(): - global sample_rate, window_ms, center_freq, NFFT, tcp_port try: - # Get parameters from form data center_freq = request.form.get('center_freq', type=float) sample_rate = request.form.get('sample_rate', type=float) NFFT = request.form.get('fft_size', type=int) window_ms = request.form.get('window_ms', type=float) tcp_port = request.form.get('tcp_port', type=int) + if not all([center_freq, sample_rate, NFFT, window_ms, tcp_port]): + return jsonify({ + 'status': 'error', + 'message': 'All parameters are required' + }), 400 - + with config_lock: + config["center_freq"] = center_freq + config["sample_rate"] = sample_rate + config["NFFT"] = NFFT + config["window_ms"] = window_ms + config["tcp_port"] = tcp_port + + print(f"Updated params: center_freq={center_freq}, sample_rate={sample_rate}, NFFT={NFFT}, window_ms={window_ms}, tcp_port={tcp_port}") - # Save to config file if needed save_config() return jsonify({ @@ -242,33 +380,78 @@ def update_params(): 'message': 'Parameters updated successfully' }) except Exception as e: + print(f"Error updating params: {e}") return jsonify({ 'status': 'error', 'message': str(e) }), 500 -# Add to your config handling +@app.route('/start_stream', methods=['POST']) +def start_stream(): + try: + success = start_plotting() + if success: + return jsonify({"status": "success", "message": "Streaming started"}) + else: + return jsonify({"status": "error", "message": "Failed to start streaming"}), 500 + except Exception as e: + return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500 + +@app.route('/stop_stream', methods=['POST']) +def stop_stream(): + try: + success = stop_plotting() + if success: + return jsonify({"status": "success", "message": "Streaming stopped"}) + else: + return jsonify({"status": "error", "message": "Failed to stop streaming"}), 500 + except Exception as e: + return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500 + +@app.route('/pause_stream', methods=['POST']) +def pause_stream(): + try: + result = pause_plotting() + return jsonify({"status": "success", "message": f"Streaming {result}", "state": result}) + except Exception as e: + return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500 + +@app.route('/get_stream_state', methods=['GET']) +def get_stream_state(): + with config_lock: + streaming = config["streaming"] + paused = pause_event.is_set() + + state = "stopped" + if streaming and not paused: + state = "running" + elif streaming and paused: + state = "paused" + + return jsonify({"state": state}) + def save_config(): - config = { - 'center_freq': center_freq, - 'sample_rate': sample_rate, - 'fft_size': NFFT, - 'window_ms': window_ms, - 'tcp_port' : tcp_port - } - with open('/opt/gain-viz/config.json', 'w') as f: - json.dump(config, f) + with config_lock: + cfg = dict(config) + + try: + with open(os.path.join(os.getcwd(), "gain_viz.json"), 'w') as f: + json.dump(cfg, f, indent=2) + except Exception as e: + print(f"Error saving config: {e}") # ----------------- Main ----------------- def main(): # Ensure placeholder image exists if not os.path.exists(PLOT_PATH): - plt.figure() - plt.text(0.5, 0.5, "Waiting for data...", ha='center', va='center') + fig, ax = plt.subplots(figsize=(12, 6)) + ax.text(0.5, 0.5, "Click Start to begin streaming", ha='center', va='center', fontsize=16) + ax.set_title("Gain-Viz Spectrum Analyzer - Ready") plt.savefig(PLOT_PATH) - plt.close() + plt.close(fig) - # Start plotting thread - threading.Thread(target=generate_spectrum_plot, daemon=True).start() + print("Gain-Viz server started. Use the web interface to control streaming.") + app.run(host="0.0.0.0", port=5000, debug=True, use_reloader=False) - app.run(host="0.0.0.0", debug=True) +if __name__ == '__main__': + main() \ No newline at end of file