This notebook is a basic demo of using lartpc_mlreco3d API to train and run inference, as well as display the events and the network output. We will focus on simple UResNet network using SCN library. You can start by cloning the current code from Github repository.

Note: Use the develop branch.

  1. Using YAML config file
  2. Using API functions

    2.1. Training loop

    2.2. Inference

    2.3. Event displays

    2.4. Analysis and display of the network output

In [1]:
import numpy as np
import matplotlib
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import pyplot as plt
import pandas as pd
import seaborn

seaborn.set(rc={
    'figure.figsize':(15, 10),
})
seaborn.set_context('talk') # or paper

import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=False)

Let's load the lartpc_mlreco3d module first (change the path to wherever you cloned the repository):

In [2]:
import sys
sys.path.append("/u/ki/ldomine/lartpc_mlreco3d/")
In [3]:
# If necessary reload the mlreco module after modifying it
# import importlib
# importlib.reload(mlreco.main_funcs)

1. Using YAML config file

First option is to define the configuration in a YAML text file or string, and load it directly into training or inference loop.

In [4]:
from mlreco.main_funcs import process_config, train, inference
import yaml

We will be using the LArCV data files in /gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined directory, more specifically the files starting with train_512px. For more explanations on this configuration see section 2.

In [5]:
cfg = """
iotool:
  batch_size: 4
  shuffle: False
  num_workers: 4
  collate_fn: CollateSparse
  sampler:
    name: RandomSequenceSampler
    batch_size: 4
  dataset:
    name: LArCVDataset
    data_dirs:
      - /gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined
    data_key: train_512px
    limit_num_files: 10
    schema:
      input_data:
        - parse_sparse3d_scn
        - sparse3d_data
      segment_label:
        - parse_sparse3d_scn
        - sparse3d_fivetypes
model:
  name: uresnet
  modules:
    uresnet:
      num_strides: 5
      filters: 16
      num_classes: 5
      data_dim: 3
      spatial_size: 512
  network_input:
    - input_data
  loss_input:
    - segment_label
training:
  seed: -1
  learning_rate: 0.001
  gpus: '0'
  weight_prefix: weights_trash/snapshot
  iterations: 10
  report_step: 1
  checkpoint_step: 500
  log_dir: log_trash
  model_path: ''
  train: True
  debug: False
  minibatch_size: -1
"""
In [6]:
cfg = yaml.load(cfg, Loader=yaml.Loader)
process_config(cfg)
{   'iotool': {   'batch_size': 4,
                  'collate_fn': 'CollateSparse',
                  'dataset': {   'data_dirs': [   '/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined'],
                                 'data_key': 'train_512px',
                                 'limit_num_files': 10,
                                 'name': 'LArCVDataset',
                                 'schema': {   'input_data': [   'parse_sparse3d_scn',
                                                                 'sparse3d_data'],
                                               'segment_label': [   'parse_sparse3d_scn',
                                                                    'sparse3d_fivetypes']}},
                  'num_workers': 4,
                  'sampler': {'batch_size': 4, 'name': 'RandomSequenceSampler'},
                  'shuffle': False},
    'model': {   'loss_input': ['segment_label'],
                 'modules': {   'uresnet': {   'data_dim': 3,
                                               'filters': 16,
                                               'num_classes': 5,
                                               'num_strides': 5,
                                               'spatial_size': 512}},
                 'name': 'uresnet',
                 'network_input': ['input_data']},
    'training': {   'checkpoint_step': 500,
                    'debug': False,
                    'gpus': [0],
                    'iterations': 10,
                    'learning_rate': 0.001,
                    'log_dir': 'log_trash',
                    'minibatch_size': 4,
                    'model_path': '',
                    'report_step': 1,
                    'seed': 1559594062,
                    'train': True,
                    'weight_prefix': 'weights_trash/snapshot'}}
In [7]:
if cfg['training']['train']:
    train(cfg)
else:
    inference(cfg)
4
Welcome to JupyROOT 6.16/00
Loading file: /gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined/train_512px.root
Loading tree sparse3d_data
Loading tree sparse3d_fivetypes
Iter. 0 (epoch 0) @ 2019-06-03 13:34:28 ... train time 26.49% (0.957 [s]) mem. 0.305 GB 
   Segmentation: loss 1.7094 accuracy 0.0297855

