Instructions to use Superxixixi/LoCoNet_ASD with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Superxixixi/LoCoNet_ASD with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Superxixixi/LoCoNet_ASD", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Superxixixi/LoCoNet_ASD", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os, torch, numpy, cv2, random, glob, python_speech_features, json, math | |
| from scipy.io import wavfile | |
| from torchvision.transforms import RandomCrop | |
| from operator import itemgetter | |
| from torchvggish import vggish_input, vggish_params, mel_features | |
| def overlap(audio, noiseAudio): | |
| snr = [random.uniform(-5, 5)] | |
| if len(noiseAudio) < len(audio): | |
| shortage = len(audio) - len(noiseAudio) | |
| noiseAudio = numpy.pad(noiseAudio, (0, shortage), 'wrap') | |
| else: | |
| noiseAudio = noiseAudio[:len(audio)] | |
| noiseDB = 10 * numpy.log10(numpy.mean(abs(noiseAudio**2)) + 1e-4) | |
| cleanDB = 10 * numpy.log10(numpy.mean(abs(audio**2)) + 1e-4) | |
| noiseAudio = numpy.sqrt(10**((cleanDB - noiseDB - snr) / 10)) * noiseAudio | |
| audio = audio + noiseAudio | |
| return audio.astype(numpy.int16) | |
| def load_audio(data, dataPath, numFrames, audioAug, audioSet=None): | |
| dataName = data[0] | |
| fps = float(data[2]) | |
| audio = audioSet[dataName] | |
| if audioAug == True: | |
| augType = random.randint(0, 1) | |
| if augType == 1: | |
| audio = overlap(dataName, audio, audioSet) | |
| else: | |
| audio = audio | |
| # fps is not always 25, in order to align the visual, we modify the window and step in MFCC extraction process based on fps | |
| audio = python_speech_features.mfcc(audio, | |
| 16000, | |
| numcep=13, | |
| winlen=0.025 * 25 / fps, | |
| winstep=0.010 * 25 / fps) | |
| maxAudio = int(numFrames * 4) | |
| if audio.shape[0] < maxAudio: | |
| shortage = maxAudio - audio.shape[0] | |
| audio = numpy.pad(audio, ((0, shortage), (0, 0)), 'wrap') | |
| audio = audio[:int(round(numFrames * 4)), :] | |
| return audio | |
| def load_single_audio(audio, fps, numFrames, audioAug=False): | |
| audio = python_speech_features.mfcc(audio, | |
| 16000, | |
| numcep=13, | |
| winlen=0.025 * 25 / fps, | |
| winstep=0.010 * 25 / fps) | |
| maxAudio = int(numFrames * 4) | |
| if audio.shape[0] < maxAudio: | |
| shortage = maxAudio - audio.shape[0] | |
| audio = numpy.pad(audio, ((0, shortage), (0, 0)), 'wrap') | |
| audio = audio[:int(round(numFrames * 4)), :] | |
| return audio | |
| def load_visual(data, dataPath, numFrames, visualAug): | |
| dataName = data[0] | |
| videoName = data[0][:11] | |
| faceFolderPath = os.path.join(dataPath, videoName, dataName) | |
| faceFiles = glob.glob("%s/*.jpg" % faceFolderPath) | |
| sortedFaceFiles = sorted(faceFiles, | |
| key=lambda data: (float(data.split('/')[-1][:-4])), | |
| reverse=False) | |
| faces = [] | |
| H = 112 | |
| if visualAug == True: | |
| new = int(H * random.uniform(0.7, 1)) | |
| x, y = numpy.random.randint(0, H - new), numpy.random.randint(0, H - new) | |
| M = cv2.getRotationMatrix2D((H / 2, H / 2), random.uniform(-15, 15), 1) | |
| augType = random.choice(['orig', 'flip', 'crop', 'rotate']) | |
| else: | |
| augType = 'orig' | |
| for faceFile in sortedFaceFiles[:numFrames]: | |
| face = cv2.imread(faceFile) | |
| face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY) | |
| face = cv2.resize(face, (H, H)) | |
| if augType == 'orig': | |
| faces.append(face) | |
| elif augType == 'flip': | |
| faces.append(cv2.flip(face, 1)) | |
| elif augType == 'crop': | |
| faces.append(cv2.resize(face[y:y + new, x:x + new], (H, H))) | |
| elif augType == 'rotate': | |
| faces.append(cv2.warpAffine(face, M, (H, H))) | |
| faces = numpy.array(faces) | |
| return faces | |
| def load_label(data, numFrames): | |
| res = [] | |
| labels = data[3].replace('[', '').replace(']', '') | |
| labels = labels.split(',') | |
| for label in labels: | |
| res.append(int(label)) | |
| res = numpy.array(res[:numFrames]) | |
| return res | |
| class train_loader(object): | |
| def __init__(self, cfg, trialFileName, audioPath, visualPath, num_speakers): | |
| self.cfg = cfg | |
| self.audioPath = audioPath | |
| self.visualPath = visualPath | |
| self.candidate_speakers = num_speakers | |
| self.path = os.path.join(cfg.DATA.dataPathAVA, "csv") | |
| self.entity_data = json.load(open(os.path.join(self.path, 'train_entity.json'))) | |
| self.ts_to_entity = json.load(open(os.path.join(self.path, 'train_ts.json'))) | |
| self.mixLst = open(trialFileName).read().splitlines() | |
| self.list_length = len(self.mixLst) | |
| random.shuffle(self.mixLst) | |
| def load_single_audio(self, audio, fps, numFrames, audioAug=False, aug_audio=None): | |
| if audioAug: | |
| augType = random.randint(0, 1) | |
| if augType == 1: | |
| audio = overlap(audio, aug_audio) | |
| else: | |
| audio = audio | |
| res = vggish_input.waveform_to_examples(audio, 16000, numFrames, fps, return_tensor=False) | |
| return res | |
| def load_visual_label_mask(self, videoName, entityName, target_ts, context_ts, visualAug=True): | |
| faceFolderPath = os.path.join(self.visualPath, videoName, entityName) | |
| faces = [] | |
| H = 112 | |
| if visualAug == True: | |
| new = int(H * random.uniform(0.7, 1)) | |
| x, y = numpy.random.randint(0, H - new), numpy.random.randint(0, H - new) | |
| M = cv2.getRotationMatrix2D((H / 2, H / 2), random.uniform(-15, 15), 1) | |
| augType = random.choice(['orig', 'flip', 'crop', 'rotate']) | |
| else: | |
| augType = 'orig' | |
| labels_dict = self.entity_data[videoName][entityName] | |
| labels = numpy.zeros(len(target_ts)) | |
| mask = numpy.zeros(len(target_ts)) | |
| for i, time in enumerate(target_ts): | |
| if time not in context_ts: | |
| faces.append(numpy.zeros((H, H))) | |
| else: | |
| labels[i] = labels_dict[time] | |
| mask[i] = 1 | |
| time = "%.2f" % float(time) | |
| faceFile = os.path.join(faceFolderPath, str(time) + '.jpg') | |
| face = cv2.imread(faceFile) | |
| face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY) | |
| face = cv2.resize(face, (H, H)) | |
| if augType == 'orig': | |
| faces.append(face) | |
| elif augType == 'flip': | |
| faces.append(cv2.flip(face, 1)) | |
| elif augType == 'crop': | |
| faces.append(cv2.resize(face[y:y + new, x:x + new], (H, H))) | |
| elif augType == 'rotate': | |
| faces.append(cv2.warpAffine(face, M, (H, H))) | |
| faces = numpy.array(faces) | |
| return faces, labels, mask | |
| def get_speaker_context(self, videoName, target_entity, all_ts, center_ts): | |
| context_speakers = list(self.ts_to_entity[videoName][center_ts]) | |
| context = {} | |
| chosen_speakers = [] | |
| context[target_entity] = all_ts | |
| context_speakers.remove(target_entity) | |
| num_frames = len(all_ts) | |
| for candidate in context_speakers: | |
| candidate_ts = self.entity_data[videoName][candidate] | |
| shared_ts = set(all_ts).intersection(set(candidate_ts)) | |
| if (len(shared_ts) > (num_frames / 2)): | |
| context[candidate] = shared_ts | |
| chosen_speakers.append(candidate) | |
| context_speakers = chosen_speakers | |
| random.shuffle(context_speakers) | |
| if not context_speakers: | |
| context_speakers.insert(0, target_entity) # make sure is at 0 | |
| while len(context_speakers) < self.candidate_speakers: | |
| context_speakers.append(random.choice(context_speakers)) | |
| elif len(context_speakers) < self.candidate_speakers: | |
| context_speakers.insert(0, target_entity) # make sure is at 0 | |
| while len(context_speakers) < self.candidate_speakers: | |
| context_speakers.append(random.choice(context_speakers[1:])) | |
| else: | |
| context_speakers.insert(0, target_entity) # make sure is at 0 | |
| context_speakers = context_speakers[:self.candidate_speakers] | |
| assert set(context_speakers).issubset(set(list(context.keys()))), target_entity | |
| assert target_entity in context_speakers, target_entity | |
| return context_speakers, context | |
| def __getitem__(self, index): | |
| target_video = self.mixLst[index] | |
| data = target_video.split('\t') | |
| fps = float(data[2]) | |
| videoName = data[0][:11] | |
| target_entity = data[0] | |
| all_ts = list(self.entity_data[videoName][target_entity].keys()) | |
| numFrames = int(data[1]) | |
| assert numFrames == len(all_ts) | |
| center_ts = all_ts[math.floor(numFrames / 2)] | |
| # get context speakers which have more than half time overlapped with target speaker | |
| context_speakers, context = self.get_speaker_context(videoName, target_entity, all_ts, | |
| center_ts) | |
| if self.cfg.TRAIN.AUDIO_AUG: | |
| other_indices = list(range(0, index)) + list(range(index + 1, self.list_length)) | |
| augment_entity = self.mixLst[random.choice(other_indices)] | |
| augment_data = augment_entity.split('\t') | |
| augment_entity = augment_data[0] | |
| augment_videoname = augment_data[0][:11] | |
| aug_sr, aug_audio = wavfile.read( | |
| os.path.join(self.audioPath, augment_videoname, augment_entity + '.wav')) | |
| else: | |
| aug_audio = None | |
| audio_path = os.path.join(self.audioPath, videoName, target_entity + '.wav') | |
| sr, audio = wavfile.read(os.path.join(self.audioPath, videoName, target_entity + '.wav')) | |
| audio = self.load_single_audio(audio, | |
| fps, | |
| numFrames, | |
| audioAug=self.cfg.TRAIN.AUDIO_AUG, | |
| aug_audio=aug_audio) | |
| visualFeatures, labels, masks = [], [], [] | |
| # target_label = list(self.entity_data[videoName][target_entity].values()) | |
| visual, target_labels, target_masks = self.load_visual_label_mask( | |
| videoName, target_entity, all_ts, all_ts) | |
| for idx, context_entity in enumerate(context_speakers): | |
| if context_entity == target_entity: | |
| label = target_labels | |
| visualfeat = visual | |
| mask = target_masks | |
| else: | |
| visualfeat, label, mask = self.load_visual_label_mask(videoName, context_entity, | |
| all_ts, | |
| context[context_entity]) | |
| visualFeatures.append(visualfeat) | |
| labels.append(label) | |
| masks.append(mask) | |
| audio = torch.FloatTensor(audio)[None, :, :] | |
| visualFeatures = torch.FloatTensor(numpy.array(visualFeatures)) | |
| audio_t = audio.shape[1] | |
| video_t = visualFeatures.shape[1] | |
| if audio_t != video_t * 4: | |
| print(visualFeatures.shape, audio.shape, videoName, target_entity, numFrames) | |
| labels = torch.LongTensor(numpy.array(labels)) | |
| masks = torch.LongTensor(numpy.array(masks)) | |
| print(audio.shape) | |
| return audio, visualFeatures, labels, masks | |
| def __len__(self): | |
| return len(self.mixLst) | |
| class val_loader(object): | |
| def __init__(self, cfg, trialFileName, audioPath, visualPath, num_speakers): | |
| self.cfg = cfg | |
| self.audioPath = audioPath | |
| self.visualPath = visualPath | |
| self.candidate_speakers = num_speakers | |
| self.path = os.path.join(cfg.DATA.dataPathAVA, "csv") | |
| self.entity_data = json.load(open(os.path.join(self.path, 'val_entity.json'))) | |
| self.ts_to_entity = json.load(open(os.path.join(self.path, 'val_ts.json'))) | |
| self.mixLst = open(trialFileName).read().splitlines() | |
| def load_single_audio(self, audio, fps, numFrames, audioAug=False, aug_audio=None): | |
| res = vggish_input.waveform_to_examples(audio, 16000, numFrames, fps, return_tensor=False) | |
| return res | |
| def load_visual_label_mask(self, videoName, entityName, target_ts, context_ts): | |
| faceFolderPath = os.path.join(self.visualPath, videoName, entityName) | |
| faces = [] | |
| H = 112 | |
| labels_dict = self.entity_data[videoName][entityName] | |
| labels = numpy.zeros(len(target_ts)) | |
| mask = numpy.zeros(len(target_ts)) | |
| for i, time in enumerate(target_ts): | |
| if time not in context_ts: | |
| faces.append(numpy.zeros((H, H))) | |
| else: | |
| labels[i] = labels_dict[time] | |
| mask[i] = 1 | |
| time = "%.2f" % float(time) | |
| faceFile = os.path.join(faceFolderPath, str(time) + '.jpg') | |
| face = cv2.imread(faceFile) | |
| face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY) | |
| face = cv2.resize(face, (H, H)) | |
| faces.append(face) | |
| faces = numpy.array(faces) | |
| return faces, labels, mask | |
| def get_speaker_context(self, videoName, target_entity, all_ts, center_ts): | |
| context_speakers = list(self.ts_to_entity[videoName][center_ts]) | |
| context = {} | |
| chosen_speakers = [] | |
| context[target_entity] = all_ts | |
| context_speakers.remove(target_entity) | |
| num_frames = len(all_ts) | |
| for candidate in context_speakers: | |
| candidate_ts = self.entity_data[videoName][candidate] | |
| shared_ts = set(all_ts).intersection(set(candidate_ts)) | |
| context[candidate] = shared_ts | |
| chosen_speakers.append(candidate) | |
| # if (len(shared_ts) > (num_frames / 2)): | |
| # context[candidate] = shared_ts | |
| # chosen_speakers.append(candidate) | |
| context_speakers = chosen_speakers | |
| random.shuffle(context_speakers) | |
| if not context_speakers: | |
| context_speakers.insert(0, target_entity) # make sure is at 0 | |
| while len(context_speakers) < self.candidate_speakers: | |
| context_speakers.append(random.choice(context_speakers)) | |
| elif len(context_speakers) < self.candidate_speakers: | |
| context_speakers.insert(0, target_entity) # make sure is at 0 | |
| while len(context_speakers) < self.candidate_speakers: | |
| context_speakers.append(random.choice(context_speakers[1:])) | |
| else: | |
| context_speakers.insert(0, target_entity) # make sure is at 0 | |
| context_speakers = context_speakers[:self.candidate_speakers] | |
| assert set(context_speakers).issubset(set(list(context.keys()))), target_entity | |
| return context_speakers, context | |
| def __getitem__(self, index): | |
| target_video = self.mixLst[index] | |
| data = target_video.split('\t') | |
| fps = float(data[2]) | |
| videoName = data[0][:11] | |
| target_entity = data[0] | |
| all_ts = list(self.entity_data[videoName][target_entity].keys()) | |
| numFrames = int(data[1]) | |
| # print(numFrames, len(all_ts)) | |
| assert numFrames == len(all_ts) | |
| center_ts = all_ts[math.floor(numFrames / 2)] | |
| # get context speakers which have more than half time overlapped with target speaker | |
| context_speakers, context = self.get_speaker_context(videoName, target_entity, all_ts, | |
| center_ts) | |
| sr, audio = wavfile.read(os.path.join(self.audioPath, videoName, target_entity + '.wav')) | |
| audio = self.load_single_audio(audio, fps, numFrames, audioAug=False) | |
| visualFeatures, labels, masks = [], [], [] | |
| # target_label = list(self.entity_data[videoName][target_entity].values()) | |
| target_visual, target_labels, target_masks = self.load_visual_label_mask( | |
| videoName, target_entity, all_ts, all_ts) | |
| for idx, context_entity in enumerate(context_speakers): | |
| if context_entity == target_entity: | |
| label = target_labels | |
| visualfeat = target_visual | |
| mask = target_masks | |
| else: | |
| visualfeat, label, mask = self.load_visual_label_mask(videoName, context_entity, | |
| all_ts, | |
| context[context_entity]) | |
| visualFeatures.append(visualfeat) | |
| labels.append(label) | |
| masks.append(mask) | |
| audio = torch.FloatTensor(audio)[None, :, :] | |
| visualFeatures = torch.FloatTensor(numpy.array(visualFeatures)) | |
| audio_t = audio.shape[1] | |
| video_t = visualFeatures.shape[1] | |
| if audio_t != video_t * 4: | |
| print(visualFeatures.shape, audio.shape, videoName, target_entity, numFrames) | |
| labels = torch.LongTensor(numpy.array(labels)) | |
| masks = torch.LongTensor(numpy.array(masks)) | |
| return audio, visualFeatures, labels, masks | |
| def __len__(self): | |
| return len(self.mixLst) | |