Subversion Repositories Kolibri OS

Rev

Blame | Last modification | View Log | RSS feed

  1. /*
  2.  * Copyright © 2014 Intel Corporation
  3.  *
  4.  * Permission is hereby granted, free of charge, to any person obtaining a
  5.  * copy of this software and associated documentation files (the "Software"),
  6.  * to deal in the Software without restriction, including without limitation
  7.  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
  8.  * and/or sell copies of the Software, and to permit persons to whom the
  9.  * Software is furnished to do so, subject to the following conditions:
  10.  *
  11.  * The above copyright notice and this permission notice (including the next
  12.  * paragraph) shall be included in all copies or substantial portions of the
  13.  * Software.
  14.  *
  15.  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16.  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17.  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
  18.  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19.  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  20.  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
  21.  * IN THE SOFTWARE.
  22.  *
  23.  * Authors:
  24.  *    Jason Ekstrand (jason@jlekstrand.net)
  25.  *
  26.  */
  27.  
  28. #include "nir_search.h"
  29.  
  30. struct match_state {
  31.    unsigned variables_seen;
  32.    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
  33. };
  34.  
  35. static bool
  36. match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
  37.                  unsigned num_components, const uint8_t *swizzle,
  38.                  struct match_state *state);
  39.  
  40. static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
  41.  
  42. static bool alu_instr_is_bool(nir_alu_instr *instr);
  43.  
  44. static bool
  45. src_is_bool(nir_src src)
  46. {
  47.    if (!src.is_ssa)
  48.       return false;
  49.    if (src.ssa->parent_instr->type != nir_instr_type_alu)
  50.       return false;
  51.    return alu_instr_is_bool((nir_alu_instr *)src.ssa->parent_instr);
  52. }
  53.  
  54. static bool
  55. alu_instr_is_bool(nir_alu_instr *instr)
  56. {
  57.    switch (instr->op) {
  58.    case nir_op_iand:
  59.    case nir_op_ior:
  60.    case nir_op_ixor:
  61.       return src_is_bool(instr->src[0].src) && src_is_bool(instr->src[1].src);
  62.    case nir_op_inot:
  63.       return src_is_bool(instr->src[0].src);
  64.    default:
  65.       return nir_op_infos[instr->op].output_type == nir_type_bool;
  66.    }
  67. }
  68.  
  69. static bool
  70. match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
  71.             unsigned num_components, const uint8_t *swizzle,
  72.             struct match_state *state)
  73. {
  74.    uint8_t new_swizzle[4];
  75.  
  76.    /* If the source is an explicitly sized source, then we need to reset
  77.     * both the number of components and the swizzle.
  78.     */
  79.    if (nir_op_infos[instr->op].input_sizes[src] != 0) {
  80.       num_components = nir_op_infos[instr->op].input_sizes[src];
  81.       swizzle = identity_swizzle;
  82.    }
  83.  
  84.    for (int i = 0; i < num_components; ++i)
  85.       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
  86.  
  87.    switch (value->type) {
  88.    case nir_search_value_expression:
  89.       if (!instr->src[src].src.is_ssa)
  90.          return false;
  91.  
  92.       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
  93.          return false;
  94.  
  95.       return match_expression(nir_search_value_as_expression(value),
  96.                               nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
  97.                               num_components, new_swizzle, state);
  98.  
  99.    case nir_search_value_variable: {
  100.       nir_search_variable *var = nir_search_value_as_variable(value);
  101.       assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
  102.  
  103.       if (state->variables_seen & (1 << var->variable)) {
  104.          if (!nir_srcs_equal(state->variables[var->variable].src,
  105.                              instr->src[src].src))
  106.             return false;
  107.  
  108.          assert(!instr->src[src].abs && !instr->src[src].negate);
  109.  
  110.          for (int i = 0; i < num_components; ++i) {
  111.             if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
  112.                return false;
  113.          }
  114.  
  115.          return true;
  116.       } else {
  117.          if (var->is_constant &&
  118.              instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
  119.             return false;
  120.  
  121.          if (var->type != nir_type_invalid) {
  122.             if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
  123.                return false;
  124.  
  125.             nir_alu_instr *src_alu =
  126.                nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
  127.  
  128.             if (nir_op_infos[src_alu->op].output_type != var->type &&
  129.                 !(var->type == nir_type_bool && alu_instr_is_bool(src_alu)))
  130.                return false;
  131.          }
  132.  
  133.          state->variables_seen |= (1 << var->variable);
  134.          state->variables[var->variable].src = instr->src[src].src;
  135.          state->variables[var->variable].abs = false;
  136.          state->variables[var->variable].negate = false;
  137.  
  138.          for (int i = 0; i < 4; ++i) {
  139.             if (i < num_components)
  140.                state->variables[var->variable].swizzle[i] = new_swizzle[i];
  141.             else
  142.                state->variables[var->variable].swizzle[i] = 0;
  143.          }
  144.  
  145.          return true;
  146.       }
  147.    }
  148.  
  149.    case nir_search_value_constant: {
  150.       nir_search_constant *const_val = nir_search_value_as_constant(value);
  151.  
  152.       if (!instr->src[src].src.is_ssa)
  153.          return false;
  154.  
  155.       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
  156.          return false;
  157.  
  158.       nir_load_const_instr *load =
  159.          nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
  160.  
  161.       switch (nir_op_infos[instr->op].input_types[src]) {
  162.       case nir_type_float:
  163.          for (unsigned i = 0; i < num_components; ++i) {
  164.             if (load->value.f[new_swizzle[i]] != const_val->data.f)
  165.                return false;
  166.          }
  167.          return true;
  168.       case nir_type_int:
  169.       case nir_type_unsigned:
  170.       case nir_type_bool:
  171.          for (unsigned i = 0; i < num_components; ++i) {
  172.             if (load->value.i[new_swizzle[i]] != const_val->data.i)
  173.                return false;
  174.          }
  175.          return true;
  176.       default:
  177.          unreachable("Invalid alu source type");
  178.       }
  179.    }
  180.  
  181.    default:
  182.       unreachable("Invalid search value type");
  183.    }
  184. }
  185.  
  186. static bool
  187. match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
  188.                  unsigned num_components, const uint8_t *swizzle,
  189.                  struct match_state *state)
  190. {
  191.    if (instr->op != expr->opcode)
  192.       return false;
  193.  
  194.    assert(!instr->dest.saturate);
  195.    assert(nir_op_infos[instr->op].num_inputs > 0);
  196.  
  197.    /* If we have an explicitly sized destination, we can only handle the
  198.     * identity swizzle.  While dot(vec3(a, b, c).zxy) is a valid
  199.     * expression, we don't have the information right now to propagate that
  200.     * swizzle through.  We can only properly propagate swizzles if the
  201.     * instruction is vectorized.
  202.     */
  203.    if (nir_op_infos[instr->op].output_size != 0) {
  204.       for (unsigned i = 0; i < num_components; i++) {
  205.          if (swizzle[i] != i)
  206.             return false;
  207.       }
  208.    }
  209.  
  210.    /* Stash off the current variables_seen bitmask.  This way we can
  211.     * restore it prior to matching in the commutative case below.
  212.     */
  213.    unsigned variables_seen_stash = state->variables_seen;
  214.  
  215.    bool matched = true;
  216.    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
  217.       if (!match_value(expr->srcs[i], instr, i, num_components,
  218.                        swizzle, state)) {
  219.          matched = false;
  220.          break;
  221.       }
  222.    }
  223.  
  224.    if (matched)
  225.       return true;
  226.  
  227.    if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
  228.       assert(nir_op_infos[instr->op].num_inputs == 2);
  229.  
  230.       /* Restore the variables_seen bitmask.  If we don't do this, then we
  231.        * could end up with an erroneous failure due to variables found in the
  232.        * first match attempt above not matching those in the second.
  233.        */
  234.       state->variables_seen = variables_seen_stash;
  235.  
  236.       if (!match_value(expr->srcs[0], instr, 1, num_components,
  237.                        swizzle, state))
  238.          return false;
  239.  
  240.       return match_value(expr->srcs[1], instr, 0, num_components,
  241.                          swizzle, state);
  242.    } else {
  243.       return false;
  244.    }
  245. }
  246.  
  247. static nir_alu_src
  248. construct_value(const nir_search_value *value, nir_alu_type type,
  249.                 unsigned num_components, struct match_state *state,
  250.                 nir_instr *instr, void *mem_ctx)
  251. {
  252.    switch (value->type) {
  253.    case nir_search_value_expression: {
  254.       const nir_search_expression *expr = nir_search_value_as_expression(value);
  255.  
  256.       if (nir_op_infos[expr->opcode].output_size != 0)
  257.          num_components = nir_op_infos[expr->opcode].output_size;
  258.  
  259.       nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
  260.       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, NULL);
  261.       alu->dest.write_mask = (1 << num_components) - 1;
  262.       alu->dest.saturate = false;
  263.  
  264.       for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
  265.          /* If the source is an explicitly sized source, then we need to reset
  266.           * the number of components to match.
  267.           */
  268.          if (nir_op_infos[alu->op].input_sizes[i] != 0)
  269.             num_components = nir_op_infos[alu->op].input_sizes[i];
  270.  
  271.          alu->src[i] = construct_value(expr->srcs[i],
  272.                                        nir_op_infos[alu->op].input_types[i],
  273.                                        num_components,
  274.                                        state, instr, mem_ctx);
  275.       }
  276.  
  277.       nir_instr_insert_before(instr, &alu->instr);
  278.  
  279.       nir_alu_src val;
  280.       val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
  281.       val.negate = false;
  282.       val.abs = false,
  283.       memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
  284.  
  285.       return val;
  286.    }
  287.  
  288.    case nir_search_value_variable: {
  289.       const nir_search_variable *var = nir_search_value_as_variable(value);
  290.       assert(state->variables_seen & (1 << var->variable));
  291.  
  292.       nir_alu_src val = { NIR_SRC_INIT };
  293.       nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
  294.  
  295.       assert(!var->is_constant);
  296.  
  297.       return val;
  298.    }
  299.  
  300.    case nir_search_value_constant: {
  301.       const nir_search_constant *c = nir_search_value_as_constant(value);
  302.       nir_load_const_instr *load = nir_load_const_instr_create(mem_ctx, 1);
  303.  
  304.       switch (type) {
  305.       case nir_type_float:
  306.          load->def.name = ralloc_asprintf(mem_ctx, "%f", c->data.f);
  307.          load->value.f[0] = c->data.f;
  308.          break;
  309.       case nir_type_int:
  310.          load->def.name = ralloc_asprintf(mem_ctx, "%d", c->data.i);
  311.          load->value.i[0] = c->data.i;
  312.          break;
  313.       case nir_type_unsigned:
  314.       case nir_type_bool:
  315.          load->value.u[0] = c->data.u;
  316.          break;
  317.       default:
  318.          unreachable("Invalid alu source type");
  319.       }
  320.  
  321.       nir_instr_insert_before(instr, &load->instr);
  322.  
  323.       nir_alu_src val;
  324.       val.src = nir_src_for_ssa(&load->def);
  325.       val.negate = false;
  326.       val.abs = false,
  327.       memset(val.swizzle, 0, sizeof val.swizzle);
  328.  
  329.       return val;
  330.    }
  331.  
  332.    default:
  333.       unreachable("Invalid search value type");
  334.    }
  335. }
  336.  
  337. nir_alu_instr *
  338. nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
  339.                   const nir_search_value *replace, void *mem_ctx)
  340. {
  341.    uint8_t swizzle[4] = { 0, 0, 0, 0 };
  342.  
  343.    for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
  344.       swizzle[i] = i;
  345.  
  346.    assert(instr->dest.dest.is_ssa);
  347.  
  348.    struct match_state state;
  349.    state.variables_seen = 0;
  350.  
  351.    if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
  352.                          swizzle, &state))
  353.       return NULL;
  354.  
  355.    /* Inserting a mov may be unnecessary.  However, it's much easier to
  356.     * simply let copy propagation clean this up than to try to go through
  357.     * and rewrite swizzles ourselves.
  358.     */
  359.    nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
  360.    mov->dest.write_mask = instr->dest.write_mask;
  361.    nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
  362.                      instr->dest.dest.ssa.num_components, NULL);
  363.  
  364.    mov->src[0] = construct_value(replace, nir_op_infos[instr->op].output_type,
  365.                                  instr->dest.dest.ssa.num_components, &state,
  366.                                  &instr->instr, mem_ctx);
  367.    nir_instr_insert_before(&instr->instr, &mov->instr);
  368.  
  369.    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
  370.                             nir_src_for_ssa(&mov->dest.dest.ssa), mem_ctx);
  371.  
  372.    /* We know this one has no more uses because we just rewrote them all,
  373.     * so we can remove it.  The rest of the matched expression, however, we
  374.     * don't know so much about.  We'll just let dead code clean them up.
  375.     */
  376.    nir_instr_remove(&instr->instr);
  377.  
  378.    return mov;
  379. }
  380.