Iter. 1 (epoch 0) @ 2019-06-03 13:34:32 ... train time 99.99% (0.562 [s]) mem. 0.305 GB 
   Segmentation: loss 1.6261 accuracy 0.0765092

Iter. 2 (epoch 0) @ 2019-06-03 13:34:32 ... train time 98.04% (0.785 [s]) mem. 0.346 GB 
   Segmentation: loss 1.699 accuracy 0.0631898

Iter. 3 (epoch 0) @ 2019-06-03 13:34:33 ... train time 98.85% (0.72 [s]) mem. 0.346 GB 
   Segmentation: loss 1.6291 accuracy 0.0439133

Iter. 4 (epoch 0) @ 2019-06-03 13:34:34 ... train time 99.99% (0.817 [s]) mem. 0.354 GB 
   Segmentation: loss 1.5805 accuracy 0.126151

Iter. 5 (epoch 0) @ 2019-06-03 13:34:34 ... train time 99.99% (0.753 [s]) mem. 0.354 GB 
   Segmentation: loss 1.5846 accuracy 0.106087

Iter. 6 (epoch 0) @ 2019-06-03 13:34:35 ... train time 98.16% (0.878 [s]) mem. 0.362 GB 
   Segmentation: loss 1.6011 accuracy 0.118148

Iter. 7 (epoch 0) @ 2019-06-03 13:34:36 ... train time 97.11% (0.688 [s]) mem. 0.362 GB 
   Segmentation: loss 1.5368 accuracy 0.138305

Iter. 8 (epoch 0) @ 2019-06-03 13:34:37 ... train time 99.98% (0.806 [s]) mem. 0.362 GB 
   Segmentation: loss 1.5003 accuracy 0.172841

Iter. 9 (epoch 0) @ 2019-06-03 13:34:38 ... train time 99.99% (0.861 [s]) mem. 0.366 GB 
   Segmentation: loss 1.5232 accuracy 0.184089

2. Using API functions

In [8]:
from mlreco.trainval import trainval
from mlreco.main_funcs import make_directories
from mlreco.main_funcs import get_data_minibatched
from mlreco.main_funcs import process_config
from mlreco.iotools.factories import loader_factory
from mlreco.main_funcs import cycle

2.1. Training loop

Let's start by defining I/O configuration. We will be using the LArCV data files in /gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined directory, more specifically the files starting with train_512px.

The schema field defines several data products. For each data product, the first string is the parser function name (e.g. parse_sparse3d_scn) and the following strings are TTree branch names to be fed to the parser.

The batch size needs to be defined in 2 places: batch_size field and in sampler.batch_size field.

In [9]:
io_cfg = {
    "batch_size": 4,
    "shuffle": False,
    "num_workers": 4,
    "collate_fn": "CollateSparse",
    "sampler": {
        "name": "RandomSequenceSampler",
        "batch_size": 4
    },
    "dataset": {
        "name": "LArCVDataset",
        "data_dirs": ["/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined"],
        "data_key": "train_512px",
        "limit_num_files": 10,
        "schema": {
            "input_data": ["parse_sparse3d_scn", "sparse3d_data"],
            "segment_label": ["parse_sparse3d_scn", "sparse3d_fivetypes"]
        }
    }
}

Then the network configuration, here we will train UResNet using SCN library, with depth 5 and 16 filters. We need to tell it the data spatial size and dimensions. The network_input and loss_input refer to I/O configuration names, they need to match between io_cfg and model_cfg.

In [10]:
model_cfg = {
    "name": "uresnet",
    "modules": {
        "uresnet": {
            "num_strides": 5,
            "filters": 16,
            "num_classes": 5,
            "data_dim": 3,
            "spatial_size": 512
        }
    },
    "network_input": ["input_data"],
    "loss_input": ["segment_label"]
}

Finally training configuration defines some global parameters such as learning rate and weights directory:

In [11]:
train_cfg = {
    "seed": 123,
    "learning_rate": 0.001,
    "gpus": '0',
    "weight_prefix": "weights_trash/snapshot",
    "iterations": 10,
    "report_step": 1,
    "checkpoint_step": 500,
    "log_dir": "log_trash",
    "model_path": "",
    "train": True,
    "debug": False,
    "minibatch_size": -1
}

The final configuration groups I/O, model and training:

