Subversion Repositories Kolibri OS

Rev

Blame | Last modification | View Log | RSS feed

  1. /*
  2.  * Copyright © 2014 Intel Corporation
  3.  *
  4.  * Permission is hereby granted, free of charge, to any person obtaining a
  5.  * copy of this software and associated documentation files (the "Software"),
  6.  * to deal in the Software without restriction, including without limitation
  7.  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
  8.  * and/or sell copies of the Software, and to permit persons to whom the
  9.  * Software is furnished to do so, subject to the following conditions:
  10.  *
  11.  * The above copyright notice and this permission notice (including the next
  12.  * paragraph) shall be included in all copies or substantial portions of the
  13.  * Software.
  14.  *
  15.  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16.  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17.  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
  18.  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19.  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  20.  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
  21.  * IN THE SOFTWARE.
  22.  *
  23.  * Authors:
  24.  *    Jason Ekstrand (jason@jlekstrand.net)
  25.  *
  26.  */
  27.  
  28. #include "nir.h"
  29.  
  30. /*
  31.  * Implements a small peephole optimization that looks for a multiply that
  32.  * is only ever used in an add and replaces both with an fma.
  33.  */
  34.  
  35. struct peephole_ffma_state {
  36.    void *mem_ctx;
  37.    nir_function_impl *impl;
  38.    bool progress;
  39. };
  40.  
  41. static inline bool
  42. are_all_uses_fadd(nir_ssa_def *def)
  43. {
  44.    if (!list_empty(&def->if_uses))
  45.       return false;
  46.  
  47.    nir_foreach_use(def, use_src) {
  48.       nir_instr *use_instr = use_src->parent_instr;
  49.  
  50.       if (use_instr->type != nir_instr_type_alu)
  51.          return false;
  52.  
  53.       nir_alu_instr *use_alu = nir_instr_as_alu(use_instr);
  54.       switch (use_alu->op) {
  55.       case nir_op_fadd:
  56.          break; /* This one's ok */
  57.  
  58.       case nir_op_imov:
  59.       case nir_op_fmov:
  60.       case nir_op_fneg:
  61.       case nir_op_fabs:
  62.          assert(use_alu->dest.dest.is_ssa);
  63.          if (!are_all_uses_fadd(&use_alu->dest.dest.ssa))
  64.             return false;
  65.          break;
  66.  
  67.       default:
  68.          return false;
  69.       }
  70.    }
  71.  
  72.    return true;
  73. }
  74.  
  75. static nir_alu_instr *
  76. get_mul_for_src(nir_alu_src *src, uint8_t swizzle[4], bool *negate, bool *abs)
  77. {
  78.    assert(src->src.is_ssa && !src->abs && !src->negate);
  79.  
  80.    nir_instr *instr = src->src.ssa->parent_instr;
  81.    if (instr->type != nir_instr_type_alu)
  82.       return NULL;
  83.  
  84.    nir_alu_instr *alu = nir_instr_as_alu(instr);
  85.    switch (alu->op) {
  86.    case nir_op_imov:
  87.    case nir_op_fmov:
  88.       alu = get_mul_for_src(&alu->src[0], swizzle, negate, abs);
  89.       break;
  90.  
  91.    case nir_op_fneg:
  92.       alu = get_mul_for_src(&alu->src[0], swizzle, negate, abs);
  93.       *negate = !*negate;
  94.       break;
  95.  
  96.    case nir_op_fabs:
  97.       alu = get_mul_for_src(&alu->src[0], swizzle, negate, abs);
  98.       *negate = false;
  99.       *abs = true;
  100.       break;
  101.  
  102.    case nir_op_fmul:
  103.       /* Only absorb a fmul into a ffma if the fmul is is only used in fadd
  104.        * operations.  This prevents us from being too aggressive with our
  105.        * fusing which can actually lead to more instructions.
  106.        */
  107.       if (!are_all_uses_fadd(&alu->dest.dest.ssa))
  108.          return NULL;
  109.       break;
  110.  
  111.    default:
  112.       return NULL;
  113.    }
  114.  
  115.    if (!alu)
  116.       return NULL;
  117.  
  118.    for (unsigned i = 0; i < 4; i++) {
  119.       if (!(alu->dest.write_mask & (1 << i)))
  120.          break;
  121.  
  122.       swizzle[i] = swizzle[src->swizzle[i]];
  123.    }
  124.  
  125.    return alu;
  126. }
  127.  
  128. static bool
  129. nir_opt_peephole_ffma_block(nir_block *block, void *void_state)
  130. {
  131.    struct peephole_ffma_state *state = void_state;
  132.  
  133.    nir_foreach_instr_safe(block, instr) {
  134.       if (instr->type != nir_instr_type_alu)
  135.          continue;
  136.  
  137.       nir_alu_instr *add = nir_instr_as_alu(instr);
  138.       if (add->op != nir_op_fadd)
  139.          continue;
  140.  
  141.       /* TODO: Maybe bail if this expression is considered "precise"? */
  142.  
  143.       assert(add->src[0].src.is_ssa && add->src[1].src.is_ssa);
  144.  
  145.       /* This, is the case a + a.  We would rather handle this with an
  146.        * algebraic reduction than fuse it.  Also, we want to only fuse
  147.        * things where the multiply is used only once and, in this case,
  148.        * it would be used twice by the same instruction.
  149.        */
  150.       if (add->src[0].src.ssa == add->src[1].src.ssa)
  151.          continue;
  152.  
  153.       nir_alu_instr *mul;
  154.       uint8_t add_mul_src, swizzle[4];
  155.       bool negate, abs;
  156.       for (add_mul_src = 0; add_mul_src < 2; add_mul_src++) {
  157.          for (unsigned i = 0; i < 4; i++)
  158.             swizzle[i] = i;
  159.  
  160.          negate = false;
  161.          abs = false;
  162.  
  163.          mul = get_mul_for_src(&add->src[add_mul_src], swizzle, &negate, &abs);
  164.  
  165.          if (mul != NULL)
  166.             break;
  167.       }
  168.  
  169.       if (mul == NULL)
  170.          continue;
  171.  
  172.       nir_ssa_def *mul_src[2];
  173.       mul_src[0] = mul->src[0].src.ssa;
  174.       mul_src[1] = mul->src[1].src.ssa;
  175.  
  176.       if (abs) {
  177.          for (unsigned i = 0; i < 2; i++) {
  178.             nir_alu_instr *abs = nir_alu_instr_create(state->mem_ctx,
  179.                                                       nir_op_fabs);
  180.             abs->src[0].src = nir_src_for_ssa(mul_src[i]);
  181.             nir_ssa_dest_init(&abs->instr, &abs->dest.dest,
  182.                               mul_src[i]->num_components, NULL);
  183.             abs->dest.write_mask = (1 << mul_src[i]->num_components) - 1;
  184.             nir_instr_insert_before(&add->instr, &abs->instr);
  185.             mul_src[i] = &abs->dest.dest.ssa;
  186.          }
  187.       }
  188.  
  189.       if (negate) {
  190.          nir_alu_instr *neg = nir_alu_instr_create(state->mem_ctx,
  191.                                                    nir_op_fneg);
  192.          neg->src[0].src = nir_src_for_ssa(mul_src[0]);
  193.          nir_ssa_dest_init(&neg->instr, &neg->dest.dest,
  194.                            mul_src[0]->num_components, NULL);
  195.          neg->dest.write_mask = (1 << mul_src[0]->num_components) - 1;
  196.          nir_instr_insert_before(&add->instr, &neg->instr);
  197.          mul_src[0] = &neg->dest.dest.ssa;
  198.       }
  199.  
  200.       nir_alu_instr *ffma = nir_alu_instr_create(state->mem_ctx, nir_op_ffma);
  201.       ffma->dest.saturate = add->dest.saturate;
  202.       ffma->dest.write_mask = add->dest.write_mask;
  203.  
  204.       for (unsigned i = 0; i < 2; i++) {
  205.          ffma->src[i].src = nir_src_for_ssa(mul_src[i]);
  206.          for (unsigned j = 0; j < add->dest.dest.ssa.num_components; j++)
  207.             ffma->src[i].swizzle[j] = mul->src[i].swizzle[swizzle[j]];
  208.       }
  209.       nir_alu_src_copy(&ffma->src[2], &add->src[1 - add_mul_src],
  210.                        state->mem_ctx);
  211.  
  212.       assert(add->dest.dest.is_ssa);
  213.  
  214.       nir_ssa_dest_init(&ffma->instr, &ffma->dest.dest,
  215.                         add->dest.dest.ssa.num_components,
  216.                         add->dest.dest.ssa.name);
  217.       nir_ssa_def_rewrite_uses(&add->dest.dest.ssa,
  218.                                nir_src_for_ssa(&ffma->dest.dest.ssa),
  219.                                state->mem_ctx);
  220.  
  221.       nir_instr_insert_before(&add->instr, &ffma->instr);
  222.       assert(list_empty(&add->dest.dest.ssa.uses));
  223.       nir_instr_remove(&add->instr);
  224.  
  225.       state->progress = true;
  226.    }
  227.  
  228.    return true;
  229. }
  230.  
  231. static bool
  232. nir_opt_peephole_ffma_impl(nir_function_impl *impl)
  233. {
  234.    struct peephole_ffma_state state;
  235.  
  236.    state.mem_ctx = ralloc_parent(impl);
  237.    state.impl = impl;
  238.    state.progress = false;
  239.  
  240.    nir_foreach_block(impl, nir_opt_peephole_ffma_block, &state);
  241.  
  242.    if (state.progress)
  243.       nir_metadata_preserve(impl, nir_metadata_block_index |
  244.                                   nir_metadata_dominance);
  245.  
  246.    return state.progress;
  247. }
  248.  
  249. bool
  250. nir_opt_peephole_ffma(nir_shader *shader)
  251. {
  252.    bool progress = false;
  253.  
  254.    nir_foreach_overload(shader, overload) {
  255.       if (overload->impl)
  256.          progress |= nir_opt_peephole_ffma_impl(overload->impl);
  257.    }
  258.  
  259.    return progress;
  260. }
  261.