| | import typer |
| | import torch |
| | import subprocess |
| | from pathlib import Path |
| |
|
| | from expert import UpstreamExpert |
| |
|
| | SUBMISSION_FILES = ["expert.py", "model.pt"] |
| | SAMPLE_RATE = 16000 |
| | SECONDS = [2, 1.8, 3.7] |
| |
|
| | app = typer.Typer() |
| |
|
| | @app.command() |
| | def validate(): |
| | |
| | for file in SUBMISSION_FILES: |
| | if not Path(file).is_file(): |
| | raise ValueError(f"File {file} not found! Please include {file} in your submission") |
| |
|
| | try: |
| | upstream = UpstreamExpert(ckpt="model.pt") |
| | samples = [round(SAMPLE_RATE * sec) for sec in SECONDS] |
| | wavs = [torch.rand(sample) for sample in samples] |
| | results = upstream(wavs) |
| |
|
| | assert isinstance(results, dict) |
| | tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"] |
| | for task in tasks: |
| | hidden_states = results.get(task, results["hidden_states"]) |
| | assert isinstance(hidden_states, list) |
| |
|
| | for state in hidden_states: |
| | assert isinstance(state, torch.Tensor) |
| | assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)" |
| | assert state.shape == hidden_states[0].shape |
| |
|
| | downsample_rate = upstream.get_downsample_rates(task) |
| | assert isinstance(downsample_rate, int) |
| | assert abs(round(max(samples) / downsample_rate) - hidden_states[0].size(1)) < 5, "wrong downsample rate" |
| |
|
| | except: |
| | print("Please check the Upstream Specification on https://superbbenchmark.org/challenge-slt2022/upstream") |
| | raise |
| |
|
| | typer.echo("All submission files validated!") |
| | typer.echo("Now you can upload these files to huggingface's Hub.") |
| |
|
| |
|
| | @app.command() |
| | def upload(commit_message: str): |
| | subprocess.call("git pull origin main".split()) |
| | subprocess.call(["git", "add", "."]) |
| | subprocess.call(["git", "commit", "-m", f"Upload Upstream: {commit_message} "]) |
| | subprocess.call(["git", "push"]) |
| | typer.echo("Upload successful!") |
| | typer.echo("Please go to https://superbbenchmark.org/submit to make a submission with the following information:") |
| | typer.echo("1. Organization Name") |
| | typer.echo("2. Repository Name") |
| | typer.echo("3. Commit Hash (full 40 characters)") |
| | typer.echo("These information can be shown by: python cli.py info") |
| |
|
| | @app.command() |
| | def info(): |
| | result = subprocess.run(["git", "config", "--get", "remote.origin.url"], capture_output=True) |
| | url = result.stdout.decode("utf-8").strip() |
| | organization = url.split("/")[-2] |
| | repo = url.split("/")[-1] |
| |
|
| | result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True) |
| | commit_hash = result.stdout.decode("utf-8").strip() |
| |
|
| | typer.echo(f"Organization Name: {organization}") |
| | typer.echo(f"Repository Name: {repo}") |
| | typer.echo(f"Commit Hash: {commit_hash}") |
| |
|
| | if __name__ == "__main__": |
| | app() |
| |
|