In [12]:
cfg = {
    "iotool": io_cfg,
    "model": model_cfg,
    "training": train_cfg
}
process_config(cfg)
{   'iotool': {   'batch_size': 4,
                  'collate_fn': 'CollateSparse',
                  'dataset': {   'data_dirs': [   '/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined'],
                                 'data_key': 'train_512px',
                                 'limit_num_files': 10,
                                 'name': 'LArCVDataset',
                                 'schema': {   'input_data': [   'parse_sparse3d_scn',
                                                                 'sparse3d_data'],
                                               'segment_label': [   'parse_sparse3d_scn',
                                                                    'sparse3d_fivetypes']}},
                  'num_workers': 4,
                  'sampler': {'batch_size': 4, 'name': 'RandomSequenceSampler'},
                  'shuffle': False},
    'model': {   'loss_input': ['segment_label'],
                 'modules': {   'uresnet': {   'data_dim': 3,
                                               'filters': 16,
                                               'num_classes': 5,
                                               'num_strides': 5,
                                               'spatial_size': 512}},
                 'name': 'uresnet',
                 'network_input': ['input_data']},
    'training': {   'checkpoint_step': 500,
                    'debug': False,
                    'gpus': [0],
                    'iterations': 10,
                    'learning_rate': 0.001,
                    'log_dir': 'log_trash',
                    'minibatch_size': 4,
                    'model_path': '',
                    'report_step': 1,
                    'seed': 123,
                    'train': True,
                    'weight_prefix': 'weights_trash/snapshot'}}

Let's load the dataset using the configuration:

In [13]:
loader, cfg['data_keys'] = loader_factory(cfg)
dataset = iter(cycle(loader))
4
Loading file: /gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined/train_512px.root
Loading tree sparse3d_data
Loading tree sparse3d_fivetypes

Initialize the trainer:

In [14]:
Trainer = trainval(cfg)
loaded_iteration = Trainer.initialize()

Create directories for weights and log files:

In [15]:
make_directories(cfg, loaded_iteration)

Now the actual training loop:

In [16]:
for i in range(10):
    data_blob = get_data_minibatched(dataset, cfg)
    res = Trainer.train_step(data_blob)
    print(i, res)
0 {'accuracy': 0.06412212701643931, 'loss_seg': 1.7459412813186646}
1 {'accuracy': 0.08063084899237064, 'loss_seg': 1.6731960773468018}
2 {'accuracy': 0.09538538903173195, 'loss_seg': 1.6820142269134521}
3 {'accuracy': 0.1340081820900174, 'loss_seg': 1.6716103553771973}
4 {'accuracy': 0.09003476534294935, 'loss_seg': 1.6943106651306152}
5 {'accuracy': 0.1569436772876498, 'loss_seg': 1.6227076053619385}
6 {'accuracy': 0.293962765198199, 'loss_seg': 1.5047603845596313}
7 {'accuracy': 0.27419106686415284, 'loss_seg': 1.5663695335388184}
8 {'accuracy': 0.3285321106821949, 'loss_seg': 1.5413806438446045}
9 {'accuracy': 0.2567255054530958, 'loss_seg': 1.5678668022155762}

Visualizing metrics

Now we want to visualize the loss and accuracy curves, so we load the CSV log file and plot these metrics.

In [17]:
log = np.genfromtxt('log_trash/train_log-0000000.csv', names=True, delimiter=',')
In [18]:
seaborn.set(rc={
    'figure.figsize':(15, 10),
})
seaborn.set_context('talk') # or paper
In [19]:
plt.plot(log['iter'], log['loss_seg'])
plt.title("UResNet loss")
plt.xlabel("Iterations")
plt.ylabel("Segmentation Loss")
Out[19]:
Text(0, 0.5, 'Segmentation Loss')
In [20]:
plt.plot(log['iter'], log['acc_seg'])
plt.title("UResNet nonzero accuracy")
plt.xlabel("Iterations")
plt.ylabel("Nonzero accuracy")
Out[20]:
Text(0, 0.5, 'Nonzero accuracy')

2.2. Inference

In [21]:
from mlreco.output_formatters import output

In order to record event displays in CSV files we need to specify output formatters:

In [22]:
model_cfg['outputs'] = ["input"]

The training configuration is almost same, we only need to change the train and model_path parameters. Let's use weights from a previous training:

