-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcommu_wrapper.py
70 lines (53 loc) · 2.28 KB
/
commu_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
import yaml
from tqdm import tqdm
from commu.midi_generator.generate_pipeline import MidiGenerationPipeline
from commu_dset import DSET
from commu_file import CommuFile
def make_midis(
bpm: int,
key: str,
time_signature: str,
num_measures: int,
genre: str,
rhythm: str,
chord_progression: str,
timestamp: str) -> Dict[str, List[CommuFile]]:
with open('cfg/inference.yaml') as f:
cfg = yaml.safe_load(f)
role_to_midis = defaultdict(list)
for role in tqdm(DSET.get_track_roles()):
pipeline = MidiGenerationPipeline({'checkpoint_dir': 'ckpt/checkpoint_best.pt'})
inference_cfg = pipeline.model_initialize_task.inference_cfg
model = pipeline.model_initialize_task.execute()
min_v, max_v = DSET.sample_min_max_velocity(role)
instrument = DSET.sample_instrument(role)
encoded_meta = pipeline.preprocess_task.excecute({
'track_role': role,
'bpm': bpm,
'audio_key': key,
'time_signature': time_signature,
'num_measures': num_measures,
'genre': genre,
'rhythm': rhythm,
'chord_progression': DSET.unfold(chord_progression),
'pitch_range': DSET.sample_pitch_range(role),
'inst': instrument,
'min_velocity': min_v,
'max_velocity': max_v,
'top_k': cfg['top_k'],
'temperature': cfg['temperature'],
'output_dir': f'out/{timestamp}',
'num_generate': 1})
input_data = pipeline.preprocess_task.input_data
meta_info_len = pipeline.preprocess_task.get_meta_info_length()
pipeline.inference_task(model=model, input_data=input_data, inference_cfg=inference_cfg)
sequences = pipeline.inference_task.execute(encoded_meta)
pipeline.postprocess_task(input_data=input_data)
pipeline.postprocess_task.execute(sequences=sequences, meta_info_len=meta_info_len)
filepath = f'out/{timestamp}/{role}.mid'
role_to_midis[role].append(CommuFile(filepath, role, instrument))
Path(filepath).unlink()
return role_to_midis