telegram191 commited on
Commit
8be5e63
·
verified ·
1 Parent(s): 44ea95d

Fallback to ffmpeg when PyAV is unavailable

Browse files
Files changed (1) hide show
  1. audiocraft/data/audio.py +69 -2
audiocraft/data/audio.py CHANGED
@@ -19,19 +19,25 @@ import soundfile
19
  import torch
20
  from torch.nn import functional as F
21
 
22
- import av
23
  import subprocess as sp
24
 
25
  from .audio_utils import f32_pcm, normalize_audio
26
 
27
 
28
  _av_initialized = False
 
 
 
 
29
 
30
 
31
  def _init_av():
32
  global _av_initialized
33
  if _av_initialized:
34
  return
 
 
 
35
  logger = logging.getLogger('libav.mp3')
36
  logger.setLevel(logging.ERROR)
37
  _av_initialized = True
@@ -46,6 +52,8 @@ class AudioFileInfo:
46
 
47
  def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
48
  _init_av()
 
 
49
  with av.open(str(filepath)) as af:
50
  stream = af.streams.audio[0]
51
  sample_rate = stream.codec_context.sample_rate
@@ -59,6 +67,24 @@ def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
59
  return AudioFileInfo(info.samplerate, info.duration, info.channels)
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
63
  # torchaudio no longer returns useful duration informations for some formats like mp3s.
64
  filepath = Path(filepath)
@@ -66,6 +92,8 @@ def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
66
  # ffmpeg has some weird issue with flac.
67
  return _soundfile_info(filepath)
68
  else:
 
 
69
  return _av_info(filepath)
70
 
71
 
@@ -81,6 +109,8 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
81
  tuple of torch.Tensor, int: Tuple containing audio data and sample rate
82
  """
83
  _init_av()
 
 
84
  with av.open(str(filepath)) as af:
85
  stream = af.streams.audio[0]
86
  sr = stream.codec_context.sample_rate
@@ -113,6 +143,40 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
113
  return f32_pcm(wav), sr
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
117
  duration: float = -1.0, pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
118
  """Read audio by picking the most appropriate backend tool based on the audio format.
@@ -137,7 +201,10 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
137
  if len(wav.shape) == 1:
138
  wav = torch.unsqueeze(wav, 0)
139
  else:
140
- wav, sr = _av_read(filepath, seek_time, duration)
 
 
 
141
  if pad and duration > 0:
142
  expected_frames = int(duration * sr)
143
  wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
 
19
  import torch
20
  from torch.nn import functional as F
21
 
 
22
  import subprocess as sp
23
 
24
  from .audio_utils import f32_pcm, normalize_audio
25
 
26
 
27
  _av_initialized = False
28
+ try:
29
+ import av
30
+ except Exception:
31
+ av = None
32
 
33
 
34
  def _init_av():
35
  global _av_initialized
36
  if _av_initialized:
37
  return
38
+ if av is None:
39
+ _av_initialized = True
40
+ return
41
  logger = logging.getLogger('libav.mp3')
42
  logger.setLevel(logging.ERROR)
43
  _av_initialized = True
 
52
 
53
  def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
54
  _init_av()
55
+ if av is None:
56
+ raise RuntimeError("PyAV is not available")
57
  with av.open(str(filepath)) as af:
58
  stream = af.streams.audio[0]
59
  sample_rate = stream.codec_context.sample_rate
 
67
  return AudioFileInfo(info.samplerate, info.duration, info.channels)
68
 
69
 
70
+ def _ffmpeg_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
71
+ command = [
72
+ "ffprobe",
73
+ "-v", "error",
74
+ "-select_streams", "a:0",
75
+ "-show_entries", "stream=sample_rate,channels,duration",
76
+ "-of", "default=noprint_wrappers=1:nokey=1",
77
+ str(filepath),
78
+ ]
79
+ out = sp.check_output(command).decode("utf-8", "replace").strip().splitlines()
80
+ if len(out) < 3:
81
+ raise RuntimeError("ffprobe did not return enough audio info")
82
+ sample_rate = int(float(out[0]))
83
+ channels = int(float(out[1]))
84
+ duration = float(out[2])
85
+ return AudioFileInfo(sample_rate, duration, channels)
86
+
87
+
88
  def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
89
  # torchaudio no longer returns useful duration informations for some formats like mp3s.
90
  filepath = Path(filepath)
 
92
  # ffmpeg has some weird issue with flac.
93
  return _soundfile_info(filepath)
94
  else:
95
+ if av is None:
96
+ return _ffmpeg_info(filepath)
97
  return _av_info(filepath)
98
 
99
 
 
109
  tuple of torch.Tensor, int: Tuple containing audio data and sample rate
110
  """
111
  _init_av()
112
+ if av is None:
113
+ raise RuntimeError("PyAV is not available")
114
  with av.open(str(filepath)) as af:
115
  stream = af.streams.audio[0]
116
  sr = stream.codec_context.sample_rate
 
143
  return f32_pcm(wav), sr
144
 
145
 
146
+ def _ffmpeg_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
147
+ try:
148
+ info = _ffmpeg_info(filepath)
149
+ sr = info.sample_rate
150
+ channels = info.channels
151
+ except Exception:
152
+ sr = 44100
153
+ channels = 2
154
+ command = [
155
+ "ffmpeg",
156
+ "-loglevel", "error",
157
+ "-nostdin",
158
+ ]
159
+ if seek_time > 0:
160
+ command += ["-ss", str(seek_time)]
161
+ command += ["-i", str(filepath)]
162
+ if duration and duration > 0:
163
+ command += ["-t", str(duration)]
164
+ command += [
165
+ "-f", "f32le",
166
+ "-acodec", "pcm_f32le",
167
+ "-ar", str(sr),
168
+ "-ac", str(channels),
169
+ "-",
170
+ ]
171
+ raw = sp.check_output(command)
172
+ audio = np.frombuffer(raw, dtype=np.float32)
173
+ if audio.size == 0:
174
+ raise RuntimeError("ffmpeg returned empty audio")
175
+ audio = audio.reshape(-1, channels).T
176
+ wav = torch.from_numpy(audio).contiguous()
177
+ return f32_pcm(wav), sr
178
+
179
+
180
  def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
181
  duration: float = -1.0, pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
182
  """Read audio by picking the most appropriate backend tool based on the audio format.
 
201
  if len(wav.shape) == 1:
202
  wav = torch.unsqueeze(wav, 0)
203
  else:
204
+ if av is None:
205
+ wav, sr = _ffmpeg_read(filepath, seek_time, duration)
206
+ else:
207
+ wav, sr = _av_read(filepath, seek_time, duration)
208
  if pad and duration > 0:
209
  expected_frames = int(duration * sr)
210
  wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))