In [23]:
inference_cfg = train_cfg.copy()
inference_cfg['gpus'] = '0'
inference_cfg["train"] = False
inference_cfg["model_path"] = "/gpfs/slac/staas/fs1/g/neutrino/ldomine/ppn_uresnet/weights_uresnet2/snapshot-999.ckpt"

We rebuild the configuration:

In [24]:
cfg = {
    "iotool": io_cfg,
    "model": model_cfg,
    "training": inference_cfg
}
process_config(cfg)
{   'iotool': {   'batch_size': 4,
                  'collate_fn': 'CollateSparse',
                  'dataset': {   'data_dirs': [   '/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined'],
                                 'data_key': 'train_512px',
                                 'limit_num_files': 10,
                                 'name': 'LArCVDataset',
                                 'schema': {   'input_data': [   'parse_sparse3d_scn',
                                                                 'sparse3d_data'],
                                               'segment_label': [   'parse_sparse3d_scn',
                                                                    'sparse3d_fivetypes']}},
                  'num_workers': 4,
                  'sampler': {'batch_size': 4, 'name': 'RandomSequenceSampler'},
                  'shuffle': False},
    'model': {   'loss_input': ['segment_label'],
                 'modules': {   'uresnet': {   'data_dim': 3,
                                               'filters': 16,
                                               'num_classes': 5,
                                               'num_strides': 5,
                                               'spatial_size': 512}},
                 'name': 'uresnet',
                 'network_input': ['input_data'],
                 'outputs': ['input']},
    'training': {   'checkpoint_step': 500,
                    'debug': False,
                    'gpus': [0],
                    'iterations': 10,
                    'learning_rate': 0.001,
                    'log_dir': 'log_trash',
                    'minibatch_size': 4,
                    'model_path': '/gpfs/slac/staas/fs1/g/neutrino/ldomine/ppn_uresnet/weights_uresnet2/snapshot-999.ckpt',
                    'report_step': 1,
                    'seed': 123,
                    'train': False,
                    'weight_prefix': 'weights_trash/snapshot'}}

Reload the dataset, reinitalize trainer:

In [25]:
loader, cfg['data_keys'] = loader_factory(cfg)
dataset = iter(cycle(loader))
Trainer = trainval(cfg)
loaded_iteration = Trainer.initialize()
make_directories(cfg, loaded_iteration)
4
Loading file: /gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined/train_512px.root
Loading tree sparse3d_data
Loading tree sparse3d_fivetypes
Restoring weights from /gpfs/slac/staas/fs1/g/neutrino/ldomine/ppn_uresnet/weights_uresnet2/snapshot-999.ckpt...
Done.

Now run the inference loop for some iterations and use the output function to record the events in CSV files:

In [26]:
for i in range(10):
    data_blob = get_data_minibatched(dataset, cfg)
    res = Trainer.forward(data_blob)
    print(i, res)
    output(cfg['model']['outputs'], data_blob, res, cfg, i)
0 {'accuracy': 0.9867804777374559, 'loss_seg': 0.0526193231344223}
{'accuracy': 0.9867804777374559, 'loss_seg': 0.0526193231344223}
1 {'accuracy': 0.9868171002163086, 'loss_seg': 0.043108195066452026}
{'accuracy': 0.9868171002163086, 'loss_seg': 0.043108195066452026}
2 {'accuracy': 0.9623052696071026, 'loss_seg': 0.13719916343688965}
{'accuracy': 0.9623052696071026, 'loss_seg': 0.13719916343688965}
3 {'accuracy': 0.9390638987963343, 'loss_seg': 0.1717257797718048}
{'accuracy': 0.9390638987963343, 'loss_seg': 0.1717257797718048}
4 {'accuracy': 0.9772822168115907, 'loss_seg': 0.0752861350774765}
{'accuracy': 0.9772822168115907, 'loss_seg': 0.0752861350774765}
5 {'accuracy': 0.9800032949039785, 'loss_seg': 0.07227931916713715}
{'accuracy': 0.9800032949039785, 'loss_seg': 0.07227931916713715}
6 {'accuracy': 0.9783886821203386, 'loss_seg': 0.07452468574047089}
{'accuracy': 0.9783886821203386, 'loss_seg': 0.07452468574047089}
7 {'accuracy': 0.982328784751016, 'loss_seg': 0.07993170619010925}
{'accuracy': 0.982328784751016, 'loss_seg': 0.07993170619010925}
8 {'accuracy': 0.9796758921060514, 'loss_seg': 0.06401433795690536}
{'accuracy': 0.9796758921060514, 'loss_seg': 0.06401433795690536}
9 {'accuracy': 0.9796770820680686, 'loss_seg': 0.06422899663448334}
{'accuracy': 0.9796770820680686, 'loss_seg': 0.06422899663448334}

