forked from YosukeSugiura/SEGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
210 lines (161 loc) · 6.84 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#
# Data Creator for SEGAN
#
#
from __future__ import absolute_import
import math
import wave
import array
import joblib
import glob
from numba import jit
import numpy as np
import numpy.random as rd
from scipy.signal import lfilter
from settings import settings
import os
# Low Pass Filter for de-emphasis
@jit
def de_emph(y, preemph=0.95):
if preemph <= 0:
return y
return lfilter(1,[1, -preemph], y)
# Dataset loader
def data_loader(test=False, preemph=0.95):
"""
Read wav files or Load pkl files
"""
## Sub function : wav read & data shaping
def wavloader(filename, length, name='wav'):
# Error
num = len(filename)
if num == 0:
print('Dataset Error : no wave files.')
i = 1
filedata = []
for filename_ in filename:
file_ = wave.open(filename_, 'rb')
filedata.append(np.frombuffer(file_.readframes(-1), dtype='int16'))
file_.close()
print(' Loading {0} wav... #{1} / {2}'.format(name, i, num))
i+=1
filedata = np.concatenate(filedata, axis=0) # Serializing
filedata = filedata - preemph * np.roll(filedata, 1) # Pre-enphasis
filedata = filedata.astype(np.float32) # Data Compressing (float64 -> float32)
L = length // 2 # Half of Input Size (init: 8192 samples)
D = len(filedata) // L # No. of 0.5s blocks
filedata = filedata[:D * L].reshape(D, L) # Split data for each half of input size : (1,:) --> (D, 8192)
return filedata
# Load settings
args = settings()
# Make folder
if not os.path.exists(args.model_save_path): # Folder of model
os.makedirs(args.model_save_path)
if not os.path.exists(args.train_pkl_path): # Folder of train pkl
os.makedirs(args.train_pkl_path)
if not os.path.exists(args.test_pkl_path): # Folder of test pkl
os.makedirs(args.test_pkl_path)
# File name
if not test:
wav_clean = args.clean_train_path + '/*.wav'
wav_noisy = args.noisy_train_path + '/*.wav'
pkl_clean = args.train_pkl_path + '/' + args.train_pkl_clean
pkl_noisy = args.train_pkl_path + '/' + args.train_pkl_noisy
else:
wav_clean = args.clean_test_path + '/*.wav'
wav_noisy = args.noisy_test_path + '/*.wav'
pkl_clean = args.test_pkl_path + '/' + args.test_pkl_clean
pkl_noisy = args.test_pkl_path + '/' + args.test_pkl_noisy
## No pkl files -> read wav + create pkl files
## -------------------------------------------------
if not (os.access(pkl_clean, os.F_OK) and os.access(pkl_noisy, os.F_OK)):
## Wav files
print(' Load wav file...')
# Get file path
cname = glob.glob(wav_clean)
nname = glob.glob(wav_noisy)
# Get wave data
cdata = wavloader(cname, args.len, name='clean') # Clean wav
ndata = wavloader(nname, args.len, name='noisy') # Noisy wav
## Pkl files
print(' Create Pkl file...')
# Create clean pkl file
with open(pkl_clean, 'wb') as f:
joblib.dump(cdata, f, protocol=-1,compress=3)
# Create noisy pkl file
with open(pkl_noisy, 'wb') as f:
joblib.dump(ndata, f, protocol=-1,compress=3)
## Pkl files exist -> Load
## -------------------------------------------------
else:
# Load clean pkl file
print(' Load Clean Pkl...')
with open(pkl_clean, 'rb') as f:
cdata = joblib.load(f)
# Load noisy pkl file
print(' Load Noisy Pkl...')
with open(pkl_noisy, 'rb') as f:
ndata = joblib.load(f)
return cdata, ndata
class create_batch:
"""
Creating Batch Data for training
"""
## Initialization
def __init__(self, clean_data, noisy_data, batches):
# Normalization
def normalize(data):
return (1. / 32767.) * data # [-32768 ~ 32768] -> [-1 ~ 1]
# Data Shaping
self.clean = np.expand_dims(normalize(clean_data),axis=1) # (D,8192,1) -> (D,1,8192)
self.noisy = np.expand_dims(normalize(noisy_data),axis=1) # (D,8192,1) -> (D,1,8192)
# Random index ( for data scrambling)
ind = np.array(range(len(clean_data)-1))
rd.shuffle(ind)
# Parameters
self.batch = batches
self.batch_num = math.ceil(len(clean_data)/batches) # Batch num for each 1 Epoch
self.rnd = np.r_[ind,ind[:self.batch_num*batches-len(clean_data)+1]] # Reuse beggining of data when not enough data
self.len = len(clean_data) # Data length
self.index = 0 # Start Position for data loading
## Pop batch data
def next(self, i):
# Index of extracting data
index = self.rnd[ i * self.batch : (i + 1) * self.batch ]
# Reconstructing clean & noisy batch : (*, 1,8192) -> (*, 1,16384)
return np.concatenate((self.clean[index],self.clean[index+1]),axis=2), \
np.concatenate((self.noisy[index],self.noisy[index+1]),axis=2)
class create_batch_test:
"""
Creating Batch Data for test
"""
## Initialization
def __init__(self, clean_data, noisy_data, start_frame=None, stop_frame=None):
def normalize(data):
return (1. / 32767.) * data # [-32768 ~ 32768] -> [-1 ~ 1]
# Processing range
if start_frame is None: # Start frame position
start_frame = 0
if stop_frame is None: # Stop frame position
stop_frame = clean_data.shape[0]
# Parameters
f_len = clean_data.shape[1] * 2 # Inuput size : 8192*2 = 16384
stop_frame = 2 * math.floor((stop_frame-start_frame)/2) # Truncate protruded frame
self.clean = np.expand_dims(normalize(clean_data[start_frame:stop_frame]).reshape(-1, f_len), axis=1)
self.noisy = np.expand_dims(normalize(noisy_data[start_frame:stop_frame]).reshape(-1, f_len), axis=1)
self.len = len(clean_data)
def wav_write(filename, x, fs=16000):
# x = de_emph(x) # De-emphasis using LPF
x = x * 32767 # denormalized
x = x.astype('int16') # cast to int
w = wave.Wave_write(filename)
w.setparams((1, # channel
2, # byte width
fs, # sampling rate
len(x), # #. of frames
'NONE',
'not compressed' # no compression
))
w.writeframes(array.array('h', x).tobytes())
w.close()
return 0