diff --git a/README.md b/README.md
index cf9522ead895a430f17869ce0f723c81ad675368..5395c734b41779da8bb4f4cd971f5f7ec3f53cec 100644
--- a/README.md
+++ b/README.md
@@ -42,7 +42,7 @@ You should expect to hear intelligible (but noisy) speech by ~8k steps (~1.5h on
 #### Training with trajectories
 You don't need to run any preprocessing.
 ```
-python src/diffwave/__main__.py /path/to/model/dir /path/to/file/containing/trajectories --data_type trajectories_x #or trajectories if you want to use 3D data
+python src/diffwave/__main__.py /path/to/model/dir /path/to/file/containing/trajectories --data_type trajectories_x # for 1D data
 ```
 
 #### Multi-GPU training
diff --git a/src/diffwave/inference.py b/src/diffwave/inference.py
index 18a3715a497331b6204a16801b27d967be847665..59bb1187ce43654a12df55bd12cc45d50446d5e3 100644
--- a/src/diffwave/inference.py
+++ b/src/diffwave/inference.py
@@ -31,8 +31,10 @@ from diffwave.model import DiffWave
 
 
 models = {}
+STD = 1.7358
+MEAN = -0.0003
 
-def predict(spectrogram=None, model_dir=None, params=None, device=torch.device('cuda'), fast_sampling=False):
+def predict(spectrogram=None, model_dir=None, params=None, device=torch.device('cuda'), fast_sampling=False, clamp=True):
   # Lazy load model.
   if not model_dir in models:
     if os.path.exists(f'{model_dir}/weights.pt'):
@@ -91,41 +93,56 @@ def predict(spectrogram=None, model_dir=None, params=None, device=torch.device('
         noise = torch.randn_like(audio)
         sigma = ((1.0 - alpha_cum[n-1]) / (1.0 - alpha_cum[n]) * beta[n])**0.5
         audio += sigma * noise
-        audio = torch.clamp(audio, -1.0, 1.0)
-        #audio = torch.clamp(audio, -1.0, 1.0) if(params.audio_len != 2000) else audio
+      if clamp:  # originally done for audio
+        audio = torch.clamp(audio, -1.0, 1.0) 
+        
     return audio, model.params.sample_rate
 
 
-def main(args):
-  if args.spectrogram_path:
-    spectrogram = torch.from_numpy(np.load(args.spectrogram_path))
-  else:
-    spectrogram = None
-    
+def predict_audio(spectrogram, args):
+  '''
+    Function that calls predict() to generate an audio sample and save it.
+    Note that we are not using a list of spectograms, but a single one.
+    So, if args.num_samples > 1, we will generate the same audio sample multiple times.
+  '''
   samples = []
-  for i in range(args.num_samples):
+  for _ in range(args.num_samples):
     audio, sr = predict(spectrogram, model_dir=args.model_dir, fast_sampling=args.fast, params=base_params, device=torch.device('cpu' if args.cpu else 'cuda'))
-    if base_params.audio_len !=2000:
-      samples.append(audio.cpu())
-    else:
-      # this is a lagrangian trajectory, we have to apply the inverse of 
-      # the transformations used when preprocessing
-      reverse_transform = Compose([
+    samples.append(audio.cpu())
+    # save data
+    torchaudio.save(args.output, audio, sample_rate=sr)
+
+def predict_trajectories(args):
+  '''
+    Function that calls predict() to generate a trajectory sample and save it.
+    Note we're not making the transformations something variable; they are fixed.
+  '''
+  samples = []
+  for _ in range(args.num_samples):
+    trajectory, _ = predict(model_dir=args.model_dir, fast_sampling=args.fast, params=base_params, device=torch.device('cpu' if args.cpu else 'cuda'), clamp=args.clamp)
+    reverse_transform = Compose([
+        Lambda(lambda t: (t*STD) + MEAN),
         Lambda(lambda t: t.numpy(force=True).astype(np.float64).transpose()),
       ])
-      trajectory = reverse_transform(audio)
-      samples.append(trajectory)
-
-  if base_params.audio_len !=2000:
-    for audio in samples:
-      torchaudio.save(args.output, audio, sample_rate=sr)
-  else:
-    # vertically stack all the trajectories
-    trajectories = np.stack(samples, axis=0)
-    print(trajectories.shape)
+    trajectory = reverse_transform(trajectory)
+    samples.append(trajectory)
+    # save data
+    trajectories = np.stack(samples, axis=0) # so size = (num_samples, num_timesteps, 1)
     with open(args.output, 'wb') as f:
       np.save(f, trajectories)
 
+def main(args):
+  if args.spectrogram_path:
+    spectrogram = torch.from_numpy(np.load(args.spectrogram_path))
+  else:
+    spectrogram = None
+  
+  if args.data_type == 'audio':
+    predict_audio(spectrogram, args)
+  elif args.data_type == 'trajectories_x':
+    predict_trajectories(args)
+  else:
+    raise NotImplementedError
 
 if __name__ == '__main__':
   parser = ArgumentParser(description='runs inference on a spectrogram file generated by diffwave.preprocess')
@@ -141,4 +158,14 @@ if __name__ == '__main__':
       help='use cpu instead of cuda')
   parser.add_argument('--num_samples', default=1, type=int,
       help='number of samples to generate')
+  parser.add_argument('--data_type', default='audio', type=str,
+      help='indicate what type of data is being trained on (audio or other with custom dataloader)')
+  parser.add_argument('--clamp', '-c', default=True, type=bool,
+      help='clamp in [-1,1] when generating data')
   main(parser.parse_args())
+
+'''
+Example usage:
+python src/diffwave/inference.py ./models/weigths.pt -o ./path_to_file/that/stores/new_samples.npy \
+--data_type trajectories_x --cpu --num_samples 100 --clamp False
+'''
\ No newline at end of file