2.3. Event displays

Let's define some useful plotting functions using Plotly.

In [27]:
layout = go.Layout(margin=dict(l=0, r=0, b=0, t=0))

def get_trace(event, point_type, key="", point=False, cmax=0.05):
    subevent = event[event['type'] == point_type]
    if point:
        trace1 = go.Scatter3d(
            x=subevent['x'],
            y=subevent['y'],
            z=subevent['z'],
            mode='markers',
            marker=dict(
                size=5,
                color=subevent[key] if key else "orange",
                #colorscale=[[0, "blue"], [0.2, "red"], [0.4, "firebrick"], [0.6, "lightseagreen"], [0.8, "mediumpurple"]],
                colorscale="Rainbow",
                opacity=0.8
            )
        )        
    else:
        trace1 = go.Scatter3d(
            x=subevent['x'],
            y=subevent['y'],
            z=subevent['z'],
            mode='markers',
            marker=dict(
                size=1,
                color=subevent["value"],
                colorscale="Viridis",
                opacity=0.8,
                cmax=cmax,
                cmin=0.
            )
        )
    return trace1

def plot_event(data_products):
    fig = go.Figure(data=data_products, layout=layout)
    iplot(fig)

We load the output CSV file from one event:

In [28]:
event_id = 2
event = np.genfromtxt('log_trash/output-%.07d.csv' % event_id, names=True, delimiter=',')

Here is how it is recorded: we have 5 columns for each point: x, y, z for point coordinates, type to record different types of points (data, label, output, etc), and value (can be energy deposit, class label, etc).

In [29]:
event[:10]
Out[29]:
array([(196., 190., 0., 0., 0.022069), ( 94., 262., 4., 0., 0.019485),
       ( 94., 261., 5., 0., 0.019485), ( 93., 262., 5., 0., 0.019485),
       ( 94., 262., 5., 0., 0.087326), ( 95., 262., 5., 0., 0.019485),
       ( 94., 263., 5., 0., 0.019485), (127., 271., 5., 0., 0.011621),
       (128., 271., 5., 0., 0.05455 ), (129., 271., 5., 0., 0.021995)],
      dtype=[('x', '<f8'), ('y', '<f8'), ('z', '<f8'), ('type', '<f8'), ('value', '<f8')])

For our purpose, all we need to know is energy deposits (type 0), segmentation labels (type 2) and later predictions (type 4)

2.3.1. Energy deposits

In [30]:
plot_event([get_trace(event, 0)])

2.3.2. Labels for UResNet

In [31]:
plot_event([get_trace(event, 2, cmax=5), get_trace(event, 1, key="value", point=True)])

2.4. Analysis and display of the network output

Now you might want to compute more metrics from the output of the network. To do this you need to specify a field analysis_keys in the model configuration. It will fetch the corresponding outputs from the network and feed it to our analysis function. Here we say that we want the first output of the network (index 0) to be saved under the key segmentation in the results dictionary:

In [32]:
model_cfg["analysis_keys"] = {
    "segmentation": 0
}

Now we might also want to display the segmentation output of the network, so we will add the uresnet_ppn output formatter:

In [33]:
model_cfg['outputs'] = ["input", "uresnet_ppn"]

Now let's reload the configuration:

