We evaluated the performance of the fine-tuned models via linear probing. We fit a panelized logistic regression model to predict brain layer (WM, L1-L6) using image embeddings. This involved training a simple linear classifier on 80\% of the data, sampled with five different seeds, using the embeddings from both the fine-tuned and zero-shot models (CLIP, PLIP, and UNI). This code is adopted from PLIP GitHub repository.
# python
import os
import pandas as pd
import numpy as np
from sklearn.metrics import auc, roc_curve, f1_score, recall_score, precision_score, matthews_corrcoef, accuracy_score, classification_report,silhouette_score
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import LabelEncoder
import glob
def eval_metrics(y_true, y_pred, y_pred_proba = None, average_method='weighted'):
assert len(y_true) == len(y_pred)
if y_pred_proba is None:
auroc = np.nan
elif len(np.unique(y_true)) > 2:
print('Multiclass AUC is not currently available.')
auroc = np.nan
else:
fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba)
auroc = auc(fpr, tpr)
f1 = f1_score(y_true, y_pred, average = average_method)
tp,fp,tn,fn = 0,0,0,0
for i in range(len(y_pred)):
if y_true[i]==y_pred[i]==1:
tp += 1
if y_pred[i]==1 and y_true[i]!=y_pred[i]:
fp += 1
if y_true[i]==y_pred[i]==0:
tn += 1
if y_pred[i]==0 and y_true[i]!=y_pred[i]:
fn += 1
if (tp+fn) == 0: sensitivity = np.nan
else: sensitivity = tp/(tp+fn) # recall
if (tn+fp) == 0: specificity = np.nan
else: specificity = tn/(tn+fp)
if (tp+fp) == 0: ppv = np.nan
else: ppv = tp/(tp+fp) # precision or positive predictive value (PPV)
if (tn+fn) == 0: npv = np.nan
else: npv = tn/(tn+fn) # negative predictive value (NPV)
if (tp+tn+fp+fn) == 0: hitrate = np.nan
else: hitrate = (tp+tn)/(tp+tn+fp+fn) # accuracy (ACC)
performance = {#'Accuracy': acc,
'AUC': auroc,
'WF1': f1,
#'precision': precision,
#'recall': recall,
#'mcc': mcc,
'tp': tp,
'fp': fp,
'tn': tn,
'fn': fn,
'sensitivity': sensitivity,
'specificity': specificity,
'ppv': ppv,
'npv': npv,
'hitrate': hitrate,
'instances' : len(y_true)}
return performance
def run_classification(train_x, train_y, test_x, test_y, val_x, val_y, seed=1, alpha=0.1,penalty="l2"):
classifier = SGDClassifier(random_state=seed, loss="log_loss",
alpha=alpha, verbose=0,
penalty="l2", max_iter=10000, class_weight="balanced")
train_y = train_y.to_numpy()
test_y = test_y.to_numpy()
val_y = val_y.to_numpy()
classifier.fit(train_x, train_y)
test_pred = classifier.predict(test_x)
train_pred = classifier.predict(train_x)
val_pred = classifier.predict(val_x)
train_matrics = eval_metrics(train_y, train_pred, average_method="macro")
test_metrics = eval_metrics(test_y, test_pred, average_method="macro")
val_metrics = eval_metrics(val_y, val_pred, average_method="macro")
return {'train_f1': train_matrics['WF1'], 'test_f1': test_metrics['WF1'], 'val_f1': val_metrics['WF1'], 'alpha': alpha}
# path to the embedding
model_path = glob.glob('./embedding/withanno/*')
# path to the annotation
anno_path = glob.glob('./anno_forlp/*')
anno_path = sorted(anno_path)
anno_name = [i.replace('./anno_forlp/','') for i in anno_path]
anno_name = [i.replace('_anno.csv','') for i in anno_name]
data_table = pd.DataFrame({'anno_name':anno_name,'anno_path':anno_path})
model_record_best = []
for j in range(len(model_path)):
embedding_all = pd.read_csv(model_path[j],index_col=0)
model_name = model_path[j].replace('./embedding/withanno/','')
model_name = model_name.replace('_withanno.csv','')
print(model_name)
for i in range(data_table.shape[0]):
print(data_table['anno_path'][i])
anno_temp = pd.read_csv(data_table['anno_path'][i],sep='\t',index_col=0)
anno_temp = anno_temp[anno_temp.V2 != '']
anno_temp = anno_temp[anno_temp.V2.notnull()]
anno_temp = anno_temp[anno_temp.V2 != 'undetermined']
anno_temp = anno_temp[anno_temp.V2 != 'Exclude']
#print(anno_temp.V2.value_counts(dropna=False))
#
index_keep = anno_temp.index.intersection(embedding_all.index)
embedding = embedding_all.loc[index_keep]
anno_temp = anno_temp.loc[index_keep]
le = LabelEncoder()
anno_temp['mapped_y'] = le.fit_transform(anno_temp.V2)
all_records_dataset = []
for k in range(5):
# random split 80% train 10% test 10% validation
np.random.seed(k)
train_index = np.random.choice(anno_temp.index, int(0.8*anno_temp.shape[0]), replace=False)
test_index = np.random.choice(list(set(anno_temp.index)-set(train_index)), int(0.1*anno_temp.shape[0]), replace=False)
val_index = list(set(anno_temp.index)-set(train_index)-set(test_index))
train_y = anno_temp.loc[train_index].mapped_y
train_x = embedding.loc[train_index].to_numpy()
test_y = anno_temp.loc[test_index].mapped_y
test_x = embedding.loc[test_index].to_numpy()
val_y = anno_temp.loc[val_index].mapped_y
val_x = embedding.loc[val_index].to_numpy()
all_records = []
for alpha in [1.0, 0.1, 0.01, 0.001,0.0001]:
metrics = run_classification(train_x, train_y, test_x, test_y, val_x, val_y, alpha = alpha,penalty='l2')
metrics["alpha"] = alpha
metrics["test_on"] = 'split'+str(k)
metrics["model_name"] = model_name
all_records.append(metrics)
all_records_dataset.extend(all_records)
#
all_records_dataset_df = pd.DataFrame(all_records_dataset)
best_alpha = all_records_dataset_df.groupby('alpha')['val_f1'].mean().idxmax()
mean_wf1 = all_records_dataset_df[all_records_dataset_df['alpha'] == best_alpha]['test_f1'].mean()
std_wf1 = all_records_dataset_df[all_records_dataset_df['alpha'] == best_alpha]['test_f1'].std()
record_best = {'model_name':model_name,'best_alpha':best_alpha,'mean_wf1':mean_wf1,'std_wf1':std_wf1,'data_name':data_table['anno_name'][i]}
model_record_best.append(record_best)
model_record_best_df = pd.DataFrame(model_record_best)
model_record_best_df.to_csv('linear_probing_result.csv',index=False,sep='\t')
# path to zero-shot embedding
model_name_zero = ['CLIP','PLIP','uni']
model_record_best = []
for model_name in model_name_zero:
embedding1 = pd.read_csv(f'./zero_shot_embedding/{model_name}_human_image_feature.csv',index_col=0)
embedding2 = pd.read_csv(f'./zero_shot_embedding/{model_name}_mouse_image_feature.csv',index_col=0)
embedding_all = pd.concat([embedding1,embedding2])
print(model_name)
for i in range(data_table.shape[0]):
print(data_table['anno_path'][i])
anno_temp = pd.read_csv(data_table['anno_path'][i],sep='\t',index_col=0)
anno_temp = anno_temp[anno_temp.V2 != '']
anno_temp = anno_temp[anno_temp.V2.notnull()]
anno_temp = anno_temp[anno_temp.V2 != 'undetermined']
anno_temp = anno_temp[anno_temp.V2 != 'Exclude']
index_keep = anno_temp.index.intersection(embedding_all.index)
embedding = embedding_all.loc[index_keep]
anno_temp = anno_temp.loc[index_keep]
le = LabelEncoder()
anno_temp['mapped_y'] = le.fit_transform(anno_temp.V2)
all_records_dataset = []
for k in range(5):
# random split 80% train 10% test 10% validation
np.random.seed(k)
train_index = np.random.choice(anno_temp.index, int(0.8*anno_temp.shape[0]), replace=False)
test_index = np.random.choice(list(set(anno_temp.index)-set(train_index)), int(0.1*anno_temp.shape[0]), replace=False)
val_index = list(set(anno_temp.index)-set(train_index)-set(test_index))
train_y = anno_temp.loc[train_index].mapped_y
train_x = embedding.loc[train_index].to_numpy()
test_y = anno_temp.loc[test_index].mapped_y
test_x = embedding.loc[test_index].to_numpy()
val_y = anno_temp.loc[val_index].mapped_y
val_x = embedding.loc[val_index].to_numpy()
all_records = []
for alpha in [1.0, 0.1, 0.01, 0.001,0.0001]:
metrics = run_classification(train_x, train_y, test_x, test_y, val_x, val_y, alpha = alpha)
metrics["alpha"] = alpha
metrics["test_on"] = 'split'+str(k)
metrics["model_name"] = model_name
all_records.append(metrics)
all_records_dataset.extend(all_records)
#
all_records_dataset_df = pd.DataFrame(all_records_dataset)
best_alpha = all_records_dataset_df.groupby('alpha')['val_f1'].mean().idxmax()
mean_wf1 = all_records_dataset_df[all_records_dataset_df['alpha'] == best_alpha]['test_f1'].mean()
std_wf1 = all_records_dataset_df[all_records_dataset_df['alpha'] == best_alpha]['test_f1'].std()
record_best = {'model_name':model_name,'best_alpha':best_alpha,'mean_wf1':mean_wf1,'std_wf1':std_wf1,'data_name':data_table['anno_name'][i]}
model_record_best.append(record_best)
model_record_best_df = pd.DataFrame(model_record_best)
model_record_best_df.to_csv('linear_probing_zero_shot_result.csv',index=False,sep='\t')
PREVIOUSFine-tune PLIP