forked from magenta/symbolic-music-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
106 lines (94 loc) · 3.29 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Model configurations."""
from magenta.models.music_vae import configs
from magenta.models.music_vae import data
from magenta.models.music_vae import data_hierarchical
MUSIC_VAE_CONFIG = {}
melody_2bar_converter = data.OneHotMelodyConverter(
skip_polyphony=False,
max_bars=100, # Truncate long melodies before slicing.
max_tensors_per_notesequence=None,
slice_bars=2,
gap_bars=None,
steps_per_quarter=4,
dedupe_event_lists=False,
)
mel_2bar_nopoly_converter = data.OneHotMelodyConverter(
skip_polyphony=True,
max_bars=100, # Truncate long melodies before slicing.
max_tensors_per_notesequence=None,
slice_bars=2,
gap_bars=None,
steps_per_quarter=4,
dedupe_event_lists=False,
)
melody_16bar_converter = data.OneHotMelodyConverter(
skip_polyphony=False,
max_bars=100, # Truncate long melodies before slicing.
slice_bars=16,
gap_bars=16,
max_tensors_per_notesequence=None,
steps_per_quarter=4,
dedupe_event_lists=False,
)
multitrack_default_1bar_converter = (
data_hierarchical.MultiInstrumentPerformanceConverter(
num_velocity_bins=8,
hop_size_bars=1,
min_num_instruments=2,
max_num_instruments=8,
max_events_per_instrument=64,
)
)
multitrack_zero_1bar_converter = data_hierarchical.MultiInstrumentPerformanceConverter(
num_velocity_bins=8,
hop_size_bars=1,
min_num_instruments=0,
max_num_instruments=8,
min_total_events=0,
max_events_per_instrument=64,
drop_tracks_and_truncate=True,
)
MUSIC_VAE_CONFIG["melody-2-big"] = configs.CONFIG_MAP["cat-mel_2bar_big"]._replace(
data_converter=melody_2bar_converter
)
MUSIC_VAE_CONFIG["melody-16-big"] = configs.CONFIG_MAP["hierdec-mel_16bar"]._replace(
data_converter=melody_16bar_converter
)
MUSIC_VAE_CONFIG["multi-1-big"] = configs.CONFIG_MAP[
"hier-multiperf_vel_1bar_big"
]._replace(data_converter=multitrack_default_1bar_converter)
MUSIC_VAE_CONFIG["multi-0min-1-big"] = configs.CONFIG_MAP[
"hier-multiperf_vel_1bar_big"
]._replace(data_converter=multitrack_zero_1bar_converter)
MUSIC_VAE_CONFIG["melody-2-big-nopoly"] = configs.Config(
model=configs.MusicVAE(
configs.lstm_models.BidirectionalLstmEncoder(),
configs.lstm_models.CategoricalLstmDecoder(),
),
hparams=configs.merge_hparams(
configs.lstm_models.get_default_hparams(),
configs.HParams(
batch_size=512,
max_seq_len=32, # 2 bars w/ 16 steps per bar
z_size=512,
enc_rnn_size=[2048],
dec_rnn_size=[2048, 2048, 2048],
),
),
note_sequence_augmenter=data.NoteSequenceAugmenter(transpose_range=(-5, 5)),
data_converter=mel_2bar_nopoly_converter,
)