In [34]:
inference_cfg['gpus'] = '0'
cfg = {
    "iotool": io_cfg,
    "model": model_cfg,
    "training": inference_cfg
}
process_config(cfg)
{   'iotool': {   'batch_size': 4,
                  'collate_fn': 'CollateSparse',
                  'dataset': {   'data_dirs': [   '/gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined'],
                                 'data_key': 'train_512px',
                                 'limit_num_files': 10,
                                 'name': 'LArCVDataset',
                                 'schema': {   'input_data': [   'parse_sparse3d_scn',
                                                                 'sparse3d_data'],
                                               'segment_label': [   'parse_sparse3d_scn',
                                                                    'sparse3d_fivetypes']}},
                  'num_workers': 4,
                  'sampler': {'batch_size': 4, 'name': 'RandomSequenceSampler'},
                  'shuffle': False},
    'model': {   'analysis_keys': {'segmentation': 0},
                 'loss_input': ['segment_label'],
                 'modules': {   'uresnet': {   'data_dim': 3,
                                               'filters': 16,
                                               'num_classes': 5,
                                               'num_strides': 5,
                                               'spatial_size': 512}},
                 'name': 'uresnet',
                 'network_input': ['input_data'],
                 'outputs': ['input', 'uresnet_ppn']},
    'training': {   'checkpoint_step': 500,
                    'debug': False,
                    'gpus': [0],
                    'iterations': 10,
                    'learning_rate': 0.001,
                    'log_dir': 'log_trash',
                    'minibatch_size': 4,
                    'model_path': '/gpfs/slac/staas/fs1/g/neutrino/ldomine/ppn_uresnet/weights_uresnet2/snapshot-999.ckpt',
                    'report_step': 1,
                    'seed': 123,
                    'train': False,
                    'weight_prefix': 'weights_trash/snapshot'}}
In [35]:
loader, cfg['data_keys'] = loader_factory(cfg)
dataset = iter(cycle(loader))
Trainer = trainval(cfg)
loaded_iteration = Trainer.initialize()
make_directories(cfg, loaded_iteration)
4
Loading file: /gpfs/slac/staas/fs1/g/neutrino/kterao/data/dlprod_ppn_v10/combined/train_512px.root
Loading tree sparse3d_data
Loading tree sparse3d_fivetypes
Restoring weights from /gpfs/slac/staas/fs1/g/neutrino/ldomine/ppn_uresnet/weights_uresnet2/snapshot-999.ckpt...
Done.
In [36]:
data_blob = get_data_minibatched(dataset, cfg)
res = Trainer.forward(data_blob)
print(res)
output(cfg['model']['outputs'], data_blob, res, cfg, 0)
{'accuracy': 0.9867804777374559, 'loss_seg': 0.0526193231344223, 'segmentation': [array([[ 4.4756413 ,  0.9771209 , -1.7139757 , -2.7478278 , -2.4188387 ],
       [ 3.9867988 ,  0.85306716, -1.5087371 , -2.5764997 , -2.4625697 ],
       [ 4.8668814 ,  1.4653498 , -2.3174696 , -2.8363597 , -2.8884084 ],
       ...,
       [-1.1616472 ,  4.1530633 , -2.3052652 , -1.9561005 , -2.2434492 ],
       [-1.5729904 ,  4.558466  , -2.3279161 , -1.9647713 , -2.2248027 ],
       [-1.7966137 ,  4.699898  , -2.3987367 , -1.8705047 , -2.319155  ]],
      dtype=float32)]}
{'accuracy': 0.9867804777374559, 'loss_seg': 0.0526193231344223, 'segmentation': [array([[ 4.4756413 ,  0.9771209 , -1.7139757 , -2.7478278 , -2.4188387 ],
       [ 3.9867988 ,  0.85306716, -1.5087371 , -2.5764997 , -2.4625697 ],
       [ 4.8668814 ,  1.4653498 , -2.3174696 , -2.8363597 , -2.8884084 ],
       ...,
       [-1.1616472 ,  4.1530633 , -2.3052652 , -1.9561005 , -2.2434492 ],
       [-1.5729904 ,  4.558466  , -2.3279161 , -1.9647713 , -2.2248027 ],
       [-1.7966137 ,  4.699898  , -2.3987367 , -1.8705047 , -2.319155  ]],
      dtype=float32)]}

Now you can see that in addition to loss and accuracy, res['segmentation'] contains the output of the semantic segmentation.

In [37]:
res['segmentation'][0].shape
Out[37]:
(35608, 5)

You can now use it in your advanced analysis script, for example to compute accuracy per class, etc.

Let's see what the uresnet_ppn formatter recorded for us, the predictions:

In [38]:
event_id = 0
event = np.genfromtxt('log_trash/output-%07d.csv' % event_id, names=True, delimiter=',')

Predictions for this event

In [39]:
plot_event([get_trace(event, 4, cmax=5)])  # , get_trace(event1, 5, key="value", point=True)

Corresponding segmentation labels for this event

In [40]:
plot_event([get_trace(event, 2, cmax=5), get_trace(event, 1, key="value", point=True)])
In [ ]: