Add ktx
This commit is contained in:
@@ -0,0 +1,560 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# -----------------------------------------------------------------------------
|
||||
# Copyright 2021 Arm Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy
|
||||
# of the License at:
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
The ``astc_trace_analysis`` utility provides a tool to analyze trace files.
|
||||
|
||||
WARNING: Trace files are an engineering tool, and not part of the standard
|
||||
product, so traces and their associated tools are volatile and may change
|
||||
significantly without notice.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict as ddict
|
||||
import json
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
QUANT_TABLE = {
|
||||
0: 2,
|
||||
1: 3,
|
||||
2: 4,
|
||||
3: 5,
|
||||
4: 6,
|
||||
5: 8,
|
||||
6: 10,
|
||||
7: 12,
|
||||
8: 16,
|
||||
9: 20,
|
||||
10: 24,
|
||||
11: 32
|
||||
}
|
||||
|
||||
CHANNEL_TABLE = {
|
||||
0: "R",
|
||||
1: "G",
|
||||
2: "B",
|
||||
3: "A"
|
||||
}
|
||||
|
||||
class Trace:
|
||||
|
||||
def __init__(self, block_x, block_y, block_z):
|
||||
self.block_x = block_x
|
||||
self.block_y = block_y
|
||||
self.block_z = block_z
|
||||
self.blocks = []
|
||||
|
||||
def add_block(self, block):
|
||||
self.blocks.append(block)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.blocks[i]
|
||||
|
||||
def __delitem__(self, i):
|
||||
del self.blocks[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.blocks)
|
||||
|
||||
class Block:
|
||||
|
||||
def __init__(self, pos_x, pos_y, pos_z, error_target):
|
||||
self.pos_x = pos_x
|
||||
self.pos_y = pos_y
|
||||
self.pos_z = pos_z
|
||||
|
||||
self.raw_min = None
|
||||
self.raw_max = None
|
||||
|
||||
self.ldr_min = None
|
||||
self.ldr_max = None
|
||||
|
||||
self.error_target = error_target
|
||||
self.passes = []
|
||||
self.qualityHit = None
|
||||
|
||||
def add_minimums(self, r, g, b, a):
|
||||
self.raw_min = (r, g, b, a)
|
||||
|
||||
def ldr(x):
|
||||
cmax = 65535.0
|
||||
return int((r / cmax) * 255.0)
|
||||
|
||||
self.ldr_min = (ldr(r), ldr(g), ldr(b), ldr(a))
|
||||
|
||||
def add_maximums(self, r, g, b, a):
|
||||
self.raw_max = (r, g, b, a)
|
||||
|
||||
def ldr(x):
|
||||
cmax = 65535.0
|
||||
return int((r / cmax) * 255.0)
|
||||
|
||||
self.ldr_max = (ldr(r), ldr(g), ldr(b), ldr(a))
|
||||
|
||||
def add_pass(self, pas):
|
||||
self.passes.append(pas)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.passes[i]
|
||||
|
||||
def __delitem__(self, i):
|
||||
del self.passes[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.passes)
|
||||
|
||||
|
||||
class Pass:
|
||||
|
||||
def __init__(self, partitions, partition, planes, target_hit, mode, component):
|
||||
self.partitions = partitions
|
||||
self.partition_index = 0 if partition is None else partition
|
||||
self.planes = planes
|
||||
self.plane2_component = component
|
||||
self.target_hit = target_hit
|
||||
self.search_mode = mode
|
||||
self.candidates = []
|
||||
|
||||
def add_candidate(self, candidate):
|
||||
self.candidates.append(candidate)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.candidates[i]
|
||||
|
||||
def __delitem__(self, i):
|
||||
del self.candidates[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.candidates)
|
||||
|
||||
|
||||
class Candidate:
|
||||
|
||||
def __init__(self, weight_x, weight_y, weight_z, weight_quant):
|
||||
self.weight_x = weight_x
|
||||
self.weight_y = weight_y
|
||||
self.weight_z = weight_z
|
||||
self.weight_quant = weight_quant
|
||||
self.refinement_errors = []
|
||||
|
||||
def add_refinement(self, errorval):
|
||||
self.refinement_errors.append(errorval)
|
||||
|
||||
|
||||
def get_attrib(data, name, multiple=False, hard_fail=True):
|
||||
results = []
|
||||
for attrib in data:
|
||||
if len(attrib) == 2 and attrib[0] == name:
|
||||
results.append(attrib[1])
|
||||
|
||||
if not results:
|
||||
if hard_fail:
|
||||
print(json.dumps(data, indent=2))
|
||||
assert False, "Attribute %s not found" % name
|
||||
if multiple:
|
||||
return list()
|
||||
return None
|
||||
|
||||
if not multiple:
|
||||
if len(results) > 1:
|
||||
print(json.dumps(data, indent=2))
|
||||
assert False, "Attribute %s found %u times" % (name, len(results))
|
||||
return results[0]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rev_enumerate(seq):
|
||||
return zip(reversed(range(len(seq))), reversed(seq))
|
||||
|
||||
def foreach_block(data):
|
||||
|
||||
for block in data:
|
||||
yield block
|
||||
|
||||
def foreach_pass(data):
|
||||
|
||||
for block in data:
|
||||
for pas in block:
|
||||
yield (block, pas)
|
||||
|
||||
def foreach_candidate(data):
|
||||
|
||||
for block in data:
|
||||
for pas in block:
|
||||
# Special case - None candidates for 0 partition
|
||||
if not len(pas):
|
||||
yield (block, pas, None)
|
||||
|
||||
for candidate in pas:
|
||||
yield (block, pas, candidate)
|
||||
|
||||
def get_node(data, name, multiple=False, hard_fail=True):
|
||||
results = []
|
||||
for attrib in data:
|
||||
if len(attrib) == 3 and attrib[0] == "node" and attrib[1] == name:
|
||||
results.append(attrib[2])
|
||||
|
||||
if not results:
|
||||
if hard_fail:
|
||||
print(json.dumps(data, indent=2))
|
||||
assert False, "Node %s not found" % name
|
||||
return None
|
||||
|
||||
if not multiple:
|
||||
if len(results) > 1:
|
||||
print(json.dumps(data, indent=2))
|
||||
assert False, "Node %s found %u times" % (name, len(results))
|
||||
return results[0]
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def find_best_pass_and_candidate(block):
|
||||
explicit_pass = None
|
||||
|
||||
best_error = 1e30
|
||||
best_pass = None
|
||||
best_candidate = None
|
||||
|
||||
for pas in block:
|
||||
# Special case for constant color blocks - no trial candidates
|
||||
if pas.target_hit and pas.partitions == 0:
|
||||
return (pas, None)
|
||||
|
||||
for candidate in pas:
|
||||
errorval = candidate.refinement_errors[-1]
|
||||
if errorval <= best_error:
|
||||
best_error = errorval
|
||||
best_pass = pas
|
||||
best_candidate = candidate
|
||||
|
||||
# Every other return type must have both best pass and best candidate
|
||||
assert (best_pass and best_candidate)
|
||||
return (best_pass, best_candidate)
|
||||
|
||||
|
||||
def generate_database(data):
|
||||
# Skip header
|
||||
assert(data[0] == "node")
|
||||
assert(data[1] == "root")
|
||||
data = data[2]
|
||||
|
||||
bx = get_attrib(data, "block_x")
|
||||
by = get_attrib(data, "block_y")
|
||||
bz = get_attrib(data, "block_z")
|
||||
dbStruct = Trace(bx, by, bz)
|
||||
|
||||
for block in get_node(data, "block", True):
|
||||
px = get_attrib(block, "pos_x")
|
||||
py = get_attrib(block, "pos_y")
|
||||
pz = get_attrib(block, "pos_z")
|
||||
|
||||
minr = get_attrib(block, "min_r")
|
||||
ming = get_attrib(block, "min_g")
|
||||
minb = get_attrib(block, "min_b")
|
||||
mina = get_attrib(block, "min_a")
|
||||
|
||||
maxr = get_attrib(block, "max_r")
|
||||
maxg = get_attrib(block, "max_g")
|
||||
maxb = get_attrib(block, "max_b")
|
||||
maxa = get_attrib(block, "max_a")
|
||||
|
||||
et = get_attrib(block, "tune_error_threshold")
|
||||
|
||||
blockStruct = Block(px, py, pz, et)
|
||||
blockStruct.add_minimums(minr, ming, minb, mina)
|
||||
blockStruct.add_maximums(maxr, maxg, maxb, maxa)
|
||||
dbStruct.add_block(blockStruct)
|
||||
|
||||
for pas in get_node(block, "pass", True):
|
||||
# Don't copy across passes we skipped due to heuristics
|
||||
skipped = get_attrib(pas, "skip", False, False)
|
||||
if skipped:
|
||||
continue
|
||||
|
||||
prts = get_attrib(pas, "partition_count")
|
||||
prti = get_attrib(pas, "partition_index", False, False)
|
||||
plns = get_attrib(pas, "plane_count")
|
||||
chan = get_attrib(pas, "plane_component", False, plns > 2)
|
||||
mode = get_attrib(pas, "search_mode", False, False)
|
||||
ehit = get_attrib(pas, "exit", False, False) == "quality hit"
|
||||
|
||||
passStruct = Pass(prts, prti, plns, ehit, mode, chan)
|
||||
blockStruct.add_pass(passStruct)
|
||||
|
||||
# Constant color blocks don't have any candidates
|
||||
if prts == 0:
|
||||
continue
|
||||
|
||||
for candidate in get_node(pas, "candidate", True):
|
||||
# Don't copy across candidates we couldn't encode
|
||||
failed = get_attrib(candidate, "failed", False, False)
|
||||
if failed:
|
||||
continue
|
||||
|
||||
wx = get_attrib(candidate, "weight_x")
|
||||
wy = get_attrib(candidate, "weight_y")
|
||||
wz = get_attrib(candidate, "weight_z")
|
||||
wq = QUANT_TABLE[get_attrib(candidate, "weight_quant")]
|
||||
epre = get_attrib(candidate, "error_prerealign", True, False)
|
||||
epst = get_attrib(candidate, "error_postrealign", True, False)
|
||||
|
||||
candStruct = Candidate(wx, wy, wz, wq)
|
||||
passStruct.add_candidate(candStruct)
|
||||
for value in epre:
|
||||
candStruct.add_refinement(value)
|
||||
for value in epst:
|
||||
candStruct.add_refinement(value)
|
||||
|
||||
return dbStruct
|
||||
|
||||
|
||||
def filter_database(data):
|
||||
|
||||
for block in data:
|
||||
best_pass, best_candidate = find_best_pass_and_candidate(block)
|
||||
|
||||
for i, pas in rev_enumerate(block):
|
||||
if pas != best_pass:
|
||||
del block[i]
|
||||
continue
|
||||
|
||||
if best_candidate is None:
|
||||
continue
|
||||
|
||||
for j, candidate in rev_enumerate(pas):
|
||||
if candidate != best_candidate:
|
||||
del pas[j]
|
||||
|
||||
|
||||
def generate_pass_statistics(data):
|
||||
pass
|
||||
|
||||
|
||||
def generate_feature_statistics(data):
|
||||
# -------------------------------------------------------------------------
|
||||
# Config
|
||||
print("Compressor Config")
|
||||
print("=================")
|
||||
|
||||
if data.block_z > 1:
|
||||
dat = (data.block_x, data.block_y, data.block_z)
|
||||
print(" - Block size: %ux%ux%u" % dat)
|
||||
else:
|
||||
dat = (data.block_x, data.block_y)
|
||||
print(" - Block size: %ux%u" % dat)
|
||||
|
||||
print("")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Block metrics
|
||||
result = ddict(int)
|
||||
|
||||
RANGE_QUANT = 16
|
||||
|
||||
for block in foreach_block(data):
|
||||
ranges = []
|
||||
for i in range(0, 4):
|
||||
ranges.append(block.ldr_max[i] - block.ldr_min[i])
|
||||
|
||||
max_range = max(ranges)
|
||||
max_range = int(max_range / RANGE_QUANT) * RANGE_QUANT
|
||||
|
||||
result[max_range] += 1
|
||||
|
||||
print("Channel Range")
|
||||
print("=============")
|
||||
keys = sorted(result.keys())
|
||||
for key in keys:
|
||||
dat = (key, key + RANGE_QUANT - 1, result[key])
|
||||
print(" - %3u-%3u dynamic range = %6u blocks" % dat)
|
||||
|
||||
print("")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Partition usage
|
||||
result_totals = ddict(int)
|
||||
results = ddict(lambda: ddict(int))
|
||||
|
||||
for _, pas in foreach_pass(data):
|
||||
result_totals[pas.partitions] += 1
|
||||
results[pas.partitions][pas.partition_index] += 1
|
||||
|
||||
print("Partition Count")
|
||||
print("===============")
|
||||
keys = sorted(result_totals.keys())
|
||||
for key in keys:
|
||||
dat = (key, result_totals[key], len(results[key]))
|
||||
print(" - %u partition(s) = %6u blocks / %4u indicies" % dat)
|
||||
|
||||
print("")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Plane usage
|
||||
result_count = ddict(lambda: ddict(int))
|
||||
result_channel = ddict(lambda: ddict(int))
|
||||
|
||||
for _, pas in foreach_pass(data):
|
||||
result_count[pas.partitions][pas.planes] += 1
|
||||
if (pas.planes > 1):
|
||||
result_channel[pas.partitions][pas.plane2_component] += 1
|
||||
|
||||
print("Plane Usage")
|
||||
print("===========")
|
||||
keys = sorted(result_count.keys())
|
||||
for key in keys:
|
||||
keys2 = sorted(result_count[key])
|
||||
for key2 in keys2:
|
||||
val2 = result_count[key][key2]
|
||||
dat = (key, key2, val2)
|
||||
print(" - %u partition(s) %u plane(s) = %6u blocks" % dat)
|
||||
if key2 == 2:
|
||||
keys3 = sorted(result_channel[key])
|
||||
for key3 in keys3:
|
||||
dat = (CHANNEL_TABLE[key3], result_channel[key][key3])
|
||||
print(" - %s plane = %6u blocks" % dat)
|
||||
|
||||
print("")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Decimation usage
|
||||
decim_count = ddict(lambda: ddict(int))
|
||||
quant_count = ddict(lambda: ddict(lambda: ddict(int)))
|
||||
|
||||
|
||||
MERGE_ROTATIONS = True
|
||||
|
||||
for _, pas, can in foreach_candidate(data):
|
||||
# Skip constant color blocks
|
||||
if can is None:
|
||||
continue
|
||||
|
||||
wx = can.weight_x
|
||||
wy = can.weight_y
|
||||
|
||||
if MERGE_ROTATIONS and wx < wy:
|
||||
wx, wy = wy, wx
|
||||
|
||||
decim_count[wx][wy] += 1
|
||||
quant_count[wx][wy][can.weight_quant] += 1
|
||||
|
||||
print("Decimation Usage")
|
||||
print("================")
|
||||
|
||||
if MERGE_ROTATIONS:
|
||||
print(" - Note: data merging grid rotations")
|
||||
|
||||
x_keys = sorted(decim_count.keys())
|
||||
for x_key in x_keys:
|
||||
y_keys = sorted(decim_count[x_key])
|
||||
|
||||
for y_key in y_keys:
|
||||
count = decim_count[x_key][y_key]
|
||||
dat = (x_key, y_key, count)
|
||||
print(" - %ux%u weights = %6u blocks" % dat)
|
||||
|
||||
q_keys = sorted(quant_count[x_key][y_key])
|
||||
for q_key in q_keys:
|
||||
count = quant_count[x_key][y_key][q_key]
|
||||
dat = (q_key, count)
|
||||
print(" - %2u quant range = %6u blocks" % dat)
|
||||
|
||||
print("")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Refinement usage
|
||||
|
||||
total_count = 0
|
||||
better_count = 0
|
||||
could_have_count = 0
|
||||
success_count = 0
|
||||
|
||||
refinement_step = []
|
||||
|
||||
for block, pas, candidate in foreach_candidate(data):
|
||||
# Ignore zero partition blocks - they don't use refinement
|
||||
if not candidate:
|
||||
continue
|
||||
|
||||
target_error = block.error_target
|
||||
start_error = candidate.refinement_errors[0]
|
||||
end_error = candidate.refinement_errors[-1]
|
||||
|
||||
rpf = float(start_error - end_error) / float(len(candidate.refinement_errors))
|
||||
rpf = abs(rpf)
|
||||
refinement_step.append(rpf / start_error)
|
||||
|
||||
total_count += 1
|
||||
if end_error <= start_error:
|
||||
better_count += 1
|
||||
|
||||
if end_error <= target_error:
|
||||
success_count += 1
|
||||
else:
|
||||
for refinement in candidate.refinement_errors:
|
||||
if refinement <= target_error:
|
||||
could_have_count += 1
|
||||
break
|
||||
|
||||
|
||||
print("Refinement Usage")
|
||||
print("================")
|
||||
print(" - %u refinements(s)" % total_count)
|
||||
print(" - %u refinements(s) improved" % better_count)
|
||||
print(" - %u refinements(s) worsened" % (total_count - better_count))
|
||||
print(" - %u refinements(s) could hit target, but didn't" % could_have_count)
|
||||
print(" - %u refinements(s) hit target" % success_count)
|
||||
print(" - %f mean step improvement" % np.mean(refinement_step))
|
||||
|
||||
|
||||
def parse_command_line():
|
||||
"""
|
||||
Parse the command line.
|
||||
|
||||
Returns:
|
||||
Namespace: The parsed command line container.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("trace", type=argparse.FileType("r"),
|
||||
help="The trace file to analyze")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
The main function.
|
||||
|
||||
Returns:
|
||||
int: The process return code.
|
||||
"""
|
||||
args = parse_command_line()
|
||||
|
||||
data = json.load(args.trace)
|
||||
db = generate_database(data)
|
||||
filter_database(db)
|
||||
|
||||
generate_feature_statistics(db)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user