Update gain_viz/app.py

This commit is contained in:
G gael 2025-09-25 20:53:22 -04:00
parent d5398a9c86
commit 800541a141

View File

@ -7,26 +7,36 @@ import os
import threading import threading
import time import time
import serial import serial
import json
app = Flask(__name__) app = Flask(__name__)
# Path to save the plot image
PLOT_PATH = os.path.join(os.getcwd(), "plot.png") PLOT_PATH = os.path.join(os.getcwd(), "plot.png")
# ----------------- Shared Config -----------------
config = {
"usrp_tx_gain": 60,
"usrp_rx_gain": 30,
"scm_tx_gain": 30,
"scm_rx_gain": 30,
"sample_rate": 23.04e6,
"window_ms": 20,
"center_freq": 3.415e9,
"NFFT": 1024,
"tcp_port": 5556,
"streaming": False, # Added streaming state
}
config_lock = threading.Lock()
# Global variables
usrp_tx_gain = config["usrp_tx_gain"]
usrp_rx_gain = config["usrp_rx_gain"]
scm_tx_gain = config["scm_tx_gain"]
scm_rx_gain = config["scm_rx_gain"]
# Global variables for gain values # Plotting thread control
usrp_tx_gain = 60 plot_thread = None
usrp_rx_gain = 30 stop_event = threading.Event()
scm_tx_gain = 30 pause_event = threading.Event()
scm_rx_gain = 30
# Global variables for plot settings
sample_rate = 23.04e6 # Hz
window_ms = 20
center_freq = 3.415e9
NFFT = 1024
tcp_port = 5556
# ----------------- Serial / SCM ----------------- # ----------------- Serial / SCM -----------------
def connect_serial(port, baudrate=115200, timeout=1): def connect_serial(port, baudrate=115200, timeout=1):
@ -46,11 +56,11 @@ def connect_serial(port, baudrate=115200, timeout=1):
return None return None
def send_command(ser, command): def send_command(ser, command):
if ser.is_open: if ser and ser.is_open:
ser.write(command.encode('utf-8')) ser.write(command.encode('utf-8'))
def receive_feedback(ser): def receive_feedback(ser):
if ser.is_open: if ser and ser.is_open:
try: try:
ser.flush() ser.flush()
raw_response = ser.readlines() raw_response = ser.readlines()
@ -75,6 +85,7 @@ def scm_conf(port, baudrate, rx_cmd, tx_cmd):
send_command(ser, cmd + "\r") send_command(ser, cmd + "\r")
feedback = receive_feedback(ser) feedback = receive_feedback(ser)
attempt += 1 attempt += 1
ser.close()
return True return True
return False return False
@ -105,48 +116,81 @@ def gain_update(usrp_tx, usrp_rx, scm_tx, scm_rx):
if scm_change: if scm_change:
scm_conf("/dev/ttyUSB0", 115200, r_cmd, t_cmd) scm_conf("/dev/ttyUSB0", 115200, r_cmd, t_cmd)
scm_conf("/dev/ttyUSB1", 115200, r_cmd, t_cmd) scm_conf("/dev/ttyUSB1", 115200, r_cmd, t_cmd)
with config_lock:
config["scm_tx_gain"] = scm_tx_gain
config["scm_rx_gain"] = scm_rx_gain
with config_lock:
config["usrp_tx_gain"] = usrp_tx_gain
config["usrp_rx_gain"] = usrp_rx_gain
return True return True
# ----------------- ZMQ Subscriber -----------------
def zmq_subscriber(host, port):
context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.setsockopt(zmq.CONFLATE, 1)
socket.setsockopt_string(zmq.SUBSCRIBE, "")
socket.setsockopt(zmq.RCVTIMEO, 1000)
socket.connect(f"tcp://{host}:{port}")
return socket
# ----------------- Plot Generation ----------------- # ----------------- Plot Generation -----------------
def generate_spectrum_plot(): def generate_spectrum_plot():
socket = zmq_subscriber("localhost", tcp_port) socket = None
global sample_rate, window_ms, center_freq, NFFT iq_sample = np.zeros(1, dtype=np.complex64)
window_samples = int(sample_rate * window_ms / 1000) last_port = None
noverlap = 512
cmap = plt.get_cmap('twilight')
# Initial placeholder for first plot (zeros) while not stop_event.is_set():
iq_sample = np.zeros(window_samples, dtype=np.complex64) # Check if we're paused
if pause_event.is_set():
time.sleep(0.1)
continue
with config_lock:
sample_rate = config["sample_rate"]
window_ms = config["window_ms"]
center_freq = config["center_freq"]
NFFT = config["NFFT"]
tcp_port = config["tcp_port"]
streaming = config["streaming"]
# Only process if streaming is active
if not streaming:
time.sleep(0.1)
continue
# Reconnect if port changed or socket is None
if socket is None or tcp_port != last_port:
if socket:
socket.close()
try:
context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.setsockopt(zmq.CONFLATE, 1)
socket.setsockopt_string(zmq.SUBSCRIBE, "")
socket.setsockopt(zmq.RCVTIMEO, 1000)
socket.connect(f"tcp://localhost:{tcp_port}")
last_port = tcp_port
print(f"Connected to ZMQ on port {tcp_port}")
except Exception as e:
print(f"ZMQ connection error: {e}")
socket = None
time.sleep(1)
continue
window_samples = int(sample_rate * window_ms / 1000)
if iq_sample.size != window_samples:
iq_sample = np.zeros(window_samples, dtype=np.complex64)
while True:
try: try:
# Try to read ZMQ message
msg = socket.recv(zmq.NOBLOCK) msg = socket.recv(zmq.NOBLOCK)
float_data = np.frombuffer(msg, dtype=np.float32) float_data = np.frombuffer(msg, dtype=np.float32)
if float_data.size >= 2: if float_data.size >= 2:
complex_data = float_data.reshape(-1, 2) complex_data = float_data.reshape(-1, 2)
iq_all = complex_data[:, 0] + 1j * complex_data[:, 1] iq_all = complex_data[:, 0] + 1j * complex_data[:, 1]
iq_sample = ( if len(iq_all) >= window_samples:
iq_all[-window_samples:] iq_sample = iq_all[-window_samples:]
if len(iq_all) >= window_samples else:
else np.pad(iq_all, (window_samples - len(iq_all), 0)) iq_sample = np.pad(iq_all, (window_samples - len(iq_all), 0))
)
# --- Create plot --- # Create plot
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6)) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6))
fig.subplots_adjust(hspace=0.4) fig.subplots_adjust(hspace=0.4)
# Time-domain plot (ms) # Time-domain plot
times_ms = np.arange(len(iq_sample)) * 1000 / sample_rate times_ms = np.arange(len(iq_sample)) * 1000 / sample_rate
ax1.plot(times_ms, np.real(iq_sample), label="Real", color='b') ax1.plot(times_ms, np.real(iq_sample), label="Real", color='b')
ax1.plot(times_ms, np.imag(iq_sample), label="Imag", color='r') ax1.plot(times_ms, np.imag(iq_sample), label="Imag", color='r')
@ -156,40 +200,105 @@ def generate_spectrum_plot():
ax1.grid(True, which='both', linestyle='--', linewidth=0.5) ax1.grid(True, which='both', linestyle='--', linewidth=0.5)
ax1.legend() ax1.legend()
# Spectrogram without grid # Spectrogram
cmap = plt.get_cmap('twilight')
ax2.specgram( ax2.specgram(
iq_sample, iq_sample,
Fs=sample_rate, Fs=sample_rate,
Fc=center_freq, Fc=center_freq,
NFFT=NFFT, NFFT=NFFT,
noverlap=noverlap, noverlap=512,
cmap=cmap cmap=cmap
) )
ax2.set_xlabel("Time (ms)") ax2.set_xlabel("Time (ms)")
ax2.set_ylabel("Frequency (Hz)") ax2.set_ylabel("Frequency (Hz)")
ax2.grid(False) ax2.grid(False)
ax2.set_ylim(center_freq - sample_rate / 2, ax2.set_ylim(center_freq - sample_rate / 2,
center_freq + sample_rate / 2) center_freq + sample_rate / 2)
ax2.xaxis.set_major_formatter(ticker.FuncFormatter(lambda t, pos: '{0:g}'.format(t*1e3))) # kHz format ax2.xaxis.set_major_formatter(
ax2.set_xlabel("Time (ms)") ticker.FuncFormatter(lambda t, pos: '{0:g}'.format(t*1e3))
ax2.set_ylabel("Frequency (Hz)") )
ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator()) ax2.xaxis.set_minor_locator(ticker.AutoMinorLocator())
# Save the plot
os.makedirs(os.path.dirname(PLOT_PATH), exist_ok=True)
plt.savefig(PLOT_PATH, bbox_inches='tight') plt.savefig(PLOT_PATH, bbox_inches='tight')
plt.close(fig) plt.close(fig)
except zmq.Again: except zmq.Again:
pass # No new data, keep last iq_sample # No new data
fig, ax = plt.subplots(figsize=(12, 6))
ax.text(0.5, 0.5, "Waiting for data...",
ha='center', va='center', transform=ax.transAxes, fontsize=16)
ax.set_title("Spectrum Analyzer - No Data (Streaming Active)")
plt.savefig(PLOT_PATH, bbox_inches='tight')
plt.close(fig)
except Exception as e:
print(f"Plot generation error: {e}")
fig, ax = plt.subplots(figsize=(12, 6))
ax.text(0.5, 0.5, f"Error: {str(e)}",
ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.set_title("Spectrum Analyzer - Error")
plt.savefig(PLOT_PATH, bbox_inches='tight')
plt.close(fig)
time.sleep(0.1)
# Fast refresh (20ms = 50 fps)
time.sleep(0.5)
# Cleanup when stopping
if socket:
socket.close()
print("Plotting thread stopped")
def start_plotting():
"""Start the plotting thread"""
global plot_thread, stop_event, pause_event
stop_event.clear()
pause_event.clear()
with config_lock:
config["streaming"] = True
if plot_thread is None or not plot_thread.is_alive():
plot_thread = threading.Thread(target=generate_spectrum_plot, daemon=True)
plot_thread.start()
print("Plotting thread started")
return True
def stop_plotting():
"""Stop the plotting thread"""
global plot_thread, stop_event
with config_lock:
config["streaming"] = False
stop_event.set()
if plot_thread and plot_thread.is_alive():
plot_thread.join(timeout=2.0)
# Create stopped message plot
fig, ax = plt.subplots(figsize=(12, 6))
ax.text(0.5, 0.5, "Streaming Stopped\nClick Start to begin",
ha='center', va='center', transform=ax.transAxes, fontsize=16)
ax.set_title("Spectrum Analyzer - Stopped")
plt.savefig(PLOT_PATH, bbox_inches='tight')
plt.close(fig)
print("Plotting thread stopped")
return True
def pause_plotting():
"""Pause the plotting updates"""
global pause_event
if pause_event.is_set():
pause_event.clear()
print("Plotting resumed")
return "resumed"
else:
pause_event.set()
print("Plotting paused")
return "paused"
# ----------------- Flask Routes ----------------- # ----------------- Flask Routes -----------------
@app.route('/') @app.route('/')
@ -199,17 +308,37 @@ def index():
@app.route('/update_gains', methods=['POST']) @app.route('/update_gains', methods=['POST'])
def update_gains(): def update_gains():
global usrp_tx_gain, usrp_rx_gain, scm_tx_gain, scm_rx_gain global usrp_tx_gain, usrp_rx_gain, scm_tx_gain, scm_rx_gain
usrp_tx = request.form.get('usrp_tx_gain', usrp_tx_gain, type=float)
usrp_rx = request.form.get('usrp_rx_gain', usrp_rx_gain, type=float) try:
scm_tx = request.form.get('scm_tx_gain', scm_tx_gain, type=float) usrp_tx = request.form.get('usrp_tx_gain', type=float)
scm_rx = request.form.get('scm_rx_gain', scm_rx_gain, type=float) usrp_rx = request.form.get('usrp_rx_gain', type=float)
scm_tx = request.form.get('scm_tx_gain', type=float)
scm_rx = request.form.get('scm_rx_gain', type=float)
if usrp_tx is None:
usrp_tx = usrp_tx_gain
if usrp_rx is None:
usrp_rx = usrp_rx_gain
if scm_tx is None:
scm_tx = scm_tx_gain
if scm_rx is None:
scm_rx = scm_rx_gain
gain_update(usrp_tx, usrp_rx, scm_tx, scm_rx) success = gain_update(usrp_tx, usrp_rx, scm_tx, scm_rx)
return jsonify({"status": "success", "message": "Gains updated successfully"}) if success:
return jsonify({"status": "success", "message": "Gains updated successfully"})
else:
return jsonify({"status": "error", "message": "Failed to update gains"}), 500
except Exception as e:
return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500
@app.route('/plot') @app.route('/plot')
def plot(): def plot():
return send_file(PLOT_PATH, mimetype='image/png') try:
return send_file(PLOT_PATH, mimetype='image/png')
except Exception as e:
return send_file(PLOT_PATH, mimetype='image/png')
@app.route('/get_gains') @app.route('/get_gains')
def get_gains(): def get_gains():
@ -222,19 +351,28 @@ def get_gains():
@app.route('/update_params', methods=['POST']) @app.route('/update_params', methods=['POST'])
def update_params(): def update_params():
global sample_rate, window_ms, center_freq, NFFT, tcp_port
try: try:
# Get parameters from form data
center_freq = request.form.get('center_freq', type=float) center_freq = request.form.get('center_freq', type=float)
sample_rate = request.form.get('sample_rate', type=float) sample_rate = request.form.get('sample_rate', type=float)
NFFT = request.form.get('fft_size', type=int) NFFT = request.form.get('fft_size', type=int)
window_ms = request.form.get('window_ms', type=float) window_ms = request.form.get('window_ms', type=float)
tcp_port = request.form.get('tcp_port', type=int) tcp_port = request.form.get('tcp_port', type=int)
if not all([center_freq, sample_rate, NFFT, window_ms, tcp_port]):
return jsonify({
'status': 'error',
'message': 'All parameters are required'
}), 400
with config_lock:
config["center_freq"] = center_freq
config["sample_rate"] = sample_rate
config["NFFT"] = NFFT
config["window_ms"] = window_ms
config["tcp_port"] = tcp_port
print(f"Updated params: center_freq={center_freq}, sample_rate={sample_rate}, NFFT={NFFT}, window_ms={window_ms}, tcp_port={tcp_port}")
# Save to config file if needed
save_config() save_config()
return jsonify({ return jsonify({
@ -242,33 +380,78 @@ def update_params():
'message': 'Parameters updated successfully' 'message': 'Parameters updated successfully'
}) })
except Exception as e: except Exception as e:
print(f"Error updating params: {e}")
return jsonify({ return jsonify({
'status': 'error', 'status': 'error',
'message': str(e) 'message': str(e)
}), 500 }), 500
# Add to your config handling @app.route('/start_stream', methods=['POST'])
def start_stream():
try:
success = start_plotting()
if success:
return jsonify({"status": "success", "message": "Streaming started"})
else:
return jsonify({"status": "error", "message": "Failed to start streaming"}), 500
except Exception as e:
return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500
@app.route('/stop_stream', methods=['POST'])
def stop_stream():
try:
success = stop_plotting()
if success:
return jsonify({"status": "success", "message": "Streaming stopped"})
else:
return jsonify({"status": "error", "message": "Failed to stop streaming"}), 500
except Exception as e:
return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500
@app.route('/pause_stream', methods=['POST'])
def pause_stream():
try:
result = pause_plotting()
return jsonify({"status": "success", "message": f"Streaming {result}", "state": result})
except Exception as e:
return jsonify({"status": "error", "message": f"Error: {str(e)}"}), 500
@app.route('/get_stream_state', methods=['GET'])
def get_stream_state():
with config_lock:
streaming = config["streaming"]
paused = pause_event.is_set()
state = "stopped"
if streaming and not paused:
state = "running"
elif streaming and paused:
state = "paused"
return jsonify({"state": state})
def save_config(): def save_config():
config = { with config_lock:
'center_freq': center_freq, cfg = dict(config)
'sample_rate': sample_rate,
'fft_size': NFFT, try:
'window_ms': window_ms, with open(os.path.join(os.getcwd(), "gain_viz.json"), 'w') as f:
'tcp_port' : tcp_port json.dump(cfg, f, indent=2)
} except Exception as e:
with open('/opt/gain-viz/config.json', 'w') as f: print(f"Error saving config: {e}")
json.dump(config, f)
# ----------------- Main ----------------- # ----------------- Main -----------------
def main(): def main():
# Ensure placeholder image exists # Ensure placeholder image exists
if not os.path.exists(PLOT_PATH): if not os.path.exists(PLOT_PATH):
plt.figure() fig, ax = plt.subplots(figsize=(12, 6))
plt.text(0.5, 0.5, "Waiting for data...", ha='center', va='center') ax.text(0.5, 0.5, "Click Start to begin streaming", ha='center', va='center', fontsize=16)
ax.set_title("Gain-Viz Spectrum Analyzer - Ready")
plt.savefig(PLOT_PATH) plt.savefig(PLOT_PATH)
plt.close() plt.close(fig)
# Start plotting thread print("Gain-Viz server started. Use the web interface to control streaming.")
threading.Thread(target=generate_spectrum_plot, daemon=True).start() app.run(host="0.0.0.0", port=5000, debug=True, use_reloader=False)
app.run(host="0.0.0.0", debug=True) if __name__ == '__main__':
main()