| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | """ |
| | To run this script, from the root of the repo. Make sure to have Flask installed |
| | |
| | FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567 |
| | # or if you have gunicorn |
| | gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile - |
| | |
| | """ |
| | from collections import defaultdict |
| | from functools import wraps |
| | from hashlib import sha1 |
| | import json |
| | import math |
| | from pathlib import Path |
| | import random |
| | import typing as tp |
| |
|
| | from flask import Flask, redirect, render_template, request, session, url_for |
| |
|
| | from audiocraft import train |
| | from audiocraft.utils.samples.manager import get_samples_for_xps |
| |
|
| |
|
| | SAMPLES_PER_PAGE = 8 |
| | MAX_RATING = 5 |
| | storage = Path(train.main.dora.dir / 'mos_storage') |
| | storage.mkdir(exist_ok=True) |
| | surveys = storage / 'surveys' |
| | surveys.mkdir(exist_ok=True) |
| | magma_root = Path(train.__file__).parent.parent |
| | app = Flask('mos', static_folder=str(magma_root / 'scripts/static'), |
| | template_folder=str(magma_root / 'scripts/templates')) |
| | app.secret_key = b'audiocraft makes the best songs' |
| |
|
| |
|
| | def normalize_path(path: Path): |
| | """Just to make path a bit nicer, make them relative to the Dora root dir. |
| | """ |
| | path = path.resolve() |
| | dora_dir = train.main.dora.dir.resolve() / 'xps' |
| | return path.relative_to(dora_dir) |
| |
|
| |
|
| | def get_full_path(normalized_path: Path): |
| | """Revert `normalize_path`. |
| | """ |
| | return train.main.dora.dir.resolve() / 'xps' / normalized_path |
| |
|
| |
|
| | def get_signature(xps: tp.List[str]): |
| | """Return a signature for a list of XP signatures. |
| | """ |
| | return sha1(json.dumps(xps).encode()).hexdigest()[:10] |
| |
|
| |
|
| | def ensure_logged(func): |
| | """Ensure user is logged in. |
| | """ |
| | @wraps(func) |
| | def _wrapped(*args, **kwargs): |
| | user = session.get('user') |
| | if user is None: |
| | return redirect(url_for('login', redirect_to=request.url)) |
| | return func(*args, **kwargs) |
| | return _wrapped |
| |
|
| |
|
| | @app.route('/login', methods=['GET', 'POST']) |
| | def login(): |
| | """Login user if not already, then redirect. |
| | """ |
| | user = session.get('user') |
| | if user is None: |
| | error = None |
| | if request.method == 'POST': |
| | user = request.form['user'] |
| | if not user: |
| | error = 'User cannot be empty' |
| | if user is None or error: |
| | return render_template('login.html', error=error) |
| | assert user |
| | session['user'] = user |
| | redirect_to = request.args.get('redirect_to') |
| | if redirect_to is None: |
| | redirect_to = url_for('index') |
| | return redirect(redirect_to) |
| |
|
| |
|
| | @app.route('/', methods=['GET', 'POST']) |
| | @ensure_logged |
| | def index(): |
| | """Offer to create a new study. |
| | """ |
| | errors = [] |
| | if request.method == 'POST': |
| | xps_or_grids = [part.strip() for part in request.form['xps'].split()] |
| | xps = set() |
| | for xp_or_grid in xps_or_grids: |
| | xp_path = train.main.dora.dir / 'xps' / xp_or_grid |
| | if xp_path.exists(): |
| | xps.add(xp_or_grid) |
| | continue |
| | grid_path = train.main.dora.dir / 'grids' / xp_or_grid |
| | if grid_path.exists(): |
| | for child in grid_path.iterdir(): |
| | if child.is_symlink(): |
| | xps.add(child.name) |
| | continue |
| | errors.append(f'{xp_or_grid} is neither an XP nor a grid!') |
| | assert xps or errors |
| | blind = 'true' if request.form.get('blind') == 'on' else 'false' |
| | xps = list(xps) |
| | if not errors: |
| | signature = get_signature(xps) |
| | manifest = { |
| | 'xps': xps, |
| | } |
| | survey_path = surveys / signature |
| | survey_path.mkdir(exist_ok=True) |
| | with open(survey_path / 'manifest.json', 'w') as f: |
| | json.dump(manifest, f, indent=2) |
| | return redirect(url_for('survey', blind=blind, signature=signature)) |
| | return render_template('index.html', errors=errors) |
| |
|
| |
|
| | @app.route('/survey/<signature>', methods=['GET', 'POST']) |
| | @ensure_logged |
| | def survey(signature): |
| | success = request.args.get('success', False) |
| | seed = int(request.args.get('seed', 4321)) |
| | blind = request.args.get('blind', 'false') in ['true', 'on', 'True'] |
| | exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True'] |
| | exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True'] |
| | max_epoch = int(request.args.get('max_epoch', '-1')) |
| | survey_path = surveys / signature |
| | assert survey_path.exists(), survey_path |
| |
|
| | user = session['user'] |
| | result_folder = survey_path / 'results' |
| | result_folder.mkdir(exist_ok=True) |
| | result_file = result_folder / f'{user}_{seed}.json' |
| |
|
| | with open(survey_path / 'manifest.json') as f: |
| | manifest = json.load(f) |
| |
|
| | xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']] |
| | names, ref_name = train.main.get_names(xps) |
| |
|
| | samples_kwargs = { |
| | 'exclude_prompted': exclude_prompted, |
| | 'exclude_unprompted': exclude_unprompted, |
| | 'max_epoch': max_epoch, |
| | } |
| | matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs) |
| | models_by_id = { |
| | id: [{ |
| | 'xp': xps[idx], |
| | 'xp_name': names[idx], |
| | 'model_id': f'{xps[idx].sig}-{sample.id}', |
| | 'sample': sample, |
| | 'is_prompted': sample.prompt is not None, |
| | 'errors': [], |
| | } for idx, sample in enumerate(samples)] |
| | for id, samples in matched_samples.items() |
| | } |
| | experiments = [ |
| | {'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch} |
| | for idx, xp in enumerate(xps) |
| | ] |
| |
|
| | keys = list(matched_samples.keys()) |
| | keys.sort() |
| | rng = random.Random(seed) |
| | rng.shuffle(keys) |
| | model_ids = keys[:SAMPLES_PER_PAGE] |
| |
|
| | if blind: |
| | for key in model_ids: |
| | rng.shuffle(models_by_id[key]) |
| |
|
| | ok = True |
| | if request.method == 'POST': |
| | all_samples_results = [] |
| | for id in model_ids: |
| | models = models_by_id[id] |
| | result = { |
| | 'id': id, |
| | 'is_prompted': models[0]['is_prompted'], |
| | 'models': {} |
| | } |
| | all_samples_results.append(result) |
| | for model in models: |
| | rating = request.form[model['model_id']] |
| | if rating: |
| | rating = int(rating) |
| | assert rating <= MAX_RATING and rating >= 1 |
| | result['models'][model['xp'].sig] = rating |
| | model['rating'] = rating |
| | else: |
| | ok = False |
| | model['errors'].append('Please rate this model.') |
| | if ok: |
| | result = { |
| | 'results': all_samples_results, |
| | 'seed': seed, |
| | 'user': user, |
| | 'blind': blind, |
| | 'exclude_prompted': exclude_prompted, |
| | 'exclude_unprompted': exclude_unprompted, |
| | } |
| | print(result) |
| | with open(result_file, 'w') as f: |
| | json.dump(result, f) |
| | seed = seed + 1 |
| | return redirect(url_for( |
| | 'survey', signature=signature, blind=blind, seed=seed, |
| | exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, |
| | max_epoch=max_epoch, success=True)) |
| |
|
| | ratings = list(range(1, MAX_RATING + 1)) |
| | return render_template( |
| | 'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success, |
| | exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch, |
| | experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[], |
| | ref_name=ref_name, already_filled=result_file.exists()) |
| |
|
| |
|
| | @app.route('/audio/<path:path>') |
| | def audio(path: str): |
| | full_path = Path('/') / path |
| | assert full_path.suffix in [".mp3", ".wav"] |
| | return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'} |
| |
|
| |
|
| | def mean(x): |
| | return sum(x) / len(x) |
| |
|
| |
|
| | def std(x): |
| | m = mean(x) |
| | return math.sqrt(sum((i - m)**2 for i in x) / len(x)) |
| |
|
| |
|
| | @app.route('/results/<signature>') |
| | @ensure_logged |
| | def results(signature): |
| |
|
| | survey_path = surveys / signature |
| | assert survey_path.exists(), survey_path |
| | result_folder = survey_path / 'results' |
| | result_folder.mkdir(exist_ok=True) |
| |
|
| | |
| | ratings_per_model = defaultdict(list) |
| | users = [] |
| | for result_file in result_folder.iterdir(): |
| | if result_file.suffix != '.json': |
| | continue |
| | with open(result_file) as f: |
| | results = json.load(f) |
| | users.append(results['user']) |
| | for result in results['results']: |
| | for sig, rating in result['models'].items(): |
| | ratings_per_model[sig].append(rating) |
| |
|
| | fmt = '{:.2f}' |
| | models = [] |
| | for model in sorted(ratings_per_model.keys()): |
| | ratings = ratings_per_model[model] |
| |
|
| | models.append({ |
| | 'sig': model, |
| | 'samples': len(ratings), |
| | 'mean_rating': fmt.format(mean(ratings)), |
| | |
| | |
| | 'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5), |
| | }) |
| | return render_template('results.html', signature=signature, models=models, users=users) |
| |
|