Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
142 commits
Select commit Hold shift + click to select a range
07e4809
newlines
sshleifer Oct 16, 2019
906c925
still working, some more cleanup
sshleifer Oct 17, 2019
504e094
little if simplification
sshleifer Oct 31, 2019
ff9af42
slow ass utest
sshleifer Oct 31, 2019
e15e139
metrics cleanup
sshleifer Nov 1, 2019
e1f9524
test args
sshleifer Nov 1, 2019
5a1b19c
no expid
sshleifer Nov 1, 2019
a4baa86
more cleanup
sshleifer Nov 1, 2019
eafe6c6
more cleanup
sshleifer Nov 1, 2019
eda481f
test metric shapes
sshleifer Nov 1, 2019
bb977f0
small requirements
sshleifer Nov 1, 2019
fce44a2
README
sshleifer Nov 1, 2019
3a1dc52
Cleanup (#1)
sshleifer Nov 1, 2019
ee246b1
cheaper_metric fn
sshleifer Nov 2, 2019
c679945
call cheaper_metric
sshleifer Nov 2, 2019
0d5a515
tests pass
sshleifer Nov 2, 2019
a9c2044
delete other code
sshleifer Nov 2, 2019
100e4d3
util update
sshleifer Nov 3, 2019
abc18ab
item
sshleifer Nov 3, 2019
d7b9be1
factor out test metrics
sshleifer Nov 3, 2019
5b85366
utest for test script
sshleifer Nov 3, 2019
e387cb9
mkdir
sshleifer Nov 3, 2019
c877504
Cheaper metric (#2)
sshleifer Nov 3, 2019
5b820a6
test push permissions
invalid-email-address Nov 4, 2019
dce48c3
test push
vamsikc Nov 4, 2019
d6cd058
add test script args
sshleifer Nov 4, 2019
ca3e076
Merge branch 'master' of github.com:sshleifer/Graph-WaveNet
sshleifer Nov 4, 2019
75d3280
utest stuff
sshleifer Nov 4, 2019
a7c4d55
dont overwrite script args
sshleifer Nov 4, 2019
046cd2b
checkin baseline_ckpt.pth
sshleifer Nov 4, 2019
14d20bd
Better baseline
sshleifer Nov 5, 2019
6372489
Dont check in utest side effects
sshleifer Nov 6, 2019
58a378a
gitignore
sshleifer Nov 6, 2019
db176e8
utestfixes
sshleifer Nov 7, 2019
edfbe91
Hoist nconv
sshleifer Nov 7, 2019
a66878c
cleaning up test files
chmccreery Nov 7, 2019
3d92651
removing more test files
chmccreery Nov 7, 2019
fe96d62
New adjacency matrices with less thresholding and script to analyze them
chmccreery Nov 7, 2019
4a73304
add day of the week feature
chmccreery Nov 7, 2019
66409af
Default to include day of week one-hot feature
chmccreery Nov 7, 2019
02e97d8
Test metrics stuff
sshleifer Nov 7, 2019
b950a72
Merge branch 'master' of github.com:sshleifer/Graph-WaveNet
sshleifer Nov 7, 2019
209b6bc
use utils to unpickle and add heatmap, histogram to adj_mx_stats
chmccreery Nov 8, 2019
b345237
Merge branch 'master' of github.com:sshleifer/Graph-WaveNet
sshleifer Nov 8, 2019
15b62f4
more cov
sshleifer Nov 8, 2019
cfb72cf
register parameter
sshleifer Nov 8, 2019
0be46da
fix supports_len bug
sshleifer Nov 8, 2019
cf8b066
Fix
sshleifer Nov 8, 2019
1d9d19d
unit test on GPU (#3)
sshleifer Nov 22, 2019
00bfee0
fix test.py (#4)
sshleifer Nov 22, 2019
cabed88
warning
sshleifer Nov 22, 2019
35af7d1
idiot
sshleifer Nov 22, 2019
e7b6d6f
Hoist model
sshleifer Nov 29, 2019
cc50db2
early break
sshleifer Nov 29, 2019
ec24a8f
remove assert
sshleifer Nov 29, 2019
69ba145
model kw
sshleifer Nov 29, 2019
d093ef1
progbar
sshleifer Nov 29, 2019
63c0494
Major cleanup
sshleifer Nov 29, 2019
2c66c6a
more cleanup
sshleifer Nov 29, 2019
5e1e897
cleanup train.py a bit
sshleifer Nov 29, 2019
d4b3ba5
Smaller nhid
sshleifer Nov 30, 2019
64bed79
Unit-test speedup
sshleifer Nov 30, 2019
6799d61
better comment
sshleifer Nov 30, 2019
2a95723
more test.py improvements
sshleifer Nov 30, 2019
3610e9d
more test.py improvements
sshleifer Nov 30, 2019
0be74bc
Clamp
sshleifer Nov 30, 2019
25045b3
C
sshleifer Nov 30, 2019
0b433db
bango
sshleifer Nov 30, 2019
a8f21a9
new util
sshleifer Nov 30, 2019
da29d41
small sensible changes
sshleifer Dec 1, 2019
1474035
lr scheduling
sshleifer Dec 1, 2019
19b1bb4
style
sshleifer Dec 1, 2019
932bb6a
exp result analyzer
sshleifer Dec 1, 2019
78dc465
update requirements
sshleifer Dec 2, 2019
954ac45
Shared args
sshleifer Dec 2, 2019
96a7282
save import
sshleifer Dec 2, 2019
f4bf38a
Softmax-temp
sshleifer Dec 2, 2019
aeb22ca
style model.py
sshleifer Dec 2, 2019
b1ec325
lr_decay_rate_hoist
sshleifer Dec 2, 2019
bccf70d
better train ux
sshleifer Dec 2, 2019
5972db3
More cleanup
sshleifer Dec 2, 2019
452cbb7
dont save preds
sshleifer Dec 2, 2019
2db9f5b
more cleanup
sshleifer Dec 2, 2019
effada8
more cleanup
sshleifer Dec 2, 2019
2e2379b
more cleanup
sshleifer Dec 2, 2019
6fa65dd
More cleanup
sshleifer Dec 2, 2019
3f283f4
Dont register parameter
sshleifer Dec 2, 2019
a4bc70c
Cat feat plane gc (#5)
sshleifer Dec 3, 2019
f991c29
check in baseline args
sshleifer Dec 3, 2019
5c8fc47
plot grad flow
sshleifer Dec 3, 2019
932b865
Apex (#6)
sshleifer Dec 3, 2019
b248ba6
fix fp16 compat
sshleifer Dec 3, 2019
7fff6e7
hoist clip arg
sshleifer Dec 3, 2019
a8bb661
cleanup
sshleifer Dec 3, 2019
0001031
es patience 20
sshleifer Dec 3, 2019
3088d06
allows model kwarg passing
sshleifer Dec 4, 2019
65c2d23
bigger eval batch sizes
sshleifer Dec 4, 2019
99e29ee
lrfinder
sshleifer Dec 4, 2019
5766a99
(VAMSI) Seq length (#7)
sshleifer Dec 4, 2019
bed47e0
surgery func
sshleifer Dec 4, 2019
8b74eec
squeeze like master I guess
sshleifer Dec 4, 2019
0dee392
jk
sshleifer Dec 4, 2019
7fb7ae2
force squeeze1
sshleifer Dec 4, 2019
bf465c6
Merge branch 'surgery'
sshleifer Dec 4, 2019
11c83de
end_conv_lr arg
sshleifer Dec 4, 2019
df4c895
args update
sshleifer Dec 4, 2019
41fa3bf
freeze group b code
sshleifer Dec 4, 2019
0be1191
Freeze if lrate=0
sshleifer Dec 4, 2019
b777b96
call freezer
sshleifer Dec 4, 2019
102a21e
Revert batch size
sshleifer Dec 5, 2019
2678b90
Train loss progbar
sshleifer Dec 5, 2019
9cc8923
log test flag
sshleifer Dec 5, 2019
804428e
Uneven datasets
sshleifer Dec 5, 2019
48a0eae
y_start clarg
sshleifer Dec 5, 2019
29c9f3d
assert and squeeze
sshleifer Dec 5, 2019
c664dd7
gitignore
sshleifer Dec 5, 2019
ae06854
Some cleanup
sshleifer Dec 5, 2019
1cb00bf
bug fix
sshleifer Dec 5, 2019
c695c8e
Merge branch 'uneven-ds'
sshleifer Dec 5, 2019
cb6736b
replace 0 feat with 58.435
sshleifer Dec 5, 2019
03063d0
Skip zero batches
sshleifer Dec 6, 2019
3e85bde
test.main supports loader=val
sshleifer Dec 6, 2019
bc103a1
scaler prints
sshleifer Dec 6, 2019
6f7a926
scaler prints
sshleifer Dec 6, 2019
d4bdd9f
reduce fill-val
sshleifer Dec 6, 2019
64253c5
call .eval
sshleifer Dec 6, 2019
2e0778d
fix imports
sshleifer Dec 6, 2019
a91c888
fill val is mean
sshleifer Dec 6, 2019
371a784
copy of nb
sshleifer Dec 7, 2019
399a270
some surgery docs
sshleifer Dec 9, 2019
8ed396c
Merge branch 'master' of github.com:sshleifer/Graph-WaveNet
sshleifer Dec 9, 2019
4f48833
better baseline_args
sshleifer Dec 9, 2019
215640a
del unused
sshleifer Dec 9, 2019
7ce81fe
Merge branch 'master' of github.com:sshleifer/Graph-WaveNet
sshleifer Dec 9, 2019
38c2f78
ADJ_CHOICES
sshleifer Dec 10, 2019
2d5830a
fill_zeroes arg
sshleifer Dec 11, 2019
dae86d5
move surgery func
sshleifer Dec 11, 2019
7a2a7c2
README updates
sshleifer Dec 11, 2019
f06cc8c
Merge master
sshleifer Dec 17, 2019
03684b8
update README for new cli
sshleifer Dec 17, 2019
efbc44b
check in less data
sshleifer Dec 17, 2019
2237839
support cat_feat_gc and dow
sshleifer Dec 17, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
utest_experiment/*
heatmap.png
last_test_metrics.csv
preds.csv
.ipynb_checkpoints/
data/
.DS_Store
*.pkl
72 changes: 38 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Graph WaveNet for Deep Spatial-Temporal Graph Modeling

This is the original pytorch implementation of Graph WaveNet in the following paper:
[Graph WaveNet for Deep Spatial-Temporal Graph Modeling, IJCAI 2019] (https://arxiv.org/abs/1906.00121).
[Graph WaveNet for Deep Spatial-Temporal Graph Modeling, IJCAI 2019] (https://arxiv.org/abs/1906.00121),
with modifications presented in [Incrementally Improving Graph WaveNet Performance on Traffic Prediction] (https://arxiv.org/abs/1912.07390):


<p align="center">
<img width="350" height="400" src=./fig/model.png>
Expand All @@ -14,9 +16,9 @@ This is the original pytorch implementation of Graph WaveNet in the following pa

## Data Preparation

### Step1: Download METR-LA and PEMS-BAY data from [Google Drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) or [Baidu Yun](https://pan.baidu.com/s/14Yy9isAIZYdU__OYEQGa_g) links provided by [DCRNN](https://github.com/liyaguang/DCRNN).
1) Download METR-LA and PEMS-BAY data from [Google Drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) or [Baidu Yun](https://pan.baidu.com/s/14Yy9isAIZYdU__OYEQGa_g) links provided by [DCRNN](https://github.com/liyaguang/DCRNN).

### Step2:
2)

```
# Create data directories
Expand All @@ -29,40 +31,42 @@ python generate_training_data.py --output_dir=data/METR-LA --traffic_df_filename
python generate_training_data.py --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5

```
## Experiments
Train models configured in Table 3 of the paper.

## Train Commands
Note: train.py saves metrics to a directory specified by the `--save` arg in metrics.csv and test_metrics.csv

Model that gets (3.00 - 3.02 Test MAE, ~2.73 Validation MAE)
```
python train.py --cat_feat_gc --fill_zeroes --do_graph_conv --addaptadj --randomadj --es_patience 20 --save logs/baseline_v2
```

Finetuning (2.99 - 3.00 MAE)
```
python generate_training_data.py --seq_length_y 6 --output_dir data/METR-LA_12_6
python train.py --data data/METR-LA_12_6 --cat_feat_gc --fill_zeroes --do_graph_conv --addaptadj --randomadj --es_patience 20 --save logs/front_6
python train.py --checkpoint logs/front_6/best_model.pth --cat_feat_gc --fill_zeroes --do_graph_conv --addaptadj --randomadj --es_patience 20 --save logs/finetuned

```
Original Graph Wavenet Model (3.04-3.07 MAE)
```
python train.py --clip 5 --lr_decay_rate 1. --nhid 32 --do_graph_conv --addaptadj --randomadj --save logs/baseline
```
ep=100
dv=cuda:0
mkdir experiment
mkdir experiment/metr

#identity
expid=1
python train.py --device $dv --gcn_bool --adjtype identity --epoch $ep --expid $expid --save ./experiment/metr/metr > ./experiment/metr/train-$expid.log
rm ./experiment/metr/metr_epoch*

#forward-only
expid=2
python train.py --device $dv --gcn_bool --adjtype transition --epoch $ep --expid $expid --save ./experiment/metr/metr > ./experiment/metr/train-$expid.log
rm ./experiment/metr/metr_epoch*

#adaptive-only
expid=3
python train.py --device $dv --gcn_bool --adjtype transition --aptonly --addaptadj --randomadj --epoch $ep --expid $expid --save ./experiment/metr/metr > ./experiment/metr/train-$expid.log
rm ./experiment/metr/metr_epoch*

#forward-backward
expid=4
python train.py --device $dv --gcn_bool --adjtype doubletransition --epoch $ep --expid $expid --save ./experiment/metr/metr > ./experiment/metr/train-$expid.log
rm ./experiment/metr/metr_epoch*

#forward-backward-adaptive
expid=5
python train.py --device $dv --gcn_bool --adjtype doubletransition --addaptadj --randomadj --epoch $ep --expid $expid --save ./experiment/metr/metr > ./experiment/metr/train-$expid.log
rm ./experiment/metr/metr_epoch*

You can also train from a jupyter notebook with
```{python}
from train import main
from durbango import pickle_load
args = pickle_load('baseline_args.pkl') # manipulate these in python
args.lr_decay_rate = .97
args.clip = 3
args.save = 'logs/from_jupyter'
main(args) # takes roughly an hour depending on nhid, and early_stopping
```

Train models configured in Table 3 of the original GraphWavenet paper by using the `--adjtype, --addaptadj, --aptonly` command line argument.
These flags are (somewhat) documented in util.py.

Run unitests with `pytest`

### Possible Improvements
* move redundant `.transpose(1,3)` to dataloader or `load_dataset`
Binary file added baseline_args.pkl
Binary file not shown.
45 changes: 22 additions & 23 deletions engine.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
import torch.optim as optim
from model import *
import util
class trainer():
def __init__(self, scaler, in_dim, seq_length, num_nodes, nhid , dropout, lrate, wdecay, device, supports, gcn_bool, addaptadj, aptinit):
self.model = gwnet(device, num_nodes, dropout, supports=supports, gcn_bool=gcn_bool, addaptadj=addaptadj, aptinit=aptinit, in_dim=in_dim, out_dim=seq_length, residual_channels=nhid, dilation_channels=nhid, skip_channels=nhid * 8, end_channels=nhid * 16)
self.model.to(device)

class Trainer():
def __init__(self, model: GWNet, scaler, lrate, wdecay, clip=3, lr_decay_rate=.97):
self.model = model

self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay)
self.loss = util.masked_mae
self.scaler = scaler
self.clip = 5
self.clip = clip
self.scheduler = optim.lr_scheduler.LambdaLR(
self.optimizer, lr_lambda=lambda epoch: lr_decay_rate ** epoch)

@classmethod
def from_args(cls, model, scaler, args):
return cls(model, scaler, args.learning_rate, args.weight_decay, clip=args.clip,
lr_decay_rate=args.lr_decay_rate)

def train(self, input, real_val):
self.model.train()
self.optimizer.zero_grad()
input = nn.functional.pad(input,(1,0,0,0))
output = self.model(input)
output = output.transpose(1,3)
#output = [batch_size,12,num_nodes,1]
real = torch.unsqueeze(real_val,dim=1)
output = self.model(input).transpose(1,3) # now, output = [batch_size,1,num_nodes, seq_length]
predict = self.scaler.inverse_transform(output)

loss = self.loss(predict, real, 0.0)
loss.backward()
assert predict.shape[1] == 1
mae, mape, rmse = util.calc_metrics(predict.squeeze(1), real_val, null_val=0.0)
mae.backward()
if self.clip is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
self.optimizer.step()
mape = util.masked_mape(predict,real,0.0).item()
rmse = util.masked_rmse(predict,real,0.0).item()
return loss.item(),mape,rmse
return mae.item(),mape.item(),rmse.item()

def eval(self, input, real_val):
self.model.eval()
input = nn.functional.pad(input,(1,0,0,0))
output = self.model(input)
output = output.transpose(1,3)
#output = [batch_size,12,num_nodes,1]
output = self.model(input).transpose(1,3) # [batch_size,seq_length,num_nodes,1]
real = torch.unsqueeze(real_val,dim=1)
predict = self.scaler.inverse_transform(output)
loss = self.loss(predict, real, 0.0)
mape = util.masked_mape(predict,real,0.0).item()
rmse = util.masked_rmse(predict,real,0.0).item()
return loss.item(),mape,rmse
predict = torch.clamp(predict, min=0., max=70.)
mae, mape, rmse = [x.item() for x in util.calc_metrics(predict, real, null_val=0.0)]
return mae, mape, rmse
38 changes: 38 additions & 0 deletions exp_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Utilities for comparing metrics saved by train.py"""
import pandas as pd
import os
from glob import glob
import matplotlib.pyplot as plt


def summary(d):
try:
tr_val = pd.read_csv(f'{d}/metrics.csv', index_col=0)
tr_ser = tr_val.loc[tr_val.valid_loss.idxmin()]
tr_ser['best_epoch'] = tr_val.valid_loss.idxmin()
tr_ser['min_train_loss'] = tr_val.train_loss.min()
except FileNotFoundError:
tr_ser = pd.Series()
try:
tmet = pd.read_csv(f'{d}/test_metrics.csv', index_col=0)
tmean = tmet.add_prefix('test_').mean()

except FileNotFoundError:
tmean = pd.Series()
tab = pd.concat([tr_ser, tmean]).round(3)
return tab

def loss_curve(d):
if 'logs' not in d: d = f'logs/{d}'
tr_val = pd.read_csv(f'{d}/metrics.csv', index_col=0)
return tr_val[['train_loss', 'valid_loss']]


def plot_loss_curve(log_dir):
d = loss_curve(log_dir)
ax = d.plot()
plt.axhline(d.valid_loss.min())
print(d.valid_loss.idxmin())

def make_results_table():
return pd.DataFrame({os.path.basename(c): summary(c) for c in glob('logs/*')}).T.sort_values('valid_loss')
63 changes: 63 additions & 0 deletions gen_adj_mx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np
import pandas as pd
import pickle


def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1):
"""

:param distance_df: data frame with three columns: [from, to, distance].
:param sensor_ids: list of sensor ids.
:param normalized_k: entries that become lower than normalized_k after normalization are set to zero for sparsity.
:return:
"""
num_sensors = len(sensor_ids)
dist_mx = np.zeros((num_sensors, num_sensors), dtype=np.float32)
dist_mx[:] = np.inf
# Builds sensor id to index map.
sensor_id_to_ind = {}
for i, sensor_id in enumerate(sensor_ids):
sensor_id_to_ind[sensor_id] = i

# Fills cells in the matrix with distances.
for row in distance_df.values:
if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:
continue
dist_mx[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]

# Calculates the standard deviation as theta.
distances = dist_mx[~np.isinf(dist_mx)].flatten()
std = distances.std()
adj_mx = np.exp(-np.square(dist_mx / std))
# Make the adjacent matrix symmetric by taking the max.
# adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])

# Sets entries that lower than a threshold, i.e., k, to zero for sparsity.
adj_mx[adj_mx < normalized_k] = 0
return sensor_ids, sensor_id_to_ind, adj_mx


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--sensor_ids_filename', type=str, default='data/sensor_graph/graph_sensor_ids.txt',
help='File containing sensor ids separated by comma.')
parser.add_argument('--distances_filename', type=str, default='data/sensor_graph/distances_la_2012.csv',
help='CSV file containing sensor distances with three columns: [from, to, distance].')
parser.add_argument('--normalized_k', type=float, default=0.1,
help='Entries that become lower than normalized_k after normalization are set to zero for sparsity.')
parser.add_argument('--output_pkl_filename', type=str, default='data/sensor_graph/adj_mat.pkl',
help='Path of the output file.')
args = parser.parse_args()

with open(args.sensor_ids_filename) as f:
sensor_ids = f.read().strip().split(',')
distance_df = pd.read_csv(args.distances_filename, dtype={'from': 'str', 'to': 'str'})
_, sensor_id_to_ind, adj_mx = get_adjacency_matrix(distance_df, sensor_ids, args.normalized_k)
# Save to pickle file.
with open(args.output_pkl_filename, 'wb') as f:
pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f, protocol=2)
Loading