Tcav 101
TCAV Introduction
Understanding a deep learning model is an open topic and is quite subjective. In part, I’m still unsure what exactly understanding a model means and this is brought up in many papers as understanding can be very different depending on your technical understandings.
What makes a model interpretable? Some good blog posts that similarly touch on this topic are:
While those are great and talk about different ways of understanding models, the TCAV score is based around the idea that specific layers will be activated by features or “concepts” more than others.
While understanding images is easy to interpret, there is not a ton of research when deep learning models use datasets that are not so easy to visualize. Take for instance datasets about fraud detection or medical diagnosis that are not based on images.
The dataset I am using is SWAT and is a time-series dataset with multiple features where the target is a column that contains either 0 or 1 indicating an attack or not.
Understanding TCAV (Testing with Concept Activation Vectors)
The score for the sensitivity is given as
$$ \nabla \textcolor{green}{ h_{l, k} ( \textcolor{blue}{ f_l( \textcolor{black}{X_{input}}) } ) } \cdot \textcolor{red}{v_C^l} $$
Where \(f_l\) is the model split to the l-th layer and $h_{l,k}$ is the model from the l-th layer to the k-th output class. The vector $v_C^l$ is obtained by taking the “vector orthogonal to the classification boundary” which means that with the linear classifier we trained on the layer l activations, we are taking the coefficients from that vector.
What does this mean? The greater the dot product of the orthogonal vector and gradients, the higher the sensitivity score. This is somewhat obvious but thinking about what it implies about the sensitivity is that for instance the more separable the linear classifier is at the layer we split upon and the greater the values of the gradient of the second part of the model, the more we can say that this layer is finding some signal related to the concept trained on at this layer.
For something like images, this is quite intuitive, for instance zebras have the concept of stripes so the images with stripes will be be more sensitive at a layer that delineates stripes. For datasets where we cannot so easily map concepts to visualizable information, this can mean that a high-level concept that a human can understand but cannot be clearly delineated in a data-set (for instance a concept for the SWaT dataset could be something such as attacks on valves where the attack may be similar to the system at another time during the process).
While I understand the mechanism to calculate the sensitivity and TCAV score, I am still interested in further understanding what the sensitivity implies and its usefulness compared to other scores or to comparing it amongst itself.
Calculating TCAV
The idea behind TCAV relies on the idea that a neural network will have some layers that are more sensitive to specific ideas or concepts.
The rough breakdown of the entire process: Train a neural network. With this trained model you will then split the model on some layer as your “bottleneck”.
With this bottle neck (or model split), you will feed a set of data that exemplifies a concept as well as counter examples in the original model and get the activations at this layer.
With these activations (which you may need to flatten if they are convolved or something similar) you will train a linear classifier where these activations are the features and the concept/counterexamples are the inputs. From these inputs you then calculate the gradient of the second part of the model from the activations and take the dot product of this with the coefficients of the linear classifier. In this sense, this explainability technique is only useful if you have the correct labels as well and cannot directly be used to explain a model during production.
In general, the counterexamples are up to your discretion and for something like an image this seems more clear to do. The most straightforward counterexamples for image datasets generally are randomly generated image comprised of noise. This is easy to understand as an image comprised of nothing but noise does not really have anything in common with a concept such as “sky” or “stripes”. For datasets not from images, this does not seem as clear to me and generally the counterexamples should be directly selected.
Code
While there is an official implementation and a keras implementation, both are based around tensorflow <= 2.0. While most the implementations are done in tensorflow, it is possible to do this all in pytorch as well but rather than splitting then model and creating two new models, you add hooks into the model to capture the activations and then the gradient from some point of the model to the end. This is one of the few times I found that a tensorflow implementation was actually more understandable and easy than the equivalent pytorch implementation but perhaps there is a better way to do it.
I chose to implement it myself and the code can be found here but I will walk through the basics of it.
The first part of the code, we must split a trained neural network into a $f_l$ and a $h_l, k$, to do this, we use a functional or sequential keras model.
def use_bottleneck(self, bottleneck: int):
"""split the model into pre and post models for tcav linear model
Args:
layer (int): layer to split nn model
"""
if bottleneck < 0 or bottleneck >= len(self.model.layers):
raise ValueError("Bottleneck layer must be greater than or equal to 0 and less than the number of layers!")
self.model_f = tf.keras.Model(inputs=self.model.input, outputs=self.model.layers[bottleneck].output)
# create model h functional
model_h_input = tf.keras.layers.Input(self.model.layers[bottleneck + 1].input_shape[1:])
model_h = model_h_input
for layer in self.model.layers[bottleneck + 1 :]:
model_h = layer(model_h)
self.model_h = tf.keras.Model(inputs=model_h_input, outputs=model_h)
self.bottleneck_layer = self.model.layers[bottleneck]
Now that we have the original model split into 2, we need to train the linear classifier such that we can get CAV scores:
def train_cav(self, concepts, counterexamples):
concept_activations = self.model_f.predict(concepts)
counterexamples_activations = self.model_f.predict(counterexamples)
x = np.concatenate([concept_activations, counterexamples_activations])
x = x.reshape(x.shape[0], -1)
y = np.concatenate([np.ones(len(concept_activations)), np.zeros(len(counterexamples_activations))])
self.lm.fit(x, y)
self.coefs = self.lm.coef_
self.cav = np.transpose(-1 * self.coefs)
This gives us the Concept Activation Vector which is the coefficients of our linear classifier. There are a few ways to interpret this but it seems like the general idea is that the larger these coefficients, the more separable the concepts are from the counterexamples. The more separable, in theory the more “signal” is being noticed on the model input at this layer. This still is only part of the TCAV though and we need to calculate the gradient and then the sensitivity score as such:
def calculate_sensitivty(self, concepts, concepts_labels, counterexamples, counterexamples_labels):
"""the sensitivity scores come from dot product of the gradients with the CAV"""
activations = np.concatenate([self.model_f.predict(concepts), self.model_f.predict(counterexamples)])
labels = np.concatenate([concepts_labels, counterexamples_labels])
grad_vals = []
for x, y in zip(activations, labels):
x = tf.convert_to_tensor(np.expand_dims(x, axis=0), dtype=tf.float32)
y = tf.convert_to_tensor(np.expand_dims(y, axis=0), dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
y_out = self.model_h(x)
loss = tf.keras.backend.categorical_crossentropy(y, y_out)
grad_vals.append(tape.gradient(loss, x).numpy())
grad_vals = np.array(grad_vals).squeeze()
self.sensitivity = np.dot(grad_vals.reshape(grad_vals.shape[0], -1), self.cav)
self.labels = labels
self.grad_vals = grad_vals
def sensitivity_score(self):
"""Print the sensitivities in a readable way"""
num_classes = self.labels.shape[-1]
sens_for_class_k = {}
for k in range(0, num_classes):
class_idxs = np.where(self.labels[:, k] == 1)
if len(class_idxs[0]) == 0:
sens_for_class_k[k] = None
else:
sens_for_class = self.sensitivity[class_idxs[0]]
sens_for_class_k[k] = len(sens_for_class[sens_for_class > 0]) / len(sens_for_class)
return sens_for_class_k
This gives us the sensitivity between a concept and a provided counter example. Then to use this, we can do something as such:
attack_info_df = get_attack_info_df(pdf_path=model_df_dir / "docs/List_of_attacks_Final.pdf")
concept, counterexamples = create_concept(df, attack_info_df, [10, 11])
concepts_gen = tf.keras.preprocessing.sequence.TimeseriesGenerator(
concept.drop(TARGETCOL, axis=1).values,
concept[TARGETCOL].values,
length=TIMESERIES_LENGTH,
batch_size=1,
shuffle=True,
)
# use stride to balance the number of samples somehow?
counterexamples_gen = tf.keras.preprocessing.sequence.TimeseriesGenerator(
counterexamples.drop(TARGETCOL, axis=1).values,
counterexamples[TARGETCOL].values,
length=TIMESERIES_LENGTH,
batch_size=1,
stride=round(len(counterexamples) / len(concept)),
shuffle=True,
)
concepts_x = []
concepts_y = []
for x, y in concepts_gen:
concepts_x.append(x)
concepts_y.append(y)
counterexamples_x = []
counterexamples_y = []
for x, y in counterexamples_gen:
counterexamples_x.append(x)
counterexamples_y.append(y)
concepts_x = np.array(concepts_x).squeeze()
concepts_y = np.array(concepts_y).squeeze()
counterexamples_x = np.array(counterexamples_x).squeeze()
counterexamples_y = np.array(counterexamples_y).squeeze()
model = tf.keras.models.load_model(model_path)
tcav = TCAV(model)
for layer_n in range(1, len(tcav.model.layers) - 1):
tcav.use_bottleneck(layer_n)
tcav.train_cav(concepts_x, counterexamples_x)
tcav.calculate_sensitivty(concepts_x, counterexamples_x)
sensitivity_score = tcav.sensitivity_score()
logger.info(f"=== === ===")
logger.info(f"sensitivity scores for LAYER: {layer_n} of type: {tcav.bottleneck_layer.name}")
logger.info(f"[class 0 to concept] ==> {sensitivity_score[0]}")
logger.info(f"[class 1 to concept] ==> {sensitivity_score[1]}")
logger.info(f"=== === ===")
Full Implementation
An example code base and the included dataset can be found here: https://gitlab.com/besiktas/falcon_tcav
Experiments
Using the models trained on the SWAT dataset (which is a time-series), I ran experiments based on the premise that certain attack types comprised a concept. While this does show some validity, it is much more ambiguous what this means as opposed to image based models where high-level concepts can be understood as something visualizable.
From the experiments, one thing I think I came away with is that the sensitivity to the concepts is not monotonic. I assumed that in general the sensitivity score would go up from the beginning of the model (or possibly decrease). I don’t think this is true as I saw many times where there would be peaks in the sensitivity scores at the earlier layers or in the layer layers but not the other parts of the model.
Further investigations
A few things I’d like to further investigate are things such as, is it possible to extract a sensitivity score from something such as a transformer? For instance, predicting wether a comment or post is sincere wether it uses abrasive language or cursing. Does this explain anything about how the model is “understanding” what it is taking as input.
Another investigation is what the score means itself. While TCAV is related to the fractional percent that is $>0$, what does the value of the score itself say about the concept related to the output? Perhaps some sorts of models can have high sensitivities related to specific types of attacks (or concepts) which are useful for models in real world environments where the output is not the end all be all but used to guide human judgement ()
Lastly, one of the ideas mentioned in the original paper had to do with adversarial attacks on the network and understanding if TCAVs can be used to detect these sorts of attack.
Other
If there are any corrections or questions please let me know!
References/Links
- https://github.com/tensorflow/tcav
- https://github.com/pnxenopoulos/cav-keras
- https://shanzhenren.github.io/csci-699-replnlp-2019fall/lectures/W9-L1-Interpretability.pdf